sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -26,10 +26,10 @@ from typing import List, Optional, Tuple, Union
26
26
  import torch
27
27
  import torch.distributed as dist
28
28
 
29
- from sglang.srt import debug_utils
30
29
  from sglang.srt.configs.device_config import DeviceConfig
31
30
  from sglang.srt.configs.load_config import LoadConfig
32
31
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
32
+ from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
33
33
  from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
34
34
  from sglang.srt.distributed import (
35
35
  get_tp_group,
@@ -40,6 +40,19 @@ from sglang.srt.distributed import (
40
40
  set_mscclpp_all_reduce,
41
41
  )
42
42
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
43
+ from sglang.srt.eplb.eplb_manager import EPLBManager
44
+ from sglang.srt.eplb.expert_distribution import (
45
+ ExpertDistributionRecorder,
46
+ get_global_expert_distribution_recorder,
47
+ set_global_expert_distribution_recorder,
48
+ )
49
+ from sglang.srt.eplb.expert_location import (
50
+ ExpertLocationMetadata,
51
+ compute_initial_expert_location_metadata,
52
+ get_global_expert_location_metadata,
53
+ set_global_expert_location_metadata,
54
+ )
55
+ from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
43
56
  from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
44
57
  from sglang.srt.layers.dp_attention import (
45
58
  get_attention_tp_group,
@@ -55,35 +68,27 @@ from sglang.srt.layers.sampler import Sampler
55
68
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
56
69
  from sglang.srt.layers.utils import is_sm100_supported
57
70
  from sglang.srt.lora.lora_manager import LoRAManager
58
- from sglang.srt.managers.eplb_manager import EPLBManager
59
- from sglang.srt.managers.expert_distribution import (
60
- ExpertDistributionRecorder,
61
- get_global_expert_distribution_recorder,
62
- set_global_expert_distribution_recorder,
63
- )
64
- from sglang.srt.managers.expert_location import (
65
- ExpertLocationMetadata,
66
- compute_initial_expert_location_metadata,
67
- get_global_expert_location_metadata,
68
- set_global_expert_location_metadata,
69
- )
70
71
  from sglang.srt.managers.schedule_batch import (
71
72
  GLOBAL_SERVER_ARGS_KEYS,
72
73
  global_server_args_dict,
73
74
  )
74
75
  from sglang.srt.mem_cache.allocator import (
76
+ AscendPagedTokenToKVPoolAllocator,
75
77
  BaseTokenToKVPoolAllocator,
76
78
  PagedTokenToKVPoolAllocator,
79
+ SWATokenToKVPoolAllocator,
77
80
  TokenToKVPoolAllocator,
78
81
  )
79
82
  from sglang.srt.mem_cache.memory_pool import (
83
+ AscendMLAPagedTokenToKVPool,
84
+ AscendTokenToKVPool,
80
85
  DoubleSparseTokenToKVPool,
81
86
  MHATokenToKVPool,
82
87
  MLATokenToKVPool,
83
88
  ReqToTokenPool,
89
+ SWAKVPool,
84
90
  )
85
91
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
86
- from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
87
92
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
88
93
  from sglang.srt.model_loader import get_model
89
94
  from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
@@ -101,6 +106,7 @@ from sglang.srt.utils import (
101
106
  enable_show_time_cost,
102
107
  get_available_gpu_memory,
103
108
  get_bool_env_var,
109
+ get_cpu_ids_by_node,
104
110
  init_custom_process_group,
105
111
  is_cuda,
106
112
  is_fa3_default_architecture,
@@ -108,6 +114,7 @@ from sglang.srt.utils import (
108
114
  is_hip,
109
115
  is_hopper_with_cuda_12_3,
110
116
  is_no_spec_infer_or_topk_one,
117
+ is_npu,
111
118
  monkey_patch_p2p_access_check,
112
119
  monkey_patch_vllm_gguf_config,
113
120
  set_cpu_offload_max_bytes,
@@ -115,6 +122,7 @@ from sglang.srt.utils import (
115
122
  )
116
123
 
117
124
  _is_hip = is_hip()
125
+ _is_npu = is_npu()
118
126
  _is_cpu_amx_available = cpu_has_amx_support()
119
127
 
120
128
  # Use a small KV cache pool size for tests in CI
@@ -158,7 +166,6 @@ class ModelRunner:
158
166
  token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
159
167
  ):
160
168
  # Parse args
161
- self.model_config = model_config
162
169
  self.mem_fraction_static = mem_fraction_static
163
170
  self.device = server_args.device
164
171
  self.gpu_id = gpu_id
@@ -171,6 +178,7 @@ class ModelRunner:
171
178
  self.dp_size = server_args.dp_size
172
179
  self.pp_rank = pp_rank
173
180
  self.pp_size = pp_size
181
+ self.model_config = model_config
174
182
  self.dist_port = nccl_port
175
183
  self.server_args = server_args
176
184
  self.is_draft_worker = is_draft_worker
@@ -185,6 +193,7 @@ class ModelRunner:
185
193
  self.page_size = server_args.page_size
186
194
  self.req_to_token_pool = req_to_token_pool
187
195
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
196
+ self.is_hybrid = model_config.is_hybrid
188
197
  self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
189
198
  self.attention_chunk_size = model_config.attention_chunk_size
190
199
 
@@ -209,6 +218,10 @@ class ModelRunner:
209
218
  # CPU offload
210
219
  set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
211
220
 
221
+ # Init OpenMP threads binding for CPU
222
+ if self.device == "cpu":
223
+ self.init_threads_binding()
224
+
212
225
  # Get memory before model loading
213
226
  min_per_gpu_memory = self.init_torch_distributed()
214
227
 
@@ -223,6 +236,7 @@ class ModelRunner:
223
236
  self.support_pp = (
224
237
  "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
225
238
  )
239
+ self._model_update_group = {}
226
240
 
227
241
  def initialize(self, min_per_gpu_memory: float):
228
242
  server_args = self.server_args
@@ -300,11 +314,31 @@ class ModelRunner:
300
314
  self.init_cuda_graphs()
301
315
  else:
302
316
  self.cuda_graph_runner = None
317
+ self.cuda_graph_mem_usage = 0
303
318
  self.init_attention_backend()
304
319
 
305
320
  # auxiliary hidden capture mode. TODO: expose this to server args?
306
321
  if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
307
- self.model.set_eagle3_layers_to_capture()
322
+ # load draft config
323
+ draft_model_config = ModelConfig.from_server_args(
324
+ server_args,
325
+ model_path=(server_args.speculative_draft_model_path),
326
+ is_draft_model=True,
327
+ )
328
+
329
+ try:
330
+ # get the aux layer from draft model config
331
+ eagle_config = getattr(
332
+ draft_model_config.hf_config, "eagle_config", None
333
+ )
334
+ eagle_aux_hidden_state_layer_ids = eagle_config[
335
+ "eagle_aux_hidden_state_layer_ids"
336
+ ]
337
+ except:
338
+ # if there is no aux layer, set to None
339
+ eagle_aux_hidden_state_layer_ids = None
340
+
341
+ self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
308
342
 
309
343
  def model_specific_adjustment(self):
310
344
  server_args = self.server_args
@@ -342,6 +376,8 @@ class ModelRunner:
342
376
  server_args.attention_backend = "fa3"
343
377
  elif _is_hip:
344
378
  server_args.attention_backend = "aiter"
379
+ elif _is_npu:
380
+ server_args.attention_backend = "ascend"
345
381
  else:
346
382
  server_args.attention_backend = (
347
383
  "flashinfer" if is_flashinfer_available() else "triton"
@@ -361,6 +397,8 @@ class ModelRunner:
361
397
  server_args.attention_backend = "aiter"
362
398
  else:
363
399
  server_args.attention_backend = "triton"
400
+ elif _is_npu:
401
+ server_args.attention_backend = "ascend"
364
402
  else:
365
403
  server_args.attention_backend = "triton"
366
404
  logger.info(
@@ -375,6 +413,7 @@ class ModelRunner:
375
413
  "triton",
376
414
  "flashmla",
377
415
  "cutlass_mla",
416
+ "ascend",
378
417
  ]:
379
418
  logger.info(
380
419
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
@@ -412,11 +451,6 @@ class ModelRunner:
412
451
  self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
413
452
 
414
453
  if self.is_multimodal:
415
- self.mem_fraction_static *= 0.90
416
- logger.info(
417
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
418
- f"because this is a multimodal model."
419
- )
420
454
  if not self.is_multimodal_chunked_prefill_supported:
421
455
  server_args.chunked_prefill_size = -1
422
456
  logger.info(
@@ -437,6 +471,10 @@ class ModelRunner:
437
471
  if self.model_config.context_len > 8192:
438
472
  self.mem_fraction_static *= 0.85
439
473
 
474
+ if self.is_hybrid and not server_args.disable_radix_cache:
475
+ logger.info("Automatically disable radix cache for hybrid cache.")
476
+ server_args.disable_radix_cache = True
477
+
440
478
  def init_torch_distributed(self):
441
479
  logger.info("Init torch distributed begin.")
442
480
 
@@ -471,6 +509,19 @@ class ModelRunner:
471
509
  set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
472
510
 
473
511
  if not self.is_draft_worker:
512
+ if self.device == "cpu":
513
+ if _is_cpu_amx_available:
514
+ # Bind OpenMP threads to CPU cores
515
+ torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
516
+
517
+ # Set local size to hint SGLang to use shared memory based AllReduce
518
+ os.environ["LOCAL_SIZE"] = str(self.tp_size)
519
+ torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
520
+ else:
521
+ logger.warning(
522
+ "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
523
+ )
524
+
474
525
  # Only initialize the distributed environment on the target model worker.
475
526
  init_distributed_environment(
476
527
  backend=backend,
@@ -549,6 +600,10 @@ class ModelRunner:
549
600
  download_dir=self.server_args.download_dir,
550
601
  model_loader_extra_config=self.server_args.model_loader_extra_config,
551
602
  )
603
+ if self.device == "cpu":
604
+ self.model_config = adjust_config_with_unaligned_cpu_tp(
605
+ self.model_config, self.load_config, self.tp_size
606
+ )
552
607
  if self.server_args.load_format == "gguf":
553
608
  monkey_patch_vllm_gguf_config()
554
609
 
@@ -598,12 +653,13 @@ class ModelRunner:
598
653
  self.dtype = self.model_config.dtype
599
654
 
600
655
  after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
656
+ self.weight_load_mem_usage = before_avail_memory - after_avail_memory
601
657
  logger.info(
602
658
  f"Load weight end. "
603
659
  f"type={type(self.model).__name__}, "
604
660
  f"dtype={self.dtype}, "
605
661
  f"avail mem={after_avail_memory:.2f} GB, "
606
- f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
662
+ f"mem usage={self.weight_load_mem_usage:.2f} GB."
607
663
  )
608
664
 
609
665
  # Handle the case where some ranks do not finish loading.
@@ -718,7 +774,7 @@ class ModelRunner:
718
774
  )
719
775
 
720
776
  try:
721
- self._model_update_group = init_custom_process_group(
777
+ self._model_update_group[group_name] = init_custom_process_group(
722
778
  backend=backend,
723
779
  init_method=f"tcp://{master_address}:{master_port}",
724
780
  world_size=world_size,
@@ -731,7 +787,7 @@ class ModelRunner:
731
787
  logger.error(message)
732
788
  return False, message
733
789
 
734
- def update_weights_from_distributed(self, name, dtype, shape):
790
+ def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
735
791
  """
736
792
  Update specific parameter in the model weights online
737
793
  through `_model_update_group` process group.
@@ -741,19 +797,34 @@ class ModelRunner:
741
797
  dtype: the data type of the parameter to be updated.
742
798
  shape: the shape of the parameter to be updated.
743
799
  """
744
- target_dtype = (
745
- dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
746
- )
747
800
 
748
- assert (
749
- self._model_update_group is not None
750
- ), "model update group must be initialized"
801
+ assert group_name in self._model_update_group, (
802
+ f"Group {group_name} not in {list(self._model_update_group.keys())}. "
803
+ "Please call `init_weights_update_group` first."
804
+ )
751
805
 
752
806
  try:
753
- weights = torch.empty(shape, dtype=target_dtype, device=self.device)
754
- torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
755
- self.model.load_weights([(name, weights)])
756
- return True, f"Succeeded to update parameter {name} online."
807
+ weights = []
808
+ handles = []
809
+ for name, dtype, shape in zip(names, dtypes, shapes):
810
+ target_dtype = (
811
+ dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
812
+ )
813
+ weight = torch.empty(shape, dtype=target_dtype, device=self.device)
814
+ handles.append(
815
+ torch.distributed.broadcast(
816
+ weight,
817
+ src=0,
818
+ group=self._model_update_group[group_name],
819
+ async_op=True,
820
+ )
821
+ )
822
+ weights.append((name, weight))
823
+ for handle in handles:
824
+ handle.wait()
825
+
826
+ self.model.load_weights(weights)
827
+ return True, f"Succeeded to update parameter online."
757
828
 
758
829
  except Exception as e:
759
830
  error_msg = (
@@ -812,8 +883,47 @@ class ModelRunner:
812
883
  tp_size=self.tp_size,
813
884
  tp_rank=self.tp_rank,
814
885
  )
815
- self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
816
- logger.info("LoRA manager ready.")
886
+ result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
887
+ if result.success:
888
+ logger.info(
889
+ f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
890
+ )
891
+ else:
892
+ raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
893
+
894
+ def load_lora_adapter(self, lora_name: str, lora_path: str):
895
+ """Load a new lora adapter from disk or huggingface."""
896
+
897
+ logger.info(
898
+ f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
899
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
900
+ )
901
+
902
+ result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
903
+
904
+ logger.info(
905
+ f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
906
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
907
+ )
908
+
909
+ return result
910
+
911
+ def unload_lora_adapter(self, lora_name: str):
912
+ """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
913
+
914
+ logger.info(
915
+ f"LoRA adapter unloading starts: name={lora_name}. "
916
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
917
+ )
918
+
919
+ result = self.lora_manager.unload_lora_adapter(lora_name)
920
+
921
+ logger.info(
922
+ f"LoRA adapter unloading completes: name={lora_name}. "
923
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
924
+ )
925
+
926
+ return result
817
927
 
818
928
  def profile_max_num_token(self, total_gpu_memory: int):
819
929
  available_gpu_memory = get_available_gpu_memory(
@@ -852,6 +962,40 @@ class ModelRunner:
852
962
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
853
963
  return max_num_token
854
964
 
965
+ def set_num_token_hybrid(self):
966
+ if (
967
+ "Llama4ForConditionalGeneration"
968
+ in self.model_config.hf_config.architectures
969
+ ):
970
+ temp_ratio = (
971
+ (1 - self.is_hybrid)
972
+ + self.is_hybrid
973
+ * self.attention_chunk_size
974
+ / self.model_config.context_len
975
+ )
976
+ self.swa_max_total_num_tokens = (
977
+ 4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
978
+ )
979
+ self.full_max_total_num_tokens = (
980
+ 4 * self.max_total_num_tokens
981
+ - 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
982
+ )
983
+ self.swa_max_total_num_tokens = int(
984
+ self.swa_max_total_num_tokens
985
+ // self.server_args.page_size
986
+ * self.server_args.page_size
987
+ )
988
+ self.full_max_total_num_tokens = int(
989
+ self.full_max_total_num_tokens
990
+ // self.server_args.page_size
991
+ * self.server_args.page_size
992
+ )
993
+ self.max_total_num_tokens = self.full_max_total_num_tokens
994
+ else:
995
+ raise ValueError(
996
+ f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}."
997
+ )
998
+
855
999
  def init_memory_pool(
856
1000
  self,
857
1001
  total_gpu_memory: int,
@@ -929,6 +1073,10 @@ class ModelRunner:
929
1073
  * self.server_args.page_size
930
1074
  )
931
1075
 
1076
+ # create token size for hybrid cache
1077
+ if self.is_hybrid:
1078
+ self.set_num_token_hybrid()
1079
+
932
1080
  if self.max_total_num_tokens <= 0:
933
1081
  raise RuntimeError(
934
1082
  "Not enough memory. Please try to increase --mem-fraction-static."
@@ -959,8 +1107,19 @@ class ModelRunner:
959
1107
  # Draft worker shares req_to_token_pool with the target worker.
960
1108
  assert self.is_draft_worker
961
1109
 
962
- if self.use_mla_backend:
963
- self.token_to_kv_pool = MLATokenToKVPool(
1110
+ if self.server_args.attention_backend == "ascend" and not self.use_mla_backend:
1111
+ self.token_to_kv_pool = AscendTokenToKVPool(
1112
+ self.max_total_num_tokens,
1113
+ page_size=self.page_size,
1114
+ dtype=self.kv_cache_dtype,
1115
+ head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
1116
+ head_dim=self.model_config.head_dim,
1117
+ layer_num=self.model_config.num_hidden_layers,
1118
+ device=self.device,
1119
+ enable_memory_saver=self.server_args.enable_memory_saver,
1120
+ )
1121
+ elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
1122
+ self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
964
1123
  self.max_total_num_tokens,
965
1124
  page_size=self.page_size,
966
1125
  dtype=self.kv_cache_dtype,
@@ -976,22 +1135,25 @@ class ModelRunner:
976
1135
  start_layer=self.start_layer,
977
1136
  end_layer=self.end_layer,
978
1137
  )
979
- elif self.server_args.enable_double_sparsity:
980
- self.token_to_kv_pool = DoubleSparseTokenToKVPool(
1138
+ elif self.use_mla_backend:
1139
+ self.token_to_kv_pool = MLATokenToKVPool(
981
1140
  self.max_total_num_tokens,
982
1141
  page_size=self.page_size,
983
1142
  dtype=self.kv_cache_dtype,
984
- head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
985
- head_dim=self.model_config.head_dim,
986
- layer_num=self.num_effective_layers,
1143
+ kv_lora_rank=self.model_config.kv_lora_rank,
1144
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1145
+ layer_num=(
1146
+ self.model_config.num_hidden_layers
1147
+ if not self.is_draft_worker
1148
+ else self.model_config.hf_config.num_nextn_predict_layers
1149
+ ), # PP is not compatible with mla backend
987
1150
  device=self.device,
988
- heavy_channel_num=self.server_args.ds_heavy_channel_num,
989
1151
  enable_memory_saver=self.server_args.enable_memory_saver,
990
1152
  start_layer=self.start_layer,
991
1153
  end_layer=self.end_layer,
992
1154
  )
993
- else:
994
- self.token_to_kv_pool = MHATokenToKVPool(
1155
+ elif self.server_args.enable_double_sparsity:
1156
+ self.token_to_kv_pool = DoubleSparseTokenToKVPool(
995
1157
  self.max_total_num_tokens,
996
1158
  page_size=self.page_size,
997
1159
  dtype=self.kv_cache_dtype,
@@ -999,27 +1161,76 @@ class ModelRunner:
999
1161
  head_dim=self.model_config.head_dim,
1000
1162
  layer_num=self.num_effective_layers,
1001
1163
  device=self.device,
1164
+ heavy_channel_num=self.server_args.ds_heavy_channel_num,
1002
1165
  enable_memory_saver=self.server_args.enable_memory_saver,
1003
1166
  start_layer=self.start_layer,
1004
1167
  end_layer=self.end_layer,
1005
1168
  )
1006
-
1007
- if self.token_to_kv_pool_allocator is None:
1008
- if self.page_size == 1:
1009
- self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
1010
- self.max_total_num_tokens,
1169
+ else:
1170
+ if self.is_hybrid:
1171
+ self.token_to_kv_pool = SWAKVPool(
1172
+ size=self.full_max_total_num_tokens,
1173
+ size_swa=self.swa_max_total_num_tokens,
1011
1174
  dtype=self.kv_cache_dtype,
1175
+ head_num=self.model_config.get_num_kv_heads(
1176
+ get_attention_tp_size()
1177
+ ),
1178
+ head_dim=self.model_config.head_dim,
1179
+ swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
1180
+ full_attention_layer_ids=self.model_config.full_attention_layer_ids,
1181
+ enable_kvcache_transpose=False,
1012
1182
  device=self.device,
1013
- kvcache=self.token_to_kv_pool,
1014
1183
  )
1015
1184
  else:
1016
- self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1185
+ self.token_to_kv_pool = MHATokenToKVPool(
1017
1186
  self.max_total_num_tokens,
1018
1187
  page_size=self.page_size,
1019
1188
  dtype=self.kv_cache_dtype,
1189
+ head_num=self.model_config.get_num_kv_heads(
1190
+ get_attention_tp_size()
1191
+ ),
1192
+ head_dim=self.model_config.head_dim,
1193
+ layer_num=self.num_effective_layers,
1020
1194
  device=self.device,
1021
- kvcache=self.token_to_kv_pool,
1195
+ enable_memory_saver=self.server_args.enable_memory_saver,
1196
+ start_layer=self.start_layer,
1197
+ end_layer=self.end_layer,
1022
1198
  )
1199
+
1200
+ if self.token_to_kv_pool_allocator is None:
1201
+ if self.page_size == 1:
1202
+ if self.is_hybrid:
1203
+ self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
1204
+ self.full_max_total_num_tokens,
1205
+ self.swa_max_total_num_tokens,
1206
+ dtype=self.kv_cache_dtype,
1207
+ device=self.device,
1208
+ kvcache=self.token_to_kv_pool,
1209
+ )
1210
+ else:
1211
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
1212
+ self.max_total_num_tokens,
1213
+ dtype=self.kv_cache_dtype,
1214
+ device=self.device,
1215
+ kvcache=self.token_to_kv_pool,
1216
+ )
1217
+ else:
1218
+ if _is_npu:
1219
+ self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1220
+ self.max_total_num_tokens,
1221
+ page_size=self.page_size,
1222
+ dtype=self.kv_cache_dtype,
1223
+ device=self.device,
1224
+ kvcache=self.token_to_kv_pool,
1225
+ )
1226
+ else:
1227
+ self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1228
+ self.max_total_num_tokens,
1229
+ page_size=self.page_size,
1230
+ dtype=self.kv_cache_dtype,
1231
+ device=self.device,
1232
+ kvcache=self.token_to_kv_pool,
1233
+ )
1023
1234
  else:
1024
1235
  assert self.is_draft_worker
1025
1236
 
@@ -1039,7 +1250,7 @@ class ModelRunner:
1039
1250
 
1040
1251
  def init_attention_backend(self):
1041
1252
  """Init attention kernel backend."""
1042
- if self.server_args.enable_two_batch_overlap:
1253
+ if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1043
1254
  self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
1044
1255
  else:
1045
1256
  self.attn_backend = self._get_attention_backend()
@@ -1066,6 +1277,10 @@ class ModelRunner:
1066
1277
  from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1067
1278
 
1068
1279
  return AiterAttnBackend(self)
1280
+ elif self.server_args.attention_backend == "ascend":
1281
+ from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1282
+
1283
+ return AscendAttnBackend(self)
1069
1284
  elif self.server_args.attention_backend == "triton":
1070
1285
  assert not self.model_config.is_encoder_decoder, (
1071
1286
  "Cross attention is not supported in the triton attention backend. "
@@ -1141,6 +1356,7 @@ class ModelRunner:
1141
1356
  def init_cuda_graphs(self):
1142
1357
  """Capture cuda graphs."""
1143
1358
  self.cuda_graph_runner = None
1359
+ self.cuda_graph_mem_usage = 0
1144
1360
 
1145
1361
  if not self.is_generation:
1146
1362
  # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
@@ -1156,11 +1372,36 @@ class ModelRunner:
1156
1372
  )
1157
1373
  self.cuda_graph_runner = CudaGraphRunner(self)
1158
1374
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1375
+ self.cuda_graph_mem_usage = before_mem - after_mem
1159
1376
  logger.info(
1160
1377
  f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1161
- f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
1378
+ f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1162
1379
  )
1163
1380
 
1381
+ def init_threads_binding(self):
1382
+ omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
1383
+ if omp_cpuids == "all":
1384
+ cpu_ids_by_node = get_cpu_ids_by_node()
1385
+ n_numa_node = len(cpu_ids_by_node)
1386
+
1387
+ assert self.tp_size <= n_numa_node, (
1388
+ f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
1389
+ f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
1390
+ f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
1391
+ f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
1392
+ f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
1393
+ f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
1394
+ f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
1395
+ f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
1396
+ )
1397
+ if self.tp_size < n_numa_node:
1398
+ logger.warning(
1399
+ f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
1400
+ )
1401
+ self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
1402
+ else:
1403
+ self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]
1404
+
1164
1405
  def apply_torch_tp(self):
1165
1406
  logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1166
1407
  from sglang.srt.model_parallel import tensor_parallel
@@ -124,6 +124,9 @@ def _get_quantization_config(
124
124
  quant_config = get_quant_config(
125
125
  model_config, load_config, packed_modules_mapping
126
126
  )
127
+ # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
128
+ if quant_config is None:
129
+ return None
127
130
  major, minor = get_device_capability()
128
131
 
129
132
  if major is not None and minor is not None:
@@ -534,6 +537,12 @@ class DummyModelLoader(BaseModelLoader):
534
537
  model_config: ModelConfig,
535
538
  device_config: DeviceConfig,
536
539
  ) -> nn.Module:
540
+
541
+ if get_bool_env_var("SGL_CPU_QUANTIZATION"):
542
+ return load_model_with_cpu_quantization(
543
+ self, model_config=model_config, device_config=device_config
544
+ )
545
+
537
546
  with set_default_torch_dtype(model_config.dtype):
538
547
  with torch.device(device_config.device):
539
548
  model = _initialize_model(
@@ -1464,6 +1473,38 @@ class RemoteModelLoader(BaseModelLoader):
1464
1473
  return model.eval()
1465
1474
 
1466
1475
 
1476
+ def load_model_with_cpu_quantization(
1477
+ self,
1478
+ *,
1479
+ model_config: ModelConfig,
1480
+ device_config: DeviceConfig,
1481
+ ) -> nn.Module:
1482
+ target_device = torch.device(device_config.device)
1483
+ with set_default_torch_dtype(model_config.dtype):
1484
+ model = _initialize_model(
1485
+ model_config,
1486
+ self.load_config,
1487
+ )
1488
+
1489
+ if not isinstance(self, DummyModelLoader):
1490
+ model.load_weights(self._get_all_weights(model_config, model))
1491
+
1492
+ for _, module in model.named_modules():
1493
+ quant_method = getattr(module, "quant_method", None)
1494
+ if quant_method is not None:
1495
+ # When quant methods need to process weights after loading
1496
+ # (for repacking, quantizing, etc), they expect parameters
1497
+ # to be on the global target device. This scope is for the
1498
+ # case where cpu offloading is used, where we will move the
1499
+ # parameters onto device for processing and back off after.
1500
+ with device_loading_context(module, target_device):
1501
+ quant_method.process_weights_after_loading(module)
1502
+
1503
+ model.to(target_device)
1504
+
1505
+ return model.eval()
1506
+
1507
+
1467
1508
  def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1468
1509
  """Get a model loader based on the load format."""
1469
1510