hypertensor 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hypertensor-1.0.0/HyperTensor/__init__.py +17 -0
- hypertensor-1.0.0/HyperTensor/block_utils.py +8 -0
- hypertensor-1.0.0/HyperTensor/cuda_backend.py +147 -0
- hypertensor-1.0.0/HyperTensor/formats.py +118 -0
- hypertensor-1.0.0/HyperTensor/hyper_attention.py +115 -0
- hypertensor-1.0.0/HyperTensor/modules.py +464 -0
- hypertensor-1.0.0/HyperTensor/quantize.py +325 -0
- hypertensor-1.0.0/HyperTensor/serialize.py +39 -0
- hypertensor-1.0.0/HyperTensor/weight_options.py +20 -0
- hypertensor-1.0.0/HyperTensor.egg-info/PKG-INFO +208 -0
- hypertensor-1.0.0/HyperTensor.egg-info/SOURCES.txt +32 -0
- hypertensor-1.0.0/HyperTensor.egg-info/dependency_links.txt +1 -0
- hypertensor-1.0.0/HyperTensor.egg-info/requires.txt +4 -0
- hypertensor-1.0.0/HyperTensor.egg-info/top_level.txt +1 -0
- hypertensor-1.0.0/LICENSE +21 -0
- hypertensor-1.0.0/PKG-INFO +208 -0
- hypertensor-1.0.0/README.md +196 -0
- hypertensor-1.0.0/pyproject.toml +29 -0
- hypertensor-1.0.0/setup.cfg +4 -0
- hypertensor-1.0.0/setup.py +10 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .formats import PrecisionFormat
|
|
2
|
+
from .weight_options import WeightOption
|
|
3
|
+
from .quantize import scan_model
|
|
4
|
+
from .modules import HyperLinear, wrap_model_with_hypertensor
|
|
5
|
+
from .hyper_attention import HyperAttention
|
|
6
|
+
from .serialize import save_hypertensors, load_hypertensors
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"PrecisionFormat",
|
|
10
|
+
"WeightOption",
|
|
11
|
+
"scan_model",
|
|
12
|
+
"HyperLinear",
|
|
13
|
+
"HyperAttention",
|
|
14
|
+
"wrap_model_with_hypertensor",
|
|
15
|
+
"save_hypertensors",
|
|
16
|
+
"load_hypertensors",
|
|
17
|
+
]
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def compute_block_grid(shape: Tuple[int, int], block_size: int) -> Tuple[int, int]:
|
|
5
|
+
out_features, in_features = shape
|
|
6
|
+
n_blocks_out = (out_features + block_size - 1) // block_size
|
|
7
|
+
n_blocks_in = (in_features + block_size - 1) // block_size
|
|
8
|
+
return n_blocks_out, n_blocks_in
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from torch.utils.cpp_extension import load
|
|
4
|
+
|
|
5
|
+
_src_dir = os.path.dirname(__file__)
|
|
6
|
+
_ext_name = "hypertensor_kernels"
|
|
7
|
+
|
|
8
|
+
_hypertensor_ext = load(
|
|
9
|
+
name=_ext_name,
|
|
10
|
+
sources=[
|
|
11
|
+
os.path.join(_src_dir, "kernels.cu"),
|
|
12
|
+
os.path.join(_src_dir, "binding.cpp"),
|
|
13
|
+
],
|
|
14
|
+
verbose=False,
|
|
15
|
+
extra_cflags=[
|
|
16
|
+
"-O3",
|
|
17
|
+
"-DTORCH_DISABLE_DYNAMO",
|
|
18
|
+
],
|
|
19
|
+
extra_cuda_cflags=[
|
|
20
|
+
"-O3",
|
|
21
|
+
"--use_fast_math",
|
|
22
|
+
"-DTORCH_DISABLE_DYNAMO",
|
|
23
|
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
|
24
|
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
|
25
|
+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
|
26
|
+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
|
27
|
+
"-gencode=arch=compute_89,code=sm_89",
|
|
28
|
+
],
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _ensure_same_device(a: torch.Tensor, b: torch.Tensor):
|
|
33
|
+
if a.device != b.device:
|
|
34
|
+
b = b.to(a.device)
|
|
35
|
+
return a, b
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# ---------- FP32 / FP16 / BF16: let cuBLAS do its job ----------
|
|
39
|
+
|
|
40
|
+
def matmul_fp32(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
a, b = _ensure_same_device(a, b)
|
|
42
|
+
return torch.matmul(a, b)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def matmul_fp16(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
a, b = _ensure_same_device(a, b)
|
|
47
|
+
return torch.matmul(a, b)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def matmul_bf16(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
51
|
+
a, b = _ensure_same_device(a, b)
|
|
52
|
+
return torch.matmul(a, b)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# ---------- INT8 / INT4 weight GEMM (for HyperLinear) ----------
|
|
56
|
+
|
|
57
|
+
def matmul_int8(a: torch.Tensor, w_int8: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
58
|
+
a, w_int8 = _ensure_same_device(a, w_int8)
|
|
59
|
+
scale = scale.to(a.device)
|
|
60
|
+
|
|
61
|
+
if a.dtype != torch.float16:
|
|
62
|
+
a = a.to(torch.float16)
|
|
63
|
+
|
|
64
|
+
if (
|
|
65
|
+
a.is_cuda
|
|
66
|
+
and w_int8.is_cuda
|
|
67
|
+
and scale.is_cuda
|
|
68
|
+
and w_int8.dtype == torch.int8
|
|
69
|
+
and scale.dtype == torch.float32
|
|
70
|
+
):
|
|
71
|
+
return _hypertensor_ext.gemm_int8_fp16_tc(a, w_int8, scale)
|
|
72
|
+
|
|
73
|
+
w_f = w_int8.to(torch.float32)
|
|
74
|
+
if scale.numel() == 1:
|
|
75
|
+
w_f = w_f * scale.view(1, 1)
|
|
76
|
+
else:
|
|
77
|
+
w_f = w_f * scale.view(1, -1)
|
|
78
|
+
|
|
79
|
+
return torch.matmul(a.to(torch.float32), w_f)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def matmul_int4(a: torch.Tensor,
|
|
83
|
+
w_int4_packed: torch.Tensor,
|
|
84
|
+
scale: torch.Tensor,
|
|
85
|
+
out_features: int) -> torch.Tensor:
|
|
86
|
+
|
|
87
|
+
a, w_int4_packed = _ensure_same_device(a, w_int4_packed)
|
|
88
|
+
scale = scale.to(a.device)
|
|
89
|
+
|
|
90
|
+
if a.dtype != torch.float16:
|
|
91
|
+
a = a.to(torch.float16)
|
|
92
|
+
|
|
93
|
+
if (
|
|
94
|
+
a.is_cuda
|
|
95
|
+
and w_int4_packed.is_cuda
|
|
96
|
+
and scale.is_cuda
|
|
97
|
+
and w_int4_packed.dtype == torch.int8
|
|
98
|
+
and scale.dtype == torch.float32
|
|
99
|
+
):
|
|
100
|
+
return _hypertensor_ext.gemm_int4_fp16_tc(a, w_int4_packed, scale)
|
|
101
|
+
|
|
102
|
+
K = w_int4_packed.size(0)
|
|
103
|
+
N_packed = w_int4_packed.size(1)
|
|
104
|
+
N = out_features
|
|
105
|
+
|
|
106
|
+
w_full = torch.empty((K, N), dtype=torch.float32, device=a.device)
|
|
107
|
+
|
|
108
|
+
for k in range(K):
|
|
109
|
+
row = w_int4_packed[k]
|
|
110
|
+
for n in range(N):
|
|
111
|
+
col_packed = n // 2
|
|
112
|
+
high = (n % 2) == 1
|
|
113
|
+
packed = int(row[col_packed].item())
|
|
114
|
+
if high:
|
|
115
|
+
v = (packed >> 4) & 0x0F
|
|
116
|
+
else:
|
|
117
|
+
v = packed & 0x0F
|
|
118
|
+
if v & 0x08:
|
|
119
|
+
v |= 0xF0
|
|
120
|
+
w_full[k, n] = float(torch.tensor(v, dtype=torch.int8))
|
|
121
|
+
|
|
122
|
+
if scale.numel() == 1:
|
|
123
|
+
w_full = w_full * scale.view(1, 1)
|
|
124
|
+
else:
|
|
125
|
+
w_full = w_full * scale.view(1, -1)
|
|
126
|
+
|
|
127
|
+
return torch.matmul(a.to(torch.float32), w_full)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def matmul_bw_dx_fp16(grad_out: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
|
132
|
+
grad_out, w = _ensure_same_device(grad_out, w)
|
|
133
|
+
return _hypertensor_ext.gemm_bw_dx_fp16(grad_out, w)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def matmul_bw_dw_fp16(grad_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
137
|
+
grad_out, x = _ensure_same_device(grad_out, x)
|
|
138
|
+
return _hypertensor_ext.gemm_bw_dw_fp16(grad_out, x)
|
|
139
|
+
|
|
140
|
+
def matmul_bw_dx_bf16(grad_out: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
|
141
|
+
grad_out, w = _ensure_same_device(grad_out, w)
|
|
142
|
+
return _hypertensor_ext.gemm_bw_dx_bf16(grad_out, w)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def matmul_bw_dw_bf16(grad_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
146
|
+
grad_out, x = _ensure_same_device(grad_out, x)
|
|
147
|
+
return _hypertensor_ext.gemm_bw_dw_bf16(grad_out, x)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from enum import IntEnum
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PrecisionFormat(IntEnum):
|
|
6
|
+
INT4 = 1
|
|
7
|
+
INT8 = 2
|
|
8
|
+
FP16 = 5
|
|
9
|
+
BF16 = 6
|
|
10
|
+
FP32 = 7
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
PRECISION_ORDER = [
|
|
14
|
+
PrecisionFormat.INT4,
|
|
15
|
+
PrecisionFormat.INT8,
|
|
16
|
+
PrecisionFormat.FP16,
|
|
17
|
+
PrecisionFormat.BF16,
|
|
18
|
+
PrecisionFormat.FP32,
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _can_be_int4(x: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
return (x == x.round()) & (x >= -8) & (x <= 7)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _can_be_int8(x: torch.Tensor) -> torch.Tensor:
|
|
27
|
+
return (x == x.round()) & (x >= -128) & (x <= 127)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _roundtrip_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
31
|
+
return (x - y).abs()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _can_be_fp16(x: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
|
|
35
|
+
x16 = x.to(torch.float16)
|
|
36
|
+
x32 = x16.to(torch.float32)
|
|
37
|
+
if eps == 0.0:
|
|
38
|
+
return x32 == x
|
|
39
|
+
else:
|
|
40
|
+
return _roundtrip_error(x, x32) <= eps
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _can_be_bf16(x: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
|
|
44
|
+
x16 = x.to(torch.bfloat16)
|
|
45
|
+
x32 = x16.to(torch.float32)
|
|
46
|
+
if eps == 0.0:
|
|
47
|
+
return x32 == x
|
|
48
|
+
else:
|
|
49
|
+
return _roundtrip_error(x, x32) <= eps
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def choose_format_for_tensor(
|
|
53
|
+
t: torch.Tensor,
|
|
54
|
+
eps_int4: float = 1e-1,
|
|
55
|
+
eps_int8: float = 5e-2,
|
|
56
|
+
eps_fp16: float = 1e-3,
|
|
57
|
+
eps_bf16: float = 1e-3,
|
|
58
|
+
prefer_bf16: bool = False,
|
|
59
|
+
) -> torch.Tensor:
|
|
60
|
+
"""
|
|
61
|
+
Per-weight format choice based on approximate fit.
|
|
62
|
+
- INT4: allowed if quantization error <= eps_int4
|
|
63
|
+
- INT8: if INT4 too inaccurate, but error <= eps_int8
|
|
64
|
+
- FP16/BF16: if INT8 too inaccurate, but error <= eps_fp16/eps_bf16
|
|
65
|
+
- FP32: fallback if nothing else fits well
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
w32 = t.to(torch.float32)
|
|
69
|
+
fmt = torch.full_like(w32, fill_value=int(PrecisionFormat.FP32), dtype=torch.int8)
|
|
70
|
+
|
|
71
|
+
# ----- INT4 -----
|
|
72
|
+
q4 = w32.round().clamp(-8, 7)
|
|
73
|
+
err4 = (w32 - q4).abs()
|
|
74
|
+
mask_int4 = err4 <= eps_int4
|
|
75
|
+
fmt[mask_int4] = int(PrecisionFormat.INT4)
|
|
76
|
+
|
|
77
|
+
# ----- INT8 (only where not INT4) -----
|
|
78
|
+
q8 = w32.round().clamp(-128, 127)
|
|
79
|
+
err8 = (w32 - q8).abs()
|
|
80
|
+
mask_int8 = (err8 <= eps_int8) & ~mask_int4
|
|
81
|
+
fmt[mask_int8] = int(PrecisionFormat.INT8)
|
|
82
|
+
|
|
83
|
+
# ----- FP16 / BF16 (only where not INT4/INT8) -----
|
|
84
|
+
remaining = ~mask_int4 & ~mask_int8
|
|
85
|
+
|
|
86
|
+
if prefer_bf16:
|
|
87
|
+
# BF16 first
|
|
88
|
+
x_bf16 = w32.to(torch.bfloat16)
|
|
89
|
+
x_bf32 = x_bf16.to(torch.float32)
|
|
90
|
+
err_bf16 = (w32 - x_bf32).abs()
|
|
91
|
+
mask_bf16 = (err_bf16 <= eps_bf16) & remaining
|
|
92
|
+
fmt[mask_bf16] = int(PrecisionFormat.BF16)
|
|
93
|
+
|
|
94
|
+
remaining = remaining & ~mask_bf16
|
|
95
|
+
|
|
96
|
+
x_f16 = w32.to(torch.float16)
|
|
97
|
+
x_f32 = x_f16.to(torch.float32)
|
|
98
|
+
err_f16 = (w32 - x_f32).abs()
|
|
99
|
+
mask_fp16 = (err_f16 <= eps_fp16) & remaining
|
|
100
|
+
fmt[mask_fp16] = int(PrecisionFormat.FP16)
|
|
101
|
+
else:
|
|
102
|
+
# FP16 first
|
|
103
|
+
x_f16 = w32.to(torch.float16)
|
|
104
|
+
x_f32 = x_f16.to(torch.float32)
|
|
105
|
+
err_f16 = (w32 - x_f32).abs()
|
|
106
|
+
mask_fp16 = (err_f16 <= eps_fp16) & remaining
|
|
107
|
+
fmt[mask_fp16] = int(PrecisionFormat.FP16)
|
|
108
|
+
|
|
109
|
+
remaining = remaining & ~mask_fp16
|
|
110
|
+
|
|
111
|
+
x_bf16 = w32.to(torch.bfloat16)
|
|
112
|
+
x_bf32 = x_bf16.to(torch.float32)
|
|
113
|
+
err_bf16 = (w32 - x_bf32).abs()
|
|
114
|
+
mask_bf16 = (err_bf16 <= eps_bf16) & remaining
|
|
115
|
+
fmt[mask_bf16] = int(PrecisionFormat.BF16)
|
|
116
|
+
|
|
117
|
+
# Anything that didn't fit any of the above stays FP32
|
|
118
|
+
return fmt.to(torch.int8)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# HyperTensor/hyper_attention.py
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from .modules import HyperLinear
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HyperAttention(nn.Module):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
hidden_size: int,
|
|
16
|
+
num_heads: int,
|
|
17
|
+
num_kv_heads: Optional[int] = None,
|
|
18
|
+
bias: bool = False,
|
|
19
|
+
rope_theta: float = 10000.0,
|
|
20
|
+
prefer_bf16: bool = False,
|
|
21
|
+
inference_only: bool = False,
|
|
22
|
+
):
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.hidden_size = hidden_size
|
|
25
|
+
self.num_heads = num_heads
|
|
26
|
+
self.num_kv_heads = num_kv_heads or num_heads
|
|
27
|
+
self.head_dim = hidden_size // num_heads
|
|
28
|
+
self.rope_theta = rope_theta
|
|
29
|
+
|
|
30
|
+
self.q_proj = HyperLinear(
|
|
31
|
+
hidden_size,
|
|
32
|
+
num_heads * self.head_dim,
|
|
33
|
+
bias=bias,
|
|
34
|
+
prefer_bf16=prefer_bf16,
|
|
35
|
+
inference_only=inference_only,
|
|
36
|
+
)
|
|
37
|
+
self.k_proj = HyperLinear(
|
|
38
|
+
hidden_size,
|
|
39
|
+
self.num_kv_heads * self.head_dim,
|
|
40
|
+
bias=bias,
|
|
41
|
+
prefer_bf16=prefer_bf16,
|
|
42
|
+
inference_only=inference_only,
|
|
43
|
+
)
|
|
44
|
+
self.v_proj = HyperLinear(
|
|
45
|
+
hidden_size,
|
|
46
|
+
self.num_kv_heads * self.head_dim,
|
|
47
|
+
bias=bias,
|
|
48
|
+
prefer_bf16=prefer_bf16,
|
|
49
|
+
inference_only=inference_only,
|
|
50
|
+
)
|
|
51
|
+
self.o_proj = HyperLinear(
|
|
52
|
+
num_heads * self.head_dim,
|
|
53
|
+
hidden_size,
|
|
54
|
+
bias=bias,
|
|
55
|
+
prefer_bf16=prefer_bf16,
|
|
56
|
+
inference_only=inference_only,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
self.prefer_bf16 = prefer_bf16
|
|
60
|
+
self.inference_only = inference_only
|
|
61
|
+
|
|
62
|
+
def _shape(self, x: torch.Tensor, bsz: int, n_heads: int) -> torch.Tensor:
|
|
63
|
+
return x.view(bsz, -1, n_heads, self.head_dim).transpose(1, 2)
|
|
64
|
+
|
|
65
|
+
def forward(
|
|
66
|
+
self,
|
|
67
|
+
hidden_states: torch.Tensor,
|
|
68
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
69
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
70
|
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
71
|
+
output_attentions: bool = False,
|
|
72
|
+
use_cache: bool = False,
|
|
73
|
+
**kwargs,
|
|
74
|
+
):
|
|
75
|
+
bsz, q_len, _ = hidden_states.size()
|
|
76
|
+
|
|
77
|
+
q = self.q_proj(hidden_states)
|
|
78
|
+
k = self.k_proj(hidden_states)
|
|
79
|
+
v = self.v_proj(hidden_states)
|
|
80
|
+
|
|
81
|
+
q = self._shape(q, bsz, self.num_heads) # [B,H,T,D]
|
|
82
|
+
k = self._shape(k, bsz, self.num_kv_heads)
|
|
83
|
+
v = self._shape(v, bsz, self.num_kv_heads)
|
|
84
|
+
|
|
85
|
+
if past_key_value is not None:
|
|
86
|
+
k = torch.cat([past_key_value[0], k], dim=2)
|
|
87
|
+
v = torch.cat([past_key_value[1], v], dim=2)
|
|
88
|
+
|
|
89
|
+
present = (k, v) if use_cache else None
|
|
90
|
+
|
|
91
|
+
if self.num_kv_heads != self.num_heads:
|
|
92
|
+
repeat_factor = self.num_heads // self.num_kv_heads
|
|
93
|
+
k = k.repeat_interleave(repeat_factor, dim=1)
|
|
94
|
+
v = v.repeat_interleave(repeat_factor, dim=1)
|
|
95
|
+
|
|
96
|
+
dtype = torch.bfloat16 if self.prefer_bf16 else torch.float16
|
|
97
|
+
q = q.to(dtype)
|
|
98
|
+
k = k.to(dtype)
|
|
99
|
+
v = v.to(dtype)
|
|
100
|
+
|
|
101
|
+
# Fast FP16 attention using cuBLAS/FlashAttention
|
|
102
|
+
attn_scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
|
103
|
+
|
|
104
|
+
if attention_mask is not None:
|
|
105
|
+
attn_scores = attn_scores + attention_mask
|
|
106
|
+
|
|
107
|
+
attn_weights = torch.softmax(attn_scores, dim=-1)
|
|
108
|
+
attn_output = torch.matmul(attn_weights, v)
|
|
109
|
+
|
|
110
|
+
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
|
|
111
|
+
attn_output = self.o_proj(attn_output)
|
|
112
|
+
|
|
113
|
+
if output_attentions:
|
|
114
|
+
return attn_output, present, attn_weights
|
|
115
|
+
return attn_output, present
|