xagm 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xagm-0.1.1/PKG-INFO +18 -0
- xagm-0.1.1/README.md +10 -0
- xagm-0.1.1/pyproject.toml +18 -0
- xagm-0.1.1/setup.cfg +4 -0
- xagm-0.1.1/src/xagm/__init__.py +0 -0
- xagm-0.1.1/src/xagm/basis/__init__.py +4 -0
- xagm-0.1.1/src/xagm/basis/linear.py +33 -0
- xagm-0.1.1/src/xagm/basis/metrics.py +67 -0
- xagm-0.1.1/src/xagm/geoutils.py +27 -0
- xagm-0.1.1/src/xagm/manifolds/__init__.py +7 -0
- xagm-0.1.1/src/xagm/manifolds/calc.py +103 -0
- xagm-0.1.1/src/xagm/manifolds/vectors.py +56 -0
- xagm-0.1.1/src/xagm.egg-info/PKG-INFO +18 -0
- xagm-0.1.1/src/xagm.egg-info/SOURCES.txt +15 -0
- xagm-0.1.1/src/xagm.egg-info/dependency_links.txt +1 -0
- xagm-0.1.1/src/xagm.egg-info/requires.txt +3 -0
- xagm-0.1.1/src/xagm.egg-info/top_level.txt +1 -0
xagm-0.1.1/PKG-INFO
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: xagm
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Description-Content-Type: text/markdown
|
|
5
|
+
Requires-Dist: jax
|
|
6
|
+
Requires-Dist: jaxlib
|
|
7
|
+
Requires-Dist: jaxtyping
|
|
8
|
+
|
|
9
|
+
XAGM is a Riemannian Differentiable Geometry engine which stands for Accelerated Autodiff Geometry Multi-dimensional. It deals exclusively in Riemannian SPD metrics, and it is MANDATORY the metrics are Symmetric Positive Definite (SPD) for it to work.
|
|
10
|
+
It offers a vast array of functions, with 4 modules to call upon, them being metrics, linear, vectors, and calc. Vectors deal mainly with linear algebra adjacent functions with respect to the metric tensor. Speaking of the metric tensor, XAGM allows you to use fwdmet to create a pullback metric.
|
|
11
|
+
|
|
12
|
+
The crown jewels of XAGM would be christoffel(), geoexp_solver(), geolog_solver(), and geodist(), with geoexp_solver consistently performing at sub millisecond speeds, and geolog_solver being in the comfortable range of 2-20ms each run depending on how many steps are given to the solver.
|
|
13
|
+
|
|
14
|
+
XAGM has been benchmarked (quite unofficially so you are free to do your own runtime checks) and observed to outperform basically every other geometry application in numpy and the dominating Geometry powerhouses. You are highly encouraged, however, to confirm that yourself too.
|
|
15
|
+
|
|
16
|
+
XAGM is a bit hard to use at first since it expects a decent background in maths for most of the functions and a clear understanding of how to use JAX native functions like vmap and jit along with static_argnums and static_argnames, but, overall, if you behave nicely and pass clean arrays into it, it will reward you. Documentation on this project will be coming soon! (or never at all. No in between.)
|
|
17
|
+
|
|
18
|
+
|
xagm-0.1.1/README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
XAGM is a Riemannian Differentiable Geometry engine which stands for Accelerated Autodiff Geometry Multi-dimensional. It deals exclusively in Riemannian SPD metrics, and it is MANDATORY the metrics are Symmetric Positive Definite (SPD) for it to work.
|
|
2
|
+
It offers a vast array of functions, with 4 modules to call upon, them being metrics, linear, vectors, and calc. Vectors deal mainly with linear algebra adjacent functions with respect to the metric tensor. Speaking of the metric tensor, XAGM allows you to use fwdmet to create a pullback metric.
|
|
3
|
+
|
|
4
|
+
The crown jewels of XAGM would be christoffel(), geoexp_solver(), geolog_solver(), and geodist(), with geoexp_solver consistently performing at sub millisecond speeds, and geolog_solver being in the comfortable range of 2-20ms each run depending on how many steps are given to the solver.
|
|
5
|
+
|
|
6
|
+
XAGM has been benchmarked (quite unofficially so you are free to do your own runtime checks) and observed to outperform basically every other geometry application in numpy and the dominating Geometry powerhouses. You are highly encouraged, however, to confirm that yourself too.
|
|
7
|
+
|
|
8
|
+
XAGM is a bit hard to use at first since it expects a decent background in maths for most of the functions and a clear understanding of how to use JAX native functions like vmap and jit along with static_argnums and static_argnames, but, overall, if you behave nicely and pass clean arrays into it, it will reward you. Documentation on this project will be coming soon! (or never at all. No in between.)
|
|
9
|
+
|
|
10
|
+
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "xagm"
|
|
7
|
+
version = "0.1.1"
|
|
8
|
+
dependencies = [
|
|
9
|
+
"jax", "jaxlib", "jaxtyping"
|
|
10
|
+
]
|
|
11
|
+
readme = "README.md"
|
|
12
|
+
# Add any dependencies you need here, e.g., ["jax", "jaxlib"]
|
|
13
|
+
|
|
14
|
+
[tool.setuptools.packages.find]
|
|
15
|
+
# This tells setuptools to look for your code inside the 'src' folder
|
|
16
|
+
where = ["src"]
|
|
17
|
+
|
|
18
|
+
[tool.setuptools]
|
xagm-0.1.1/setup.cfg
ADDED
|
File without changes
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import geoutils as us
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from geoutils import Vector, Matrix, Scalar, Tensor, JAXArray
|
|
5
|
+
from basis import metrics as mtc
|
|
6
|
+
|
|
7
|
+
def grid(idx: JAXArray, dimens: tuple):
|
|
8
|
+
fg = jnp.unravel_index(idx, dimens)
|
|
9
|
+
g = fg[::-1]
|
|
10
|
+
ng = jnp.stack(g, axis=-1)
|
|
11
|
+
|
|
12
|
+
return ng
|
|
13
|
+
|
|
14
|
+
static_argnums = (2,)
|
|
15
|
+
def line(p1: Vector, p2: Vector, segs: int) -> Matrix:
|
|
16
|
+
t = jnp.linspace(0, 1, segs)[:, jnp.newaxis]
|
|
17
|
+
|
|
18
|
+
l = p1 + (t * (p2 - p1))
|
|
19
|
+
|
|
20
|
+
return l
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def ang(g: Matrix, u: Vector, v: Vector) -> Scalar:
|
|
25
|
+
|
|
26
|
+
numerator = mtc.iprod(g, u, v)
|
|
27
|
+
den1 = mtc.norm(g, u)
|
|
28
|
+
den2 = mtc.norm(g, v)
|
|
29
|
+
|
|
30
|
+
angle = us.div(numerator, (den1 * den2))
|
|
31
|
+
safe_cos = jnp.clip(angle, -1.0 + 1e-8, 1.0 - 1e-8)
|
|
32
|
+
|
|
33
|
+
return jnp.arccos(safe_cos)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import geoutils as us
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from geoutils import Vector, Matrix, Scalar, Tensor, JAXArray
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def euclid(x: Vector) -> Matrix:
|
|
8
|
+
return jnp.eye(x.shape[-1])
|
|
9
|
+
|
|
10
|
+
def iprod(g: Matrix, u: Vector|Matrix, v: Vector|Matrix) -> Vector:
|
|
11
|
+
return jnp.einsum('...i, ...ij, ...j -> ...', u, g, v)
|
|
12
|
+
|
|
13
|
+
def norm(g: Matrix, u: Vector) -> Scalar:
|
|
14
|
+
return jnp.sqrt(jnp.maximum(iprod(g, u, u), 1e-16))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
static_argnums = (0,)
|
|
18
|
+
def fwdmet(f, v: Vector) -> Matrix:
|
|
19
|
+
J = jax.jacfwd(f)(v)
|
|
20
|
+
nJ = J.reshape(-1, v.shape[-1])
|
|
21
|
+
return jnp.einsum('ai, aj -> ij', nJ, nJ)
|
|
22
|
+
|
|
23
|
+
static_argnums = (0,)
|
|
24
|
+
def revmet(f, v: Vector) -> Matrix:
|
|
25
|
+
J = jax.jacrev(f)(v)
|
|
26
|
+
nJ = J.reshape(-1, v.shape[-1])
|
|
27
|
+
return jnp.einsum('ai, aj -> ij', nJ, nJ)
|
|
28
|
+
|
|
29
|
+
def metinv(g: Matrix) -> Matrix:
|
|
30
|
+
vals, vecs = jnp.linalg.eigh(g)
|
|
31
|
+
inv_vals = us.div(1.0, jnp.maximum(vals, 1e-12))
|
|
32
|
+
met = jnp.einsum('ik, k, jk -> ij', vecs, inv_vals, vecs)
|
|
33
|
+
|
|
34
|
+
return met
|
|
35
|
+
|
|
36
|
+
def metinterp(g0: Matrix, v0: Vector,
|
|
37
|
+
g1: Matrix, v1: Vector,
|
|
38
|
+
target: Vector) -> Matrix:
|
|
39
|
+
|
|
40
|
+
vals0, vecs0 = jnp.linalg.eigh(g0)
|
|
41
|
+
logvals0 = jnp.log(jnp.maximum(vals0, 1e-7))
|
|
42
|
+
lg0 = jnp.einsum('ik, k, jk -> ij', vecs0, logvals0, vecs0)
|
|
43
|
+
|
|
44
|
+
vals1, vecs1 = jnp.linalg.eigh(g1)
|
|
45
|
+
logvals1 = jnp.log(jnp.maximum(vals1, 1e-7))
|
|
46
|
+
lg1 = jnp.einsum('ik, k, jk -> ij', vecs1, logvals1, vecs1)
|
|
47
|
+
|
|
48
|
+
d = v1 - v0
|
|
49
|
+
p = target - v0
|
|
50
|
+
|
|
51
|
+
t = us.div(jnp.dot(p, d),jnp.dot(d, d))
|
|
52
|
+
t = jnp.clip(t, 0.0, 1.0)
|
|
53
|
+
|
|
54
|
+
interp = (1.0 - t) * lg0 + (t * lg1)
|
|
55
|
+
|
|
56
|
+
intvals, intvecs = jnp.linalg.eigh(interp)
|
|
57
|
+
|
|
58
|
+
ival = jnp.exp(intvals)
|
|
59
|
+
|
|
60
|
+
ig = jnp.einsum('ik, k, jk -> ij', intvecs, ival, intvecs)
|
|
61
|
+
|
|
62
|
+
return ig
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
from jax import config
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
|
|
6
|
+
config.update("jax_enable_x64", True)
|
|
7
|
+
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from jaxtyping import Float64, Array
|
|
11
|
+
|
|
12
|
+
@jax.jit
|
|
13
|
+
def div(a, b):
|
|
14
|
+
safe = b != 0
|
|
15
|
+
den = jnp.where(safe, b, 1.0)
|
|
16
|
+
|
|
17
|
+
return jnp.where(safe, a/den, 0.0)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
eps = 1e-15
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
type Scalar = Float64[Array, ""] # 0D [1]
|
|
24
|
+
type Vector = Float64[Array, "N"] # 1D [1, 2]
|
|
25
|
+
type Matrix = Float64[Array, "M N"] # 2D [[1, 2], [2, 1]]
|
|
26
|
+
type Tensor = Float64[Array, "*batch M N O"] # 3-D+
|
|
27
|
+
type JAXArray = Float64[Array]
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import geoutils as us
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from geoutils import Vector, Matrix, Scalar, Tensor, JAXArray
|
|
5
|
+
|
|
6
|
+
from basis import metrics as mtc
|
|
7
|
+
|
|
8
|
+
def christoffel(func, x: Vector) -> Matrix:
|
|
9
|
+
|
|
10
|
+
g = mtc.fwdmet(func, x)
|
|
11
|
+
ginv = mtc.metinv(g)
|
|
12
|
+
mtc_func = lambda v: mtc.fwdmet(func, v)
|
|
13
|
+
|
|
14
|
+
__,dg_raw = jax.vmap(lambda v: jax.jvp(mtc_func, (x,), (v,)))(jnp.eye(x.shape[0]))
|
|
15
|
+
|
|
16
|
+
dg = jnp.moveaxis(dg_raw, 0, -1)
|
|
17
|
+
|
|
18
|
+
term1 = jnp.transpose(dg, axes=[1, 2, 0])
|
|
19
|
+
term2 = jnp.transpose(dg, axes=[0, 1, 2])
|
|
20
|
+
term3 = jnp.transpose(dg, axes=[2, 0, 1])
|
|
21
|
+
|
|
22
|
+
contract1 = 0.5 * ginv
|
|
23
|
+
contract2 = term1 + term2 - term3
|
|
24
|
+
gamma = jnp.einsum('kl, lij -> kij', contract1, contract2)
|
|
25
|
+
|
|
26
|
+
return gamma
|
|
27
|
+
|
|
28
|
+
import diffrax
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def geoexp_term(t, state, args) -> Vector:
|
|
32
|
+
dim = state.shape[0] // 3
|
|
33
|
+
|
|
34
|
+
x = state[:dim]
|
|
35
|
+
v = state[dim:2*dim]
|
|
36
|
+
y = state[2*dim:]
|
|
37
|
+
|
|
38
|
+
func = args['func']
|
|
39
|
+
|
|
40
|
+
gamma = christoffel(func, x)
|
|
41
|
+
|
|
42
|
+
v_dot = -jnp.einsum('kij, i, j -> k', gamma, v, v)
|
|
43
|
+
|
|
44
|
+
dvecdt = -jnp.einsum('kij, i, j -> k', gamma, v, y)
|
|
45
|
+
|
|
46
|
+
return jnp.concatenate([v, v_dot, dvecdt])
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def geoexp_solver(p: Vector, v: Vector, mapped_func, vt: Vector, steps = 4096) -> Vector:
|
|
50
|
+
|
|
51
|
+
state = jnp.concatenate([p, v, vt])
|
|
52
|
+
|
|
53
|
+
solution = diffrax.diffeqsolve(
|
|
54
|
+
terms = diffrax.ODETerm(geoexp_term),
|
|
55
|
+
solver = diffrax.Tsit5(),
|
|
56
|
+
t0=0,
|
|
57
|
+
t1=1,
|
|
58
|
+
dt0=1e-2,
|
|
59
|
+
y0=state,
|
|
60
|
+
args = {'func': mapped_func},
|
|
61
|
+
stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-10),
|
|
62
|
+
saveat=diffrax.SaveAt(t1=True),
|
|
63
|
+
adjoint = diffrax.DirectAdjoint(),
|
|
64
|
+
max_steps = steps,
|
|
65
|
+
throw = False
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
result = solution.ys[0]
|
|
69
|
+
|
|
70
|
+
dim = p.shape[0]
|
|
71
|
+
final_pos = result[:dim]
|
|
72
|
+
final_vel = result[dim:2*dim]
|
|
73
|
+
transported_v = result[2*dim:]
|
|
74
|
+
|
|
75
|
+
return final_pos, final_vel, transported_v
|
|
76
|
+
|
|
77
|
+
def geolog_solver(p: Vector, q: Vector, mapped_func, steps: int) -> Vector:
|
|
78
|
+
|
|
79
|
+
def shoot(v_guess):
|
|
80
|
+
pos, _, _ = geoexp_solver(p, v_guess, mapped_func, jnp.zeros_like(p))
|
|
81
|
+
return pos
|
|
82
|
+
|
|
83
|
+
v = q-p
|
|
84
|
+
J = jax.jacobian(shoot)(v)
|
|
85
|
+
|
|
86
|
+
def bodyfun(i, v):
|
|
87
|
+
error = shoot(v) - q
|
|
88
|
+
#J = jax.jacobian(shoot)(v)
|
|
89
|
+
delta = jnp.linalg.solve(J, error)
|
|
90
|
+
return v - delta
|
|
91
|
+
|
|
92
|
+
final_v = jax.lax.fori_loop(0, steps, bodyfun, v)
|
|
93
|
+
return final_v
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def geodist(p: Vector, q: Vector, mapped_func, steps: int) -> Scalar:
|
|
97
|
+
v = geolog_solver(p, q, mapped_func, steps)
|
|
98
|
+
g = mtc.fwdmet(mapped_func, p)
|
|
99
|
+
dist = mtc.norm(g, v)
|
|
100
|
+
return dist
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
|
|
2
|
+
import geoutils as us
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from geoutils import Vector, Matrix, Scalar, Tensor, JAXArray
|
|
6
|
+
|
|
7
|
+
from basis import metrics as mtc
|
|
8
|
+
|
|
9
|
+
def nrml(g: Matrix, basis: Matrix) -> Matrix:
|
|
10
|
+
|
|
11
|
+
nvals, vecs = jnp.linalg.eigh(g)
|
|
12
|
+
vals = jnp.maximum(nvals, 0.0)
|
|
13
|
+
|
|
14
|
+
L = jnp.sqrt(vals)[:, None] * vecs.T
|
|
15
|
+
bflat = basis @ L.T
|
|
16
|
+
|
|
17
|
+
Q, R = jnp.linalg.qr(bflat.T)
|
|
18
|
+
linvt = us.div(vecs, jnp.sqrt(vals))
|
|
19
|
+
|
|
20
|
+
ortho = Q.T @ linvt.T
|
|
21
|
+
det = jnp.linalg.det(ortho @ L.T) > 0
|
|
22
|
+
check = jnp.where(det, 1.0, -1.0)
|
|
23
|
+
|
|
24
|
+
northo = ortho.at[0, :].multiply(check)
|
|
25
|
+
|
|
26
|
+
return northo
|
|
27
|
+
|
|
28
|
+
#dot product territory
|
|
29
|
+
|
|
30
|
+
def scalproj(g: Matrix, a: Vector, b: Vector) -> Scalar:
|
|
31
|
+
|
|
32
|
+
norm = mtc.norm(g, b)
|
|
33
|
+
prod = us.div(mtc.iprod(g, a, b), norm)
|
|
34
|
+
|
|
35
|
+
return prod
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def vectproj(g: Matrix, a: Vector, b: Vector) -> Vector:
|
|
39
|
+
|
|
40
|
+
term = mtc.iprod(g, b, b)
|
|
41
|
+
prod = us.div(mtc.iprod(g, a, b), term)
|
|
42
|
+
proj = prod * b
|
|
43
|
+
|
|
44
|
+
return proj
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def rejvect(g: Matrix, a: Vector, b: Vector) -> Vector:
|
|
48
|
+
|
|
49
|
+
proj = vectproj(g, a, b)
|
|
50
|
+
reject = a - proj
|
|
51
|
+
|
|
52
|
+
return reject
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def unitize(g: Matrix, u: Vector) -> Vector:
|
|
56
|
+
return us.div(u, mtc.norm(g, u))
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: xagm
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Description-Content-Type: text/markdown
|
|
5
|
+
Requires-Dist: jax
|
|
6
|
+
Requires-Dist: jaxlib
|
|
7
|
+
Requires-Dist: jaxtyping
|
|
8
|
+
|
|
9
|
+
XAGM is a Riemannian Differentiable Geometry engine which stands for Accelerated Autodiff Geometry Multi-dimensional. It deals exclusively in Riemannian SPD metrics, and it is MANDATORY the metrics are Symmetric Positive Definite (SPD) for it to work.
|
|
10
|
+
It offers a vast array of functions, with 4 modules to call upon, them being metrics, linear, vectors, and calc. Vectors deal mainly with linear algebra adjacent functions with respect to the metric tensor. Speaking of the metric tensor, XAGM allows you to use fwdmet to create a pullback metric.
|
|
11
|
+
|
|
12
|
+
The crown jewels of XAGM would be christoffel(), geoexp_solver(), geolog_solver(), and geodist(), with geoexp_solver consistently performing at sub millisecond speeds, and geolog_solver being in the comfortable range of 2-20ms each run depending on how many steps are given to the solver.
|
|
13
|
+
|
|
14
|
+
XAGM has been benchmarked (quite unofficially so you are free to do your own runtime checks) and observed to outperform basically every other geometry application in numpy and the dominating Geometry powerhouses. You are highly encouraged, however, to confirm that yourself too.
|
|
15
|
+
|
|
16
|
+
XAGM is a bit hard to use at first since it expects a decent background in maths for most of the functions and a clear understanding of how to use JAX native functions like vmap and jit along with static_argnums and static_argnames, but, overall, if you behave nicely and pass clean arrays into it, it will reward you. Documentation on this project will be coming soon! (or never at all. No in between.)
|
|
17
|
+
|
|
18
|
+
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
src/xagm/__init__.py
|
|
4
|
+
src/xagm/geoutils.py
|
|
5
|
+
src/xagm.egg-info/PKG-INFO
|
|
6
|
+
src/xagm.egg-info/SOURCES.txt
|
|
7
|
+
src/xagm.egg-info/dependency_links.txt
|
|
8
|
+
src/xagm.egg-info/requires.txt
|
|
9
|
+
src/xagm.egg-info/top_level.txt
|
|
10
|
+
src/xagm/basis/__init__.py
|
|
11
|
+
src/xagm/basis/linear.py
|
|
12
|
+
src/xagm/basis/metrics.py
|
|
13
|
+
src/xagm/manifolds/__init__.py
|
|
14
|
+
src/xagm/manifolds/calc.py
|
|
15
|
+
src/xagm/manifolds/vectors.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
xagm
|