sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/srt/configs/model_config.py +1 -0
  2. sglang/srt/conversation.py +1 -0
  3. sglang/srt/custom_op.py +7 -1
  4. sglang/srt/disaggregation/base/conn.py +2 -0
  5. sglang/srt/disaggregation/decode.py +1 -1
  6. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  8. sglang/srt/disaggregation/nixl/conn.py +94 -46
  9. sglang/srt/disaggregation/prefill.py +3 -2
  10. sglang/srt/disaggregation/utils.py +12 -11
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/openai/protocol.py +47 -4
  13. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  14. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  15. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  16. sglang/srt/layers/activation.py +7 -0
  17. sglang/srt/layers/attention/flashattention_backend.py +24 -14
  18. sglang/srt/layers/layernorm.py +15 -0
  19. sglang/srt/layers/linear.py +18 -1
  20. sglang/srt/layers/logits_processor.py +12 -3
  21. sglang/srt/layers/moe/ep_moe/layer.py +79 -12
  22. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  23. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
  25. sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
  26. sglang/srt/layers/moe/topk.py +26 -0
  27. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  28. sglang/srt/layers/rotary_embedding.py +103 -11
  29. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  30. sglang/srt/managers/expert_distribution.py +21 -0
  31. sglang/srt/managers/io_struct.py +10 -2
  32. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  33. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  34. sglang/srt/managers/schedule_batch.py +9 -1
  35. sglang/srt/managers/scheduler.py +42 -6
  36. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  37. sglang/srt/model_executor/model_runner.py +5 -2
  38. sglang/srt/model_loader/loader.py +45 -10
  39. sglang/srt/model_loader/weight_utils.py +89 -0
  40. sglang/srt/models/deepseek_nextn.py +7 -4
  41. sglang/srt/models/deepseek_v2.py +147 -4
  42. sglang/srt/models/gemma3n_audio.py +949 -0
  43. sglang/srt/models/gemma3n_causal.py +1009 -0
  44. sglang/srt/models/gemma3n_mm.py +511 -0
  45. sglang/srt/models/hunyuan.py +771 -0
  46. sglang/srt/server_args.py +16 -2
  47. sglang/srt/two_batch_overlap.py +4 -1
  48. sglang/srt/utils.py +71 -0
  49. sglang/version.py +1 -1
  50. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
  51. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
  52. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  54. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -214,6 +214,10 @@ class MultimodalDataItem:
214
214
  audio_feature_lens: Optional[List[torch.Tensor]] = None
215
215
  audio_offsets: Optional[List[Tuple[int, int]]] = None
216
216
 
217
+ # gemma3n related
218
+ input_features: Optional[torch.Tensor] = None
219
+ input_features_mask: Optional[torch.Tensor] = None
220
+
217
221
  precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
218
222
 
219
223
  @staticmethod
@@ -277,7 +281,10 @@ class MultimodalDataItem:
277
281
  if self.precomputed_features is not None:
278
282
  self.hash = hash_feature(self.precomputed_features)
279
283
  elif self.is_audio():
280
- self.hash = hash_feature(self.audio_features)
284
+ if self.audio_features is not None:
285
+ self.hash = hash_feature(self.audio_features)
286
+ elif self.input_features is not None:
287
+ self.hash = hash_feature(self.input_features)
281
288
  else:
282
289
  self.hash = hash_feature(self.pixel_values)
283
290
 
@@ -288,6 +295,7 @@ class MultimodalDataItem:
288
295
  return (self.modality == Modality.AUDIO) and (
289
296
  self.precomputed_features is not None
290
297
  or not MultimodalDataItem.is_empty_list(self.audio_features)
298
+ or not MultimodalDataItem.is_empty_list(self.input_features)
291
299
  )
292
300
 
293
301
  def is_image(self):
@@ -182,6 +182,18 @@ class EmbeddingBatchResult:
182
182
  bid: int
183
183
 
184
184
 
185
+ class KvMetrics:
186
+ def __init__(self):
187
+ self.request_active_slots = None
188
+ self.request_total_slots = None
189
+ self.kv_active_blocks = None
190
+ self.kv_total_blocks = None
191
+ self.num_requests_waiting = None
192
+ self.gpu_cache_usage_perc = None
193
+ self.gpu_prefix_cache_hit_rate = None
194
+ self.data_parallel_rank = None
195
+
196
+
185
197
  class IdleSleeper:
186
198
  """
187
199
  In setups which have long inactivity periods it is desirable to reduce
@@ -223,6 +235,7 @@ class Scheduler(
223
235
  self.server_args = server_args
224
236
  self.tp_rank = tp_rank
225
237
  self.pp_rank = pp_rank
238
+ self.dp_rank = dp_rank
226
239
  self.tp_size = server_args.tp_size
227
240
  self.pp_size = server_args.pp_size
228
241
  self.dp_size = server_args.dp_size
@@ -261,6 +274,9 @@ class Scheduler(
261
274
  self.send_to_tokenizer = get_zmq_socket(
262
275
  context, zmq.PUSH, port_args.tokenizer_ipc_name, False
263
276
  )
277
+ self.send_metrics_from_scheduler = get_zmq_socket(
278
+ context, zmq.PUSH, port_args.metrics_ipc_name, False
279
+ )
264
280
 
265
281
  if server_args.skip_tokenizer_init:
266
282
  # Directly send to the TokenizerManager
@@ -286,6 +302,7 @@ class Scheduler(
286
302
  else:
287
303
  self.recv_from_tokenizer = None
288
304
  self.recv_from_rpc = None
305
+ self.send_metrics_from_scheduler = None
289
306
  self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
290
307
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
291
308
 
@@ -1239,6 +1256,22 @@ class Scheduler(
1239
1256
  req.logprob_start_len = len(req.origin_input_ids) - 1
1240
1257
  self._add_request_to_queue(req)
1241
1258
 
1259
+ def _emit_kv_metrics(self):
1260
+ kv_metrics = KvMetrics()
1261
+ kv_metrics.request_active_slots = self.stats.num_running_reqs
1262
+ kv_metrics.request_total_slots = self.max_running_requests
1263
+ kv_metrics.kv_active_blocks = int(
1264
+ self.stats.token_usage * self.max_total_num_tokens
1265
+ )
1266
+ kv_metrics.kv_total_blocks = self.max_total_num_tokens
1267
+ kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
1268
+ kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
1269
+ kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
1270
+ kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
1271
+
1272
+ if not self.send_metrics_from_scheduler.closed:
1273
+ self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
1274
+
1242
1275
  def log_prefill_stats(
1243
1276
  self,
1244
1277
  adder: PrefillAdder,
@@ -1291,6 +1324,7 @@ class Scheduler(
1291
1324
  self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
1292
1325
 
1293
1326
  self.metrics_collector.log_stats(self.stats)
1327
+ self._emit_kv_metrics()
1294
1328
  self._publish_kv_events()
1295
1329
 
1296
1330
  def log_decode_stats(
@@ -1352,6 +1386,7 @@ class Scheduler(
1352
1386
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1353
1387
  self.stats.spec_accept_length = spec_accept_length
1354
1388
  self.metrics_collector.log_stats(self.stats)
1389
+ self._emit_kv_metrics()
1355
1390
  self._publish_kv_events()
1356
1391
 
1357
1392
  def check_memory(self):
@@ -2201,8 +2236,8 @@ class Scheduler(
2201
2236
  """In-place update of the weights from disk."""
2202
2237
  success, message = self.tp_worker.update_weights_from_disk(recv_req)
2203
2238
  if success:
2204
- flash_cache_success = self.flush_cache()
2205
- assert flash_cache_success, "Cache flush failed after updating weights"
2239
+ flush_cache_success = self.flush_cache()
2240
+ assert flush_cache_success, "Cache flush failed after updating weights"
2206
2241
  else:
2207
2242
  logger.error(message)
2208
2243
  return UpdateWeightFromDiskReqOutput(success, message, 0)
@@ -2219,8 +2254,8 @@ class Scheduler(
2219
2254
  """Update the online model parameter."""
2220
2255
  success, message = self.tp_worker.update_weights_from_distributed(recv_req)
2221
2256
  if success:
2222
- flash_cache_success = self.flush_cache()
2223
- assert flash_cache_success, "Cache flush failed after updating weights"
2257
+ flush_cache_success = self.flush_cache()
2258
+ assert flush_cache_success, "Cache flush failed after updating weights"
2224
2259
  else:
2225
2260
  logger.error(message)
2226
2261
  return UpdateWeightsFromDistributedReqOutput(success, message)
@@ -2231,10 +2266,11 @@ class Scheduler(
2231
2266
  # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
2232
2267
  if success:
2233
2268
  if recv_req.flush_cache:
2234
- flash_cache_success = self.flush_cache()
2235
- assert flash_cache_success, "Cache flush failed after updating weights"
2269
+ flush_cache_success = self.flush_cache()
2270
+ assert flush_cache_success, "Cache flush failed after updating weights"
2236
2271
  else:
2237
2272
  logger.error(message)
2273
+ barrier(group=self.tp_cpu_group)
2238
2274
  return UpdateWeightsFromTensorReqOutput(success, message)
2239
2275
 
2240
2276
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
@@ -421,7 +421,7 @@ class CudaGraphRunner:
421
421
  empty_cache=False,
422
422
  )
423
423
  capture_range.set_description(
424
- f"Capturing batches ({avail_mem=:.2f} GB)"
424
+ f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
425
425
  )
426
426
 
427
427
  with patch_model(
@@ -239,7 +239,7 @@ class ModelRunner:
239
239
  "SGLANG_LOG_EXPERT_LOCATION_METADATA"
240
240
  ):
241
241
  logger.info(
242
- f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
242
+ f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
243
243
  )
244
244
 
245
245
  set_global_expert_distribution_recorder(
@@ -547,6 +547,7 @@ class ModelRunner:
547
547
  self.load_config = LoadConfig(
548
548
  load_format=self.server_args.load_format,
549
549
  download_dir=self.server_args.download_dir,
550
+ model_loader_extra_config=self.server_args.model_loader_extra_config,
550
551
  )
551
552
  if self.server_args.load_format == "gguf":
552
553
  monkey_patch_vllm_gguf_config()
@@ -865,7 +866,9 @@ class ModelRunner:
865
866
  else:
866
867
  self.kv_cache_dtype = torch.float8_e5m2
867
868
  elif self.server_args.kv_cache_dtype == "fp8_e4m3":
868
- if is_cuda():
869
+ if _is_hip: # Using natively supported format
870
+ self.kv_cache_dtype = torch.float8_e4m3fnuz
871
+ else:
869
872
  self.kv_cache_dtype = torch.float8_e4m3fn
870
873
  else:
871
874
  raise ValueError(
@@ -2,6 +2,7 @@
2
2
 
3
3
  # ruff: noqa: SIM117
4
4
  import collections
5
+ import concurrent
5
6
  import dataclasses
6
7
  import fnmatch
7
8
  import glob
@@ -11,14 +12,17 @@ import math
11
12
  import os
12
13
  import time
13
14
  from abc import ABC, abstractmethod
15
+ from concurrent.futures import ThreadPoolExecutor
14
16
  from contextlib import contextmanager
15
17
  from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
16
18
 
17
19
  import huggingface_hub
18
20
  import numpy as np
21
+ import safetensors.torch
19
22
  import torch
20
23
  from huggingface_hub import HfApi, hf_hub_download
21
24
  from torch import nn
25
+ from tqdm.auto import tqdm
22
26
  from transformers import AutoModelForCausalLM
23
27
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
24
28
 
@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
41
45
  set_default_torch_dtype,
42
46
  )
43
47
  from sglang.srt.model_loader.weight_utils import (
48
+ _BAR_FORMAT,
44
49
  download_safetensors_index_file_from_hf,
45
50
  download_weights_from_hf,
46
51
  filter_duplicate_safetensors_files,
@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
49
54
  get_quant_config,
50
55
  gguf_quant_weights_iterator,
51
56
  initialize_dummy_weights,
57
+ multi_thread_pt_weights_iterator,
58
+ multi_thread_safetensors_weights_iterator,
52
59
  np_cache_weights_iterator,
53
60
  pt_weights_iterator,
54
61
  safetensors_weights_iterator,
@@ -181,6 +188,9 @@ class BaseModelLoader(ABC):
181
188
  class DefaultModelLoader(BaseModelLoader):
182
189
  """Model loader that can load different file types from disk."""
183
190
 
191
+ # default number of thread when enable multithread weight loading
192
+ DEFAULT_NUM_THREADS = 8
193
+
184
194
  @dataclasses.dataclass
185
195
  class Source:
186
196
  """A source for weights."""
@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader):
208
218
 
209
219
  def __init__(self, load_config: LoadConfig):
210
220
  super().__init__(load_config)
211
- if load_config.model_loader_extra_config:
221
+ extra_config = load_config.model_loader_extra_config
222
+ allowed_keys = {"enable_multithread_load", "num_threads"}
223
+ unexpected_keys = set(extra_config.keys()) - allowed_keys
224
+
225
+ if unexpected_keys:
212
226
  raise ValueError(
213
- f"Model loader extra config is not supported for "
214
- f"load format {load_config.load_format}"
227
+ f"Unexpected extra config keys for load format "
228
+ f"{load_config.load_format}: "
229
+ f"{unexpected_keys}"
215
230
  )
216
231
 
217
232
  def _maybe_download_from_modelscope(
@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader):
324
339
  self, source: "Source"
325
340
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
326
341
  """Get an iterator for the model weights based on the load format."""
342
+ extra_config = self.load_config.model_loader_extra_config
327
343
  hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
328
344
  source.model_or_path, source.revision, source.fall_back_to_pt
329
345
  )
@@ -342,11 +358,30 @@ class DefaultModelLoader(BaseModelLoader):
342
358
  weight_loader_disable_mmap = global_server_args_dict.get(
343
359
  "weight_loader_disable_mmap"
344
360
  )
345
- weights_iterator = safetensors_weights_iterator(
346
- hf_weights_files, disable_mmap=weight_loader_disable_mmap
347
- )
361
+
362
+ if extra_config.get("enable_multithread_load"):
363
+ weights_iterator = multi_thread_safetensors_weights_iterator(
364
+ hf_weights_files,
365
+ max_workers=extra_config.get(
366
+ "num_threads", self.DEFAULT_NUM_THREADS
367
+ ),
368
+ disable_mmap=weight_loader_disable_mmap,
369
+ )
370
+ else:
371
+ weights_iterator = safetensors_weights_iterator(
372
+ hf_weights_files, disable_mmap=weight_loader_disable_mmap
373
+ )
374
+
348
375
  else:
349
- weights_iterator = pt_weights_iterator(hf_weights_files)
376
+ if extra_config.get("enable_multithread_load"):
377
+ weights_iterator = multi_thread_pt_weights_iterator(
378
+ hf_weights_files,
379
+ max_workers=extra_config.get(
380
+ "num_threads", self.DEFAULT_NUM_THREADS
381
+ ),
382
+ )
383
+ else:
384
+ weights_iterator = pt_weights_iterator(hf_weights_files)
350
385
 
351
386
  # Apply the prefix.
352
387
  return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
@@ -385,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader):
385
420
  self.load_config,
386
421
  )
387
422
 
388
- self.load_weights_and_postprocess(
389
- model, self._get_all_weights(model_config, model), target_device
390
- )
423
+ self.load_weights_and_postprocess(
424
+ model, self._get_all_weights(model_config, model), target_device
425
+ )
391
426
 
392
427
  return model.eval()
393
428
 
@@ -1,12 +1,14 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
2
2
 
3
3
  """Utilities for downloading and initializing model weights."""
4
+ import concurrent.futures
4
5
  import fnmatch
5
6
  import glob
6
7
  import hashlib
7
8
  import json
8
9
  import logging
9
10
  import os
11
+ import queue
10
12
  import tempfile
11
13
  from collections import defaultdict
12
14
  from typing import (
@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
453
455
  yield name, param
454
456
 
455
457
 
458
+ def multi_thread_safetensors_weights_iterator(
459
+ hf_weights_files: List[str],
460
+ is_all_weights_sharded: bool = False,
461
+ decryption_key: Optional[str] = None,
462
+ max_workers: int = 4,
463
+ disable_mmap: bool = False,
464
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
465
+ """Multi-Thread iterate over the weights in the model safetensor files.
466
+
467
+ If is_all_weights_sharded is True, it uses more optimize read by reading an
468
+ entire file instead of reading each tensor one by one.
469
+ """
470
+ if decryption_key:
471
+ logger.warning(
472
+ "Multi-Thread loading is not working for encrypted safetensor weights."
473
+ )
474
+ yield from safetensors_encrypted_weights_iterator(
475
+ hf_weights_files, is_all_weights_sharded, decryption_key
476
+ )
477
+ return
478
+
479
+ enable_tqdm = (
480
+ not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
481
+ )
482
+
483
+ def _load_file(st_file: str):
484
+ if disable_mmap:
485
+ with open(st_file, "rb") as f:
486
+ result = safetensors.torch.load(f.read())
487
+ else:
488
+ result = safetensors.torch.load_file(st_file, device="cpu")
489
+
490
+ return result
491
+
492
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
493
+ futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
494
+
495
+ if enable_tqdm:
496
+ futures_iter = tqdm(
497
+ concurrent.futures.as_completed(futures),
498
+ total=len(hf_weights_files),
499
+ desc="Multi-thread loading shards",
500
+ disable=not enable_tqdm,
501
+ bar_format=_BAR_FORMAT,
502
+ )
503
+ else:
504
+ futures_iter = concurrent.futures.as_completed(futures)
505
+
506
+ for future in futures_iter:
507
+ state_dict = future.result()
508
+ for name, param in state_dict.items():
509
+ yield name, param
510
+
511
+
456
512
  def pt_weights_iterator(
457
513
  hf_weights_files: List[str],
458
514
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
@@ -471,6 +527,39 @@ def pt_weights_iterator(
471
527
  del state
472
528
 
473
529
 
530
+ def multi_thread_pt_weights_iterator(
531
+ hf_weights_files: List[str],
532
+ max_workers: int = 4,
533
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
534
+ """Multi-Thread iterate over the weights in the model bin/pt files."""
535
+ enable_tqdm = (
536
+ not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
537
+ )
538
+
539
+ def _load_file(bin_file: str):
540
+ return torch.load(bin_file, map_location="cpu", weights_only=True)
541
+
542
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
543
+ futures = [
544
+ executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
545
+ ]
546
+
547
+ if enable_tqdm:
548
+ futures_iter = tqdm(
549
+ concurrent.futures.as_completed(futures),
550
+ total=len(hf_weights_files),
551
+ desc="Multi-thread loading pt checkpoint shards",
552
+ disable=not enable_tqdm,
553
+ bar_format=_BAR_FORMAT,
554
+ )
555
+ else:
556
+ futures_iter = concurrent.futures.as_completed(futures)
557
+
558
+ for future in futures_iter:
559
+ state = future.result()
560
+ yield from state.items()
561
+
562
+
474
563
  def get_gguf_extra_tensor_names(
475
564
  gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
476
565
  ) -> List[str]:
@@ -28,6 +28,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
28
28
  ParallelLMHead,
29
29
  VocabParallelEmbedding,
30
30
  )
31
+ from sglang.srt.managers.expert_distribution import (
32
+ get_global_expert_distribution_recorder,
33
+ )
31
34
  from sglang.srt.managers.schedule_batch import global_server_args_dict
32
35
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
33
36
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
@@ -82,7 +85,6 @@ class DeepseekModelNextN(nn.Module):
82
85
  forward_batch: ForwardBatch,
83
86
  input_embeds: torch.Tensor = None,
84
87
  ) -> torch.Tensor:
85
-
86
88
  zero_allocator = BumpAllocator(
87
89
  buffer_size=2,
88
90
  dtype=torch.float32,
@@ -108,9 +110,10 @@ class DeepseekModelNextN(nn.Module):
108
110
  )
109
111
 
110
112
  residual = None
111
- hidden_states, residual = self.decoder(
112
- positions, hidden_states, forward_batch, residual, zero_allocator
113
- )
113
+ with get_global_expert_distribution_recorder().disable_this_region():
114
+ hidden_states, residual = self.decoder(
115
+ positions, hidden_states, forward_batch, residual, zero_allocator
116
+ )
114
117
 
115
118
  if not forward_batch.forward_mode.is_idle():
116
119
  if residual is not None:
@@ -93,6 +93,7 @@ from sglang.srt.utils import (
93
93
  BumpAllocator,
94
94
  DeepEPMode,
95
95
  LazyValue,
96
+ PackWeightMethod,
96
97
  add_prefix,
97
98
  bind_or_assign,
98
99
  cpu_has_amx_support,
@@ -124,8 +125,6 @@ if _is_hip:
124
125
  decode_attention_fwd_grouped_rope,
125
126
  )
126
127
 
127
- if _use_aiter:
128
- from aiter.rotary_embedding import get_rope
129
128
 
130
129
  logger = logging.getLogger(__name__)
131
130
 
@@ -144,6 +143,9 @@ class AttnForwardMethod(IntEnum):
144
143
  # Use MLA but with fused RoPE
145
144
  MLA_FUSED_ROPE = auto()
146
145
 
146
+ # Use MLA with fused RoPE kernel for CPU
147
+ MLA_FUSED_ROPE_CPU = auto()
148
+
147
149
 
148
150
  class DeepseekV2MLP(nn.Module):
149
151
  def __init__(
@@ -212,8 +214,18 @@ class MoEGate(nn.Module):
212
214
  )
213
215
  else:
214
216
  self.e_score_correction_bias = None
217
+ if _is_cpu and _is_cpu_amx_available:
218
+ self.quant_method = PackWeightMethod(weight_names=["weight"])
215
219
 
216
220
  def forward(self, hidden_states):
221
+ if getattr(self, "use_intel_amx_backend", False):
222
+ return torch.ops.sgl_kernel.weight_packed_linear(
223
+ hidden_states,
224
+ self.weight,
225
+ None, # bias
226
+ True, # is_vnni
227
+ )
228
+
217
229
  logits = F.linear(hidden_states, self.weight, None)
218
230
  return logits
219
231
 
@@ -388,7 +400,8 @@ class DeepseekV2MoE(nn.Module):
388
400
  final_hidden_states = self.experts(
389
401
  hidden_states=hidden_states, router_logits=router_logits
390
402
  )
391
- if not _is_cuda:
403
+ if not _is_cuda and not _use_aiter:
404
+ # fused in biased_grouped_topk so we can skip here
392
405
  final_hidden_states *= self.routed_scaling_factor
393
406
  if shared_output is not None:
394
407
  final_hidden_states = final_hidden_states + shared_output
@@ -777,6 +790,37 @@ class DeepseekV2AttentionMLA(nn.Module):
777
790
  "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
778
791
  )
779
792
 
793
+ # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
794
+ # which requires self.w_kc and self.w_vc to be packed.
795
+ # If not, we will use torch.bmm and weight shouldn't be packed in this case
796
+ if (
797
+ hasattr(self, "fused_qkv_a_proj_with_mqa")
798
+ and _is_cpu
799
+ and _is_cpu_amx_available
800
+ ):
801
+ self.quant_method = PackWeightMethod(
802
+ weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
803
+ )
804
+
805
+ self.qkv_proj_with_rope_is_int8 = (
806
+ hasattr(self, "fused_qkv_a_proj_with_mqa")
807
+ and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
808
+ )
809
+ self.qkv_proj_with_rope_is_fp8 = (
810
+ hasattr(self, "fused_qkv_a_proj_with_mqa")
811
+ and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
812
+ )
813
+
814
+ self.weight_block_size = None
815
+ if self.qkv_proj_with_rope_is_fp8:
816
+ assert (
817
+ self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
818
+ == self.q_b_proj.quant_method.quant_config.weight_block_size
819
+ )
820
+ self.weight_block_size = (
821
+ self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
822
+ )
823
+
780
824
  def dispatch_attn_forward_method(
781
825
  self, forward_batch: ForwardBatch
782
826
  ) -> AttnForwardMethod:
@@ -790,7 +834,12 @@ class DeepseekV2AttentionMLA(nn.Module):
790
834
  else:
791
835
  return AttnForwardMethod.MLA
792
836
  else:
793
- return AttnForwardMethod.MLA
837
+ if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr(
838
+ self, "use_intel_amx_backend", False
839
+ ):
840
+ return AttnForwardMethod.MLA_FUSED_ROPE_CPU
841
+ else:
842
+ return AttnForwardMethod.MLA
794
843
 
795
844
  if self.attention_backend == "flashinfer":
796
845
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
@@ -904,6 +953,10 @@ class DeepseekV2AttentionMLA(nn.Module):
904
953
  inner_state = self.forward_absorb_fused_mla_rope_prepare(
905
954
  positions, hidden_states, forward_batch, zero_allocator
906
955
  )
956
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
957
+ inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
958
+ positions, hidden_states, forward_batch, zero_allocator
959
+ )
907
960
  else:
908
961
  raise NotImplementedError
909
962
  return None, attn_forward_method, forward_batch, inner_state
@@ -923,6 +976,8 @@ class DeepseekV2AttentionMLA(nn.Module):
923
976
  return self.forward_absorb_core(*inner_state)
924
977
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
925
978
  return self.forward_absorb_fused_mla_rope_core(*inner_state)
979
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
980
+ return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
926
981
  else:
927
982
  raise NotImplementedError
928
983
 
@@ -1240,6 +1295,57 @@ class DeepseekV2AttentionMLA(nn.Module):
1240
1295
  zero_allocator,
1241
1296
  )
1242
1297
 
1298
+ def forward_absorb_fused_mla_rope_cpu_prepare(
1299
+ self,
1300
+ positions: torch.Tensor,
1301
+ hidden_states: torch.Tensor,
1302
+ forward_batch: ForwardBatch,
1303
+ zero_allocator: BumpAllocator,
1304
+ ):
1305
+ assert self.q_lora_rank is not None and getattr(
1306
+ self, "use_intel_amx_backend", False
1307
+ ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
1308
+
1309
+ q_input, k_input, v_input = (
1310
+ torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
1311
+ hidden_states,
1312
+ self.fused_qkv_a_proj_with_mqa.weight,
1313
+ self.q_b_proj.weight,
1314
+ self.w_kc,
1315
+ self.q_a_layernorm.weight,
1316
+ self.kv_a_layernorm.weight,
1317
+ positions,
1318
+ self.rotary_emb.cos_sin_cache,
1319
+ self.kv_a_layernorm.variance_epsilon,
1320
+ self.qkv_proj_with_rope_is_int8,
1321
+ self.qkv_proj_with_rope_is_fp8,
1322
+ (
1323
+ self.fused_qkv_a_proj_with_mqa.weight_scale
1324
+ if self.qkv_proj_with_rope_is_int8
1325
+ else (
1326
+ self.fused_qkv_a_proj_with_mqa.weight_scale_inv
1327
+ if self.qkv_proj_with_rope_is_fp8
1328
+ else None
1329
+ )
1330
+ ),
1331
+ (
1332
+ self.q_b_proj.weight_scale
1333
+ if self.qkv_proj_with_rope_is_int8
1334
+ else (
1335
+ self.q_b_proj.weight_scale_inv
1336
+ if self.qkv_proj_with_rope_is_fp8
1337
+ else None
1338
+ )
1339
+ ),
1340
+ True, # is_vnni
1341
+ self.weight_block_size,
1342
+ self.q_lora_rank,
1343
+ self.kv_lora_rank,
1344
+ self.qk_rope_head_dim,
1345
+ )
1346
+ )
1347
+ return (q_input, k_input, v_input, forward_batch, zero_allocator)
1348
+
1243
1349
  def forward_absorb_fused_mla_rope_core(
1244
1350
  self,
1245
1351
  q_input,
@@ -1313,6 +1419,43 @@ class DeepseekV2AttentionMLA(nn.Module):
1313
1419
 
1314
1420
  return output
1315
1421
 
1422
+ def forward_absorb_fused_mla_rope_cpu_core(
1423
+ self, q_input, k_input, v_input, forward_batch, zero_allocator
1424
+ ):
1425
+ assert self.q_lora_rank is not None and getattr(
1426
+ self, "use_intel_amx_backend", False
1427
+ ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
1428
+
1429
+ attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
1430
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1431
+
1432
+ # [Note] Align shapes of bmm inputs.
1433
+ # Shapes of inputs:
1434
+ # q_nope: [M, B, K]
1435
+ # original self.w_kc: [B, K, N]
1436
+ # current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]
1437
+
1438
+ # Shapes of inputs to sgl_kernel.cpu.bmm:
1439
+ # out: [B, M, N]
1440
+ # mat1: [B, M, K]
1441
+ # mat2: [B, N, K]
1442
+ B = self.w_vc.size(0)
1443
+ N = self.w_vc.size(1)
1444
+ M = attn_output.size(0)
1445
+ output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
1446
+ attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
1447
+ torch.ops.sgl_kernel.bmm_cpu(
1448
+ attn_bmm_output,
1449
+ attn_output.transpose(0, 1),
1450
+ self.w_vc,
1451
+ True, # is_vnni
1452
+ None, # scale
1453
+ )
1454
+ attn_output = output
1455
+ output, _ = self.o_proj(attn_output)
1456
+
1457
+ return output
1458
+
1316
1459
  def _chunked_prefix_attn_mha(
1317
1460
  self,
1318
1461
  q: torch.Tensor,