sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -30,22 +30,13 @@ import torch
30
30
  import torch.nn as nn
31
31
  import torch.nn.functional as F
32
32
  from einops import rearrange
33
- from transformers import AutoModel, Qwen2VLConfig
33
+ from transformers import Qwen2VLConfig
34
34
  from transformers.activations import ACT2FN
35
35
  from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
36
- from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
37
36
  from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
38
- Qwen2_5_VLConfig,
39
37
  Qwen2_5_VLVisionConfig,
40
38
  )
41
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
42
- Qwen2_5_VLForConditionalGeneration,
43
- )
44
39
 
45
- from sglang.srt.distributed import (
46
- get_tensor_model_parallel_rank,
47
- get_tensor_model_parallel_world_size,
48
- )
49
40
  from sglang.srt.hf_transformers_utils import get_processor
50
41
  from sglang.srt.layers.attention.vision import VisionAttention
51
42
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -57,7 +48,7 @@ from sglang.srt.managers.mm_utils import (
57
48
  MultiModalityDataPaddingPatternTokenPairs,
58
49
  general_mm_embed_routine,
59
50
  )
60
- from sglang.srt.managers.schedule_batch import MultimodalInputs
51
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
61
52
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
62
53
  from sglang.srt.model_loader.weight_utils import default_weight_loader
63
54
  from sglang.srt.models.qwen2 import Qwen2Model
@@ -513,19 +504,24 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
513
504
  self.logits_processor = LogitsProcessor(config)
514
505
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
515
506
 
516
- def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
507
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
517
508
  # Get all special token IDs
518
- im_start_id: int = image_inputs.im_start_id
519
- im_end_id: int = image_inputs.im_end_id
509
+ im_start_id: int = mm_inputs.im_start_id
510
+ im_end_id: int = mm_inputs.im_end_id
520
511
 
521
512
  media_token_pairs = [(im_start_id, im_end_id)]
522
513
  pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
514
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
523
515
 
524
- return pattern.pad_input_tokens(input_ids, image_inputs)
525
-
526
- def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
527
- pixel_values = image_input.pixel_values.type(self.visual.dtype)
528
- image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
516
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
517
+ # in qwen-vl, last dim is the same
518
+ pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
519
+ self.visual.dtype
520
+ )
521
+ image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
522
+ assert pixel_values.dim() == 2, pixel_values.dim()
523
+ assert image_grid_thws.dim() == 2, image_grid_thws.dim()
524
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
529
525
  return image_embeds
530
526
 
531
527
  def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
@@ -570,18 +566,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
570
566
  f"(3, seq_len) positions, but got {positions.size()}"
571
567
  )
572
568
 
573
- inputs_embeds = general_mm_embed_routine(
569
+ hidden_states = general_mm_embed_routine(
574
570
  input_ids=input_ids,
575
571
  forward_batch=forward_batch,
576
- embed_tokens=self.get_input_embeddings(),
577
- mm_data_embedding_func=self.get_image_feature,
578
- )
579
-
580
- hidden_states = self.model(
581
- input_ids=None,
572
+ language_model=self.model,
573
+ image_data_embedding_func=self.get_image_feature,
582
574
  positions=positions,
583
- forward_batch=forward_batch,
584
- input_embeds=inputs_embeds,
585
575
  )
586
576
 
587
577
  if not get_embedding:
@@ -594,9 +584,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
594
584
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
595
585
  stacked_params_mapping = [
596
586
  # (param_name, shard_name, shard_id)
597
- ("qkv_proj", "q_proj", "q"),
598
- ("qkv_proj", "k_proj", "k"),
599
- ("qkv_proj", "v_proj", "v"),
587
+ (".qkv_proj", ".q_proj", "q"),
588
+ (".qkv_proj", ".k_proj", "k"),
589
+ (".qkv_proj", ".v_proj", "v"),
600
590
  ("gate_up_proj", "up_proj", 1),
601
591
  ("gate_up_proj", "gate_proj", 0),
602
592
  ]
@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
45
45
  MultiModalityDataPaddingPatternTokenPairs,
46
46
  general_mm_embed_routine,
47
47
  )
48
- from sglang.srt.managers.schedule_batch import MultimodalInputs
48
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
49
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
50
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
51
  from sglang.srt.models.qwen2 import Qwen2Model
@@ -472,18 +472,24 @@ class Qwen2VLForConditionalGeneration(nn.Module):
472
472
 
473
473
  # Use grid_t * grid_w * grid_h to pad tokens for each image
474
474
  # add replaced padding by unique image hash
475
- def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs):
475
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
476
476
  # Get all special token IDs
477
- im_start_id: int = multi_modal_inputs.im_start_id
478
- im_end_id: int = multi_modal_inputs.im_end_id
477
+ im_start_id: int = mm_inputs.im_start_id
478
+ im_end_id: int = mm_inputs.im_end_id
479
479
 
480
480
  media_token_pairs = [(im_start_id, im_end_id)]
481
481
  pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
482
- return pattern.pad_input_tokens(input_ids, multi_modal_inputs)
482
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
483
483
 
484
- def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
485
- pixel_values = image_input.pixel_values.type(self.visual.dtype)
486
- image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
484
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
485
+ # in qwen-vl, last dim is the same
486
+ pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
487
+ self.visual.dtype
488
+ )
489
+ image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
490
+ assert pixel_values.dim() == 2, pixel_values.dim()
491
+ assert image_grid_thws.dim() == 2, image_grid_thws.dim()
492
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
487
493
  return image_embeds
488
494
 
489
495
  def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
@@ -527,27 +533,20 @@ class Qwen2VLForConditionalGeneration(nn.Module):
527
533
  "multimodal section rotary embedding requires "
528
534
  f"(3, seq_len) positions, but got {positions.size()}"
529
535
  )
530
-
531
- inputs_embeds = general_mm_embed_routine(
536
+ hidden_states = general_mm_embed_routine(
532
537
  input_ids=input_ids,
533
538
  forward_batch=forward_batch,
534
- embed_tokens=self.get_input_embeddings(),
535
- mm_data_embedding_func=self.get_image_feature,
536
- )
537
-
538
- hidden_states = self.model(
539
- input_ids=None,
539
+ language_model=self.model,
540
+ image_data_embedding_func=self.get_image_feature,
540
541
  positions=positions,
541
- forward_batch=forward_batch,
542
- input_embeds=inputs_embeds,
543
542
  )
544
543
 
545
- if not get_embedding:
544
+ if get_embedding:
545
+ return self.pooler(hidden_states, forward_batch)
546
+ else:
546
547
  return self.logits_processor(
547
548
  input_ids, hidden_states, self.lm_head, forward_batch
548
549
  )
549
- else:
550
- return self.pooler(hidden_states, forward_batch)
551
550
 
552
551
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
553
552
  stacked_params_mapping = [
@@ -897,6 +897,7 @@ def v1_chat_generate_request(
897
897
  request_ids: List[str] = None,
898
898
  ):
899
899
  input_ids = []
900
+ prompts = []
900
901
  sampling_params_list = []
901
902
  image_data_list = []
902
903
  audio_data_list = []
@@ -916,6 +917,7 @@ def v1_chat_generate_request(
916
917
  # - audio_data: None or a list of audio strings (URLs).
917
918
  # None skips any image processing in GenerateReqInput.
918
919
  strict_tag = None
920
+ prompt = ""
919
921
  if not isinstance(request.messages, str):
920
922
  # Apply chat template and its stop strings.
921
923
  tools = None
@@ -1005,11 +1007,13 @@ def v1_chat_generate_request(
1005
1007
  image_data = None
1006
1008
  audio_data = None
1007
1009
  modalities = []
1010
+ prompt = request.messages
1008
1011
  input_ids.append(prompt_ids)
1009
1012
  return_logprobs.append(request.logprobs)
1010
1013
  logprob_start_lens.append(-1)
1011
1014
  top_logprobs_nums.append(request.top_logprobs or 0)
1012
1015
  lora_paths.append(request.lora_path)
1016
+ prompts.append(prompt)
1013
1017
 
1014
1018
  sampling_params = {
1015
1019
  "temperature": request.temperature,
@@ -1063,10 +1067,14 @@ def v1_chat_generate_request(
1063
1067
  audio_data_list.append(audio_data)
1064
1068
  modalities_list.append(modalities)
1065
1069
  if len(all_requests) == 1:
1066
- if isinstance(input_ids[0], str):
1067
- prompt_kwargs = {"text": input_ids[0]}
1070
+ if tokenizer_manager.model_config.is_multimodal:
1071
+ # processor will need text input
1072
+ prompt_kwargs = {"text": prompts[0]}
1068
1073
  else:
1069
- prompt_kwargs = {"input_ids": input_ids[0]}
1074
+ if isinstance(input_ids[0], str):
1075
+ prompt_kwargs = {"text": input_ids[0]}
1076
+ else:
1077
+ prompt_kwargs = {"input_ids": input_ids[0]}
1070
1078
  sampling_params_list = sampling_params_list[0]
1071
1079
  image_data_list = image_data_list[0]
1072
1080
  audio_data_list = audio_data_list[0]
@@ -1076,10 +1084,14 @@ def v1_chat_generate_request(
1076
1084
  modalities_list = modalities_list[0]
1077
1085
  lora_paths = lora_paths[0]
1078
1086
  else:
1079
- if isinstance(input_ids[0], str):
1080
- prompt_kwargs = {"text": input_ids}
1087
+ if tokenizer_manager.model_config.is_multimodal:
1088
+ # processor will need text input
1089
+ prompt_kwargs = {"text": prompts}
1081
1090
  else:
1082
- prompt_kwargs = {"input_ids": input_ids}
1091
+ if isinstance(input_ids[0], str):
1092
+ prompt_kwargs = {"text": input_ids}
1093
+ else:
1094
+ prompt_kwargs = {"input_ids": input_ids}
1083
1095
 
1084
1096
  adapted_request = GenerateReqInput(
1085
1097
  **prompt_kwargs,
@@ -0,0 +1,371 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/vllm-project/vllm/blob/v0.8.2/vllm/platforms/interface.py
5
+
6
+ import enum
7
+ import platform
8
+ import random
9
+ from platform import uname
10
+ from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.server_args import ServerArgs
14
+ from sglang.srt.configs.model_config import ModelConfig
15
+
16
+ import logging
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def in_wsl() -> bool:
25
+ # Reference: https://github.com/microsoft/WSL/issues/4071
26
+ return "microsoft" in " ".join(uname()).lower()
27
+
28
+
29
+ class PlatformEnum(enum.Enum):
30
+ CUDA = enum.auto()
31
+ ROCM = enum.auto()
32
+ HPU = enum.auto()
33
+ XPU = enum.auto()
34
+ CPU = enum.auto()
35
+ OOT = enum.auto()
36
+ UNSPECIFIED = enum.auto()
37
+
38
+
39
+ class CpuArchEnum(enum.Enum):
40
+ X86 = enum.auto()
41
+ ARM = enum.auto()
42
+ POWERPC = enum.auto()
43
+ OTHER = enum.auto()
44
+ UNKNOWN = enum.auto()
45
+
46
+
47
+ class DeviceCapability(NamedTuple):
48
+ major: int
49
+ minor: int
50
+
51
+ def as_version_str(self) -> str:
52
+ return f"{self.major}.{self.minor}"
53
+
54
+ def to_int(self) -> int:
55
+ """
56
+ Express device capability as an integer ``<major><minor>``.
57
+
58
+ It is assumed that the minor version is always a single digit.
59
+ """
60
+ assert 0 <= self.minor < 10
61
+ return self.major * 10 + self.minor
62
+
63
+
64
+ class Platform:
65
+ _enum: PlatformEnum
66
+
67
+ # Real device name of current platform.
68
+ device_name: str
69
+
70
+ # For specifying torch device for cuda alike platform's capability.
71
+ device_type: str
72
+
73
+ # The torch.distributed backend on current platform
74
+ torch_distributed_backend: str
75
+
76
+ # The torch.compile backend for compiling simple and
77
+ # standalone functions. The default value is "inductor" to keep
78
+ # the same behavior as PyTorch.
79
+ torch_compile_backend: str = "inductor"
80
+
81
+ supported_quantization: list[str] = []
82
+
83
+ supported_speculative_algorithm: list[str] = []
84
+
85
+ # Use first element as default dtype
86
+ supported_dtype: list[str] = []
87
+
88
+ # Use first element as default backend
89
+ supported_attntion_backend: list[str] = []
90
+
91
+ # Use first element as default backend
92
+ supported_sampling_backend: list[str] = []
93
+
94
+ # Use first element as default backend
95
+ supported_lora_backend: list[str] = []
96
+
97
+ def is_cuda(self) -> bool:
98
+ return self._enum == PlatformEnum.CUDA
99
+
100
+ def is_rocm(self) -> bool:
101
+ return self._enum == PlatformEnum.ROCM
102
+
103
+ def is_hpu(self) -> bool:
104
+ return self._enum == PlatformEnum.HPU
105
+
106
+ def is_xpu(self) -> bool:
107
+ return self._enum == PlatformEnum.XPU
108
+
109
+ def is_cpu(self) -> bool:
110
+ return self._enum == PlatformEnum.CPU
111
+
112
+ def is_out_of_tree(self) -> bool:
113
+ return self._enum == PlatformEnum.OOT
114
+
115
+ def is_cuda_alike(self) -> bool:
116
+ """Stateless version of :func:`torch.cuda.is_available`."""
117
+ return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
118
+
119
+ @classmethod
120
+ def get_device_capability(
121
+ cls,
122
+ device_id: int = 0,
123
+ ) -> Optional[DeviceCapability]:
124
+ """Stateless version of :func:`torch.cuda.get_device_capability`."""
125
+ return None
126
+
127
+ @classmethod
128
+ def has_device_capability(
129
+ cls,
130
+ capability: Union[Tuple[int, int], int],
131
+ device_id: int = 0,
132
+ ) -> bool:
133
+ """
134
+ Test whether this platform is compatible with a device capability.
135
+
136
+ The ``capability`` argument can either be:
137
+
138
+ - A tuple ``(major, minor)``.
139
+ - An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
140
+ """
141
+ current_capability = cls.get_device_capability(device_id=device_id)
142
+ if current_capability is None:
143
+ return False
144
+
145
+ if isinstance(capability, tuple):
146
+ return current_capability >= capability
147
+
148
+ return current_capability.to_int() >= capability
149
+
150
+ @classmethod
151
+ def get_device_module(cls) -> Any:
152
+ """Get `torch.device_module` like `torch.cuda` of current platform."""
153
+ raise NotImplementedError
154
+
155
+ @classmethod
156
+ def get_device_sku(cls, device_id: int = 0) -> str:
157
+ """Get the SKU name of a device."""
158
+ raise NotImplementedError
159
+
160
+ @classmethod
161
+ def get_device_uuid(cls, device_id: int = 0) -> str:
162
+ """Get the uuid of a device, e.g. the PCI bus ID."""
163
+ raise NotImplementedError
164
+
165
+ @classmethod
166
+ def get_device_core_count(cls, device_id: int = 0) -> str:
167
+ """Get the core count of a device, e.g. SMs of CUDA, CUs of ROCM."""
168
+ raise NotImplementedError
169
+
170
+ @classmethod
171
+ def get_device_count(cls) -> int:
172
+ """Get device count on current platform"""
173
+ raise NotImplementedError
174
+
175
+ @classmethod
176
+ def get_device_total_memory(cls, device_id: int = 0, distributed=False) -> float:
177
+ """
178
+ Get total memory for device_type:device_id device in gigabytes.
179
+ """
180
+ raise NotImplementedError
181
+
182
+ @classmethod
183
+ def get_device_available_memory(
184
+ cls, device_id: int = 0, distributed=False, empty_cache=True
185
+ ) -> float:
186
+ """
187
+ Get available memory for device_type:device_id device in gigabytes.
188
+ When distributed is True, the available memory is the minimum available memory of all GPUs.
189
+ """
190
+ raise NotImplementedError
191
+
192
+ @classmethod
193
+ def supports_overlap_scheduler(cls) -> bool:
194
+ """
195
+ Check if the current platform supports overlap scheduler
196
+ """
197
+ raise NotImplementedError
198
+
199
+ @classmethod
200
+ def seed_everything(cls, seed: Optional[int] = None) -> None:
201
+ """
202
+ Set the seed of each random module.
203
+ `torch.manual_seed` will set seed on all devices.
204
+
205
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
206
+ """
207
+ if seed is not None:
208
+ random.seed(seed)
209
+ np.random.seed(seed)
210
+ torch.manual_seed(seed)
211
+
212
+ @classmethod
213
+ def check_and_update_server_args(cls, server_args: ServerArgs) -> None:
214
+ """
215
+ Check and update the server arguments for the current platform.
216
+
217
+ It can raise an exception if the configuration is not compatible with
218
+ the current platform, or it can update the configuration to make it
219
+ compatible with the current platform.
220
+
221
+ The config is passed by reference, so it can be modified in place.
222
+ """
223
+ pass
224
+
225
+ @classmethod
226
+ def check_and_update_model_dtype(cls, model_config: ModelConfig, dtype: str) -> str:
227
+ """
228
+ Check and update the model's dtype for the current platform.
229
+ """
230
+ if cls.supported_dtype and dtype not in cls.supported_dtype:
231
+ logger.warning(
232
+ f"dtype {dtype} is currently not supported in "
233
+ f"{cls.device_name}. use {cls.supported_dtype[0]} instead"
234
+ )
235
+ return cls.supported_dtype[0]
236
+ return dtype
237
+
238
+ @classmethod
239
+ def check_and_update_attntion_backend(
240
+ cls, model_config: ModelConfig, backend: str
241
+ ) -> str:
242
+ """
243
+ Check and update the attntion backend for the current platform.
244
+ """
245
+ raise NotImplementedError
246
+
247
+ @classmethod
248
+ def check_and_update_sampling_backend(cls, backend: str) -> str:
249
+ """
250
+ Check and update the sampling backend for the current platform.
251
+ """
252
+ raise NotImplementedError
253
+
254
+ @classmethod
255
+ def check_and_update_lora_backend(cls, backend: str) -> str:
256
+ """
257
+ Check and update the lora backend for the current platform.
258
+ """
259
+ raise NotImplementedError
260
+
261
+ @classmethod
262
+ def verify_model_arch(cls, model_arch: str) -> None:
263
+ """
264
+ Verify whether the current platform supports the specified model
265
+ architecture.
266
+
267
+ - This will raise an Error or Warning based on the model support on
268
+ the current platform.
269
+ - By default all models are considered supported.
270
+ """
271
+ pass
272
+
273
+ @classmethod
274
+ def verify_quantization(cls, quant: str) -> None:
275
+ """
276
+ Verify whether the quantization is supported by the current platform.
277
+ """
278
+ if cls.supported_quantization and quant not in cls.supported_quantization:
279
+ raise ValueError(
280
+ f"{quant} quantization is currently not supported in "
281
+ f"{cls.device_name}."
282
+ )
283
+
284
+ @classmethod
285
+ def verify_speculative_algorithm(cls, algo: str) -> None:
286
+ """
287
+ Verify whether the speculative algorithm is supported by the current platform.
288
+ """
289
+ if (
290
+ cls.supported_speculative_algorithm
291
+ and algo not in cls.supported_speculative_algorithm
292
+ ):
293
+ raise ValueError(
294
+ f"speculative algorithm {algo} is currently not supported in "
295
+ f"{cls.device_name}."
296
+ )
297
+
298
+ @classmethod
299
+ def get_cpu_architecture(cls) -> CpuArchEnum:
300
+ """
301
+ Determine the CPU architecture of the current system.
302
+ Returns CpuArchEnum indicating the architecture type.
303
+ """
304
+ machine = platform.machine().lower()
305
+
306
+ if machine in ("x86_64", "amd64", "i386", "i686"):
307
+ return CpuArchEnum.X86
308
+ elif machine.startswith("arm") or machine.startswith("aarch"):
309
+ return CpuArchEnum.ARM
310
+ elif machine.startswith("ppc"):
311
+ return CpuArchEnum.POWERPC
312
+
313
+ return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
314
+
315
+ @classmethod
316
+ def is_pin_memory_available(cls) -> bool:
317
+ """Checks whether pin memory is available on the current platform."""
318
+ if in_wsl():
319
+ # Pinning memory in WSL is not supported.
320
+ # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
321
+ logger.warning(
322
+ "Using 'pin_memory=False' as WSL is detected. "
323
+ "This may slow down the performance."
324
+ )
325
+ return False
326
+ return True
327
+
328
+ @classmethod
329
+ def get_device_communicator_cls(cls) -> str:
330
+ """
331
+ Get device specific communicator class for distributed communication.
332
+ """
333
+ raise NotImplementedError
334
+
335
+ @classmethod
336
+ def supports_fp8(cls) -> bool:
337
+ return False
338
+
339
+ @classmethod
340
+ def fp8_dtype(cls) -> torch.dtype:
341
+ """
342
+ Returns the preferred FP8 type on the current platform.
343
+ """
344
+ return torch.float8_e4m3fn
345
+
346
+ @classmethod
347
+ def fp8_min_max(cls) -> Tuple[float, float]:
348
+ """
349
+ Returns the preferred FP8 max value on the current platform.
350
+ """
351
+ fp8_max = torch.finfo(cls.fp8_dtype()).max
352
+ return (-fp8_max, fp8_max)
353
+
354
+ @classmethod
355
+ def is_triton_avaliable(cls) -> bool:
356
+ raise NotImplementedError
357
+
358
+ @classmethod
359
+ def init_environments(cls) -> None:
360
+ """
361
+ Init environments on current platform.
362
+
363
+ - Init platform specific env vars.
364
+ - Init platform specific patches.
365
+ """
366
+ pass
367
+
368
+
369
+ class UnspecifiedPlatform(Platform):
370
+ _enum = PlatformEnum.UNSPECIFIED
371
+ device_type = ""