sglang 0.4.9.post1__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. sglang/srt/configs/model_config.py +24 -1
  2. sglang/srt/conversation.py +21 -2
  3. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  4. sglang/srt/disaggregation/ascend/conn.py +44 -0
  5. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  6. sglang/srt/disaggregation/mooncake/conn.py +15 -14
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  8. sglang/srt/disaggregation/utils.py +25 -3
  9. sglang/srt/entrypoints/engine.py +1 -1
  10. sglang/srt/entrypoints/http_server.py +1 -0
  11. sglang/srt/entrypoints/openai/protocol.py +11 -0
  12. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/kimik2_detector.py +220 -0
  15. sglang/srt/hf_transformers_utils.py +18 -0
  16. sglang/srt/jinja_template_utils.py +8 -0
  17. sglang/srt/layers/communicator.py +17 -4
  18. sglang/srt/layers/linear.py +12 -2
  19. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  20. sglang/srt/layers/moe/ep_moe/layer.py +2 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -2
  22. sglang/srt/layers/moe/topk.py +8 -2
  23. sglang/srt/layers/parameter.py +19 -3
  24. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  25. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  26. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  27. sglang/srt/managers/io_struct.py +27 -2
  28. sglang/srt/managers/mm_utils.py +55 -94
  29. sglang/srt/managers/schedule_batch.py +16 -5
  30. sglang/srt/managers/scheduler.py +21 -1
  31. sglang/srt/managers/tokenizer_manager.py +16 -0
  32. sglang/srt/mem_cache/memory_pool.py +65 -40
  33. sglang/srt/model_executor/forward_batch_info.py +13 -1
  34. sglang/srt/model_loader/loader.py +23 -12
  35. sglang/srt/models/deepseek_janus_pro.py +1 -1
  36. sglang/srt/models/deepseek_v2.py +62 -17
  37. sglang/srt/models/deepseek_vl2.py +1 -1
  38. sglang/srt/models/gemma3_mm.py +1 -1
  39. sglang/srt/models/gemma3n_mm.py +6 -3
  40. sglang/srt/models/internvl.py +8 -2
  41. sglang/srt/models/kimi_vl.py +8 -2
  42. sglang/srt/models/llama.py +2 -0
  43. sglang/srt/models/llava.py +3 -1
  44. sglang/srt/models/llavavid.py +1 -1
  45. sglang/srt/models/minicpmo.py +1 -2
  46. sglang/srt/models/minicpmv.py +1 -1
  47. sglang/srt/models/mixtral_quant.py +4 -0
  48. sglang/srt/models/mllama4.py +13 -4
  49. sglang/srt/models/phi4mm.py +8 -2
  50. sglang/srt/models/phimoe.py +553 -0
  51. sglang/srt/models/qwen2.py +2 -0
  52. sglang/srt/models/qwen2_5_vl.py +10 -7
  53. sglang/srt/models/qwen2_vl.py +12 -1
  54. sglang/srt/models/vila.py +8 -2
  55. sglang/srt/multimodal/processors/base_processor.py +197 -137
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  57. sglang/srt/multimodal/processors/gemma3.py +4 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  59. sglang/srt/multimodal/processors/internvl.py +1 -1
  60. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  61. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  62. sglang/srt/multimodal/processors/minicpm.py +4 -3
  63. sglang/srt/multimodal/processors/mllama4.py +1 -1
  64. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  65. sglang/srt/multimodal/processors/pixtral.py +1 -1
  66. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  67. sglang/srt/multimodal/processors/vila.py +1 -1
  68. sglang/srt/server_args.py +11 -4
  69. sglang/srt/utils.py +154 -31
  70. sglang/version.py +1 -1
  71. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +4 -3
  72. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +75 -70
  73. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  75. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -65,6 +65,8 @@ class GenerateReqInput:
65
65
  ] = None
66
66
  # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
67
67
  audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
68
+ # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
69
+ video_data: Optional[Union[List[List[str]], List[str], str]] = None
68
70
  # The sampling_params. See descriptions below.
69
71
  sampling_params: Optional[Union[List[Dict], Dict]] = None
70
72
  # The request id.
@@ -110,7 +112,11 @@ class GenerateReqInput:
110
112
  data_parallel_rank: Optional[int] = None
111
113
 
112
114
  def contains_mm_input(self) -> bool:
113
- return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
115
+ return (
116
+ has_valid_data(self.image_data)
117
+ or has_valid_data(self.video_data)
118
+ or has_valid_data(self.audio_data)
119
+ )
114
120
 
115
121
  def normalize_batch_and_arguments(self):
116
122
  """
@@ -232,6 +238,7 @@ class GenerateReqInput:
232
238
  self._normalize_rid(num)
233
239
  self._normalize_lora_paths(num)
234
240
  self._normalize_image_data(num)
241
+ self._normalize_video_data(num)
235
242
  self._normalize_audio_data(num)
236
243
  self._normalize_sampling_params(num)
237
244
  self._normalize_logprob_params(num)
@@ -300,6 +307,15 @@ class GenerateReqInput:
300
307
  self.image_data = wrapped_images * self.parallel_sample_num
301
308
  self.modalities = ["image"] * num
302
309
 
310
+ def _normalize_video_data(self, num):
311
+ """Normalize video data for batch processing."""
312
+ if self.video_data is None:
313
+ self.video_data = [None] * num
314
+ elif not isinstance(self.video_data, list):
315
+ self.video_data = [self.video_data] * num
316
+ elif isinstance(self.video_data, list):
317
+ self.video_data = self.video_data * self.parallel_sample_num
318
+
303
319
  def _normalize_audio_data(self, num):
304
320
  """Normalize audio data for batch processing."""
305
321
  if self.audio_data is None:
@@ -408,6 +424,7 @@ class GenerateReqInput:
408
424
  self.input_embeds[i] if self.input_embeds is not None else None
409
425
  ),
410
426
  image_data=self.image_data[i],
427
+ video_data=self.video_data[i],
411
428
  audio_data=self.audio_data[i],
412
429
  sampling_params=self.sampling_params[i],
413
430
  rid=self.rid[i],
@@ -507,6 +524,8 @@ class EmbeddingReqInput:
507
524
  image_data: Optional[
508
525
  Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
509
526
  ] = None
527
+ # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
528
+ video_data: Optional[Union[List[str], str]] = None
510
529
  # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
511
530
  audio_data: Optional[Union[List[str], str]] = None
512
531
  # The token ids for text; one can either specify text or input_ids.
@@ -578,7 +597,11 @@ class EmbeddingReqInput:
578
597
  return self.rid
579
598
 
580
599
  def contains_mm_input(self) -> bool:
581
- return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
600
+ return (
601
+ has_valid_data(self.image_data)
602
+ or has_valid_data(self.video_data)
603
+ or has_valid_data(self.audio_data)
604
+ )
582
605
 
583
606
  def __getitem__(self, i):
584
607
  if self.is_cross_encoder_request:
@@ -905,6 +928,7 @@ class ProfileReqInput:
905
928
  # If set, it profile as many as this number of steps.
906
929
  # If it is set, profiling is automatically stopped after this step, and
907
930
  # the caller doesn't need to run stop_profile.
931
+ start_step: Optional[int] = None
908
932
  num_steps: Optional[int] = None
909
933
  activities: Optional[List[str]] = None
910
934
  profile_by_stage: bool = False
@@ -932,6 +956,7 @@ class ExpertDistributionReqOutput:
932
956
  class ProfileReq:
933
957
  type: ProfileReqType
934
958
  output_dir: Optional[str] = None
959
+ start_step: Optional[int] = None
935
960
  num_steps: Optional[int] = None
936
961
  activities: Optional[List[str]] = None
937
962
  profile_by_stage: bool = False
@@ -4,7 +4,7 @@ Multi-modality utils
4
4
 
5
5
  import hashlib
6
6
  from abc import abstractmethod
7
- from typing import Callable, List, Optional, Tuple
7
+ from typing import Callable, Dict, List, Optional, Tuple
8
8
 
9
9
  import numpy as np
10
10
  import torch
@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
76
76
  This function will replace the data-tokens in between with pad_values accordingly
77
77
  """
78
78
  pad_values = [item.pad_value for item in mm_inputs.mm_items]
79
+ print(f"{mm_inputs.mm_items=}")
79
80
  data_token_pairs = self.data_token_id_pairs
80
81
  mm_inputs.data_offsets = []
81
82
  if data_token_pairs is None:
@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
159
160
  return ret_input_ids
160
161
 
161
162
 
162
- embedding_cache = None
163
+ embedding_cache: Optional[MultiModalCache] = None
163
164
 
164
165
 
165
- def init_embedding_cache(max_size: int):
166
+ def init_embedding_cache(max_size: int = 0):
166
167
  global embedding_cache
167
168
  embedding_cache = MultiModalCache(max_size)
168
169
 
@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding(
255
256
  continue
256
257
  embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
257
258
  items_offset = items_offset_list[i]
259
+ assert items_offset is not None, items_offset
258
260
  embedding_items_hash = get_embedding_hash(embedding_items_per_req)
259
261
  # if all items has been prefixed, we do not need to calculate embedding
260
262
  if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
@@ -380,11 +382,9 @@ def embed_mm_inputs(
380
382
  extend_seq_lens: List[int],
381
383
  input_ids: torch.Tensor,
382
384
  input_embedding: nn.Embedding,
383
- image_data_embedding_func: Callable[
384
- [List[MultimodalDataItem]], torch.Tensor
385
- ] = None,
386
- audio_data_embedding_func: Callable[
387
- [List[MultimodalDataItem]], torch.Tensor
385
+ multimodal_model: nn.Module = None,
386
+ data_embedding_func_mapping: Dict[
387
+ Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
388
388
  ] = None,
389
389
  placeholder_tokens: dict[Modality, List[int]] = None,
390
390
  ) -> Optional[torch.Tensor]:
@@ -397,8 +397,6 @@ def embed_mm_inputs(
397
397
  extend_seq_lens: Sequence lengths for each request
398
398
  input_ids: Input token IDs tensor
399
399
  input_embedding: Embedding layer for text tokens
400
- image_data_embedding_func: Function to embed image data
401
- audio_data_embedding_func: Function to embed audio data
402
400
  placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
403
401
 
404
402
  Returns:
@@ -415,88 +413,53 @@ def embed_mm_inputs(
415
413
  item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
416
414
 
417
415
  embeddings, masks = [], []
418
-
419
416
  # 2. Get multimodal embedding separately
420
- # TODO: make this more generic
421
- # Try get image embedding if any
422
- if (
423
- any(True for item in item_flatten_list if item.is_image())
424
- and image_data_embedding_func
425
- ):
426
- items = [item for item in item_flatten_list if item.is_image()]
427
- placeholder_tensor = torch.tensor(
428
- [item.pad_value for item in items],
429
- device=input_ids.device,
417
+ # Try get mm embedding if any
418
+ for modality in Modality.all():
419
+ items = [
420
+ item for item in item_flatten_list if item.is_modality(modality=modality)
421
+ ]
422
+ embedder = (
423
+ None
424
+ if data_embedding_func_mapping is None
425
+ else data_embedding_func_mapping.get(modality, None)
430
426
  )
431
- # calculate per request items length offset
432
- items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
433
- items_offsets = []
434
- for i, mm_inputs in enumerate(mm_inputs_list):
435
- image_items = [item for item in mm_inputs.mm_items if item.is_image()]
436
- items_size[i + 1] = len(image_items)
437
- items_offsets.append(
438
- flatten_nested_list(
439
- [
440
- item.image_offsets
441
- for item in mm_inputs.mm_items
442
- if item.is_image()
443
- ]
444
- )
427
+ if embedder is None:
428
+ # "image", "video", etc
429
+ modality_id = modality.name.lower()
430
+ embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
431
+ if len(items) != 0 and embedder is not None:
432
+ placeholder_tensor = torch.tensor(
433
+ [item.pad_value for item in items],
434
+ device=input_ids.device,
445
435
  )
446
- items_size = torch.cumsum(items_size, dim=0).tolist()
447
-
448
- embedding, mask = get_embedding_and_mask(
449
- data_embedding_func=image_data_embedding_func,
450
- embedding_items=items,
451
- placeholder_tensor=placeholder_tensor,
452
- input_ids=input_ids,
453
- items_size=items_size,
454
- prefix_length=extend_prefix_lens,
455
- extend_length=extend_seq_lens,
456
- items_offset_list=items_offsets,
457
- )
458
- embeddings += [embedding]
459
- masks += [mask]
460
-
461
- # Try get audio embedding if any
462
- if (
463
- any(True for item in item_flatten_list if item.is_audio())
464
- and audio_data_embedding_func
465
- ):
466
- items = [item for item in item_flatten_list if item.is_audio()]
467
- placeholder_tensor = torch.tensor(
468
- [item.pad_value for item in items],
469
- device=input_ids.device,
470
- )
471
- items_offsets = []
472
- # calculate per request items length offset
473
- items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
474
- for i, mm_inputs in enumerate(mm_inputs_list):
475
- audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
476
- items_size[i + 1] = len(audio_items)
477
- items_offsets.append(
478
- flatten_nested_list(
479
- [
480
- item.audio_offsets
481
- for item in mm_inputs.mm_items
482
- if item.is_audio()
483
- ]
436
+ # calculate per request items length offset
437
+ items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
438
+ items_offsets = []
439
+ for i, mm_inputs in enumerate(mm_inputs_list):
440
+ mm_items = [
441
+ item
442
+ for item in mm_inputs.mm_items
443
+ if item.is_modality(modality=modality)
444
+ ]
445
+ items_size[i + 1] = len(mm_items)
446
+ items_offsets.append(
447
+ flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
484
448
  )
449
+ items_size = torch.cumsum(items_size, dim=0).tolist()
450
+
451
+ embedding, mask = get_embedding_and_mask(
452
+ data_embedding_func=embedder,
453
+ embedding_items=items,
454
+ placeholder_tensor=placeholder_tensor,
455
+ input_ids=input_ids,
456
+ items_size=items_size,
457
+ prefix_length=extend_prefix_lens,
458
+ extend_length=extend_seq_lens,
459
+ items_offset_list=items_offsets,
485
460
  )
486
- items_size = torch.cumsum(items_size, dim=0)
487
-
488
- embedding, mask = get_embedding_and_mask(
489
- data_embedding_func=audio_data_embedding_func,
490
- embedding_items=items,
491
- placeholder_tensor=placeholder_tensor,
492
- input_ids=input_ids,
493
- items_size=items_size,
494
- prefix_length=extend_prefix_lens,
495
- extend_length=extend_seq_lens,
496
- items_offset_list=items_offsets,
497
- )
498
- embeddings += [embedding]
499
- masks += [mask]
461
+ embeddings += [embedding]
462
+ masks += [mask]
500
463
 
501
464
  # 3. Get input embeddings
502
465
  vocab_size = input_embedding.num_embeddings
@@ -523,11 +486,9 @@ def general_mm_embed_routine(
523
486
  input_ids: torch.Tensor,
524
487
  forward_batch: ForwardBatch,
525
488
  language_model: nn.Module,
526
- image_data_embedding_func: Optional[
527
- Callable[[List[MultimodalDataItem]], torch.Tensor]
528
- ] = None,
529
- audio_data_embedding_func: Optional[
530
- Callable[[List[MultimodalDataItem]], torch.Tensor]
489
+ multimodal_model: Optional[nn.Module] = None,
490
+ data_embedding_funcs: Dict[
491
+ Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
531
492
  ] = None,
532
493
  placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
533
494
  **kwargs,
@@ -572,8 +533,8 @@ def general_mm_embed_routine(
572
533
  extend_seq_lens=extend_seq_lens,
573
534
  input_ids=input_ids,
574
535
  input_embedding=embed_tokens,
575
- image_data_embedding_func=image_data_embedding_func,
576
- audio_data_embedding_func=audio_data_embedding_func,
536
+ multimodal_model=multimodal_model,
537
+ data_embedding_func_mapping=data_embedding_funcs,
577
538
  placeholder_tokens=placeholder_tokens,
578
539
  )
579
540
  # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
@@ -185,6 +185,10 @@ class Modality(Enum):
185
185
  f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
186
186
  )
187
187
 
188
+ @staticmethod
189
+ def all():
190
+ return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]
191
+
188
192
 
189
193
  @dataclasses.dataclass
190
194
  class MultimodalDataItem:
@@ -200,7 +204,7 @@ class MultimodalDataItem:
200
204
  hash: int = None
201
205
  pad_value: int = None
202
206
  image_sizes: Tuple[int, int] = None
203
- image_offsets: Optional[list] = None
207
+ offsets: Optional[list] = None
204
208
 
205
209
  # the real data, pixel_values or audio_features
206
210
  # data: Union[List[torch.Tensor], List[np.ndarray]]
@@ -253,12 +257,17 @@ class MultimodalDataItem:
253
257
  self.hash = hash_feature(self.audio_features)
254
258
  elif self.input_features is not None:
255
259
  self.hash = hash_feature(self.input_features)
260
+ elif self.is_video():
261
+ self.hash = hash_feature(self.pixel_values_videos)
256
262
  else:
257
263
  self.hash = hash_feature(self.pixel_values)
258
264
 
259
265
  assert self.hash is not None
260
266
  self.pad_value = self.hash % (1 << 30)
261
267
 
268
+ def is_modality(self, modality: Modality) -> bool:
269
+ return self.modality == modality
270
+
262
271
  def is_audio(self):
263
272
  return (self.modality == Modality.AUDIO) and (
264
273
  self.precomputed_features is not None
@@ -268,7 +277,7 @@ class MultimodalDataItem:
268
277
 
269
278
  def is_image(self):
270
279
  return (
271
- self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
280
+ self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
272
281
  ) and (
273
282
  self.precomputed_features is not None
274
283
  or not MultimodalDataItem.is_empty_list(self.pixel_values)
@@ -277,7 +286,7 @@ class MultimodalDataItem:
277
286
  def is_video(self):
278
287
  return (self.modality == Modality.VIDEO) and (
279
288
  self.precomputed_features is not None
280
- or not MultimodalDataItem.is_empty_list(self.pixel_values)
289
+ or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
281
290
  )
282
291
 
283
292
  def is_valid(self) -> bool:
@@ -351,6 +360,7 @@ class MultimodalInputs:
351
360
  "im_token_id",
352
361
  "im_start_id",
353
362
  "im_end_id",
363
+ "video_token_id",
354
364
  "slice_start_id",
355
365
  "slice_end_id",
356
366
  "audio_start_id",
@@ -364,11 +374,12 @@ class MultimodalInputs:
364
374
  return ret
365
375
 
366
376
  def contains_image_inputs(self) -> bool:
367
- """ """
368
377
  return any(item.is_image() for item in self.mm_items)
369
378
 
379
+ def contains_video_inputs(self) -> bool:
380
+ return any(item.is_video() for item in self.mm_items)
381
+
370
382
  def contains_audio_inputs(self) -> bool:
371
- """ """
372
383
  return any(item.is_audio() for item in self.mm_items)
373
384
 
374
385
  def contains_mm_input(self) -> bool:
@@ -485,6 +485,8 @@ class Scheduler(
485
485
  enable=server_args.enable_memory_saver
486
486
  )
487
487
  self.init_profier()
488
+
489
+ # Init metrics stats
488
490
  self.init_metrics()
489
491
  self.init_kv_events(server_args.kv_events_config)
490
492
 
@@ -628,6 +630,7 @@ class Scheduler(
628
630
  self.torch_profiler_output_dir: Optional[str] = None
629
631
  self.profiler_activities: Optional[List[str]] = None
630
632
  self.profile_id: Optional[str] = None
633
+ self.profiler_start_forward_ct: Optional[int] = None
631
634
  self.profiler_target_forward_ct: Optional[int] = None
632
635
  self.profiler_target_prefill_ct: Optional[int] = None
633
636
  self.profiler_target_decode_ct: Optional[int] = None
@@ -2389,9 +2392,10 @@ class Scheduler(
2389
2392
 
2390
2393
  def profile(self, recv_req: ProfileReq):
2391
2394
  if recv_req.type == ProfileReqType.START_PROFILE:
2392
- if recv_req.profile_by_stage:
2395
+ if recv_req.profile_by_stage or recv_req.start_step:
2393
2396
  return self.init_profile(
2394
2397
  recv_req.output_dir,
2398
+ recv_req.start_step,
2395
2399
  recv_req.num_steps,
2396
2400
  recv_req.activities,
2397
2401
  recv_req.with_stack,
@@ -2402,6 +2406,7 @@ class Scheduler(
2402
2406
  else:
2403
2407
  self.init_profile(
2404
2408
  recv_req.output_dir,
2409
+ recv_req.start_step,
2405
2410
  recv_req.num_steps,
2406
2411
  recv_req.activities,
2407
2412
  recv_req.with_stack,
@@ -2416,6 +2421,7 @@ class Scheduler(
2416
2421
  def init_profile(
2417
2422
  self,
2418
2423
  output_dir: Optional[str],
2424
+ start_step: Optional[int],
2419
2425
  num_steps: Optional[int],
2420
2426
  activities: Optional[List[str]],
2421
2427
  with_stack: Optional[bool],
@@ -2442,6 +2448,9 @@ class Scheduler(
2442
2448
  self.profiler_activities = activities
2443
2449
  self.profile_id = profile_id
2444
2450
 
2451
+ if start_step:
2452
+ self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
2453
+
2445
2454
  if num_steps:
2446
2455
  self.profile_steps = num_steps
2447
2456
  if self.profile_by_stage:
@@ -2449,6 +2458,10 @@ class Scheduler(
2449
2458
  self.profiler_target_decode_ct = num_steps
2450
2459
  self.profiler_prefill_ct = 0
2451
2460
  self.profiler_decode_ct = 0
2461
+ elif start_step:
2462
+ self.profiler_target_forward_ct = (
2463
+ self.profiler_start_forward_ct + num_steps
2464
+ )
2452
2465
  else:
2453
2466
  self.profiler_target_forward_ct = self.forward_ct + num_steps
2454
2467
  # The caller will be notified when reaching profiler_target_forward_ct
@@ -2521,6 +2534,7 @@ class Scheduler(
2521
2534
 
2522
2535
  if "CUDA_PROFILER" in activities:
2523
2536
  torch.cuda.cudart().cudaProfilerStart()
2537
+ self.profile_in_progress = True
2524
2538
 
2525
2539
  return ProfileReqOutput(success=True, message="Succeeded")
2526
2540
 
@@ -2584,6 +2598,7 @@ class Scheduler(
2584
2598
  )
2585
2599
  self.torch_profiler = None
2586
2600
  self.profile_in_progress = False
2601
+ self.profiler_start_forward_ct = None
2587
2602
 
2588
2603
  return ProfileReqOutput(success=True, message="Succeeded.")
2589
2604
 
@@ -2617,6 +2632,11 @@ class Scheduler(
2617
2632
  and self.profiler_target_forward_ct <= self.forward_ct
2618
2633
  ):
2619
2634
  self.stop_profile()
2635
+ if (
2636
+ self.profiler_start_forward_ct
2637
+ and self.profiler_start_forward_ct == self.forward_ct
2638
+ ):
2639
+ self.start_profile()
2620
2640
 
2621
2641
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2622
2642
  if recv_req == ExpertDistributionReq.START_RECORD:
@@ -285,6 +285,20 @@ class TokenizerManager:
285
285
  self.bootstrap_server = kv_bootstrap_server_class(
286
286
  self.server_args.disaggregation_bootstrap_port
287
287
  )
288
+ is_create_store = (
289
+ self.server_args.node_rank == 0
290
+ and self.server_args.disaggregation_transfer_backend == "ascend"
291
+ )
292
+ if is_create_store:
293
+ try:
294
+ from mf_adapter import create_config_store
295
+
296
+ ascend_url = os.getenv("ASCEND_MF_STORE_URL")
297
+ create_config_store(ascend_url)
298
+ except Exception as e:
299
+ error_message = f"Failed create mf store, invalid ascend_url."
300
+ error_message += f" With exception {e}"
301
+ raise error_message
288
302
 
289
303
  # For load balancing
290
304
  self.current_load = 0
@@ -863,6 +877,7 @@ class TokenizerManager:
863
877
  async def start_profile(
864
878
  self,
865
879
  output_dir: Optional[str] = None,
880
+ start_step: Optional[int] = None,
866
881
  num_steps: Optional[int] = None,
867
882
  activities: Optional[List[str]] = None,
868
883
  with_stack: Optional[bool] = None,
@@ -875,6 +890,7 @@ class TokenizerManager:
875
890
  req = ProfileReq(
876
891
  type=ProfileReqType.START_PROFILE,
877
892
  output_dir=output_dir,
893
+ start_step=start_step,
878
894
  num_steps=num_steps,
879
895
  activities=activities,
880
896
  with_stack=with_stack,
@@ -34,16 +34,18 @@ import torch
34
34
  import torch.distributed as dist
35
35
  import triton
36
36
  import triton.language as tl
37
- from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
38
37
 
39
38
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
40
39
  from sglang.srt.layers.radix_attention import RadixAttention
41
- from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
40
+ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
42
41
 
43
42
  logger = logging.getLogger(__name__)
44
43
 
45
44
  GB = 1024 * 1024 * 1024
46
45
  _is_cuda = is_cuda()
46
+ _is_npu = is_npu()
47
+ if not _is_npu:
48
+ from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
47
49
 
48
50
 
49
51
  class ReqToTokenPool:
@@ -602,32 +604,49 @@ class AscendTokenToKVPool(MHATokenToKVPool):
602
604
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
603
605
  # [size, head_num, head_dim] for each layer
604
606
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
605
- self.k_buffer = [
606
- torch.zeros(
607
- (
608
- self.size // self.page_size + 1,
609
- self.page_size,
610
- self.head_num,
611
- self.head_dim,
612
- ),
613
- dtype=self.store_dtype,
614
- device=self.device,
615
- )
616
- for _ in range(self.layer_num)
617
- ]
618
- self.v_buffer = [
619
- torch.zeros(
620
- (
621
- self.size // self.page_size + 1,
622
- self.page_size,
623
- self.head_num,
624
- self.head_dim,
625
- ),
626
- dtype=self.store_dtype,
627
- device=self.device,
628
- )
629
- for _ in range(self.layer_num)
630
- ]
607
+ # Continuous memory improves the efficiency of Ascend`s transmission backend,
608
+ # while other backends remain unchanged.
609
+ self.kv_buffer = torch.zeros(
610
+ (
611
+ 2,
612
+ self.layer_num,
613
+ self.size // self.page_size + 1,
614
+ self.page_size,
615
+ self.head_num,
616
+ self.head_dim,
617
+ ),
618
+ dtype=self.store_dtype,
619
+ device=self.device,
620
+ )
621
+ self.k_buffer = self.kv_buffer[0]
622
+ self.v_buffer = self.kv_buffer[1]
623
+
624
+ # for disagg
625
+ def get_contiguous_buf_infos(self):
626
+ # layer_num x [seq_len, head_num, head_dim]
627
+ # layer_num x [page_num, page_size, head_num, head_dim]
628
+ kv_data_ptrs = [
629
+ self.get_key_buffer(i).data_ptr()
630
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
631
+ ] + [
632
+ self.get_value_buffer(i).data_ptr()
633
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
634
+ ]
635
+ kv_data_lens = [
636
+ self.get_key_buffer(i).nbytes
637
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
638
+ ] + [
639
+ self.get_value_buffer(i).nbytes
640
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
641
+ ]
642
+ kv_item_lens = [
643
+ self.get_key_buffer(i)[0].nbytes
644
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
645
+ ] + [
646
+ self.get_value_buffer(i)[0].nbytes
647
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
648
+ ]
649
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
631
650
 
632
651
  def set_kv_buffer(
633
652
  self,
@@ -967,18 +986,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
967
986
 
968
987
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
969
988
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
970
- self.kv_buffer = [
971
- torch.zeros(
972
- (
973
- self.size // self.page_size + 1,
974
- self.page_size,
975
- self.kv_lora_rank + self.qk_rope_head_dim,
976
- ),
977
- dtype=self.store_dtype,
978
- device=self.device,
979
- )
980
- for _ in range(layer_num)
981
- ]
989
+ self.kv_buffer = torch.zeros(
990
+ (
991
+ layer_num,
992
+ self.size // self.page_size + 1,
993
+ self.page_size,
994
+ self.kv_lora_rank + self.qk_rope_head_dim,
995
+ ),
996
+ dtype=self.store_dtype,
997
+ device=self.device,
998
+ )
982
999
 
983
1000
  self.layer_transfer_counter = None
984
1001
 
@@ -988,6 +1005,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
988
1005
  )
989
1006
  self.mem_usage = kv_size / GB
990
1007
 
1008
+ # for disagg
1009
+ def get_contiguous_buf_infos(self):
1010
+ # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
1011
+ kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
1012
+ kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
1013
+ kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
1014
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
1015
+
991
1016
  def set_kv_buffer(
992
1017
  self,
993
1018
  layer: RadixAttention,
@@ -453,8 +453,20 @@ class ForwardBatch:
453
453
  for mm_input in self.mm_inputs
454
454
  )
455
455
 
456
+ def contains_video_inputs(self) -> bool:
457
+ if self.mm_inputs is None:
458
+ return False
459
+ return any(
460
+ mm_input is not None and mm_input.contains_video_inputs()
461
+ for mm_input in self.mm_inputs
462
+ )
463
+
456
464
  def contains_mm_inputs(self) -> bool:
457
- return self.contains_audio_inputs() or self.contains_image_inputs()
465
+ return (
466
+ self.contains_audio_inputs()
467
+ or self.contains_video_inputs()
468
+ or self.contains_image_inputs()
469
+ )
458
470
 
459
471
  def _compute_mrope_positions(
460
472
  self, model_runner: ModelRunner, batch: ModelWorkerBatch