xagm 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xagm-0.1.1/PKG-INFO ADDED
@@ -0,0 +1,18 @@
1
+ Metadata-Version: 2.4
2
+ Name: xagm
3
+ Version: 0.1.1
4
+ Description-Content-Type: text/markdown
5
+ Requires-Dist: jax
6
+ Requires-Dist: jaxlib
7
+ Requires-Dist: jaxtyping
8
+
9
+ XAGM is a Riemannian Differentiable Geometry engine which stands for Accelerated Autodiff Geometry Multi-dimensional. It deals exclusively in Riemannian SPD metrics, and it is MANDATORY the metrics are Symmetric Positive Definite (SPD) for it to work.
10
+ It offers a vast array of functions, with 4 modules to call upon, them being metrics, linear, vectors, and calc. Vectors deal mainly with linear algebra adjacent functions with respect to the metric tensor. Speaking of the metric tensor, XAGM allows you to use fwdmet to create a pullback metric.
11
+
12
+ The crown jewels of XAGM would be christoffel(), geoexp_solver(), geolog_solver(), and geodist(), with geoexp_solver consistently performing at sub millisecond speeds, and geolog_solver being in the comfortable range of 2-20ms each run depending on how many steps are given to the solver.
13
+
14
+ XAGM has been benchmarked (quite unofficially so you are free to do your own runtime checks) and observed to outperform basically every other geometry application in numpy and the dominating Geometry powerhouses. You are highly encouraged, however, to confirm that yourself too.
15
+
16
+ XAGM is a bit hard to use at first since it expects a decent background in maths for most of the functions and a clear understanding of how to use JAX native functions like vmap and jit along with static_argnums and static_argnames, but, overall, if you behave nicely and pass clean arrays into it, it will reward you. Documentation on this project will be coming soon! (or never at all. No in between.)
17
+
18
+
xagm-0.1.1/README.md ADDED
@@ -0,0 +1,10 @@
1
+ XAGM is a Riemannian Differentiable Geometry engine which stands for Accelerated Autodiff Geometry Multi-dimensional. It deals exclusively in Riemannian SPD metrics, and it is MANDATORY the metrics are Symmetric Positive Definite (SPD) for it to work.
2
+ It offers a vast array of functions, with 4 modules to call upon, them being metrics, linear, vectors, and calc. Vectors deal mainly with linear algebra adjacent functions with respect to the metric tensor. Speaking of the metric tensor, XAGM allows you to use fwdmet to create a pullback metric.
3
+
4
+ The crown jewels of XAGM would be christoffel(), geoexp_solver(), geolog_solver(), and geodist(), with geoexp_solver consistently performing at sub millisecond speeds, and geolog_solver being in the comfortable range of 2-20ms each run depending on how many steps are given to the solver.
5
+
6
+ XAGM has been benchmarked (quite unofficially so you are free to do your own runtime checks) and observed to outperform basically every other geometry application in numpy and the dominating Geometry powerhouses. You are highly encouraged, however, to confirm that yourself too.
7
+
8
+ XAGM is a bit hard to use at first since it expects a decent background in maths for most of the functions and a clear understanding of how to use JAX native functions like vmap and jit along with static_argnums and static_argnames, but, overall, if you behave nicely and pass clean arrays into it, it will reward you. Documentation on this project will be coming soon! (or never at all. No in between.)
9
+
10
+
@@ -0,0 +1,18 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "xagm"
7
+ version = "0.1.1"
8
+ dependencies = [
9
+ "jax", "jaxlib", "jaxtyping"
10
+ ]
11
+ readme = "README.md"
12
+ # Add any dependencies you need here, e.g., ["jax", "jaxlib"]
13
+
14
+ [tool.setuptools.packages.find]
15
+ # This tells setuptools to look for your code inside the 'src' folder
16
+ where = ["src"]
17
+
18
+ [tool.setuptools]
xagm-0.1.1/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
File without changes
@@ -0,0 +1,4 @@
1
+ from .linear import (grid, line, ang)
2
+
3
+ from .metrics import (euclid, iprod, norm, fwdmet, revmet, metinv,
4
+ metinterp)
@@ -0,0 +1,33 @@
1
+ import geoutils as us
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from geoutils import Vector, Matrix, Scalar, Tensor, JAXArray
5
+ from basis import metrics as mtc
6
+
7
+ def grid(idx: JAXArray, dimens: tuple):
8
+ fg = jnp.unravel_index(idx, dimens)
9
+ g = fg[::-1]
10
+ ng = jnp.stack(g, axis=-1)
11
+
12
+ return ng
13
+
14
+ static_argnums = (2,)
15
+ def line(p1: Vector, p2: Vector, segs: int) -> Matrix:
16
+ t = jnp.linspace(0, 1, segs)[:, jnp.newaxis]
17
+
18
+ l = p1 + (t * (p2 - p1))
19
+
20
+ return l
21
+
22
+
23
+
24
+ def ang(g: Matrix, u: Vector, v: Vector) -> Scalar:
25
+
26
+ numerator = mtc.iprod(g, u, v)
27
+ den1 = mtc.norm(g, u)
28
+ den2 = mtc.norm(g, v)
29
+
30
+ angle = us.div(numerator, (den1 * den2))
31
+ safe_cos = jnp.clip(angle, -1.0 + 1e-8, 1.0 - 1e-8)
32
+
33
+ return jnp.arccos(safe_cos)
@@ -0,0 +1,67 @@
1
+ import geoutils as us
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from geoutils import Vector, Matrix, Scalar, Tensor, JAXArray
5
+
6
+
7
+ def euclid(x: Vector) -> Matrix:
8
+ return jnp.eye(x.shape[-1])
9
+
10
+ def iprod(g: Matrix, u: Vector|Matrix, v: Vector|Matrix) -> Vector:
11
+ return jnp.einsum('...i, ...ij, ...j -> ...', u, g, v)
12
+
13
+ def norm(g: Matrix, u: Vector) -> Scalar:
14
+ return jnp.sqrt(jnp.maximum(iprod(g, u, u), 1e-16))
15
+
16
+
17
+ static_argnums = (0,)
18
+ def fwdmet(f, v: Vector) -> Matrix:
19
+ J = jax.jacfwd(f)(v)
20
+ nJ = J.reshape(-1, v.shape[-1])
21
+ return jnp.einsum('ai, aj -> ij', nJ, nJ)
22
+
23
+ static_argnums = (0,)
24
+ def revmet(f, v: Vector) -> Matrix:
25
+ J = jax.jacrev(f)(v)
26
+ nJ = J.reshape(-1, v.shape[-1])
27
+ return jnp.einsum('ai, aj -> ij', nJ, nJ)
28
+
29
+ def metinv(g: Matrix) -> Matrix:
30
+ vals, vecs = jnp.linalg.eigh(g)
31
+ inv_vals = us.div(1.0, jnp.maximum(vals, 1e-12))
32
+ met = jnp.einsum('ik, k, jk -> ij', vecs, inv_vals, vecs)
33
+
34
+ return met
35
+
36
+ def metinterp(g0: Matrix, v0: Vector,
37
+ g1: Matrix, v1: Vector,
38
+ target: Vector) -> Matrix:
39
+
40
+ vals0, vecs0 = jnp.linalg.eigh(g0)
41
+ logvals0 = jnp.log(jnp.maximum(vals0, 1e-7))
42
+ lg0 = jnp.einsum('ik, k, jk -> ij', vecs0, logvals0, vecs0)
43
+
44
+ vals1, vecs1 = jnp.linalg.eigh(g1)
45
+ logvals1 = jnp.log(jnp.maximum(vals1, 1e-7))
46
+ lg1 = jnp.einsum('ik, k, jk -> ij', vecs1, logvals1, vecs1)
47
+
48
+ d = v1 - v0
49
+ p = target - v0
50
+
51
+ t = us.div(jnp.dot(p, d),jnp.dot(d, d))
52
+ t = jnp.clip(t, 0.0, 1.0)
53
+
54
+ interp = (1.0 - t) * lg0 + (t * lg1)
55
+
56
+ intvals, intvecs = jnp.linalg.eigh(interp)
57
+
58
+ ival = jnp.exp(intvals)
59
+
60
+ ig = jnp.einsum('ik, k, jk -> ij', intvecs, ival, intvecs)
61
+
62
+ return ig
63
+
64
+
65
+
66
+
67
+
@@ -0,0 +1,27 @@
1
+ import jax
2
+ from jax import config
3
+
4
+ import os
5
+ os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
6
+ config.update("jax_enable_x64", True)
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from jaxtyping import Float64, Array
11
+
12
+ @jax.jit
13
+ def div(a, b):
14
+ safe = b != 0
15
+ den = jnp.where(safe, b, 1.0)
16
+
17
+ return jnp.where(safe, a/den, 0.0)
18
+
19
+
20
+ eps = 1e-15
21
+
22
+
23
+ type Scalar = Float64[Array, ""] # 0D [1]
24
+ type Vector = Float64[Array, "N"] # 1D [1, 2]
25
+ type Matrix = Float64[Array, "M N"] # 2D [[1, 2], [2, 1]]
26
+ type Tensor = Float64[Array, "*batch M N O"] # 3-D+
27
+ type JAXArray = Float64[Array]
@@ -0,0 +1,7 @@
1
+ from .calc import (christoffel, geoexp_solver, geolog_solver, geodist)
2
+
3
+ from .vectors import (nrml,
4
+ scalproj,
5
+ vectproj,
6
+ rejvect,
7
+ unitize)
@@ -0,0 +1,103 @@
1
+ import geoutils as us
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from geoutils import Vector, Matrix, Scalar, Tensor, JAXArray
5
+
6
+ from basis import metrics as mtc
7
+
8
+ def christoffel(func, x: Vector) -> Matrix:
9
+
10
+ g = mtc.fwdmet(func, x)
11
+ ginv = mtc.metinv(g)
12
+ mtc_func = lambda v: mtc.fwdmet(func, v)
13
+
14
+ __,dg_raw = jax.vmap(lambda v: jax.jvp(mtc_func, (x,), (v,)))(jnp.eye(x.shape[0]))
15
+
16
+ dg = jnp.moveaxis(dg_raw, 0, -1)
17
+
18
+ term1 = jnp.transpose(dg, axes=[1, 2, 0])
19
+ term2 = jnp.transpose(dg, axes=[0, 1, 2])
20
+ term3 = jnp.transpose(dg, axes=[2, 0, 1])
21
+
22
+ contract1 = 0.5 * ginv
23
+ contract2 = term1 + term2 - term3
24
+ gamma = jnp.einsum('kl, lij -> kij', contract1, contract2)
25
+
26
+ return gamma
27
+
28
+ import diffrax
29
+
30
+
31
+ def geoexp_term(t, state, args) -> Vector:
32
+ dim = state.shape[0] // 3
33
+
34
+ x = state[:dim]
35
+ v = state[dim:2*dim]
36
+ y = state[2*dim:]
37
+
38
+ func = args['func']
39
+
40
+ gamma = christoffel(func, x)
41
+
42
+ v_dot = -jnp.einsum('kij, i, j -> k', gamma, v, v)
43
+
44
+ dvecdt = -jnp.einsum('kij, i, j -> k', gamma, v, y)
45
+
46
+ return jnp.concatenate([v, v_dot, dvecdt])
47
+
48
+
49
+ def geoexp_solver(p: Vector, v: Vector, mapped_func, vt: Vector, steps = 4096) -> Vector:
50
+
51
+ state = jnp.concatenate([p, v, vt])
52
+
53
+ solution = diffrax.diffeqsolve(
54
+ terms = diffrax.ODETerm(geoexp_term),
55
+ solver = diffrax.Tsit5(),
56
+ t0=0,
57
+ t1=1,
58
+ dt0=1e-2,
59
+ y0=state,
60
+ args = {'func': mapped_func},
61
+ stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-10),
62
+ saveat=diffrax.SaveAt(t1=True),
63
+ adjoint = diffrax.DirectAdjoint(),
64
+ max_steps = steps,
65
+ throw = False
66
+ )
67
+
68
+ result = solution.ys[0]
69
+
70
+ dim = p.shape[0]
71
+ final_pos = result[:dim]
72
+ final_vel = result[dim:2*dim]
73
+ transported_v = result[2*dim:]
74
+
75
+ return final_pos, final_vel, transported_v
76
+
77
+ def geolog_solver(p: Vector, q: Vector, mapped_func, steps: int) -> Vector:
78
+
79
+ def shoot(v_guess):
80
+ pos, _, _ = geoexp_solver(p, v_guess, mapped_func, jnp.zeros_like(p))
81
+ return pos
82
+
83
+ v = q-p
84
+ J = jax.jacobian(shoot)(v)
85
+
86
+ def bodyfun(i, v):
87
+ error = shoot(v) - q
88
+ #J = jax.jacobian(shoot)(v)
89
+ delta = jnp.linalg.solve(J, error)
90
+ return v - delta
91
+
92
+ final_v = jax.lax.fori_loop(0, steps, bodyfun, v)
93
+ return final_v
94
+
95
+
96
+ def geodist(p: Vector, q: Vector, mapped_func, steps: int) -> Scalar:
97
+ v = geolog_solver(p, q, mapped_func, steps)
98
+ g = mtc.fwdmet(mapped_func, p)
99
+ dist = mtc.norm(g, v)
100
+ return dist
101
+
102
+
103
+
@@ -0,0 +1,56 @@
1
+
2
+ import geoutils as us
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from geoutils import Vector, Matrix, Scalar, Tensor, JAXArray
6
+
7
+ from basis import metrics as mtc
8
+
9
+ def nrml(g: Matrix, basis: Matrix) -> Matrix:
10
+
11
+ nvals, vecs = jnp.linalg.eigh(g)
12
+ vals = jnp.maximum(nvals, 0.0)
13
+
14
+ L = jnp.sqrt(vals)[:, None] * vecs.T
15
+ bflat = basis @ L.T
16
+
17
+ Q, R = jnp.linalg.qr(bflat.T)
18
+ linvt = us.div(vecs, jnp.sqrt(vals))
19
+
20
+ ortho = Q.T @ linvt.T
21
+ det = jnp.linalg.det(ortho @ L.T) > 0
22
+ check = jnp.where(det, 1.0, -1.0)
23
+
24
+ northo = ortho.at[0, :].multiply(check)
25
+
26
+ return northo
27
+
28
+ #dot product territory
29
+
30
+ def scalproj(g: Matrix, a: Vector, b: Vector) -> Scalar:
31
+
32
+ norm = mtc.norm(g, b)
33
+ prod = us.div(mtc.iprod(g, a, b), norm)
34
+
35
+ return prod
36
+
37
+
38
+ def vectproj(g: Matrix, a: Vector, b: Vector) -> Vector:
39
+
40
+ term = mtc.iprod(g, b, b)
41
+ prod = us.div(mtc.iprod(g, a, b), term)
42
+ proj = prod * b
43
+
44
+ return proj
45
+
46
+
47
+ def rejvect(g: Matrix, a: Vector, b: Vector) -> Vector:
48
+
49
+ proj = vectproj(g, a, b)
50
+ reject = a - proj
51
+
52
+ return reject
53
+
54
+
55
+ def unitize(g: Matrix, u: Vector) -> Vector:
56
+ return us.div(u, mtc.norm(g, u))
@@ -0,0 +1,18 @@
1
+ Metadata-Version: 2.4
2
+ Name: xagm
3
+ Version: 0.1.1
4
+ Description-Content-Type: text/markdown
5
+ Requires-Dist: jax
6
+ Requires-Dist: jaxlib
7
+ Requires-Dist: jaxtyping
8
+
9
+ XAGM is a Riemannian Differentiable Geometry engine which stands for Accelerated Autodiff Geometry Multi-dimensional. It deals exclusively in Riemannian SPD metrics, and it is MANDATORY the metrics are Symmetric Positive Definite (SPD) for it to work.
10
+ It offers a vast array of functions, with 4 modules to call upon, them being metrics, linear, vectors, and calc. Vectors deal mainly with linear algebra adjacent functions with respect to the metric tensor. Speaking of the metric tensor, XAGM allows you to use fwdmet to create a pullback metric.
11
+
12
+ The crown jewels of XAGM would be christoffel(), geoexp_solver(), geolog_solver(), and geodist(), with geoexp_solver consistently performing at sub millisecond speeds, and geolog_solver being in the comfortable range of 2-20ms each run depending on how many steps are given to the solver.
13
+
14
+ XAGM has been benchmarked (quite unofficially so you are free to do your own runtime checks) and observed to outperform basically every other geometry application in numpy and the dominating Geometry powerhouses. You are highly encouraged, however, to confirm that yourself too.
15
+
16
+ XAGM is a bit hard to use at first since it expects a decent background in maths for most of the functions and a clear understanding of how to use JAX native functions like vmap and jit along with static_argnums and static_argnames, but, overall, if you behave nicely and pass clean arrays into it, it will reward you. Documentation on this project will be coming soon! (or never at all. No in between.)
17
+
18
+
@@ -0,0 +1,15 @@
1
+ README.md
2
+ pyproject.toml
3
+ src/xagm/__init__.py
4
+ src/xagm/geoutils.py
5
+ src/xagm.egg-info/PKG-INFO
6
+ src/xagm.egg-info/SOURCES.txt
7
+ src/xagm.egg-info/dependency_links.txt
8
+ src/xagm.egg-info/requires.txt
9
+ src/xagm.egg-info/top_level.txt
10
+ src/xagm/basis/__init__.py
11
+ src/xagm/basis/linear.py
12
+ src/xagm/basis/metrics.py
13
+ src/xagm/manifolds/__init__.py
14
+ src/xagm/manifolds/calc.py
15
+ src/xagm/manifolds/vectors.py
@@ -0,0 +1,3 @@
1
+ jax
2
+ jaxlib
3
+ jaxtyping
@@ -0,0 +1 @@
1
+ xagm