sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -13,6 +13,8 @@
13
13
  # ==============================================================================
14
14
  """Common utilities."""
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  import base64
17
19
  import builtins
18
20
  import ctypes
@@ -40,6 +42,7 @@ import threading
40
42
  import time
41
43
  import traceback
42
44
  import warnings
45
+ from collections import OrderedDict, defaultdict
43
46
  from contextlib import contextmanager
44
47
  from enum import Enum
45
48
  from functools import lru_cache
@@ -94,35 +97,6 @@ time_infos = {}
94
97
 
95
98
  HIP_FP8_E4M3_FNUZ_MAX = 224.0
96
99
 
97
- _warned_bool_env_var_keys = set()
98
-
99
-
100
- def get_bool_env_var(name: str, default: str = "false") -> bool:
101
- value = os.getenv(name, default)
102
- value = value.lower()
103
-
104
- truthy_values = ("true", "1")
105
- falsy_values = ("false", "0")
106
-
107
- if (value not in truthy_values) and (value not in falsy_values):
108
- if value not in _warned_bool_env_var_keys:
109
- logger.warning(
110
- f"get_bool_env_var({name}) see non-understandable value={value} and treat as false"
111
- )
112
- _warned_bool_env_var_keys.add(value)
113
-
114
- return value in truthy_values
115
-
116
-
117
- def get_int_env_var(name: str, default: int = 0) -> int:
118
- value = os.getenv(name)
119
- if value is None or not value.strip():
120
- return default
121
- try:
122
- return int(value)
123
- except ValueError:
124
- return default
125
-
126
100
 
127
101
  # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
128
102
  def is_hip() -> bool:
@@ -173,6 +147,82 @@ def is_cpu() -> bool:
173
147
  return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
174
148
 
175
149
 
150
+ def get_cuda_version():
151
+ if torch.version.cuda:
152
+ return tuple(map(int, torch.version.cuda.split(".")))
153
+ return (0, 0)
154
+
155
+
156
+ def _check(cc_major):
157
+ if not is_cuda():
158
+ return False
159
+ return torch.cuda.get_device_capability()[0] == cc_major and tuple(
160
+ map(int, torch.version.cuda.split(".")[:2])
161
+ ) >= (12, 3)
162
+
163
+
164
+ is_ampere_with_cuda_12_3 = lambda: _check(8)
165
+ is_hopper_with_cuda_12_3 = lambda: _check(9)
166
+
167
+
168
+ def is_blackwell():
169
+ if not is_cuda():
170
+ return False
171
+ return torch.cuda.get_device_capability()[0] == 10
172
+
173
+
174
+ _warned_bool_env_var_keys = set()
175
+
176
+
177
+ def get_bool_env_var(name: str, default: str = "false") -> bool:
178
+ value = os.getenv(name, default)
179
+ value = value.lower()
180
+
181
+ truthy_values = ("true", "1")
182
+ falsy_values = ("false", "0")
183
+
184
+ if (value not in truthy_values) and (value not in falsy_values):
185
+ if value not in _warned_bool_env_var_keys:
186
+ logger.warning(
187
+ f"get_bool_env_var({name}) see non-understandable value={value} and treat as false"
188
+ )
189
+ _warned_bool_env_var_keys.add(value)
190
+
191
+ return value in truthy_values
192
+
193
+
194
+ def get_int_env_var(name: str, default: int = 0) -> int:
195
+ value = os.getenv(name)
196
+ if value is None or not value.strip():
197
+ return default
198
+ try:
199
+ return int(value)
200
+ except ValueError:
201
+ return default
202
+
203
+
204
+ def support_triton(backend: str) -> bool:
205
+ return backend not in ["torch_native", "intel_amx"]
206
+
207
+
208
+ try:
209
+ import sgl_kernel
210
+
211
+ is_intel_amx_backend_available = hasattr(
212
+ torch.ops.sgl_kernel, "convert_weight_packed"
213
+ )
214
+ except:
215
+ is_intel_amx_backend_available = False
216
+
217
+
218
+ def cpu_has_amx_support():
219
+ return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
220
+
221
+
222
+ def use_intel_amx_backend(layer):
223
+ return getattr(layer, "use_intel_amx_backend", False)
224
+
225
+
176
226
  def is_flashinfer_available():
177
227
  """
178
228
  Check whether flashinfer is available.
@@ -500,6 +550,46 @@ def set_random_seed(seed: int) -> None:
500
550
  torch.cuda.manual_seed_all(seed)
501
551
 
502
552
 
553
+ def find_process_using_port(port: int) -> Optional[psutil.Process]:
554
+ for conn in psutil.net_connections(kind="inet"):
555
+ if conn.laddr.port == port:
556
+ try:
557
+ return psutil.Process(conn.pid)
558
+ except psutil.NoSuchProcess:
559
+ # It could happen by race condition (the proc dies when psutil.Process is called).
560
+ pass
561
+
562
+ return None
563
+
564
+
565
+ def wait_port_available(
566
+ port: int, port_name: str, timeout_s: int = 30, raise_exception: bool = True
567
+ ) -> bool:
568
+ for i in range(timeout_s):
569
+ if is_port_available(port):
570
+ return True
571
+
572
+ if i > 10 and i % 5 == 0:
573
+ process = find_process_using_port(port)
574
+ if process is None:
575
+ logger.warning(
576
+ f"The port {port} is in use, but we could not find the process that uses it."
577
+ )
578
+
579
+ pid = process.pid
580
+ error_message = f"{port_name} is used by a process already. {process.name()=}' {process.cmdline()=} {process.status()=} {pid=}"
581
+ logger.info(
582
+ f"port {port} is in use. Waiting for {i} seconds for {port_name} to be available. {error_message}"
583
+ )
584
+ time.sleep(0.1)
585
+
586
+ if raise_exception:
587
+ raise ValueError(
588
+ f"{port_name} at {port} is not available in {timeout_s} seconds. {error_message}"
589
+ )
590
+ return False
591
+
592
+
503
593
  def is_port_available(port):
504
594
  """Return whether a port is available."""
505
595
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -514,6 +604,19 @@ def is_port_available(port):
514
604
  return False
515
605
 
516
606
 
607
+ def get_free_port():
608
+ # try ipv4
609
+ try:
610
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
611
+ s.bind(("", 0))
612
+ return s.getsockname()[1]
613
+ except OSError:
614
+ # try ipv6
615
+ with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
616
+ s.bind(("", 0))
617
+ return s.getsockname()[1]
618
+
619
+
517
620
  def decode_video_base64(video_base64):
518
621
  from PIL import Image
519
622
 
@@ -816,6 +919,7 @@ def maybe_set_triton_cache_manager() -> None:
816
919
  class CustomCacheManager(FileCacheManager):
817
920
  # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
818
921
  def __init__(self, key, override=False, dump=False):
922
+ from sglang.srt.distributed.parallel_state import get_tp_group
819
923
 
820
924
  self.key = key
821
925
  self.lock_path = None
@@ -833,7 +937,10 @@ class CustomCacheManager(FileCacheManager):
833
937
  os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
834
938
  )
835
939
  if self.cache_dir:
836
- self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
940
+ try:
941
+ self.cache_dir = f"{self.cache_dir}_{get_tp_group().local_rank}"
942
+ except:
943
+ self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
837
944
  self.cache_dir = os.path.join(self.cache_dir, self.key)
838
945
  self.lock_path = os.path.join(self.cache_dir, "lock")
839
946
  os.makedirs(self.cache_dir, exist_ok=True)
@@ -997,36 +1104,48 @@ def point_to_point_pyobj(
997
1104
  src: int = 0,
998
1105
  dst: int = 1,
999
1106
  ):
1000
- """Send data from src to dst in group."""
1107
+ """Send data from src to dst in group using DeviceToDevice communication."""
1001
1108
 
1002
1109
  if rank == src:
1003
1110
  if len(data) == 0:
1004
- tensor_size = torch.tensor([0], dtype=torch.long)
1111
+ tensor_size = torch.tensor(
1112
+ [0], dtype=torch.long, device=torch.cuda.current_device()
1113
+ )
1005
1114
  dist.send(tensor_size, dst=dst, group=group)
1006
1115
  else:
1007
1116
  serialized_data = pickle.dumps(data)
1008
1117
  size = len(serialized_data)
1009
1118
  tensor_data = torch.ByteTensor(
1010
1119
  np.frombuffer(serialized_data, dtype=np.uint8)
1120
+ ).cuda(
1121
+ device=torch.cuda.current_device()
1122
+ ) # Move to GPU
1123
+ tensor_size = torch.tensor(
1124
+ [size], dtype=torch.long, device=torch.cuda.current_device()
1011
1125
  )
1012
- tensor_size = torch.tensor([size], dtype=torch.long)
1013
1126
 
1014
1127
  dist.send(tensor_size, dst=dst, group=group)
1015
1128
  dist.send(tensor_data, dst=dst, group=group)
1016
1129
  return data
1017
1130
 
1018
1131
  elif rank == dst:
1019
- tensor_size = torch.tensor([0], dtype=torch.long)
1132
+ tensor_size = torch.tensor(
1133
+ [0], dtype=torch.long, device=torch.cuda.current_device()
1134
+ )
1020
1135
  dist.recv(tensor_size, src=src, group=group)
1021
1136
  size = tensor_size.item()
1022
1137
 
1023
1138
  if size == 0:
1024
1139
  return []
1025
1140
 
1026
- tensor_data = torch.empty(size, dtype=torch.uint8)
1141
+ tensor_data = torch.empty(
1142
+ size, dtype=torch.uint8, device=torch.cuda.current_device()
1143
+ )
1027
1144
  dist.recv(tensor_data, src=src, group=group)
1028
1145
 
1029
- serialized_data = bytes(tensor_data.cpu().numpy())
1146
+ serialized_data = bytes(
1147
+ tensor_data.cpu().numpy()
1148
+ ) # Move back to host for deserialization
1030
1149
  data = pickle.loads(serialized_data)
1031
1150
  return data
1032
1151
 
@@ -1428,6 +1547,15 @@ def is_habana_available() -> bool:
1428
1547
 
1429
1548
  @lru_cache(maxsize=8)
1430
1549
  def get_device(device_id: Optional[int] = None) -> str:
1550
+ if is_cpu():
1551
+ if cpu_has_amx_support():
1552
+ logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
1553
+ else:
1554
+ logger.warning(
1555
+ "CPU device enabled, using torch native backend, low performance expected."
1556
+ )
1557
+ return "cpu"
1558
+
1431
1559
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1432
1560
  if device_id is None:
1433
1561
  return "cuda"
@@ -1456,15 +1584,6 @@ def get_device(device_id: Optional[int] = None) -> str:
1456
1584
  "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
1457
1585
  )
1458
1586
 
1459
- if is_cpu():
1460
- if cpu_has_amx_support():
1461
- logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
1462
- else:
1463
- logger.warning(
1464
- "CPU device enabled, using torch native backend, low performance expected."
1465
- )
1466
- return "cpu"
1467
-
1468
1587
  raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
1469
1588
 
1470
1589
 
@@ -1917,20 +2036,11 @@ def configure_ipv6(dist_init_addr):
1917
2036
  return port, host
1918
2037
 
1919
2038
 
1920
- def rank0_print(msg: str):
2039
+ def rank0_log(msg: str):
1921
2040
  from sglang.srt.distributed import get_tensor_model_parallel_rank
1922
2041
 
1923
2042
  if get_tensor_model_parallel_rank() == 0:
1924
- print(msg, flush=True)
1925
-
1926
-
1927
- rank0_log = rank0_print
1928
-
1929
-
1930
- def get_cuda_version():
1931
- if torch.version.cuda:
1932
- return tuple(map(int, torch.version.cuda.split(".")))
1933
- return (0, 0)
2043
+ logger.info(msg)
1934
2044
 
1935
2045
 
1936
2046
  def launch_dummy_health_check_server(host, port):
@@ -2092,14 +2202,14 @@ class DeepEPMode(Enum):
2092
2202
  def enable_low_latency(self):
2093
2203
  return self in [DeepEPMode.low_latency, DeepEPMode.auto]
2094
2204
 
2095
- def resolve(self, forward_mode):
2205
+ def resolve(self, is_extend_in_batch: bool):
2096
2206
  if self != DeepEPMode.auto:
2097
2207
  return self
2098
2208
 
2099
- if forward_mode.is_decode():
2100
- return DeepEPMode.low_latency
2101
- else:
2209
+ if is_extend_in_batch:
2102
2210
  return DeepEPMode.normal
2211
+ else:
2212
+ return DeepEPMode.low_latency
2103
2213
 
2104
2214
 
2105
2215
  def is_non_idle_and_non_empty(forward_mode, hidden_states):
@@ -2119,35 +2229,12 @@ def fast_topk(values, topk, dim):
2119
2229
  return torch.topk(values, topk, dim=dim)
2120
2230
 
2121
2231
 
2122
- def _check(cc_major):
2123
- if not is_cuda():
2124
- return False
2125
- return torch.cuda.get_device_capability()[0] == cc_major and tuple(
2126
- map(int, torch.version.cuda.split(".")[:2])
2127
- ) >= (12, 3)
2128
-
2129
-
2130
- is_ampere_with_cuda_12_3 = lambda: _check(8)
2131
- is_hopper_with_cuda_12_3 = lambda: _check(9)
2132
-
2133
-
2134
- def is_blackwell():
2135
- if not is_cuda():
2136
- return False
2137
- return torch.cuda.get_device_capability()[0] == 10
2138
-
2139
-
2140
- def get_free_port():
2141
- # try ipv4
2142
- try:
2143
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
2144
- s.bind(("", 0))
2145
- return s.getsockname()[1]
2146
- except OSError:
2147
- # try ipv6
2148
- with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
2149
- s.bind(("", 0))
2150
- return s.getsockname()[1]
2232
+ def bind_or_assign(target, source):
2233
+ if target is not None:
2234
+ target.copy_(source)
2235
+ return target
2236
+ else:
2237
+ return source
2151
2238
 
2152
2239
 
2153
2240
  def get_local_ip_auto() -> str:
@@ -2344,45 +2431,6 @@ def require_mlp_sync(server_args):
2344
2431
  return server_args.enable_dp_attention or require_gathered_buffer(server_args)
2345
2432
 
2346
2433
 
2347
- def merge_bias_tensor(
2348
- lhs: Optional[torch.Tensor],
2349
- rhs: Optional[torch.Tensor],
2350
- bs1: int,
2351
- bs2: int,
2352
- device: str,
2353
- default: float,
2354
- ):
2355
- """Merge two bias tensors for batch merging.
2356
-
2357
- Args:
2358
- lhs: Left-hand side tensor
2359
- rhs: Right-hand side tensor
2360
- bs1: Batch size of left-hand side tensor
2361
- bs2: Batch size of right-hand side tensor
2362
- device: Device to place the merged tensor on
2363
- default: Default value for missing tensor elements
2364
-
2365
- Returns:
2366
- Merged tensor or None if both inputs are None
2367
- """
2368
- if lhs is None and rhs is None:
2369
- return None
2370
-
2371
- if lhs is not None and rhs is not None:
2372
- return torch.cat([lhs, rhs])
2373
- else:
2374
- if lhs is not None:
2375
- shape, dtype = lhs.shape[1:], lhs.dtype
2376
- else:
2377
- shape, dtype = rhs.shape[1:], rhs.dtype
2378
-
2379
- if lhs is None:
2380
- lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
2381
- if rhs is None:
2382
- rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
2383
- return torch.cat([lhs, rhs])
2384
-
2385
-
2386
2434
  def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
2387
2435
  import huggingface_hub as hf
2388
2436
 
@@ -2439,22 +2487,75 @@ def bind_or_assign(target, source):
2439
2487
  return source
2440
2488
 
2441
2489
 
2442
- def support_triton(backend: str) -> bool:
2443
- return backend not in ["torch_native", "intel_amx"]
2490
+ def prepack_weight_if_needed(weight):
2491
+ if weight.device != torch.device("cpu"):
2492
+ return weight
2493
+ if not cpu_has_amx_support():
2494
+ return weight
2444
2495
 
2496
+ return torch.ops.sgl_kernel.convert_weight_packed(weight)
2445
2497
 
2446
- try:
2447
- import sgl_kernel
2448
2498
 
2449
- is_intel_amx_backend_available = hasattr(
2450
- torch.ops.sgl_kernel, "convert_weight_packed"
2499
+ # TODO: currently gemm kernel has the below requirements:
2500
+ # OC % TILE_N == 0, where TILE_N = 16
2501
+ # IC % TILE_K == 0, where TILE_K = 32
2502
+ def dim_is_supported(weight):
2503
+ return weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0
2504
+
2505
+
2506
+ def _process_weight_after_loading(module, weight_names, transpose_dims=None) -> None:
2507
+ # Pack weight for get better performance on CPU
2508
+ devices = {getattr(module, weight_name).device for weight_name in weight_names}
2509
+ assert len(devices) == 1, f"Expects all weights to be on the same device"
2510
+ device = devices.pop()
2511
+
2512
+ if transpose_dims:
2513
+ assert len(weight_names) == len(
2514
+ transpose_dims
2515
+ ), "len(weight_names) should be equal to len(transpose_dims)"
2516
+
2517
+ for i, weight_name in enumerate(weight_names):
2518
+ weight_tensor = getattr(module, weight_name)
2519
+
2520
+ # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
2521
+ if not dim_is_supported(weight_tensor):
2522
+ logger.warning(
2523
+ f"Expects weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 "
2524
+ f"but {weight_tensor.size(0)=} and {weight_tensor.size(1)=} in {module}. "
2525
+ f"{module} won't use intel amx backend."
2526
+ )
2527
+ module.use_intel_amx_backend = False
2528
+ return
2529
+
2530
+ if transpose_dims and transpose_dims[i]:
2531
+ weight_tensor = weight_tensor.transpose(*transpose_dims[i])
2532
+
2533
+ packed_weight = torch.nn.Parameter(
2534
+ prepack_weight_if_needed(weight_tensor),
2535
+ requires_grad=False,
2536
+ )
2537
+ packed_weight.__dict__ = weight_tensor.__dict__
2538
+ setattr(module, weight_name, packed_weight)
2539
+
2540
+ module.use_intel_amx_backend = (
2541
+ device == torch.device("cpu") and cpu_has_amx_support()
2451
2542
  )
2452
- except:
2453
- is_intel_amx_backend_available = False
2454
2543
 
2544
+ if (
2545
+ module.use_intel_amx_backend
2546
+ and hasattr(module, "bias")
2547
+ and module.bias is not None
2548
+ ):
2549
+ module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
2455
2550
 
2456
- def cpu_has_amx_support():
2457
- return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
2551
+
2552
+ class PackWeightMethod:
2553
+ def __init__(self, weight_names, transpose_dims=None):
2554
+ self.weight_names = weight_names
2555
+ self.transpose_dims = transpose_dims
2556
+
2557
+ def process_weights_after_loading(self, module) -> None:
2558
+ _process_weight_after_loading(module, self.weight_names, self.transpose_dims)
2458
2559
 
2459
2560
 
2460
2561
  class LazyValue:
@@ -2506,3 +2607,133 @@ def configure_gc_logger():
2506
2607
  )
2507
2608
 
2508
2609
  gc.callbacks.append(gc_callback)
2610
+
2611
+
2612
+ # COPIED FROM DeepGEMM
2613
+ def align(x: int, y: int) -> int:
2614
+ return ceil_div(x, y) * y
2615
+
2616
+
2617
+ # COPIED FROM DeepGEMM
2618
+ def ceil_div(x: int, y: int) -> int:
2619
+ return (x + y - 1) // y
2620
+
2621
+
2622
+ def parse_lscpu_topology():
2623
+ try:
2624
+ # Get CPU topology: CPU,Core,Socket,Node
2625
+ output = subprocess.check_output(
2626
+ ["lscpu", "-p=CPU,Core,Socket,Node"], text=True
2627
+ )
2628
+ except Exception as e:
2629
+ raise RuntimeError(f"Unexpected error running 'lscpu': {e}")
2630
+
2631
+ # Parse only data lines (skip comments)
2632
+ cpu_info = []
2633
+ for line in output.splitlines():
2634
+ if not line.startswith("#"):
2635
+ cpu, core, socket, node = map(int, line.strip().split(","))
2636
+ cpu_info.append((cpu, core, socket, node))
2637
+
2638
+ # [(0,0,0,0),(1,1,0,0),...,(43,43,0,1),...,(256,0,0,0),...]
2639
+ return cpu_info
2640
+
2641
+
2642
+ def get_physical_cpus_by_numa():
2643
+ cpu_info = parse_lscpu_topology()
2644
+
2645
+ # Map NUMA node -> set of (core_id, socket) to avoid duplicates
2646
+ # 0: {(0,0): 0, (1, 0): 1,...}
2647
+ # ...
2648
+ # 5: {(214,1): 214, (215,1): 215}
2649
+ physical_by_node = defaultdict(dict) # node -> core_id -> cpu_id
2650
+
2651
+ for cpu, core, socket, node in cpu_info:
2652
+ key = (core, socket)
2653
+ if key not in physical_by_node[node]:
2654
+ physical_by_node[node][
2655
+ key
2656
+ ] = cpu # pick first CPU seen for that physical core
2657
+
2658
+ # Retrieves CPUs that the current process is allowed to run on
2659
+ cpus_allowed_list = psutil.Process().cpu_affinity()
2660
+
2661
+ # Convert to list of physical CPUs per node
2662
+ # 0: [0,1,2,...,42]
2663
+ # ...
2664
+ # 2: [86,87,...,127]
2665
+ # ...
2666
+ # 5: [214,215,...,255]
2667
+ node_to_cpus = {}
2668
+ for node, core_to_cpu in physical_by_node.items():
2669
+ cpus = sorted(core_to_cpu.values())
2670
+ allowed_cpus = set(cpus).intersection(cpus_allowed_list)
2671
+ node_to_cpus[node] = allowed_cpus
2672
+
2673
+ return node_to_cpus
2674
+
2675
+
2676
+ # Only physical cores are used. Logical cores are excluded.
2677
+ def get_cpu_ids_by_node():
2678
+ node_to_cpus = get_physical_cpus_by_numa()
2679
+ # Sort by NUMA node index
2680
+ cpu_ids = [
2681
+ ",".join(map(str, sorted(node_to_cpus[node]))) for node in sorted(node_to_cpus)
2682
+ ]
2683
+
2684
+ # ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
2685
+ return cpu_ids
2686
+
2687
+
2688
+ def is_shm_available(dtype, world_size, local_size):
2689
+ return (
2690
+ cpu_has_amx_support()
2691
+ and dtype in [torch.bfloat16, torch.float]
2692
+ and world_size >= 1
2693
+ and world_size == local_size
2694
+ )
2695
+
2696
+
2697
+ def lru_cache_frozenset(maxsize=128):
2698
+ def _to_hashable(o):
2699
+ try:
2700
+ hash(o)
2701
+ return o
2702
+ except TypeError:
2703
+ # Not hashable; convert based on type
2704
+ if isinstance(o, (dict)):
2705
+ return frozenset(
2706
+ (_to_hashable(k), _to_hashable(v)) for k, v in o.items()
2707
+ )
2708
+ elif isinstance(o, set):
2709
+ return frozenset(_to_hashable(v) for v in o)
2710
+ elif isinstance(o, (list, tuple)) or (
2711
+ isinstance(o, Sequence) and not isinstance(o, (str, bytes))
2712
+ ):
2713
+ return tuple(_to_hashable(v) for v in o)
2714
+ else:
2715
+ raise TypeError(f"Cannot make hashable: {type(o)}")
2716
+
2717
+ def decorator(func):
2718
+ cache = OrderedDict()
2719
+
2720
+ @functools.wraps(func)
2721
+ def wrapper(*args, **kwargs):
2722
+ h_args = tuple(_to_hashable(a) for a in args)
2723
+ h_kwargs = frozenset(
2724
+ (_to_hashable(k), _to_hashable(v)) for k, v in kwargs.items()
2725
+ )
2726
+ key = (h_args, h_kwargs)
2727
+ if key in cache:
2728
+ cache.move_to_end(key)
2729
+ return cache[key]
2730
+ result = func(*args, **kwargs)
2731
+ cache[key] = result
2732
+ if maxsize is not None and len(cache) > maxsize:
2733
+ cache.popitem(last=False)
2734
+ return result
2735
+
2736
+ wrapper.cache_clear = cache.clear # For manual cache clearing
2737
+ return wrapper
2738
+
2739
+ return decorator
sglang/srt/warmup.py CHANGED
@@ -4,6 +4,7 @@ from typing import List
4
4
  import numpy as np
5
5
  import tqdm
6
6
 
7
+ from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
7
8
  from sglang.srt.managers.io_struct import GenerateReqInput
8
9
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
9
10
 
@@ -20,17 +21,21 @@ def warmup(name: str) -> callable:
20
21
  return decorator
21
22
 
22
23
 
23
- async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager):
24
+ async def execute_warmups(
25
+ disaggregation_mode: str,
26
+ warmup_names: List[str],
27
+ tokenizer_manager: TokenizerManager,
28
+ ):
24
29
  for warmup_name in warmup_names:
25
30
  if warmup_name not in _warmup_registry:
26
31
  logger.warning(f"Could not find custom warmup {warmup_name}")
27
32
  continue
28
33
  logger.info(f"Running warmup {warmup_name}")
29
- await _warmup_registry[warmup_name](tokenizer_manager)
34
+ await _warmup_registry[warmup_name](disaggregation_mode, tokenizer_manager)
30
35
 
31
36
 
32
37
  @warmup("voice_chat")
33
- async def voice_chat(tokenizer_manager: TokenizerManager):
38
+ async def voice_chat(disaggregation_mode: str, tokenizer_manager: TokenizerManager):
34
39
  # this warms up the fused_moe triton kernels and caches them
35
40
  # if we don't do this we break real time inference for voice chat
36
41
  for i in tqdm.trange(1, 512):
@@ -44,4 +49,8 @@ async def voice_chat(tokenizer_manager: TokenizerManager):
44
49
  "min_p": 0.0,
45
50
  },
46
51
  )
52
+ if disaggregation_mode != "null":
53
+ generate_req_input.bootstrap_room = 0
54
+ generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
55
+
47
56
  await tokenizer_manager.generate_request(generate_req_input, None).__anext__()