sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +79 -53
  3. sglang/bench_serving.py +186 -14
  4. sglang/profiler.py +0 -1
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/longcat_flash.py +104 -0
  7. sglang/srt/configs/model_config.py +12 -0
  8. sglang/srt/connector/__init__.py +1 -1
  9. sglang/srt/connector/base_connector.py +1 -2
  10. sglang/srt/connector/redis.py +2 -2
  11. sglang/srt/connector/serde/__init__.py +1 -1
  12. sglang/srt/connector/serde/safe_serde.py +4 -3
  13. sglang/srt/conversation.py +38 -5
  14. sglang/srt/disaggregation/ascend/conn.py +75 -0
  15. sglang/srt/disaggregation/launch_lb.py +0 -13
  16. sglang/srt/disaggregation/mini_lb.py +33 -8
  17. sglang/srt/disaggregation/prefill.py +1 -1
  18. sglang/srt/distributed/parallel_state.py +24 -14
  19. sglang/srt/entrypoints/engine.py +19 -12
  20. sglang/srt/entrypoints/http_server.py +174 -34
  21. sglang/srt/entrypoints/openai/protocol.py +87 -24
  22. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  23. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  24. sglang/srt/eplb/eplb_manager.py +26 -2
  25. sglang/srt/eplb/expert_distribution.py +29 -2
  26. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  27. sglang/srt/function_call/function_call_parser.py +2 -0
  28. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  29. sglang/srt/harmony_parser.py +588 -0
  30. sglang/srt/hf_transformers_utils.py +26 -7
  31. sglang/srt/layers/activation.py +12 -0
  32. sglang/srt/layers/attention/ascend_backend.py +374 -136
  33. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  34. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  35. sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
  36. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  38. sglang/srt/layers/communicator.py +1 -2
  39. sglang/srt/layers/layernorm.py +28 -3
  40. sglang/srt/layers/linear.py +3 -2
  41. sglang/srt/layers/logits_processor.py +1 -1
  42. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  43. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  44. sglang/srt/layers/moe/ep_moe/layer.py +13 -13
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/topk.py +35 -12
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  49. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  50. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  51. sglang/srt/layers/quantization/fp8.py +2 -1
  52. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  53. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  54. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  55. sglang/srt/layers/quantization/mxfp4.py +25 -27
  56. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  57. sglang/srt/layers/quantization/utils.py +13 -0
  58. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  59. sglang/srt/layers/rotary_embedding.py +28 -1
  60. sglang/srt/layers/sampler.py +29 -5
  61. sglang/srt/layers/utils.py +0 -14
  62. sglang/srt/managers/cache_controller.py +237 -204
  63. sglang/srt/managers/detokenizer_manager.py +48 -2
  64. sglang/srt/managers/io_struct.py +57 -0
  65. sglang/srt/managers/mm_utils.py +5 -1
  66. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  67. sglang/srt/managers/scheduler.py +94 -9
  68. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  69. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  70. sglang/srt/managers/tokenizer_manager.py +122 -42
  71. sglang/srt/mem_cache/chunk_cache.py +1 -1
  72. sglang/srt/mem_cache/hicache_storage.py +51 -23
  73. sglang/srt/mem_cache/hiradix_cache.py +87 -71
  74. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  75. sglang/srt/mem_cache/memory_pool.py +77 -14
  76. sglang/srt/mem_cache/memory_pool_host.py +4 -5
  77. sglang/srt/mem_cache/radix_cache.py +6 -4
  78. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  79. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
  80. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
  81. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  82. sglang/srt/model_executor/model_runner.py +6 -5
  83. sglang/srt/model_loader/loader.py +15 -24
  84. sglang/srt/model_loader/utils.py +12 -0
  85. sglang/srt/models/deepseek_v2.py +38 -13
  86. sglang/srt/models/gpt_oss.py +2 -15
  87. sglang/srt/models/llama_eagle3.py +4 -0
  88. sglang/srt/models/longcat_flash.py +1015 -0
  89. sglang/srt/models/longcat_flash_nextn.py +691 -0
  90. sglang/srt/models/qwen2.py +26 -3
  91. sglang/srt/models/qwen2_5_vl.py +66 -41
  92. sglang/srt/models/qwen2_moe.py +22 -2
  93. sglang/srt/models/transformers.py +1 -1
  94. sglang/srt/multimodal/processors/base_processor.py +4 -2
  95. sglang/srt/reasoning_parser.py +56 -300
  96. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  97. sglang/srt/server_args.py +122 -56
  98. sglang/srt/speculative/eagle_worker.py +28 -8
  99. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  100. sglang/srt/utils.py +73 -5
  101. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  102. sglang/version.py +1 -1
  103. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
  104. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
  105. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  106. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,7 @@ from typing import Any, List, Optional, Tuple
11
11
 
12
12
  import torch
13
13
 
14
- from sglang.srt.distributed import get_tensor_model_parallel_rank
15
- from sglang.srt.layers.dp_attention import (
16
- get_attention_tp_rank,
17
- is_dp_attention_enabled,
18
- )
19
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
14
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
20
15
  from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
21
16
 
22
17
  logger = logging.getLogger(__name__)
@@ -130,6 +125,7 @@ class HiCacheHF3FS(HiCacheStorage):
130
125
  entries: int,
131
126
  dtype: torch.dtype,
132
127
  metadata_client: Hf3fsMetadataInterface,
128
+ is_mla_model: bool = False,
133
129
  ):
134
130
  self.rank = rank
135
131
  self.file_path = file_path
@@ -139,9 +135,13 @@ class HiCacheHF3FS(HiCacheStorage):
139
135
  self.entries = entries
140
136
  self.dtype = dtype
141
137
  self.metadata_client = metadata_client
142
-
138
+ self.is_mla_model = is_mla_model
143
139
  self.numel = self.bytes_per_page // self.dtype.itemsize
144
140
  self.num_pages = self.file_size // self.bytes_per_page
141
+ self.skip_backup = False
142
+ if self.is_mla_model and self.rank != 0:
143
+ self.skip_backup = True
144
+ self.rank = 0
145
145
 
146
146
  logger.info(
147
147
  f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
@@ -172,19 +172,16 @@ class HiCacheHF3FS(HiCacheStorage):
172
172
 
173
173
  @staticmethod
174
174
  def from_env_config(
175
- bytes_per_page: int, dtype: torch.dtype, rank: int = None
175
+ bytes_per_page: int,
176
+ dtype: torch.dtype,
177
+ storage_config: HiCacheStorageConfig = None,
176
178
  ) -> "HiCacheHF3FS":
177
179
  from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
178
180
  Hf3fsGlobalMetadataClient,
179
181
  Hf3fsLocalMetadataClient,
180
182
  )
181
183
 
182
- if rank is None:
183
- rank = (
184
- get_attention_tp_rank()
185
- if is_dp_attention_enabled()
186
- else get_tensor_model_parallel_rank()
187
- )
184
+ rank = storage_config.tp_rank if storage_config is not None else 0
188
185
 
189
186
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
190
187
  if not config_path:
@@ -217,10 +214,14 @@ class HiCacheHF3FS(HiCacheStorage):
217
214
  raise ValueError(f"Missing required keys in config: {missing_keys}")
218
215
 
219
216
  # Choose metadata client based on configuration
217
+ is_mla_model = False
220
218
  if "metadata_server_url" in config and config["metadata_server_url"]:
221
219
  # Use global metadata client to connect to metadata server
222
220
  metadata_server_url = config["metadata_server_url"]
223
221
  metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
222
+
223
+ # Enable MLA optimization only when using the global metadata client
224
+ is_mla_model = storage_config.is_mla_model if storage_config else False
224
225
  logger.info(
225
226
  f"Using global metadata client with server url: {metadata_server_url}"
226
227
  )
@@ -230,13 +231,15 @@ class HiCacheHF3FS(HiCacheStorage):
230
231
 
231
232
  return HiCacheHF3FS(
232
233
  rank=rank,
233
- file_path=f"{config['file_path_prefix']}.{rank}.bin",
234
+ # Let all ranks use the same file path for MLA model
235
+ file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
234
236
  file_size=int(config["file_size"]),
235
237
  numjobs=int(config["numjobs"]),
236
238
  bytes_per_page=bytes_per_page,
237
239
  entries=int(config["entries"]),
238
240
  dtype=dtype,
239
241
  metadata_client=metadata_client,
242
+ is_mla_model=is_mla_model,
240
243
  )
241
244
 
242
245
  def get(
@@ -320,6 +323,10 @@ class HiCacheHF3FS(HiCacheStorage):
320
323
  target_locations: Optional[Any] = None,
321
324
  target_sizes: Optional[Any] = None,
322
325
  ) -> bool:
326
+ # In MLA backend, only one rank needs to backup the KV cache
327
+ if self.skip_backup:
328
+ return True
329
+
323
330
  # Todo: Add prefix block's hash key
324
331
  key_with_prefix = [(key, "") for key in keys]
325
332
  indices = self.metadata_client.reserve_and_allocate_page_indices(
@@ -371,18 +378,29 @@ class HiCacheHF3FS(HiCacheStorage):
371
378
 
372
379
  return all(results)
373
380
 
374
- @synchronized()
375
381
  def delete(self, key: str) -> None:
376
382
  self.metadata_client.delete_keys(self.rank, [key])
377
383
 
378
- @synchronized()
379
384
  def exists(self, key: str) -> bool:
380
385
  result = self.metadata_client.exists(self.rank, [key])
381
386
  return result[0] if result else False
382
387
 
383
- @synchronized()
384
- def clear(self) -> None:
385
- self.metadata_client.clear(self.rank)
388
+ def batch_exists(self, keys: List[str]) -> int:
389
+ results = self.metadata_client.exists(self.rank, keys)
390
+ for i in range(len(keys)):
391
+ if not results[i]:
392
+ return i
393
+
394
+ return len(keys)
395
+
396
+ def clear(self) -> bool:
397
+ try:
398
+ self.metadata_client.clear(self.rank)
399
+ logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
400
+ return True
401
+ except Exception as e:
402
+ logger.error(f"Failed to clear HiCacheHF3FS: {e}")
403
+ return False
386
404
 
387
405
  def close(self) -> None:
388
406
  try:
@@ -10,24 +10,14 @@ import numpy as np
10
10
  import torch
11
11
 
12
12
  from sglang.srt.distributed import get_tensor_model_parallel_rank
13
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
13
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
14
14
 
15
15
  DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
16
- DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
16
+ DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
20
20
 
21
- def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
22
- prefix_str = ""
23
- if prior_hash:
24
- prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
25
- current_token_ids_bytes = np.array(token_ids).tobytes()
26
- current_hash_object = hashlib.sha256(current_token_ids_bytes)
27
- current_hash_hex = current_hash_object.hexdigest()
28
- return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
29
-
30
-
31
21
  @dataclass
32
22
  class MooncakeStoreConfig:
33
23
  local_hostname: str
@@ -54,9 +44,8 @@ class MooncakeStoreConfig:
54
44
  global_segment_size=config.get(
55
45
  "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
56
46
  ),
57
- local_buffer_size=config.get(
58
- "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
59
- ),
47
+ # Zero copy interface does not need local buffer
48
+ local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
60
49
  protocol=config.get("protocol", "tcp"),
61
50
  device_name=config.get("device_name", "auto"),
62
51
  master_server_address=config.get("master_server_address"),
@@ -79,9 +68,8 @@ class MooncakeStoreConfig:
79
68
  global_segment_size=int(
80
69
  os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
81
70
  ),
82
- local_buffer_size=int(
83
- os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
84
- ),
71
+ # Zero copy interface does not need local buffer
72
+ local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
85
73
  protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
86
74
  device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
87
75
  master_server_address=os.getenv("MOONCAKE_MASTER"),
@@ -96,7 +84,7 @@ class MooncakeStoreConfig:
96
84
 
97
85
 
98
86
  class MooncakeStore(HiCacheStorage):
99
- def __init__(self, is_mla: bool = False):
87
+ def __init__(self, storage_config: HiCacheStorageConfig = None):
100
88
  try:
101
89
  from mooncake.store import MooncakeDistributedStore
102
90
  except ImportError as e:
@@ -126,7 +114,13 @@ class MooncakeStore(HiCacheStorage):
126
114
  logger.info("Connect to Mooncake store successfully.")
127
115
  self.warmup()
128
116
  logger.info("Mooncake store warmup successfully.")
129
- self.is_mla = is_mla
117
+
118
+ if storage_config is not None:
119
+ self.is_mla_backend = storage_config.is_mla_model
120
+ self.local_rank = storage_config.tp_rank
121
+ else:
122
+ self.is_mla_backend = False
123
+ self.local_rank = 0
130
124
 
131
125
  except ValueError as e:
132
126
  logger.error("Configuration loading failed: %s", e)
@@ -137,12 +131,10 @@ class MooncakeStore(HiCacheStorage):
137
131
 
138
132
  def warmup(self):
139
133
  warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
140
- # 10 MB
141
- warmup_value = bytes(10 * 1024 * 1024)
142
- self.store.put(warmup_key, warmup_value)
134
+ warmup_value = bytes(4 * 1024) # 4 KB
135
+ assert self.store.put(warmup_key, warmup_value) == 0
143
136
  assert self.store.is_exist(warmup_key) == 1
144
- self.store.get(warmup_key)
145
- self.store.remove(warmup_key)
137
+ assert self.store.get(warmup_key) == warmup_value
146
138
 
147
139
  def register_buffer(self, buffer: torch.Tensor) -> None:
148
140
  try:
@@ -162,78 +154,96 @@ class MooncakeStore(HiCacheStorage):
162
154
  target_location: Optional[List[int]] = None,
163
155
  target_sizes: Optional[List[int]] = None,
164
156
  ) -> bool:
165
- assert len(key) == len(target_location) == len(target_sizes)
166
- if len(key) == 0:
167
- return
168
-
169
- for i in range(len(key)):
170
- if key[i] is None or target_location[i] is None or target_sizes[i] is None:
171
- return
172
-
173
- self._put_batch_zero_copy_impl(key, target_location, target_sizes)
157
+ return self.batch_set([key], [value], [target_location], [target_sizes])
174
158
 
175
159
  def batch_set(
176
160
  self,
177
161
  keys: List[str],
178
- value: Optional[Any] = None,
162
+ values: Optional[List[torch.Tensor]] = None,
179
163
  target_location: Optional[List[int]] = None,
180
164
  target_sizes: Optional[List[int]] = None,
181
165
  ) -> bool:
182
166
  assert len(keys) == len(target_location) == len(target_sizes)
183
167
  if len(keys) == 0:
184
- return
168
+ return False
185
169
 
186
170
  for i in range(len(keys)):
187
171
  if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
188
- return
172
+ return False
189
173
 
190
- self._put_batch_zero_copy_impl(keys, target_location, target_sizes)
174
+ exist_result = self._batch_exist(keys)
175
+ set_keys = []
176
+ set_target_locations = []
177
+ set_target_sizes = []
178
+ set_indices = []
179
+ for i in range(len(keys)):
180
+ if exist_result[i] != 1:
181
+ set_keys.append(keys[i])
182
+ set_target_locations.append(target_location[i])
183
+ set_target_sizes.append(target_sizes[i])
184
+ set_indices.append(i)
185
+ # Only set non-existing keys to storage
186
+ put_result = self._put_batch_zero_copy_impl(
187
+ set_keys, set_target_locations, set_target_sizes
188
+ )
189
+ for i in range(len(set_indices)):
190
+ if put_result[i] == 0:
191
+ exist_result[set_indices[i]] = 1
192
+
193
+ success_count = 0
194
+ for i in range(len(keys)):
195
+ if exist_result[i] == 0:
196
+ break
197
+ success_count += 1
198
+ # TODO: return the number of consecutive successful operations from the start.
199
+ return success_count == len(keys)
191
200
 
192
201
  def get(
193
202
  self,
194
203
  key,
195
204
  target_location: Optional[Any] = None,
196
205
  target_sizes: Optional[Any] = None,
197
- ) -> torch.Tensor | None:
198
- assert len(key) == len(target_location) == len(target_sizes)
199
- if len(key) == 0:
200
- return
201
-
202
- for i in range(len(key)):
203
- if key[i] is None or target_location[i] is None or target_sizes[i] is None:
204
- return
205
-
206
- return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
206
+ ) -> bool:
207
+ return self.batch_get([key], [target_location], [target_sizes]) == 1
207
208
 
208
209
  def batch_get(
209
210
  self,
210
211
  keys: List[str],
211
212
  target_location: Optional[Any] = None,
212
213
  target_sizes: Optional[Any] = None,
213
- ) -> torch.Tensor | None:
214
+ ) -> int:
214
215
  assert len(keys) == len(target_location) == len(target_sizes)
215
216
  if len(keys) == 0:
216
- return
217
-
217
+ return 0
218
+ get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
219
+ if self.is_mla_backend:
220
+ key_multiplier = 1
221
+ else:
222
+ key_multiplier = 2
218
223
  for i in range(len(keys)):
219
- if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
220
- return
221
-
222
- return self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
223
-
224
- def exists(self, keys) -> bool | dict:
225
- _keys = []
226
- local_rank = get_tensor_model_parallel_rank()
227
- for key in keys:
228
- if key is None:
229
- return None
230
-
231
- if self.is_mla:
232
- _keys.append(f"{key}_k")
233
- else:
234
- _keys.append(f"{key}_{local_rank}_k")
235
- result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
236
- return result
224
+ if get_result[i] < 0:
225
+ return i // key_multiplier
226
+ return len(keys) // key_multiplier
227
+
228
+ def exists(self, key) -> bool:
229
+ return self.batch_exists([key]) > 0
230
+
231
+ def batch_exists(self, keys) -> int:
232
+ if self.is_mla_backend:
233
+ query_keys = [f"{key}_k" for key in keys]
234
+ key_multiplier = 1
235
+ else:
236
+ query_keys = []
237
+ for key in keys:
238
+ query_keys.append(f"{key}_{self.local_rank}_k")
239
+ query_keys.append(f"{key}_{self.local_rank}_v")
240
+ key_multiplier = 2
241
+
242
+ exist_result = self._batch_exist(query_keys)
243
+ for i in range(len(query_keys)):
244
+ if exist_result[i] != 1:
245
+ return i // key_multiplier
246
+ return len(query_keys) // key_multiplier
237
247
 
238
248
  def delete(self, key) -> None:
239
249
  raise (NotImplementedError)
@@ -244,22 +254,17 @@ class MooncakeStore(HiCacheStorage):
244
254
  pass
245
255
 
246
256
  def clear(self) -> None:
247
- raise (NotImplementedError)
257
+ self.store.remove_all()
248
258
 
249
259
  def _put_batch_zero_copy_impl(
250
260
  self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
251
- ) -> None:
252
- try:
253
- self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
254
- except TypeError as err:
255
- logger.error("Failed to put value to Mooncake Store: %s", err)
256
- raise TypeError("Mooncake Store Put Type Error.") from err
261
+ ) -> List[int]:
262
+ return self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
257
263
 
258
264
  def _get_batch_zero_copy_impl(
259
265
  self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
260
- ) -> None:
261
- try:
262
- self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
263
- except TypeError as err:
264
- logger.error("Failed to get value from Mooncake Store: %s", err)
265
- raise TypeError("Mooncake Store Get Type Error.") from err
266
+ ) -> List[int]:
267
+ return self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
268
+
269
+ def _batch_exist(self, key_strs: List[str]) -> List[int]:
270
+ return self.store.batch_is_exist(key_strs)
@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache):
464
464
  self.req_to_token_pool.free(req.req_pool_idx)
465
465
  self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
466
466
 
467
- def cache_unfinished_req(self, req: Req) -> None:
467
+ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
468
468
  """Cache request when it is unfinished."""
469
469
  if self.disable:
470
470
  kv_indices = self.req_to_token_pool.req_to_token[
@@ -66,7 +66,6 @@ from sglang.srt.layers.quantization import (
66
66
  )
67
67
  from sglang.srt.layers.sampler import Sampler
68
68
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
69
- from sglang.srt.layers.utils import is_sm100_supported
70
69
  from sglang.srt.lora.lora_manager import LoRAManager
71
70
  from sglang.srt.lora.lora_registry import LoRARef
72
71
  from sglang.srt.managers.schedule_batch import (
@@ -121,6 +120,7 @@ from sglang.srt.utils import (
121
120
  is_hopper_with_cuda_12_3,
122
121
  is_no_spec_infer_or_topk_one,
123
122
  is_npu,
123
+ is_sm100_supported,
124
124
  monkey_patch_p2p_access_check,
125
125
  monkey_patch_vllm_gguf_config,
126
126
  set_cuda_arch,
@@ -307,7 +307,10 @@ class ModelRunner:
307
307
  model_num_layers = (
308
308
  self.model_config.num_nextn_predict_layers
309
309
  if self.is_draft_worker and model_has_mtp_layers
310
- else self.model_config.num_hidden_layers
310
+ else max(
311
+ self.model_config.num_hidden_layers,
312
+ self.model_config.num_attention_layers,
313
+ )
311
314
  )
312
315
  self.start_layer = getattr(self.model, "start_layer", 0)
313
316
  self.end_layer = getattr(self.model, "end_layer", model_num_layers)
@@ -1440,14 +1443,12 @@ class ModelRunner:
1440
1443
  else self.server_args.attention_backend
1441
1444
  )
1442
1445
  if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1443
- assert (
1444
- self.server_args.speculative_algorithm is None
1445
- ), "Currently HybridAttentionBackend does not support speculative decoding."
1446
1446
  from sglang.srt.layers.attention.hybrid_attn_backend import (
1447
1447
  HybridAttnBackend,
1448
1448
  )
1449
1449
 
1450
1450
  attn_backend = HybridAttnBackend(
1451
+ self,
1451
1452
  decode_backend=self._get_attention_backend_from_str(
1452
1453
  self.decode_attention_backend_str
1453
1454
  ),
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
42
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
43
  from sglang.srt.model_loader.utils import (
44
44
  get_model_architecture,
45
+ post_load_weights,
45
46
  set_default_torch_dtype,
46
47
  )
47
48
  from sglang.srt.model_loader.weight_utils import (
@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
600
601
  # random values to the weights.
601
602
  initialize_dummy_weights(model)
602
603
 
603
- # Model weight loading consists of two stages:
604
- # 1. Initial weight loading.
605
- # 2. Post-processing of weights, including assigning specific member variables.
606
- # For `dummy_init`, only the second stage is required.
607
- if hasattr(model, "post_load_weights"):
608
- if (
609
- model_config.hf_config.architectures[0]
610
- == "DeepseekV3ForCausalLMNextN"
611
- ):
612
- model.post_load_weights(is_nextn=True)
613
- else:
614
- model.post_load_weights()
604
+ post_load_weights(model, model_config)
615
605
 
616
606
  return model.eval()
617
607
 
@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
751
741
  state_dict.pop(key)
752
742
  if state_dict:
753
743
  raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
744
+
745
+ post_load_weights(model, model_config)
746
+
754
747
  return model.eval()
755
748
 
756
749
  @staticmethod
@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
1421
1414
  # ignore hidden files
1422
1415
  if file_name.startswith("."):
1423
1416
  continue
1424
- if os.path.splitext(file_name)[1] not in (
1425
- ".bin",
1426
- ".pt",
1427
- ".safetensors",
1428
- ):
1417
+ if os.path.splitext(file_name)[1] in (".json", ".py"):
1429
1418
  file_path = os.path.join(root, file_name)
1430
1419
  with open(file_path, encoding="utf-8") as file:
1431
1420
  file_content = file.read()
1432
1421
  f_key = f"{model_name}/files/{file_name}"
1433
1422
  client.setstr(f_key, file_content)
1434
1423
 
1435
- def _load_model_from_remote_kv(self, model: nn.Module, client):
1424
+ def _load_model_from_remote_kv(
1425
+ self, model: nn.Module, model_config: ModelConfig, client
1426
+ ):
1436
1427
  for _, module in model.named_modules():
1437
1428
  quant_method = getattr(module, "quant_method", None)
1438
1429
  if quant_method is not None:
@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
1460
1451
  if state_dict:
1461
1452
  raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
1462
1453
 
1454
+ post_load_weights(model, model_config)
1455
+
1463
1456
  def _load_model_from_remote_fs(
1464
1457
  self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1465
1458
  ) -> nn.Module:
@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
1501
1494
  with set_default_torch_dtype(model_config.dtype):
1502
1495
  with torch.device(device_config.device):
1503
1496
  model = _initialize_model(model_config, self.load_config)
1504
- for _, module in model.named_modules():
1505
- quant_method = getattr(module, "quant_method", None)
1506
- if quant_method is not None:
1507
- quant_method.process_weights_after_loading(module)
1508
1497
 
1509
- with create_remote_connector(model_weights, device_config.device) as client:
1498
+ with create_remote_connector(
1499
+ model_weights, device=device_config.device
1500
+ ) as client:
1510
1501
  connector_type = get_connector_type(client)
1511
1502
  if connector_type == ConnectorType.KV:
1512
- self._load_model_from_remote_kv(model, client)
1503
+ self._load_model_from_remote_kv(model, model_config, client)
1513
1504
  elif connector_type == ConnectorType.FS:
1514
1505
  self._load_model_from_remote_fs(
1515
1506
  model, client, model_config, device_config
@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
105
105
 
106
106
  def get_architecture_class_name(model_config: ModelConfig) -> str:
107
107
  return get_model_architecture(model_config)[1]
108
+
109
+
110
+ def post_load_weights(model: nn.Module, model_config: ModelConfig):
111
+ # Model weight loading consists of two stages:
112
+ # 1. Initial weight loading.
113
+ # 2. Post-processing of weights, including assigning specific member variables.
114
+ # For `dummy_init`, only the second stage is required.
115
+ if hasattr(model, "post_load_weights"):
116
+ if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
117
+ model.post_load_weights(is_nextn=True)
118
+ else:
119
+ model.post_load_weights()
@@ -87,8 +87,8 @@ from sglang.srt.layers.quantization.int8_utils import (
87
87
  block_dequant as int8_block_dequant,
88
88
  )
89
89
  from sglang.srt.layers.radix_attention import RadixAttention
90
- from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
91
- from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
90
+ from sglang.srt.layers.rotary_embedding import get_rope_wrapper
91
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
92
92
  from sglang.srt.layers.vocab_parallel_embedding import (
93
93
  ParallelLMHead,
94
94
  VocabParallelEmbedding,
@@ -114,6 +114,8 @@ from sglang.srt.utils import (
114
114
  is_flashinfer_available,
115
115
  is_hip,
116
116
  is_non_idle_and_non_empty,
117
+ is_npu,
118
+ is_sm100_supported,
117
119
  log_info_on_rank0,
118
120
  make_layers,
119
121
  use_intel_amx_backend,
@@ -121,6 +123,7 @@ from sglang.srt.utils import (
121
123
 
122
124
  _is_hip = is_hip()
123
125
  _is_cuda = is_cuda()
126
+ _is_npu = is_npu()
124
127
  _is_fp8_fnuz = is_fp8_fnuz()
125
128
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
126
129
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -994,7 +997,14 @@ class DeepseekV2AttentionMLA(nn.Module):
994
997
  self.current_attention_backend = attention_backend
995
998
 
996
999
  if attention_backend == "ascend":
997
- return AttnForwardMethod.MLA
1000
+ if (
1001
+ forward_batch.forward_mode.is_extend()
1002
+ and not forward_batch.forward_mode.is_target_verify()
1003
+ and not forward_batch.forward_mode.is_draft_extend()
1004
+ ):
1005
+ return AttnForwardMethod.MHA
1006
+ else:
1007
+ return AttnForwardMethod.MLA
998
1008
  elif (
999
1009
  attention_backend == "flashinfer"
1000
1010
  or attention_backend == "fa3"
@@ -1173,13 +1183,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1173
1183
  k[..., : self.qk_nope_head_dim] = k_nope
1174
1184
  k[..., self.qk_nope_head_dim :] = k_pe
1175
1185
 
1176
- latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1177
- latent_cache[:, :, self.kv_lora_rank :] = k_pe
1186
+ if not _is_npu:
1187
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1188
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
1178
1189
 
1179
- # Save latent cache
1180
- forward_batch.token_to_kv_pool.set_kv_buffer(
1181
- self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1182
- )
1190
+ # Save latent cache
1191
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1192
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1193
+ )
1194
+ else:
1195
+ # To reduce a time-costing split operation
1196
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1197
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
1198
+ )
1183
1199
 
1184
1200
  return q, k, v, forward_batch
1185
1201
 
@@ -1292,6 +1308,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1292
1308
  or self.current_attention_backend == "flashinfer"
1293
1309
  or self.current_attention_backend == "cutlass_mla"
1294
1310
  or self.current_attention_backend == "trtllm_mla"
1311
+ or self.current_attention_backend == "ascend"
1295
1312
  ):
1296
1313
  extra_args = {}
1297
1314
  if self._fuse_rope_for_trtllm_mla(forward_batch):
@@ -2397,18 +2414,26 @@ class DeepseekV2ForCausalLM(nn.Module):
2397
2414
  )
2398
2415
 
2399
2416
  num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2417
+
2400
2418
  for layer_id in range(num_hidden_layers):
2401
2419
  if is_nextn:
2402
2420
  layer = self.model.decoder
2403
2421
  else:
2404
2422
  layer = self.model.layers[layer_id]
2405
2423
 
2406
- for module in [
2407
- layer.self_attn.fused_qkv_a_proj_with_mqa,
2408
- layer.self_attn.q_b_proj,
2424
+ module_list = [
2409
2425
  layer.self_attn.kv_b_proj,
2410
2426
  layer.self_attn.o_proj,
2411
- ]:
2427
+ ]
2428
+
2429
+ if self.config.q_lora_rank is not None:
2430
+ module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
2431
+ module_list.append(layer.self_attn.q_b_proj)
2432
+ else:
2433
+ module_list.append(layer.self_attn.kv_a_proj_with_mqa)
2434
+ module_list.append(layer.self_attn.q_proj)
2435
+
2436
+ for module in module_list:
2412
2437
  requant_weight_ue8m0_inplace(
2413
2438
  module.weight, module.weight_scale_inv, weight_block_size
2414
2439
  )