sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -13,7 +13,8 @@
13
13
  # ==============================================================================
14
14
  """Common utilities."""
15
15
 
16
- import base64
16
+ from __future__ import annotations
17
+
17
18
  import builtins
18
19
  import ctypes
19
20
  import dataclasses
@@ -40,6 +41,7 @@ import threading
40
41
  import time
41
42
  import traceback
42
43
  import warnings
44
+ from collections import OrderedDict, defaultdict
43
45
  from contextlib import contextmanager
44
46
  from enum import Enum
45
47
  from functools import lru_cache
@@ -65,6 +67,7 @@ from typing import (
65
67
 
66
68
  import numpy as np
67
69
  import psutil
70
+ import pybase64
68
71
  import requests
69
72
  import torch
70
73
  import torch.distributed
@@ -80,12 +83,7 @@ from torch.func import functional_call
80
83
  from torch.library import Library
81
84
  from torch.profiler import ProfilerActivity, profile, record_function
82
85
  from torch.utils._contextlib import _DecoratorContextManager
83
- from triton.runtime.cache import (
84
- FileCacheManager,
85
- default_cache_dir,
86
- default_dump_dir,
87
- default_override_dir,
88
- )
86
+ from triton.runtime.cache import FileCacheManager
89
87
 
90
88
  logger = logging.getLogger(__name__)
91
89
 
@@ -94,35 +92,6 @@ time_infos = {}
94
92
 
95
93
  HIP_FP8_E4M3_FNUZ_MAX = 224.0
96
94
 
97
- _warned_bool_env_var_keys = set()
98
-
99
-
100
- def get_bool_env_var(name: str, default: str = "false") -> bool:
101
- value = os.getenv(name, default)
102
- value = value.lower()
103
-
104
- truthy_values = ("true", "1")
105
- falsy_values = ("false", "0")
106
-
107
- if (value not in truthy_values) and (value not in falsy_values):
108
- if value not in _warned_bool_env_var_keys:
109
- logger.warning(
110
- f"get_bool_env_var({name}) see non-understandable value={value} and treat as false"
111
- )
112
- _warned_bool_env_var_keys.add(value)
113
-
114
- return value in truthy_values
115
-
116
-
117
- def get_int_env_var(name: str, default: int = 0) -> int:
118
- value = os.getenv(name)
119
- if value is None or not value.strip():
120
- return default
121
- try:
122
- return int(value)
123
- except ValueError:
124
- return default
125
-
126
95
 
127
96
  # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
128
97
  def is_hip() -> bool:
@@ -173,6 +142,82 @@ def is_cpu() -> bool:
173
142
  return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
174
143
 
175
144
 
145
+ def get_cuda_version():
146
+ if torch.version.cuda:
147
+ return tuple(map(int, torch.version.cuda.split(".")))
148
+ return (0, 0)
149
+
150
+
151
+ def _check(cc_major):
152
+ if not is_cuda():
153
+ return False
154
+ return torch.cuda.get_device_capability()[0] == cc_major and tuple(
155
+ map(int, torch.version.cuda.split(".")[:2])
156
+ ) >= (12, 3)
157
+
158
+
159
+ is_ampere_with_cuda_12_3 = lambda: _check(8)
160
+ is_hopper_with_cuda_12_3 = lambda: _check(9)
161
+
162
+
163
+ def is_blackwell():
164
+ if not is_cuda():
165
+ return False
166
+ return torch.cuda.get_device_capability()[0] == 10
167
+
168
+
169
+ _warned_bool_env_var_keys = set()
170
+
171
+
172
+ def get_bool_env_var(name: str, default: str = "false") -> bool:
173
+ value = os.getenv(name, default)
174
+ value = value.lower()
175
+
176
+ truthy_values = ("true", "1")
177
+ falsy_values = ("false", "0")
178
+
179
+ if (value not in truthy_values) and (value not in falsy_values):
180
+ if value not in _warned_bool_env_var_keys:
181
+ logger.warning(
182
+ f"get_bool_env_var({name}) see non-understandable value={value} and treat as false"
183
+ )
184
+ _warned_bool_env_var_keys.add(value)
185
+
186
+ return value in truthy_values
187
+
188
+
189
+ def get_int_env_var(name: str, default: int = 0) -> int:
190
+ value = os.getenv(name)
191
+ if value is None or not value.strip():
192
+ return default
193
+ try:
194
+ return int(value)
195
+ except ValueError:
196
+ return default
197
+
198
+
199
+ def support_triton(backend: str) -> bool:
200
+ return backend not in ["torch_native", "intel_amx"]
201
+
202
+
203
+ try:
204
+ import sgl_kernel
205
+
206
+ is_intel_amx_backend_available = hasattr(
207
+ torch.ops.sgl_kernel, "convert_weight_packed"
208
+ )
209
+ except:
210
+ is_intel_amx_backend_available = False
211
+
212
+
213
+ def cpu_has_amx_support():
214
+ return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
215
+
216
+
217
+ def use_intel_amx_backend(layer):
218
+ return getattr(layer, "use_intel_amx_backend", False)
219
+
220
+
176
221
  def is_flashinfer_available():
177
222
  """
178
223
  Check whether flashinfer is available.
@@ -500,6 +545,46 @@ def set_random_seed(seed: int) -> None:
500
545
  torch.cuda.manual_seed_all(seed)
501
546
 
502
547
 
548
+ def find_process_using_port(port: int) -> Optional[psutil.Process]:
549
+ for conn in psutil.net_connections(kind="inet"):
550
+ if conn.laddr.port == port:
551
+ try:
552
+ return psutil.Process(conn.pid)
553
+ except psutil.NoSuchProcess:
554
+ # It could happen by race condition (the proc dies when psutil.Process is called).
555
+ pass
556
+
557
+ return None
558
+
559
+
560
+ def wait_port_available(
561
+ port: int, port_name: str, timeout_s: int = 30, raise_exception: bool = True
562
+ ) -> bool:
563
+ for i in range(timeout_s):
564
+ if is_port_available(port):
565
+ return True
566
+
567
+ if i > 10 and i % 5 == 0:
568
+ process = find_process_using_port(port)
569
+ if process is None:
570
+ logger.warning(
571
+ f"The port {port} is in use, but we could not find the process that uses it."
572
+ )
573
+
574
+ pid = process.pid
575
+ error_message = f"{port_name} is used by a process already. {process.name()=}' {process.cmdline()=} {process.status()=} {pid=}"
576
+ logger.info(
577
+ f"port {port} is in use. Waiting for {i} seconds for {port_name} to be available. {error_message}"
578
+ )
579
+ time.sleep(0.1)
580
+
581
+ if raise_exception:
582
+ raise ValueError(
583
+ f"{port_name} at {port} is not available in {timeout_s} seconds. {error_message}"
584
+ )
585
+ return False
586
+
587
+
503
588
  def is_port_available(port):
504
589
  """Return whether a port is available."""
505
590
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -514,11 +599,24 @@ def is_port_available(port):
514
599
  return False
515
600
 
516
601
 
602
+ def get_free_port():
603
+ # try ipv4
604
+ try:
605
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
606
+ s.bind(("", 0))
607
+ return s.getsockname()[1]
608
+ except OSError:
609
+ # try ipv6
610
+ with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
611
+ s.bind(("", 0))
612
+ return s.getsockname()[1]
613
+
614
+
517
615
  def decode_video_base64(video_base64):
518
616
  from PIL import Image
519
617
 
520
618
  # Decode the base64 string
521
- video_bytes = base64.b64decode(video_base64)
619
+ video_bytes = pybase64.b64decode(video_base64, validate=True)
522
620
 
523
621
  # Placeholder for the start indices of each PNG image
524
622
  img_starts = []
@@ -604,7 +702,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
604
702
  audio, original_sr = sf.read(BytesIO(audio_file))
605
703
  elif audio_file.startswith("data:"):
606
704
  audio_file = audio_file.split(",")[1]
607
- audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
705
+ audio, original_sr = sf.read(
706
+ BytesIO(pybase64.b64decode(audio_file, validate=True))
707
+ )
608
708
  elif audio_file.startswith("http://") or audio_file.startswith("https://"):
609
709
  timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
610
710
  response = requests.get(audio_file, stream=True, timeout=timeout)
@@ -673,12 +773,12 @@ def load_image(
673
773
  image = Image.open(image_file)
674
774
  elif image_file.startswith("data:"):
675
775
  image_file = image_file.split(",")[1]
676
- image = Image.open(BytesIO(base64.b64decode(image_file)))
776
+ image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
677
777
  elif image_file.startswith("video:"):
678
778
  image_file = image_file.replace("video:", "")
679
779
  image, image_size = decode_video_base64(image_file)
680
780
  elif isinstance(image_file, str):
681
- image = Image.open(BytesIO(base64.b64decode(image_file)))
781
+ image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
682
782
  else:
683
783
  raise ValueError(f"Invalid image: {image}")
684
784
 
@@ -816,24 +916,51 @@ def maybe_set_triton_cache_manager() -> None:
816
916
  class CustomCacheManager(FileCacheManager):
817
917
  # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
818
918
  def __init__(self, key, override=False, dump=False):
919
+ from sglang.srt.distributed.parallel_state import get_tp_group
819
920
 
820
921
  self.key = key
821
922
  self.lock_path = None
923
+
924
+ try:
925
+ module_path = "triton.runtime.cache"
926
+ cache_module = importlib.import_module(module_path)
927
+
928
+ default_cache_dir = getattr(cache_module, "default_cache_dir", None)
929
+ default_dump_dir = getattr(cache_module, "default_dump_dir", None)
930
+ default_override_dir = getattr(cache_module, "default_override_dir", None)
931
+ except (ModuleNotFoundError, AttributeError) as e:
932
+ default_cache_dir = None
933
+ default_dump_dir = None
934
+ default_override_dir = None
935
+
822
936
  if dump:
823
- self.cache_dir = default_dump_dir()
937
+ self.cache_dir = (
938
+ default_dump_dir()
939
+ if default_dump_dir is not None
940
+ else os.path.join(Path.home(), ".triton", "dump")
941
+ )
824
942
  self.cache_dir = os.path.join(self.cache_dir, self.key)
825
943
  self.lock_path = os.path.join(self.cache_dir, "lock")
826
944
  os.makedirs(self.cache_dir, exist_ok=True)
827
945
  elif override:
828
- self.cache_dir = default_override_dir()
946
+ self.cache_dir = (
947
+ default_override_dir()
948
+ if default_override_dir is not None
949
+ else os.path.join(Path.home(), ".triton", "override")
950
+ )
829
951
  self.cache_dir = os.path.join(self.cache_dir, self.key)
830
952
  else:
831
953
  # create cache directory if it doesn't exist
832
- self.cache_dir = (
833
- os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
954
+ self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
955
+ default_cache_dir()
956
+ if default_cache_dir is not None
957
+ else os.path.join(Path.home(), ".triton", "cache")
834
958
  )
835
959
  if self.cache_dir:
836
- self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
960
+ try:
961
+ self.cache_dir = f"{self.cache_dir}_{get_tp_group().local_rank}"
962
+ except:
963
+ self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
837
964
  self.cache_dir = os.path.join(self.cache_dir, self.key)
838
965
  self.lock_path = os.path.join(self.cache_dir, "lock")
839
966
  os.makedirs(self.cache_dir, exist_ok=True)
@@ -997,36 +1124,48 @@ def point_to_point_pyobj(
997
1124
  src: int = 0,
998
1125
  dst: int = 1,
999
1126
  ):
1000
- """Send data from src to dst in group."""
1127
+ """Send data from src to dst in group using DeviceToDevice communication."""
1001
1128
 
1002
1129
  if rank == src:
1003
1130
  if len(data) == 0:
1004
- tensor_size = torch.tensor([0], dtype=torch.long)
1131
+ tensor_size = torch.tensor(
1132
+ [0], dtype=torch.long, device=torch.cuda.current_device()
1133
+ )
1005
1134
  dist.send(tensor_size, dst=dst, group=group)
1006
1135
  else:
1007
1136
  serialized_data = pickle.dumps(data)
1008
1137
  size = len(serialized_data)
1009
1138
  tensor_data = torch.ByteTensor(
1010
1139
  np.frombuffer(serialized_data, dtype=np.uint8)
1140
+ ).cuda(
1141
+ device=torch.cuda.current_device()
1142
+ ) # Move to GPU
1143
+ tensor_size = torch.tensor(
1144
+ [size], dtype=torch.long, device=torch.cuda.current_device()
1011
1145
  )
1012
- tensor_size = torch.tensor([size], dtype=torch.long)
1013
1146
 
1014
1147
  dist.send(tensor_size, dst=dst, group=group)
1015
1148
  dist.send(tensor_data, dst=dst, group=group)
1016
1149
  return data
1017
1150
 
1018
1151
  elif rank == dst:
1019
- tensor_size = torch.tensor([0], dtype=torch.long)
1152
+ tensor_size = torch.tensor(
1153
+ [0], dtype=torch.long, device=torch.cuda.current_device()
1154
+ )
1020
1155
  dist.recv(tensor_size, src=src, group=group)
1021
1156
  size = tensor_size.item()
1022
1157
 
1023
1158
  if size == 0:
1024
1159
  return []
1025
1160
 
1026
- tensor_data = torch.empty(size, dtype=torch.uint8)
1161
+ tensor_data = torch.empty(
1162
+ size, dtype=torch.uint8, device=torch.cuda.current_device()
1163
+ )
1027
1164
  dist.recv(tensor_data, src=src, group=group)
1028
1165
 
1029
- serialized_data = bytes(tensor_data.cpu().numpy())
1166
+ serialized_data = bytes(
1167
+ tensor_data.cpu().numpy()
1168
+ ) # Move back to host for deserialization
1030
1169
  data = pickle.loads(serialized_data)
1031
1170
  return data
1032
1171
 
@@ -1428,6 +1567,15 @@ def is_habana_available() -> bool:
1428
1567
 
1429
1568
  @lru_cache(maxsize=8)
1430
1569
  def get_device(device_id: Optional[int] = None) -> str:
1570
+ if is_cpu():
1571
+ if cpu_has_amx_support():
1572
+ logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
1573
+ else:
1574
+ logger.warning(
1575
+ "CPU device enabled, using torch native backend, low performance expected."
1576
+ )
1577
+ return "cpu"
1578
+
1431
1579
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1432
1580
  if device_id is None:
1433
1581
  return "cuda"
@@ -1456,15 +1604,6 @@ def get_device(device_id: Optional[int] = None) -> str:
1456
1604
  "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
1457
1605
  )
1458
1606
 
1459
- if is_cpu():
1460
- if cpu_has_amx_support():
1461
- logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
1462
- else:
1463
- logger.warning(
1464
- "CPU device enabled, using torch native backend, low performance expected."
1465
- )
1466
- return "cpu"
1467
-
1468
1607
  raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
1469
1608
 
1470
1609
 
@@ -1729,7 +1868,7 @@ class MultiprocessingSerializer:
1729
1868
 
1730
1869
  if output_str:
1731
1870
  # Convert bytes to base64-encoded string
1732
- output = base64.b64encode(output).decode("utf-8")
1871
+ output = pybase64.b64encode(output).decode("utf-8")
1733
1872
 
1734
1873
  return output
1735
1874
 
@@ -1746,7 +1885,7 @@ class MultiprocessingSerializer:
1746
1885
  """
1747
1886
  if isinstance(data, str):
1748
1887
  # Decode base64 string to bytes
1749
- data = base64.b64decode(data)
1888
+ data = pybase64.b64decode(data, validate=True)
1750
1889
 
1751
1890
  return ForkingPickler.loads(data)
1752
1891
 
@@ -1917,20 +2056,11 @@ def configure_ipv6(dist_init_addr):
1917
2056
  return port, host
1918
2057
 
1919
2058
 
1920
- def rank0_print(msg: str):
2059
+ def rank0_log(msg: str):
1921
2060
  from sglang.srt.distributed import get_tensor_model_parallel_rank
1922
2061
 
1923
2062
  if get_tensor_model_parallel_rank() == 0:
1924
- print(msg, flush=True)
1925
-
1926
-
1927
- rank0_log = rank0_print
1928
-
1929
-
1930
- def get_cuda_version():
1931
- if torch.version.cuda:
1932
- return tuple(map(int, torch.version.cuda.split(".")))
1933
- return (0, 0)
2063
+ logger.info(msg)
1934
2064
 
1935
2065
 
1936
2066
  def launch_dummy_health_check_server(host, port):
@@ -2092,14 +2222,14 @@ class DeepEPMode(Enum):
2092
2222
  def enable_low_latency(self):
2093
2223
  return self in [DeepEPMode.low_latency, DeepEPMode.auto]
2094
2224
 
2095
- def resolve(self, forward_mode):
2225
+ def resolve(self, is_extend_in_batch: bool):
2096
2226
  if self != DeepEPMode.auto:
2097
2227
  return self
2098
2228
 
2099
- if forward_mode.is_decode():
2100
- return DeepEPMode.low_latency
2101
- else:
2229
+ if is_extend_in_batch:
2102
2230
  return DeepEPMode.normal
2231
+ else:
2232
+ return DeepEPMode.low_latency
2103
2233
 
2104
2234
 
2105
2235
  def is_non_idle_and_non_empty(forward_mode, hidden_states):
@@ -2119,35 +2249,12 @@ def fast_topk(values, topk, dim):
2119
2249
  return torch.topk(values, topk, dim=dim)
2120
2250
 
2121
2251
 
2122
- def _check(cc_major):
2123
- if not is_cuda():
2124
- return False
2125
- return torch.cuda.get_device_capability()[0] == cc_major and tuple(
2126
- map(int, torch.version.cuda.split(".")[:2])
2127
- ) >= (12, 3)
2128
-
2129
-
2130
- is_ampere_with_cuda_12_3 = lambda: _check(8)
2131
- is_hopper_with_cuda_12_3 = lambda: _check(9)
2132
-
2133
-
2134
- def is_blackwell():
2135
- if not is_cuda():
2136
- return False
2137
- return torch.cuda.get_device_capability()[0] == 10
2138
-
2139
-
2140
- def get_free_port():
2141
- # try ipv4
2142
- try:
2143
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
2144
- s.bind(("", 0))
2145
- return s.getsockname()[1]
2146
- except OSError:
2147
- # try ipv6
2148
- with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
2149
- s.bind(("", 0))
2150
- return s.getsockname()[1]
2252
+ def bind_or_assign(target, source):
2253
+ if target is not None:
2254
+ target.copy_(source)
2255
+ return target
2256
+ else:
2257
+ return source
2151
2258
 
2152
2259
 
2153
2260
  def get_local_ip_auto() -> str:
@@ -2344,45 +2451,6 @@ def require_mlp_sync(server_args):
2344
2451
  return server_args.enable_dp_attention or require_gathered_buffer(server_args)
2345
2452
 
2346
2453
 
2347
- def merge_bias_tensor(
2348
- lhs: Optional[torch.Tensor],
2349
- rhs: Optional[torch.Tensor],
2350
- bs1: int,
2351
- bs2: int,
2352
- device: str,
2353
- default: float,
2354
- ):
2355
- """Merge two bias tensors for batch merging.
2356
-
2357
- Args:
2358
- lhs: Left-hand side tensor
2359
- rhs: Right-hand side tensor
2360
- bs1: Batch size of left-hand side tensor
2361
- bs2: Batch size of right-hand side tensor
2362
- device: Device to place the merged tensor on
2363
- default: Default value for missing tensor elements
2364
-
2365
- Returns:
2366
- Merged tensor or None if both inputs are None
2367
- """
2368
- if lhs is None and rhs is None:
2369
- return None
2370
-
2371
- if lhs is not None and rhs is not None:
2372
- return torch.cat([lhs, rhs])
2373
- else:
2374
- if lhs is not None:
2375
- shape, dtype = lhs.shape[1:], lhs.dtype
2376
- else:
2377
- shape, dtype = rhs.shape[1:], rhs.dtype
2378
-
2379
- if lhs is None:
2380
- lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
2381
- if rhs is None:
2382
- rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
2383
- return torch.cat([lhs, rhs])
2384
-
2385
-
2386
2454
  def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
2387
2455
  import huggingface_hub as hf
2388
2456
 
@@ -2439,24 +2507,6 @@ def bind_or_assign(target, source):
2439
2507
  return source
2440
2508
 
2441
2509
 
2442
- def support_triton(backend: str) -> bool:
2443
- return backend not in ["torch_native", "intel_amx"]
2444
-
2445
-
2446
- try:
2447
- import sgl_kernel
2448
-
2449
- is_intel_amx_backend_available = hasattr(
2450
- torch.ops.sgl_kernel, "convert_weight_packed"
2451
- )
2452
- except:
2453
- is_intel_amx_backend_available = False
2454
-
2455
-
2456
- def cpu_has_amx_support():
2457
- return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
2458
-
2459
-
2460
2510
  def prepack_weight_if_needed(weight):
2461
2511
  if weight.device != torch.device("cpu"):
2462
2512
  return weight
@@ -2577,3 +2627,133 @@ def configure_gc_logger():
2577
2627
  )
2578
2628
 
2579
2629
  gc.callbacks.append(gc_callback)
2630
+
2631
+
2632
+ # COPIED FROM DeepGEMM
2633
+ def align(x: int, y: int) -> int:
2634
+ return ceil_div(x, y) * y
2635
+
2636
+
2637
+ # COPIED FROM DeepGEMM
2638
+ def ceil_div(x: int, y: int) -> int:
2639
+ return (x + y - 1) // y
2640
+
2641
+
2642
+ def parse_lscpu_topology():
2643
+ try:
2644
+ # Get CPU topology: CPU,Core,Socket,Node
2645
+ output = subprocess.check_output(
2646
+ ["lscpu", "-p=CPU,Core,Socket,Node"], text=True
2647
+ )
2648
+ except Exception as e:
2649
+ raise RuntimeError(f"Unexpected error running 'lscpu': {e}")
2650
+
2651
+ # Parse only data lines (skip comments)
2652
+ cpu_info = []
2653
+ for line in output.splitlines():
2654
+ if not line.startswith("#"):
2655
+ cpu, core, socket, node = map(int, line.strip().split(","))
2656
+ cpu_info.append((cpu, core, socket, node))
2657
+
2658
+ # [(0,0,0,0),(1,1,0,0),...,(43,43,0,1),...,(256,0,0,0),...]
2659
+ return cpu_info
2660
+
2661
+
2662
+ def get_physical_cpus_by_numa():
2663
+ cpu_info = parse_lscpu_topology()
2664
+
2665
+ # Map NUMA node -> set of (core_id, socket) to avoid duplicates
2666
+ # 0: {(0,0): 0, (1, 0): 1,...}
2667
+ # ...
2668
+ # 5: {(214,1): 214, (215,1): 215}
2669
+ physical_by_node = defaultdict(dict) # node -> core_id -> cpu_id
2670
+
2671
+ for cpu, core, socket, node in cpu_info:
2672
+ key = (core, socket)
2673
+ if key not in physical_by_node[node]:
2674
+ physical_by_node[node][
2675
+ key
2676
+ ] = cpu # pick first CPU seen for that physical core
2677
+
2678
+ # Retrieves CPUs that the current process is allowed to run on
2679
+ cpus_allowed_list = psutil.Process().cpu_affinity()
2680
+
2681
+ # Convert to list of physical CPUs per node
2682
+ # 0: [0,1,2,...,42]
2683
+ # ...
2684
+ # 2: [86,87,...,127]
2685
+ # ...
2686
+ # 5: [214,215,...,255]
2687
+ node_to_cpus = {}
2688
+ for node, core_to_cpu in physical_by_node.items():
2689
+ cpus = sorted(core_to_cpu.values())
2690
+ allowed_cpus = set(cpus).intersection(cpus_allowed_list)
2691
+ node_to_cpus[node] = allowed_cpus
2692
+
2693
+ return node_to_cpus
2694
+
2695
+
2696
+ # Only physical cores are used. Logical cores are excluded.
2697
+ def get_cpu_ids_by_node():
2698
+ node_to_cpus = get_physical_cpus_by_numa()
2699
+ # Sort by NUMA node index
2700
+ cpu_ids = [
2701
+ ",".join(map(str, sorted(node_to_cpus[node]))) for node in sorted(node_to_cpus)
2702
+ ]
2703
+
2704
+ # ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
2705
+ return cpu_ids
2706
+
2707
+
2708
+ def is_shm_available(dtype, world_size, local_size):
2709
+ return (
2710
+ cpu_has_amx_support()
2711
+ and dtype in [torch.bfloat16, torch.float]
2712
+ and world_size >= 1
2713
+ and world_size == local_size
2714
+ )
2715
+
2716
+
2717
+ def lru_cache_frozenset(maxsize=128):
2718
+ def _to_hashable(o):
2719
+ try:
2720
+ hash(o)
2721
+ return o
2722
+ except TypeError:
2723
+ # Not hashable; convert based on type
2724
+ if isinstance(o, (dict)):
2725
+ return frozenset(
2726
+ (_to_hashable(k), _to_hashable(v)) for k, v in o.items()
2727
+ )
2728
+ elif isinstance(o, set):
2729
+ return frozenset(_to_hashable(v) for v in o)
2730
+ elif isinstance(o, (list, tuple)) or (
2731
+ isinstance(o, Sequence) and not isinstance(o, (str, bytes))
2732
+ ):
2733
+ return tuple(_to_hashable(v) for v in o)
2734
+ else:
2735
+ raise TypeError(f"Cannot make hashable: {type(o)}")
2736
+
2737
+ def decorator(func):
2738
+ cache = OrderedDict()
2739
+
2740
+ @functools.wraps(func)
2741
+ def wrapper(*args, **kwargs):
2742
+ h_args = tuple(_to_hashable(a) for a in args)
2743
+ h_kwargs = frozenset(
2744
+ (_to_hashable(k), _to_hashable(v)) for k, v in kwargs.items()
2745
+ )
2746
+ key = (h_args, h_kwargs)
2747
+ if key in cache:
2748
+ cache.move_to_end(key)
2749
+ return cache[key]
2750
+ result = func(*args, **kwargs)
2751
+ cache[key] = result
2752
+ if maxsize is not None and len(cache) > maxsize:
2753
+ cache.popitem(last=False)
2754
+ return result
2755
+
2756
+ wrapper.cache_clear = cache.clear # For manual cache clearing
2757
+ return wrapper
2758
+
2759
+ return decorator