hypertensor 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,17 @@
1
+ from .formats import PrecisionFormat
2
+ from .weight_options import WeightOption
3
+ from .quantize import scan_model
4
+ from .modules import HyperLinear, wrap_model_with_hypertensor
5
+ from .hyper_attention import HyperAttention
6
+ from .serialize import save_hypertensors, load_hypertensors
7
+
8
+ __all__ = [
9
+ "PrecisionFormat",
10
+ "WeightOption",
11
+ "scan_model",
12
+ "HyperLinear",
13
+ "HyperAttention",
14
+ "wrap_model_with_hypertensor",
15
+ "save_hypertensors",
16
+ "load_hypertensors",
17
+ ]
@@ -0,0 +1,8 @@
1
+ from typing import Tuple
2
+
3
+
4
+ def compute_block_grid(shape: Tuple[int, int], block_size: int) -> Tuple[int, int]:
5
+ out_features, in_features = shape
6
+ n_blocks_out = (out_features + block_size - 1) // block_size
7
+ n_blocks_in = (in_features + block_size - 1) // block_size
8
+ return n_blocks_out, n_blocks_in
@@ -0,0 +1,147 @@
1
+ import os
2
+ import torch
3
+ from torch.utils.cpp_extension import load
4
+
5
+ _src_dir = os.path.dirname(__file__)
6
+ _ext_name = "hypertensor_kernels"
7
+
8
+ _hypertensor_ext = load(
9
+ name=_ext_name,
10
+ sources=[
11
+ os.path.join(_src_dir, "kernels.cu"),
12
+ os.path.join(_src_dir, "binding.cpp"),
13
+ ],
14
+ verbose=False,
15
+ extra_cflags=[
16
+ "-O3",
17
+ "-DTORCH_DISABLE_DYNAMO",
18
+ ],
19
+ extra_cuda_cflags=[
20
+ "-O3",
21
+ "--use_fast_math",
22
+ "-DTORCH_DISABLE_DYNAMO",
23
+ "-U__CUDA_NO_HALF_OPERATORS__",
24
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
25
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
26
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__",
27
+ "-gencode=arch=compute_89,code=sm_89",
28
+ ],
29
+ )
30
+
31
+
32
+ def _ensure_same_device(a: torch.Tensor, b: torch.Tensor):
33
+ if a.device != b.device:
34
+ b = b.to(a.device)
35
+ return a, b
36
+
37
+
38
+ # ---------- FP32 / FP16 / BF16: let cuBLAS do its job ----------
39
+
40
+ def matmul_fp32(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
41
+ a, b = _ensure_same_device(a, b)
42
+ return torch.matmul(a, b)
43
+
44
+
45
+ def matmul_fp16(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
46
+ a, b = _ensure_same_device(a, b)
47
+ return torch.matmul(a, b)
48
+
49
+
50
+ def matmul_bf16(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
51
+ a, b = _ensure_same_device(a, b)
52
+ return torch.matmul(a, b)
53
+
54
+
55
+ # ---------- INT8 / INT4 weight GEMM (for HyperLinear) ----------
56
+
57
+ def matmul_int8(a: torch.Tensor, w_int8: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
58
+ a, w_int8 = _ensure_same_device(a, w_int8)
59
+ scale = scale.to(a.device)
60
+
61
+ if a.dtype != torch.float16:
62
+ a = a.to(torch.float16)
63
+
64
+ if (
65
+ a.is_cuda
66
+ and w_int8.is_cuda
67
+ and scale.is_cuda
68
+ and w_int8.dtype == torch.int8
69
+ and scale.dtype == torch.float32
70
+ ):
71
+ return _hypertensor_ext.gemm_int8_fp16_tc(a, w_int8, scale)
72
+
73
+ w_f = w_int8.to(torch.float32)
74
+ if scale.numel() == 1:
75
+ w_f = w_f * scale.view(1, 1)
76
+ else:
77
+ w_f = w_f * scale.view(1, -1)
78
+
79
+ return torch.matmul(a.to(torch.float32), w_f)
80
+
81
+
82
+ def matmul_int4(a: torch.Tensor,
83
+ w_int4_packed: torch.Tensor,
84
+ scale: torch.Tensor,
85
+ out_features: int) -> torch.Tensor:
86
+
87
+ a, w_int4_packed = _ensure_same_device(a, w_int4_packed)
88
+ scale = scale.to(a.device)
89
+
90
+ if a.dtype != torch.float16:
91
+ a = a.to(torch.float16)
92
+
93
+ if (
94
+ a.is_cuda
95
+ and w_int4_packed.is_cuda
96
+ and scale.is_cuda
97
+ and w_int4_packed.dtype == torch.int8
98
+ and scale.dtype == torch.float32
99
+ ):
100
+ return _hypertensor_ext.gemm_int4_fp16_tc(a, w_int4_packed, scale)
101
+
102
+ K = w_int4_packed.size(0)
103
+ N_packed = w_int4_packed.size(1)
104
+ N = out_features
105
+
106
+ w_full = torch.empty((K, N), dtype=torch.float32, device=a.device)
107
+
108
+ for k in range(K):
109
+ row = w_int4_packed[k]
110
+ for n in range(N):
111
+ col_packed = n // 2
112
+ high = (n % 2) == 1
113
+ packed = int(row[col_packed].item())
114
+ if high:
115
+ v = (packed >> 4) & 0x0F
116
+ else:
117
+ v = packed & 0x0F
118
+ if v & 0x08:
119
+ v |= 0xF0
120
+ w_full[k, n] = float(torch.tensor(v, dtype=torch.int8))
121
+
122
+ if scale.numel() == 1:
123
+ w_full = w_full * scale.view(1, 1)
124
+ else:
125
+ w_full = w_full * scale.view(1, -1)
126
+
127
+ return torch.matmul(a.to(torch.float32), w_full)
128
+
129
+
130
+
131
+ def matmul_bw_dx_fp16(grad_out: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
132
+ grad_out, w = _ensure_same_device(grad_out, w)
133
+ return _hypertensor_ext.gemm_bw_dx_fp16(grad_out, w)
134
+
135
+
136
+ def matmul_bw_dw_fp16(grad_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
137
+ grad_out, x = _ensure_same_device(grad_out, x)
138
+ return _hypertensor_ext.gemm_bw_dw_fp16(grad_out, x)
139
+
140
+ def matmul_bw_dx_bf16(grad_out: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
141
+ grad_out, w = _ensure_same_device(grad_out, w)
142
+ return _hypertensor_ext.gemm_bw_dx_bf16(grad_out, w)
143
+
144
+
145
+ def matmul_bw_dw_bf16(grad_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
146
+ grad_out, x = _ensure_same_device(grad_out, x)
147
+ return _hypertensor_ext.gemm_bw_dw_bf16(grad_out, x)
@@ -0,0 +1,118 @@
1
+ from enum import IntEnum
2
+ import torch
3
+
4
+
5
+ class PrecisionFormat(IntEnum):
6
+ INT4 = 1
7
+ INT8 = 2
8
+ FP16 = 5
9
+ BF16 = 6
10
+ FP32 = 7
11
+
12
+
13
+ PRECISION_ORDER = [
14
+ PrecisionFormat.INT4,
15
+ PrecisionFormat.INT8,
16
+ PrecisionFormat.FP16,
17
+ PrecisionFormat.BF16,
18
+ PrecisionFormat.FP32,
19
+ ]
20
+
21
+
22
+ def _can_be_int4(x: torch.Tensor) -> torch.Tensor:
23
+ return (x == x.round()) & (x >= -8) & (x <= 7)
24
+
25
+
26
+ def _can_be_int8(x: torch.Tensor) -> torch.Tensor:
27
+ return (x == x.round()) & (x >= -128) & (x <= 127)
28
+
29
+
30
+ def _roundtrip_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
31
+ return (x - y).abs()
32
+
33
+
34
+ def _can_be_fp16(x: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
35
+ x16 = x.to(torch.float16)
36
+ x32 = x16.to(torch.float32)
37
+ if eps == 0.0:
38
+ return x32 == x
39
+ else:
40
+ return _roundtrip_error(x, x32) <= eps
41
+
42
+
43
+ def _can_be_bf16(x: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
44
+ x16 = x.to(torch.bfloat16)
45
+ x32 = x16.to(torch.float32)
46
+ if eps == 0.0:
47
+ return x32 == x
48
+ else:
49
+ return _roundtrip_error(x, x32) <= eps
50
+
51
+
52
+ def choose_format_for_tensor(
53
+ t: torch.Tensor,
54
+ eps_int4: float = 1e-1,
55
+ eps_int8: float = 5e-2,
56
+ eps_fp16: float = 1e-3,
57
+ eps_bf16: float = 1e-3,
58
+ prefer_bf16: bool = False,
59
+ ) -> torch.Tensor:
60
+ """
61
+ Per-weight format choice based on approximate fit.
62
+ - INT4: allowed if quantization error <= eps_int4
63
+ - INT8: if INT4 too inaccurate, but error <= eps_int8
64
+ - FP16/BF16: if INT8 too inaccurate, but error <= eps_fp16/eps_bf16
65
+ - FP32: fallback if nothing else fits well
66
+ """
67
+
68
+ w32 = t.to(torch.float32)
69
+ fmt = torch.full_like(w32, fill_value=int(PrecisionFormat.FP32), dtype=torch.int8)
70
+
71
+ # ----- INT4 -----
72
+ q4 = w32.round().clamp(-8, 7)
73
+ err4 = (w32 - q4).abs()
74
+ mask_int4 = err4 <= eps_int4
75
+ fmt[mask_int4] = int(PrecisionFormat.INT4)
76
+
77
+ # ----- INT8 (only where not INT4) -----
78
+ q8 = w32.round().clamp(-128, 127)
79
+ err8 = (w32 - q8).abs()
80
+ mask_int8 = (err8 <= eps_int8) & ~mask_int4
81
+ fmt[mask_int8] = int(PrecisionFormat.INT8)
82
+
83
+ # ----- FP16 / BF16 (only where not INT4/INT8) -----
84
+ remaining = ~mask_int4 & ~mask_int8
85
+
86
+ if prefer_bf16:
87
+ # BF16 first
88
+ x_bf16 = w32.to(torch.bfloat16)
89
+ x_bf32 = x_bf16.to(torch.float32)
90
+ err_bf16 = (w32 - x_bf32).abs()
91
+ mask_bf16 = (err_bf16 <= eps_bf16) & remaining
92
+ fmt[mask_bf16] = int(PrecisionFormat.BF16)
93
+
94
+ remaining = remaining & ~mask_bf16
95
+
96
+ x_f16 = w32.to(torch.float16)
97
+ x_f32 = x_f16.to(torch.float32)
98
+ err_f16 = (w32 - x_f32).abs()
99
+ mask_fp16 = (err_f16 <= eps_fp16) & remaining
100
+ fmt[mask_fp16] = int(PrecisionFormat.FP16)
101
+ else:
102
+ # FP16 first
103
+ x_f16 = w32.to(torch.float16)
104
+ x_f32 = x_f16.to(torch.float32)
105
+ err_f16 = (w32 - x_f32).abs()
106
+ mask_fp16 = (err_f16 <= eps_fp16) & remaining
107
+ fmt[mask_fp16] = int(PrecisionFormat.FP16)
108
+
109
+ remaining = remaining & ~mask_fp16
110
+
111
+ x_bf16 = w32.to(torch.bfloat16)
112
+ x_bf32 = x_bf16.to(torch.float32)
113
+ err_bf16 = (w32 - x_bf32).abs()
114
+ mask_bf16 = (err_bf16 <= eps_bf16) & remaining
115
+ fmt[mask_bf16] = int(PrecisionFormat.BF16)
116
+
117
+ # Anything that didn't fit any of the above stays FP32
118
+ return fmt.to(torch.int8)
@@ -0,0 +1,115 @@
1
+ # HyperTensor/hyper_attention.py
2
+
3
+ import math
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .modules import HyperLinear
10
+
11
+
12
+ class HyperAttention(nn.Module):
13
+ def __init__(
14
+ self,
15
+ hidden_size: int,
16
+ num_heads: int,
17
+ num_kv_heads: Optional[int] = None,
18
+ bias: bool = False,
19
+ rope_theta: float = 10000.0,
20
+ prefer_bf16: bool = False,
21
+ inference_only: bool = False,
22
+ ):
23
+ super().__init__()
24
+ self.hidden_size = hidden_size
25
+ self.num_heads = num_heads
26
+ self.num_kv_heads = num_kv_heads or num_heads
27
+ self.head_dim = hidden_size // num_heads
28
+ self.rope_theta = rope_theta
29
+
30
+ self.q_proj = HyperLinear(
31
+ hidden_size,
32
+ num_heads * self.head_dim,
33
+ bias=bias,
34
+ prefer_bf16=prefer_bf16,
35
+ inference_only=inference_only,
36
+ )
37
+ self.k_proj = HyperLinear(
38
+ hidden_size,
39
+ self.num_kv_heads * self.head_dim,
40
+ bias=bias,
41
+ prefer_bf16=prefer_bf16,
42
+ inference_only=inference_only,
43
+ )
44
+ self.v_proj = HyperLinear(
45
+ hidden_size,
46
+ self.num_kv_heads * self.head_dim,
47
+ bias=bias,
48
+ prefer_bf16=prefer_bf16,
49
+ inference_only=inference_only,
50
+ )
51
+ self.o_proj = HyperLinear(
52
+ num_heads * self.head_dim,
53
+ hidden_size,
54
+ bias=bias,
55
+ prefer_bf16=prefer_bf16,
56
+ inference_only=inference_only,
57
+ )
58
+
59
+ self.prefer_bf16 = prefer_bf16
60
+ self.inference_only = inference_only
61
+
62
+ def _shape(self, x: torch.Tensor, bsz: int, n_heads: int) -> torch.Tensor:
63
+ return x.view(bsz, -1, n_heads, self.head_dim).transpose(1, 2)
64
+
65
+ def forward(
66
+ self,
67
+ hidden_states: torch.Tensor,
68
+ attention_mask: Optional[torch.Tensor] = None,
69
+ position_ids: Optional[torch.LongTensor] = None,
70
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
71
+ output_attentions: bool = False,
72
+ use_cache: bool = False,
73
+ **kwargs,
74
+ ):
75
+ bsz, q_len, _ = hidden_states.size()
76
+
77
+ q = self.q_proj(hidden_states)
78
+ k = self.k_proj(hidden_states)
79
+ v = self.v_proj(hidden_states)
80
+
81
+ q = self._shape(q, bsz, self.num_heads) # [B,H,T,D]
82
+ k = self._shape(k, bsz, self.num_kv_heads)
83
+ v = self._shape(v, bsz, self.num_kv_heads)
84
+
85
+ if past_key_value is not None:
86
+ k = torch.cat([past_key_value[0], k], dim=2)
87
+ v = torch.cat([past_key_value[1], v], dim=2)
88
+
89
+ present = (k, v) if use_cache else None
90
+
91
+ if self.num_kv_heads != self.num_heads:
92
+ repeat_factor = self.num_heads // self.num_kv_heads
93
+ k = k.repeat_interleave(repeat_factor, dim=1)
94
+ v = v.repeat_interleave(repeat_factor, dim=1)
95
+
96
+ dtype = torch.bfloat16 if self.prefer_bf16 else torch.float16
97
+ q = q.to(dtype)
98
+ k = k.to(dtype)
99
+ v = v.to(dtype)
100
+
101
+ # Fast FP16 attention using cuBLAS/FlashAttention
102
+ attn_scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
103
+
104
+ if attention_mask is not None:
105
+ attn_scores = attn_scores + attention_mask
106
+
107
+ attn_weights = torch.softmax(attn_scores, dim=-1)
108
+ attn_output = torch.matmul(attn_weights, v)
109
+
110
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
111
+ attn_output = self.o_proj(attn_output)
112
+
113
+ if output_attentions:
114
+ return attn_output, present, attn_weights
115
+ return attn_output, present