sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -26,10 +26,10 @@ from typing import List, Optional, Tuple, Union
26
26
  import torch
27
27
  import torch.distributed as dist
28
28
 
29
- from sglang.srt import debug_utils
30
29
  from sglang.srt.configs.device_config import DeviceConfig
31
30
  from sglang.srt.configs.load_config import LoadConfig
32
31
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
32
+ from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
33
33
  from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
34
34
  from sglang.srt.distributed import (
35
35
  get_tp_group,
@@ -40,6 +40,19 @@ from sglang.srt.distributed import (
40
40
  set_mscclpp_all_reduce,
41
41
  )
42
42
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
43
+ from sglang.srt.eplb.eplb_manager import EPLBManager
44
+ from sglang.srt.eplb.expert_distribution import (
45
+ ExpertDistributionRecorder,
46
+ get_global_expert_distribution_recorder,
47
+ set_global_expert_distribution_recorder,
48
+ )
49
+ from sglang.srt.eplb.expert_location import (
50
+ ExpertLocationMetadata,
51
+ compute_initial_expert_location_metadata,
52
+ get_global_expert_location_metadata,
53
+ set_global_expert_location_metadata,
54
+ )
55
+ from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
43
56
  from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
44
57
  from sglang.srt.layers.dp_attention import (
45
58
  get_attention_tp_group,
@@ -55,35 +68,27 @@ from sglang.srt.layers.sampler import Sampler
55
68
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
56
69
  from sglang.srt.layers.utils import is_sm100_supported
57
70
  from sglang.srt.lora.lora_manager import LoRAManager
58
- from sglang.srt.managers.eplb_manager import EPLBManager
59
- from sglang.srt.managers.expert_distribution import (
60
- ExpertDistributionRecorder,
61
- get_global_expert_distribution_recorder,
62
- set_global_expert_distribution_recorder,
63
- )
64
- from sglang.srt.managers.expert_location import (
65
- ExpertLocationMetadata,
66
- compute_initial_expert_location_metadata,
67
- get_global_expert_location_metadata,
68
- set_global_expert_location_metadata,
69
- )
70
71
  from sglang.srt.managers.schedule_batch import (
71
72
  GLOBAL_SERVER_ARGS_KEYS,
72
73
  global_server_args_dict,
73
74
  )
74
75
  from sglang.srt.mem_cache.allocator import (
76
+ AscendPagedTokenToKVPoolAllocator,
75
77
  BaseTokenToKVPoolAllocator,
76
78
  PagedTokenToKVPoolAllocator,
79
+ SWATokenToKVPoolAllocator,
77
80
  TokenToKVPoolAllocator,
78
81
  )
79
82
  from sglang.srt.mem_cache.memory_pool import (
83
+ AscendMLAPagedTokenToKVPool,
84
+ AscendTokenToKVPool,
80
85
  DoubleSparseTokenToKVPool,
81
86
  MHATokenToKVPool,
82
87
  MLATokenToKVPool,
83
88
  ReqToTokenPool,
89
+ SWAKVPool,
84
90
  )
85
91
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
86
- from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
87
92
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
88
93
  from sglang.srt.model_loader import get_model
89
94
  from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
@@ -101,6 +106,7 @@ from sglang.srt.utils import (
101
106
  enable_show_time_cost,
102
107
  get_available_gpu_memory,
103
108
  get_bool_env_var,
109
+ get_cpu_ids_by_node,
104
110
  init_custom_process_group,
105
111
  is_cuda,
106
112
  is_fa3_default_architecture,
@@ -108,6 +114,7 @@ from sglang.srt.utils import (
108
114
  is_hip,
109
115
  is_hopper_with_cuda_12_3,
110
116
  is_no_spec_infer_or_topk_one,
117
+ is_npu,
111
118
  monkey_patch_p2p_access_check,
112
119
  monkey_patch_vllm_gguf_config,
113
120
  set_cpu_offload_max_bytes,
@@ -115,6 +122,7 @@ from sglang.srt.utils import (
115
122
  )
116
123
 
117
124
  _is_hip = is_hip()
125
+ _is_npu = is_npu()
118
126
  _is_cpu_amx_available = cpu_has_amx_support()
119
127
 
120
128
  # Use a small KV cache pool size for tests in CI
@@ -158,7 +166,6 @@ class ModelRunner:
158
166
  token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
159
167
  ):
160
168
  # Parse args
161
- self.model_config = model_config
162
169
  self.mem_fraction_static = mem_fraction_static
163
170
  self.device = server_args.device
164
171
  self.gpu_id = gpu_id
@@ -171,6 +178,7 @@ class ModelRunner:
171
178
  self.dp_size = server_args.dp_size
172
179
  self.pp_rank = pp_rank
173
180
  self.pp_size = pp_size
181
+ self.model_config = model_config
174
182
  self.dist_port = nccl_port
175
183
  self.server_args = server_args
176
184
  self.is_draft_worker = is_draft_worker
@@ -185,6 +193,7 @@ class ModelRunner:
185
193
  self.page_size = server_args.page_size
186
194
  self.req_to_token_pool = req_to_token_pool
187
195
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
196
+ self.is_hybrid = model_config.is_hybrid
188
197
  self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
189
198
  self.attention_chunk_size = model_config.attention_chunk_size
190
199
 
@@ -209,6 +218,10 @@ class ModelRunner:
209
218
  # CPU offload
210
219
  set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
211
220
 
221
+ # Init OpenMP threads binding for CPU
222
+ if self.device == "cpu":
223
+ self.init_threads_binding()
224
+
212
225
  # Get memory before model loading
213
226
  min_per_gpu_memory = self.init_torch_distributed()
214
227
 
@@ -223,6 +236,7 @@ class ModelRunner:
223
236
  self.support_pp = (
224
237
  "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
225
238
  )
239
+ self._model_update_group = {}
226
240
 
227
241
  def initialize(self, min_per_gpu_memory: float):
228
242
  server_args = self.server_args
@@ -239,7 +253,7 @@ class ModelRunner:
239
253
  "SGLANG_LOG_EXPERT_LOCATION_METADATA"
240
254
  ):
241
255
  logger.info(
242
- f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
256
+ f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
243
257
  )
244
258
 
245
259
  set_global_expert_distribution_recorder(
@@ -300,11 +314,31 @@ class ModelRunner:
300
314
  self.init_cuda_graphs()
301
315
  else:
302
316
  self.cuda_graph_runner = None
317
+ self.cuda_graph_mem_usage = 0
303
318
  self.init_attention_backend()
304
319
 
305
320
  # auxiliary hidden capture mode. TODO: expose this to server args?
306
321
  if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
307
- self.model.set_eagle3_layers_to_capture()
322
+ # load draft config
323
+ draft_model_config = ModelConfig.from_server_args(
324
+ server_args,
325
+ model_path=(server_args.speculative_draft_model_path),
326
+ is_draft_model=True,
327
+ )
328
+
329
+ try:
330
+ # get the aux layer from draft model config
331
+ eagle_config = getattr(
332
+ draft_model_config.hf_config, "eagle_config", None
333
+ )
334
+ eagle_aux_hidden_state_layer_ids = eagle_config[
335
+ "eagle_aux_hidden_state_layer_ids"
336
+ ]
337
+ except:
338
+ # if there is no aux layer, set to None
339
+ eagle_aux_hidden_state_layer_ids = None
340
+
341
+ self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
308
342
 
309
343
  def model_specific_adjustment(self):
310
344
  server_args = self.server_args
@@ -342,6 +376,8 @@ class ModelRunner:
342
376
  server_args.attention_backend = "fa3"
343
377
  elif _is_hip:
344
378
  server_args.attention_backend = "aiter"
379
+ elif _is_npu:
380
+ server_args.attention_backend = "ascend"
345
381
  else:
346
382
  server_args.attention_backend = (
347
383
  "flashinfer" if is_flashinfer_available() else "triton"
@@ -361,6 +397,8 @@ class ModelRunner:
361
397
  server_args.attention_backend = "aiter"
362
398
  else:
363
399
  server_args.attention_backend = "triton"
400
+ elif _is_npu:
401
+ server_args.attention_backend = "ascend"
364
402
  else:
365
403
  server_args.attention_backend = "triton"
366
404
  logger.info(
@@ -375,6 +413,7 @@ class ModelRunner:
375
413
  "triton",
376
414
  "flashmla",
377
415
  "cutlass_mla",
416
+ "ascend",
378
417
  ]:
379
418
  logger.info(
380
419
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
@@ -412,11 +451,6 @@ class ModelRunner:
412
451
  self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
413
452
 
414
453
  if self.is_multimodal:
415
- self.mem_fraction_static *= 0.90
416
- logger.info(
417
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
418
- f"because this is a multimodal model."
419
- )
420
454
  if not self.is_multimodal_chunked_prefill_supported:
421
455
  server_args.chunked_prefill_size = -1
422
456
  logger.info(
@@ -437,6 +471,10 @@ class ModelRunner:
437
471
  if self.model_config.context_len > 8192:
438
472
  self.mem_fraction_static *= 0.85
439
473
 
474
+ if self.is_hybrid and not server_args.disable_radix_cache:
475
+ logger.info("Automatically disable radix cache for hybrid cache.")
476
+ server_args.disable_radix_cache = True
477
+
440
478
  def init_torch_distributed(self):
441
479
  logger.info("Init torch distributed begin.")
442
480
 
@@ -471,6 +509,19 @@ class ModelRunner:
471
509
  set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
472
510
 
473
511
  if not self.is_draft_worker:
512
+ if self.device == "cpu":
513
+ if _is_cpu_amx_available:
514
+ # Bind OpenMP threads to CPU cores
515
+ torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
516
+
517
+ # Set local size to hint SGLang to use shared memory based AllReduce
518
+ os.environ["LOCAL_SIZE"] = str(self.tp_size)
519
+ torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
520
+ else:
521
+ logger.warning(
522
+ "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
523
+ )
524
+
474
525
  # Only initialize the distributed environment on the target model worker.
475
526
  init_distributed_environment(
476
527
  backend=backend,
@@ -547,7 +598,12 @@ class ModelRunner:
547
598
  self.load_config = LoadConfig(
548
599
  load_format=self.server_args.load_format,
549
600
  download_dir=self.server_args.download_dir,
601
+ model_loader_extra_config=self.server_args.model_loader_extra_config,
550
602
  )
603
+ if self.device == "cpu":
604
+ self.model_config = adjust_config_with_unaligned_cpu_tp(
605
+ self.model_config, self.load_config, self.tp_size
606
+ )
551
607
  if self.server_args.load_format == "gguf":
552
608
  monkey_patch_vllm_gguf_config()
553
609
 
@@ -597,12 +653,13 @@ class ModelRunner:
597
653
  self.dtype = self.model_config.dtype
598
654
 
599
655
  after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
656
+ self.weight_load_mem_usage = before_avail_memory - after_avail_memory
600
657
  logger.info(
601
658
  f"Load weight end. "
602
659
  f"type={type(self.model).__name__}, "
603
660
  f"dtype={self.dtype}, "
604
661
  f"avail mem={after_avail_memory:.2f} GB, "
605
- f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
662
+ f"mem usage={self.weight_load_mem_usage:.2f} GB."
606
663
  )
607
664
 
608
665
  # Handle the case where some ranks do not finish loading.
@@ -717,7 +774,7 @@ class ModelRunner:
717
774
  )
718
775
 
719
776
  try:
720
- self._model_update_group = init_custom_process_group(
777
+ self._model_update_group[group_name] = init_custom_process_group(
721
778
  backend=backend,
722
779
  init_method=f"tcp://{master_address}:{master_port}",
723
780
  world_size=world_size,
@@ -730,7 +787,7 @@ class ModelRunner:
730
787
  logger.error(message)
731
788
  return False, message
732
789
 
733
- def update_weights_from_distributed(self, name, dtype, shape):
790
+ def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
734
791
  """
735
792
  Update specific parameter in the model weights online
736
793
  through `_model_update_group` process group.
@@ -740,19 +797,34 @@ class ModelRunner:
740
797
  dtype: the data type of the parameter to be updated.
741
798
  shape: the shape of the parameter to be updated.
742
799
  """
743
- target_dtype = (
744
- dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
745
- )
746
800
 
747
- assert (
748
- self._model_update_group is not None
749
- ), "model update group must be initialized"
801
+ assert group_name in self._model_update_group, (
802
+ f"Group {group_name} not in {list(self._model_update_group.keys())}. "
803
+ "Please call `init_weights_update_group` first."
804
+ )
750
805
 
751
806
  try:
752
- weights = torch.empty(shape, dtype=target_dtype, device=self.device)
753
- torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
754
- self.model.load_weights([(name, weights)])
755
- return True, f"Succeeded to update parameter {name} online."
807
+ weights = []
808
+ handles = []
809
+ for name, dtype, shape in zip(names, dtypes, shapes):
810
+ target_dtype = (
811
+ dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
812
+ )
813
+ weight = torch.empty(shape, dtype=target_dtype, device=self.device)
814
+ handles.append(
815
+ torch.distributed.broadcast(
816
+ weight,
817
+ src=0,
818
+ group=self._model_update_group[group_name],
819
+ async_op=True,
820
+ )
821
+ )
822
+ weights.append((name, weight))
823
+ for handle in handles:
824
+ handle.wait()
825
+
826
+ self.model.load_weights(weights)
827
+ return True, f"Succeeded to update parameter online."
756
828
 
757
829
  except Exception as e:
758
830
  error_msg = (
@@ -811,8 +883,47 @@ class ModelRunner:
811
883
  tp_size=self.tp_size,
812
884
  tp_rank=self.tp_rank,
813
885
  )
814
- self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
815
- logger.info("LoRA manager ready.")
886
+ result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
887
+ if result.success:
888
+ logger.info(
889
+ f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
890
+ )
891
+ else:
892
+ raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
893
+
894
+ def load_lora_adapter(self, lora_name: str, lora_path: str):
895
+ """Load a new lora adapter from disk or huggingface."""
896
+
897
+ logger.info(
898
+ f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
899
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
900
+ )
901
+
902
+ result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
903
+
904
+ logger.info(
905
+ f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
906
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
907
+ )
908
+
909
+ return result
910
+
911
+ def unload_lora_adapter(self, lora_name: str):
912
+ """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
913
+
914
+ logger.info(
915
+ f"LoRA adapter unloading starts: name={lora_name}. "
916
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
917
+ )
918
+
919
+ result = self.lora_manager.unload_lora_adapter(lora_name)
920
+
921
+ logger.info(
922
+ f"LoRA adapter unloading completes: name={lora_name}. "
923
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
924
+ )
925
+
926
+ return result
816
927
 
817
928
  def profile_max_num_token(self, total_gpu_memory: int):
818
929
  available_gpu_memory = get_available_gpu_memory(
@@ -851,6 +962,40 @@ class ModelRunner:
851
962
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
852
963
  return max_num_token
853
964
 
965
+ def set_num_token_hybrid(self):
966
+ if (
967
+ "Llama4ForConditionalGeneration"
968
+ in self.model_config.hf_config.architectures
969
+ ):
970
+ temp_ratio = (
971
+ (1 - self.is_hybrid)
972
+ + self.is_hybrid
973
+ * self.attention_chunk_size
974
+ / self.model_config.context_len
975
+ )
976
+ self.swa_max_total_num_tokens = (
977
+ 4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
978
+ )
979
+ self.full_max_total_num_tokens = (
980
+ 4 * self.max_total_num_tokens
981
+ - 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
982
+ )
983
+ self.swa_max_total_num_tokens = int(
984
+ self.swa_max_total_num_tokens
985
+ // self.server_args.page_size
986
+ * self.server_args.page_size
987
+ )
988
+ self.full_max_total_num_tokens = int(
989
+ self.full_max_total_num_tokens
990
+ // self.server_args.page_size
991
+ * self.server_args.page_size
992
+ )
993
+ self.max_total_num_tokens = self.full_max_total_num_tokens
994
+ else:
995
+ raise ValueError(
996
+ f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}."
997
+ )
998
+
854
999
  def init_memory_pool(
855
1000
  self,
856
1001
  total_gpu_memory: int,
@@ -865,7 +1010,9 @@ class ModelRunner:
865
1010
  else:
866
1011
  self.kv_cache_dtype = torch.float8_e5m2
867
1012
  elif self.server_args.kv_cache_dtype == "fp8_e4m3":
868
- if is_cuda():
1013
+ if _is_hip: # Using natively supported format
1014
+ self.kv_cache_dtype = torch.float8_e4m3fnuz
1015
+ else:
869
1016
  self.kv_cache_dtype = torch.float8_e4m3fn
870
1017
  else:
871
1018
  raise ValueError(
@@ -926,6 +1073,10 @@ class ModelRunner:
926
1073
  * self.server_args.page_size
927
1074
  )
928
1075
 
1076
+ # create token size for hybrid cache
1077
+ if self.is_hybrid:
1078
+ self.set_num_token_hybrid()
1079
+
929
1080
  if self.max_total_num_tokens <= 0:
930
1081
  raise RuntimeError(
931
1082
  "Not enough memory. Please try to increase --mem-fraction-static."
@@ -956,8 +1107,19 @@ class ModelRunner:
956
1107
  # Draft worker shares req_to_token_pool with the target worker.
957
1108
  assert self.is_draft_worker
958
1109
 
959
- if self.use_mla_backend:
960
- self.token_to_kv_pool = MLATokenToKVPool(
1110
+ if self.server_args.attention_backend == "ascend" and not self.use_mla_backend:
1111
+ self.token_to_kv_pool = AscendTokenToKVPool(
1112
+ self.max_total_num_tokens,
1113
+ page_size=self.page_size,
1114
+ dtype=self.kv_cache_dtype,
1115
+ head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
1116
+ head_dim=self.model_config.head_dim,
1117
+ layer_num=self.model_config.num_hidden_layers,
1118
+ device=self.device,
1119
+ enable_memory_saver=self.server_args.enable_memory_saver,
1120
+ )
1121
+ elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
1122
+ self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
961
1123
  self.max_total_num_tokens,
962
1124
  page_size=self.page_size,
963
1125
  dtype=self.kv_cache_dtype,
@@ -973,22 +1135,25 @@ class ModelRunner:
973
1135
  start_layer=self.start_layer,
974
1136
  end_layer=self.end_layer,
975
1137
  )
976
- elif self.server_args.enable_double_sparsity:
977
- self.token_to_kv_pool = DoubleSparseTokenToKVPool(
1138
+ elif self.use_mla_backend:
1139
+ self.token_to_kv_pool = MLATokenToKVPool(
978
1140
  self.max_total_num_tokens,
979
1141
  page_size=self.page_size,
980
1142
  dtype=self.kv_cache_dtype,
981
- head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
982
- head_dim=self.model_config.head_dim,
983
- layer_num=self.num_effective_layers,
1143
+ kv_lora_rank=self.model_config.kv_lora_rank,
1144
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1145
+ layer_num=(
1146
+ self.model_config.num_hidden_layers
1147
+ if not self.is_draft_worker
1148
+ else self.model_config.hf_config.num_nextn_predict_layers
1149
+ ), # PP is not compatible with mla backend
984
1150
  device=self.device,
985
- heavy_channel_num=self.server_args.ds_heavy_channel_num,
986
1151
  enable_memory_saver=self.server_args.enable_memory_saver,
987
1152
  start_layer=self.start_layer,
988
1153
  end_layer=self.end_layer,
989
1154
  )
990
- else:
991
- self.token_to_kv_pool = MHATokenToKVPool(
1155
+ elif self.server_args.enable_double_sparsity:
1156
+ self.token_to_kv_pool = DoubleSparseTokenToKVPool(
992
1157
  self.max_total_num_tokens,
993
1158
  page_size=self.page_size,
994
1159
  dtype=self.kv_cache_dtype,
@@ -996,27 +1161,76 @@ class ModelRunner:
996
1161
  head_dim=self.model_config.head_dim,
997
1162
  layer_num=self.num_effective_layers,
998
1163
  device=self.device,
1164
+ heavy_channel_num=self.server_args.ds_heavy_channel_num,
999
1165
  enable_memory_saver=self.server_args.enable_memory_saver,
1000
1166
  start_layer=self.start_layer,
1001
1167
  end_layer=self.end_layer,
1002
1168
  )
1003
-
1004
- if self.token_to_kv_pool_allocator is None:
1005
- if self.page_size == 1:
1006
- self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
1007
- self.max_total_num_tokens,
1169
+ else:
1170
+ if self.is_hybrid:
1171
+ self.token_to_kv_pool = SWAKVPool(
1172
+ size=self.full_max_total_num_tokens,
1173
+ size_swa=self.swa_max_total_num_tokens,
1008
1174
  dtype=self.kv_cache_dtype,
1175
+ head_num=self.model_config.get_num_kv_heads(
1176
+ get_attention_tp_size()
1177
+ ),
1178
+ head_dim=self.model_config.head_dim,
1179
+ swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
1180
+ full_attention_layer_ids=self.model_config.full_attention_layer_ids,
1181
+ enable_kvcache_transpose=False,
1009
1182
  device=self.device,
1010
- kvcache=self.token_to_kv_pool,
1011
1183
  )
1012
1184
  else:
1013
- self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1185
+ self.token_to_kv_pool = MHATokenToKVPool(
1014
1186
  self.max_total_num_tokens,
1015
1187
  page_size=self.page_size,
1016
1188
  dtype=self.kv_cache_dtype,
1189
+ head_num=self.model_config.get_num_kv_heads(
1190
+ get_attention_tp_size()
1191
+ ),
1192
+ head_dim=self.model_config.head_dim,
1193
+ layer_num=self.num_effective_layers,
1017
1194
  device=self.device,
1018
- kvcache=self.token_to_kv_pool,
1195
+ enable_memory_saver=self.server_args.enable_memory_saver,
1196
+ start_layer=self.start_layer,
1197
+ end_layer=self.end_layer,
1019
1198
  )
1199
+
1200
+ if self.token_to_kv_pool_allocator is None:
1201
+ if self.page_size == 1:
1202
+ if self.is_hybrid:
1203
+ self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
1204
+ self.full_max_total_num_tokens,
1205
+ self.swa_max_total_num_tokens,
1206
+ dtype=self.kv_cache_dtype,
1207
+ device=self.device,
1208
+ kvcache=self.token_to_kv_pool,
1209
+ )
1210
+ else:
1211
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
1212
+ self.max_total_num_tokens,
1213
+ dtype=self.kv_cache_dtype,
1214
+ device=self.device,
1215
+ kvcache=self.token_to_kv_pool,
1216
+ )
1217
+ else:
1218
+ if _is_npu:
1219
+ self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1220
+ self.max_total_num_tokens,
1221
+ page_size=self.page_size,
1222
+ dtype=self.kv_cache_dtype,
1223
+ device=self.device,
1224
+ kvcache=self.token_to_kv_pool,
1225
+ )
1226
+ else:
1227
+ self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1228
+ self.max_total_num_tokens,
1229
+ page_size=self.page_size,
1230
+ dtype=self.kv_cache_dtype,
1231
+ device=self.device,
1232
+ kvcache=self.token_to_kv_pool,
1233
+ )
1020
1234
  else:
1021
1235
  assert self.is_draft_worker
1022
1236
 
@@ -1036,7 +1250,7 @@ class ModelRunner:
1036
1250
 
1037
1251
  def init_attention_backend(self):
1038
1252
  """Init attention kernel backend."""
1039
- if self.server_args.enable_two_batch_overlap:
1253
+ if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1040
1254
  self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
1041
1255
  else:
1042
1256
  self.attn_backend = self._get_attention_backend()
@@ -1063,6 +1277,10 @@ class ModelRunner:
1063
1277
  from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1064
1278
 
1065
1279
  return AiterAttnBackend(self)
1280
+ elif self.server_args.attention_backend == "ascend":
1281
+ from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1282
+
1283
+ return AscendAttnBackend(self)
1066
1284
  elif self.server_args.attention_backend == "triton":
1067
1285
  assert not self.model_config.is_encoder_decoder, (
1068
1286
  "Cross attention is not supported in the triton attention backend. "
@@ -1138,6 +1356,7 @@ class ModelRunner:
1138
1356
  def init_cuda_graphs(self):
1139
1357
  """Capture cuda graphs."""
1140
1358
  self.cuda_graph_runner = None
1359
+ self.cuda_graph_mem_usage = 0
1141
1360
 
1142
1361
  if not self.is_generation:
1143
1362
  # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
@@ -1153,11 +1372,36 @@ class ModelRunner:
1153
1372
  )
1154
1373
  self.cuda_graph_runner = CudaGraphRunner(self)
1155
1374
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1375
+ self.cuda_graph_mem_usage = before_mem - after_mem
1156
1376
  logger.info(
1157
1377
  f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1158
- f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
1378
+ f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1159
1379
  )
1160
1380
 
1381
+ def init_threads_binding(self):
1382
+ omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
1383
+ if omp_cpuids == "all":
1384
+ cpu_ids_by_node = get_cpu_ids_by_node()
1385
+ n_numa_node = len(cpu_ids_by_node)
1386
+
1387
+ assert self.tp_size <= n_numa_node, (
1388
+ f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
1389
+ f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
1390
+ f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
1391
+ f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
1392
+ f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
1393
+ f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
1394
+ f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
1395
+ f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
1396
+ )
1397
+ if self.tp_size < n_numa_node:
1398
+ logger.warning(
1399
+ f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
1400
+ )
1401
+ self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
1402
+ else:
1403
+ self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]
1404
+
1161
1405
  def apply_torch_tp(self):
1162
1406
  logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1163
1407
  from sglang.srt.model_parallel import tensor_parallel