sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/lang/chat_template.py +24 -0
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +5 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/conversation.py +29 -4
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +678 -83
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +5 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +503 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +60 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +16 -5
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/mllama4.py +154 -0
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -30,22 +30,13 @@ import torch
|
|
30
30
|
import torch.nn as nn
|
31
31
|
import torch.nn.functional as F
|
32
32
|
from einops import rearrange
|
33
|
-
from transformers import
|
33
|
+
from transformers import Qwen2VLConfig
|
34
34
|
from transformers.activations import ACT2FN
|
35
35
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
36
|
-
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
37
36
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
38
|
-
Qwen2_5_VLConfig,
|
39
37
|
Qwen2_5_VLVisionConfig,
|
40
38
|
)
|
41
|
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
42
|
-
Qwen2_5_VLForConditionalGeneration,
|
43
|
-
)
|
44
39
|
|
45
|
-
from sglang.srt.distributed import (
|
46
|
-
get_tensor_model_parallel_rank,
|
47
|
-
get_tensor_model_parallel_world_size,
|
48
|
-
)
|
49
40
|
from sglang.srt.hf_transformers_utils import get_processor
|
50
41
|
from sglang.srt.layers.attention.vision import VisionAttention
|
51
42
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
@@ -57,7 +48,7 @@ from sglang.srt.managers.mm_utils import (
|
|
57
48
|
MultiModalityDataPaddingPatternTokenPairs,
|
58
49
|
general_mm_embed_routine,
|
59
50
|
)
|
60
|
-
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
51
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
61
52
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
62
53
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
63
54
|
from sglang.srt.models.qwen2 import Qwen2Model
|
@@ -513,19 +504,24 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
513
504
|
self.logits_processor = LogitsProcessor(config)
|
514
505
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
515
506
|
|
516
|
-
def pad_input_ids(self, input_ids: List[int],
|
507
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
517
508
|
# Get all special token IDs
|
518
|
-
im_start_id: int =
|
519
|
-
im_end_id: int =
|
509
|
+
im_start_id: int = mm_inputs.im_start_id
|
510
|
+
im_end_id: int = mm_inputs.im_end_id
|
520
511
|
|
521
512
|
media_token_pairs = [(im_start_id, im_end_id)]
|
522
513
|
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
514
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
523
515
|
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
516
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
517
|
+
# in qwen-vl, last dim is the same
|
518
|
+
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
519
|
+
self.visual.dtype
|
520
|
+
)
|
521
|
+
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
|
522
|
+
assert pixel_values.dim() == 2, pixel_values.dim()
|
523
|
+
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
|
524
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
|
529
525
|
return image_embeds
|
530
526
|
|
531
527
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
@@ -570,18 +566,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
570
566
|
f"(3, seq_len) positions, but got {positions.size()}"
|
571
567
|
)
|
572
568
|
|
573
|
-
|
569
|
+
hidden_states = general_mm_embed_routine(
|
574
570
|
input_ids=input_ids,
|
575
571
|
forward_batch=forward_batch,
|
576
|
-
|
577
|
-
|
578
|
-
)
|
579
|
-
|
580
|
-
hidden_states = self.model(
|
581
|
-
input_ids=None,
|
572
|
+
language_model=self.model,
|
573
|
+
image_data_embedding_func=self.get_image_feature,
|
582
574
|
positions=positions,
|
583
|
-
forward_batch=forward_batch,
|
584
|
-
input_embeds=inputs_embeds,
|
585
575
|
)
|
586
576
|
|
587
577
|
if not get_embedding:
|
@@ -594,9 +584,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
594
584
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
595
585
|
stacked_params_mapping = [
|
596
586
|
# (param_name, shard_name, shard_id)
|
597
|
-
("qkv_proj", "q_proj", "q"),
|
598
|
-
("qkv_proj", "k_proj", "k"),
|
599
|
-
("qkv_proj", "v_proj", "v"),
|
587
|
+
(".qkv_proj", ".q_proj", "q"),
|
588
|
+
(".qkv_proj", ".k_proj", "k"),
|
589
|
+
(".qkv_proj", ".v_proj", "v"),
|
600
590
|
("gate_up_proj", "up_proj", 1),
|
601
591
|
("gate_up_proj", "gate_proj", 0),
|
602
592
|
]
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
|
|
45
45
|
MultiModalityDataPaddingPatternTokenPairs,
|
46
46
|
general_mm_embed_routine,
|
47
47
|
)
|
48
|
-
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
48
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
49
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
50
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
51
|
from sglang.srt.models.qwen2 import Qwen2Model
|
@@ -472,18 +472,24 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
472
472
|
|
473
473
|
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
474
474
|
# add replaced padding by unique image hash
|
475
|
-
def pad_input_ids(self, input_ids: List[int],
|
475
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
476
476
|
# Get all special token IDs
|
477
|
-
im_start_id: int =
|
478
|
-
im_end_id: int =
|
477
|
+
im_start_id: int = mm_inputs.im_start_id
|
478
|
+
im_end_id: int = mm_inputs.im_end_id
|
479
479
|
|
480
480
|
media_token_pairs = [(im_start_id, im_end_id)]
|
481
481
|
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
482
|
-
return pattern.pad_input_tokens(input_ids,
|
482
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
483
483
|
|
484
|
-
def get_image_feature(self,
|
485
|
-
|
486
|
-
|
484
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
485
|
+
# in qwen-vl, last dim is the same
|
486
|
+
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
487
|
+
self.visual.dtype
|
488
|
+
)
|
489
|
+
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
|
490
|
+
assert pixel_values.dim() == 2, pixel_values.dim()
|
491
|
+
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
|
492
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
|
487
493
|
return image_embeds
|
488
494
|
|
489
495
|
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
@@ -527,27 +533,20 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
527
533
|
"multimodal section rotary embedding requires "
|
528
534
|
f"(3, seq_len) positions, but got {positions.size()}"
|
529
535
|
)
|
530
|
-
|
531
|
-
inputs_embeds = general_mm_embed_routine(
|
536
|
+
hidden_states = general_mm_embed_routine(
|
532
537
|
input_ids=input_ids,
|
533
538
|
forward_batch=forward_batch,
|
534
|
-
|
535
|
-
|
536
|
-
)
|
537
|
-
|
538
|
-
hidden_states = self.model(
|
539
|
-
input_ids=None,
|
539
|
+
language_model=self.model,
|
540
|
+
image_data_embedding_func=self.get_image_feature,
|
540
541
|
positions=positions,
|
541
|
-
forward_batch=forward_batch,
|
542
|
-
input_embeds=inputs_embeds,
|
543
542
|
)
|
544
543
|
|
545
|
-
if
|
544
|
+
if get_embedding:
|
545
|
+
return self.pooler(hidden_states, forward_batch)
|
546
|
+
else:
|
546
547
|
return self.logits_processor(
|
547
548
|
input_ids, hidden_states, self.lm_head, forward_batch
|
548
549
|
)
|
549
|
-
else:
|
550
|
-
return self.pooler(hidden_states, forward_batch)
|
551
550
|
|
552
551
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
553
552
|
stacked_params_mapping = [
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -897,6 +897,7 @@ def v1_chat_generate_request(
|
|
897
897
|
request_ids: List[str] = None,
|
898
898
|
):
|
899
899
|
input_ids = []
|
900
|
+
prompts = []
|
900
901
|
sampling_params_list = []
|
901
902
|
image_data_list = []
|
902
903
|
audio_data_list = []
|
@@ -916,6 +917,7 @@ def v1_chat_generate_request(
|
|
916
917
|
# - audio_data: None or a list of audio strings (URLs).
|
917
918
|
# None skips any image processing in GenerateReqInput.
|
918
919
|
strict_tag = None
|
920
|
+
prompt = ""
|
919
921
|
if not isinstance(request.messages, str):
|
920
922
|
# Apply chat template and its stop strings.
|
921
923
|
tools = None
|
@@ -1005,11 +1007,13 @@ def v1_chat_generate_request(
|
|
1005
1007
|
image_data = None
|
1006
1008
|
audio_data = None
|
1007
1009
|
modalities = []
|
1010
|
+
prompt = request.messages
|
1008
1011
|
input_ids.append(prompt_ids)
|
1009
1012
|
return_logprobs.append(request.logprobs)
|
1010
1013
|
logprob_start_lens.append(-1)
|
1011
1014
|
top_logprobs_nums.append(request.top_logprobs or 0)
|
1012
1015
|
lora_paths.append(request.lora_path)
|
1016
|
+
prompts.append(prompt)
|
1013
1017
|
|
1014
1018
|
sampling_params = {
|
1015
1019
|
"temperature": request.temperature,
|
@@ -1063,10 +1067,14 @@ def v1_chat_generate_request(
|
|
1063
1067
|
audio_data_list.append(audio_data)
|
1064
1068
|
modalities_list.append(modalities)
|
1065
1069
|
if len(all_requests) == 1:
|
1066
|
-
if
|
1067
|
-
|
1070
|
+
if tokenizer_manager.model_config.is_multimodal:
|
1071
|
+
# processor will need text input
|
1072
|
+
prompt_kwargs = {"text": prompts[0]}
|
1068
1073
|
else:
|
1069
|
-
|
1074
|
+
if isinstance(input_ids[0], str):
|
1075
|
+
prompt_kwargs = {"text": input_ids[0]}
|
1076
|
+
else:
|
1077
|
+
prompt_kwargs = {"input_ids": input_ids[0]}
|
1070
1078
|
sampling_params_list = sampling_params_list[0]
|
1071
1079
|
image_data_list = image_data_list[0]
|
1072
1080
|
audio_data_list = audio_data_list[0]
|
@@ -1076,10 +1084,14 @@ def v1_chat_generate_request(
|
|
1076
1084
|
modalities_list = modalities_list[0]
|
1077
1085
|
lora_paths = lora_paths[0]
|
1078
1086
|
else:
|
1079
|
-
if
|
1080
|
-
|
1087
|
+
if tokenizer_manager.model_config.is_multimodal:
|
1088
|
+
# processor will need text input
|
1089
|
+
prompt_kwargs = {"text": prompts}
|
1081
1090
|
else:
|
1082
|
-
|
1091
|
+
if isinstance(input_ids[0], str):
|
1092
|
+
prompt_kwargs = {"text": input_ids}
|
1093
|
+
else:
|
1094
|
+
prompt_kwargs = {"input_ids": input_ids}
|
1083
1095
|
|
1084
1096
|
adapted_request = GenerateReqInput(
|
1085
1097
|
**prompt_kwargs,
|
@@ -0,0 +1,371 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
# Adapted from
|
4
|
+
# https://github.com/vllm-project/vllm/blob/v0.8.2/vllm/platforms/interface.py
|
5
|
+
|
6
|
+
import enum
|
7
|
+
import platform
|
8
|
+
import random
|
9
|
+
from platform import uname
|
10
|
+
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.server_args import ServerArgs
|
14
|
+
from sglang.srt.configs.model_config import ModelConfig
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
import numpy as np
|
19
|
+
import torch
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
def in_wsl() -> bool:
|
25
|
+
# Reference: https://github.com/microsoft/WSL/issues/4071
|
26
|
+
return "microsoft" in " ".join(uname()).lower()
|
27
|
+
|
28
|
+
|
29
|
+
class PlatformEnum(enum.Enum):
|
30
|
+
CUDA = enum.auto()
|
31
|
+
ROCM = enum.auto()
|
32
|
+
HPU = enum.auto()
|
33
|
+
XPU = enum.auto()
|
34
|
+
CPU = enum.auto()
|
35
|
+
OOT = enum.auto()
|
36
|
+
UNSPECIFIED = enum.auto()
|
37
|
+
|
38
|
+
|
39
|
+
class CpuArchEnum(enum.Enum):
|
40
|
+
X86 = enum.auto()
|
41
|
+
ARM = enum.auto()
|
42
|
+
POWERPC = enum.auto()
|
43
|
+
OTHER = enum.auto()
|
44
|
+
UNKNOWN = enum.auto()
|
45
|
+
|
46
|
+
|
47
|
+
class DeviceCapability(NamedTuple):
|
48
|
+
major: int
|
49
|
+
minor: int
|
50
|
+
|
51
|
+
def as_version_str(self) -> str:
|
52
|
+
return f"{self.major}.{self.minor}"
|
53
|
+
|
54
|
+
def to_int(self) -> int:
|
55
|
+
"""
|
56
|
+
Express device capability as an integer ``<major><minor>``.
|
57
|
+
|
58
|
+
It is assumed that the minor version is always a single digit.
|
59
|
+
"""
|
60
|
+
assert 0 <= self.minor < 10
|
61
|
+
return self.major * 10 + self.minor
|
62
|
+
|
63
|
+
|
64
|
+
class Platform:
|
65
|
+
_enum: PlatformEnum
|
66
|
+
|
67
|
+
# Real device name of current platform.
|
68
|
+
device_name: str
|
69
|
+
|
70
|
+
# For specifying torch device for cuda alike platform's capability.
|
71
|
+
device_type: str
|
72
|
+
|
73
|
+
# The torch.distributed backend on current platform
|
74
|
+
torch_distributed_backend: str
|
75
|
+
|
76
|
+
# The torch.compile backend for compiling simple and
|
77
|
+
# standalone functions. The default value is "inductor" to keep
|
78
|
+
# the same behavior as PyTorch.
|
79
|
+
torch_compile_backend: str = "inductor"
|
80
|
+
|
81
|
+
supported_quantization: list[str] = []
|
82
|
+
|
83
|
+
supported_speculative_algorithm: list[str] = []
|
84
|
+
|
85
|
+
# Use first element as default dtype
|
86
|
+
supported_dtype: list[str] = []
|
87
|
+
|
88
|
+
# Use first element as default backend
|
89
|
+
supported_attntion_backend: list[str] = []
|
90
|
+
|
91
|
+
# Use first element as default backend
|
92
|
+
supported_sampling_backend: list[str] = []
|
93
|
+
|
94
|
+
# Use first element as default backend
|
95
|
+
supported_lora_backend: list[str] = []
|
96
|
+
|
97
|
+
def is_cuda(self) -> bool:
|
98
|
+
return self._enum == PlatformEnum.CUDA
|
99
|
+
|
100
|
+
def is_rocm(self) -> bool:
|
101
|
+
return self._enum == PlatformEnum.ROCM
|
102
|
+
|
103
|
+
def is_hpu(self) -> bool:
|
104
|
+
return self._enum == PlatformEnum.HPU
|
105
|
+
|
106
|
+
def is_xpu(self) -> bool:
|
107
|
+
return self._enum == PlatformEnum.XPU
|
108
|
+
|
109
|
+
def is_cpu(self) -> bool:
|
110
|
+
return self._enum == PlatformEnum.CPU
|
111
|
+
|
112
|
+
def is_out_of_tree(self) -> bool:
|
113
|
+
return self._enum == PlatformEnum.OOT
|
114
|
+
|
115
|
+
def is_cuda_alike(self) -> bool:
|
116
|
+
"""Stateless version of :func:`torch.cuda.is_available`."""
|
117
|
+
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def get_device_capability(
|
121
|
+
cls,
|
122
|
+
device_id: int = 0,
|
123
|
+
) -> Optional[DeviceCapability]:
|
124
|
+
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
|
125
|
+
return None
|
126
|
+
|
127
|
+
@classmethod
|
128
|
+
def has_device_capability(
|
129
|
+
cls,
|
130
|
+
capability: Union[Tuple[int, int], int],
|
131
|
+
device_id: int = 0,
|
132
|
+
) -> bool:
|
133
|
+
"""
|
134
|
+
Test whether this platform is compatible with a device capability.
|
135
|
+
|
136
|
+
The ``capability`` argument can either be:
|
137
|
+
|
138
|
+
- A tuple ``(major, minor)``.
|
139
|
+
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
|
140
|
+
"""
|
141
|
+
current_capability = cls.get_device_capability(device_id=device_id)
|
142
|
+
if current_capability is None:
|
143
|
+
return False
|
144
|
+
|
145
|
+
if isinstance(capability, tuple):
|
146
|
+
return current_capability >= capability
|
147
|
+
|
148
|
+
return current_capability.to_int() >= capability
|
149
|
+
|
150
|
+
@classmethod
|
151
|
+
def get_device_module(cls) -> Any:
|
152
|
+
"""Get `torch.device_module` like `torch.cuda` of current platform."""
|
153
|
+
raise NotImplementedError
|
154
|
+
|
155
|
+
@classmethod
|
156
|
+
def get_device_sku(cls, device_id: int = 0) -> str:
|
157
|
+
"""Get the SKU name of a device."""
|
158
|
+
raise NotImplementedError
|
159
|
+
|
160
|
+
@classmethod
|
161
|
+
def get_device_uuid(cls, device_id: int = 0) -> str:
|
162
|
+
"""Get the uuid of a device, e.g. the PCI bus ID."""
|
163
|
+
raise NotImplementedError
|
164
|
+
|
165
|
+
@classmethod
|
166
|
+
def get_device_core_count(cls, device_id: int = 0) -> str:
|
167
|
+
"""Get the core count of a device, e.g. SMs of CUDA, CUs of ROCM."""
|
168
|
+
raise NotImplementedError
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
def get_device_count(cls) -> int:
|
172
|
+
"""Get device count on current platform"""
|
173
|
+
raise NotImplementedError
|
174
|
+
|
175
|
+
@classmethod
|
176
|
+
def get_device_total_memory(cls, device_id: int = 0, distributed=False) -> float:
|
177
|
+
"""
|
178
|
+
Get total memory for device_type:device_id device in gigabytes.
|
179
|
+
"""
|
180
|
+
raise NotImplementedError
|
181
|
+
|
182
|
+
@classmethod
|
183
|
+
def get_device_available_memory(
|
184
|
+
cls, device_id: int = 0, distributed=False, empty_cache=True
|
185
|
+
) -> float:
|
186
|
+
"""
|
187
|
+
Get available memory for device_type:device_id device in gigabytes.
|
188
|
+
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
189
|
+
"""
|
190
|
+
raise NotImplementedError
|
191
|
+
|
192
|
+
@classmethod
|
193
|
+
def supports_overlap_scheduler(cls) -> bool:
|
194
|
+
"""
|
195
|
+
Check if the current platform supports overlap scheduler
|
196
|
+
"""
|
197
|
+
raise NotImplementedError
|
198
|
+
|
199
|
+
@classmethod
|
200
|
+
def seed_everything(cls, seed: Optional[int] = None) -> None:
|
201
|
+
"""
|
202
|
+
Set the seed of each random module.
|
203
|
+
`torch.manual_seed` will set seed on all devices.
|
204
|
+
|
205
|
+
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
206
|
+
"""
|
207
|
+
if seed is not None:
|
208
|
+
random.seed(seed)
|
209
|
+
np.random.seed(seed)
|
210
|
+
torch.manual_seed(seed)
|
211
|
+
|
212
|
+
@classmethod
|
213
|
+
def check_and_update_server_args(cls, server_args: ServerArgs) -> None:
|
214
|
+
"""
|
215
|
+
Check and update the server arguments for the current platform.
|
216
|
+
|
217
|
+
It can raise an exception if the configuration is not compatible with
|
218
|
+
the current platform, or it can update the configuration to make it
|
219
|
+
compatible with the current platform.
|
220
|
+
|
221
|
+
The config is passed by reference, so it can be modified in place.
|
222
|
+
"""
|
223
|
+
pass
|
224
|
+
|
225
|
+
@classmethod
|
226
|
+
def check_and_update_model_dtype(cls, model_config: ModelConfig, dtype: str) -> str:
|
227
|
+
"""
|
228
|
+
Check and update the model's dtype for the current platform.
|
229
|
+
"""
|
230
|
+
if cls.supported_dtype and dtype not in cls.supported_dtype:
|
231
|
+
logger.warning(
|
232
|
+
f"dtype {dtype} is currently not supported in "
|
233
|
+
f"{cls.device_name}. use {cls.supported_dtype[0]} instead"
|
234
|
+
)
|
235
|
+
return cls.supported_dtype[0]
|
236
|
+
return dtype
|
237
|
+
|
238
|
+
@classmethod
|
239
|
+
def check_and_update_attntion_backend(
|
240
|
+
cls, model_config: ModelConfig, backend: str
|
241
|
+
) -> str:
|
242
|
+
"""
|
243
|
+
Check and update the attntion backend for the current platform.
|
244
|
+
"""
|
245
|
+
raise NotImplementedError
|
246
|
+
|
247
|
+
@classmethod
|
248
|
+
def check_and_update_sampling_backend(cls, backend: str) -> str:
|
249
|
+
"""
|
250
|
+
Check and update the sampling backend for the current platform.
|
251
|
+
"""
|
252
|
+
raise NotImplementedError
|
253
|
+
|
254
|
+
@classmethod
|
255
|
+
def check_and_update_lora_backend(cls, backend: str) -> str:
|
256
|
+
"""
|
257
|
+
Check and update the lora backend for the current platform.
|
258
|
+
"""
|
259
|
+
raise NotImplementedError
|
260
|
+
|
261
|
+
@classmethod
|
262
|
+
def verify_model_arch(cls, model_arch: str) -> None:
|
263
|
+
"""
|
264
|
+
Verify whether the current platform supports the specified model
|
265
|
+
architecture.
|
266
|
+
|
267
|
+
- This will raise an Error or Warning based on the model support on
|
268
|
+
the current platform.
|
269
|
+
- By default all models are considered supported.
|
270
|
+
"""
|
271
|
+
pass
|
272
|
+
|
273
|
+
@classmethod
|
274
|
+
def verify_quantization(cls, quant: str) -> None:
|
275
|
+
"""
|
276
|
+
Verify whether the quantization is supported by the current platform.
|
277
|
+
"""
|
278
|
+
if cls.supported_quantization and quant not in cls.supported_quantization:
|
279
|
+
raise ValueError(
|
280
|
+
f"{quant} quantization is currently not supported in "
|
281
|
+
f"{cls.device_name}."
|
282
|
+
)
|
283
|
+
|
284
|
+
@classmethod
|
285
|
+
def verify_speculative_algorithm(cls, algo: str) -> None:
|
286
|
+
"""
|
287
|
+
Verify whether the speculative algorithm is supported by the current platform.
|
288
|
+
"""
|
289
|
+
if (
|
290
|
+
cls.supported_speculative_algorithm
|
291
|
+
and algo not in cls.supported_speculative_algorithm
|
292
|
+
):
|
293
|
+
raise ValueError(
|
294
|
+
f"speculative algorithm {algo} is currently not supported in "
|
295
|
+
f"{cls.device_name}."
|
296
|
+
)
|
297
|
+
|
298
|
+
@classmethod
|
299
|
+
def get_cpu_architecture(cls) -> CpuArchEnum:
|
300
|
+
"""
|
301
|
+
Determine the CPU architecture of the current system.
|
302
|
+
Returns CpuArchEnum indicating the architecture type.
|
303
|
+
"""
|
304
|
+
machine = platform.machine().lower()
|
305
|
+
|
306
|
+
if machine in ("x86_64", "amd64", "i386", "i686"):
|
307
|
+
return CpuArchEnum.X86
|
308
|
+
elif machine.startswith("arm") or machine.startswith("aarch"):
|
309
|
+
return CpuArchEnum.ARM
|
310
|
+
elif machine.startswith("ppc"):
|
311
|
+
return CpuArchEnum.POWERPC
|
312
|
+
|
313
|
+
return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
|
314
|
+
|
315
|
+
@classmethod
|
316
|
+
def is_pin_memory_available(cls) -> bool:
|
317
|
+
"""Checks whether pin memory is available on the current platform."""
|
318
|
+
if in_wsl():
|
319
|
+
# Pinning memory in WSL is not supported.
|
320
|
+
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
|
321
|
+
logger.warning(
|
322
|
+
"Using 'pin_memory=False' as WSL is detected. "
|
323
|
+
"This may slow down the performance."
|
324
|
+
)
|
325
|
+
return False
|
326
|
+
return True
|
327
|
+
|
328
|
+
@classmethod
|
329
|
+
def get_device_communicator_cls(cls) -> str:
|
330
|
+
"""
|
331
|
+
Get device specific communicator class for distributed communication.
|
332
|
+
"""
|
333
|
+
raise NotImplementedError
|
334
|
+
|
335
|
+
@classmethod
|
336
|
+
def supports_fp8(cls) -> bool:
|
337
|
+
return False
|
338
|
+
|
339
|
+
@classmethod
|
340
|
+
def fp8_dtype(cls) -> torch.dtype:
|
341
|
+
"""
|
342
|
+
Returns the preferred FP8 type on the current platform.
|
343
|
+
"""
|
344
|
+
return torch.float8_e4m3fn
|
345
|
+
|
346
|
+
@classmethod
|
347
|
+
def fp8_min_max(cls) -> Tuple[float, float]:
|
348
|
+
"""
|
349
|
+
Returns the preferred FP8 max value on the current platform.
|
350
|
+
"""
|
351
|
+
fp8_max = torch.finfo(cls.fp8_dtype()).max
|
352
|
+
return (-fp8_max, fp8_max)
|
353
|
+
|
354
|
+
@classmethod
|
355
|
+
def is_triton_avaliable(cls) -> bool:
|
356
|
+
raise NotImplementedError
|
357
|
+
|
358
|
+
@classmethod
|
359
|
+
def init_environments(cls) -> None:
|
360
|
+
"""
|
361
|
+
Init environments on current platform.
|
362
|
+
|
363
|
+
- Init platform specific env vars.
|
364
|
+
- Init platform specific patches.
|
365
|
+
"""
|
366
|
+
pass
|
367
|
+
|
368
|
+
|
369
|
+
class UnspecifiedPlatform(Platform):
|
370
|
+
_enum = PlatformEnum.UNSPECIFIED
|
371
|
+
device_type = ""
|