sglang 0.4.9.post1__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +33 -8
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +22 -2
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  11. sglang/srt/disaggregation/ascend/conn.py +44 -0
  12. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  13. sglang/srt/disaggregation/decode.py +9 -1
  14. sglang/srt/disaggregation/mooncake/conn.py +59 -70
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  16. sglang/srt/disaggregation/utils.py +25 -3
  17. sglang/srt/distributed/parallel_state.py +33 -0
  18. sglang/srt/entrypoints/engine.py +30 -26
  19. sglang/srt/entrypoints/http_server.py +1 -0
  20. sglang/srt/entrypoints/openai/protocol.py +11 -0
  21. sglang/srt/entrypoints/openai/serving_chat.py +28 -2
  22. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  23. sglang/srt/function_call/function_call_parser.py +4 -0
  24. sglang/srt/function_call/kimik2_detector.py +220 -0
  25. sglang/srt/function_call/qwen3_detector.py +150 -0
  26. sglang/srt/hf_transformers_utils.py +17 -0
  27. sglang/srt/jinja_template_utils.py +8 -0
  28. sglang/srt/layers/activation.py +13 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  30. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  31. sglang/srt/layers/communicator.py +17 -4
  32. sglang/srt/layers/linear.py +24 -103
  33. sglang/srt/layers/moe/ep_moe/kernels.py +6 -3
  34. sglang/srt/layers/moe/ep_moe/layer.py +24 -402
  35. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  36. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  37. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  43. sglang/srt/layers/moe/fused_moe_triton/layer.py +21 -398
  44. sglang/srt/layers/moe/topk.py +195 -14
  45. sglang/srt/layers/parameter.py +19 -3
  46. sglang/srt/layers/quantization/__init__.py +20 -134
  47. sglang/srt/layers/quantization/awq.py +578 -11
  48. sglang/srt/layers/quantization/awq_triton.py +339 -0
  49. sglang/srt/layers/quantization/base_config.py +85 -10
  50. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  51. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  52. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  53. sglang/srt/layers/quantization/fp8.py +273 -62
  54. sglang/srt/layers/quantization/fp8_kernel.py +212 -48
  55. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  56. sglang/srt/layers/quantization/gptq.py +501 -143
  57. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  58. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  59. sglang/srt/layers/quantization/moe_wna16.py +46 -51
  60. sglang/srt/layers/quantization/petit.py +252 -0
  61. sglang/srt/layers/quantization/petit_utils.py +104 -0
  62. sglang/srt/layers/quantization/qoq.py +7 -6
  63. sglang/srt/layers/quantization/scalar_type.py +352 -0
  64. sglang/srt/layers/quantization/unquant.py +422 -0
  65. sglang/srt/layers/quantization/utils.py +343 -3
  66. sglang/srt/layers/quantization/w4afp8.py +8 -4
  67. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  68. sglang/srt/layers/quantization/w8a8_int8.py +718 -58
  69. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  70. sglang/srt/lora/lora.py +0 -4
  71. sglang/srt/lora/lora_manager.py +87 -53
  72. sglang/srt/lora/mem_pool.py +81 -33
  73. sglang/srt/lora/utils.py +12 -5
  74. sglang/srt/managers/cache_controller.py +241 -0
  75. sglang/srt/managers/io_struct.py +65 -28
  76. sglang/srt/managers/mm_utils.py +61 -101
  77. sglang/srt/managers/schedule_batch.py +162 -111
  78. sglang/srt/managers/schedule_policy.py +68 -27
  79. sglang/srt/managers/scheduler.py +264 -62
  80. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  81. sglang/srt/managers/tokenizer_manager.py +27 -3
  82. sglang/srt/managers/tp_worker.py +14 -0
  83. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  84. sglang/srt/mem_cache/allocator.py +7 -16
  85. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  86. sglang/srt/mem_cache/chunk_cache.py +5 -2
  87. sglang/srt/mem_cache/hicache_storage.py +152 -0
  88. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  89. sglang/srt/mem_cache/memory_pool.py +81 -41
  90. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  91. sglang/srt/mem_cache/radix_cache.py +26 -0
  92. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  93. sglang/srt/metrics/collector.py +9 -0
  94. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  95. sglang/srt/model_executor/forward_batch_info.py +27 -2
  96. sglang/srt/model_executor/model_runner.py +109 -22
  97. sglang/srt/model_loader/loader.py +30 -13
  98. sglang/srt/model_loader/utils.py +4 -4
  99. sglang/srt/models/clip.py +1 -1
  100. sglang/srt/models/deepseek.py +9 -6
  101. sglang/srt/models/deepseek_janus_pro.py +2 -2
  102. sglang/srt/models/deepseek_v2.py +252 -187
  103. sglang/srt/models/deepseek_vl2.py +6 -6
  104. sglang/srt/models/gemma.py +48 -0
  105. sglang/srt/models/gemma2.py +52 -0
  106. sglang/srt/models/gemma3_causal.py +63 -0
  107. sglang/srt/models/gemma3_mm.py +2 -2
  108. sglang/srt/models/gemma3n_mm.py +8 -7
  109. sglang/srt/models/granitemoe.py +385 -0
  110. sglang/srt/models/grok.py +9 -3
  111. sglang/srt/models/hunyuan.py +63 -16
  112. sglang/srt/models/internvl.py +9 -3
  113. sglang/srt/models/kimi_vl.py +9 -3
  114. sglang/srt/models/llama.py +43 -0
  115. sglang/srt/models/llama4.py +11 -11
  116. sglang/srt/models/llava.py +5 -3
  117. sglang/srt/models/llavavid.py +2 -2
  118. sglang/srt/models/minicpm.py +0 -2
  119. sglang/srt/models/minicpmo.py +4 -9
  120. sglang/srt/models/minicpmv.py +2 -2
  121. sglang/srt/models/mistral.py +1 -1
  122. sglang/srt/models/mixtral.py +9 -2
  123. sglang/srt/models/mixtral_quant.py +4 -0
  124. sglang/srt/models/mllama.py +3 -5
  125. sglang/srt/models/mllama4.py +16 -7
  126. sglang/srt/models/olmoe.py +8 -5
  127. sglang/srt/models/persimmon.py +330 -0
  128. sglang/srt/models/phi.py +321 -0
  129. sglang/srt/models/phi4mm.py +52 -6
  130. sglang/srt/models/phi4mm_audio.py +1260 -0
  131. sglang/srt/models/phi4mm_utils.py +1917 -0
  132. sglang/srt/models/phimoe.py +559 -0
  133. sglang/srt/models/qwen.py +37 -0
  134. sglang/srt/models/qwen2.py +43 -0
  135. sglang/srt/models/qwen2_5_vl.py +10 -7
  136. sglang/srt/models/qwen2_audio.py +1 -1
  137. sglang/srt/models/qwen2_moe.py +53 -5
  138. sglang/srt/models/qwen2_vl.py +13 -2
  139. sglang/srt/models/qwen3.py +65 -1
  140. sglang/srt/models/qwen3_moe.py +56 -18
  141. sglang/srt/models/vila.py +9 -3
  142. sglang/srt/multimodal/processors/base_processor.py +273 -219
  143. sglang/srt/multimodal/processors/clip.py +21 -19
  144. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  145. sglang/srt/multimodal/processors/gemma3.py +13 -15
  146. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  147. sglang/srt/multimodal/processors/internvl.py +10 -11
  148. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  149. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  150. sglang/srt/multimodal/processors/llava.py +4 -2
  151. sglang/srt/multimodal/processors/minicpm.py +37 -45
  152. sglang/srt/multimodal/processors/mlama.py +21 -18
  153. sglang/srt/multimodal/processors/mllama4.py +5 -6
  154. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  155. sglang/srt/multimodal/processors/pixtral.py +14 -35
  156. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  157. sglang/srt/multimodal/processors/qwen_vl.py +211 -93
  158. sglang/srt/multimodal/processors/vila.py +14 -14
  159. sglang/srt/sampling/sampling_params.py +8 -1
  160. sglang/srt/server_args.py +404 -234
  161. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  162. sglang/srt/two_batch_overlap.py +1 -0
  163. sglang/srt/utils.py +181 -32
  164. sglang/test/runners.py +14 -3
  165. sglang/test/test_block_fp8.py +8 -3
  166. sglang/test/test_block_fp8_ep.py +1 -1
  167. sglang/test/test_custom_ops.py +12 -7
  168. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  169. sglang/test/test_fp4_moe.py +1 -3
  170. sglang/test/test_marlin_moe.py +286 -0
  171. sglang/test/test_marlin_utils.py +171 -0
  172. sglang/test/test_utils.py +35 -0
  173. sglang/version.py +1 -1
  174. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +10 -9
  175. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +178 -153
  176. sglang/srt/layers/quantization/quant_utils.py +0 -166
  177. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  178. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  179. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -271,12 +271,13 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
271
271
  batch,
272
272
  dp_size=model_runner.server_args.dp_size,
273
273
  attn_tp_size=1,
274
- tp_cpu_group=model_runner.tp_group.cpu_group,
274
+ tp_group=model_runner.tp_group,
275
275
  get_idle_batch=None,
276
276
  disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
277
277
  spec_algorithm=SpeculativeAlgorithm.NONE,
278
278
  speculative_num_draft_tokens=None,
279
279
  require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
280
+ disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
280
281
  )
281
282
 
282
283
 
@@ -73,6 +73,8 @@ async def benchmark(args):
73
73
 
74
74
  tasks: List[asyncio.Task] = []
75
75
  for idx, ex in enumerate(dataset):
76
+ if idx >= args.num_prompts:
77
+ break
76
78
  tasks.append(
77
79
  asyncio.create_task(
78
80
  fetch_response(
@@ -103,6 +105,8 @@ def analyse(args):
103
105
  hyps: List[str] = []
104
106
  refs: List[str] = []
105
107
  for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
108
+ if idx >= args.num_prompts:
109
+ break
106
110
  pkl_file = output_dir / f"response_{idx}.pkl"
107
111
  if not pkl_file.exists():
108
112
  raise FileNotFoundError(pkl_file)
@@ -150,6 +154,9 @@ if __name__ == "__main__":
150
154
  parser.add_argument(
151
155
  "--output-dir", default="tmp-output-dir", help="Directory for cached responses"
152
156
  )
157
+ parser.add_argument(
158
+ "--num-prompts", type=int, default=10000, help="Number of prompts to run"
159
+ )
153
160
  args = parser.parse_args()
154
161
 
155
162
  asyncio.run(benchmark(args))
@@ -42,6 +42,9 @@ def select_best_resolution(image_size, candidate_resolutions):
42
42
 
43
43
 
44
44
  class DictOutput(object):
45
+ def items(self):
46
+ return self.__dict__.items()
47
+
45
48
  def keys(self):
46
49
  return self.__dict__.keys()
47
50
 
@@ -59,7 +62,9 @@ class DictOutput(object):
59
62
  class VLChatProcessorOutput(DictOutput):
60
63
  input_ids: torch.LongTensor
61
64
  target_ids: torch.LongTensor
62
- images: torch.Tensor
65
+ pixel_values: (
66
+ torch.Tensor
67
+ ) # rename from "images" to "pixel_values" for compatibility
63
68
  images_seq_mask: torch.BoolTensor
64
69
  images_spatial_crop: torch.LongTensor
65
70
 
@@ -312,10 +317,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
312
317
  images = torch.stack(images_list, dim=0)
313
318
  images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
314
319
 
320
+ images_spatial_crop = torch.stack(
321
+ [images_spatial_crop], dim=0
322
+ ) # stack the tensor to make it a batch of 1
323
+
315
324
  prepare = VLChatProcessorOutput(
316
325
  input_ids=input_ids,
317
326
  target_ids=target_ids,
318
- images=images,
327
+ pixel_values=images,
319
328
  images_seq_mask=images_seq_mask,
320
329
  images_spatial_crop=images_spatial_crop,
321
330
  )
@@ -9,6 +9,7 @@ from transformers import (
9
9
  LlamaConfig,
10
10
  PretrainedConfig,
11
11
  PreTrainedTokenizer,
12
+ Qwen2Config,
12
13
  )
13
14
 
14
15
  from sglang.utils import logger
@@ -311,6 +312,8 @@ class InternVLChatConfig(PretrainedConfig):
311
312
  self.llm_config = LlamaConfig(**llm_config)
312
313
  elif llm_config.get("architectures")[0] == "InternLM2ForCausalLM":
313
314
  self.llm_config = InternLM2Config(**llm_config)
315
+ elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
316
+ self.llm_config = Qwen2Config(**llm_config)
314
317
  else:
315
318
  raise ValueError(
316
319
  "Unsupported architecture: {}".format(
@@ -284,6 +284,9 @@ class VLMImageProcessor(BaseImageProcessor):
284
284
 
285
285
 
286
286
  class DictOutput(object):
287
+ def items(self):
288
+ return self.__dict__.items()
289
+
287
290
  def keys(self):
288
291
  return self.__dict__.keys()
289
292
 
@@ -25,6 +25,7 @@ from transformers import PretrainedConfig
25
25
  from sglang.srt.hf_transformers_utils import (
26
26
  get_config,
27
27
  get_context_length,
28
+ get_generation_config,
28
29
  get_hf_text_config,
29
30
  )
30
31
  from sglang.srt.layers.quantization import QUANTIZATION_METHODS
@@ -52,7 +53,7 @@ class ModelConfig:
52
53
  trust_remote_code: bool = True,
53
54
  revision: Optional[str] = None,
54
55
  context_length: Optional[int] = None,
55
- model_override_args: Optional[str] = None,
56
+ model_override_args: str = "{}",
56
57
  is_embedding: Optional[bool] = None,
57
58
  enable_multimodal: Optional[bool] = None,
58
59
  dtype: str = "auto",
@@ -60,13 +61,13 @@ class ModelConfig:
60
61
  override_config_file: Optional[str] = None,
61
62
  is_draft_model: bool = False,
62
63
  hybrid_kvcache_ratio: Optional[float] = None,
63
- impl: Union[str, ModelImpl] = ModelImpl.AUTO,
64
+ model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
64
65
  ) -> None:
65
66
 
66
67
  self.model_path = model_path
67
68
  self.revision = revision
68
69
  self.quantization = quantization
69
- self.impl = impl
70
+ self.model_impl = model_impl
70
71
 
71
72
  # Parse args
72
73
  self.maybe_pull_model_tokenizer_from_remote()
@@ -83,6 +84,13 @@ class ModelConfig:
83
84
  **kwargs,
84
85
  )
85
86
 
87
+ self.hf_generation_config = get_generation_config(
88
+ self.model_path,
89
+ trust_remote_code=trust_remote_code,
90
+ revision=revision,
91
+ **kwargs,
92
+ )
93
+
86
94
  self.hf_text_config = get_hf_text_config(self.hf_config)
87
95
  self.attention_chunk_size = getattr(
88
96
  self.hf_text_config, "attention_chunk_size", None
@@ -278,7 +286,7 @@ class ModelConfig:
278
286
  dtype=server_args.dtype,
279
287
  quantization=server_args.quantization,
280
288
  hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
281
- impl=server_args.impl,
289
+ model_impl=server_args.model_impl,
282
290
  **kwargs,
283
291
  )
284
292
 
@@ -383,6 +391,7 @@ class ModelConfig:
383
391
  "compressed-tensors",
384
392
  "fbgemm_fp8",
385
393
  "w8a8_fp8",
394
+ "petit_nvfp4",
386
395
  ]
387
396
  optimized_quantization_methods = [
388
397
  "fp8",
@@ -400,9 +409,11 @@ class ModelConfig:
400
409
  "moe_wna16",
401
410
  "qoq",
402
411
  "w4afp8",
412
+ "petit_nvfp4",
403
413
  ]
404
414
  compatible_quantization_methods = {
405
415
  "modelopt_fp4": ["modelopt"],
416
+ "petit_nvfp4": ["modelopt"],
406
417
  "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
407
418
  "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
408
419
  }
@@ -413,7 +424,9 @@ class ModelConfig:
413
424
  quant_cfg = self._parse_quant_hf_config()
414
425
 
415
426
  if quant_cfg is not None:
416
- quant_method = quant_cfg.get("quant_method", "").lower()
427
+ quant_method = quant_cfg.get(
428
+ "quant_method", "" if not self.quantization else self.quantization
429
+ ).lower()
417
430
 
418
431
  # Detect which checkpoint is it
419
432
  for _, method in QUANTIZATION_METHODS.items():
@@ -465,6 +478,19 @@ class ModelConfig:
465
478
  if eos_ids:
466
479
  # it can be either int or list of int
467
480
  eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
481
+ if eos_ids is None:
482
+ eos_ids = set()
483
+ if self.hf_generation_config:
484
+ generation_eos_ids = getattr(
485
+ self.hf_generation_config, "eos_token_id", None
486
+ )
487
+ if generation_eos_ids:
488
+ generation_eos_ids = (
489
+ {generation_eos_ids}
490
+ if isinstance(generation_eos_ids, int)
491
+ else set(generation_eos_ids)
492
+ )
493
+ eos_ids = eos_ids | generation_eos_ids
468
494
  return eos_ids
469
495
 
470
496
  def maybe_pull_model_tokenizer_from_remote(self) -> None:
@@ -688,7 +714,6 @@ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int)
688
714
  i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
689
715
  ]
690
716
  else:
691
- raise ValueError(
692
- "get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
693
- )
717
+ swa_attention_layer_ids = None
718
+ full_attention_layer_ids = None
694
719
  return swa_attention_layer_ids, full_attention_layer_ids
@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
115
115
  model_config = update_intermediate_size(
116
116
  model_config, "intermediate_size", intermediate_padding_size
117
117
  )
118
-
118
+ model_config = update_intermediate_size(
119
+ model_config, "intermediate_size_mlp", intermediate_padding_size
120
+ )
119
121
  return model_config
@@ -88,9 +88,11 @@ class Conversation:
88
88
  stop_str: Union[str, List[str]] = None
89
89
  # The string that represents an image token in the prompt
90
90
  image_token: str = "<image>"
91
+ video_token: str = "<video>"
91
92
  audio_token: str = "<audio>"
92
93
 
93
94
  image_data: Optional[List[str]] = None
95
+ video_data: Optional[List[str]] = None
94
96
  modalities: Optional[List[str]] = None
95
97
  stop_token_ids: Optional[int] = None
96
98
 
@@ -380,11 +382,15 @@ class Conversation:
380
382
  self.messages.append([role, message])
381
383
 
382
384
  def append_image(self, image: str):
383
- """Append a new message."""
385
+ """Append a new image."""
384
386
  self.image_data.append(image)
385
387
 
388
+ def append_video(self, video: str):
389
+ """Append a new video."""
390
+ self.video_data.append(video)
391
+
386
392
  def append_audio(self, audio: str):
387
- """Append a new message."""
393
+ """Append a new audio."""
388
394
  self.audio_data.append(audio)
389
395
 
390
396
  def update_last_message(self, message: str):
@@ -433,6 +439,7 @@ class Conversation:
433
439
  sep2=self.sep2,
434
440
  stop_str=self.stop_str,
435
441
  image_token=self.image_token,
442
+ video_token=self.video_token,
436
443
  audio_token=self.audio_token,
437
444
  )
438
445
 
@@ -495,8 +502,12 @@ def generate_embedding_convs(
495
502
  sep2=conv_template.sep2,
496
503
  stop_str=conv_template.stop_str,
497
504
  image_data=[],
505
+ video_data=[],
506
+ audio_data=[],
498
507
  modalities=[],
499
508
  image_token=conv_template.image_token,
509
+ video_token=conv_template.video_token,
510
+ audio_token=conv_template.audio_token,
500
511
  )
501
512
  real_content = ""
502
513
 
@@ -557,10 +568,12 @@ def generate_chat_conv(
557
568
  sep2=conv.sep2,
558
569
  stop_str=conv.stop_str,
559
570
  image_data=[],
571
+ video_data=[],
560
572
  audio_data=[],
561
573
  modalities=[],
562
574
  image_token=conv.image_token,
563
575
  audio_token=conv.audio_token,
576
+ video_token=conv.video_token,
564
577
  )
565
578
 
566
579
  if isinstance(request.messages, str):
@@ -602,6 +615,7 @@ def generate_chat_conv(
602
615
  image_token = ""
603
616
 
604
617
  audio_token = conv.audio_token
618
+ video_token = conv.video_token
605
619
  for content in message.content:
606
620
  if content.type == "text":
607
621
  if num_image_url > 16:
@@ -614,6 +628,9 @@ def generate_chat_conv(
614
628
  else:
615
629
  real_content += image_token
616
630
  conv.append_image(content.image_url.url)
631
+ elif content.type == "video_url":
632
+ real_content += video_token
633
+ conv.append_video(content.video_url.url)
617
634
  elif content.type == "audio_url":
618
635
  real_content += audio_token
619
636
  conv.append_audio(content.audio_url.url)
@@ -712,6 +729,7 @@ register_conv_template(
712
729
  sep="<|end|>",
713
730
  stop_str="<|end|>",
714
731
  image_token="<|endoftext10|>",
732
+ audio_token="<|endoftext11|>",
715
733
  )
716
734
  )
717
735
 
@@ -810,6 +828,7 @@ register_conv_template(
810
828
  sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
811
829
  stop_str=["<|im_end|>"],
812
830
  image_token="<|vision_start|><|image_pad|><|vision_end|>",
831
+ video_token="<|vision_start|><|video_pad|><|vision_end|>",
813
832
  )
814
833
  )
815
834
 
@@ -870,6 +889,7 @@ register_conv_template(
870
889
  sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
871
890
  stop_str=("<|im_end|>", "<|endoftext|>"),
872
891
  image_token="(<image>./</image>)",
892
+ video_token="(<video>./</video>)",
873
893
  )
874
894
  )
875
895
 
sglang/srt/custom_op.py CHANGED
@@ -29,15 +29,18 @@ class CustomOp(nn.Module):
29
29
 
30
30
  self._original_forward_method = self._forward_method
31
31
  # NOTE: Temporarily workaround MoE
32
+ # The performance of torch.compile on this layer is not always good when bs > 1,
33
+ # so we decide to only use torch.compile when bs=1
32
34
  if "FusedMoE" in self.__class__.__name__:
33
35
  if num_tokens == 1:
34
36
  from sglang.srt.layers.moe.fused_moe_native import (
35
37
  fused_moe_forward_native,
36
38
  )
37
39
 
38
- # The performance of torch.compile on this layer is not always good when bs > 1,
39
- # so we decide to only use torch.compile when bs =1
40
40
  self._forward_method = fused_moe_forward_native
41
+ elif "TopK" in self.__class__.__name__:
42
+ if num_tokens == 1:
43
+ self._forward_method = self.forward_native
41
44
  else:
42
45
  self._forward_method = self.forward_native
43
46
  self.is_torch_compile = True
@@ -0,0 +1,6 @@
1
+ from sglang.srt.disaggregation.ascend.conn import (
2
+ AscendKVBootstrapServer,
3
+ AscendKVManager,
4
+ AscendKVReceiver,
5
+ AscendKVSender,
6
+ )
@@ -0,0 +1,44 @@
1
+ import logging
2
+
3
+ from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
4
+ from sglang.srt.disaggregation.mooncake.conn import (
5
+ MooncakeKVBootstrapServer,
6
+ MooncakeKVManager,
7
+ MooncakeKVReceiver,
8
+ MooncakeKVSender,
9
+ )
10
+ from sglang.srt.utils import get_local_ip_by_remote
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class AscendKVManager(MooncakeKVManager):
16
+ def init_engine(self):
17
+ # TransferEngine initialized on ascend.
18
+ local_ip = get_local_ip_by_remote()
19
+ self.engine = AscendTransferEngine(
20
+ hostname=local_ip,
21
+ npu_id=self.kv_args.gpu_id,
22
+ disaggregation_mode=self.disaggregation_mode,
23
+ )
24
+
25
+ def register_buffer_to_engine(self):
26
+ self.engine.register(
27
+ self.kv_args.kv_data_ptrs[0], sum(self.kv_args.kv_data_lens)
28
+ )
29
+ # The Ascend backend optimize batch registration for small memory blocks.
30
+ self.engine.batch_register(
31
+ self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
32
+ )
33
+
34
+
35
+ class AscendKVSender(MooncakeKVSender):
36
+ pass
37
+
38
+
39
+ class AscendKVReceiver(MooncakeKVReceiver):
40
+ pass
41
+
42
+
43
+ class AscendKVBootstrapServer(MooncakeKVBootstrapServer):
44
+ pass
@@ -0,0 +1,58 @@
1
+ import logging
2
+ import os
3
+ from typing import List, Optional
4
+
5
+ from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
6
+ from sglang.srt.disaggregation.utils import DisaggregationMode
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class AscendTransferEngine(MooncakeTransferEngine):
12
+
13
+ def __init__(
14
+ self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
15
+ ):
16
+ try:
17
+ from mf_adapter import TransferEngine
18
+ except ImportError as e:
19
+ raise ImportError(
20
+ "Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
21
+ ) from e
22
+
23
+ self.engine = TransferEngine()
24
+ self.hostname = hostname
25
+ self.npu_id = npu_id
26
+
27
+ # Centralized storage address of the AscendTransferEngine
28
+ self.store_url = os.getenv("ASCEND_MF_STORE_URL")
29
+ if disaggregation_mode == DisaggregationMode.PREFILL:
30
+ self.role = "Prefill"
31
+ elif disaggregation_mode == DisaggregationMode.DECODE:
32
+ self.role = "Decode"
33
+ else:
34
+ logger.error(f"Unsupported DisaggregationMode: {disaggregation_mode}")
35
+ raise ValueError(f"Unsupported DisaggregationMode: {disaggregation_mode}")
36
+ self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
37
+ self.initialize()
38
+
39
+ def initialize(self) -> None:
40
+ """Initialize the ascend transfer instance."""
41
+ ret_value = self.engine.initialize(
42
+ self.store_url,
43
+ self.session_id,
44
+ self.role,
45
+ self.npu_id,
46
+ )
47
+ if ret_value != 0:
48
+ logger.error("Ascend Transfer Engine initialization failed.")
49
+ raise RuntimeError("Ascend Transfer Engine initialization failed.")
50
+
51
+ def batch_register(self, ptrs: List[int], lengths: List[int]):
52
+ try:
53
+ ret_value = self.engine.batch_register_memory(ptrs, lengths)
54
+ except Exception:
55
+ # Mark register as failed
56
+ ret_value = -1
57
+ if ret_value != 0:
58
+ logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
@@ -439,7 +439,15 @@ class DecodePreallocQueue:
439
439
  else 0
440
440
  )
441
441
 
442
- allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
442
+ if self.scheduler.model_config.is_hybrid:
443
+ available_size = min(
444
+ self.token_to_kv_pool_allocator.full_available_size(),
445
+ self.token_to_kv_pool_allocator.swa_available_size(),
446
+ )
447
+ else:
448
+ available_size = self.token_to_kv_pool_allocator.available_size()
449
+
450
+ allocatable_tokens = available_size - max(
443
451
  # preserve some space for future decode
444
452
  self.num_reserved_decode_tokens
445
453
  * (