sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,104 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ try:
6
+ from petit_kernel import mul_nvfp4_a16, process_nvfp4_scales, repack_nvfp4
7
+ except ImportError:
8
+
9
+ def _check_petit_nvfp4_supported(
10
+ quant_method: str, group_size: Optional[int]
11
+ ) -> tuple[bool, Optional[str]]:
12
+ return (
13
+ False,
14
+ "Petit is not installed. Please install it with `pip install petit-kernel`.",
15
+ )
16
+
17
+ def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
18
+ raise ValueError(
19
+ "Petit is not installed. Please install it with `pip install petit-kernel`."
20
+ )
21
+
22
+ def apply_petit_nvfp4_linear(
23
+ input: torch.Tensor,
24
+ weight: torch.Tensor,
25
+ weight_scale: torch.Tensor,
26
+ weight_scale_2: torch.Tensor,
27
+ size_n: int,
28
+ size_k: int,
29
+ bias: Optional[torch.Tensor] = None,
30
+ ) -> torch.Tensor:
31
+ raise ValueError(
32
+ "Petit is not installed. Please install it with `pip install petit-kernel`."
33
+ )
34
+
35
+
36
+ def _check_petit_nvfp4_supported(
37
+ quant_method: str, group_size: Optional[int]
38
+ ) -> tuple[bool, Optional[str]]:
39
+ if quant_method != "NVFP4":
40
+ return (
41
+ False,
42
+ "Petit currently only supports: NVFP4"
43
+ " quantizations in sglang. Please check the "
44
+ "`hf_quant_config.json` file for your model's "
45
+ "quant configuration.",
46
+ )
47
+ if group_size is not None and group_size != 16:
48
+ return (
49
+ False,
50
+ "Petit currently only supports: group_size=16" " quantizations.",
51
+ )
52
+ return (True, None)
53
+
54
+
55
+ def verify_petit_nvfp4_supported(quant_method: str, group_size: Optional[int]) -> None:
56
+ supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size)
57
+ if not supported:
58
+ raise ValueError(error_msg)
59
+
60
+
61
+ def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
62
+ # Repack weights to petit format
63
+ part_size_n = layer.output_size_per_partition
64
+ part_size_k = layer.input_size_per_partition
65
+ qweight = layer.weight.view(torch.int32).contiguous()
66
+ petit_qweight = repack_nvfp4(qweight, size_n=part_size_n, size_k=part_size_k)
67
+ layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False)
68
+
69
+ # Permute scales
70
+ weight_scale = process_nvfp4_scales(
71
+ scales=layer.weight_scale, size_k=part_size_k, size_n=part_size_n
72
+ )
73
+ layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
74
+
75
+ return
76
+
77
+
78
+ def apply_petit_nvfp4_linear(
79
+ input: torch.Tensor,
80
+ weight: torch.Tensor,
81
+ weight_scale: torch.Tensor,
82
+ weight_scale_2: torch.Tensor,
83
+ size_n: int,
84
+ size_k: int,
85
+ bias: Optional[torch.Tensor] = None,
86
+ ) -> torch.Tensor:
87
+ reshaped_x = input.reshape(-1, input.shape[-1])
88
+ out_shape = input.shape[:-1] + (size_n,)
89
+
90
+ # TODO: Use auto-tuning to find the performant solution_id
91
+ output = mul_nvfp4_a16(
92
+ a=reshaped_x,
93
+ b=weight,
94
+ s=weight_scale,
95
+ global_scale=weight_scale_2,
96
+ size_m=reshaped_x.size(0),
97
+ size_n=size_n,
98
+ size_k=size_k,
99
+ solution_id=-1,
100
+ )
101
+ if bias is not None:
102
+ output.add_(bias) # In-place add
103
+
104
+ return output.reshape(out_shape)
@@ -1,16 +1,17 @@
1
- from typing import Any, Callable, Dict, List, Optional
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional
2
4
 
3
5
  import torch
4
6
  from torch.nn.parameter import Parameter
5
7
 
6
- from sglang.srt.distributed import get_tensor_model_parallel_world_size
7
- from sglang.srt.layers.linear import LinearMethodBase
8
8
  from sglang.srt.layers.parameter import (
9
9
  ChannelQuantScaleParameter,
10
10
  GroupQuantScaleParameter,
11
11
  ModelWeightParameter,
12
12
  )
13
13
  from sglang.srt.layers.quantization.base_config import (
14
+ LinearMethodBase,
14
15
  QuantizationConfig,
15
16
  QuantizeMethodBase,
16
17
  )
@@ -71,7 +72,7 @@ class QoQConfig(QuantizationConfig):
71
72
  return 80
72
73
 
73
74
  @classmethod
74
- def get_name(self) -> str:
75
+ def get_name(cls) -> str:
75
76
  return "qoq"
76
77
 
77
78
  @classmethod
@@ -83,7 +84,7 @@ class QoQConfig(QuantizationConfig):
83
84
  ]
84
85
 
85
86
  @classmethod
86
- def from_config(cls, config: Dict[str, Any]) -> "QoQConfig":
87
+ def from_config(cls, config: Dict[str, Any]) -> QoQConfig:
87
88
  weight_bits = cls.get_from_keys(config, ["wbits"])
88
89
  group_size = cls.get_from_keys(config, ["group_size"])
89
90
  return cls(weight_bits, group_size)
@@ -92,7 +93,7 @@ class QoQConfig(QuantizationConfig):
92
93
  self,
93
94
  layer: torch.nn.Module,
94
95
  prefix: str,
95
- ) -> Optional["QuantizeMethodBase"]:
96
+ ) -> Optional[QuantizeMethodBase]:
96
97
  from sglang.srt.layers.linear import LinearBase
97
98
 
98
99
  if isinstance(layer, LinearBase):
@@ -0,0 +1,352 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import functools
5
+ import struct
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from typing import Optional, Union
9
+
10
+ _SCALAR_TYPES_ID_MAP = {}
11
+
12
+
13
+ # Mirrors enum in `core/scalar_type.hpp`
14
+ class NanRepr(Enum):
15
+ NONE = 0 # nans are not supported
16
+ IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
17
+ EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
18
+
19
+
20
+ # This ScalarType class is a parallel implementation of the C++ ScalarType
21
+ # class found in csrc/core/scalar_type.hpp. These two classes should be kept
22
+ # in sync until the inductor fully supports custom C++ classes.
23
+ @dataclass(frozen=True)
24
+ class ScalarType:
25
+ """
26
+ ScalarType can represent a wide range of floating point and integer
27
+ types, in particular it can be used to represent sub-byte data types
28
+ (something that torch.dtype currently does not support). It is also
29
+ capable of representing types with a bias, i.e.:
30
+ `stored_value = value + bias`,
31
+ this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
32
+ of 8). The implementation for this class can be found in
33
+ csrc/core/scalar_type.hpp, these type signatures should be kept in sync
34
+ with that file.
35
+ """
36
+
37
+ exponent: int
38
+ """
39
+ Number of bits in the exponent if this is a floating point type
40
+ (zero if this an integer type)
41
+ """
42
+
43
+ mantissa: int
44
+ """
45
+ Number of bits in the mantissa if this is a floating point type,
46
+ or the number bits representing an integer excluding the sign bit if
47
+ this an integer type.
48
+ """
49
+
50
+ signed: bool
51
+ "If the type is signed (i.e. has a sign bit)"
52
+
53
+ bias: int
54
+ """
55
+ bias used to encode the values in this scalar type
56
+ (value = stored_value - bias, default 0) for example if we store the
57
+ type as an unsigned integer with a bias of 128 then the value 0 will be
58
+ stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
59
+ """
60
+
61
+ _finite_values_only: bool = False
62
+ """
63
+ Private: if infs are supported, used `has_infs()` instead.
64
+ """
65
+
66
+ nan_repr: NanRepr = NanRepr.IEEE_754
67
+ """
68
+ How NaNs are represent in this scalar type, returns NanRepr value.
69
+ (not applicable for integer types)
70
+ """
71
+
72
+ def _floating_point_max_int(self) -> int:
73
+ assert (
74
+ self.mantissa <= 52 and self.exponent <= 11
75
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
76
+
77
+ max_mantissa = (1 << self.mantissa) - 1
78
+ if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
79
+ max_mantissa = max_mantissa - 1
80
+
81
+ max_exponent = (1 << self.exponent) - 2
82
+ if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
83
+ assert (
84
+ self.exponent < 11
85
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
86
+ max_exponent = max_exponent + 1
87
+
88
+ # adjust the exponent to match that of a double
89
+ # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
90
+ # e is the exponent bits), there is some precedent for non-standard
91
+ # biases, example `float8_e4m3b11fnuz` here:
92
+ # https://github.com/jax-ml/ml_dtypes but to avoid premature over
93
+ # complication we are just assuming the standard exponent bias until
94
+ # there is a need to support non-standard biases
95
+ exponent_bias = (1 << (self.exponent - 1)) - 1
96
+ exponent_bias_double = (1 << 10) - 1 # double e = 11
97
+
98
+ max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
99
+
100
+ # shift the mantissa and exponent into the proper positions for an
101
+ # IEEE double and bitwise-or them together.
102
+ return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
103
+
104
+ def _floating_point_max(self) -> float:
105
+ double_raw = self._floating_point_max_int()
106
+ return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
107
+
108
+ def _raw_max(self) -> Union[int, float]:
109
+ if self.is_floating_point():
110
+ return self._floating_point_max()
111
+ else:
112
+ assert (
113
+ self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
114
+ ), "Cannot represent max as an int"
115
+ return (1 << self.mantissa) - 1
116
+
117
+ def _raw_min(self) -> Union[int, float]:
118
+ if self.is_floating_point():
119
+ assert (
120
+ self.is_signed()
121
+ ), "We currently assume all floating point types are signed"
122
+ sign_bit_double = 1 << 63
123
+
124
+ max_raw = self._floating_point_max_int()
125
+ min_raw = max_raw | sign_bit_double
126
+ return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
127
+ else:
128
+ assert (
129
+ not self.is_signed() or self.size_bits <= 64
130
+ ), "Cannot represent min as a int64_t"
131
+
132
+ if self.is_signed():
133
+ return -(1 << (self.size_bits - 1))
134
+ else:
135
+ return 0
136
+
137
+ @functools.cached_property
138
+ def id(self) -> int:
139
+ """
140
+ Convert the ScalarType to an int which can be passed to pytorch custom
141
+ ops. This layout of the int must be kept in sync with the C++
142
+ ScalarType's from_id method.
143
+ """
144
+ val = 0
145
+ offset = 0
146
+
147
+ def or_and_advance(member, bit_width):
148
+ nonlocal val
149
+ nonlocal offset
150
+ bit_mask = (1 << bit_width) - 1
151
+ val = val | (int(member) & bit_mask) << offset
152
+ offset = offset + bit_width
153
+
154
+ or_and_advance(self.exponent, 8)
155
+ or_and_advance(self.mantissa, 8)
156
+ or_and_advance(self.signed, 1)
157
+ or_and_advance(self.bias, 32)
158
+ or_and_advance(self._finite_values_only, 1)
159
+ or_and_advance(self.nan_repr.value, 8)
160
+
161
+ assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
162
+
163
+ _SCALAR_TYPES_ID_MAP[val] = self
164
+
165
+ return val
166
+
167
+ @property
168
+ def size_bits(self) -> int:
169
+ return self.exponent + self.mantissa + int(self.signed)
170
+
171
+ def min(self) -> Union[int, float]:
172
+ """
173
+ Min representable value for this scalar type.
174
+ (accounting for bias if there is one)
175
+ """
176
+ return self._raw_min() - self.bias
177
+
178
+ def max(self) -> Union[int, float]:
179
+ """
180
+ Max representable value for this scalar type.
181
+ (accounting for bias if there is one)
182
+ """
183
+ return self._raw_max() - self.bias
184
+
185
+ def is_signed(self) -> bool:
186
+ """
187
+ If the type is signed (i.e. has a sign bit), same as `signed`
188
+ added for consistency with:
189
+ https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
190
+ """
191
+ return self.signed
192
+
193
+ def is_floating_point(self) -> bool:
194
+ "If the type is a floating point type"
195
+ return self.exponent != 0
196
+
197
+ def is_integer(self) -> bool:
198
+ "If the type is an integer type"
199
+ return self.exponent == 0
200
+
201
+ def has_bias(self) -> bool:
202
+ "If the type has a non-zero bias"
203
+ return self.bias != 0
204
+
205
+ def has_infs(self) -> bool:
206
+ "If the type is floating point and supports infinity"
207
+ return not self._finite_values_only
208
+
209
+ def has_nans(self) -> bool:
210
+ return self.nan_repr != NanRepr.NONE.value
211
+
212
+ def is_ieee_754(self) -> bool:
213
+ """
214
+ If the type is a floating point type that follows IEEE 754
215
+ conventions
216
+ """
217
+ return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only
218
+
219
+ def __str__(self) -> str:
220
+ """
221
+ naming generally follows: https://github.com/jax-ml/ml_dtypes
222
+ for floating point types (leading f) the scheme is:
223
+ `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
224
+ flags:
225
+ - no-flags: means it follows IEEE 754 conventions
226
+ - f: means finite values only (no infinities)
227
+ - n: means nans are supported (non-standard encoding)
228
+ for integer types the scheme is:
229
+ `[u]int<size_bits>[b<bias>]`
230
+ - if bias is not present it means its zero
231
+ """
232
+ if self.is_floating_point():
233
+ ret = (
234
+ "float"
235
+ + str(self.size_bits)
236
+ + "_e"
237
+ + str(self.exponent)
238
+ + "m"
239
+ + str(self.mantissa)
240
+ )
241
+
242
+ if not self.is_ieee_754():
243
+ if self._finite_values_only:
244
+ ret = ret + "f"
245
+ if self.nan_repr != NanRepr.NONE:
246
+ ret = ret + "n"
247
+
248
+ return ret
249
+ else:
250
+ ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
251
+ if self.has_bias():
252
+ ret = ret + "b" + str(self.bias)
253
+ return ret
254
+
255
+ def __repr__(self) -> str:
256
+ return "ScalarType." + self.__str__()
257
+
258
+ # __len__ needs to be defined (and has to throw TypeError) for pytorch's
259
+ # opcheck to work.
260
+ def __len__(self) -> int:
261
+ raise TypeError
262
+
263
+ #
264
+ # Convenience Constructors
265
+ #
266
+
267
+ @classmethod
268
+ def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
269
+ "Create a signed integer scalar type (size_bits includes sign-bit)."
270
+ ret = cls(0, size_bits - 1, True, bias if bias else 0)
271
+ ret.id # noqa B018: make sure the id is cached
272
+ return ret
273
+
274
+ @classmethod
275
+ def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
276
+ """Create a unsigned integer scalar type."""
277
+ ret = cls(0, size_bits, False, bias if bias else 0)
278
+ ret.id # noqa B018: make sure the id is cached
279
+ return ret
280
+
281
+ @classmethod
282
+ def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
283
+ """
284
+ Create a standard floating point type
285
+ (i.e. follows IEEE 754 conventions).
286
+ """
287
+ assert mantissa > 0 and exponent > 0
288
+ ret = cls(exponent, mantissa, True, 0)
289
+ ret.id # noqa B018: make sure the id is cached
290
+ return ret
291
+
292
+ @classmethod
293
+ def float_(
294
+ cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
295
+ ) -> "ScalarType":
296
+ """
297
+ Create a non-standard floating point type
298
+ (i.e. does not follow IEEE 754 conventions).
299
+ """
300
+ assert mantissa > 0 and exponent > 0
301
+ assert nan_repr != NanRepr.IEEE_754, (
302
+ "use `float_IEEE754` constructor for floating point types that "
303
+ "follow IEEE 754 conventions"
304
+ )
305
+ ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
306
+ ret.id # noqa B018: make sure the id is cached
307
+ return ret
308
+
309
+ @classmethod
310
+ def from_id(cls, scalar_type_id: int):
311
+ if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
312
+ raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
313
+ return _SCALAR_TYPES_ID_MAP[scalar_type_id]
314
+
315
+
316
+ # naming generally follows: https://github.com/jax-ml/ml_dtypes
317
+ # for floating point types (leading f) the scheme is:
318
+ # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
319
+ # flags:
320
+ # - no-flags: means it follows IEEE 754 conventions
321
+ # - f: means finite values only (no infinities)
322
+ # - n: means nans are supported (non-standard encoding)
323
+ # for integer types the scheme is:
324
+ # `[u]int<size_bits>[b<bias>]`
325
+ # - if bias is not present it means its zero
326
+
327
+
328
+ class scalar_types:
329
+ int4 = ScalarType.int_(4, None)
330
+ uint4 = ScalarType.uint(4, None)
331
+ int8 = ScalarType.int_(8, None)
332
+ uint8 = ScalarType.uint(8, None)
333
+ float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
334
+ float8_e5m2 = ScalarType.float_IEEE754(5, 2)
335
+ float16_e8m7 = ScalarType.float_IEEE754(8, 7)
336
+ float16_e5m10 = ScalarType.float_IEEE754(5, 10)
337
+
338
+ # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
339
+ float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
340
+
341
+ # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
342
+ float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
343
+
344
+ # "gptq" types
345
+ uint2b2 = ScalarType.uint(2, 2)
346
+ uint3b4 = ScalarType.uint(3, 4)
347
+ uint4b8 = ScalarType.uint(4, 8)
348
+ uint8b128 = ScalarType.uint(8, 128)
349
+
350
+ # colloquial names
351
+ bfloat16 = float16_e8m7
352
+ float16 = float16_e5m10