sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -36,12 +36,13 @@ import tempfile
36
36
  import threading
37
37
  import time
38
38
  import warnings
39
+ from contextlib import contextmanager
39
40
  from functools import lru_cache
40
41
  from importlib.metadata import PackageNotFoundError, version
41
42
  from importlib.util import find_spec
42
43
  from io import BytesIO
43
- from multiprocessing import Pool
44
44
  from multiprocessing.reduction import ForkingPickler
45
+ from pathlib import Path
45
46
  from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
46
47
 
47
48
  import numpy as np
@@ -54,13 +55,13 @@ import triton
54
55
  import zmq
55
56
  from fastapi.responses import ORJSONResponse
56
57
  from packaging import version as pkg_version
57
- from packaging.version import Version, parse
58
+ from PIL import Image
58
59
  from starlette.routing import Mount
59
60
  from torch import nn
60
61
  from torch.func import functional_call
61
62
  from torch.library import Library
62
63
  from torch.profiler import ProfilerActivity, profile, record_function
63
- from torch.utils.cpp_extension import CUDA_HOME
64
+ from torch.utils._contextlib import _DecoratorContextManager
64
65
  from triton.runtime.cache import (
65
66
  FileCacheManager,
66
67
  default_cache_dir,
@@ -76,6 +77,11 @@ time_infos = {}
76
77
  HIP_FP8_E4M3_FNUZ_MAX = 224.0
77
78
 
78
79
 
80
+ def get_bool_env_var(name: str, default: str = "false") -> bool:
81
+ value = os.getenv(name, default)
82
+ return value.lower() in ("true", "1")
83
+
84
+
79
85
  # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
80
86
  def is_hip() -> bool:
81
87
  return torch.version.hip is not None
@@ -126,6 +132,63 @@ def is_cuda_available():
126
132
  return is_cuda()
127
133
 
128
134
 
135
+ _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
136
+ "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
137
+ )
138
+
139
+
140
+ class DynamicGradMode(_DecoratorContextManager):
141
+ """
142
+ A combination of torch.no_grad and torch.inference_mode,
143
+ with their behavior controlled by an environment variable. Just refer to them.
144
+ """
145
+
146
+ @staticmethod
147
+ def set_inference_mode(mode: bool):
148
+ if isinstance(mode, bool):
149
+ global _ENABLE_TORCH_INFERENCE_MODE
150
+
151
+ _ENABLE_TORCH_INFERENCE_MODE = mode
152
+ else:
153
+ logger.warning("mode is not a boolean object")
154
+
155
+ def __init__(self, mode=True):
156
+ if not torch._jit_internal.is_scripting():
157
+ super().__init__()
158
+ if _ENABLE_TORCH_INFERENCE_MODE:
159
+ self.mode = mode
160
+ else:
161
+ self.prev = False
162
+
163
+ def __new__(cls, mode_or_orig_func=True if _ENABLE_TORCH_INFERENCE_MODE else None):
164
+ if mode_or_orig_func is None or isinstance(mode_or_orig_func, bool):
165
+ return super().__new__(cls)
166
+ return cls()(mode_or_orig_func)
167
+
168
+ def __enter__(self) -> None:
169
+ if _ENABLE_TORCH_INFERENCE_MODE:
170
+ self._inference_mode_context = torch._C._InferenceMode(self.mode)
171
+ self._inference_mode_context.__enter__()
172
+ else:
173
+ self.prev = torch.is_grad_enabled()
174
+ torch.set_grad_enabled(False)
175
+
176
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
177
+ if _ENABLE_TORCH_INFERENCE_MODE:
178
+ self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
179
+ else:
180
+ torch.set_grad_enabled(self.prev)
181
+
182
+ def clone(self) -> "DynamicGradMode":
183
+ r"""
184
+ Create a copy of this class
185
+ """
186
+ if _ENABLE_TORCH_INFERENCE_MODE:
187
+ return self.__class__(self.mode)
188
+ else:
189
+ return self.__class__()
190
+
191
+
129
192
  def enable_show_time_cost():
130
193
  global show_time_cost
131
194
  show_time_cost = True
@@ -198,7 +261,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
198
261
  When distributed is True, the available memory is the minimum available memory of all GPUs.
199
262
  """
200
263
  if device == "cuda":
201
- num_gpus = torch.cuda.device_count()
264
+ num_gpus = cuda_device_count_stateless()
202
265
  assert gpu_id < num_gpus
203
266
 
204
267
  if torch.cuda.current_device() != gpu_id:
@@ -443,17 +506,46 @@ def decode_video_base64(video_base64):
443
506
  ) # Return an empty array and size tuple if no frames were found
444
507
 
445
508
 
446
- def load_image(image_file: Union[str, bytes]):
447
- from PIL import Image
509
+ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
510
+ # Use soundfile here, since librosa use it under the hood,
511
+ # and librosa will not support audio loading in the future
512
+ import soundfile as sf
513
+ from scipy.signal import resample
514
+
515
+ # print(f"loading {audio_file}")
516
+ # Load audio data
517
+ if isinstance(audio_file, bytes):
518
+ audio, original_sr = sf.read(BytesIO(audio_file))
519
+ elif audio_file.startswith("data:"):
520
+ audio_file = audio_file.split(",")[1]
521
+ audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
522
+ elif isinstance(audio_file, str):
523
+ audio, original_sr = sf.read(audio_file)
524
+ else:
525
+ raise ValueError(f"Invalid audio format: {audio_file}")
526
+
527
+ # Resample audio if the original sample rate is different from the desired sample rate
528
+ if original_sr != sr:
529
+ num_samples = int(len(audio) * float(sr) / original_sr)
530
+ audio = resample(audio, num_samples)
448
531
 
532
+ # Convert to mono if requested and audio is stereo
533
+ if mono and len(audio.shape) > 1:
534
+ audio = np.mean(audio, axis=1)
535
+
536
+ return audio
537
+
538
+
539
+ def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
449
540
  image = image_size = None
450
541
 
451
542
  if isinstance(image_file, bytes):
452
543
  image = Image.open(BytesIO(image_file))
453
544
  elif image_file.startswith("http://") or image_file.startswith("https://"):
454
545
  timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
455
- response = requests.get(image_file, timeout=timeout)
456
- image = Image.open(BytesIO(response.content))
546
+ response = requests.get(image_file, stream=True, timeout=timeout).raw
547
+ image = Image.open(response)
548
+ response.close()
457
549
  elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
458
550
  image = Image.open(image_file)
459
551
  elif image_file.startswith("data:"):
@@ -471,7 +563,10 @@ def load_image(image_file: Union[str, bytes]):
471
563
 
472
564
 
473
565
  def suppress_other_loggers():
474
- from vllm.logger import logger as vllm_default_logger
566
+ try:
567
+ from vllm.logger import logger as vllm_default_logger
568
+ except ImportError:
569
+ return
475
570
 
476
571
  vllm_default_logger.setLevel(logging.WARN)
477
572
  logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
@@ -480,6 +575,7 @@ def suppress_other_loggers():
480
575
  logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
481
576
  logging.WARN
482
577
  )
578
+ logging.getLogger("vllm.config").setLevel(logging.ERROR)
483
579
 
484
580
  warnings.filterwarnings(
485
581
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -527,6 +623,10 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
527
623
 
528
624
  if include_parent:
529
625
  try:
626
+ if parent_pid == os.getpid():
627
+ itself.kill()
628
+ sys.exit(0)
629
+
530
630
  itself.kill()
531
631
 
532
632
  # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
@@ -555,11 +655,14 @@ def monkey_patch_p2p_access_check():
555
655
 
556
656
 
557
657
  def monkey_patch_vllm_gguf_config():
558
- from vllm.model_executor.layers.quantization.gguf import (
559
- GGUFConfig,
560
- GGUFEmbeddingMethod,
561
- GGUFLinearMethod,
562
- )
658
+ try:
659
+ from vllm.model_executor.layers.quantization.gguf import (
660
+ GGUFConfig,
661
+ GGUFEmbeddingMethod,
662
+ GGUFLinearMethod,
663
+ )
664
+ except ImportError:
665
+ return
563
666
 
564
667
  from sglang.srt.layers.linear import LinearBase
565
668
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -651,6 +754,16 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
651
754
 
652
755
 
653
756
  def configure_logger(server_args, prefix: str = ""):
757
+ if SGLANG_LOGGING_CONFIG_PATH := os.getenv("SGLANG_LOGGING_CONFIG_PATH"):
758
+ if not os.path.exists(SGLANG_LOGGING_CONFIG_PATH):
759
+ raise Exception(
760
+ "Setting SGLANG_LOGGING_CONFIG_PATH from env with "
761
+ f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
762
+ )
763
+ with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
764
+ custom_config = json.loads(file.read())
765
+ logging.config.dictConfig(custom_config)
766
+ return
654
767
  format = f"[%(asctime)s{prefix}] %(message)s"
655
768
  # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
656
769
  logging.basicConfig(
@@ -774,12 +887,22 @@ def get_zmq_socket(
774
887
  buf_size = -1
775
888
 
776
889
  socket = context.socket(socket_type)
777
- if socket_type == zmq.PUSH:
890
+
891
+ def set_send_opt():
778
892
  socket.setsockopt(zmq.SNDHWM, 0)
779
893
  socket.setsockopt(zmq.SNDBUF, buf_size)
780
- elif socket_type == zmq.PULL:
894
+
895
+ def set_recv_opt():
781
896
  socket.setsockopt(zmq.RCVHWM, 0)
782
897
  socket.setsockopt(zmq.RCVBUF, buf_size)
898
+
899
+ if socket_type == zmq.PUSH:
900
+ set_send_opt()
901
+ elif socket_type == zmq.PULL:
902
+ set_recv_opt()
903
+ elif socket_type == zmq.DEALER:
904
+ set_send_opt()
905
+ set_recv_opt()
783
906
  else:
784
907
  raise ValueError(f"Unsupported socket type: {socket_type}")
785
908
 
@@ -910,6 +1033,13 @@ def get_amdgpu_memory_capacity():
910
1033
  )
911
1034
 
912
1035
 
1036
+ def get_device_sm():
1037
+ if torch.cuda.is_available():
1038
+ major, minor = torch.cuda.get_device_capability()
1039
+ return major * 10 + minor
1040
+ return 0
1041
+
1042
+
913
1043
  def get_nvgpu_memory_capacity():
914
1044
  try:
915
1045
  # Run nvidia-smi and capture the output
@@ -1246,11 +1376,6 @@ def set_gpu_proc_affinity(
1246
1376
  logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")
1247
1377
 
1248
1378
 
1249
- def get_bool_env_var(name: str, default: str = "false") -> bool:
1250
- value = os.getenv(name, default)
1251
- return value.lower() in ("true", "1")
1252
-
1253
-
1254
1379
  @lru_cache(maxsize=2)
1255
1380
  def disable_request_logging() -> bool:
1256
1381
  return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
@@ -1561,6 +1686,16 @@ def next_power_of_2(n: int):
1561
1686
  setattr(triton, "next_power_of_2", next_power_of_2)
1562
1687
 
1563
1688
 
1689
+ @contextmanager
1690
+ def empty_context(*args, **kwargs):
1691
+ try:
1692
+ # Setup code goes here
1693
+ yield
1694
+ finally:
1695
+ # Cleanup code goes here
1696
+ pass
1697
+
1698
+
1564
1699
  def add_prefix(name: str, prefix: str) -> str:
1565
1700
  """Add a weight path prefix to a module name.
1566
1701
 
@@ -1572,3 +1707,29 @@ def add_prefix(name: str, prefix: str) -> str:
1572
1707
  The string `prefix.name` if prefix is non-empty, otherwise just `name`.
1573
1708
  """
1574
1709
  return name if not prefix else f"{prefix}.{name}"
1710
+
1711
+
1712
+ def is_remote_url(url: Union[str, Path]) -> bool:
1713
+ """
1714
+ Check if the URL is a remote URL of the format:
1715
+ <connector_type>://<host>:<port>/<model_name>
1716
+ """
1717
+ if isinstance(url, Path):
1718
+ return False
1719
+
1720
+ pattern = r"(.+)://(.*)"
1721
+ m = re.match(pattern, url)
1722
+ return m is not None
1723
+
1724
+
1725
+ def parse_connector_type(url: str) -> str:
1726
+ """
1727
+ Parse the connector type from the URL of the format:
1728
+ <connector_type>://<path>
1729
+ """
1730
+ pattern = r"(.+)://(.*)"
1731
+ m = re.match(pattern, url)
1732
+ if m is None:
1733
+ return ""
1734
+
1735
+ return m.group(1)
File without changes
File without changes
@@ -0,0 +1,312 @@
1
+ import unittest
2
+
3
+ import torch
4
+
5
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
6
+ from sglang.srt.layers.radix_attention import RadixAttention
7
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
8
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
9
+ from sglang.test.test_utils import CustomTestCase
10
+
11
+
12
+ class MockModelRunner:
13
+ model_config = type(
14
+ "ModelConfig", (), {"context_len": 2048, "is_multimodal": False}
15
+ )
16
+ sliding_window_size = None
17
+
18
+ def __init__(self, device="cuda"):
19
+ self.device = device
20
+ # Create a proper req_to_token_pool with the req_to_token attribute
21
+ self.req_to_token_pool = type(
22
+ "TokenPool",
23
+ (),
24
+ {
25
+ "size": 160, # a typical max_bs * max_context_len for cuda graph decode
26
+ "req_to_token": torch.zeros(
27
+ 160, 2048, dtype=torch.int32, device=device
28
+ ), # Add req_to_token attribute
29
+ },
30
+ )
31
+
32
+
33
+ class MockReqToTokenPool:
34
+ def __init__(self, batch_size, seq_len, device):
35
+ self.req_to_token = (
36
+ torch.arange(batch_size * seq_len, device=device)
37
+ .reshape(batch_size, seq_len)
38
+ .to(torch.int32)
39
+ )
40
+
41
+
42
+ @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
43
+ class TestFlashAttentionBackend(CustomTestCase):
44
+ def setUp(self):
45
+ """Set up test fixtures before each test method."""
46
+ self.model_runner = MockModelRunner()
47
+ self.backend = FlashAttentionBackend(self.model_runner)
48
+
49
+ # Common test parameters
50
+ self.batch_size = 2
51
+ self.seq_len = 4
52
+ self.num_heads = 2
53
+ self.head_dim = 8
54
+ self.device = "cuda"
55
+ self.dtype = torch.float16
56
+
57
+ def _create_attention_layer(self):
58
+ """Helper method to create an attention layer."""
59
+ return RadixAttention(
60
+ num_heads=self.num_heads,
61
+ head_dim=self.head_dim,
62
+ scaling=1.0,
63
+ num_kv_heads=self.num_heads,
64
+ layer_id=0,
65
+ )
66
+
67
+ def _create_kv_pool(self, size):
68
+ """Helper method to create a KV pool."""
69
+ return MHATokenToKVPool(
70
+ size=size,
71
+ page_size=1, # only consider page=1 for unit test
72
+ dtype=self.dtype,
73
+ head_num=self.num_heads,
74
+ head_dim=self.head_dim,
75
+ layer_num=1, # only consider layer=1 for unit test
76
+ device=self.device,
77
+ enable_memory_saver=False,
78
+ )
79
+
80
+ def _create_qkv_tensors(self, tokens_len):
81
+ """Helper method to create q, k, v tensors."""
82
+ return (
83
+ torch.randn(
84
+ tokens_len,
85
+ self.num_heads,
86
+ self.head_dim,
87
+ dtype=self.dtype,
88
+ device=self.device,
89
+ ),
90
+ torch.randn(
91
+ tokens_len,
92
+ self.num_heads,
93
+ self.head_dim,
94
+ dtype=self.dtype,
95
+ device=self.device,
96
+ ),
97
+ torch.randn(
98
+ tokens_len,
99
+ self.num_heads,
100
+ self.head_dim,
101
+ dtype=self.dtype,
102
+ device=self.device,
103
+ ),
104
+ )
105
+
106
+ def _verify_output(self, output, expected_shape):
107
+ """Helper method to verify output."""
108
+ self.assertEqual(
109
+ output.shape,
110
+ expected_shape,
111
+ f"Expected shape {expected_shape}, got {output.shape}",
112
+ )
113
+ self.assertEqual(output.dtype, self.dtype)
114
+ self.assertEqual(output.device.type, "cuda")
115
+ self.assertEqual(
116
+ torch.isnan(output).sum().item(), 0, "Output contains NaN values"
117
+ )
118
+
119
+ def test_forward_extend(self):
120
+ """Test the standard extend operation."""
121
+ # Create test inputs
122
+ q, k, v = self._create_qkv_tensors(self.batch_size * self.seq_len)
123
+
124
+ # Create attention layer
125
+ layer = self._create_attention_layer()
126
+
127
+ # Create forward batch
128
+ forward_batch = ForwardBatch(
129
+ batch_size=self.batch_size,
130
+ input_ids=torch.randint(
131
+ 0, 100, (self.batch_size, self.seq_len), device=self.device
132
+ ),
133
+ out_cache_loc=torch.arange(
134
+ self.batch_size * self.seq_len, device=self.device
135
+ ),
136
+ seq_lens_sum=self.batch_size * self.seq_len,
137
+ forward_mode=ForwardMode.EXTEND,
138
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
139
+ seq_lens=torch.tensor([self.seq_len] * self.batch_size, device=self.device),
140
+ # 0 prefix, 4 extend
141
+ extend_prefix_lens=torch.tensor([0] * self.batch_size, device=self.device),
142
+ extend_seq_lens=torch.tensor([4] * self.batch_size, device=self.device),
143
+ attn_backend=self.backend,
144
+ )
145
+
146
+ # Add token pool and KV cache
147
+ forward_batch.req_to_token_pool = MockReqToTokenPool(
148
+ self.batch_size, self.seq_len, self.device
149
+ )
150
+ forward_batch.token_to_kv_pool = self._create_kv_pool(
151
+ self.batch_size * self.seq_len
152
+ )
153
+
154
+ # Initialize forward metadata before running the attention
155
+ self.backend.init_forward_metadata(forward_batch)
156
+
157
+ # Run forward_extend
158
+ output = self.backend.forward_extend(q, k, v, layer, forward_batch)
159
+
160
+ # Verify output
161
+ expected_shape = (
162
+ self.batch_size * self.seq_len,
163
+ self.num_heads * self.head_dim,
164
+ )
165
+ self._verify_output(output, expected_shape)
166
+
167
+ def test_forward_decode(self):
168
+ """Test the decode operation with cached tokens."""
169
+ # For decode, we only have one token per sequence
170
+ decode_len = 1
171
+ curr_seq_len = self.seq_len + decode_len
172
+
173
+ # Create test inputs
174
+ q, k, v = self._create_qkv_tensors(self.batch_size * decode_len)
175
+
176
+ # Create attention layer
177
+ layer = self._create_attention_layer()
178
+
179
+ # Create forward batch
180
+ forward_batch = ForwardBatch(
181
+ batch_size=self.batch_size,
182
+ input_ids=torch.randint(
183
+ 0, 100, (self.batch_size, decode_len), device=self.device
184
+ ),
185
+ out_cache_loc=torch.arange(
186
+ self.batch_size * self.seq_len,
187
+ self.batch_size * curr_seq_len,
188
+ device=self.device,
189
+ ),
190
+ seq_lens_sum=self.batch_size * curr_seq_len,
191
+ forward_mode=ForwardMode.DECODE,
192
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
193
+ seq_lens=torch.tensor([curr_seq_len] * self.batch_size, device=self.device),
194
+ attn_backend=self.backend,
195
+ )
196
+
197
+ # Add token pool and KV cache
198
+ forward_batch.req_to_token_pool = MockReqToTokenPool(
199
+ self.batch_size, curr_seq_len, self.device
200
+ )
201
+ forward_batch.token_to_kv_pool = self._create_kv_pool(
202
+ self.batch_size * curr_seq_len
203
+ )
204
+
205
+ # Pre-fill KV cache
206
+ cache_k, cache_v, _ = self._create_qkv_tensors(self.batch_size * self.seq_len)
207
+ forward_batch.token_to_kv_pool.set_kv_buffer(
208
+ layer,
209
+ torch.arange(self.batch_size * self.seq_len, device=self.device),
210
+ cache_k,
211
+ cache_v,
212
+ layer.k_scale,
213
+ layer.v_scale,
214
+ )
215
+
216
+ # Initialize forward metadata before running the attention
217
+ self.backend.init_forward_metadata(forward_batch)
218
+
219
+ # Run forward_decode
220
+ output = self.backend.forward_decode(q, k, v, layer, forward_batch)
221
+
222
+ # Verify output
223
+ expected_shape = (self.batch_size, self.num_heads * self.head_dim)
224
+ self._verify_output(output, expected_shape)
225
+
226
+ def test_forward_extend_with_prefix(self):
227
+ """Test extending from cached prefix tokens."""
228
+ # Define prefix and extend lengths
229
+ prefix_len = 2
230
+ extend_len = 2
231
+ total_len = prefix_len + extend_len
232
+
233
+ # Create test inputs for the extend portion
234
+ q, k, v = self._create_qkv_tensors(self.batch_size * extend_len)
235
+
236
+ # Create attention layer
237
+ layer = self._create_attention_layer()
238
+
239
+ # Create forward batch
240
+ forward_batch = ForwardBatch(
241
+ batch_size=self.batch_size,
242
+ input_ids=torch.randint(
243
+ 0, 100, (self.batch_size, extend_len), device=self.device
244
+ ),
245
+ out_cache_loc=torch.arange(
246
+ self.batch_size * prefix_len,
247
+ self.batch_size * total_len,
248
+ device=self.device,
249
+ ),
250
+ seq_lens_sum=self.batch_size * total_len,
251
+ forward_mode=ForwardMode.EXTEND,
252
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
253
+ seq_lens=torch.tensor([total_len] * self.batch_size, device=self.device),
254
+ extend_prefix_lens=torch.tensor(
255
+ [prefix_len] * self.batch_size, device=self.device
256
+ ),
257
+ extend_seq_lens=torch.tensor(
258
+ [extend_len] * self.batch_size, device=self.device
259
+ ),
260
+ attn_backend=self.backend,
261
+ )
262
+
263
+ # Add token pool and KV cache
264
+ forward_batch.req_to_token_pool = MockReqToTokenPool(
265
+ self.batch_size, total_len, self.device
266
+ )
267
+ forward_batch.token_to_kv_pool = self._create_kv_pool(
268
+ self.batch_size * total_len
269
+ )
270
+
271
+ # Pre-fill the KV cache for prefix with known values
272
+ cache_k = torch.ones(
273
+ self.batch_size * prefix_len,
274
+ self.num_heads,
275
+ self.head_dim,
276
+ dtype=self.dtype,
277
+ device=self.device,
278
+ )
279
+ cache_v = (
280
+ torch.ones(
281
+ self.batch_size * prefix_len,
282
+ self.num_heads,
283
+ self.head_dim,
284
+ dtype=self.dtype,
285
+ device=self.device,
286
+ )
287
+ * 2
288
+ )
289
+
290
+ # Set the prefix KV cache
291
+ forward_batch.token_to_kv_pool.set_kv_buffer(
292
+ layer,
293
+ torch.arange(self.batch_size * prefix_len, device=self.device),
294
+ cache_k,
295
+ cache_v,
296
+ layer.k_scale,
297
+ layer.v_scale,
298
+ )
299
+
300
+ # Initialize forward metadata before running the attention
301
+ self.backend.init_forward_metadata(forward_batch)
302
+
303
+ # Run forward_extend
304
+ output = self.backend.forward_extend(q, k, v, layer, forward_batch)
305
+
306
+ # Verify output
307
+ expected_shape = (self.batch_size * extend_len, self.num_heads * self.head_dim)
308
+ self._verify_output(output, expected_shape)
309
+
310
+
311
+ if __name__ == "__main__":
312
+ unittest.main()
sglang/test/runners.py CHANGED
@@ -437,6 +437,7 @@ class SRTRunner:
437
437
  speculative_eagle_topk: Optional[int] = None,
438
438
  speculative_num_draft_tokens: Optional[int] = None,
439
439
  disable_overlap_schedule: bool = False,
440
+ disable_custom_all_reduce: bool = False,
440
441
  ):
441
442
  self.model_type = model_type
442
443
  self.is_generation = model_type == "generation"
@@ -470,6 +471,7 @@ class SRTRunner:
470
471
  enable_ep_moe=enable_ep_moe,
471
472
  disable_overlap_schedule=disable_overlap_schedule,
472
473
  cuda_graph_max_bs=4,
474
+ disable_custom_all_reduce=disable_custom_all_reduce,
473
475
  **spec_kwargs,
474
476
  )
475
477
 
@@ -4,9 +4,10 @@ import unittest
4
4
  import torch
5
5
 
6
6
  from sglang.srt.layers.activation import GeluAndMul
7
+ from sglang.test.test_utils import CustomTestCase
7
8
 
8
9
 
9
- class TestGeluAndMul(unittest.TestCase):
10
+ class TestGeluAndMul(CustomTestCase):
10
11
  DTYPES = [torch.half, torch.bfloat16]
11
12
  NUM_TOKENS = [7, 83, 2048]
12
13
  D = [512, 4096, 5120, 13824]