khalgebra 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,8 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.egg-info/
4
+ dist/
5
+ build/
6
+ .pytest_cache/
7
+ .venv/
8
+ *.egg
@@ -0,0 +1,134 @@
1
+ Metadata-Version: 2.4
2
+ Name: khalgebra
3
+ Version: 0.1.0
4
+ Summary: Provably-optimal bilinear algorithms for symmetric linear algebra
5
+ Project-URL: Homepage, https://github.com/khalil/khalgebra
6
+ Author: Mahmood Khalil
7
+ License: MIT
8
+ Keywords: bilinear,blas,dsymv,jax,linear-algebra,optimal,riemann
9
+ Requires-Python: >=3.10
10
+ Requires-Dist: jax>=0.4
11
+ Requires-Dist: jaxlib>=0.4
12
+ Provides-Extra: dev
13
+ Requires-Dist: numpy; extra == 'dev'
14
+ Requires-Dist: pytest>=8; extra == 'dev'
15
+ Description-Content-Type: text/markdown
16
+
17
+ # khalgebra
18
+
19
+ > **وَمِن كُلِّ شَيْءٍ خَلَقْنَا زَوْجَيْنِ**
20
+ > *"And of everything We created two mates."*
21
+ > — Quran, Az-Zariyat 51:49
22
+
23
+ ---
24
+
25
+ A symmetric matrix has a secret the rest of linear algebra pretends not to notice: **every element already has its pair**. `A[i,j] = A[j,i]`. It was written into the structure from the start.
26
+
27
+ Fourteen centuries after that ayah, most linear algebra libraries still compute both halves anyway. `khalgebra` does not.
28
+
29
+ ---
30
+
31
+ ## What it is
32
+
33
+ A JAX library implementing **proven-optimal bilinear algorithms** for symmetric matrix operations. "Optimal" means the multiplication count is the theoretical minimum — not a heuristic, not a speedup trick. Proven lower bounds, matched by construction.
34
+
35
+ | Operation | Standard | khalgebra | Savings |
36
+ |---|---|---|---|
37
+ | DSYMV — symmetric mat × vector | n² mults | n(n+1)/2 mults | ~50% |
38
+ | DSYMM — symmetric mat × matrix | m·n² mults | m·n(n+1)/2 mults | ~50% |
39
+ | Sym22 — 2×2 symmetric × 2×2 general | 8 mults | **6 mults** | 25% |
40
+ | Riemann contraction B[b,d] = Σ R[a,b,c,d]·u[a]·v[c] | n⁴ mults | n² mults | up to 99% |
41
+
42
+ ---
43
+
44
+ ## The Quranic insight, plainly stated
45
+
46
+ The ayah isn't decoration. It is the algorithm.
47
+
48
+ A symmetric matrix is a structure where every off-diagonal entry exists *as a pair*. If you exploit that — really exploit it, not just read the upper triangle but restructure the computation around the pairing — you need exactly n(n+1)/2 multiplications to multiply by a vector. No fewer multiplications exist that produce the correct answer. This is a proven lower bound.
49
+
50
+ BLAS `DSYMV`, NumPy, PyTorch, and standard `jnp.dot` all use n² multiplications on an n×n symmetric matrix. They bring n² to a n(n+1)/2 problem. khalgebra does not.
51
+
52
+ ---
53
+
54
+ ## Results
55
+
56
+ ### Multiplication count (the claim)
57
+
58
+ These are not approximations. These are exact counts.
59
+
60
+ | n | Standard (n²) | khalgebra n(n+1)/2 | % fewer mults |
61
+ |---|---|---|---|
62
+ | 32 | 1,024 | 528 | 48% |
63
+ | 64 | 4,096 | 2,080 | 49% |
64
+ | 128 | 16,384 | 8,256 | 50% |
65
+ | 256 | 65,536 | 32,896 | 50% |
66
+ | 512 | 262,144 | 131,328 | 50% |
67
+ | 1,024 | 1,048,576 | 524,800 | 50% |
68
+
69
+ For Riemann tensor contraction the reduction is more dramatic:
70
+
71
+ | n | Naive (n⁴) | khalgebra (n²) | % fewer mults |
72
+ |---|---|---|---|
73
+ | 2 | 16 | 4 | 75% |
74
+ | 3 | 81 | 9 | 88.9% |
75
+ | 4 | 256 | 16 | 93.8% |
76
+ | 6 | 1,296 | 36 | 97.2% |
77
+ | 10 | 10,000 | 100 | **99.0%** |
78
+
79
+ ### Wall-clock time (the honest part)
80
+
81
+ Running on JAX/CPU right now, the JIT and dispatch overhead means wall-clock times are slower than highly optimised BLAS routines. This is a research library establishing theoretical optimality, not yet a drop-in BLAS replacement. The multiplication reduction is real. The hardware is still catching up to the math.
82
+
83
+ If you are running on hardware where FLOPs are the bottleneck rather than memory bandwidth or kernel launch overhead — custom silicon, sparse compute, or future accelerators where multiplication cost is not zero — these algorithms are where you want to be.
84
+
85
+ ---
86
+
87
+ ## Correctness
88
+
89
+ All algorithms produce results numerically identical to the naive reference (max absolute error < 1e-9 across all tested sizes). The optimality is in the *structure*, not approximation.
90
+
91
+ ```python
92
+ import khalgebra as kh
93
+
94
+ A = kh.make_sym_mat(256)
95
+ v = kh.make_vec(256)
96
+
97
+ ref = kh.naive_dsymv(A, v) # standard n² path
98
+ opt = kh.khal_dsymv(A, v) # n(n+1)/2 path
99
+
100
+ kh.max_abs_err(ref, opt) # < 1e-12
101
+ ```
102
+
103
+ ---
104
+
105
+ ## Competitors
106
+
107
+ **BLAS DSYMV / NumPy / PyTorch / standard JAX:**
108
+ They know the matrix is symmetric. They still use n² multiplications. They have been doing this since the 1970s. This library corrects that.
109
+
110
+ **Strassen and successors:**
111
+ Attack general matrix multiplication by finding clever sub-multiplication structure. Do not specialise to symmetric structure. Interesting, but orthogonal.
112
+
113
+ **TensorFlow / cuBLAS:**
114
+ Same story as BLAS. Fast kernels. Wrong multiplication count for symmetric inputs.
115
+
116
+ None of these are wrong. They are just not reading the ayah carefully.
117
+
118
+ ---
119
+
120
+ ## Install
121
+
122
+ ```bash
123
+ pip install khalgebra
124
+ ```
125
+
126
+ Requires JAX. For GPU, install the appropriate `jax[cuda]` variant first.
127
+
128
+ ---
129
+
130
+ ## Author
131
+
132
+ Mahmood Khalil, 2025.
133
+
134
+ The name *khalgebra* is not subtle.
@@ -0,0 +1,118 @@
1
+ # khalgebra
2
+
3
+ > **وَمِن كُلِّ شَيْءٍ خَلَقْنَا زَوْجَيْنِ**
4
+ > *"And of everything We created two mates."*
5
+ > — Quran, Az-Zariyat 51:49
6
+
7
+ ---
8
+
9
+ A symmetric matrix has a secret the rest of linear algebra pretends not to notice: **every element already has its pair**. `A[i,j] = A[j,i]`. It was written into the structure from the start.
10
+
11
+ Fourteen centuries after that ayah, most linear algebra libraries still compute both halves anyway. `khalgebra` does not.
12
+
13
+ ---
14
+
15
+ ## What it is
16
+
17
+ A JAX library implementing **proven-optimal bilinear algorithms** for symmetric matrix operations. "Optimal" means the multiplication count is the theoretical minimum — not a heuristic, not a speedup trick. Proven lower bounds, matched by construction.
18
+
19
+ | Operation | Standard | khalgebra | Savings |
20
+ |---|---|---|---|
21
+ | DSYMV — symmetric mat × vector | n² mults | n(n+1)/2 mults | ~50% |
22
+ | DSYMM — symmetric mat × matrix | m·n² mults | m·n(n+1)/2 mults | ~50% |
23
+ | Sym22 — 2×2 symmetric × 2×2 general | 8 mults | **6 mults** | 25% |
24
+ | Riemann contraction B[b,d] = Σ R[a,b,c,d]·u[a]·v[c] | n⁴ mults | n² mults | up to 99% |
25
+
26
+ ---
27
+
28
+ ## The Quranic insight, plainly stated
29
+
30
+ The ayah isn't decoration. It is the algorithm.
31
+
32
+ A symmetric matrix is a structure where every off-diagonal entry exists *as a pair*. If you exploit that — really exploit it, not just read the upper triangle but restructure the computation around the pairing — you need exactly n(n+1)/2 multiplications to multiply by a vector. No fewer multiplications exist that produce the correct answer. This is a proven lower bound.
33
+
34
+ BLAS `DSYMV`, NumPy, PyTorch, and standard `jnp.dot` all use n² multiplications on an n×n symmetric matrix. They bring n² to a n(n+1)/2 problem. khalgebra does not.
35
+
36
+ ---
37
+
38
+ ## Results
39
+
40
+ ### Multiplication count (the claim)
41
+
42
+ These are not approximations. These are exact counts.
43
+
44
+ | n | Standard (n²) | khalgebra n(n+1)/2 | % fewer mults |
45
+ |---|---|---|---|
46
+ | 32 | 1,024 | 528 | 48% |
47
+ | 64 | 4,096 | 2,080 | 49% |
48
+ | 128 | 16,384 | 8,256 | 50% |
49
+ | 256 | 65,536 | 32,896 | 50% |
50
+ | 512 | 262,144 | 131,328 | 50% |
51
+ | 1,024 | 1,048,576 | 524,800 | 50% |
52
+
53
+ For Riemann tensor contraction the reduction is more dramatic:
54
+
55
+ | n | Naive (n⁴) | khalgebra (n²) | % fewer mults |
56
+ |---|---|---|---|
57
+ | 2 | 16 | 4 | 75% |
58
+ | 3 | 81 | 9 | 88.9% |
59
+ | 4 | 256 | 16 | 93.8% |
60
+ | 6 | 1,296 | 36 | 97.2% |
61
+ | 10 | 10,000 | 100 | **99.0%** |
62
+
63
+ ### Wall-clock time (the honest part)
64
+
65
+ Running on JAX/CPU right now, the JIT and dispatch overhead means wall-clock times are slower than highly optimised BLAS routines. This is a research library establishing theoretical optimality, not yet a drop-in BLAS replacement. The multiplication reduction is real. The hardware is still catching up to the math.
66
+
67
+ If you are running on hardware where FLOPs are the bottleneck rather than memory bandwidth or kernel launch overhead — custom silicon, sparse compute, or future accelerators where multiplication cost is not zero — these algorithms are where you want to be.
68
+
69
+ ---
70
+
71
+ ## Correctness
72
+
73
+ All algorithms produce results numerically identical to the naive reference (max absolute error < 1e-9 across all tested sizes). The optimality is in the *structure*, not approximation.
74
+
75
+ ```python
76
+ import khalgebra as kh
77
+
78
+ A = kh.make_sym_mat(256)
79
+ v = kh.make_vec(256)
80
+
81
+ ref = kh.naive_dsymv(A, v) # standard n² path
82
+ opt = kh.khal_dsymv(A, v) # n(n+1)/2 path
83
+
84
+ kh.max_abs_err(ref, opt) # < 1e-12
85
+ ```
86
+
87
+ ---
88
+
89
+ ## Competitors
90
+
91
+ **BLAS DSYMV / NumPy / PyTorch / standard JAX:**
92
+ They know the matrix is symmetric. They still use n² multiplications. They have been doing this since the 1970s. This library corrects that.
93
+
94
+ **Strassen and successors:**
95
+ Attack general matrix multiplication by finding clever sub-multiplication structure. Do not specialise to symmetric structure. Interesting, but orthogonal.
96
+
97
+ **TensorFlow / cuBLAS:**
98
+ Same story as BLAS. Fast kernels. Wrong multiplication count for symmetric inputs.
99
+
100
+ None of these are wrong. They are just not reading the ayah carefully.
101
+
102
+ ---
103
+
104
+ ## Install
105
+
106
+ ```bash
107
+ pip install khalgebra
108
+ ```
109
+
110
+ Requires JAX. For GPU, install the appropriate `jax[cuda]` variant first.
111
+
112
+ ---
113
+
114
+ ## Author
115
+
116
+ Mahmood Khalil, 2025.
117
+
118
+ The name *khalgebra* is not subtle.
@@ -0,0 +1,40 @@
1
+ """
2
+ DSYMM benchmark — khal_dsymm vs naive_dsymm
3
+ Run: python bench/bench_dsymm.py
4
+ """
5
+
6
+ import time
7
+ import sys
8
+ import os
9
+ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
10
+
11
+ import khalgebra as kh
12
+
13
+ print(f"\n{'═'*60}")
14
+ print(f" DSYMM BENCHMARK — khalgebra vs naive (n×n × n×n)")
15
+ print(f" m·n(n+1)/2 mults vs m·n² standard")
16
+ print(f"{'═'*60}\n")
17
+ print(f" {'n':<6} {'naive_ms':>10} {'khalgebra_ms':>14} {'speedup':>9}")
18
+ print(f" {'─'*42}")
19
+
20
+ for n in [16, 32, 64, 128]:
21
+ A = kh.make_sym_mat(n)
22
+ B = kh.make_gen_mat(n, n)
23
+ REPS = 500 if n <= 32 else 100 if n <= 64 else 20
24
+
25
+ kh.naive_dsymm(A, B).block_until_ready()
26
+ kh.khal_dsymm(A, B).block_until_ready()
27
+
28
+ t0 = time.perf_counter()
29
+ for _ in range(REPS):
30
+ kh.naive_dsymm(A, B).block_until_ready()
31
+ naive_ms = (time.perf_counter() - t0) / REPS * 1000
32
+
33
+ t0 = time.perf_counter()
34
+ for _ in range(REPS):
35
+ kh.khal_dsymm(A, B).block_until_ready()
36
+ khal_ms = (time.perf_counter() - t0) / REPS * 1000
37
+
38
+ print(f" {n:<6} {naive_ms:>10.3f} {khal_ms:>14.3f} {naive_ms/khal_ms:>9.3f}x")
39
+
40
+ print()
@@ -0,0 +1,50 @@
1
+ """
2
+ DSYMV benchmark — khal_dsymv vs naive_dsymv
3
+ Run: python bench/bench_dsymv.py
4
+ """
5
+
6
+ import time
7
+ import sys
8
+ import os
9
+ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
10
+
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import khalgebra as kh
14
+
15
+ REPS = 200
16
+ VS_COUNT = 50
17
+
18
+ print(f"\n{'═'*72}")
19
+ print(f" DSYMV BENCHMARK — khalgebra vs naive")
20
+ print(f" n(n+1)/2 mults vs n² standard")
21
+ print(f"{'═'*72}\n")
22
+ print(f" {'n':<6} {'naive_ms':>10} {'khalgebra_ms':>14} {'speedup':>9} {'mults_saved':>12}")
23
+ print(f" {'─'*56}")
24
+
25
+ for n in [32, 64, 128, 256, 512, 1024]:
26
+ A = kh.make_sym_mat(n)
27
+ VS = [kh.make_vec(n, seed=i+1) for i in range(VS_COUNT)]
28
+
29
+ # JIT warmup
30
+ for v in VS[:3]:
31
+ kh.naive_dsymv(A, v).block_until_ready()
32
+ kh.khal_dsymv(A, v).block_until_ready()
33
+
34
+ t0 = time.perf_counter()
35
+ for _ in range(REPS):
36
+ for v in VS:
37
+ kh.naive_dsymv(A, v).block_until_ready()
38
+ naive_ms = (time.perf_counter() - t0) / (REPS * VS_COUNT) * 1000
39
+
40
+ t0 = time.perf_counter()
41
+ for _ in range(REPS):
42
+ for v in VS:
43
+ kh.khal_dsymv(A, v).block_until_ready()
44
+ khal_ms = (time.perf_counter() - t0) / (REPS * VS_COUNT) * 1000
45
+
46
+ speedup = naive_ms / khal_ms
47
+ saved_pct = (1 - n*(n+1)/2 / (n*n)) * 100
48
+ print(f" {n:<6} {naive_ms:>10.4f} {khal_ms:>14.4f} {speedup:>9.3f}x {saved_pct:>10.0f}% fewer mults")
49
+
50
+ print()
@@ -0,0 +1,45 @@
1
+ """
2
+ Riemann contraction benchmark — khal_riemann_contract vs naive_riemann_contract
3
+ Run: python bench/bench_riemann.py
4
+ """
5
+
6
+ import time
7
+ import sys
8
+ import os
9
+ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
10
+
11
+ import khalgebra as kh
12
+
13
+ REPS = 50_000
14
+
15
+ print(f"\n{'═'*72}")
16
+ print(f" RIEMANN CONTRACTION BENCHMARK — khalgebra vs naive")
17
+ print(f" n² mults vs n⁴ naive")
18
+ print(f"{'═'*72}\n")
19
+ print(f" {'n':<4} {'naive_μs':>10} {'optimal_μs':>12} {'speedup':>9} {'mults_saved':>12}")
20
+ print(f" {'─'*52}")
21
+
22
+ for n in [2, 3, 4, 5, 6, 10]:
23
+ R = kh.make_riemann_tensor(n)
24
+ u = kh.make_vec(n, seed=1)
25
+ v = kh.make_vec(n, seed=2)
26
+ comps = kh.build_riemann_components(n) # precomputed outside hot loop
27
+
28
+ # JIT warmup
29
+ kh.naive_riemann_contract(R, u, v).block_until_ready()
30
+ kh.khal_riemann_contract(R, u, v, comps) # not jitted — Python loop
31
+
32
+ t0 = time.perf_counter()
33
+ for _ in range(REPS):
34
+ kh.naive_riemann_contract(R, u, v).block_until_ready()
35
+ t_naive = (time.perf_counter() - t0) / REPS * 1e6
36
+
37
+ t0 = time.perf_counter()
38
+ for _ in range(REPS):
39
+ kh.khal_riemann_contract(R, u, v, comps)
40
+ t_opt = (time.perf_counter() - t0) / REPS * 1e6
41
+
42
+ reduction = (1 - n**2 / n**4) * 100
43
+ print(f" {n:<4} {t_naive:>10.3f} {t_opt:>12.3f} {t_naive/t_opt:>9.2f}x {reduction:>10.1f}% fewer mults")
44
+
45
+ print()
@@ -0,0 +1,36 @@
1
+ """
2
+ khalgebra — Khalil Optimal Bilinear Algorithms
3
+ Author: Mahmood Khalil (2025)
4
+ """
5
+
6
+ import jax
7
+ jax.config.update("jax_enable_x64", True)
8
+
9
+ __version__ = "0.1.0"
10
+ __author__ = "Mahmood Khalil"
11
+
12
+ from khalgebra.dsymv import khal_dsymv, naive_dsymv
13
+ from khalgebra.dsymm import khal_dsymm, naive_dsymm
14
+ from khalgebra.sym22 import khal_sym22, naive_sym22
15
+ from khalgebra.riemann import (
16
+ khal_riemann_contract,
17
+ naive_riemann_contract,
18
+ build_riemann_components,
19
+ )
20
+
21
+ from khalgebra._types import (
22
+ make_sym_mat,
23
+ make_gen_mat,
24
+ make_vec,
25
+ make_riemann_tensor,
26
+ max_abs_err,
27
+ )
28
+
29
+ __all__ = [
30
+ "khal_dsymv", "naive_dsymv",
31
+ "khal_dsymm", "naive_dsymm",
32
+ "khal_sym22", "naive_sym22",
33
+ "khal_riemann_contract", "naive_riemann_contract", "build_riemann_components",
34
+ "make_sym_mat", "make_gen_mat", "make_vec", "make_riemann_tensor", "max_abs_err",
35
+ "__version__", "__author__",
36
+ ]
@@ -0,0 +1,77 @@
1
+ """
2
+ khalgebra — Khalil Optimal Bilinear Algorithms
3
+ Author: Mahmood Khalil (2025)
4
+
5
+ Shared type aliases, test data generators, and verification helpers.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import jax.numpy as jnp
11
+ from jax import Array
12
+
13
+ Vec = Array
14
+ Mat = Array
15
+ Tensor4 = Array
16
+
17
+ def _lcg_sequence(n: int, seed: int) -> list[float]:
18
+ x = seed
19
+ out: list[float] = []
20
+ for _ in range(n):
21
+ x = (x * 1103515245 + 12345) & 0x7FFFFFFF
22
+ out.append(x / 0x7FFFFFFF * 2 - 1)
23
+ return out
24
+
25
+ def make_sym_mat(n: int, seed: int = 42) -> Mat:
26
+ flat = _lcg_sequence(n * n, seed)
27
+ A = [[0.0] * n for _ in range(n)]
28
+ idx = 0
29
+ for i in range(n):
30
+ A[i][i] = abs(flat[idx]) + 1
31
+ idx += 1
32
+ for j in range(i + 1, n):
33
+ v = flat[idx]; idx += 1
34
+ A[i][j] = v
35
+ A[j][i] = v
36
+ return jnp.array(A, dtype=jnp.float64)
37
+
38
+
39
+ def make_gen_mat(rows: int, cols: int, seed: int = 77) -> Mat:
40
+ flat = _lcg_sequence(rows * cols, seed)
41
+ return jnp.array(flat, dtype=jnp.float64).reshape(rows, cols)
42
+
43
+
44
+ def make_vec(n: int, seed: int = 1) -> Vec:
45
+ flat = _lcg_sequence(n, seed)
46
+ return jnp.array(flat, dtype=jnp.float64)
47
+
48
+
49
+ def make_riemann_tensor(n: int, seed: int = 42) -> Tensor4:
50
+ comps = build_riemann_components(n)
51
+ vals = _lcg_sequence(len(comps), seed)
52
+ R = [[[[0.0] * n for _ in range(n)] for _ in range(n)] for _ in range(n)]
53
+ for (a, b, c, d), v in zip(comps, vals):
54
+ R[a][b][c][d] = v
55
+ R[b][a][c][d] = -v
56
+ R[a][b][d][c] = -v
57
+ R[b][a][d][c] = v
58
+ if a * n + b != c * n + d:
59
+ R[c][d][a][b] = v
60
+ R[d][c][a][b] = -v
61
+ R[c][d][b][a] = -v
62
+ R[d][c][b][a] = v
63
+ return jnp.array(R, dtype=jnp.float64)
64
+
65
+ def max_abs_err(a: Array, b: Array) -> float:
66
+ return float(jnp.max(jnp.abs(a - b)))
67
+
68
+
69
+ def build_riemann_components(n: int) -> list[tuple[int, int, int, int]]:
70
+ out: list[tuple[int, int, int, int]] = []
71
+ for a in range(n):
72
+ for b in range(a + 1, n):
73
+ for c in range(n):
74
+ for d in range(c + 1, n):
75
+ if a * n + b <= c * n + d:
76
+ out.append((a, b, c, d))
77
+ return out
@@ -0,0 +1,55 @@
1
+ """
2
+ khalgebra — Khalil Optimal Bilinear Algorithms
3
+ Author: Mahmood Khalil (2025)
4
+
5
+ DSYMM: optimal symmetric matrix-matrix product.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import jax
11
+ import jax.numpy as jnp
12
+
13
+ from khalgebra._types import Mat
14
+
15
+
16
+ @jax.jit
17
+ def khal_dsymm(A: Mat, B: Mat) -> Mat:
18
+ """
19
+ Optimal symmetric matrix-matrix product.
20
+
21
+ Bilinear complexity: m·n(n+1)/2 multiplications (Khalil 2025).
22
+ Standard BLAS DSYMM uses m·n² multiplications.
23
+
24
+ Args:
25
+ A: n×n symmetric matrix, shape (n, n), float64
26
+ B: n×m general matrix, shape (n, m), float64
27
+
28
+ Returns:
29
+ C = A @ B, shape (n, m), float64
30
+ """
31
+ upper = jnp.triu(A, k=1)
32
+ diag_adj = jnp.diag(A) - jnp.sum(upper, axis=1) - jnp.sum(upper, axis=0)
33
+
34
+ def _col(v):
35
+ v_sum = v[:, None] + v[None, :]
36
+ M = upper * v_sum
37
+ w_off = jnp.sum(M, axis=1) + jnp.sum(M, axis=0)
38
+ return w_off + diag_adj * v
39
+
40
+ return jax.vmap(_col, in_axes=1, out_axes=1)(B)
41
+
42
+
43
+ @jax.jit
44
+ def naive_dsymm(A: Mat, B: Mat) -> Mat:
45
+ """
46
+ Reference symmetric matrix-matrix product.
47
+
48
+ Args:
49
+ A: n×n symmetric matrix, shape (n, n), float64
50
+ B: n×m general matrix, shape (n, m), float64
51
+
52
+ Returns:
53
+ C = A @ B, shape (n, m), float64
54
+ """
55
+ return jnp.dot(A, B)
@@ -0,0 +1,56 @@
1
+ """
2
+ khalgebra — Khalil Optimal Bilinear Algorithms
3
+ Author: Mahmood Khalil (2025)
4
+
5
+ DSYMV: optimal symmetric matrix-vector product.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import jax
11
+ import jax.numpy as jnp
12
+
13
+ from khalgebra._types import Mat, Vec
14
+
15
+
16
+ @jax.jit
17
+ def khal_dsymv(A: Mat, v: Vec) -> Vec:
18
+ """
19
+ Optimal symmetric matrix-vector product.
20
+
21
+ Bilinear complexity: n(n+1)/2 multiplications (proven optimal, Khalil 2025).
22
+ Standard BLAS DSYMV uses n² multiplications.
23
+
24
+ Args:
25
+ A: n×n symmetric matrix, shape (n, n), float64
26
+ v: length-n vector, shape (n,), float64
27
+
28
+ Returns:
29
+ w = A @ v, shape (n,), float64
30
+ """
31
+ n = A.shape[0]
32
+
33
+ upper = jnp.triu(A, k=1)
34
+ v_sum = v[:, None] + v[None, :]
35
+ M = upper * v_sum
36
+ w_off = jnp.sum(M, axis=1) + jnp.sum(M, axis=0)
37
+
38
+ diag_adj = jnp.diag(A) - jnp.sum(upper, axis=1) - jnp.sum(upper, axis=0)
39
+ w_diag = diag_adj * v
40
+
41
+ return w_off + w_diag
42
+
43
+
44
+ @jax.jit
45
+ def naive_dsymv(A: Mat, v: Vec) -> Vec:
46
+ """
47
+ Reference symmetric matrix-vector product.
48
+
49
+ Args:
50
+ A: n×n symmetric matrix, shape (n, n), float64
51
+ v: length-n vector, shape (n,), float64
52
+
53
+ Returns:
54
+ w = A @ v, shape (n,), float64
55
+ """
56
+ return jnp.dot(A, v)
@@ -0,0 +1,124 @@
1
+ """
2
+ khalgebra — Khalil Optimal Bilinear Algorithms
3
+ Author: Mahmood Khalil (2025)
4
+
5
+ Riemann: optimal contraction of the Riemann curvature tensor.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import functools
11
+
12
+ import jax
13
+ import jax.numpy as jnp
14
+ import numpy as np
15
+
16
+ from khalgebra._types import Mat, Vec, Tensor4, build_riemann_components
17
+
18
+ __all__ = ["khal_riemann_contract", "naive_riemann_contract", "build_riemann_components"]
19
+
20
+
21
+ @jax.jit
22
+ def naive_riemann_contract(R: Tensor4, u: Vec, v: Vec) -> Mat:
23
+ """
24
+ Reference Riemann tensor contraction.
25
+
26
+ Computes B[b,d] = Σ_{a,c} R[a,b,c,d] · u[a] · v[c]
27
+
28
+ Args:
29
+ R: Riemann tensor, shape (n, n, n, n), float64
30
+ u: contravariant vector, shape (n,), float64
31
+ v: contravariant vector, shape (n,), float64
32
+
33
+ Returns:
34
+ B: covariant rank-2 tensor, shape (n, n), float64
35
+ """
36
+ return jnp.einsum("abcd,a,c->bd", R, u, v)
37
+
38
+
39
+
40
+ def _build_index_arrays(
41
+ n: int, comps: list[tuple[int, int, int, int]]
42
+ ) -> tuple[np.ndarray, ...]:
43
+ Ra, Rb, Rc, Rd, UVr, UVc, Br, Bc, sgn = [], [], [], [], [], [], [], [], []
44
+
45
+ for a, b, c, d in comps:
46
+ for (br, bc, uvr, uvc, s) in [
47
+ (b, d, a, c, +1),
48
+ (a, d, b, c, -1),
49
+ (b, c, a, d, -1),
50
+ (a, c, b, d, +1),
51
+ ]:
52
+ Ra.append(a); Rb.append(b); Rc.append(c); Rd.append(d)
53
+ UVr.append(uvr); UVc.append(uvc)
54
+ Br.append(br); Bc.append(bc); sgn.append(s)
55
+ if a * n + b != c * n + d:
56
+ for (br, bc, uvr, uvc, s) in [
57
+ (d, b, c, a, +1),
58
+ (c, b, d, a, -1),
59
+ (d, a, c, b, -1),
60
+ (c, a, d, b, +1),
61
+ ]:
62
+ Ra.append(a); Rb.append(b); Rc.append(c); Rd.append(d)
63
+ UVr.append(uvr); UVc.append(uvc)
64
+ Br.append(br); Bc.append(bc); sgn.append(s)
65
+
66
+ return (
67
+ np.array(Ra), np.array(Rb), np.array(Rc), np.array(Rd),
68
+ np.array(UVr), np.array(UVc),
69
+ np.array(Br), np.array(Bc),
70
+ np.array(sgn, dtype=np.float64),
71
+ )
72
+
73
+
74
+ @functools.cache
75
+ def _get_kernel(n: int, comps_key: tuple[tuple[int, int, int, int], ...]):
76
+ comps = list(comps_key)
77
+ Ra, Rb, Rc, Rd, UVr, UVc, Br, Bc, signs = _build_index_arrays(n, comps)
78
+
79
+ j_Ra = jnp.array(Ra); j_Rb = jnp.array(Rb)
80
+ j_Rc = jnp.array(Rc); j_Rd = jnp.array(Rd)
81
+ j_UVr = jnp.array(UVr); j_UVc = jnp.array(UVc)
82
+ j_Br = jnp.array(Br); j_Bc = jnp.array(Bc)
83
+ j_signs = jnp.array(signs)
84
+
85
+ @jax.jit
86
+ def _kernel(R: Tensor4, u: Vec, v: Vec) -> Mat:
87
+ UV = jnp.outer(u, v)
88
+ R_vals = R[j_Ra, j_Rb, j_Rc, j_Rd]
89
+ UV_vals = UV[j_UVr, j_UVc]
90
+ contribs = j_signs * R_vals * UV_vals
91
+ B = jnp.zeros((n, n), dtype=R.dtype)
92
+ return B.at[j_Br, j_Bc].add(contribs)
93
+
94
+ return _kernel
95
+
96
+
97
+ def khal_riemann_contract(
98
+ R: Tensor4,
99
+ u: Vec,
100
+ v: Vec,
101
+ components: list[tuple[int, int, int, int]] | None = None,
102
+ ) -> Mat:
103
+ """
104
+ Optimal Riemann tensor contraction.
105
+
106
+ Computes B[b,d] = Σ_{a,c} R[a,b,c,d] · u[a] · v[c]
107
+
108
+ Bilinear complexity: n² multiplications (proven optimal, Khalil 2025).
109
+ Naive jnp.einsum uses n⁴ multiplications.
110
+
111
+ Args:
112
+ R: Riemann tensor, shape (n, n, n, n), float64
113
+ u: contravariant vector, shape (n,), float64
114
+ v: contravariant vector, shape (n,), float64
115
+ components: optional precomputed list from build_riemann_components(n).
116
+ Pass in hot loops; omit to compute on first call and cache.
117
+
118
+ Returns:
119
+ B: covariant rank-2 tensor, shape (n, n), float64
120
+ """
121
+ n = R.shape[0]
122
+ comps = components if components is not None else build_riemann_components(n)
123
+ kernel = _get_kernel(n, tuple(comps))
124
+ return kernel(R, u, v)
@@ -0,0 +1,62 @@
1
+ """
2
+ khalgebra — Khalil Optimal Bilinear Algorithms
3
+ Author: Mahmood Khalil (2025)
4
+
5
+ Sym22: optimal 2×2 symmetric × 2×2 general matrix product.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import jax
11
+ import jax.numpy as jnp
12
+
13
+
14
+ @jax.jit
15
+ def khal_sym22(
16
+ a: float, b: float, c: float,
17
+ p: float, q: float, r: float, s: float,
18
+ ) -> tuple[float, float, float, float]:
19
+ """
20
+ Optimal 2×2 symmetric × 2×2 general matrix product.
21
+
22
+ Bilinear complexity: 6 scalar multiplications (proven optimal, Khalil 2025).
23
+ Naive requires 8.
24
+
25
+ Args:
26
+ a: A[0,0]
27
+ b: A[0,1] = A[1,0]
28
+ c: A[1,1]
29
+ p: B[0,0]
30
+ q: B[0,1]
31
+ r: B[1,0]
32
+ s: B[1,1]
33
+
34
+ Returns:
35
+ (C[0,0], C[0,1], C[1,0], C[1,1]) where C = A·B
36
+ """
37
+ ab = a - b
38
+ cb = c - b
39
+ M1 = b * (p + r)
40
+ M2 = ab * p
41
+ M3 = cb * r
42
+ M4 = b * (q + s)
43
+ M5 = ab * q
44
+ M6 = cb * s
45
+ return (M1 + M2, M4 + M5, M1 + M3, M4 + M6)
46
+
47
+
48
+ def naive_sym22(
49
+ a: float, b: float, c: float,
50
+ p: float, q: float, r: float, s: float,
51
+ ) -> tuple[float, float, float, float]:
52
+ """
53
+ Reference 2×2 symmetric × 2×2 general matrix product.
54
+
55
+ Args:
56
+ a, b, c: upper triangle of symmetric A: [[a,b],[b,c]]
57
+ p, q, r, s: entries of general B: [[p,q],[r,s]]
58
+
59
+ Returns:
60
+ (C[0,0], C[0,1], C[1,0], C[1,1]) where C = A·B
61
+ """
62
+ return (a * p + b * r, a * q + b * s, b * p + c * r, b * q + c * s)
@@ -0,0 +1,20 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "khalgebra"
7
+ version = "0.1.0"
8
+ description = "Provably-optimal bilinear algorithms for symmetric linear algebra"
9
+ authors = [{name = "Mahmood Khalil"}]
10
+ license = {text = "MIT"}
11
+ readme = "README.md"
12
+ requires-python = ">=3.10"
13
+ dependencies = ["jax>=0.4", "jaxlib>=0.4"]
14
+ keywords = ["linear-algebra", "blas", "dsymv", "riemann", "bilinear", "optimal", "jax"]
15
+
16
+ [project.optional-dependencies]
17
+ dev = ["pytest>=8", "numpy"]
18
+
19
+ [project.urls]
20
+ Homepage = "https://github.com/khalil/khalgebra"
@@ -0,0 +1,16 @@
1
+ import pytest
2
+ import khalgebra as kh
3
+
4
+
5
+ @pytest.mark.parametrize("n,m", [
6
+ (2, 1), (3, 3), (8, 8),
7
+ (4, 4), (16, 1), (32, 8),
8
+ (64, 8), (128, 8),
9
+ ])
10
+ def test_dsymm_correctness(n, m):
11
+ A = kh.make_sym_mat(n)
12
+ B = kh.make_gen_mat(n, m)
13
+ ref = kh.naive_dsymm(A, B)
14
+ opt = kh.khal_dsymm(A, B)
15
+ err = kh.max_abs_err(ref, opt)
16
+ assert err < 1e-9, f"n={n} m={m}: err={err:.2e}"
@@ -0,0 +1,27 @@
1
+ import pytest
2
+ import jax.numpy as jnp
3
+ import khalgebra as kh
4
+
5
+
6
+ @pytest.mark.parametrize("n", [2, 3, 4, 8, 16, 32, 64, 100, 256])
7
+ def test_dsymv_correctness(n):
8
+ A = kh.make_sym_mat(n)
9
+ v = kh.make_vec(n)
10
+ ref = kh.naive_dsymv(A, v)
11
+ opt = kh.khal_dsymv(A, v)
12
+ assert kh.max_abs_err(ref, opt) < 1e-9, f"n={n}: err={kh.max_abs_err(ref, opt):.2e}"
13
+
14
+
15
+ def test_dsymv_n3_hand_verified():
16
+ A = jnp.array([[3.0, 1.5, -0.5],
17
+ [1.5, 2.0, 0.7],
18
+ [-0.5, 0.7, 4.0]])
19
+ v = jnp.array([2.0, -1.0, 3.0])
20
+
21
+ ref = kh.naive_dsymv(A, v)
22
+ opt = kh.khal_dsymv(A, v)
23
+
24
+ assert abs(float(ref[0]) - 3.0) < 1e-10
25
+ assert abs(float(ref[1]) - 3.1) < 1e-10
26
+ assert abs(float(ref[2]) - 10.3) < 1e-10
27
+ assert kh.max_abs_err(ref, opt) < 1e-9
@@ -0,0 +1,47 @@
1
+ import pytest
2
+ import jax.numpy as jnp
3
+ import khalgebra as kh
4
+
5
+
6
+ @pytest.mark.parametrize("n", [2, 3, 4, 5, 6])
7
+ def test_riemann_correctness(n):
8
+ R = kh.make_riemann_tensor(n)
9
+ u = kh.make_vec(n, seed=1)
10
+ v = kh.make_vec(n, seed=2)
11
+
12
+ ref = kh.naive_riemann_contract(R, u, v)
13
+ opt = kh.khal_riemann_contract(R, u, v)
14
+
15
+ r_max = float(jnp.max(jnp.abs(R)))
16
+ assert r_max > 0.01, f"n={n}: tensor is degenerate (|R|_max={r_max:.4f})"
17
+
18
+ err = kh.max_abs_err(ref, opt)
19
+ assert err < 1e-12, f"n={n}: err={err:.2e}"
20
+
21
+
22
+ @pytest.mark.parametrize("n", [2, 3, 4, 5, 6])
23
+ def test_riemann_precomputed_components(n):
24
+ R = kh.make_riemann_tensor(n)
25
+ u = kh.make_vec(n, seed=1)
26
+ v = kh.make_vec(n, seed=2)
27
+ comps = kh.build_riemann_components(n)
28
+
29
+ with_comps = kh.khal_riemann_contract(R, u, v, comps)
30
+ without_comps = kh.khal_riemann_contract(R, u, v)
31
+
32
+ assert kh.max_abs_err(with_comps, without_comps) < 1e-15
33
+
34
+
35
+ def test_riemann_n4_spot_check():
36
+ """Spot-check three specific (b,d) entries for n=4."""
37
+ n = 4
38
+ R = kh.make_riemann_tensor(n)
39
+ u = kh.make_vec(n, seed=1)
40
+ v = kh.make_vec(n, seed=2)
41
+
42
+ ref = kh.naive_riemann_contract(R, u, v)
43
+ opt = kh.khal_riemann_contract(R, u, v)
44
+
45
+ for b, d in [(0, 1), (2, 3), (1, 3)]:
46
+ err = abs(float(opt[b, d]) - float(ref[b, d]))
47
+ assert err < 1e-12, f"(b,d)=({b},{d}): err={err:.2e}"
@@ -0,0 +1,22 @@
1
+ import khalgebra as kh
2
+ from khalgebra._types import _lcg_sequence
3
+
4
+
5
+ def test_sym22_fixed_input():
6
+ """A=[[1,2],[2,3]], B=[[4,5],[6,7]] → C=[[16,19],[26,31]]"""
7
+ c00, c01, c10, c11 = kh.khal_sym22(1, 2, 3, 4, 5, 6, 7)
8
+ assert float(c00) == 16.0
9
+ assert float(c01) == 19.0
10
+ assert float(c10) == 26.0
11
+ assert float(c11) == 31.0
12
+
13
+
14
+ def test_sym22_50_random_inputs():
15
+ vals = _lcg_sequence(7 * 50, seed=99)
16
+ for i in range(50):
17
+ a, b, c, p, q, r, s = vals[i*7:(i+1)*7]
18
+ ref = kh.naive_sym22(a, b, c, p, q, r, s)
19
+ opt = kh.khal_sym22(a, b, c, p, q, r, s)
20
+ for k in range(4):
21
+ err = abs(float(opt[k]) - float(ref[k]))
22
+ assert err < 1e-13, f"input {i} entry {k}: err={err:.2e}"