sglang 0.4.9.post1__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +24 -1
- sglang/srt/conversation.py +21 -2
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +15 -14
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +17 -4
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -2
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/managers/io_struct.py +27 -2
- sglang/srt/managers/mm_utils.py +55 -94
- sglang/srt/managers/schedule_batch.py +16 -5
- sglang/srt/managers/scheduler.py +21 -1
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/memory_pool.py +65 -40
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +62 -17
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +13 -4
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +1 -1
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +11 -4
- sglang/srt/utils.py +154 -31
- sglang/version.py +1 -1
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +4 -3
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +75 -70
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -65,6 +65,8 @@ class GenerateReqInput:
|
|
65
65
|
] = None
|
66
66
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
67
67
|
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
68
|
+
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
69
|
+
video_data: Optional[Union[List[List[str]], List[str], str]] = None
|
68
70
|
# The sampling_params. See descriptions below.
|
69
71
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
70
72
|
# The request id.
|
@@ -110,7 +112,11 @@ class GenerateReqInput:
|
|
110
112
|
data_parallel_rank: Optional[int] = None
|
111
113
|
|
112
114
|
def contains_mm_input(self) -> bool:
|
113
|
-
return
|
115
|
+
return (
|
116
|
+
has_valid_data(self.image_data)
|
117
|
+
or has_valid_data(self.video_data)
|
118
|
+
or has_valid_data(self.audio_data)
|
119
|
+
)
|
114
120
|
|
115
121
|
def normalize_batch_and_arguments(self):
|
116
122
|
"""
|
@@ -232,6 +238,7 @@ class GenerateReqInput:
|
|
232
238
|
self._normalize_rid(num)
|
233
239
|
self._normalize_lora_paths(num)
|
234
240
|
self._normalize_image_data(num)
|
241
|
+
self._normalize_video_data(num)
|
235
242
|
self._normalize_audio_data(num)
|
236
243
|
self._normalize_sampling_params(num)
|
237
244
|
self._normalize_logprob_params(num)
|
@@ -300,6 +307,15 @@ class GenerateReqInput:
|
|
300
307
|
self.image_data = wrapped_images * self.parallel_sample_num
|
301
308
|
self.modalities = ["image"] * num
|
302
309
|
|
310
|
+
def _normalize_video_data(self, num):
|
311
|
+
"""Normalize video data for batch processing."""
|
312
|
+
if self.video_data is None:
|
313
|
+
self.video_data = [None] * num
|
314
|
+
elif not isinstance(self.video_data, list):
|
315
|
+
self.video_data = [self.video_data] * num
|
316
|
+
elif isinstance(self.video_data, list):
|
317
|
+
self.video_data = self.video_data * self.parallel_sample_num
|
318
|
+
|
303
319
|
def _normalize_audio_data(self, num):
|
304
320
|
"""Normalize audio data for batch processing."""
|
305
321
|
if self.audio_data is None:
|
@@ -408,6 +424,7 @@ class GenerateReqInput:
|
|
408
424
|
self.input_embeds[i] if self.input_embeds is not None else None
|
409
425
|
),
|
410
426
|
image_data=self.image_data[i],
|
427
|
+
video_data=self.video_data[i],
|
411
428
|
audio_data=self.audio_data[i],
|
412
429
|
sampling_params=self.sampling_params[i],
|
413
430
|
rid=self.rid[i],
|
@@ -507,6 +524,8 @@ class EmbeddingReqInput:
|
|
507
524
|
image_data: Optional[
|
508
525
|
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
509
526
|
] = None
|
527
|
+
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
528
|
+
video_data: Optional[Union[List[str], str]] = None
|
510
529
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
511
530
|
audio_data: Optional[Union[List[str], str]] = None
|
512
531
|
# The token ids for text; one can either specify text or input_ids.
|
@@ -578,7 +597,11 @@ class EmbeddingReqInput:
|
|
578
597
|
return self.rid
|
579
598
|
|
580
599
|
def contains_mm_input(self) -> bool:
|
581
|
-
return
|
600
|
+
return (
|
601
|
+
has_valid_data(self.image_data)
|
602
|
+
or has_valid_data(self.video_data)
|
603
|
+
or has_valid_data(self.audio_data)
|
604
|
+
)
|
582
605
|
|
583
606
|
def __getitem__(self, i):
|
584
607
|
if self.is_cross_encoder_request:
|
@@ -905,6 +928,7 @@ class ProfileReqInput:
|
|
905
928
|
# If set, it profile as many as this number of steps.
|
906
929
|
# If it is set, profiling is automatically stopped after this step, and
|
907
930
|
# the caller doesn't need to run stop_profile.
|
931
|
+
start_step: Optional[int] = None
|
908
932
|
num_steps: Optional[int] = None
|
909
933
|
activities: Optional[List[str]] = None
|
910
934
|
profile_by_stage: bool = False
|
@@ -932,6 +956,7 @@ class ExpertDistributionReqOutput:
|
|
932
956
|
class ProfileReq:
|
933
957
|
type: ProfileReqType
|
934
958
|
output_dir: Optional[str] = None
|
959
|
+
start_step: Optional[int] = None
|
935
960
|
num_steps: Optional[int] = None
|
936
961
|
activities: Optional[List[str]] = None
|
937
962
|
profile_by_stage: bool = False
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -4,7 +4,7 @@ Multi-modality utils
|
|
4
4
|
|
5
5
|
import hashlib
|
6
6
|
from abc import abstractmethod
|
7
|
-
from typing import Callable, List, Optional, Tuple
|
7
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import torch
|
@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
76
76
|
This function will replace the data-tokens in between with pad_values accordingly
|
77
77
|
"""
|
78
78
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
79
|
+
print(f"{mm_inputs.mm_items=}")
|
79
80
|
data_token_pairs = self.data_token_id_pairs
|
80
81
|
mm_inputs.data_offsets = []
|
81
82
|
if data_token_pairs is None:
|
@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|
159
160
|
return ret_input_ids
|
160
161
|
|
161
162
|
|
162
|
-
embedding_cache = None
|
163
|
+
embedding_cache: Optional[MultiModalCache] = None
|
163
164
|
|
164
165
|
|
165
|
-
def init_embedding_cache(max_size: int):
|
166
|
+
def init_embedding_cache(max_size: int = 0):
|
166
167
|
global embedding_cache
|
167
168
|
embedding_cache = MultiModalCache(max_size)
|
168
169
|
|
@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding(
|
|
255
256
|
continue
|
256
257
|
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
257
258
|
items_offset = items_offset_list[i]
|
259
|
+
assert items_offset is not None, items_offset
|
258
260
|
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
|
259
261
|
# if all items has been prefixed, we do not need to calculate embedding
|
260
262
|
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
|
@@ -380,11 +382,9 @@ def embed_mm_inputs(
|
|
380
382
|
extend_seq_lens: List[int],
|
381
383
|
input_ids: torch.Tensor,
|
382
384
|
input_embedding: nn.Embedding,
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
audio_data_embedding_func: Callable[
|
387
|
-
[List[MultimodalDataItem]], torch.Tensor
|
385
|
+
multimodal_model: nn.Module = None,
|
386
|
+
data_embedding_func_mapping: Dict[
|
387
|
+
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
388
388
|
] = None,
|
389
389
|
placeholder_tokens: dict[Modality, List[int]] = None,
|
390
390
|
) -> Optional[torch.Tensor]:
|
@@ -397,8 +397,6 @@ def embed_mm_inputs(
|
|
397
397
|
extend_seq_lens: Sequence lengths for each request
|
398
398
|
input_ids: Input token IDs tensor
|
399
399
|
input_embedding: Embedding layer for text tokens
|
400
|
-
image_data_embedding_func: Function to embed image data
|
401
|
-
audio_data_embedding_func: Function to embed audio data
|
402
400
|
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
|
403
401
|
|
404
402
|
Returns:
|
@@ -415,88 +413,53 @@ def embed_mm_inputs(
|
|
415
413
|
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
416
414
|
|
417
415
|
embeddings, masks = [], []
|
418
|
-
|
419
416
|
# 2. Get multimodal embedding separately
|
420
|
-
#
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
device=input_ids.device,
|
417
|
+
# Try get mm embedding if any
|
418
|
+
for modality in Modality.all():
|
419
|
+
items = [
|
420
|
+
item for item in item_flatten_list if item.is_modality(modality=modality)
|
421
|
+
]
|
422
|
+
embedder = (
|
423
|
+
None
|
424
|
+
if data_embedding_func_mapping is None
|
425
|
+
else data_embedding_func_mapping.get(modality, None)
|
430
426
|
)
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
[
|
440
|
-
item.image_offsets
|
441
|
-
for item in mm_inputs.mm_items
|
442
|
-
if item.is_image()
|
443
|
-
]
|
444
|
-
)
|
427
|
+
if embedder is None:
|
428
|
+
# "image", "video", etc
|
429
|
+
modality_id = modality.name.lower()
|
430
|
+
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
431
|
+
if len(items) != 0 and embedder is not None:
|
432
|
+
placeholder_tensor = torch.tensor(
|
433
|
+
[item.pad_value for item in items],
|
434
|
+
device=input_ids.device,
|
445
435
|
)
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
embeddings += [embedding]
|
459
|
-
masks += [mask]
|
460
|
-
|
461
|
-
# Try get audio embedding if any
|
462
|
-
if (
|
463
|
-
any(True for item in item_flatten_list if item.is_audio())
|
464
|
-
and audio_data_embedding_func
|
465
|
-
):
|
466
|
-
items = [item for item in item_flatten_list if item.is_audio()]
|
467
|
-
placeholder_tensor = torch.tensor(
|
468
|
-
[item.pad_value for item in items],
|
469
|
-
device=input_ids.device,
|
470
|
-
)
|
471
|
-
items_offsets = []
|
472
|
-
# calculate per request items length offset
|
473
|
-
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
474
|
-
for i, mm_inputs in enumerate(mm_inputs_list):
|
475
|
-
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
|
476
|
-
items_size[i + 1] = len(audio_items)
|
477
|
-
items_offsets.append(
|
478
|
-
flatten_nested_list(
|
479
|
-
[
|
480
|
-
item.audio_offsets
|
481
|
-
for item in mm_inputs.mm_items
|
482
|
-
if item.is_audio()
|
483
|
-
]
|
436
|
+
# calculate per request items length offset
|
437
|
+
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
438
|
+
items_offsets = []
|
439
|
+
for i, mm_inputs in enumerate(mm_inputs_list):
|
440
|
+
mm_items = [
|
441
|
+
item
|
442
|
+
for item in mm_inputs.mm_items
|
443
|
+
if item.is_modality(modality=modality)
|
444
|
+
]
|
445
|
+
items_size[i + 1] = len(mm_items)
|
446
|
+
items_offsets.append(
|
447
|
+
flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
|
484
448
|
)
|
449
|
+
items_size = torch.cumsum(items_size, dim=0).tolist()
|
450
|
+
|
451
|
+
embedding, mask = get_embedding_and_mask(
|
452
|
+
data_embedding_func=embedder,
|
453
|
+
embedding_items=items,
|
454
|
+
placeholder_tensor=placeholder_tensor,
|
455
|
+
input_ids=input_ids,
|
456
|
+
items_size=items_size,
|
457
|
+
prefix_length=extend_prefix_lens,
|
458
|
+
extend_length=extend_seq_lens,
|
459
|
+
items_offset_list=items_offsets,
|
485
460
|
)
|
486
|
-
|
487
|
-
|
488
|
-
embedding, mask = get_embedding_and_mask(
|
489
|
-
data_embedding_func=audio_data_embedding_func,
|
490
|
-
embedding_items=items,
|
491
|
-
placeholder_tensor=placeholder_tensor,
|
492
|
-
input_ids=input_ids,
|
493
|
-
items_size=items_size,
|
494
|
-
prefix_length=extend_prefix_lens,
|
495
|
-
extend_length=extend_seq_lens,
|
496
|
-
items_offset_list=items_offsets,
|
497
|
-
)
|
498
|
-
embeddings += [embedding]
|
499
|
-
masks += [mask]
|
461
|
+
embeddings += [embedding]
|
462
|
+
masks += [mask]
|
500
463
|
|
501
464
|
# 3. Get input embeddings
|
502
465
|
vocab_size = input_embedding.num_embeddings
|
@@ -523,11 +486,9 @@ def general_mm_embed_routine(
|
|
523
486
|
input_ids: torch.Tensor,
|
524
487
|
forward_batch: ForwardBatch,
|
525
488
|
language_model: nn.Module,
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
audio_data_embedding_func: Optional[
|
530
|
-
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
489
|
+
multimodal_model: Optional[nn.Module] = None,
|
490
|
+
data_embedding_funcs: Dict[
|
491
|
+
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
531
492
|
] = None,
|
532
493
|
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
533
494
|
**kwargs,
|
@@ -572,8 +533,8 @@ def general_mm_embed_routine(
|
|
572
533
|
extend_seq_lens=extend_seq_lens,
|
573
534
|
input_ids=input_ids,
|
574
535
|
input_embedding=embed_tokens,
|
575
|
-
|
576
|
-
|
536
|
+
multimodal_model=multimodal_model,
|
537
|
+
data_embedding_func_mapping=data_embedding_funcs,
|
577
538
|
placeholder_tokens=placeholder_tokens,
|
578
539
|
)
|
579
540
|
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
@@ -185,6 +185,10 @@ class Modality(Enum):
|
|
185
185
|
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
|
186
186
|
)
|
187
187
|
|
188
|
+
@staticmethod
|
189
|
+
def all():
|
190
|
+
return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]
|
191
|
+
|
188
192
|
|
189
193
|
@dataclasses.dataclass
|
190
194
|
class MultimodalDataItem:
|
@@ -200,7 +204,7 @@ class MultimodalDataItem:
|
|
200
204
|
hash: int = None
|
201
205
|
pad_value: int = None
|
202
206
|
image_sizes: Tuple[int, int] = None
|
203
|
-
|
207
|
+
offsets: Optional[list] = None
|
204
208
|
|
205
209
|
# the real data, pixel_values or audio_features
|
206
210
|
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
@@ -253,12 +257,17 @@ class MultimodalDataItem:
|
|
253
257
|
self.hash = hash_feature(self.audio_features)
|
254
258
|
elif self.input_features is not None:
|
255
259
|
self.hash = hash_feature(self.input_features)
|
260
|
+
elif self.is_video():
|
261
|
+
self.hash = hash_feature(self.pixel_values_videos)
|
256
262
|
else:
|
257
263
|
self.hash = hash_feature(self.pixel_values)
|
258
264
|
|
259
265
|
assert self.hash is not None
|
260
266
|
self.pad_value = self.hash % (1 << 30)
|
261
267
|
|
268
|
+
def is_modality(self, modality: Modality) -> bool:
|
269
|
+
return self.modality == modality
|
270
|
+
|
262
271
|
def is_audio(self):
|
263
272
|
return (self.modality == Modality.AUDIO) and (
|
264
273
|
self.precomputed_features is not None
|
@@ -268,7 +277,7 @@ class MultimodalDataItem:
|
|
268
277
|
|
269
278
|
def is_image(self):
|
270
279
|
return (
|
271
|
-
self.
|
280
|
+
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
|
272
281
|
) and (
|
273
282
|
self.precomputed_features is not None
|
274
283
|
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
@@ -277,7 +286,7 @@ class MultimodalDataItem:
|
|
277
286
|
def is_video(self):
|
278
287
|
return (self.modality == Modality.VIDEO) and (
|
279
288
|
self.precomputed_features is not None
|
280
|
-
or not MultimodalDataItem.is_empty_list(self.
|
289
|
+
or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
|
281
290
|
)
|
282
291
|
|
283
292
|
def is_valid(self) -> bool:
|
@@ -351,6 +360,7 @@ class MultimodalInputs:
|
|
351
360
|
"im_token_id",
|
352
361
|
"im_start_id",
|
353
362
|
"im_end_id",
|
363
|
+
"video_token_id",
|
354
364
|
"slice_start_id",
|
355
365
|
"slice_end_id",
|
356
366
|
"audio_start_id",
|
@@ -364,11 +374,12 @@ class MultimodalInputs:
|
|
364
374
|
return ret
|
365
375
|
|
366
376
|
def contains_image_inputs(self) -> bool:
|
367
|
-
""" """
|
368
377
|
return any(item.is_image() for item in self.mm_items)
|
369
378
|
|
379
|
+
def contains_video_inputs(self) -> bool:
|
380
|
+
return any(item.is_video() for item in self.mm_items)
|
381
|
+
|
370
382
|
def contains_audio_inputs(self) -> bool:
|
371
|
-
""" """
|
372
383
|
return any(item.is_audio() for item in self.mm_items)
|
373
384
|
|
374
385
|
def contains_mm_input(self) -> bool:
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -485,6 +485,8 @@ class Scheduler(
|
|
485
485
|
enable=server_args.enable_memory_saver
|
486
486
|
)
|
487
487
|
self.init_profier()
|
488
|
+
|
489
|
+
# Init metrics stats
|
488
490
|
self.init_metrics()
|
489
491
|
self.init_kv_events(server_args.kv_events_config)
|
490
492
|
|
@@ -628,6 +630,7 @@ class Scheduler(
|
|
628
630
|
self.torch_profiler_output_dir: Optional[str] = None
|
629
631
|
self.profiler_activities: Optional[List[str]] = None
|
630
632
|
self.profile_id: Optional[str] = None
|
633
|
+
self.profiler_start_forward_ct: Optional[int] = None
|
631
634
|
self.profiler_target_forward_ct: Optional[int] = None
|
632
635
|
self.profiler_target_prefill_ct: Optional[int] = None
|
633
636
|
self.profiler_target_decode_ct: Optional[int] = None
|
@@ -2389,9 +2392,10 @@ class Scheduler(
|
|
2389
2392
|
|
2390
2393
|
def profile(self, recv_req: ProfileReq):
|
2391
2394
|
if recv_req.type == ProfileReqType.START_PROFILE:
|
2392
|
-
if recv_req.profile_by_stage:
|
2395
|
+
if recv_req.profile_by_stage or recv_req.start_step:
|
2393
2396
|
return self.init_profile(
|
2394
2397
|
recv_req.output_dir,
|
2398
|
+
recv_req.start_step,
|
2395
2399
|
recv_req.num_steps,
|
2396
2400
|
recv_req.activities,
|
2397
2401
|
recv_req.with_stack,
|
@@ -2402,6 +2406,7 @@ class Scheduler(
|
|
2402
2406
|
else:
|
2403
2407
|
self.init_profile(
|
2404
2408
|
recv_req.output_dir,
|
2409
|
+
recv_req.start_step,
|
2405
2410
|
recv_req.num_steps,
|
2406
2411
|
recv_req.activities,
|
2407
2412
|
recv_req.with_stack,
|
@@ -2416,6 +2421,7 @@ class Scheduler(
|
|
2416
2421
|
def init_profile(
|
2417
2422
|
self,
|
2418
2423
|
output_dir: Optional[str],
|
2424
|
+
start_step: Optional[int],
|
2419
2425
|
num_steps: Optional[int],
|
2420
2426
|
activities: Optional[List[str]],
|
2421
2427
|
with_stack: Optional[bool],
|
@@ -2442,6 +2448,9 @@ class Scheduler(
|
|
2442
2448
|
self.profiler_activities = activities
|
2443
2449
|
self.profile_id = profile_id
|
2444
2450
|
|
2451
|
+
if start_step:
|
2452
|
+
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
|
2453
|
+
|
2445
2454
|
if num_steps:
|
2446
2455
|
self.profile_steps = num_steps
|
2447
2456
|
if self.profile_by_stage:
|
@@ -2449,6 +2458,10 @@ class Scheduler(
|
|
2449
2458
|
self.profiler_target_decode_ct = num_steps
|
2450
2459
|
self.profiler_prefill_ct = 0
|
2451
2460
|
self.profiler_decode_ct = 0
|
2461
|
+
elif start_step:
|
2462
|
+
self.profiler_target_forward_ct = (
|
2463
|
+
self.profiler_start_forward_ct + num_steps
|
2464
|
+
)
|
2452
2465
|
else:
|
2453
2466
|
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
2454
2467
|
# The caller will be notified when reaching profiler_target_forward_ct
|
@@ -2521,6 +2534,7 @@ class Scheduler(
|
|
2521
2534
|
|
2522
2535
|
if "CUDA_PROFILER" in activities:
|
2523
2536
|
torch.cuda.cudart().cudaProfilerStart()
|
2537
|
+
self.profile_in_progress = True
|
2524
2538
|
|
2525
2539
|
return ProfileReqOutput(success=True, message="Succeeded")
|
2526
2540
|
|
@@ -2584,6 +2598,7 @@ class Scheduler(
|
|
2584
2598
|
)
|
2585
2599
|
self.torch_profiler = None
|
2586
2600
|
self.profile_in_progress = False
|
2601
|
+
self.profiler_start_forward_ct = None
|
2587
2602
|
|
2588
2603
|
return ProfileReqOutput(success=True, message="Succeeded.")
|
2589
2604
|
|
@@ -2617,6 +2632,11 @@ class Scheduler(
|
|
2617
2632
|
and self.profiler_target_forward_ct <= self.forward_ct
|
2618
2633
|
):
|
2619
2634
|
self.stop_profile()
|
2635
|
+
if (
|
2636
|
+
self.profiler_start_forward_ct
|
2637
|
+
and self.profiler_start_forward_ct == self.forward_ct
|
2638
|
+
):
|
2639
|
+
self.start_profile()
|
2620
2640
|
|
2621
2641
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
2622
2642
|
if recv_req == ExpertDistributionReq.START_RECORD:
|
@@ -285,6 +285,20 @@ class TokenizerManager:
|
|
285
285
|
self.bootstrap_server = kv_bootstrap_server_class(
|
286
286
|
self.server_args.disaggregation_bootstrap_port
|
287
287
|
)
|
288
|
+
is_create_store = (
|
289
|
+
self.server_args.node_rank == 0
|
290
|
+
and self.server_args.disaggregation_transfer_backend == "ascend"
|
291
|
+
)
|
292
|
+
if is_create_store:
|
293
|
+
try:
|
294
|
+
from mf_adapter import create_config_store
|
295
|
+
|
296
|
+
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
297
|
+
create_config_store(ascend_url)
|
298
|
+
except Exception as e:
|
299
|
+
error_message = f"Failed create mf store, invalid ascend_url."
|
300
|
+
error_message += f" With exception {e}"
|
301
|
+
raise error_message
|
288
302
|
|
289
303
|
# For load balancing
|
290
304
|
self.current_load = 0
|
@@ -863,6 +877,7 @@ class TokenizerManager:
|
|
863
877
|
async def start_profile(
|
864
878
|
self,
|
865
879
|
output_dir: Optional[str] = None,
|
880
|
+
start_step: Optional[int] = None,
|
866
881
|
num_steps: Optional[int] = None,
|
867
882
|
activities: Optional[List[str]] = None,
|
868
883
|
with_stack: Optional[bool] = None,
|
@@ -875,6 +890,7 @@ class TokenizerManager:
|
|
875
890
|
req = ProfileReq(
|
876
891
|
type=ProfileReqType.START_PROFILE,
|
877
892
|
output_dir=output_dir,
|
893
|
+
start_step=start_step,
|
878
894
|
num_steps=num_steps,
|
879
895
|
activities=activities,
|
880
896
|
with_stack=with_stack,
|
@@ -34,16 +34,18 @@ import torch
|
|
34
34
|
import torch.distributed as dist
|
35
35
|
import triton
|
36
36
|
import triton.language as tl
|
37
|
-
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
38
37
|
|
39
38
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
40
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
-
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
40
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
42
41
|
|
43
42
|
logger = logging.getLogger(__name__)
|
44
43
|
|
45
44
|
GB = 1024 * 1024 * 1024
|
46
45
|
_is_cuda = is_cuda()
|
46
|
+
_is_npu = is_npu()
|
47
|
+
if not _is_npu:
|
48
|
+
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
47
49
|
|
48
50
|
|
49
51
|
class ReqToTokenPool:
|
@@ -602,32 +604,49 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
|
602
604
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
603
605
|
# [size, head_num, head_dim] for each layer
|
604
606
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
)
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
607
|
+
# Continuous memory improves the efficiency of Ascend`s transmission backend,
|
608
|
+
# while other backends remain unchanged.
|
609
|
+
self.kv_buffer = torch.zeros(
|
610
|
+
(
|
611
|
+
2,
|
612
|
+
self.layer_num,
|
613
|
+
self.size // self.page_size + 1,
|
614
|
+
self.page_size,
|
615
|
+
self.head_num,
|
616
|
+
self.head_dim,
|
617
|
+
),
|
618
|
+
dtype=self.store_dtype,
|
619
|
+
device=self.device,
|
620
|
+
)
|
621
|
+
self.k_buffer = self.kv_buffer[0]
|
622
|
+
self.v_buffer = self.kv_buffer[1]
|
623
|
+
|
624
|
+
# for disagg
|
625
|
+
def get_contiguous_buf_infos(self):
|
626
|
+
# layer_num x [seq_len, head_num, head_dim]
|
627
|
+
# layer_num x [page_num, page_size, head_num, head_dim]
|
628
|
+
kv_data_ptrs = [
|
629
|
+
self.get_key_buffer(i).data_ptr()
|
630
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
631
|
+
] + [
|
632
|
+
self.get_value_buffer(i).data_ptr()
|
633
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
634
|
+
]
|
635
|
+
kv_data_lens = [
|
636
|
+
self.get_key_buffer(i).nbytes
|
637
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
638
|
+
] + [
|
639
|
+
self.get_value_buffer(i).nbytes
|
640
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
641
|
+
]
|
642
|
+
kv_item_lens = [
|
643
|
+
self.get_key_buffer(i)[0].nbytes
|
644
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
645
|
+
] + [
|
646
|
+
self.get_value_buffer(i)[0].nbytes
|
647
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
648
|
+
]
|
649
|
+
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
631
650
|
|
632
651
|
def set_kv_buffer(
|
633
652
|
self,
|
@@ -967,18 +986,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
967
986
|
|
968
987
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
969
988
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
970
|
-
self.kv_buffer =
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
980
|
-
for _ in range(layer_num)
|
981
|
-
]
|
989
|
+
self.kv_buffer = torch.zeros(
|
990
|
+
(
|
991
|
+
layer_num,
|
992
|
+
self.size // self.page_size + 1,
|
993
|
+
self.page_size,
|
994
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
995
|
+
),
|
996
|
+
dtype=self.store_dtype,
|
997
|
+
device=self.device,
|
998
|
+
)
|
982
999
|
|
983
1000
|
self.layer_transfer_counter = None
|
984
1001
|
|
@@ -988,6 +1005,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
988
1005
|
)
|
989
1006
|
self.mem_usage = kv_size / GB
|
990
1007
|
|
1008
|
+
# for disagg
|
1009
|
+
def get_contiguous_buf_infos(self):
|
1010
|
+
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
1011
|
+
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
1012
|
+
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
1013
|
+
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
|
1014
|
+
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
1015
|
+
|
991
1016
|
def set_kv_buffer(
|
992
1017
|
self,
|
993
1018
|
layer: RadixAttention,
|
@@ -453,8 +453,20 @@ class ForwardBatch:
|
|
453
453
|
for mm_input in self.mm_inputs
|
454
454
|
)
|
455
455
|
|
456
|
+
def contains_video_inputs(self) -> bool:
|
457
|
+
if self.mm_inputs is None:
|
458
|
+
return False
|
459
|
+
return any(
|
460
|
+
mm_input is not None and mm_input.contains_video_inputs()
|
461
|
+
for mm_input in self.mm_inputs
|
462
|
+
)
|
463
|
+
|
456
464
|
def contains_mm_inputs(self) -> bool:
|
457
|
-
return
|
465
|
+
return (
|
466
|
+
self.contains_audio_inputs()
|
467
|
+
or self.contains_video_inputs()
|
468
|
+
or self.contains_image_inputs()
|
469
|
+
)
|
458
470
|
|
459
471
|
def _compute_mrope_positions(
|
460
472
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|