sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
1
+ import atexit
2
+ import concurrent.futures
3
+ import json
4
+ import logging
5
+ import os
6
+ import signal
7
+ import threading
8
+ from collections import OrderedDict
9
+ from functools import wraps
10
+ from typing import List, Optional
11
+
12
+ import torch
13
+
14
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
15
+ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class AtomicCounter:
21
+ def __init__(self, n: int):
22
+ assert n > 0
23
+ self.n = n
24
+ self._value = 0
25
+ self._lock = threading.Lock()
26
+
27
+ def next(self) -> int:
28
+ with self._lock:
29
+ current = self._value
30
+ self._value = (current + 1) % self.n
31
+ return current
32
+
33
+
34
+ def synchronized():
35
+ def _decorator(func):
36
+ @wraps(func)
37
+ def wrapper(self, *args, **kwargs):
38
+ with self.lock:
39
+ return func(self, *args, **kwargs)
40
+
41
+ return wrapper
42
+
43
+ return _decorator
44
+
45
+
46
+ class HiCacheHF3FS(HiCacheStorage):
47
+ default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
48
+
49
+ def __init__(
50
+ self,
51
+ file_path: str,
52
+ file_size: int,
53
+ numjobs: int,
54
+ bytes_per_page: int,
55
+ entries: int,
56
+ dtype: torch.dtype,
57
+ ):
58
+ self.file_path = file_path
59
+ self.file_size = file_size
60
+ self.numjobs = numjobs
61
+ self.bytes_per_page = bytes_per_page
62
+ self.entries = entries
63
+ self.dtype = dtype
64
+
65
+ self.numel = self.bytes_per_page // self.dtype.itemsize
66
+
67
+ self.num_pages = self.file_size // self.bytes_per_page
68
+
69
+ logger.info(
70
+ "HiCacheHF3FS "
71
+ f"file_path = {self.file_path}, "
72
+ f"file_size = {self.file_size/(2**30):.2f} GB, "
73
+ f"numjobs = {self.numjobs}, "
74
+ f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
75
+ f"entries = {self.entries}, "
76
+ f"num_pages = {self.num_pages}"
77
+ )
78
+
79
+ self.ac = AtomicCounter(self.numjobs)
80
+ self.clients = [
81
+ Hf3fsClient(
82
+ self.file_path, self.file_size, self.bytes_per_page, self.entries
83
+ )
84
+ for _ in range(numjobs)
85
+ ]
86
+ self.executor = concurrent.futures.ThreadPoolExecutor(
87
+ max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
88
+ )
89
+
90
+ # Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
91
+ # Future iterations may adopt a global KVCache manager to coordinate external cache instances
92
+ # through centralized metadata orchestration.
93
+ self.lock = threading.RLock()
94
+ self.free_pages = list(range(self.num_pages))
95
+ self.key_to_index = OrderedDict()
96
+
97
+ atexit.register(self.close)
98
+
99
+ signal.signal(signal.SIGINT, lambda sig, frame: self.close())
100
+ signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
101
+ signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
102
+
103
+ @staticmethod
104
+ def from_env_config(
105
+ rank: int, bytes_per_page: int, dtype: torch.dtype
106
+ ) -> "HiCacheHF3FS":
107
+ config_path = os.getenv(HiCacheHF3FS.default_env_var)
108
+ if not config_path:
109
+ return HiCacheHF3FS(
110
+ file_path=f"/data/hicache.{rank}.bin",
111
+ file_size=1 << 40,
112
+ numjobs=16,
113
+ bytes_per_page=bytes_per_page,
114
+ entries=8,
115
+ dtype=dtype,
116
+ )
117
+
118
+ try:
119
+ with open(config_path, "r") as f:
120
+ config = json.load(f)
121
+ except Exception as e:
122
+ raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
123
+
124
+ required_keys = {
125
+ "file_path_prefix",
126
+ "file_size",
127
+ "numjobs",
128
+ "entries",
129
+ }
130
+ missing_keys = required_keys - set(config.keys())
131
+ if missing_keys:
132
+ raise ValueError(f"Missing required keys in config: {missing_keys}")
133
+
134
+ return HiCacheHF3FS(
135
+ file_path=f"{config['file_path_prefix']}.{rank}.bin",
136
+ file_size=int(config["file_size"]),
137
+ numjobs=int(config["numjobs"]),
138
+ bytes_per_page=bytes_per_page,
139
+ entries=int(config["entries"]),
140
+ dtype=dtype,
141
+ )
142
+
143
+ def get(
144
+ self, key: str, target_location: Optional[torch.Tensor] = None
145
+ ) -> torch.Tensor | None:
146
+ return self.batch_get([key], target_location)[0]
147
+
148
+ @synchronized()
149
+ def batch_get(
150
+ self,
151
+ keys: List[str],
152
+ target_locations: Optional[List[torch.Tensor]] = None,
153
+ ) -> List[torch.Tensor | None]:
154
+ batch_indices, file_offsets = [], []
155
+ for i, key in enumerate(keys):
156
+ if key not in self.key_to_index:
157
+ continue
158
+ batch_indices.append(i)
159
+ file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
160
+ self.key_to_index.move_to_end(key)
161
+ # TODO: target_locations
162
+ file_results = [
163
+ torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
164
+ ]
165
+
166
+ futures = [
167
+ self.executor.submit(
168
+ self.clients[self.ac.next()].batch_read,
169
+ file_offsets[i : i + self.entries],
170
+ file_results[i : i + self.entries],
171
+ )
172
+ for i in range(0, len(batch_indices), self.entries)
173
+ ]
174
+ read_results = [result for future in futures for result in future.result()]
175
+
176
+ results = [None] * len(keys)
177
+ for batch_index, file_result, read_result in zip(
178
+ batch_indices, file_results, read_results
179
+ ):
180
+ if read_result == self.bytes_per_page:
181
+ results[batch_index] = file_result
182
+ else:
183
+ logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
184
+
185
+ return results
186
+
187
+ def set(self, key: str, value: torch.Tensor) -> bool:
188
+ return self.batch_set([key], [value])
189
+
190
+ def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
191
+ indices = self.get_batch_set_indices(keys)
192
+ batch_indices, file_offsets, file_values = [], [], []
193
+ for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
194
+ if is_written or index == -1:
195
+ continue
196
+ batch_indices.append(i)
197
+ file_offsets.append(index * self.bytes_per_page)
198
+ file_values.append(value.contiguous())
199
+
200
+ futures = [
201
+ self.executor.submit(
202
+ self.clients[self.ac.next()].batch_write,
203
+ file_offsets[i : i + self.entries],
204
+ file_values[i : i + self.entries],
205
+ )
206
+ for i in range(0, len(batch_indices), self.entries)
207
+ ]
208
+ write_results = [
209
+ result == self.bytes_per_page
210
+ for future in futures
211
+ for result in future.result()
212
+ ]
213
+
214
+ results = [index[0] for index in indices]
215
+ for batch_index, write_result in zip(batch_indices, write_results):
216
+ key = keys[batch_index]
217
+ index = indices[batch_index][1]
218
+ if write_result:
219
+ self.key_to_index[key] = index
220
+ self.key_to_index.move_to_end(key)
221
+ else:
222
+ logger.error(f"HiCacheHF3FS set {key} failed")
223
+ self.free_pages.append(index)
224
+ results[batch_index] = write_result
225
+ return all(results)
226
+
227
+ @synchronized()
228
+ def get_batch_set_indices(self, keys: List[str]) -> list:
229
+ ionum = len(keys)
230
+ # results: tuples of (is_written: bool, page_idx: int)
231
+ # - is_written: True = hit (no I/O), False = write (miss)
232
+ # - page_idx: page storing data
233
+ results = [None] * min(ionum, self.num_pages)
234
+ if ionum > self.num_pages:
235
+ results.extend([(False, -1)] * (ionum - self.num_pages))
236
+
237
+ new_keys = []
238
+ for batch_index, key in enumerate(keys[: self.num_pages]):
239
+ if key in self.key_to_index:
240
+ results[batch_index] = (True, self.key_to_index[key])
241
+ self.key_to_index.move_to_end(key)
242
+ else:
243
+ new_keys.append((batch_index, key))
244
+
245
+ for batch_index, _ in new_keys:
246
+ index = (
247
+ self.free_pages.pop()
248
+ if len(self.free_pages) > 0
249
+ else self.key_to_index.popitem(last=False)[1]
250
+ )
251
+ results[batch_index] = (False, index)
252
+
253
+ return results
254
+
255
+ @synchronized()
256
+ def delete(self, key: str) -> None:
257
+ if key not in self.key_to_index:
258
+ return
259
+ index = self.key_to_index.pop(key)
260
+ self.free_pages.append(index)
261
+
262
+ @synchronized()
263
+ def exists(self, key: str) -> bool:
264
+ return key in self.key_to_index
265
+
266
+ @synchronized()
267
+ def clear(self) -> None:
268
+ self.free_pages = list(range(self.num_pages))
269
+ self.key_to_index.clear()
270
+
271
+ def close(self) -> None:
272
+ try:
273
+ for c in self.clients:
274
+ c.close()
275
+ self.executor.shutdown(wait=True)
276
+ except Exception as e:
277
+ logger.error(f"close HiCacheHF3FS: {e}")
278
+ logger.info("close HiCacheHF3FS")
@@ -0,0 +1,43 @@
1
+ import multiprocessing.shared_memory
2
+ from pathlib import Path
3
+
4
+ import pytest
5
+ import torch
6
+ from torch.utils.cpp_extension import load
7
+ from tqdm import tqdm
8
+
9
+ root = Path(__file__).parent.resolve()
10
+ hf3fs_utils = load(
11
+ name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"], verbose=True
12
+ )
13
+
14
+
15
+ def test_rw_shm():
16
+ numel = 8 << 20
17
+ dtype = torch.bfloat16
18
+ page_num = 128
19
+ page_bytes = numel * dtype.itemsize
20
+ shm = multiprocessing.shared_memory.SharedMemory(
21
+ size=page_num * page_bytes, create=True
22
+ )
23
+ tshm = torch.frombuffer(shm.buf, dtype=torch.uint8)
24
+ a = [
25
+ torch.randn(numel, dtype=dtype)
26
+ for _ in tqdm(range(page_num), desc="prepare input")
27
+ ]
28
+ b = [
29
+ torch.empty(numel, dtype=dtype)
30
+ for _ in tqdm(range(page_num), desc="prepare output")
31
+ ]
32
+ hf3fs_utils.write_shm(a, tshm)
33
+ hf3fs_utils.read_shm(tshm, b)
34
+ for _a, _b in tqdm(zip(a, b), desc="assert_close"):
35
+ torch.testing.assert_close(_a, _b)
36
+
37
+ del tshm
38
+ shm.close()
39
+ shm.unlink()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ pytest.main([__file__])
@@ -16,6 +16,7 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import bisect
19
+ import gc
19
20
  import inspect
20
21
  import logging
21
22
  import os
@@ -75,6 +76,24 @@ def model_capture_mode():
75
76
  is_capture_mode = False
76
77
 
77
78
 
79
+ @contextmanager
80
+ def freeze_gc(enable_cudagraph_gc: bool):
81
+ """
82
+ Optimize garbage collection during CUDA graph capture.
83
+ Clean up, then freeze all remaining objects from being included
84
+ in future collections if GC is disabled during capture.
85
+ """
86
+ gc.collect()
87
+ should_freeze = not enable_cudagraph_gc
88
+ if should_freeze:
89
+ gc.freeze()
90
+ try:
91
+ yield
92
+ finally:
93
+ if should_freeze:
94
+ gc.unfreeze()
95
+
96
+
78
97
  def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
79
98
  for sub in model._modules.values():
80
99
  if isinstance(sub, CustomOp):
@@ -423,7 +442,12 @@ class CudaGraphRunner:
423
442
  record_shapes=True,
424
443
  )
425
444
 
426
- with graph_capture() as graph_capture_context:
445
+ # Trigger CUDA graph capture for specific shapes.
446
+ # Capture the large shapes first so that the smaller shapes
447
+ # can reuse the memory pool allocated for the large shapes.
448
+ with freeze_gc(
449
+ self.model_runner.server_args.enable_cudagraph_gc
450
+ ), graph_capture() as graph_capture_context:
427
451
  with profile_context as prof:
428
452
  self.stream = graph_capture_context.stream
429
453
  avail_mem = get_available_gpu_memory(
@@ -157,6 +157,8 @@ class ModelRunner:
157
157
  gpu_id: int,
158
158
  tp_rank: int,
159
159
  tp_size: int,
160
+ moe_ep_rank: int,
161
+ moe_ep_size: int,
160
162
  pp_rank: int,
161
163
  pp_size: int,
162
164
  nccl_port: int,
@@ -175,6 +177,8 @@ class ModelRunner:
175
177
  logger.addFilter(RankZeroFilter(tp_rank == 0))
176
178
  self.tp_rank = tp_rank
177
179
  self.tp_size = tp_size
180
+ self.moe_ep_rank = moe_ep_rank
181
+ self.moe_ep_size = moe_ep_size
178
182
  self.dp_size = server_args.dp_size
179
183
  self.pp_rank = pp_rank
180
184
  self.pp_size = pp_size
@@ -432,6 +436,7 @@ class ModelRunner:
432
436
  "triton",
433
437
  "flashmla",
434
438
  "cutlass_mla",
439
+ "trtllm_mla",
435
440
  "ascend",
436
441
  ]:
437
442
  logger.info(
@@ -549,6 +554,7 @@ class ModelRunner:
549
554
  initialize_model_parallel(
550
555
  tensor_model_parallel_size=self.tp_size,
551
556
  pipeline_model_parallel_size=self.pp_size,
557
+ expert_model_parallel_size=self.moe_ep_size,
552
558
  duplicate_tp_group=self.server_args.enable_pdmux,
553
559
  )
554
560
  initialize_dp_attention(
@@ -666,7 +672,7 @@ class ModelRunner:
666
672
  self.sliding_window_size = self.model.get_attention_sliding_window_size()
667
673
  elif self.model_config.attention_chunk_size is not None:
668
674
  self.sliding_window_size = self.model_config.attention_chunk_size
669
- print(
675
+ logger.info(
670
676
  f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
671
677
  )
672
678
 
@@ -1432,6 +1438,12 @@ class ModelRunner:
1432
1438
  )
1433
1439
 
1434
1440
  return CutlassMLABackend(self)
1441
+ elif self.server_args.attention_backend == "trtllm_mla":
1442
+ if not self.use_mla_backend:
1443
+ raise ValueError("trtllm_mla backend can only be used with MLA models.")
1444
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
1445
+
1446
+ return TRTLLMMLABackend(self)
1435
1447
  elif self.server_args.attention_backend == "intel_amx":
1436
1448
  from sglang.srt.layers.attention.intel_amx_backend import (
1437
1449
  IntelAMXAttnBackend,
@@ -229,6 +229,8 @@ def get_quant_config(
229
229
  f"Unsupported quantization config"
230
230
  f" found for {model_config.quantization} in {f}."
231
231
  )
232
+ elif model_config.quantization == "w8a8_int8":
233
+ config["packed_modules_mapping"] = packed_modules_mapping
232
234
 
233
235
  return quant_cls.from_config(config)
234
236