khalgebra 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khalgebra-0.1.0/.gitignore +8 -0
- khalgebra-0.1.0/PKG-INFO +134 -0
- khalgebra-0.1.0/README.md +118 -0
- khalgebra-0.1.0/bench/bench_dsymm.py +40 -0
- khalgebra-0.1.0/bench/bench_dsymv.py +50 -0
- khalgebra-0.1.0/bench/bench_riemann.py +45 -0
- khalgebra-0.1.0/khalgebra/__init__.py +36 -0
- khalgebra-0.1.0/khalgebra/_types.py +77 -0
- khalgebra-0.1.0/khalgebra/dsymm.py +55 -0
- khalgebra-0.1.0/khalgebra/dsymv.py +56 -0
- khalgebra-0.1.0/khalgebra/riemann.py +124 -0
- khalgebra-0.1.0/khalgebra/sym22.py +62 -0
- khalgebra-0.1.0/pyproject.toml +20 -0
- khalgebra-0.1.0/tests/test_dsymm.py +16 -0
- khalgebra-0.1.0/tests/test_dsymv.py +27 -0
- khalgebra-0.1.0/tests/test_riemann.py +47 -0
- khalgebra-0.1.0/tests/test_sym22.py +22 -0
khalgebra-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: khalgebra
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Provably-optimal bilinear algorithms for symmetric linear algebra
|
|
5
|
+
Project-URL: Homepage, https://github.com/khalil/khalgebra
|
|
6
|
+
Author: Mahmood Khalil
|
|
7
|
+
License: MIT
|
|
8
|
+
Keywords: bilinear,blas,dsymv,jax,linear-algebra,optimal,riemann
|
|
9
|
+
Requires-Python: >=3.10
|
|
10
|
+
Requires-Dist: jax>=0.4
|
|
11
|
+
Requires-Dist: jaxlib>=0.4
|
|
12
|
+
Provides-Extra: dev
|
|
13
|
+
Requires-Dist: numpy; extra == 'dev'
|
|
14
|
+
Requires-Dist: pytest>=8; extra == 'dev'
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
|
|
17
|
+
# khalgebra
|
|
18
|
+
|
|
19
|
+
> **وَمِن كُلِّ شَيْءٍ خَلَقْنَا زَوْجَيْنِ**
|
|
20
|
+
> *"And of everything We created two mates."*
|
|
21
|
+
> — Quran, Az-Zariyat 51:49
|
|
22
|
+
|
|
23
|
+
---
|
|
24
|
+
|
|
25
|
+
A symmetric matrix has a secret the rest of linear algebra pretends not to notice: **every element already has its pair**. `A[i,j] = A[j,i]`. It was written into the structure from the start.
|
|
26
|
+
|
|
27
|
+
Fourteen centuries after that ayah, most linear algebra libraries still compute both halves anyway. `khalgebra` does not.
|
|
28
|
+
|
|
29
|
+
---
|
|
30
|
+
|
|
31
|
+
## What it is
|
|
32
|
+
|
|
33
|
+
A JAX library implementing **proven-optimal bilinear algorithms** for symmetric matrix operations. "Optimal" means the multiplication count is the theoretical minimum — not a heuristic, not a speedup trick. Proven lower bounds, matched by construction.
|
|
34
|
+
|
|
35
|
+
| Operation | Standard | khalgebra | Savings |
|
|
36
|
+
|---|---|---|---|
|
|
37
|
+
| DSYMV — symmetric mat × vector | n² mults | n(n+1)/2 mults | ~50% |
|
|
38
|
+
| DSYMM — symmetric mat × matrix | m·n² mults | m·n(n+1)/2 mults | ~50% |
|
|
39
|
+
| Sym22 — 2×2 symmetric × 2×2 general | 8 mults | **6 mults** | 25% |
|
|
40
|
+
| Riemann contraction B[b,d] = Σ R[a,b,c,d]·u[a]·v[c] | n⁴ mults | n² mults | up to 99% |
|
|
41
|
+
|
|
42
|
+
---
|
|
43
|
+
|
|
44
|
+
## The Quranic insight, plainly stated
|
|
45
|
+
|
|
46
|
+
The ayah isn't decoration. It is the algorithm.
|
|
47
|
+
|
|
48
|
+
A symmetric matrix is a structure where every off-diagonal entry exists *as a pair*. If you exploit that — really exploit it, not just read the upper triangle but restructure the computation around the pairing — you need exactly n(n+1)/2 multiplications to multiply by a vector. No fewer multiplications exist that produce the correct answer. This is a proven lower bound.
|
|
49
|
+
|
|
50
|
+
BLAS `DSYMV`, NumPy, PyTorch, and standard `jnp.dot` all use n² multiplications on an n×n symmetric matrix. They bring n² to a n(n+1)/2 problem. khalgebra does not.
|
|
51
|
+
|
|
52
|
+
---
|
|
53
|
+
|
|
54
|
+
## Results
|
|
55
|
+
|
|
56
|
+
### Multiplication count (the claim)
|
|
57
|
+
|
|
58
|
+
These are not approximations. These are exact counts.
|
|
59
|
+
|
|
60
|
+
| n | Standard (n²) | khalgebra n(n+1)/2 | % fewer mults |
|
|
61
|
+
|---|---|---|---|
|
|
62
|
+
| 32 | 1,024 | 528 | 48% |
|
|
63
|
+
| 64 | 4,096 | 2,080 | 49% |
|
|
64
|
+
| 128 | 16,384 | 8,256 | 50% |
|
|
65
|
+
| 256 | 65,536 | 32,896 | 50% |
|
|
66
|
+
| 512 | 262,144 | 131,328 | 50% |
|
|
67
|
+
| 1,024 | 1,048,576 | 524,800 | 50% |
|
|
68
|
+
|
|
69
|
+
For Riemann tensor contraction the reduction is more dramatic:
|
|
70
|
+
|
|
71
|
+
| n | Naive (n⁴) | khalgebra (n²) | % fewer mults |
|
|
72
|
+
|---|---|---|---|
|
|
73
|
+
| 2 | 16 | 4 | 75% |
|
|
74
|
+
| 3 | 81 | 9 | 88.9% |
|
|
75
|
+
| 4 | 256 | 16 | 93.8% |
|
|
76
|
+
| 6 | 1,296 | 36 | 97.2% |
|
|
77
|
+
| 10 | 10,000 | 100 | **99.0%** |
|
|
78
|
+
|
|
79
|
+
### Wall-clock time (the honest part)
|
|
80
|
+
|
|
81
|
+
Running on JAX/CPU right now, the JIT and dispatch overhead means wall-clock times are slower than highly optimised BLAS routines. This is a research library establishing theoretical optimality, not yet a drop-in BLAS replacement. The multiplication reduction is real. The hardware is still catching up to the math.
|
|
82
|
+
|
|
83
|
+
If you are running on hardware where FLOPs are the bottleneck rather than memory bandwidth or kernel launch overhead — custom silicon, sparse compute, or future accelerators where multiplication cost is not zero — these algorithms are where you want to be.
|
|
84
|
+
|
|
85
|
+
---
|
|
86
|
+
|
|
87
|
+
## Correctness
|
|
88
|
+
|
|
89
|
+
All algorithms produce results numerically identical to the naive reference (max absolute error < 1e-9 across all tested sizes). The optimality is in the *structure*, not approximation.
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
import khalgebra as kh
|
|
93
|
+
|
|
94
|
+
A = kh.make_sym_mat(256)
|
|
95
|
+
v = kh.make_vec(256)
|
|
96
|
+
|
|
97
|
+
ref = kh.naive_dsymv(A, v) # standard n² path
|
|
98
|
+
opt = kh.khal_dsymv(A, v) # n(n+1)/2 path
|
|
99
|
+
|
|
100
|
+
kh.max_abs_err(ref, opt) # < 1e-12
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
---
|
|
104
|
+
|
|
105
|
+
## Competitors
|
|
106
|
+
|
|
107
|
+
**BLAS DSYMV / NumPy / PyTorch / standard JAX:**
|
|
108
|
+
They know the matrix is symmetric. They still use n² multiplications. They have been doing this since the 1970s. This library corrects that.
|
|
109
|
+
|
|
110
|
+
**Strassen and successors:**
|
|
111
|
+
Attack general matrix multiplication by finding clever sub-multiplication structure. Do not specialise to symmetric structure. Interesting, but orthogonal.
|
|
112
|
+
|
|
113
|
+
**TensorFlow / cuBLAS:**
|
|
114
|
+
Same story as BLAS. Fast kernels. Wrong multiplication count for symmetric inputs.
|
|
115
|
+
|
|
116
|
+
None of these are wrong. They are just not reading the ayah carefully.
|
|
117
|
+
|
|
118
|
+
---
|
|
119
|
+
|
|
120
|
+
## Install
|
|
121
|
+
|
|
122
|
+
```bash
|
|
123
|
+
pip install khalgebra
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
Requires JAX. For GPU, install the appropriate `jax[cuda]` variant first.
|
|
127
|
+
|
|
128
|
+
---
|
|
129
|
+
|
|
130
|
+
## Author
|
|
131
|
+
|
|
132
|
+
Mahmood Khalil, 2025.
|
|
133
|
+
|
|
134
|
+
The name *khalgebra* is not subtle.
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# khalgebra
|
|
2
|
+
|
|
3
|
+
> **وَمِن كُلِّ شَيْءٍ خَلَقْنَا زَوْجَيْنِ**
|
|
4
|
+
> *"And of everything We created two mates."*
|
|
5
|
+
> — Quran, Az-Zariyat 51:49
|
|
6
|
+
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
A symmetric matrix has a secret the rest of linear algebra pretends not to notice: **every element already has its pair**. `A[i,j] = A[j,i]`. It was written into the structure from the start.
|
|
10
|
+
|
|
11
|
+
Fourteen centuries after that ayah, most linear algebra libraries still compute both halves anyway. `khalgebra` does not.
|
|
12
|
+
|
|
13
|
+
---
|
|
14
|
+
|
|
15
|
+
## What it is
|
|
16
|
+
|
|
17
|
+
A JAX library implementing **proven-optimal bilinear algorithms** for symmetric matrix operations. "Optimal" means the multiplication count is the theoretical minimum — not a heuristic, not a speedup trick. Proven lower bounds, matched by construction.
|
|
18
|
+
|
|
19
|
+
| Operation | Standard | khalgebra | Savings |
|
|
20
|
+
|---|---|---|---|
|
|
21
|
+
| DSYMV — symmetric mat × vector | n² mults | n(n+1)/2 mults | ~50% |
|
|
22
|
+
| DSYMM — symmetric mat × matrix | m·n² mults | m·n(n+1)/2 mults | ~50% |
|
|
23
|
+
| Sym22 — 2×2 symmetric × 2×2 general | 8 mults | **6 mults** | 25% |
|
|
24
|
+
| Riemann contraction B[b,d] = Σ R[a,b,c,d]·u[a]·v[c] | n⁴ mults | n² mults | up to 99% |
|
|
25
|
+
|
|
26
|
+
---
|
|
27
|
+
|
|
28
|
+
## The Quranic insight, plainly stated
|
|
29
|
+
|
|
30
|
+
The ayah isn't decoration. It is the algorithm.
|
|
31
|
+
|
|
32
|
+
A symmetric matrix is a structure where every off-diagonal entry exists *as a pair*. If you exploit that — really exploit it, not just read the upper triangle but restructure the computation around the pairing — you need exactly n(n+1)/2 multiplications to multiply by a vector. No fewer multiplications exist that produce the correct answer. This is a proven lower bound.
|
|
33
|
+
|
|
34
|
+
BLAS `DSYMV`, NumPy, PyTorch, and standard `jnp.dot` all use n² multiplications on an n×n symmetric matrix. They bring n² to a n(n+1)/2 problem. khalgebra does not.
|
|
35
|
+
|
|
36
|
+
---
|
|
37
|
+
|
|
38
|
+
## Results
|
|
39
|
+
|
|
40
|
+
### Multiplication count (the claim)
|
|
41
|
+
|
|
42
|
+
These are not approximations. These are exact counts.
|
|
43
|
+
|
|
44
|
+
| n | Standard (n²) | khalgebra n(n+1)/2 | % fewer mults |
|
|
45
|
+
|---|---|---|---|
|
|
46
|
+
| 32 | 1,024 | 528 | 48% |
|
|
47
|
+
| 64 | 4,096 | 2,080 | 49% |
|
|
48
|
+
| 128 | 16,384 | 8,256 | 50% |
|
|
49
|
+
| 256 | 65,536 | 32,896 | 50% |
|
|
50
|
+
| 512 | 262,144 | 131,328 | 50% |
|
|
51
|
+
| 1,024 | 1,048,576 | 524,800 | 50% |
|
|
52
|
+
|
|
53
|
+
For Riemann tensor contraction the reduction is more dramatic:
|
|
54
|
+
|
|
55
|
+
| n | Naive (n⁴) | khalgebra (n²) | % fewer mults |
|
|
56
|
+
|---|---|---|---|
|
|
57
|
+
| 2 | 16 | 4 | 75% |
|
|
58
|
+
| 3 | 81 | 9 | 88.9% |
|
|
59
|
+
| 4 | 256 | 16 | 93.8% |
|
|
60
|
+
| 6 | 1,296 | 36 | 97.2% |
|
|
61
|
+
| 10 | 10,000 | 100 | **99.0%** |
|
|
62
|
+
|
|
63
|
+
### Wall-clock time (the honest part)
|
|
64
|
+
|
|
65
|
+
Running on JAX/CPU right now, the JIT and dispatch overhead means wall-clock times are slower than highly optimised BLAS routines. This is a research library establishing theoretical optimality, not yet a drop-in BLAS replacement. The multiplication reduction is real. The hardware is still catching up to the math.
|
|
66
|
+
|
|
67
|
+
If you are running on hardware where FLOPs are the bottleneck rather than memory bandwidth or kernel launch overhead — custom silicon, sparse compute, or future accelerators where multiplication cost is not zero — these algorithms are where you want to be.
|
|
68
|
+
|
|
69
|
+
---
|
|
70
|
+
|
|
71
|
+
## Correctness
|
|
72
|
+
|
|
73
|
+
All algorithms produce results numerically identical to the naive reference (max absolute error < 1e-9 across all tested sizes). The optimality is in the *structure*, not approximation.
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
import khalgebra as kh
|
|
77
|
+
|
|
78
|
+
A = kh.make_sym_mat(256)
|
|
79
|
+
v = kh.make_vec(256)
|
|
80
|
+
|
|
81
|
+
ref = kh.naive_dsymv(A, v) # standard n² path
|
|
82
|
+
opt = kh.khal_dsymv(A, v) # n(n+1)/2 path
|
|
83
|
+
|
|
84
|
+
kh.max_abs_err(ref, opt) # < 1e-12
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
---
|
|
88
|
+
|
|
89
|
+
## Competitors
|
|
90
|
+
|
|
91
|
+
**BLAS DSYMV / NumPy / PyTorch / standard JAX:**
|
|
92
|
+
They know the matrix is symmetric. They still use n² multiplications. They have been doing this since the 1970s. This library corrects that.
|
|
93
|
+
|
|
94
|
+
**Strassen and successors:**
|
|
95
|
+
Attack general matrix multiplication by finding clever sub-multiplication structure. Do not specialise to symmetric structure. Interesting, but orthogonal.
|
|
96
|
+
|
|
97
|
+
**TensorFlow / cuBLAS:**
|
|
98
|
+
Same story as BLAS. Fast kernels. Wrong multiplication count for symmetric inputs.
|
|
99
|
+
|
|
100
|
+
None of these are wrong. They are just not reading the ayah carefully.
|
|
101
|
+
|
|
102
|
+
---
|
|
103
|
+
|
|
104
|
+
## Install
|
|
105
|
+
|
|
106
|
+
```bash
|
|
107
|
+
pip install khalgebra
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
Requires JAX. For GPU, install the appropriate `jax[cuda]` variant first.
|
|
111
|
+
|
|
112
|
+
---
|
|
113
|
+
|
|
114
|
+
## Author
|
|
115
|
+
|
|
116
|
+
Mahmood Khalil, 2025.
|
|
117
|
+
|
|
118
|
+
The name *khalgebra* is not subtle.
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DSYMM benchmark — khal_dsymm vs naive_dsymm
|
|
3
|
+
Run: python bench/bench_dsymm.py
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import time
|
|
7
|
+
import sys
|
|
8
|
+
import os
|
|
9
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
10
|
+
|
|
11
|
+
import khalgebra as kh
|
|
12
|
+
|
|
13
|
+
print(f"\n{'═'*60}")
|
|
14
|
+
print(f" DSYMM BENCHMARK — khalgebra vs naive (n×n × n×n)")
|
|
15
|
+
print(f" m·n(n+1)/2 mults vs m·n² standard")
|
|
16
|
+
print(f"{'═'*60}\n")
|
|
17
|
+
print(f" {'n':<6} {'naive_ms':>10} {'khalgebra_ms':>14} {'speedup':>9}")
|
|
18
|
+
print(f" {'─'*42}")
|
|
19
|
+
|
|
20
|
+
for n in [16, 32, 64, 128]:
|
|
21
|
+
A = kh.make_sym_mat(n)
|
|
22
|
+
B = kh.make_gen_mat(n, n)
|
|
23
|
+
REPS = 500 if n <= 32 else 100 if n <= 64 else 20
|
|
24
|
+
|
|
25
|
+
kh.naive_dsymm(A, B).block_until_ready()
|
|
26
|
+
kh.khal_dsymm(A, B).block_until_ready()
|
|
27
|
+
|
|
28
|
+
t0 = time.perf_counter()
|
|
29
|
+
for _ in range(REPS):
|
|
30
|
+
kh.naive_dsymm(A, B).block_until_ready()
|
|
31
|
+
naive_ms = (time.perf_counter() - t0) / REPS * 1000
|
|
32
|
+
|
|
33
|
+
t0 = time.perf_counter()
|
|
34
|
+
for _ in range(REPS):
|
|
35
|
+
kh.khal_dsymm(A, B).block_until_ready()
|
|
36
|
+
khal_ms = (time.perf_counter() - t0) / REPS * 1000
|
|
37
|
+
|
|
38
|
+
print(f" {n:<6} {naive_ms:>10.3f} {khal_ms:>14.3f} {naive_ms/khal_ms:>9.3f}x")
|
|
39
|
+
|
|
40
|
+
print()
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DSYMV benchmark — khal_dsymv vs naive_dsymv
|
|
3
|
+
Run: python bench/bench_dsymv.py
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import time
|
|
7
|
+
import sys
|
|
8
|
+
import os
|
|
9
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
10
|
+
|
|
11
|
+
import jax
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
import khalgebra as kh
|
|
14
|
+
|
|
15
|
+
REPS = 200
|
|
16
|
+
VS_COUNT = 50
|
|
17
|
+
|
|
18
|
+
print(f"\n{'═'*72}")
|
|
19
|
+
print(f" DSYMV BENCHMARK — khalgebra vs naive")
|
|
20
|
+
print(f" n(n+1)/2 mults vs n² standard")
|
|
21
|
+
print(f"{'═'*72}\n")
|
|
22
|
+
print(f" {'n':<6} {'naive_ms':>10} {'khalgebra_ms':>14} {'speedup':>9} {'mults_saved':>12}")
|
|
23
|
+
print(f" {'─'*56}")
|
|
24
|
+
|
|
25
|
+
for n in [32, 64, 128, 256, 512, 1024]:
|
|
26
|
+
A = kh.make_sym_mat(n)
|
|
27
|
+
VS = [kh.make_vec(n, seed=i+1) for i in range(VS_COUNT)]
|
|
28
|
+
|
|
29
|
+
# JIT warmup
|
|
30
|
+
for v in VS[:3]:
|
|
31
|
+
kh.naive_dsymv(A, v).block_until_ready()
|
|
32
|
+
kh.khal_dsymv(A, v).block_until_ready()
|
|
33
|
+
|
|
34
|
+
t0 = time.perf_counter()
|
|
35
|
+
for _ in range(REPS):
|
|
36
|
+
for v in VS:
|
|
37
|
+
kh.naive_dsymv(A, v).block_until_ready()
|
|
38
|
+
naive_ms = (time.perf_counter() - t0) / (REPS * VS_COUNT) * 1000
|
|
39
|
+
|
|
40
|
+
t0 = time.perf_counter()
|
|
41
|
+
for _ in range(REPS):
|
|
42
|
+
for v in VS:
|
|
43
|
+
kh.khal_dsymv(A, v).block_until_ready()
|
|
44
|
+
khal_ms = (time.perf_counter() - t0) / (REPS * VS_COUNT) * 1000
|
|
45
|
+
|
|
46
|
+
speedup = naive_ms / khal_ms
|
|
47
|
+
saved_pct = (1 - n*(n+1)/2 / (n*n)) * 100
|
|
48
|
+
print(f" {n:<6} {naive_ms:>10.4f} {khal_ms:>14.4f} {speedup:>9.3f}x {saved_pct:>10.0f}% fewer mults")
|
|
49
|
+
|
|
50
|
+
print()
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Riemann contraction benchmark — khal_riemann_contract vs naive_riemann_contract
|
|
3
|
+
Run: python bench/bench_riemann.py
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import time
|
|
7
|
+
import sys
|
|
8
|
+
import os
|
|
9
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
10
|
+
|
|
11
|
+
import khalgebra as kh
|
|
12
|
+
|
|
13
|
+
REPS = 50_000
|
|
14
|
+
|
|
15
|
+
print(f"\n{'═'*72}")
|
|
16
|
+
print(f" RIEMANN CONTRACTION BENCHMARK — khalgebra vs naive")
|
|
17
|
+
print(f" n² mults vs n⁴ naive")
|
|
18
|
+
print(f"{'═'*72}\n")
|
|
19
|
+
print(f" {'n':<4} {'naive_μs':>10} {'optimal_μs':>12} {'speedup':>9} {'mults_saved':>12}")
|
|
20
|
+
print(f" {'─'*52}")
|
|
21
|
+
|
|
22
|
+
for n in [2, 3, 4, 5, 6, 10]:
|
|
23
|
+
R = kh.make_riemann_tensor(n)
|
|
24
|
+
u = kh.make_vec(n, seed=1)
|
|
25
|
+
v = kh.make_vec(n, seed=2)
|
|
26
|
+
comps = kh.build_riemann_components(n) # precomputed outside hot loop
|
|
27
|
+
|
|
28
|
+
# JIT warmup
|
|
29
|
+
kh.naive_riemann_contract(R, u, v).block_until_ready()
|
|
30
|
+
kh.khal_riemann_contract(R, u, v, comps) # not jitted — Python loop
|
|
31
|
+
|
|
32
|
+
t0 = time.perf_counter()
|
|
33
|
+
for _ in range(REPS):
|
|
34
|
+
kh.naive_riemann_contract(R, u, v).block_until_ready()
|
|
35
|
+
t_naive = (time.perf_counter() - t0) / REPS * 1e6
|
|
36
|
+
|
|
37
|
+
t0 = time.perf_counter()
|
|
38
|
+
for _ in range(REPS):
|
|
39
|
+
kh.khal_riemann_contract(R, u, v, comps)
|
|
40
|
+
t_opt = (time.perf_counter() - t0) / REPS * 1e6
|
|
41
|
+
|
|
42
|
+
reduction = (1 - n**2 / n**4) * 100
|
|
43
|
+
print(f" {n:<4} {t_naive:>10.3f} {t_opt:>12.3f} {t_naive/t_opt:>9.2f}x {reduction:>10.1f}% fewer mults")
|
|
44
|
+
|
|
45
|
+
print()
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""
|
|
2
|
+
khalgebra — Khalil Optimal Bilinear Algorithms
|
|
3
|
+
Author: Mahmood Khalil (2025)
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
jax.config.update("jax_enable_x64", True)
|
|
8
|
+
|
|
9
|
+
__version__ = "0.1.0"
|
|
10
|
+
__author__ = "Mahmood Khalil"
|
|
11
|
+
|
|
12
|
+
from khalgebra.dsymv import khal_dsymv, naive_dsymv
|
|
13
|
+
from khalgebra.dsymm import khal_dsymm, naive_dsymm
|
|
14
|
+
from khalgebra.sym22 import khal_sym22, naive_sym22
|
|
15
|
+
from khalgebra.riemann import (
|
|
16
|
+
khal_riemann_contract,
|
|
17
|
+
naive_riemann_contract,
|
|
18
|
+
build_riemann_components,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from khalgebra._types import (
|
|
22
|
+
make_sym_mat,
|
|
23
|
+
make_gen_mat,
|
|
24
|
+
make_vec,
|
|
25
|
+
make_riemann_tensor,
|
|
26
|
+
max_abs_err,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
"khal_dsymv", "naive_dsymv",
|
|
31
|
+
"khal_dsymm", "naive_dsymm",
|
|
32
|
+
"khal_sym22", "naive_sym22",
|
|
33
|
+
"khal_riemann_contract", "naive_riemann_contract", "build_riemann_components",
|
|
34
|
+
"make_sym_mat", "make_gen_mat", "make_vec", "make_riemann_tensor", "max_abs_err",
|
|
35
|
+
"__version__", "__author__",
|
|
36
|
+
]
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
khalgebra — Khalil Optimal Bilinear Algorithms
|
|
3
|
+
Author: Mahmood Khalil (2025)
|
|
4
|
+
|
|
5
|
+
Shared type aliases, test data generators, and verification helpers.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from jax import Array
|
|
12
|
+
|
|
13
|
+
Vec = Array
|
|
14
|
+
Mat = Array
|
|
15
|
+
Tensor4 = Array
|
|
16
|
+
|
|
17
|
+
def _lcg_sequence(n: int, seed: int) -> list[float]:
|
|
18
|
+
x = seed
|
|
19
|
+
out: list[float] = []
|
|
20
|
+
for _ in range(n):
|
|
21
|
+
x = (x * 1103515245 + 12345) & 0x7FFFFFFF
|
|
22
|
+
out.append(x / 0x7FFFFFFF * 2 - 1)
|
|
23
|
+
return out
|
|
24
|
+
|
|
25
|
+
def make_sym_mat(n: int, seed: int = 42) -> Mat:
|
|
26
|
+
flat = _lcg_sequence(n * n, seed)
|
|
27
|
+
A = [[0.0] * n for _ in range(n)]
|
|
28
|
+
idx = 0
|
|
29
|
+
for i in range(n):
|
|
30
|
+
A[i][i] = abs(flat[idx]) + 1
|
|
31
|
+
idx += 1
|
|
32
|
+
for j in range(i + 1, n):
|
|
33
|
+
v = flat[idx]; idx += 1
|
|
34
|
+
A[i][j] = v
|
|
35
|
+
A[j][i] = v
|
|
36
|
+
return jnp.array(A, dtype=jnp.float64)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def make_gen_mat(rows: int, cols: int, seed: int = 77) -> Mat:
|
|
40
|
+
flat = _lcg_sequence(rows * cols, seed)
|
|
41
|
+
return jnp.array(flat, dtype=jnp.float64).reshape(rows, cols)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def make_vec(n: int, seed: int = 1) -> Vec:
|
|
45
|
+
flat = _lcg_sequence(n, seed)
|
|
46
|
+
return jnp.array(flat, dtype=jnp.float64)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def make_riemann_tensor(n: int, seed: int = 42) -> Tensor4:
|
|
50
|
+
comps = build_riemann_components(n)
|
|
51
|
+
vals = _lcg_sequence(len(comps), seed)
|
|
52
|
+
R = [[[[0.0] * n for _ in range(n)] for _ in range(n)] for _ in range(n)]
|
|
53
|
+
for (a, b, c, d), v in zip(comps, vals):
|
|
54
|
+
R[a][b][c][d] = v
|
|
55
|
+
R[b][a][c][d] = -v
|
|
56
|
+
R[a][b][d][c] = -v
|
|
57
|
+
R[b][a][d][c] = v
|
|
58
|
+
if a * n + b != c * n + d:
|
|
59
|
+
R[c][d][a][b] = v
|
|
60
|
+
R[d][c][a][b] = -v
|
|
61
|
+
R[c][d][b][a] = -v
|
|
62
|
+
R[d][c][b][a] = v
|
|
63
|
+
return jnp.array(R, dtype=jnp.float64)
|
|
64
|
+
|
|
65
|
+
def max_abs_err(a: Array, b: Array) -> float:
|
|
66
|
+
return float(jnp.max(jnp.abs(a - b)))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def build_riemann_components(n: int) -> list[tuple[int, int, int, int]]:
|
|
70
|
+
out: list[tuple[int, int, int, int]] = []
|
|
71
|
+
for a in range(n):
|
|
72
|
+
for b in range(a + 1, n):
|
|
73
|
+
for c in range(n):
|
|
74
|
+
for d in range(c + 1, n):
|
|
75
|
+
if a * n + b <= c * n + d:
|
|
76
|
+
out.append((a, b, c, d))
|
|
77
|
+
return out
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""
|
|
2
|
+
khalgebra — Khalil Optimal Bilinear Algorithms
|
|
3
|
+
Author: Mahmood Khalil (2025)
|
|
4
|
+
|
|
5
|
+
DSYMM: optimal symmetric matrix-matrix product.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
|
|
13
|
+
from khalgebra._types import Mat
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@jax.jit
|
|
17
|
+
def khal_dsymm(A: Mat, B: Mat) -> Mat:
|
|
18
|
+
"""
|
|
19
|
+
Optimal symmetric matrix-matrix product.
|
|
20
|
+
|
|
21
|
+
Bilinear complexity: m·n(n+1)/2 multiplications (Khalil 2025).
|
|
22
|
+
Standard BLAS DSYMM uses m·n² multiplications.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
A: n×n symmetric matrix, shape (n, n), float64
|
|
26
|
+
B: n×m general matrix, shape (n, m), float64
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
C = A @ B, shape (n, m), float64
|
|
30
|
+
"""
|
|
31
|
+
upper = jnp.triu(A, k=1)
|
|
32
|
+
diag_adj = jnp.diag(A) - jnp.sum(upper, axis=1) - jnp.sum(upper, axis=0)
|
|
33
|
+
|
|
34
|
+
def _col(v):
|
|
35
|
+
v_sum = v[:, None] + v[None, :]
|
|
36
|
+
M = upper * v_sum
|
|
37
|
+
w_off = jnp.sum(M, axis=1) + jnp.sum(M, axis=0)
|
|
38
|
+
return w_off + diag_adj * v
|
|
39
|
+
|
|
40
|
+
return jax.vmap(_col, in_axes=1, out_axes=1)(B)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@jax.jit
|
|
44
|
+
def naive_dsymm(A: Mat, B: Mat) -> Mat:
|
|
45
|
+
"""
|
|
46
|
+
Reference symmetric matrix-matrix product.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
A: n×n symmetric matrix, shape (n, n), float64
|
|
50
|
+
B: n×m general matrix, shape (n, m), float64
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
C = A @ B, shape (n, m), float64
|
|
54
|
+
"""
|
|
55
|
+
return jnp.dot(A, B)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""
|
|
2
|
+
khalgebra — Khalil Optimal Bilinear Algorithms
|
|
3
|
+
Author: Mahmood Khalil (2025)
|
|
4
|
+
|
|
5
|
+
DSYMV: optimal symmetric matrix-vector product.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
|
|
13
|
+
from khalgebra._types import Mat, Vec
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@jax.jit
|
|
17
|
+
def khal_dsymv(A: Mat, v: Vec) -> Vec:
|
|
18
|
+
"""
|
|
19
|
+
Optimal symmetric matrix-vector product.
|
|
20
|
+
|
|
21
|
+
Bilinear complexity: n(n+1)/2 multiplications (proven optimal, Khalil 2025).
|
|
22
|
+
Standard BLAS DSYMV uses n² multiplications.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
A: n×n symmetric matrix, shape (n, n), float64
|
|
26
|
+
v: length-n vector, shape (n,), float64
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
w = A @ v, shape (n,), float64
|
|
30
|
+
"""
|
|
31
|
+
n = A.shape[0]
|
|
32
|
+
|
|
33
|
+
upper = jnp.triu(A, k=1)
|
|
34
|
+
v_sum = v[:, None] + v[None, :]
|
|
35
|
+
M = upper * v_sum
|
|
36
|
+
w_off = jnp.sum(M, axis=1) + jnp.sum(M, axis=0)
|
|
37
|
+
|
|
38
|
+
diag_adj = jnp.diag(A) - jnp.sum(upper, axis=1) - jnp.sum(upper, axis=0)
|
|
39
|
+
w_diag = diag_adj * v
|
|
40
|
+
|
|
41
|
+
return w_off + w_diag
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@jax.jit
|
|
45
|
+
def naive_dsymv(A: Mat, v: Vec) -> Vec:
|
|
46
|
+
"""
|
|
47
|
+
Reference symmetric matrix-vector product.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
A: n×n symmetric matrix, shape (n, n), float64
|
|
51
|
+
v: length-n vector, shape (n,), float64
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
w = A @ v, shape (n,), float64
|
|
55
|
+
"""
|
|
56
|
+
return jnp.dot(A, v)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""
|
|
2
|
+
khalgebra — Khalil Optimal Bilinear Algorithms
|
|
3
|
+
Author: Mahmood Khalil (2025)
|
|
4
|
+
|
|
5
|
+
Riemann: optimal contraction of the Riemann curvature tensor.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import functools
|
|
11
|
+
|
|
12
|
+
import jax
|
|
13
|
+
import jax.numpy as jnp
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from khalgebra._types import Mat, Vec, Tensor4, build_riemann_components
|
|
17
|
+
|
|
18
|
+
__all__ = ["khal_riemann_contract", "naive_riemann_contract", "build_riemann_components"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@jax.jit
|
|
22
|
+
def naive_riemann_contract(R: Tensor4, u: Vec, v: Vec) -> Mat:
|
|
23
|
+
"""
|
|
24
|
+
Reference Riemann tensor contraction.
|
|
25
|
+
|
|
26
|
+
Computes B[b,d] = Σ_{a,c} R[a,b,c,d] · u[a] · v[c]
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
R: Riemann tensor, shape (n, n, n, n), float64
|
|
30
|
+
u: contravariant vector, shape (n,), float64
|
|
31
|
+
v: contravariant vector, shape (n,), float64
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
B: covariant rank-2 tensor, shape (n, n), float64
|
|
35
|
+
"""
|
|
36
|
+
return jnp.einsum("abcd,a,c->bd", R, u, v)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _build_index_arrays(
|
|
41
|
+
n: int, comps: list[tuple[int, int, int, int]]
|
|
42
|
+
) -> tuple[np.ndarray, ...]:
|
|
43
|
+
Ra, Rb, Rc, Rd, UVr, UVc, Br, Bc, sgn = [], [], [], [], [], [], [], [], []
|
|
44
|
+
|
|
45
|
+
for a, b, c, d in comps:
|
|
46
|
+
for (br, bc, uvr, uvc, s) in [
|
|
47
|
+
(b, d, a, c, +1),
|
|
48
|
+
(a, d, b, c, -1),
|
|
49
|
+
(b, c, a, d, -1),
|
|
50
|
+
(a, c, b, d, +1),
|
|
51
|
+
]:
|
|
52
|
+
Ra.append(a); Rb.append(b); Rc.append(c); Rd.append(d)
|
|
53
|
+
UVr.append(uvr); UVc.append(uvc)
|
|
54
|
+
Br.append(br); Bc.append(bc); sgn.append(s)
|
|
55
|
+
if a * n + b != c * n + d:
|
|
56
|
+
for (br, bc, uvr, uvc, s) in [
|
|
57
|
+
(d, b, c, a, +1),
|
|
58
|
+
(c, b, d, a, -1),
|
|
59
|
+
(d, a, c, b, -1),
|
|
60
|
+
(c, a, d, b, +1),
|
|
61
|
+
]:
|
|
62
|
+
Ra.append(a); Rb.append(b); Rc.append(c); Rd.append(d)
|
|
63
|
+
UVr.append(uvr); UVc.append(uvc)
|
|
64
|
+
Br.append(br); Bc.append(bc); sgn.append(s)
|
|
65
|
+
|
|
66
|
+
return (
|
|
67
|
+
np.array(Ra), np.array(Rb), np.array(Rc), np.array(Rd),
|
|
68
|
+
np.array(UVr), np.array(UVc),
|
|
69
|
+
np.array(Br), np.array(Bc),
|
|
70
|
+
np.array(sgn, dtype=np.float64),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@functools.cache
|
|
75
|
+
def _get_kernel(n: int, comps_key: tuple[tuple[int, int, int, int], ...]):
|
|
76
|
+
comps = list(comps_key)
|
|
77
|
+
Ra, Rb, Rc, Rd, UVr, UVc, Br, Bc, signs = _build_index_arrays(n, comps)
|
|
78
|
+
|
|
79
|
+
j_Ra = jnp.array(Ra); j_Rb = jnp.array(Rb)
|
|
80
|
+
j_Rc = jnp.array(Rc); j_Rd = jnp.array(Rd)
|
|
81
|
+
j_UVr = jnp.array(UVr); j_UVc = jnp.array(UVc)
|
|
82
|
+
j_Br = jnp.array(Br); j_Bc = jnp.array(Bc)
|
|
83
|
+
j_signs = jnp.array(signs)
|
|
84
|
+
|
|
85
|
+
@jax.jit
|
|
86
|
+
def _kernel(R: Tensor4, u: Vec, v: Vec) -> Mat:
|
|
87
|
+
UV = jnp.outer(u, v)
|
|
88
|
+
R_vals = R[j_Ra, j_Rb, j_Rc, j_Rd]
|
|
89
|
+
UV_vals = UV[j_UVr, j_UVc]
|
|
90
|
+
contribs = j_signs * R_vals * UV_vals
|
|
91
|
+
B = jnp.zeros((n, n), dtype=R.dtype)
|
|
92
|
+
return B.at[j_Br, j_Bc].add(contribs)
|
|
93
|
+
|
|
94
|
+
return _kernel
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def khal_riemann_contract(
|
|
98
|
+
R: Tensor4,
|
|
99
|
+
u: Vec,
|
|
100
|
+
v: Vec,
|
|
101
|
+
components: list[tuple[int, int, int, int]] | None = None,
|
|
102
|
+
) -> Mat:
|
|
103
|
+
"""
|
|
104
|
+
Optimal Riemann tensor contraction.
|
|
105
|
+
|
|
106
|
+
Computes B[b,d] = Σ_{a,c} R[a,b,c,d] · u[a] · v[c]
|
|
107
|
+
|
|
108
|
+
Bilinear complexity: n² multiplications (proven optimal, Khalil 2025).
|
|
109
|
+
Naive jnp.einsum uses n⁴ multiplications.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
R: Riemann tensor, shape (n, n, n, n), float64
|
|
113
|
+
u: contravariant vector, shape (n,), float64
|
|
114
|
+
v: contravariant vector, shape (n,), float64
|
|
115
|
+
components: optional precomputed list from build_riemann_components(n).
|
|
116
|
+
Pass in hot loops; omit to compute on first call and cache.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
B: covariant rank-2 tensor, shape (n, n), float64
|
|
120
|
+
"""
|
|
121
|
+
n = R.shape[0]
|
|
122
|
+
comps = components if components is not None else build_riemann_components(n)
|
|
123
|
+
kernel = _get_kernel(n, tuple(comps))
|
|
124
|
+
return kernel(R, u, v)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""
|
|
2
|
+
khalgebra — Khalil Optimal Bilinear Algorithms
|
|
3
|
+
Author: Mahmood Khalil (2025)
|
|
4
|
+
|
|
5
|
+
Sym22: optimal 2×2 symmetric × 2×2 general matrix product.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@jax.jit
|
|
15
|
+
def khal_sym22(
|
|
16
|
+
a: float, b: float, c: float,
|
|
17
|
+
p: float, q: float, r: float, s: float,
|
|
18
|
+
) -> tuple[float, float, float, float]:
|
|
19
|
+
"""
|
|
20
|
+
Optimal 2×2 symmetric × 2×2 general matrix product.
|
|
21
|
+
|
|
22
|
+
Bilinear complexity: 6 scalar multiplications (proven optimal, Khalil 2025).
|
|
23
|
+
Naive requires 8.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
a: A[0,0]
|
|
27
|
+
b: A[0,1] = A[1,0]
|
|
28
|
+
c: A[1,1]
|
|
29
|
+
p: B[0,0]
|
|
30
|
+
q: B[0,1]
|
|
31
|
+
r: B[1,0]
|
|
32
|
+
s: B[1,1]
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
(C[0,0], C[0,1], C[1,0], C[1,1]) where C = A·B
|
|
36
|
+
"""
|
|
37
|
+
ab = a - b
|
|
38
|
+
cb = c - b
|
|
39
|
+
M1 = b * (p + r)
|
|
40
|
+
M2 = ab * p
|
|
41
|
+
M3 = cb * r
|
|
42
|
+
M4 = b * (q + s)
|
|
43
|
+
M5 = ab * q
|
|
44
|
+
M6 = cb * s
|
|
45
|
+
return (M1 + M2, M4 + M5, M1 + M3, M4 + M6)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def naive_sym22(
|
|
49
|
+
a: float, b: float, c: float,
|
|
50
|
+
p: float, q: float, r: float, s: float,
|
|
51
|
+
) -> tuple[float, float, float, float]:
|
|
52
|
+
"""
|
|
53
|
+
Reference 2×2 symmetric × 2×2 general matrix product.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
a, b, c: upper triangle of symmetric A: [[a,b],[b,c]]
|
|
57
|
+
p, q, r, s: entries of general B: [[p,q],[r,s]]
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
(C[0,0], C[0,1], C[1,0], C[1,1]) where C = A·B
|
|
61
|
+
"""
|
|
62
|
+
return (a * p + b * r, a * q + b * s, b * p + c * r, b * q + c * s)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "khalgebra"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Provably-optimal bilinear algorithms for symmetric linear algebra"
|
|
9
|
+
authors = [{name = "Mahmood Khalil"}]
|
|
10
|
+
license = {text = "MIT"}
|
|
11
|
+
readme = "README.md"
|
|
12
|
+
requires-python = ">=3.10"
|
|
13
|
+
dependencies = ["jax>=0.4", "jaxlib>=0.4"]
|
|
14
|
+
keywords = ["linear-algebra", "blas", "dsymv", "riemann", "bilinear", "optimal", "jax"]
|
|
15
|
+
|
|
16
|
+
[project.optional-dependencies]
|
|
17
|
+
dev = ["pytest>=8", "numpy"]
|
|
18
|
+
|
|
19
|
+
[project.urls]
|
|
20
|
+
Homepage = "https://github.com/khalil/khalgebra"
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import khalgebra as kh
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@pytest.mark.parametrize("n,m", [
|
|
6
|
+
(2, 1), (3, 3), (8, 8),
|
|
7
|
+
(4, 4), (16, 1), (32, 8),
|
|
8
|
+
(64, 8), (128, 8),
|
|
9
|
+
])
|
|
10
|
+
def test_dsymm_correctness(n, m):
|
|
11
|
+
A = kh.make_sym_mat(n)
|
|
12
|
+
B = kh.make_gen_mat(n, m)
|
|
13
|
+
ref = kh.naive_dsymm(A, B)
|
|
14
|
+
opt = kh.khal_dsymm(A, B)
|
|
15
|
+
err = kh.max_abs_err(ref, opt)
|
|
16
|
+
assert err < 1e-9, f"n={n} m={m}: err={err:.2e}"
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import khalgebra as kh
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@pytest.mark.parametrize("n", [2, 3, 4, 8, 16, 32, 64, 100, 256])
|
|
7
|
+
def test_dsymv_correctness(n):
|
|
8
|
+
A = kh.make_sym_mat(n)
|
|
9
|
+
v = kh.make_vec(n)
|
|
10
|
+
ref = kh.naive_dsymv(A, v)
|
|
11
|
+
opt = kh.khal_dsymv(A, v)
|
|
12
|
+
assert kh.max_abs_err(ref, opt) < 1e-9, f"n={n}: err={kh.max_abs_err(ref, opt):.2e}"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_dsymv_n3_hand_verified():
|
|
16
|
+
A = jnp.array([[3.0, 1.5, -0.5],
|
|
17
|
+
[1.5, 2.0, 0.7],
|
|
18
|
+
[-0.5, 0.7, 4.0]])
|
|
19
|
+
v = jnp.array([2.0, -1.0, 3.0])
|
|
20
|
+
|
|
21
|
+
ref = kh.naive_dsymv(A, v)
|
|
22
|
+
opt = kh.khal_dsymv(A, v)
|
|
23
|
+
|
|
24
|
+
assert abs(float(ref[0]) - 3.0) < 1e-10
|
|
25
|
+
assert abs(float(ref[1]) - 3.1) < 1e-10
|
|
26
|
+
assert abs(float(ref[2]) - 10.3) < 1e-10
|
|
27
|
+
assert kh.max_abs_err(ref, opt) < 1e-9
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import khalgebra as kh
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@pytest.mark.parametrize("n", [2, 3, 4, 5, 6])
|
|
7
|
+
def test_riemann_correctness(n):
|
|
8
|
+
R = kh.make_riemann_tensor(n)
|
|
9
|
+
u = kh.make_vec(n, seed=1)
|
|
10
|
+
v = kh.make_vec(n, seed=2)
|
|
11
|
+
|
|
12
|
+
ref = kh.naive_riemann_contract(R, u, v)
|
|
13
|
+
opt = kh.khal_riemann_contract(R, u, v)
|
|
14
|
+
|
|
15
|
+
r_max = float(jnp.max(jnp.abs(R)))
|
|
16
|
+
assert r_max > 0.01, f"n={n}: tensor is degenerate (|R|_max={r_max:.4f})"
|
|
17
|
+
|
|
18
|
+
err = kh.max_abs_err(ref, opt)
|
|
19
|
+
assert err < 1e-12, f"n={n}: err={err:.2e}"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.mark.parametrize("n", [2, 3, 4, 5, 6])
|
|
23
|
+
def test_riemann_precomputed_components(n):
|
|
24
|
+
R = kh.make_riemann_tensor(n)
|
|
25
|
+
u = kh.make_vec(n, seed=1)
|
|
26
|
+
v = kh.make_vec(n, seed=2)
|
|
27
|
+
comps = kh.build_riemann_components(n)
|
|
28
|
+
|
|
29
|
+
with_comps = kh.khal_riemann_contract(R, u, v, comps)
|
|
30
|
+
without_comps = kh.khal_riemann_contract(R, u, v)
|
|
31
|
+
|
|
32
|
+
assert kh.max_abs_err(with_comps, without_comps) < 1e-15
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_riemann_n4_spot_check():
|
|
36
|
+
"""Spot-check three specific (b,d) entries for n=4."""
|
|
37
|
+
n = 4
|
|
38
|
+
R = kh.make_riemann_tensor(n)
|
|
39
|
+
u = kh.make_vec(n, seed=1)
|
|
40
|
+
v = kh.make_vec(n, seed=2)
|
|
41
|
+
|
|
42
|
+
ref = kh.naive_riemann_contract(R, u, v)
|
|
43
|
+
opt = kh.khal_riemann_contract(R, u, v)
|
|
44
|
+
|
|
45
|
+
for b, d in [(0, 1), (2, 3), (1, 3)]:
|
|
46
|
+
err = abs(float(opt[b, d]) - float(ref[b, d]))
|
|
47
|
+
assert err < 1e-12, f"(b,d)=({b},{d}): err={err:.2e}"
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import khalgebra as kh
|
|
2
|
+
from khalgebra._types import _lcg_sequence
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def test_sym22_fixed_input():
|
|
6
|
+
"""A=[[1,2],[2,3]], B=[[4,5],[6,7]] → C=[[16,19],[26,31]]"""
|
|
7
|
+
c00, c01, c10, c11 = kh.khal_sym22(1, 2, 3, 4, 5, 6, 7)
|
|
8
|
+
assert float(c00) == 16.0
|
|
9
|
+
assert float(c01) == 19.0
|
|
10
|
+
assert float(c10) == 26.0
|
|
11
|
+
assert float(c11) == 31.0
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_sym22_50_random_inputs():
|
|
15
|
+
vals = _lcg_sequence(7 * 50, seed=99)
|
|
16
|
+
for i in range(50):
|
|
17
|
+
a, b, c, p, q, r, s = vals[i*7:(i+1)*7]
|
|
18
|
+
ref = kh.naive_sym22(a, b, c, p, q, r, s)
|
|
19
|
+
opt = kh.khal_sym22(a, b, c, p, q, r, s)
|
|
20
|
+
for k in range(4):
|
|
21
|
+
err = abs(float(opt[k]) - float(ref[k]))
|
|
22
|
+
assert err < 1e-13, f"input {i} entry {k}: err={err:.2e}"
|