blue-sampler 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blue_sampler-0.1.0/PKG-INFO +71 -0
- blue_sampler-0.1.0/README.MD +55 -0
- blue_sampler-0.1.0/README.md +55 -0
- blue_sampler-0.1.0/pyproject.toml +37 -0
- blue_sampler-0.1.0/src/blue_sampler/__init__.py +24 -0
- blue_sampler-0.1.0/src/blue_sampler/kernels.py +100 -0
- blue_sampler-0.1.0/src/blue_sampler/math_utils.py +254 -0
- blue_sampler-0.1.0/src/blue_sampler/progress.py +84 -0
- blue_sampler-0.1.0/src/blue_sampler/run.py +271 -0
- blue_sampler-0.1.0/src/blue_sampler/viz.py +118 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: blue_sampler
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Stealthy point-pattern sampling on the unit torus
|
|
5
|
+
Project-URL: Repository, https://github.com/For-a-few-DPPs-more/hyperuniform-samplers
|
|
6
|
+
License: MIT
|
|
7
|
+
Keywords: blue noise,jax,point process,sampling,stealthy
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
Requires-Dist: jax
|
|
10
|
+
Requires-Dist: jaxlib
|
|
11
|
+
Requires-Dist: matplotlib
|
|
12
|
+
Requires-Dist: numpy
|
|
13
|
+
Requires-Dist: requests
|
|
14
|
+
Requires-Dist: squarenet
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
|
|
17
|
+
# blue-sampler
|
|
18
|
+
|
|
19
|
+
Generate **stealthy point patterns** — low-discrepancy, spectrally isotropic
|
|
20
|
+
samples on the unit torus [0, 1)^D.
|
|
21
|
+
|
|
22
|
+
Stealthy patterns suppress long-range density fluctuations while remaining
|
|
23
|
+
aperiodic. They are useful in rendering, quadrature, and computational
|
|
24
|
+
physics wherever quasi-random, isotropic spatial coverage is needed.
|
|
25
|
+
|
|
26
|
+
## Installation
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
pip install blue_sampler
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## Quick start
|
|
33
|
+
|
|
34
|
+
```python
|
|
35
|
+
import blue_sampler as blue
|
|
36
|
+
|
|
37
|
+
# 10 000 points in 2-D
|
|
38
|
+
x = blue.sample(10_000)
|
|
39
|
+
blue.plot(x)
|
|
40
|
+
blue.plot_structure_factor(x)
|
|
41
|
+
|
|
42
|
+
# 3-D
|
|
43
|
+
x = blue.sample(5_000, D=3)
|
|
44
|
+
```
|
|
45
|
+
## Supported dimensions
|
|
46
|
+
|
|
47
|
+
| D | Notes |
|
|
48
|
+
|---|-------|
|
|
49
|
+
| 2 | Fast, recommended for exploration |
|
|
50
|
+
| 3 | ~3× slower than 2-D |
|
|
51
|
+
| 4 | Requires more iterations (set automatically by `Config.auto`) |
|
|
52
|
+
| 5 | Experimental |
|
|
53
|
+
|
|
54
|
+
## Algorithm overview
|
|
55
|
+
|
|
56
|
+
The pipeline alternates between:
|
|
57
|
+
|
|
58
|
+
1. **Spatial gradient** — short-range Gaussian repulsion via
|
|
59
|
+
neighbour convolution on the torus.
|
|
60
|
+
2. **Spectral gradient** — minimises the structure factor S(k) for k below
|
|
61
|
+
a chosen cut-off, using a set of all the wave vectors within an integer
|
|
62
|
+
half-ball.
|
|
63
|
+
3. **Grid assignment** (SquareNet) — periodic re-assignment to a regular
|
|
64
|
+
grid for efficient sparse local operations.
|
|
65
|
+
|
|
66
|
+
For N ≤ 3 000 a direct O(N²) bootstrap is used. For larger N a
|
|
67
|
+
hierarchical strategy clones and refines a coarser solution.
|
|
68
|
+
|
|
69
|
+
## License
|
|
70
|
+
|
|
71
|
+
MIT
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# blue-sampler
|
|
2
|
+
|
|
3
|
+
Generate **stealthy point patterns** — low-discrepancy, spectrally isotropic
|
|
4
|
+
samples on the unit torus [0, 1)^D.
|
|
5
|
+
|
|
6
|
+
Stealthy patterns suppress long-range density fluctuations while remaining
|
|
7
|
+
aperiodic. They are useful in rendering, quadrature, and computational
|
|
8
|
+
physics wherever quasi-random, isotropic spatial coverage is needed.
|
|
9
|
+
|
|
10
|
+
## Installation
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
pip install blue_sampler
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
## Quick start
|
|
17
|
+
|
|
18
|
+
```python
|
|
19
|
+
import blue_sampler as blue
|
|
20
|
+
|
|
21
|
+
# 10 000 points in 2-D
|
|
22
|
+
x = blue.sample(10_000)
|
|
23
|
+
blue.plot(x)
|
|
24
|
+
blue.plot_structure_factor(x)
|
|
25
|
+
|
|
26
|
+
# 3-D
|
|
27
|
+
x = blue.sample(5_000, D=3)
|
|
28
|
+
```
|
|
29
|
+
## Supported dimensions
|
|
30
|
+
|
|
31
|
+
| D | Notes |
|
|
32
|
+
|---|-------|
|
|
33
|
+
| 2 | Fast, recommended for exploration |
|
|
34
|
+
| 3 | ~3× slower than 2-D |
|
|
35
|
+
| 4 | Requires more iterations (set automatically by `Config.auto`) |
|
|
36
|
+
| 5 | Experimental |
|
|
37
|
+
|
|
38
|
+
## Algorithm overview
|
|
39
|
+
|
|
40
|
+
The pipeline alternates between:
|
|
41
|
+
|
|
42
|
+
1. **Spatial gradient** — short-range Gaussian repulsion via
|
|
43
|
+
neighbour convolution on the torus.
|
|
44
|
+
2. **Spectral gradient** — minimises the structure factor S(k) for k below
|
|
45
|
+
a chosen cut-off, using a set of all the wave vectors within an integer
|
|
46
|
+
half-ball.
|
|
47
|
+
3. **Grid assignment** (SquareNet) — periodic re-assignment to a regular
|
|
48
|
+
grid for efficient sparse local operations.
|
|
49
|
+
|
|
50
|
+
For N ≤ 3 000 a direct O(N²) bootstrap is used. For larger N a
|
|
51
|
+
hierarchical strategy clones and refines a coarser solution.
|
|
52
|
+
|
|
53
|
+
## License
|
|
54
|
+
|
|
55
|
+
MIT
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# blue-sampler
|
|
2
|
+
|
|
3
|
+
Generate **stealthy point patterns** — low-discrepancy, spectrally isotropic
|
|
4
|
+
samples on the unit torus [0, 1)^D.
|
|
5
|
+
|
|
6
|
+
Stealthy patterns suppress long-range density fluctuations while remaining
|
|
7
|
+
aperiodic. They are useful in rendering, quadrature, and computational
|
|
8
|
+
physics wherever quasi-random, isotropic spatial coverage is needed.
|
|
9
|
+
|
|
10
|
+
## Installation
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
pip install blue_sampler
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
## Quick start
|
|
17
|
+
|
|
18
|
+
```python
|
|
19
|
+
import blue_sampler as blue
|
|
20
|
+
|
|
21
|
+
# 10 000 points in 2-D
|
|
22
|
+
x = blue.sample(10_000)
|
|
23
|
+
blue.plot(x)
|
|
24
|
+
blue.plot_structure_factor(x)
|
|
25
|
+
|
|
26
|
+
# 3-D
|
|
27
|
+
x = blue.sample(5_000, D=3)
|
|
28
|
+
```
|
|
29
|
+
## Supported dimensions
|
|
30
|
+
|
|
31
|
+
| D | Notes |
|
|
32
|
+
|---|-------|
|
|
33
|
+
| 2 | Fast, recommended for exploration |
|
|
34
|
+
| 3 | ~3× slower than 2-D |
|
|
35
|
+
| 4 | Requires more iterations (set automatically by `Config.auto`) |
|
|
36
|
+
| 5 | Experimental |
|
|
37
|
+
|
|
38
|
+
## Algorithm overview
|
|
39
|
+
|
|
40
|
+
The pipeline alternates between:
|
|
41
|
+
|
|
42
|
+
1. **Spatial gradient** — short-range Gaussian repulsion via
|
|
43
|
+
neighbour convolution on the torus.
|
|
44
|
+
2. **Spectral gradient** — minimises the structure factor S(k) for k below
|
|
45
|
+
a chosen cut-off, using a set of all the wave vectors within an integer
|
|
46
|
+
half-ball.
|
|
47
|
+
3. **Grid assignment** (SquareNet) — periodic re-assignment to a regular
|
|
48
|
+
grid for efficient sparse local operations.
|
|
49
|
+
|
|
50
|
+
For N ≤ 3 000 a direct O(N²) bootstrap is used. For larger N a
|
|
51
|
+
hierarchical strategy clones and refines a coarser solution.
|
|
52
|
+
|
|
53
|
+
## License
|
|
54
|
+
|
|
55
|
+
MIT
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling >= 1.26"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "blue_sampler"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Stealthy point-pattern sampling on the unit torus"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = { text = "MIT" }
|
|
11
|
+
requires-python = ">=3.10"
|
|
12
|
+
|
|
13
|
+
keywords = [
|
|
14
|
+
"blue noise",
|
|
15
|
+
"stealthy",
|
|
16
|
+
"point process",
|
|
17
|
+
"sampling",
|
|
18
|
+
"jax",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
dependencies = [
|
|
22
|
+
"numpy",
|
|
23
|
+
"jax",
|
|
24
|
+
"jaxlib",
|
|
25
|
+
"matplotlib",
|
|
26
|
+
"requests",
|
|
27
|
+
"squarenet",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[project.urls]
|
|
31
|
+
Repository = "https://github.com/For-a-few-DPPs-more/hyperuniform-samplers"
|
|
32
|
+
|
|
33
|
+
[tool.hatch.build.targets.wheel]
|
|
34
|
+
packages = ["src/blue_sampler"]
|
|
35
|
+
|
|
36
|
+
[tool.ruff]
|
|
37
|
+
line-length = 100
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""
|
|
2
|
+
blue_sampler
|
|
3
|
+
============
|
|
4
|
+
|
|
5
|
+
Generate stealthy point patterns — low-discrepancy, spectrally isotropic
|
|
6
|
+
samples on the unit torus [0, 1)^D.
|
|
7
|
+
|
|
8
|
+
Quick start
|
|
9
|
+
-----------
|
|
10
|
+
>>> import blue_sampler as blue
|
|
11
|
+
>>> x = blue.sample(N=10_000, D=2) # (10000, 2) array
|
|
12
|
+
>>> blue.plot(x)
|
|
13
|
+
>>> blue.plot_structure_factor(x)
|
|
14
|
+
"""
|
|
15
|
+
from .run import sample
|
|
16
|
+
from .viz import plot, plot_structure_factor
|
|
17
|
+
from .math_utils import structure_factor
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"sample",
|
|
21
|
+
"plot",
|
|
22
|
+
"plot_structure_factor",
|
|
23
|
+
"structure_factor",
|
|
24
|
+
]
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Kernels for energy gradient
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from .math_utils import torus_delta, clean_grad
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
15
|
+
# Kernel functions (JAX)
|
|
16
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
17
|
+
|
|
18
|
+
def gauss_kernel(
|
|
19
|
+
x: jnp.ndarray,
|
|
20
|
+
y: jnp.ndarray,
|
|
21
|
+
sigma2: float,
|
|
22
|
+
) -> jnp.ndarray:
|
|
23
|
+
"""
|
|
24
|
+
Isotropic Gaussian repulsion kernel on the torus.
|
|
25
|
+
|
|
26
|
+
Returns the *gradient* contribution (y − x) * exp(−‖y−x‖² / σ²).
|
|
27
|
+
"""
|
|
28
|
+
delta = torus_delta(y - x)
|
|
29
|
+
dist2 = jnp.sum(delta ** 2, axis=-1, keepdims=True)
|
|
30
|
+
return clean_grad(delta * jnp.exp(-dist2 / sigma2))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def gauss_sin_kernel(
|
|
34
|
+
x: jnp.ndarray,
|
|
35
|
+
y: jnp.ndarray,
|
|
36
|
+
a: float,
|
|
37
|
+
b: float,
|
|
38
|
+
c: float,
|
|
39
|
+
) -> jnp.ndarray:
|
|
40
|
+
"""
|
|
41
|
+
Trigonometric Gaussian kernel — more stable in contexts were
|
|
42
|
+
sigma2 is not << 1 e.g. high dimension or low number of points
|
|
43
|
+
-> using discontinuous torus_delta would become problematic.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
a, b, c : pre-computed scale factors (derived from sigma²).
|
|
48
|
+
"""
|
|
49
|
+
delta = a * (y - x)
|
|
50
|
+
cos_term = b * (1.0 - jnp.cos(delta))
|
|
51
|
+
sin_term = c * jnp.sin(delta)
|
|
52
|
+
dist2 = jnp.sum(cos_term, axis=-1, keepdims=True)
|
|
53
|
+
return clean_grad(sin_term * jnp.exp(-dist2))
|
|
54
|
+
|
|
55
|
+
def spectral_kernel(x, k, k_):
|
|
56
|
+
"""
|
|
57
|
+
spectral kernel directly target spectral energy. Only usable
|
|
58
|
+
for small subsets of preselected wavevectors
|
|
59
|
+
"""
|
|
60
|
+
phase = jnp.sum(k * x, axis=-1, keepdims=True)
|
|
61
|
+
ek = clean_grad(jnp.exp(phase))
|
|
62
|
+
Sk = jnp.sum(ek, axis=0, keepdims=True)
|
|
63
|
+
return jnp.real(Sk * k_ * jnp.conjugate(ek))
|
|
64
|
+
|
|
65
|
+
def reduce_kernel(kernel, x, params, init=None):
|
|
66
|
+
"""
|
|
67
|
+
Generic reduction over a kernel.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
kernel : callable
|
|
72
|
+
Function of the form kernel(x, param) -> contribution.
|
|
73
|
+
x : array
|
|
74
|
+
State.
|
|
75
|
+
params : PyTree
|
|
76
|
+
Collection of parameters passed to kernel.
|
|
77
|
+
init : array, optional
|
|
78
|
+
Initial accumulator. Defaults to zeros_like(x).
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
array
|
|
83
|
+
Sum of all kernel contributions.
|
|
84
|
+
"""
|
|
85
|
+
if init is None:
|
|
86
|
+
init = jnp.zeros_like(x)
|
|
87
|
+
|
|
88
|
+
def body(acc, param):
|
|
89
|
+
return acc + kernel(x, param), None
|
|
90
|
+
|
|
91
|
+
out, _ = jax.lax.scan(body, init, params)
|
|
92
|
+
return out
|
|
93
|
+
|
|
94
|
+
def micro_shift_kernel(x, shift, kernel, Axes):
|
|
95
|
+
contrib = kernel(x, jnp.roll(x, shift, axis=Axes))
|
|
96
|
+
return contrib - jnp.roll(contrib, -shift, axis=Axes)
|
|
97
|
+
|
|
98
|
+
def micro_grad(x_val, SHIFTS, LR_spatial, S):
|
|
99
|
+
out = reduce_kernel(micro_shift_kernel, x_val, SHIFTS)
|
|
100
|
+
return (LR_spatial / S) * out
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Low-level mathematical helpers
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
13
|
+
# Lattice helpers
|
|
14
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
15
|
+
|
|
16
|
+
def drop_symmetric(directions: np.ndarray) -> np.ndarray:
|
|
17
|
+
"""
|
|
18
|
+
Keep only one representative from each direction pair {v, -v}.
|
|
19
|
+
|
|
20
|
+
The canonical representative is the one whose *first non-zero component*
|
|
21
|
+
is positive.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
directions : (M, D) int array
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
(K, D) int array with K ≤ M // 2 + 1
|
|
30
|
+
"""
|
|
31
|
+
first_nz_idx = (directions != 0).argmax(axis=1)
|
|
32
|
+
first_nz_val = directions[np.arange(len(directions)), first_nz_idx]
|
|
33
|
+
return directions[first_nz_val > 0]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def integers_in_half_ball(radius: float, D: int) -> np.ndarray:
|
|
37
|
+
"""
|
|
38
|
+
Return all non-zero integer lattice vectors inside a sphere of *radius*,
|
|
39
|
+
keeping only one vector per direction pair.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
radius : float
|
|
44
|
+
D : int
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
(M, D) int32 array
|
|
49
|
+
"""
|
|
50
|
+
if radius <= 0.9:
|
|
51
|
+
return np.zeros((0, D), dtype=np.int32)
|
|
52
|
+
if radius <= 1.9:
|
|
53
|
+
return np.eye(D, dtype=np.int32)
|
|
54
|
+
|
|
55
|
+
r = np.arange(-radius, radius + 1)
|
|
56
|
+
pts = np.stack(np.meshgrid(*(r,) * D, indexing="ij"), axis=-1).reshape(-1, D)
|
|
57
|
+
d2 = np.sum(pts ** 2, axis=-1)
|
|
58
|
+
return drop_symmetric(pts[(d2 > 0) & (d2 <= radius ** 2)])
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def simplex(D: int) -> np.ndarray:
|
|
62
|
+
"""
|
|
63
|
+
Vertices of a regular simplex centred at the origin in R^D.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
(D+1, D) float64 array
|
|
68
|
+
"""
|
|
69
|
+
if D == 1:
|
|
70
|
+
return np.array([-1.0, 1.0])[:, None]
|
|
71
|
+
null = np.zeros((D, 1))
|
|
72
|
+
tip = np.zeros((1, D))
|
|
73
|
+
tip[0, -1] = 1.0
|
|
74
|
+
base = np.hstack((simplex(D - 1), null))
|
|
75
|
+
return np.vstack((np.sqrt(1.0 - (1.0 / D) ** 2) * base - tip / D, tip))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def grid_shape(N: int, D: int) -> tuple[tuple[int, ...], int, tuple[int, ...]]:
|
|
79
|
+
"""
|
|
80
|
+
Smallest D-hypercube grid that contains at least *N* points.
|
|
81
|
+
|
|
82
|
+
Returns
|
|
83
|
+
-------
|
|
84
|
+
IJK : shape tuple e.g. (32, 32) for D=2
|
|
85
|
+
total : total number of grid slots (I^D)
|
|
86
|
+
axes : tuple(range(D))
|
|
87
|
+
"""
|
|
88
|
+
I = int(np.ceil(N ** (1.0 / D)))
|
|
89
|
+
IJK = (I,) * D
|
|
90
|
+
return IJK, I ** D, tuple(range(D))
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
94
|
+
# Torus geometry (JAX)
|
|
95
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
96
|
+
|
|
97
|
+
def torus_wrap(x: jnp.ndarray) -> jnp.ndarray:
|
|
98
|
+
"""Wrap coordinates into [0, 1)^D."""
|
|
99
|
+
return x - jnp.floor(x)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def torus_delta(delta: jnp.ndarray) -> jnp.ndarray:
|
|
103
|
+
"""Shortest signed displacement on the unit torus."""
|
|
104
|
+
return delta - jnp.round(delta)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
108
|
+
# Gradient / status helpers (JAX)
|
|
109
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
110
|
+
|
|
111
|
+
def clean_grad(x: jnp.ndarray) -> jnp.ndarray:
|
|
112
|
+
"""Replace NaN gradient contributions (fictive points) with 0."""
|
|
113
|
+
return jnp.nan_to_num(x, nan=0.0)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def clean_points(x: jnp.ndarray) -> jnp.ndarray:
|
|
117
|
+
"""
|
|
118
|
+
Preserve the NaN status flag of empty grid slots after a torus wrap.
|
|
119
|
+
|
|
120
|
+
The last coordinate of each grid slot encodes whether the slot is real
|
|
121
|
+
(0.0) or fictive (NaN). ``torus_wrap`` can corrupt this flag, so we
|
|
122
|
+
re-round it here.
|
|
123
|
+
"""
|
|
124
|
+
return x.at[..., -1:].set(jnp.round(x[..., -1:]))
|
|
125
|
+
|
|
126
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
127
|
+
# Wave-vector preparation
|
|
128
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
129
|
+
|
|
130
|
+
def prepare_wave_vectors(
|
|
131
|
+
Ks: np.ndarray,
|
|
132
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
133
|
+
"""
|
|
134
|
+
Build JAX arrays for the spectral gradient.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
Ks : (M, D) integer wave-vector matrix
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
K_w : complex array of shape (M, 1, D+1) — phase multipliers
|
|
143
|
+
K_ : complex array of shape (M, 1, D+1) — normalised duals
|
|
144
|
+
"""
|
|
145
|
+
K = 2.0 * jnp.pi * Ks * 1j
|
|
146
|
+
K = jnp.concatenate((K, np.zeros((len(K), 1))), axis=1)[:, None, :]
|
|
147
|
+
Kn = (jnp.abs(K) ** 2).sum(axis=-1, keepdims=True)
|
|
148
|
+
return K, -K / Kn
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
152
|
+
# Grid initialisation helpers
|
|
153
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
154
|
+
|
|
155
|
+
def prepare_points(
|
|
156
|
+
x: np.ndarray | None,
|
|
157
|
+
N_asked: int,
|
|
158
|
+
IJK: tuple[int, ...],
|
|
159
|
+
D: int,
|
|
160
|
+
) -> jnp.ndarray:
|
|
161
|
+
"""
|
|
162
|
+
Pad *N_asked* real points to fill the I^D grid.
|
|
163
|
+
|
|
164
|
+
Fictive slots receive a NaN status coordinate so gradients ignore them.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
x : (N_asked, D) array or *None* (random initialisation).
|
|
169
|
+
N_asked : number of real points.
|
|
170
|
+
IJK : grid shape tuple.
|
|
171
|
+
D : spatial dimension.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
jnp.ndarray of shape (*IJK, D+1)
|
|
176
|
+
"""
|
|
177
|
+
if x is None:
|
|
178
|
+
x = np.random.rand(N_asked, D)
|
|
179
|
+
else:
|
|
180
|
+
x = np.asarray(x).reshape(N_asked, D)
|
|
181
|
+
|
|
182
|
+
total = int(np.prod(IJK))
|
|
183
|
+
xfull = np.random.rand(total, D + 1)
|
|
184
|
+
xfull[:, -1] = 0.0 # status = 0 → real
|
|
185
|
+
xfull[:N_asked, :D] = x
|
|
186
|
+
xfull[N_asked:, D] = np.nan # status = NaN → fictive
|
|
187
|
+
return jnp.array(xfull.reshape(*IJK, D + 1))
|
|
188
|
+
|
|
189
|
+
def random_rotations(x, batch_size, Dout, Din):
|
|
190
|
+
Q, _ = np.linalg.qr(np.random.randn(batch_size, Dout, Din))
|
|
191
|
+
offsets = np.einsum(
|
|
192
|
+
"nij,kj->nki", Q, x
|
|
193
|
+
)
|
|
194
|
+
return offsets
|
|
195
|
+
|
|
196
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
197
|
+
# Structure factor
|
|
198
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
|
199
|
+
|
|
200
|
+
def structure_factor(
|
|
201
|
+
points: np.ndarray,
|
|
202
|
+
nbins: int = 100,
|
|
203
|
+
resolution: float = 30.0,
|
|
204
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
205
|
+
"""
|
|
206
|
+
Estimate the radial structure factor S(k) via scattering intensity.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
points : (N, D) array of point coordinates in [0, 1)^D.
|
|
211
|
+
nbins : number of radial bins.
|
|
212
|
+
resolution : how many random wave-vectors to sample per bin.
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
k : (M,) float array — bin centres.
|
|
217
|
+
S : (M,) float array — mean S(k) per non-empty bin.
|
|
218
|
+
"""
|
|
219
|
+
pts = np.asarray(points)
|
|
220
|
+
N, D = pts.shape
|
|
221
|
+
|
|
222
|
+
kmed = int(1_000 ** (1.0 / D))
|
|
223
|
+
kmax = int(2 * N ** (1.0 / D))
|
|
224
|
+
bins = np.linspace(0, kmax, nbins)
|
|
225
|
+
|
|
226
|
+
# Random + deterministic wave-vector sampling
|
|
227
|
+
nvecs = np.random.randint(-kmax, kmax + 1, size=(int(resolution * nbins), D))
|
|
228
|
+
nvecs = np.concatenate([nvecs, integers_in_half_ball(kmed, D)], axis=0)
|
|
229
|
+
nvecs = nvecs[np.any(nvecs != 0, axis=1)]
|
|
230
|
+
|
|
231
|
+
knorm = np.linalg.norm(nvecs, axis=1)
|
|
232
|
+
bin_idx = np.searchsorted(bins, knorm) - 1
|
|
233
|
+
valid = (bin_idx >= 0) & (bin_idx < len(bins) - 1)
|
|
234
|
+
nvecs, bin_idx = nvecs[valid], bin_idx[valid]
|
|
235
|
+
|
|
236
|
+
kvecs = jnp.array(2.0 * np.pi * nvecs)
|
|
237
|
+
pts_j = jnp.array(pts)
|
|
238
|
+
|
|
239
|
+
def Sk_one(k: jnp.ndarray) -> jnp.ndarray:
|
|
240
|
+
rho = jnp.sum(jnp.exp(1j * (pts_j @ k)), axis=0)
|
|
241
|
+
return jnp.abs(rho) ** 2 / N
|
|
242
|
+
|
|
243
|
+
Sk = np.asarray(jax.lax.map(Sk_one, kvecs))
|
|
244
|
+
|
|
245
|
+
n_bins = len(bins) - 1
|
|
246
|
+
S_sum = np.bincount(bin_idx, weights=Sk, minlength=n_bins)
|
|
247
|
+
counts = np.bincount(bin_idx, minlength=n_bins)
|
|
248
|
+
|
|
249
|
+
S = np.zeros_like(S_sum, dtype=float)
|
|
250
|
+
nz = counts > 0
|
|
251
|
+
S[nz] = S_sum[nz] / counts[nz]
|
|
252
|
+
|
|
253
|
+
centres = 0.5 * (bins[:-1] + bins[1:])
|
|
254
|
+
return centres[nz], S[nz]
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ProgressLogger:
|
|
8
|
+
"""Hierarchical \r-based progress display for nested pipeline levels."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, verbose: int):
|
|
11
|
+
self.verbose = verbose
|
|
12
|
+
self.level = -1
|
|
13
|
+
|
|
14
|
+
def enter_level(self, N: int, D: int, N_ITER: int) -> _LevelCtx:
|
|
15
|
+
"""Push a new recursion level and return its context."""
|
|
16
|
+
self.level += 1
|
|
17
|
+
return _LevelCtx(self, N, D, N_ITER)
|
|
18
|
+
|
|
19
|
+
def exit_level(self) -> None:
|
|
20
|
+
self.level -= 1
|
|
21
|
+
|
|
22
|
+
def _prefix(self) -> str:
|
|
23
|
+
return f"[L{self.level}] "
|
|
24
|
+
|
|
25
|
+
def write(self, msg: str, newline: bool = False) -> None:
|
|
26
|
+
if self.verbose < 1:
|
|
27
|
+
return
|
|
28
|
+
sys.stdout.write(f"\r{self._prefix()}{msg} ")
|
|
29
|
+
if newline:
|
|
30
|
+
sys.stdout.write("\n")
|
|
31
|
+
sys.stdout.flush()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class _LevelCtx:
|
|
35
|
+
"""Tracks timing and tick state for a single pipeline level."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, logger: ProgressLogger, N: int, D: int, N_ITER: int):
|
|
38
|
+
self._log = logger
|
|
39
|
+
self.N = N
|
|
40
|
+
self.D = D
|
|
41
|
+
self.N_ITER = N_ITER
|
|
42
|
+
self._tick = 0
|
|
43
|
+
self._t0: float | None = None
|
|
44
|
+
self._t_iter: float | None = None
|
|
45
|
+
|
|
46
|
+
def on_compile(self) -> None:
|
|
47
|
+
self._log.write("compiling JAX kernel…")
|
|
48
|
+
|
|
49
|
+
def on_bruteforce_start(self) -> None:
|
|
50
|
+
self._log.write(f"bruteforce N={self.N} D={self.D} …")
|
|
51
|
+
|
|
52
|
+
def on_bruteforce_done(self) -> None:
|
|
53
|
+
self._log.write("bruteforce done ✓", newline=True)
|
|
54
|
+
|
|
55
|
+
def tick(self) -> None:
|
|
56
|
+
"""Called once per gridification callback (= one full iteration). Drives the ETA display."""
|
|
57
|
+
now = time.perf_counter()
|
|
58
|
+
self._tick += 1
|
|
59
|
+
|
|
60
|
+
if self._tick == 1:
|
|
61
|
+
self._t0 = now
|
|
62
|
+
self._log.write(f"{self._bar()} — calibrating…")
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
if self._tick == 2:
|
|
66
|
+
self._t_iter = now - self._t0 # type: ignore[operator]
|
|
67
|
+
|
|
68
|
+
self._log.write(f"{self._bar()} — {self._eta(now)} remaining")
|
|
69
|
+
|
|
70
|
+
def done(self) -> None:
|
|
71
|
+
self._log.write(f"{self._bar(done=True)} — done ✓", newline=True)
|
|
72
|
+
|
|
73
|
+
def _bar(self, done: bool = False) -> str:
|
|
74
|
+
filled = self.N_ITER if done else max(0, self._tick - 1)
|
|
75
|
+
W = 20
|
|
76
|
+
n_fill = int(W * filled / self.N_ITER)
|
|
77
|
+
bar = "▓" * n_fill + "░" * (W - n_fill)
|
|
78
|
+
return f"{filled}/{self.N_ITER} [{bar}]"
|
|
79
|
+
|
|
80
|
+
def _eta(self, now: float) -> str:
|
|
81
|
+
if self._t_iter is None:
|
|
82
|
+
return "?"
|
|
83
|
+
remaining = (self.N_ITER - (self._tick - 1)) * self._t_iter
|
|
84
|
+
return f"~{remaining:.0f}s" if remaining < 60 else f"~{remaining / 60:.1f}min"
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from squarenet import SquareNet
|
|
7
|
+
|
|
8
|
+
from .math_utils import (
|
|
9
|
+
integers_in_half_ball,
|
|
10
|
+
simplex,
|
|
11
|
+
grid_shape,
|
|
12
|
+
torus_wrap,
|
|
13
|
+
clean_points,
|
|
14
|
+
prepare_wave_vectors,
|
|
15
|
+
prepare_points,
|
|
16
|
+
random_rotations,
|
|
17
|
+
)
|
|
18
|
+
from .kernels import (
|
|
19
|
+
gauss_kernel,
|
|
20
|
+
gauss_sin_kernel,
|
|
21
|
+
spectral_kernel,
|
|
22
|
+
)
|
|
23
|
+
from .progress import ProgressLogger, _LevelCtx
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
_PRESETS: dict[int, dict] = {
|
|
27
|
+
2: dict(spatial_radius=7, spectral_radius=7, LR_spatial=0.1, LR_spectral=0.1, expension_factor=0.3, S=1.0),
|
|
28
|
+
3: dict(spatial_radius=5, spectral_radius=5, LR_spatial=0.1, LR_spectral=0.1, expension_factor=0.3, S=1.0),
|
|
29
|
+
4: dict(spatial_radius=3, spectral_radius=3, LR_spatial=0.01, LR_spectral=0.1, expension_factor=1.0, S=0.5),
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ── Bruteforce (small N) ──────────────────────────────────────────────────────
|
|
34
|
+
|
|
35
|
+
def _build_bruteforce(N: int, D: int, ctx: _LevelCtx):
|
|
36
|
+
"""AOT-compile a gradient-descent sampler for N ≤ ~3 000 points."""
|
|
37
|
+
DX = 1.0 / N ** (1.0 / D)
|
|
38
|
+
S = 1.0
|
|
39
|
+
sigma2 = S * 2.0 * DX ** 2
|
|
40
|
+
high_D = sigma2 >= 0.03
|
|
41
|
+
|
|
42
|
+
lr_table = {2: 0.4, 3: 0.1, 4: 0.05, 5: 0.01}
|
|
43
|
+
lr = lr_table.get(D, 0.01)
|
|
44
|
+
Niter = 1_000 if high_D else 3_000
|
|
45
|
+
|
|
46
|
+
if high_D:
|
|
47
|
+
a = 2.0 * jnp.pi
|
|
48
|
+
b = 2.0 / (sigma2 * a ** 2)
|
|
49
|
+
c = 1.0 / (2.0 * S * jnp.pi)
|
|
50
|
+
kernel = lambda x, y: gauss_sin_kernel(x, y, a, b, c)
|
|
51
|
+
else:
|
|
52
|
+
kernel = lambda x, y: gauss_kernel(x, y, sigma2)
|
|
53
|
+
|
|
54
|
+
def grad(x):
|
|
55
|
+
return jax.vmap(lambda xi: kernel(xi[None], x).sum(axis=0))(x)
|
|
56
|
+
|
|
57
|
+
@jax.jit
|
|
58
|
+
def _run(x):
|
|
59
|
+
def step(_, x):
|
|
60
|
+
return torus_wrap(x - lr * grad(x))
|
|
61
|
+
return jax.lax.fori_loop(0, Niter, step, x)
|
|
62
|
+
|
|
63
|
+
ctx.on_compile()
|
|
64
|
+
compiled = _run.lower(jax.ShapeDtypeStruct((N, D), jnp.float32)).compile()
|
|
65
|
+
|
|
66
|
+
def sample_fn(init: np.ndarray | None = None) -> jnp.ndarray:
|
|
67
|
+
ctx.on_bruteforce_start()
|
|
68
|
+
if init is None:
|
|
69
|
+
init = np.random.rand(N, D)
|
|
70
|
+
out = compiled(jnp.asarray(init))
|
|
71
|
+
out.block_until_ready()
|
|
72
|
+
ctx.on_bruteforce_done()
|
|
73
|
+
return out
|
|
74
|
+
|
|
75
|
+
return sample_fn
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# ── Core pipeline ─────────────────────────────────────────────────────────────
|
|
79
|
+
|
|
80
|
+
def _run_pipeline(
|
|
81
|
+
N: int,
|
|
82
|
+
D: int,
|
|
83
|
+
N_ITER: int,
|
|
84
|
+
logger: ProgressLogger,
|
|
85
|
+
*,
|
|
86
|
+
x: np.ndarray | None = None,
|
|
87
|
+
S: float,
|
|
88
|
+
expension_factor: float,
|
|
89
|
+
LR_spatial: float,
|
|
90
|
+
LR_spectral: float,
|
|
91
|
+
spatial_radius: int,
|
|
92
|
+
spectral_radius: int,
|
|
93
|
+
N_PER_STEP: int,
|
|
94
|
+
_is_root: bool = False,
|
|
95
|
+
_is_leaf: bool = True,
|
|
96
|
+
) -> np.ndarray:
|
|
97
|
+
"""Recursive stealthy-sampling pipeline. Spawns child pipelines when N is large."""
|
|
98
|
+
ctx = logger.enter_level(N, D, N_ITER)
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
Dsimp = min(D, 3)
|
|
102
|
+
IJK, _, Axes = grid_shape(N, D)
|
|
103
|
+
Nsqrt = N ** 0.5
|
|
104
|
+
Ncbrt = N ** (1.0 / D)
|
|
105
|
+
is_root = _is_root or (N <= 2_000) or (x is not None)
|
|
106
|
+
sigma2 = S * 2.0 * (1.0 / Ncbrt) ** 2
|
|
107
|
+
high_D = sigma2 >= 0.03
|
|
108
|
+
|
|
109
|
+
SHIFTS = integers_in_half_ball(spatial_radius, D)
|
|
110
|
+
Ks = integers_in_half_ball(spectral_radius, D)
|
|
111
|
+
K_w, K_ = prepare_wave_vectors(Ks)
|
|
112
|
+
Clone_simplex = jnp.array(simplex(Dsimp))
|
|
113
|
+
|
|
114
|
+
if D == 4:
|
|
115
|
+
N_ITER *= 2
|
|
116
|
+
if D >= 5:
|
|
117
|
+
N_ITER *= 6
|
|
118
|
+
|
|
119
|
+
if high_D:
|
|
120
|
+
a = 2.0 * jnp.pi
|
|
121
|
+
b = 2.0 / (sigma2 * a ** 2)
|
|
122
|
+
c = 1.0 / (2.0 * S * jnp.pi)
|
|
123
|
+
micro_kernel = lambda x_val, y_val: gauss_sin_kernel(x_val, y_val, a, b, c)
|
|
124
|
+
else:
|
|
125
|
+
micro_kernel = lambda x_val, y_val: gauss_kernel(x_val, y_val, sigma2)
|
|
126
|
+
|
|
127
|
+
def micro_grad(x_val):
|
|
128
|
+
def body(acc, shift):
|
|
129
|
+
contrib = micro_kernel(x_val, jnp.roll(x_val, shift, axis=Axes))
|
|
130
|
+
return acc + contrib - jnp.roll(contrib, -shift, axis=Axes), None
|
|
131
|
+
out, _ = jax.lax.scan(body, jnp.zeros_like(x_val), SHIFTS)
|
|
132
|
+
return out
|
|
133
|
+
|
|
134
|
+
def macro_grad(x_val):
|
|
135
|
+
x_flat = x_val.reshape(-1, D + 1)
|
|
136
|
+
def body(acc, args):
|
|
137
|
+
k, k_ = args
|
|
138
|
+
return acc + spectral_kernel(x_flat, k, k_), None
|
|
139
|
+
out, _ = jax.lax.scan(body, jnp.zeros_like(x_flat), (K_w[:, 0], K_[:, 0]))
|
|
140
|
+
return out.reshape(*IJK, D + 1)
|
|
141
|
+
|
|
142
|
+
sn = SquareNet(gridshape=IJK, max_iter=50, verbose=0)
|
|
143
|
+
|
|
144
|
+
def _gridify_numpy(x_val: np.ndarray) -> np.ndarray:
|
|
145
|
+
ctx.tick()
|
|
146
|
+
flat = torus_wrap(np.random.permutation(x_val.reshape(-1, D + 1)) - 0.5)
|
|
147
|
+
sn.fit(flat[:, :D], method="ultimate")
|
|
148
|
+
return sn.map(flat)
|
|
149
|
+
|
|
150
|
+
def gridify(x_val: jnp.ndarray) -> jnp.ndarray:
|
|
151
|
+
return jax.pure_callback(
|
|
152
|
+
_gridify_numpy,
|
|
153
|
+
jax.ShapeDtypeStruct(x_val.shape, x_val.dtype),
|
|
154
|
+
x_val,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def clone(x_val: np.ndarray) -> np.ndarray:
|
|
158
|
+
"""Expand N//(Dsimp+1) parents into N children via simplex offsets."""
|
|
159
|
+
x_val = x_val.reshape(-1, D + 1)
|
|
160
|
+
x_val = x_val[np.isfinite(x_val[:, -1]), :D]
|
|
161
|
+
x_val = np.random.permutation(x_val)
|
|
162
|
+
N_parents = N // (Dsimp + 1)
|
|
163
|
+
N_keep = N - (Dsimp + 1) * N_parents
|
|
164
|
+
offsets = random_rotations(Clone_simplex, N_parents, D, Dsimp) * (expension_factor / Ncbrt)
|
|
165
|
+
children = (x_val[:N_parents, None, :] + offsets).reshape(-1, D)
|
|
166
|
+
if N_keep > 0:
|
|
167
|
+
children = np.concatenate([x_val[N_parents:], children], axis=0)
|
|
168
|
+
return np.asarray(torus_wrap(jnp.array(children)))
|
|
169
|
+
|
|
170
|
+
@jax.jit
|
|
171
|
+
def run_iters(x_val: jnp.ndarray) -> jnp.ndarray:
|
|
172
|
+
def step(i, x_val):
|
|
173
|
+
x_val = jax.lax.cond(
|
|
174
|
+
i % N_PER_STEP == 0,
|
|
175
|
+
gridify,
|
|
176
|
+
lambda val: val,
|
|
177
|
+
x_val,
|
|
178
|
+
)
|
|
179
|
+
return clean_points(torus_wrap(
|
|
180
|
+
x_val
|
|
181
|
+
- (LR_spatial / S) * micro_grad(x_val)
|
|
182
|
+
- (LR_spectral / (Nsqrt * Ncbrt)) * macro_grad(x_val)
|
|
183
|
+
))
|
|
184
|
+
return jax.lax.fori_loop(0, N_ITER * N_PER_STEP, step, x_val)
|
|
185
|
+
|
|
186
|
+
if is_root:
|
|
187
|
+
xparent = _build_bruteforce(N, D, ctx)(x)
|
|
188
|
+
x_pts = prepare_points(np.asarray(xparent), N, IJK, D)
|
|
189
|
+
else:
|
|
190
|
+
N_child = N // (Dsimp + 1) + N % (Dsimp + 1)
|
|
191
|
+
xparent = clone(
|
|
192
|
+
_run_pipeline(
|
|
193
|
+
N=N_child,
|
|
194
|
+
D=D,
|
|
195
|
+
logger=logger,
|
|
196
|
+
S=S,
|
|
197
|
+
expension_factor=expension_factor,
|
|
198
|
+
LR_spatial=LR_spatial,
|
|
199
|
+
LR_spectral=LR_spectral,
|
|
200
|
+
spatial_radius=spatial_radius,
|
|
201
|
+
spectral_radius=spectral_radius,
|
|
202
|
+
N_ITER=N_ITER,
|
|
203
|
+
N_PER_STEP=N_PER_STEP,
|
|
204
|
+
_is_root=False,
|
|
205
|
+
_is_leaf=False,
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
x_pts = prepare_points(xparent, N, IJK, D)
|
|
209
|
+
x_pts = run_iters(x_pts)
|
|
210
|
+
ctx.done()
|
|
211
|
+
|
|
212
|
+
if _is_leaf:
|
|
213
|
+
x_pts = np.array(x_pts.reshape(-1, D + 1))
|
|
214
|
+
x_pts = x_pts[np.isfinite(x_pts[:, -1]), :D]
|
|
215
|
+
|
|
216
|
+
return x_pts
|
|
217
|
+
|
|
218
|
+
finally:
|
|
219
|
+
logger.exit_level()
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# ── Public entry point ────────────────────────────────────────────────────────
|
|
223
|
+
|
|
224
|
+
def sample(
|
|
225
|
+
N: int,
|
|
226
|
+
D: int,
|
|
227
|
+
bruteforce: bool = False,
|
|
228
|
+
N_ITER: int = 6,
|
|
229
|
+
verbose: int = 1,
|
|
230
|
+
) -> np.ndarray:
|
|
231
|
+
"""
|
|
232
|
+
Generate N stealthy points in [0, 1)^D.
|
|
233
|
+
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
N : number of output points.
|
|
237
|
+
D : spatial dimension (2–4, or ≥5 falls back to bruteforce).
|
|
238
|
+
N_ITER : pipeline iterations — more is better but slower.
|
|
239
|
+
verbose : 0 = silent, 1 = live progress.
|
|
240
|
+
"""
|
|
241
|
+
logger = ProgressLogger(verbose)
|
|
242
|
+
|
|
243
|
+
if bruteforce or N <= 3_000 or D >= 5:
|
|
244
|
+
reason = (
|
|
245
|
+
f"D={D} ≥ 5" if D >= 5
|
|
246
|
+
else f"N={N} ≤ 3 000" if N <= 3_000
|
|
247
|
+
else "bruteforce flag"
|
|
248
|
+
)
|
|
249
|
+
ctx = logger.enter_level(N, D, 0, reason)
|
|
250
|
+
blue = _build_bruteforce(N, D, ctx)
|
|
251
|
+
out = np.array(blue())
|
|
252
|
+
logger.exit_level()
|
|
253
|
+
return out
|
|
254
|
+
|
|
255
|
+
preset = _PRESETS[D]
|
|
256
|
+
return _run_pipeline(
|
|
257
|
+
N=N,
|
|
258
|
+
D=D,
|
|
259
|
+
N_ITER=N_ITER,
|
|
260
|
+
logger=logger,
|
|
261
|
+
x=None,
|
|
262
|
+
S=preset["S"],
|
|
263
|
+
expension_factor=preset["expension_factor"],
|
|
264
|
+
LR_spatial=preset["LR_spatial"],
|
|
265
|
+
LR_spectral=preset["LR_spectral"],
|
|
266
|
+
spatial_radius=preset["spatial_radius"],
|
|
267
|
+
spectral_radius=preset["spectral_radius"],
|
|
268
|
+
N_PER_STEP=10,
|
|
269
|
+
_is_root=False,
|
|
270
|
+
_is_leaf=True,
|
|
271
|
+
)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visualisation helpers
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
|
|
10
|
+
from .math_utils import structure_factor as _structure_factor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def plot(
|
|
14
|
+
points: np.ndarray,
|
|
15
|
+
max_scatter: int = 30_000,
|
|
16
|
+
ax: plt.Axes | None = None,
|
|
17
|
+
**scatter_kw,
|
|
18
|
+
) -> plt.Figure:
|
|
19
|
+
"""
|
|
20
|
+
Scatter plot of a 2-D or 3-D point set.
|
|
21
|
+
|
|
22
|
+
For large point sets the view is automatically zoomed so that at most
|
|
23
|
+
*max_scatter* points are displayed.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
points : array-like
|
|
28
|
+
Point coordinates, shape (N, D) with D ∈ {2, 3}.
|
|
29
|
+
Higher-dimensional arrays are silently projected onto the first 3 axes.
|
|
30
|
+
max_scatter : int
|
|
31
|
+
Maximum number of points to draw. Excess points are cropped by
|
|
32
|
+
zooming into the lower-left corner of the domain.
|
|
33
|
+
ax : matplotlib Axes | None
|
|
34
|
+
Existing axes to draw into. When *None* a new figure is created.
|
|
35
|
+
**scatter_kw
|
|
36
|
+
Extra keyword arguments forwarded to ``ax.scatter``.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
fig : matplotlib.figure.Figure
|
|
41
|
+
"""
|
|
42
|
+
pts = np.asarray(points).reshape(-1, np.asarray(points).shape[-1])
|
|
43
|
+
D = min(pts.shape[-1], 3)
|
|
44
|
+
pts = pts[:, :D]
|
|
45
|
+
|
|
46
|
+
if len(pts) > max_scatter:
|
|
47
|
+
zoom = (max_scatter / len(pts)) ** (1.0 / D)
|
|
48
|
+
pts = pts[(pts <= zoom).all(axis=1)]
|
|
49
|
+
|
|
50
|
+
kw = dict(s=0.4, color="black")
|
|
51
|
+
kw.update(scatter_kw)
|
|
52
|
+
|
|
53
|
+
if ax is None:
|
|
54
|
+
fig = plt.figure(figsize=(8, 8))
|
|
55
|
+
if D == 2:
|
|
56
|
+
ax = fig.add_subplot(111)
|
|
57
|
+
else:
|
|
58
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
59
|
+
else:
|
|
60
|
+
fig = ax.get_figure()
|
|
61
|
+
|
|
62
|
+
if D == 2:
|
|
63
|
+
ax.scatter(pts[:, 0], pts[:, 1], **kw)
|
|
64
|
+
else:
|
|
65
|
+
ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], **kw)
|
|
66
|
+
|
|
67
|
+
ax.set_axis_off()
|
|
68
|
+
plt.tight_layout()
|
|
69
|
+
plt.show()
|
|
70
|
+
return fig
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def plot_structure_factor(
|
|
74
|
+
points: np.ndarray,
|
|
75
|
+
bins: int = 100,
|
|
76
|
+
resolution: float = 30.0,
|
|
77
|
+
ax: plt.Axes | None = None,
|
|
78
|
+
**plot_kw,
|
|
79
|
+
) -> plt.Figure:
|
|
80
|
+
"""
|
|
81
|
+
Log-log plot of the radial structure factor S(k).
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
points : (N, D) array
|
|
86
|
+
Point coordinates in [0, 1)^D.
|
|
87
|
+
bins : int
|
|
88
|
+
Number of radial bins for the structure-factor estimate.
|
|
89
|
+
resolution : float
|
|
90
|
+
Random wave-vector density (vectors per bin) for the estimate.
|
|
91
|
+
ax : matplotlib Axes | None
|
|
92
|
+
Existing axes to draw into. When *None* a new figure is created.
|
|
93
|
+
**plot_kw
|
|
94
|
+
Extra keyword arguments forwarded to ``ax.loglog``.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
fig : matplotlib.figure.Figure
|
|
99
|
+
"""
|
|
100
|
+
pts = np.asarray(points).reshape(-1, np.asarray(points).shape[-1])
|
|
101
|
+
k, S = _structure_factor(pts, nbins=bins, resolution=resolution)
|
|
102
|
+
|
|
103
|
+
kw = dict(marker="o", markersize=2, linewidth=1)
|
|
104
|
+
kw.update(plot_kw)
|
|
105
|
+
|
|
106
|
+
if ax is None:
|
|
107
|
+
fig, ax = plt.subplots(figsize=(7, 5))
|
|
108
|
+
else:
|
|
109
|
+
fig = ax.get_figure()
|
|
110
|
+
|
|
111
|
+
ax.loglog(k, S, **kw)
|
|
112
|
+
ax.set_xlabel("k")
|
|
113
|
+
ax.set_ylabel("S(k)")
|
|
114
|
+
ax.set_title("Structure factor (log-log)")
|
|
115
|
+
ax.grid(True, which="both", alpha=0.4)
|
|
116
|
+
plt.tight_layout()
|
|
117
|
+
plt.show()
|
|
118
|
+
return fig
|