sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -36,12 +36,13 @@ import tempfile
36
36
  import threading
37
37
  import time
38
38
  import warnings
39
+ from contextlib import contextmanager
39
40
  from functools import lru_cache
40
41
  from importlib.metadata import PackageNotFoundError, version
41
42
  from importlib.util import find_spec
42
43
  from io import BytesIO
43
- from multiprocessing import Pool
44
44
  from multiprocessing.reduction import ForkingPickler
45
+ from pathlib import Path
45
46
  from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
46
47
 
47
48
  import numpy as np
@@ -54,13 +55,13 @@ import triton
54
55
  import zmq
55
56
  from fastapi.responses import ORJSONResponse
56
57
  from packaging import version as pkg_version
57
- from packaging.version import Version, parse
58
+ from PIL import Image
58
59
  from starlette.routing import Mount
59
60
  from torch import nn
60
61
  from torch.func import functional_call
61
62
  from torch.library import Library
62
63
  from torch.profiler import ProfilerActivity, profile, record_function
63
- from torch.utils.cpp_extension import CUDA_HOME
64
+ from torch.utils._contextlib import _DecoratorContextManager
64
65
  from triton.runtime.cache import (
65
66
  FileCacheManager,
66
67
  default_cache_dir,
@@ -76,6 +77,11 @@ time_infos = {}
76
77
  HIP_FP8_E4M3_FNUZ_MAX = 224.0
77
78
 
78
79
 
80
+ def get_bool_env_var(name: str, default: str = "false") -> bool:
81
+ value = os.getenv(name, default)
82
+ return value.lower() in ("true", "1")
83
+
84
+
79
85
  # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
80
86
  def is_hip() -> bool:
81
87
  return torch.version.hip is not None
@@ -126,6 +132,63 @@ def is_cuda_available():
126
132
  return is_cuda()
127
133
 
128
134
 
135
+ _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
136
+ "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
137
+ )
138
+
139
+
140
+ class DynamicGradMode(_DecoratorContextManager):
141
+ """
142
+ A combination of torch.no_grad and torch.inference_mode,
143
+ with their behavior controlled by an environment variable. Just refer to them.
144
+ """
145
+
146
+ @staticmethod
147
+ def set_inference_mode(mode: bool):
148
+ if isinstance(mode, bool):
149
+ global _ENABLE_TORCH_INFERENCE_MODE
150
+
151
+ _ENABLE_TORCH_INFERENCE_MODE = mode
152
+ else:
153
+ logger.warning("mode is not a boolean object")
154
+
155
+ def __init__(self, mode=True):
156
+ if not torch._jit_internal.is_scripting():
157
+ super().__init__()
158
+ if _ENABLE_TORCH_INFERENCE_MODE:
159
+ self.mode = mode
160
+ else:
161
+ self.prev = False
162
+
163
+ def __new__(cls, mode_or_orig_func=True if _ENABLE_TORCH_INFERENCE_MODE else None):
164
+ if mode_or_orig_func is None or isinstance(mode_or_orig_func, bool):
165
+ return super().__new__(cls)
166
+ return cls()(mode_or_orig_func)
167
+
168
+ def __enter__(self) -> None:
169
+ if _ENABLE_TORCH_INFERENCE_MODE:
170
+ self._inference_mode_context = torch._C._InferenceMode(self.mode)
171
+ self._inference_mode_context.__enter__()
172
+ else:
173
+ self.prev = torch.is_grad_enabled()
174
+ torch.set_grad_enabled(False)
175
+
176
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
177
+ if _ENABLE_TORCH_INFERENCE_MODE:
178
+ self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
179
+ else:
180
+ torch.set_grad_enabled(self.prev)
181
+
182
+ def clone(self) -> "DynamicGradMode":
183
+ r"""
184
+ Create a copy of this class
185
+ """
186
+ if _ENABLE_TORCH_INFERENCE_MODE:
187
+ return self.__class__(self.mode)
188
+ else:
189
+ return self.__class__()
190
+
191
+
129
192
  def enable_show_time_cost():
130
193
  global show_time_cost
131
194
  show_time_cost = True
@@ -198,7 +261,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
198
261
  When distributed is True, the available memory is the minimum available memory of all GPUs.
199
262
  """
200
263
  if device == "cuda":
201
- num_gpus = torch.cuda.device_count()
264
+ num_gpus = cuda_device_count_stateless()
202
265
  assert gpu_id < num_gpus
203
266
 
204
267
  if torch.cuda.current_device() != gpu_id:
@@ -443,17 +506,46 @@ def decode_video_base64(video_base64):
443
506
  ) # Return an empty array and size tuple if no frames were found
444
507
 
445
508
 
446
- def load_image(image_file: Union[str, bytes]):
447
- from PIL import Image
509
+ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
510
+ # Use soundfile here, since librosa use it under the hood,
511
+ # and librosa will not support audio loading in the future
512
+ import soundfile as sf
513
+ from scipy.signal import resample
514
+
515
+ # print(f"loading {audio_file}")
516
+ # Load audio data
517
+ if isinstance(audio_file, bytes):
518
+ audio, original_sr = sf.read(BytesIO(audio_file))
519
+ elif audio_file.startswith("data:"):
520
+ audio_file = audio_file.split(",")[1]
521
+ audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
522
+ elif isinstance(audio_file, str):
523
+ audio, original_sr = sf.read(audio_file)
524
+ else:
525
+ raise ValueError(f"Invalid audio format: {audio_file}")
448
526
 
527
+ # Resample audio if the original sample rate is different from the desired sample rate
528
+ if original_sr != sr:
529
+ num_samples = int(len(audio) * float(sr) / original_sr)
530
+ audio = resample(audio, num_samples)
531
+
532
+ # Convert to mono if requested and audio is stereo
533
+ if mono and len(audio.shape) > 1:
534
+ audio = np.mean(audio, axis=1)
535
+
536
+ return audio
537
+
538
+
539
+ def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
449
540
  image = image_size = None
450
541
 
451
542
  if isinstance(image_file, bytes):
452
543
  image = Image.open(BytesIO(image_file))
453
544
  elif image_file.startswith("http://") or image_file.startswith("https://"):
454
545
  timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
455
- response = requests.get(image_file, timeout=timeout)
456
- image = Image.open(BytesIO(response.content))
546
+ response = requests.get(image_file, stream=True, timeout=timeout).raw
547
+ image = Image.open(response)
548
+ response.close()
457
549
  elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
458
550
  image = Image.open(image_file)
459
551
  elif image_file.startswith("data:"):
@@ -471,7 +563,10 @@ def load_image(image_file: Union[str, bytes]):
471
563
 
472
564
 
473
565
  def suppress_other_loggers():
474
- from vllm.logger import logger as vllm_default_logger
566
+ try:
567
+ from vllm.logger import logger as vllm_default_logger
568
+ except ImportError:
569
+ return
475
570
 
476
571
  vllm_default_logger.setLevel(logging.WARN)
477
572
  logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
@@ -480,6 +575,7 @@ def suppress_other_loggers():
480
575
  logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
481
576
  logging.WARN
482
577
  )
578
+ logging.getLogger("vllm.config").setLevel(logging.ERROR)
483
579
 
484
580
  warnings.filterwarnings(
485
581
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -527,6 +623,10 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
527
623
 
528
624
  if include_parent:
529
625
  try:
626
+ if parent_pid == os.getpid():
627
+ itself.kill()
628
+ sys.exit(0)
629
+
530
630
  itself.kill()
531
631
 
532
632
  # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
@@ -555,11 +655,14 @@ def monkey_patch_p2p_access_check():
555
655
 
556
656
 
557
657
  def monkey_patch_vllm_gguf_config():
558
- from vllm.model_executor.layers.quantization.gguf import (
559
- GGUFConfig,
560
- GGUFEmbeddingMethod,
561
- GGUFLinearMethod,
562
- )
658
+ try:
659
+ from vllm.model_executor.layers.quantization.gguf import (
660
+ GGUFConfig,
661
+ GGUFEmbeddingMethod,
662
+ GGUFLinearMethod,
663
+ )
664
+ except ImportError:
665
+ return
563
666
 
564
667
  from sglang.srt.layers.linear import LinearBase
565
668
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -651,6 +754,16 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
651
754
 
652
755
 
653
756
  def configure_logger(server_args, prefix: str = ""):
757
+ if SGLANG_LOGGING_CONFIG_PATH := os.getenv("SGLANG_LOGGING_CONFIG_PATH"):
758
+ if not os.path.exists(SGLANG_LOGGING_CONFIG_PATH):
759
+ raise Exception(
760
+ "Setting SGLANG_LOGGING_CONFIG_PATH from env with "
761
+ f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
762
+ )
763
+ with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
764
+ custom_config = json.loads(file.read())
765
+ logging.config.dictConfig(custom_config)
766
+ return
654
767
  format = f"[%(asctime)s{prefix}] %(message)s"
655
768
  # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
656
769
  logging.basicConfig(
@@ -774,12 +887,22 @@ def get_zmq_socket(
774
887
  buf_size = -1
775
888
 
776
889
  socket = context.socket(socket_type)
777
- if socket_type == zmq.PUSH:
890
+
891
+ def set_send_opt():
778
892
  socket.setsockopt(zmq.SNDHWM, 0)
779
893
  socket.setsockopt(zmq.SNDBUF, buf_size)
780
- elif socket_type == zmq.PULL:
894
+
895
+ def set_recv_opt():
781
896
  socket.setsockopt(zmq.RCVHWM, 0)
782
897
  socket.setsockopt(zmq.RCVBUF, buf_size)
898
+
899
+ if socket_type == zmq.PUSH:
900
+ set_send_opt()
901
+ elif socket_type == zmq.PULL:
902
+ set_recv_opt()
903
+ elif socket_type == zmq.DEALER:
904
+ set_send_opt()
905
+ set_recv_opt()
783
906
  else:
784
907
  raise ValueError(f"Unsupported socket type: {socket_type}")
785
908
 
@@ -910,6 +1033,13 @@ def get_amdgpu_memory_capacity():
910
1033
  )
911
1034
 
912
1035
 
1036
+ def get_device_sm():
1037
+ if torch.cuda.is_available():
1038
+ major, minor = torch.cuda.get_device_capability()
1039
+ return major * 10 + minor
1040
+ return 0
1041
+
1042
+
913
1043
  def get_nvgpu_memory_capacity():
914
1044
  try:
915
1045
  # Run nvidia-smi and capture the output
@@ -1246,11 +1376,6 @@ def set_gpu_proc_affinity(
1246
1376
  logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")
1247
1377
 
1248
1378
 
1249
- def get_bool_env_var(name: str, default: str = "false") -> bool:
1250
- value = os.getenv(name, default)
1251
- return value.lower() in ("true", "1")
1252
-
1253
-
1254
1379
  @lru_cache(maxsize=2)
1255
1380
  def disable_request_logging() -> bool:
1256
1381
  return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
@@ -1477,6 +1602,7 @@ def get_ip() -> str:
1477
1602
  def get_open_port() -> int:
1478
1603
  port = os.getenv("SGLANG_PORT")
1479
1604
  if port is not None:
1605
+ port = int(port)
1480
1606
  while True:
1481
1607
  try:
1482
1608
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -1505,6 +1631,38 @@ def is_valid_ipv6_address(address: str) -> bool:
1505
1631
  return False
1506
1632
 
1507
1633
 
1634
+ def configure_ipv6(dist_init_addr):
1635
+ addr = dist_init_addr
1636
+ end = addr.find("]")
1637
+ if end == -1:
1638
+ raise ValueError("invalid IPv6 address format: missing ']'")
1639
+
1640
+ host = addr[: end + 1]
1641
+
1642
+ # this only validates the address without brackets: we still need the below checks.
1643
+ # if it's invalid, immediately raise an error so we know it's not formatting issues.
1644
+ if not is_valid_ipv6_address(host[1:end]):
1645
+ raise ValueError(f"invalid IPv6 address: {host}")
1646
+
1647
+ port_str = None
1648
+ if len(addr) > end + 1:
1649
+ if addr[end + 1] == ":":
1650
+ port_str = addr[end + 2 :]
1651
+ else:
1652
+ raise ValueError("received IPv6 address format: expected ':' after ']'")
1653
+
1654
+ if not port_str:
1655
+ raise ValueError(
1656
+ "a port must be specified in IPv6 address (format: [ipv6]:port)"
1657
+ )
1658
+
1659
+ try:
1660
+ port = int(port_str)
1661
+ except ValueError:
1662
+ raise ValueError(f"invalid port in IPv6 address: '{port_str}'")
1663
+ return port, host
1664
+
1665
+
1508
1666
  def rank0_print(msg: str):
1509
1667
  from sglang.srt.distributed import get_tensor_model_parallel_rank
1510
1668
 
@@ -1561,6 +1719,16 @@ def next_power_of_2(n: int):
1561
1719
  setattr(triton, "next_power_of_2", next_power_of_2)
1562
1720
 
1563
1721
 
1722
+ @contextmanager
1723
+ def empty_context(*args, **kwargs):
1724
+ try:
1725
+ # Setup code goes here
1726
+ yield
1727
+ finally:
1728
+ # Cleanup code goes here
1729
+ pass
1730
+
1731
+
1564
1732
  def add_prefix(name: str, prefix: str) -> str:
1565
1733
  """Add a weight path prefix to a module name.
1566
1734
 
@@ -1572,3 +1740,29 @@ def add_prefix(name: str, prefix: str) -> str:
1572
1740
  The string `prefix.name` if prefix is non-empty, otherwise just `name`.
1573
1741
  """
1574
1742
  return name if not prefix else f"{prefix}.{name}"
1743
+
1744
+
1745
+ def is_remote_url(url: Union[str, Path]) -> bool:
1746
+ """
1747
+ Check if the URL is a remote URL of the format:
1748
+ <connector_type>://<host>:<port>/<model_name>
1749
+ """
1750
+ if isinstance(url, Path):
1751
+ return False
1752
+
1753
+ pattern = r"(.+)://(.*)"
1754
+ m = re.match(pattern, url)
1755
+ return m is not None
1756
+
1757
+
1758
+ def parse_connector_type(url: str) -> str:
1759
+ """
1760
+ Parse the connector type from the URL of the format:
1761
+ <connector_type>://<path>
1762
+ """
1763
+ pattern = r"(.+)://(.*)"
1764
+ m = re.match(pattern, url)
1765
+ if m is None:
1766
+ return ""
1767
+
1768
+ return m.group(1)
File without changes
File without changes
@@ -0,0 +1,312 @@
1
+ import unittest
2
+
3
+ import torch
4
+
5
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
6
+ from sglang.srt.layers.radix_attention import RadixAttention
7
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
8
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
9
+ from sglang.test.test_utils import CustomTestCase
10
+
11
+
12
+ class MockModelRunner:
13
+ model_config = type(
14
+ "ModelConfig", (), {"context_len": 2048, "is_multimodal": False}
15
+ )
16
+ sliding_window_size = None
17
+
18
+ def __init__(self, device="cuda"):
19
+ self.device = device
20
+ # Create a proper req_to_token_pool with the req_to_token attribute
21
+ self.req_to_token_pool = type(
22
+ "TokenPool",
23
+ (),
24
+ {
25
+ "size": 160, # a typical max_bs * max_context_len for cuda graph decode
26
+ "req_to_token": torch.zeros(
27
+ 160, 2048, dtype=torch.int32, device=device
28
+ ), # Add req_to_token attribute
29
+ },
30
+ )
31
+
32
+
33
+ class MockReqToTokenPool:
34
+ def __init__(self, batch_size, seq_len, device):
35
+ self.req_to_token = (
36
+ torch.arange(batch_size * seq_len, device=device)
37
+ .reshape(batch_size, seq_len)
38
+ .to(torch.int32)
39
+ )
40
+
41
+
42
+ @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
43
+ class TestFlashAttentionBackend(CustomTestCase):
44
+ def setUp(self):
45
+ """Set up test fixtures before each test method."""
46
+ self.model_runner = MockModelRunner()
47
+ self.backend = FlashAttentionBackend(self.model_runner)
48
+
49
+ # Common test parameters
50
+ self.batch_size = 2
51
+ self.seq_len = 4
52
+ self.num_heads = 2
53
+ self.head_dim = 8
54
+ self.device = "cuda"
55
+ self.dtype = torch.float16
56
+
57
+ def _create_attention_layer(self):
58
+ """Helper method to create an attention layer."""
59
+ return RadixAttention(
60
+ num_heads=self.num_heads,
61
+ head_dim=self.head_dim,
62
+ scaling=1.0,
63
+ num_kv_heads=self.num_heads,
64
+ layer_id=0,
65
+ )
66
+
67
+ def _create_kv_pool(self, size):
68
+ """Helper method to create a KV pool."""
69
+ return MHATokenToKVPool(
70
+ size=size,
71
+ page_size=1, # only consider page=1 for unit test
72
+ dtype=self.dtype,
73
+ head_num=self.num_heads,
74
+ head_dim=self.head_dim,
75
+ layer_num=1, # only consider layer=1 for unit test
76
+ device=self.device,
77
+ enable_memory_saver=False,
78
+ )
79
+
80
+ def _create_qkv_tensors(self, tokens_len):
81
+ """Helper method to create q, k, v tensors."""
82
+ return (
83
+ torch.randn(
84
+ tokens_len,
85
+ self.num_heads,
86
+ self.head_dim,
87
+ dtype=self.dtype,
88
+ device=self.device,
89
+ ),
90
+ torch.randn(
91
+ tokens_len,
92
+ self.num_heads,
93
+ self.head_dim,
94
+ dtype=self.dtype,
95
+ device=self.device,
96
+ ),
97
+ torch.randn(
98
+ tokens_len,
99
+ self.num_heads,
100
+ self.head_dim,
101
+ dtype=self.dtype,
102
+ device=self.device,
103
+ ),
104
+ )
105
+
106
+ def _verify_output(self, output, expected_shape):
107
+ """Helper method to verify output."""
108
+ self.assertEqual(
109
+ output.shape,
110
+ expected_shape,
111
+ f"Expected shape {expected_shape}, got {output.shape}",
112
+ )
113
+ self.assertEqual(output.dtype, self.dtype)
114
+ self.assertEqual(output.device.type, "cuda")
115
+ self.assertEqual(
116
+ torch.isnan(output).sum().item(), 0, "Output contains NaN values"
117
+ )
118
+
119
+ def test_forward_extend(self):
120
+ """Test the standard extend operation."""
121
+ # Create test inputs
122
+ q, k, v = self._create_qkv_tensors(self.batch_size * self.seq_len)
123
+
124
+ # Create attention layer
125
+ layer = self._create_attention_layer()
126
+
127
+ # Create forward batch
128
+ forward_batch = ForwardBatch(
129
+ batch_size=self.batch_size,
130
+ input_ids=torch.randint(
131
+ 0, 100, (self.batch_size, self.seq_len), device=self.device
132
+ ),
133
+ out_cache_loc=torch.arange(
134
+ self.batch_size * self.seq_len, device=self.device
135
+ ),
136
+ seq_lens_sum=self.batch_size * self.seq_len,
137
+ forward_mode=ForwardMode.EXTEND,
138
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
139
+ seq_lens=torch.tensor([self.seq_len] * self.batch_size, device=self.device),
140
+ # 0 prefix, 4 extend
141
+ extend_prefix_lens=torch.tensor([0] * self.batch_size, device=self.device),
142
+ extend_seq_lens=torch.tensor([4] * self.batch_size, device=self.device),
143
+ attn_backend=self.backend,
144
+ )
145
+
146
+ # Add token pool and KV cache
147
+ forward_batch.req_to_token_pool = MockReqToTokenPool(
148
+ self.batch_size, self.seq_len, self.device
149
+ )
150
+ forward_batch.token_to_kv_pool = self._create_kv_pool(
151
+ self.batch_size * self.seq_len
152
+ )
153
+
154
+ # Initialize forward metadata before running the attention
155
+ self.backend.init_forward_metadata(forward_batch)
156
+
157
+ # Run forward_extend
158
+ output = self.backend.forward_extend(q, k, v, layer, forward_batch)
159
+
160
+ # Verify output
161
+ expected_shape = (
162
+ self.batch_size * self.seq_len,
163
+ self.num_heads * self.head_dim,
164
+ )
165
+ self._verify_output(output, expected_shape)
166
+
167
+ def test_forward_decode(self):
168
+ """Test the decode operation with cached tokens."""
169
+ # For decode, we only have one token per sequence
170
+ decode_len = 1
171
+ curr_seq_len = self.seq_len + decode_len
172
+
173
+ # Create test inputs
174
+ q, k, v = self._create_qkv_tensors(self.batch_size * decode_len)
175
+
176
+ # Create attention layer
177
+ layer = self._create_attention_layer()
178
+
179
+ # Create forward batch
180
+ forward_batch = ForwardBatch(
181
+ batch_size=self.batch_size,
182
+ input_ids=torch.randint(
183
+ 0, 100, (self.batch_size, decode_len), device=self.device
184
+ ),
185
+ out_cache_loc=torch.arange(
186
+ self.batch_size * self.seq_len,
187
+ self.batch_size * curr_seq_len,
188
+ device=self.device,
189
+ ),
190
+ seq_lens_sum=self.batch_size * curr_seq_len,
191
+ forward_mode=ForwardMode.DECODE,
192
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
193
+ seq_lens=torch.tensor([curr_seq_len] * self.batch_size, device=self.device),
194
+ attn_backend=self.backend,
195
+ )
196
+
197
+ # Add token pool and KV cache
198
+ forward_batch.req_to_token_pool = MockReqToTokenPool(
199
+ self.batch_size, curr_seq_len, self.device
200
+ )
201
+ forward_batch.token_to_kv_pool = self._create_kv_pool(
202
+ self.batch_size * curr_seq_len
203
+ )
204
+
205
+ # Pre-fill KV cache
206
+ cache_k, cache_v, _ = self._create_qkv_tensors(self.batch_size * self.seq_len)
207
+ forward_batch.token_to_kv_pool.set_kv_buffer(
208
+ layer,
209
+ torch.arange(self.batch_size * self.seq_len, device=self.device),
210
+ cache_k,
211
+ cache_v,
212
+ layer.k_scale,
213
+ layer.v_scale,
214
+ )
215
+
216
+ # Initialize forward metadata before running the attention
217
+ self.backend.init_forward_metadata(forward_batch)
218
+
219
+ # Run forward_decode
220
+ output = self.backend.forward_decode(q, k, v, layer, forward_batch)
221
+
222
+ # Verify output
223
+ expected_shape = (self.batch_size, self.num_heads * self.head_dim)
224
+ self._verify_output(output, expected_shape)
225
+
226
+ def test_forward_extend_with_prefix(self):
227
+ """Test extending from cached prefix tokens."""
228
+ # Define prefix and extend lengths
229
+ prefix_len = 2
230
+ extend_len = 2
231
+ total_len = prefix_len + extend_len
232
+
233
+ # Create test inputs for the extend portion
234
+ q, k, v = self._create_qkv_tensors(self.batch_size * extend_len)
235
+
236
+ # Create attention layer
237
+ layer = self._create_attention_layer()
238
+
239
+ # Create forward batch
240
+ forward_batch = ForwardBatch(
241
+ batch_size=self.batch_size,
242
+ input_ids=torch.randint(
243
+ 0, 100, (self.batch_size, extend_len), device=self.device
244
+ ),
245
+ out_cache_loc=torch.arange(
246
+ self.batch_size * prefix_len,
247
+ self.batch_size * total_len,
248
+ device=self.device,
249
+ ),
250
+ seq_lens_sum=self.batch_size * total_len,
251
+ forward_mode=ForwardMode.EXTEND,
252
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
253
+ seq_lens=torch.tensor([total_len] * self.batch_size, device=self.device),
254
+ extend_prefix_lens=torch.tensor(
255
+ [prefix_len] * self.batch_size, device=self.device
256
+ ),
257
+ extend_seq_lens=torch.tensor(
258
+ [extend_len] * self.batch_size, device=self.device
259
+ ),
260
+ attn_backend=self.backend,
261
+ )
262
+
263
+ # Add token pool and KV cache
264
+ forward_batch.req_to_token_pool = MockReqToTokenPool(
265
+ self.batch_size, total_len, self.device
266
+ )
267
+ forward_batch.token_to_kv_pool = self._create_kv_pool(
268
+ self.batch_size * total_len
269
+ )
270
+
271
+ # Pre-fill the KV cache for prefix with known values
272
+ cache_k = torch.ones(
273
+ self.batch_size * prefix_len,
274
+ self.num_heads,
275
+ self.head_dim,
276
+ dtype=self.dtype,
277
+ device=self.device,
278
+ )
279
+ cache_v = (
280
+ torch.ones(
281
+ self.batch_size * prefix_len,
282
+ self.num_heads,
283
+ self.head_dim,
284
+ dtype=self.dtype,
285
+ device=self.device,
286
+ )
287
+ * 2
288
+ )
289
+
290
+ # Set the prefix KV cache
291
+ forward_batch.token_to_kv_pool.set_kv_buffer(
292
+ layer,
293
+ torch.arange(self.batch_size * prefix_len, device=self.device),
294
+ cache_k,
295
+ cache_v,
296
+ layer.k_scale,
297
+ layer.v_scale,
298
+ )
299
+
300
+ # Initialize forward metadata before running the attention
301
+ self.backend.init_forward_metadata(forward_batch)
302
+
303
+ # Run forward_extend
304
+ output = self.backend.forward_extend(q, k, v, layer, forward_batch)
305
+
306
+ # Verify output
307
+ expected_shape = (self.batch_size * extend_len, self.num_heads * self.head_dim)
308
+ self._verify_output(output, expected_shape)
309
+
310
+
311
+ if __name__ == "__main__":
312
+ unittest.main()