blue-sampler 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,71 @@
1
+ Metadata-Version: 2.4
2
+ Name: blue_sampler
3
+ Version: 0.1.0
4
+ Summary: Stealthy point-pattern sampling on the unit torus
5
+ Project-URL: Repository, https://github.com/For-a-few-DPPs-more/hyperuniform-samplers
6
+ License: MIT
7
+ Keywords: blue noise,jax,point process,sampling,stealthy
8
+ Requires-Python: >=3.10
9
+ Requires-Dist: jax
10
+ Requires-Dist: jaxlib
11
+ Requires-Dist: matplotlib
12
+ Requires-Dist: numpy
13
+ Requires-Dist: requests
14
+ Requires-Dist: squarenet
15
+ Description-Content-Type: text/markdown
16
+
17
+ # blue-sampler
18
+
19
+ Generate **stealthy point patterns** — low-discrepancy, spectrally isotropic
20
+ samples on the unit torus [0, 1)^D.
21
+
22
+ Stealthy patterns suppress long-range density fluctuations while remaining
23
+ aperiodic. They are useful in rendering, quadrature, and computational
24
+ physics wherever quasi-random, isotropic spatial coverage is needed.
25
+
26
+ ## Installation
27
+
28
+ ```bash
29
+ pip install blue_sampler
30
+ ```
31
+
32
+ ## Quick start
33
+
34
+ ```python
35
+ import blue_sampler as blue
36
+
37
+ # 10 000 points in 2-D
38
+ x = blue.sample(10_000)
39
+ blue.plot(x)
40
+ blue.plot_structure_factor(x)
41
+
42
+ # 3-D
43
+ x = blue.sample(5_000, D=3)
44
+ ```
45
+ ## Supported dimensions
46
+
47
+ | D | Notes |
48
+ |---|-------|
49
+ | 2 | Fast, recommended for exploration |
50
+ | 3 | ~3× slower than 2-D |
51
+ | 4 | Requires more iterations (set automatically by `Config.auto`) |
52
+ | 5 | Experimental |
53
+
54
+ ## Algorithm overview
55
+
56
+ The pipeline alternates between:
57
+
58
+ 1. **Spatial gradient** — short-range Gaussian repulsion via
59
+ neighbour convolution on the torus.
60
+ 2. **Spectral gradient** — minimises the structure factor S(k) for k below
61
+ a chosen cut-off, using a set of all the wave vectors within an integer
62
+ half-ball.
63
+ 3. **Grid assignment** (SquareNet) — periodic re-assignment to a regular
64
+ grid for efficient sparse local operations.
65
+
66
+ For N ≤ 3 000 a direct O(N²) bootstrap is used. For larger N a
67
+ hierarchical strategy clones and refines a coarser solution.
68
+
69
+ ## License
70
+
71
+ MIT
@@ -0,0 +1,55 @@
1
+ # blue-sampler
2
+
3
+ Generate **stealthy point patterns** — low-discrepancy, spectrally isotropic
4
+ samples on the unit torus [0, 1)^D.
5
+
6
+ Stealthy patterns suppress long-range density fluctuations while remaining
7
+ aperiodic. They are useful in rendering, quadrature, and computational
8
+ physics wherever quasi-random, isotropic spatial coverage is needed.
9
+
10
+ ## Installation
11
+
12
+ ```bash
13
+ pip install blue_sampler
14
+ ```
15
+
16
+ ## Quick start
17
+
18
+ ```python
19
+ import blue_sampler as blue
20
+
21
+ # 10 000 points in 2-D
22
+ x = blue.sample(10_000)
23
+ blue.plot(x)
24
+ blue.plot_structure_factor(x)
25
+
26
+ # 3-D
27
+ x = blue.sample(5_000, D=3)
28
+ ```
29
+ ## Supported dimensions
30
+
31
+ | D | Notes |
32
+ |---|-------|
33
+ | 2 | Fast, recommended for exploration |
34
+ | 3 | ~3× slower than 2-D |
35
+ | 4 | Requires more iterations (set automatically by `Config.auto`) |
36
+ | 5 | Experimental |
37
+
38
+ ## Algorithm overview
39
+
40
+ The pipeline alternates between:
41
+
42
+ 1. **Spatial gradient** — short-range Gaussian repulsion via
43
+ neighbour convolution on the torus.
44
+ 2. **Spectral gradient** — minimises the structure factor S(k) for k below
45
+ a chosen cut-off, using a set of all the wave vectors within an integer
46
+ half-ball.
47
+ 3. **Grid assignment** (SquareNet) — periodic re-assignment to a regular
48
+ grid for efficient sparse local operations.
49
+
50
+ For N ≤ 3 000 a direct O(N²) bootstrap is used. For larger N a
51
+ hierarchical strategy clones and refines a coarser solution.
52
+
53
+ ## License
54
+
55
+ MIT
@@ -0,0 +1,55 @@
1
+ # blue-sampler
2
+
3
+ Generate **stealthy point patterns** — low-discrepancy, spectrally isotropic
4
+ samples on the unit torus [0, 1)^D.
5
+
6
+ Stealthy patterns suppress long-range density fluctuations while remaining
7
+ aperiodic. They are useful in rendering, quadrature, and computational
8
+ physics wherever quasi-random, isotropic spatial coverage is needed.
9
+
10
+ ## Installation
11
+
12
+ ```bash
13
+ pip install blue_sampler
14
+ ```
15
+
16
+ ## Quick start
17
+
18
+ ```python
19
+ import blue_sampler as blue
20
+
21
+ # 10 000 points in 2-D
22
+ x = blue.sample(10_000)
23
+ blue.plot(x)
24
+ blue.plot_structure_factor(x)
25
+
26
+ # 3-D
27
+ x = blue.sample(5_000, D=3)
28
+ ```
29
+ ## Supported dimensions
30
+
31
+ | D | Notes |
32
+ |---|-------|
33
+ | 2 | Fast, recommended for exploration |
34
+ | 3 | ~3× slower than 2-D |
35
+ | 4 | Requires more iterations (set automatically by `Config.auto`) |
36
+ | 5 | Experimental |
37
+
38
+ ## Algorithm overview
39
+
40
+ The pipeline alternates between:
41
+
42
+ 1. **Spatial gradient** — short-range Gaussian repulsion via
43
+ neighbour convolution on the torus.
44
+ 2. **Spectral gradient** — minimises the structure factor S(k) for k below
45
+ a chosen cut-off, using a set of all the wave vectors within an integer
46
+ half-ball.
47
+ 3. **Grid assignment** (SquareNet) — periodic re-assignment to a regular
48
+ grid for efficient sparse local operations.
49
+
50
+ For N ≤ 3 000 a direct O(N²) bootstrap is used. For larger N a
51
+ hierarchical strategy clones and refines a coarser solution.
52
+
53
+ ## License
54
+
55
+ MIT
@@ -0,0 +1,37 @@
1
+ [build-system]
2
+ requires = ["hatchling >= 1.26"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "blue_sampler"
7
+ version = "0.1.0"
8
+ description = "Stealthy point-pattern sampling on the unit torus"
9
+ readme = "README.md"
10
+ license = { text = "MIT" }
11
+ requires-python = ">=3.10"
12
+
13
+ keywords = [
14
+ "blue noise",
15
+ "stealthy",
16
+ "point process",
17
+ "sampling",
18
+ "jax",
19
+ ]
20
+
21
+ dependencies = [
22
+ "numpy",
23
+ "jax",
24
+ "jaxlib",
25
+ "matplotlib",
26
+ "requests",
27
+ "squarenet",
28
+ ]
29
+
30
+ [project.urls]
31
+ Repository = "https://github.com/For-a-few-DPPs-more/hyperuniform-samplers"
32
+
33
+ [tool.hatch.build.targets.wheel]
34
+ packages = ["src/blue_sampler"]
35
+
36
+ [tool.ruff]
37
+ line-length = 100
@@ -0,0 +1,24 @@
1
+ """
2
+ blue_sampler
3
+ ============
4
+
5
+ Generate stealthy point patterns — low-discrepancy, spectrally isotropic
6
+ samples on the unit torus [0, 1)^D.
7
+
8
+ Quick start
9
+ -----------
10
+ >>> import blue_sampler as blue
11
+ >>> x = blue.sample(N=10_000, D=2) # (10000, 2) array
12
+ >>> blue.plot(x)
13
+ >>> blue.plot_structure_factor(x)
14
+ """
15
+ from .run import sample
16
+ from .viz import plot, plot_structure_factor
17
+ from .math_utils import structure_factor
18
+
19
+ __all__ = [
20
+ "sample",
21
+ "plot",
22
+ "plot_structure_factor",
23
+ "structure_factor",
24
+ ]
@@ -0,0 +1,100 @@
1
+ """
2
+ Kernels for energy gradient
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import numpy as np
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from .math_utils import torus_delta, clean_grad
11
+
12
+
13
+
14
+ # ──────────────────────────────────────────────────────────────────────────────
15
+ # Kernel functions (JAX)
16
+ # ──────────────────────────────────────────────────────────────────────────────
17
+
18
+ def gauss_kernel(
19
+ x: jnp.ndarray,
20
+ y: jnp.ndarray,
21
+ sigma2: float,
22
+ ) -> jnp.ndarray:
23
+ """
24
+ Isotropic Gaussian repulsion kernel on the torus.
25
+
26
+ Returns the *gradient* contribution (y − x) * exp(−‖y−x‖² / σ²).
27
+ """
28
+ delta = torus_delta(y - x)
29
+ dist2 = jnp.sum(delta ** 2, axis=-1, keepdims=True)
30
+ return clean_grad(delta * jnp.exp(-dist2 / sigma2))
31
+
32
+
33
+ def gauss_sin_kernel(
34
+ x: jnp.ndarray,
35
+ y: jnp.ndarray,
36
+ a: float,
37
+ b: float,
38
+ c: float,
39
+ ) -> jnp.ndarray:
40
+ """
41
+ Trigonometric Gaussian kernel — more stable in contexts were
42
+ sigma2 is not << 1 e.g. high dimension or low number of points
43
+ -> using discontinuous torus_delta would become problematic.
44
+
45
+ Parameters
46
+ ----------
47
+ a, b, c : pre-computed scale factors (derived from sigma²).
48
+ """
49
+ delta = a * (y - x)
50
+ cos_term = b * (1.0 - jnp.cos(delta))
51
+ sin_term = c * jnp.sin(delta)
52
+ dist2 = jnp.sum(cos_term, axis=-1, keepdims=True)
53
+ return clean_grad(sin_term * jnp.exp(-dist2))
54
+
55
+ def spectral_kernel(x, k, k_):
56
+ """
57
+ spectral kernel directly target spectral energy. Only usable
58
+ for small subsets of preselected wavevectors
59
+ """
60
+ phase = jnp.sum(k * x, axis=-1, keepdims=True)
61
+ ek = clean_grad(jnp.exp(phase))
62
+ Sk = jnp.sum(ek, axis=0, keepdims=True)
63
+ return jnp.real(Sk * k_ * jnp.conjugate(ek))
64
+
65
+ def reduce_kernel(kernel, x, params, init=None):
66
+ """
67
+ Generic reduction over a kernel.
68
+
69
+ Parameters
70
+ ----------
71
+ kernel : callable
72
+ Function of the form kernel(x, param) -> contribution.
73
+ x : array
74
+ State.
75
+ params : PyTree
76
+ Collection of parameters passed to kernel.
77
+ init : array, optional
78
+ Initial accumulator. Defaults to zeros_like(x).
79
+
80
+ Returns
81
+ -------
82
+ array
83
+ Sum of all kernel contributions.
84
+ """
85
+ if init is None:
86
+ init = jnp.zeros_like(x)
87
+
88
+ def body(acc, param):
89
+ return acc + kernel(x, param), None
90
+
91
+ out, _ = jax.lax.scan(body, init, params)
92
+ return out
93
+
94
+ def micro_shift_kernel(x, shift, kernel, Axes):
95
+ contrib = kernel(x, jnp.roll(x, shift, axis=Axes))
96
+ return contrib - jnp.roll(contrib, -shift, axis=Axes)
97
+
98
+ def micro_grad(x_val, SHIFTS, LR_spatial, S):
99
+ out = reduce_kernel(micro_shift_kernel, x_val, SHIFTS)
100
+ return (LR_spatial / S) * out
@@ -0,0 +1,254 @@
1
+ """
2
+ Low-level mathematical helpers
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import numpy as np
8
+ import jax
9
+ import jax.numpy as jnp
10
+
11
+
12
+ # ──────────────────────────────────────────────────────────────────────────────
13
+ # Lattice helpers
14
+ # ──────────────────────────────────────────────────────────────────────────────
15
+
16
+ def drop_symmetric(directions: np.ndarray) -> np.ndarray:
17
+ """
18
+ Keep only one representative from each direction pair {v, -v}.
19
+
20
+ The canonical representative is the one whose *first non-zero component*
21
+ is positive.
22
+
23
+ Parameters
24
+ ----------
25
+ directions : (M, D) int array
26
+
27
+ Returns
28
+ -------
29
+ (K, D) int array with K ≤ M // 2 + 1
30
+ """
31
+ first_nz_idx = (directions != 0).argmax(axis=1)
32
+ first_nz_val = directions[np.arange(len(directions)), first_nz_idx]
33
+ return directions[first_nz_val > 0]
34
+
35
+
36
+ def integers_in_half_ball(radius: float, D: int) -> np.ndarray:
37
+ """
38
+ Return all non-zero integer lattice vectors inside a sphere of *radius*,
39
+ keeping only one vector per direction pair.
40
+
41
+ Parameters
42
+ ----------
43
+ radius : float
44
+ D : int
45
+
46
+ Returns
47
+ -------
48
+ (M, D) int32 array
49
+ """
50
+ if radius <= 0.9:
51
+ return np.zeros((0, D), dtype=np.int32)
52
+ if radius <= 1.9:
53
+ return np.eye(D, dtype=np.int32)
54
+
55
+ r = np.arange(-radius, radius + 1)
56
+ pts = np.stack(np.meshgrid(*(r,) * D, indexing="ij"), axis=-1).reshape(-1, D)
57
+ d2 = np.sum(pts ** 2, axis=-1)
58
+ return drop_symmetric(pts[(d2 > 0) & (d2 <= radius ** 2)])
59
+
60
+
61
+ def simplex(D: int) -> np.ndarray:
62
+ """
63
+ Vertices of a regular simplex centred at the origin in R^D.
64
+
65
+ Returns
66
+ -------
67
+ (D+1, D) float64 array
68
+ """
69
+ if D == 1:
70
+ return np.array([-1.0, 1.0])[:, None]
71
+ null = np.zeros((D, 1))
72
+ tip = np.zeros((1, D))
73
+ tip[0, -1] = 1.0
74
+ base = np.hstack((simplex(D - 1), null))
75
+ return np.vstack((np.sqrt(1.0 - (1.0 / D) ** 2) * base - tip / D, tip))
76
+
77
+
78
+ def grid_shape(N: int, D: int) -> tuple[tuple[int, ...], int, tuple[int, ...]]:
79
+ """
80
+ Smallest D-hypercube grid that contains at least *N* points.
81
+
82
+ Returns
83
+ -------
84
+ IJK : shape tuple e.g. (32, 32) for D=2
85
+ total : total number of grid slots (I^D)
86
+ axes : tuple(range(D))
87
+ """
88
+ I = int(np.ceil(N ** (1.0 / D)))
89
+ IJK = (I,) * D
90
+ return IJK, I ** D, tuple(range(D))
91
+
92
+
93
+ # ──────────────────────────────────────────────────────────────────────────────
94
+ # Torus geometry (JAX)
95
+ # ──────────────────────────────────────────────────────────────────────────────
96
+
97
+ def torus_wrap(x: jnp.ndarray) -> jnp.ndarray:
98
+ """Wrap coordinates into [0, 1)^D."""
99
+ return x - jnp.floor(x)
100
+
101
+
102
+ def torus_delta(delta: jnp.ndarray) -> jnp.ndarray:
103
+ """Shortest signed displacement on the unit torus."""
104
+ return delta - jnp.round(delta)
105
+
106
+
107
+ # ──────────────────────────────────────────────────────────────────────────────
108
+ # Gradient / status helpers (JAX)
109
+ # ──────────────────────────────────────────────────────────────────────────────
110
+
111
+ def clean_grad(x: jnp.ndarray) -> jnp.ndarray:
112
+ """Replace NaN gradient contributions (fictive points) with 0."""
113
+ return jnp.nan_to_num(x, nan=0.0)
114
+
115
+
116
+ def clean_points(x: jnp.ndarray) -> jnp.ndarray:
117
+ """
118
+ Preserve the NaN status flag of empty grid slots after a torus wrap.
119
+
120
+ The last coordinate of each grid slot encodes whether the slot is real
121
+ (0.0) or fictive (NaN). ``torus_wrap`` can corrupt this flag, so we
122
+ re-round it here.
123
+ """
124
+ return x.at[..., -1:].set(jnp.round(x[..., -1:]))
125
+
126
+ # ──────────────────────────────────────────────────────────────────────────────
127
+ # Wave-vector preparation
128
+ # ──────────────────────────────────────────────────────────────────────────────
129
+
130
+ def prepare_wave_vectors(
131
+ Ks: np.ndarray,
132
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
133
+ """
134
+ Build JAX arrays for the spectral gradient.
135
+
136
+ Parameters
137
+ ----------
138
+ Ks : (M, D) integer wave-vector matrix
139
+
140
+ Returns
141
+ -------
142
+ K_w : complex array of shape (M, 1, D+1) — phase multipliers
143
+ K_ : complex array of shape (M, 1, D+1) — normalised duals
144
+ """
145
+ K = 2.0 * jnp.pi * Ks * 1j
146
+ K = jnp.concatenate((K, np.zeros((len(K), 1))), axis=1)[:, None, :]
147
+ Kn = (jnp.abs(K) ** 2).sum(axis=-1, keepdims=True)
148
+ return K, -K / Kn
149
+
150
+
151
+ # ──────────────────────────────────────────────────────────────────────────────
152
+ # Grid initialisation helpers
153
+ # ──────────────────────────────────────────────────────────────────────────────
154
+
155
+ def prepare_points(
156
+ x: np.ndarray | None,
157
+ N_asked: int,
158
+ IJK: tuple[int, ...],
159
+ D: int,
160
+ ) -> jnp.ndarray:
161
+ """
162
+ Pad *N_asked* real points to fill the I^D grid.
163
+
164
+ Fictive slots receive a NaN status coordinate so gradients ignore them.
165
+
166
+ Parameters
167
+ ----------
168
+ x : (N_asked, D) array or *None* (random initialisation).
169
+ N_asked : number of real points.
170
+ IJK : grid shape tuple.
171
+ D : spatial dimension.
172
+
173
+ Returns
174
+ -------
175
+ jnp.ndarray of shape (*IJK, D+1)
176
+ """
177
+ if x is None:
178
+ x = np.random.rand(N_asked, D)
179
+ else:
180
+ x = np.asarray(x).reshape(N_asked, D)
181
+
182
+ total = int(np.prod(IJK))
183
+ xfull = np.random.rand(total, D + 1)
184
+ xfull[:, -1] = 0.0 # status = 0 → real
185
+ xfull[:N_asked, :D] = x
186
+ xfull[N_asked:, D] = np.nan # status = NaN → fictive
187
+ return jnp.array(xfull.reshape(*IJK, D + 1))
188
+
189
+ def random_rotations(x, batch_size, Dout, Din):
190
+ Q, _ = np.linalg.qr(np.random.randn(batch_size, Dout, Din))
191
+ offsets = np.einsum(
192
+ "nij,kj->nki", Q, x
193
+ )
194
+ return offsets
195
+
196
+ # ──────────────────────────────────────────────────────────────────────────────
197
+ # Structure factor
198
+ # ──────────────────────────────────────────────────────────────────────────────
199
+
200
+ def structure_factor(
201
+ points: np.ndarray,
202
+ nbins: int = 100,
203
+ resolution: float = 30.0,
204
+ ) -> tuple[np.ndarray, np.ndarray]:
205
+ """
206
+ Estimate the radial structure factor S(k) via scattering intensity.
207
+
208
+ Parameters
209
+ ----------
210
+ points : (N, D) array of point coordinates in [0, 1)^D.
211
+ nbins : number of radial bins.
212
+ resolution : how many random wave-vectors to sample per bin.
213
+
214
+ Returns
215
+ -------
216
+ k : (M,) float array — bin centres.
217
+ S : (M,) float array — mean S(k) per non-empty bin.
218
+ """
219
+ pts = np.asarray(points)
220
+ N, D = pts.shape
221
+
222
+ kmed = int(1_000 ** (1.0 / D))
223
+ kmax = int(2 * N ** (1.0 / D))
224
+ bins = np.linspace(0, kmax, nbins)
225
+
226
+ # Random + deterministic wave-vector sampling
227
+ nvecs = np.random.randint(-kmax, kmax + 1, size=(int(resolution * nbins), D))
228
+ nvecs = np.concatenate([nvecs, integers_in_half_ball(kmed, D)], axis=0)
229
+ nvecs = nvecs[np.any(nvecs != 0, axis=1)]
230
+
231
+ knorm = np.linalg.norm(nvecs, axis=1)
232
+ bin_idx = np.searchsorted(bins, knorm) - 1
233
+ valid = (bin_idx >= 0) & (bin_idx < len(bins) - 1)
234
+ nvecs, bin_idx = nvecs[valid], bin_idx[valid]
235
+
236
+ kvecs = jnp.array(2.0 * np.pi * nvecs)
237
+ pts_j = jnp.array(pts)
238
+
239
+ def Sk_one(k: jnp.ndarray) -> jnp.ndarray:
240
+ rho = jnp.sum(jnp.exp(1j * (pts_j @ k)), axis=0)
241
+ return jnp.abs(rho) ** 2 / N
242
+
243
+ Sk = np.asarray(jax.lax.map(Sk_one, kvecs))
244
+
245
+ n_bins = len(bins) - 1
246
+ S_sum = np.bincount(bin_idx, weights=Sk, minlength=n_bins)
247
+ counts = np.bincount(bin_idx, minlength=n_bins)
248
+
249
+ S = np.zeros_like(S_sum, dtype=float)
250
+ nz = counts > 0
251
+ S[nz] = S_sum[nz] / counts[nz]
252
+
253
+ centres = 0.5 * (bins[:-1] + bins[1:])
254
+ return centres[nz], S[nz]
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ import time
5
+
6
+
7
+ class ProgressLogger:
8
+ """Hierarchical \r-based progress display for nested pipeline levels."""
9
+
10
+ def __init__(self, verbose: int):
11
+ self.verbose = verbose
12
+ self.level = -1
13
+
14
+ def enter_level(self, N: int, D: int, N_ITER: int) -> _LevelCtx:
15
+ """Push a new recursion level and return its context."""
16
+ self.level += 1
17
+ return _LevelCtx(self, N, D, N_ITER)
18
+
19
+ def exit_level(self) -> None:
20
+ self.level -= 1
21
+
22
+ def _prefix(self) -> str:
23
+ return f"[L{self.level}] "
24
+
25
+ def write(self, msg: str, newline: bool = False) -> None:
26
+ if self.verbose < 1:
27
+ return
28
+ sys.stdout.write(f"\r{self._prefix()}{msg} ")
29
+ if newline:
30
+ sys.stdout.write("\n")
31
+ sys.stdout.flush()
32
+
33
+
34
+ class _LevelCtx:
35
+ """Tracks timing and tick state for a single pipeline level."""
36
+
37
+ def __init__(self, logger: ProgressLogger, N: int, D: int, N_ITER: int):
38
+ self._log = logger
39
+ self.N = N
40
+ self.D = D
41
+ self.N_ITER = N_ITER
42
+ self._tick = 0
43
+ self._t0: float | None = None
44
+ self._t_iter: float | None = None
45
+
46
+ def on_compile(self) -> None:
47
+ self._log.write("compiling JAX kernel…")
48
+
49
+ def on_bruteforce_start(self) -> None:
50
+ self._log.write(f"bruteforce N={self.N} D={self.D} …")
51
+
52
+ def on_bruteforce_done(self) -> None:
53
+ self._log.write("bruteforce done ✓", newline=True)
54
+
55
+ def tick(self) -> None:
56
+ """Called once per gridification callback (= one full iteration). Drives the ETA display."""
57
+ now = time.perf_counter()
58
+ self._tick += 1
59
+
60
+ if self._tick == 1:
61
+ self._t0 = now
62
+ self._log.write(f"{self._bar()} — calibrating…")
63
+ return
64
+
65
+ if self._tick == 2:
66
+ self._t_iter = now - self._t0 # type: ignore[operator]
67
+
68
+ self._log.write(f"{self._bar()} — {self._eta(now)} remaining")
69
+
70
+ def done(self) -> None:
71
+ self._log.write(f"{self._bar(done=True)} — done ✓", newline=True)
72
+
73
+ def _bar(self, done: bool = False) -> str:
74
+ filled = self.N_ITER if done else max(0, self._tick - 1)
75
+ W = 20
76
+ n_fill = int(W * filled / self.N_ITER)
77
+ bar = "▓" * n_fill + "░" * (W - n_fill)
78
+ return f"{filled}/{self.N_ITER} [{bar}]"
79
+
80
+ def _eta(self, now: float) -> str:
81
+ if self._t_iter is None:
82
+ return "?"
83
+ remaining = (self.N_ITER - (self._tick - 1)) * self._t_iter
84
+ return f"~{remaining:.0f}s" if remaining < 60 else f"~{remaining / 60:.1f}min"
@@ -0,0 +1,271 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from squarenet import SquareNet
7
+
8
+ from .math_utils import (
9
+ integers_in_half_ball,
10
+ simplex,
11
+ grid_shape,
12
+ torus_wrap,
13
+ clean_points,
14
+ prepare_wave_vectors,
15
+ prepare_points,
16
+ random_rotations,
17
+ )
18
+ from .kernels import (
19
+ gauss_kernel,
20
+ gauss_sin_kernel,
21
+ spectral_kernel,
22
+ )
23
+ from .progress import ProgressLogger, _LevelCtx
24
+
25
+
26
+ _PRESETS: dict[int, dict] = {
27
+ 2: dict(spatial_radius=7, spectral_radius=7, LR_spatial=0.1, LR_spectral=0.1, expension_factor=0.3, S=1.0),
28
+ 3: dict(spatial_radius=5, spectral_radius=5, LR_spatial=0.1, LR_spectral=0.1, expension_factor=0.3, S=1.0),
29
+ 4: dict(spatial_radius=3, spectral_radius=3, LR_spatial=0.01, LR_spectral=0.1, expension_factor=1.0, S=0.5),
30
+ }
31
+
32
+
33
+ # ── Bruteforce (small N) ──────────────────────────────────────────────────────
34
+
35
+ def _build_bruteforce(N: int, D: int, ctx: _LevelCtx):
36
+ """AOT-compile a gradient-descent sampler for N ≤ ~3 000 points."""
37
+ DX = 1.0 / N ** (1.0 / D)
38
+ S = 1.0
39
+ sigma2 = S * 2.0 * DX ** 2
40
+ high_D = sigma2 >= 0.03
41
+
42
+ lr_table = {2: 0.4, 3: 0.1, 4: 0.05, 5: 0.01}
43
+ lr = lr_table.get(D, 0.01)
44
+ Niter = 1_000 if high_D else 3_000
45
+
46
+ if high_D:
47
+ a = 2.0 * jnp.pi
48
+ b = 2.0 / (sigma2 * a ** 2)
49
+ c = 1.0 / (2.0 * S * jnp.pi)
50
+ kernel = lambda x, y: gauss_sin_kernel(x, y, a, b, c)
51
+ else:
52
+ kernel = lambda x, y: gauss_kernel(x, y, sigma2)
53
+
54
+ def grad(x):
55
+ return jax.vmap(lambda xi: kernel(xi[None], x).sum(axis=0))(x)
56
+
57
+ @jax.jit
58
+ def _run(x):
59
+ def step(_, x):
60
+ return torus_wrap(x - lr * grad(x))
61
+ return jax.lax.fori_loop(0, Niter, step, x)
62
+
63
+ ctx.on_compile()
64
+ compiled = _run.lower(jax.ShapeDtypeStruct((N, D), jnp.float32)).compile()
65
+
66
+ def sample_fn(init: np.ndarray | None = None) -> jnp.ndarray:
67
+ ctx.on_bruteforce_start()
68
+ if init is None:
69
+ init = np.random.rand(N, D)
70
+ out = compiled(jnp.asarray(init))
71
+ out.block_until_ready()
72
+ ctx.on_bruteforce_done()
73
+ return out
74
+
75
+ return sample_fn
76
+
77
+
78
+ # ── Core pipeline ─────────────────────────────────────────────────────────────
79
+
80
+ def _run_pipeline(
81
+ N: int,
82
+ D: int,
83
+ N_ITER: int,
84
+ logger: ProgressLogger,
85
+ *,
86
+ x: np.ndarray | None = None,
87
+ S: float,
88
+ expension_factor: float,
89
+ LR_spatial: float,
90
+ LR_spectral: float,
91
+ spatial_radius: int,
92
+ spectral_radius: int,
93
+ N_PER_STEP: int,
94
+ _is_root: bool = False,
95
+ _is_leaf: bool = True,
96
+ ) -> np.ndarray:
97
+ """Recursive stealthy-sampling pipeline. Spawns child pipelines when N is large."""
98
+ ctx = logger.enter_level(N, D, N_ITER)
99
+
100
+ try:
101
+ Dsimp = min(D, 3)
102
+ IJK, _, Axes = grid_shape(N, D)
103
+ Nsqrt = N ** 0.5
104
+ Ncbrt = N ** (1.0 / D)
105
+ is_root = _is_root or (N <= 2_000) or (x is not None)
106
+ sigma2 = S * 2.0 * (1.0 / Ncbrt) ** 2
107
+ high_D = sigma2 >= 0.03
108
+
109
+ SHIFTS = integers_in_half_ball(spatial_radius, D)
110
+ Ks = integers_in_half_ball(spectral_radius, D)
111
+ K_w, K_ = prepare_wave_vectors(Ks)
112
+ Clone_simplex = jnp.array(simplex(Dsimp))
113
+
114
+ if D == 4:
115
+ N_ITER *= 2
116
+ if D >= 5:
117
+ N_ITER *= 6
118
+
119
+ if high_D:
120
+ a = 2.0 * jnp.pi
121
+ b = 2.0 / (sigma2 * a ** 2)
122
+ c = 1.0 / (2.0 * S * jnp.pi)
123
+ micro_kernel = lambda x_val, y_val: gauss_sin_kernel(x_val, y_val, a, b, c)
124
+ else:
125
+ micro_kernel = lambda x_val, y_val: gauss_kernel(x_val, y_val, sigma2)
126
+
127
+ def micro_grad(x_val):
128
+ def body(acc, shift):
129
+ contrib = micro_kernel(x_val, jnp.roll(x_val, shift, axis=Axes))
130
+ return acc + contrib - jnp.roll(contrib, -shift, axis=Axes), None
131
+ out, _ = jax.lax.scan(body, jnp.zeros_like(x_val), SHIFTS)
132
+ return out
133
+
134
+ def macro_grad(x_val):
135
+ x_flat = x_val.reshape(-1, D + 1)
136
+ def body(acc, args):
137
+ k, k_ = args
138
+ return acc + spectral_kernel(x_flat, k, k_), None
139
+ out, _ = jax.lax.scan(body, jnp.zeros_like(x_flat), (K_w[:, 0], K_[:, 0]))
140
+ return out.reshape(*IJK, D + 1)
141
+
142
+ sn = SquareNet(gridshape=IJK, max_iter=50, verbose=0)
143
+
144
+ def _gridify_numpy(x_val: np.ndarray) -> np.ndarray:
145
+ ctx.tick()
146
+ flat = torus_wrap(np.random.permutation(x_val.reshape(-1, D + 1)) - 0.5)
147
+ sn.fit(flat[:, :D], method="ultimate")
148
+ return sn.map(flat)
149
+
150
+ def gridify(x_val: jnp.ndarray) -> jnp.ndarray:
151
+ return jax.pure_callback(
152
+ _gridify_numpy,
153
+ jax.ShapeDtypeStruct(x_val.shape, x_val.dtype),
154
+ x_val,
155
+ )
156
+
157
+ def clone(x_val: np.ndarray) -> np.ndarray:
158
+ """Expand N//(Dsimp+1) parents into N children via simplex offsets."""
159
+ x_val = x_val.reshape(-1, D + 1)
160
+ x_val = x_val[np.isfinite(x_val[:, -1]), :D]
161
+ x_val = np.random.permutation(x_val)
162
+ N_parents = N // (Dsimp + 1)
163
+ N_keep = N - (Dsimp + 1) * N_parents
164
+ offsets = random_rotations(Clone_simplex, N_parents, D, Dsimp) * (expension_factor / Ncbrt)
165
+ children = (x_val[:N_parents, None, :] + offsets).reshape(-1, D)
166
+ if N_keep > 0:
167
+ children = np.concatenate([x_val[N_parents:], children], axis=0)
168
+ return np.asarray(torus_wrap(jnp.array(children)))
169
+
170
+ @jax.jit
171
+ def run_iters(x_val: jnp.ndarray) -> jnp.ndarray:
172
+ def step(i, x_val):
173
+ x_val = jax.lax.cond(
174
+ i % N_PER_STEP == 0,
175
+ gridify,
176
+ lambda val: val,
177
+ x_val,
178
+ )
179
+ return clean_points(torus_wrap(
180
+ x_val
181
+ - (LR_spatial / S) * micro_grad(x_val)
182
+ - (LR_spectral / (Nsqrt * Ncbrt)) * macro_grad(x_val)
183
+ ))
184
+ return jax.lax.fori_loop(0, N_ITER * N_PER_STEP, step, x_val)
185
+
186
+ if is_root:
187
+ xparent = _build_bruteforce(N, D, ctx)(x)
188
+ x_pts = prepare_points(np.asarray(xparent), N, IJK, D)
189
+ else:
190
+ N_child = N // (Dsimp + 1) + N % (Dsimp + 1)
191
+ xparent = clone(
192
+ _run_pipeline(
193
+ N=N_child,
194
+ D=D,
195
+ logger=logger,
196
+ S=S,
197
+ expension_factor=expension_factor,
198
+ LR_spatial=LR_spatial,
199
+ LR_spectral=LR_spectral,
200
+ spatial_radius=spatial_radius,
201
+ spectral_radius=spectral_radius,
202
+ N_ITER=N_ITER,
203
+ N_PER_STEP=N_PER_STEP,
204
+ _is_root=False,
205
+ _is_leaf=False,
206
+ )
207
+ )
208
+ x_pts = prepare_points(xparent, N, IJK, D)
209
+ x_pts = run_iters(x_pts)
210
+ ctx.done()
211
+
212
+ if _is_leaf:
213
+ x_pts = np.array(x_pts.reshape(-1, D + 1))
214
+ x_pts = x_pts[np.isfinite(x_pts[:, -1]), :D]
215
+
216
+ return x_pts
217
+
218
+ finally:
219
+ logger.exit_level()
220
+
221
+
222
+ # ── Public entry point ────────────────────────────────────────────────────────
223
+
224
+ def sample(
225
+ N: int,
226
+ D: int,
227
+ bruteforce: bool = False,
228
+ N_ITER: int = 6,
229
+ verbose: int = 1,
230
+ ) -> np.ndarray:
231
+ """
232
+ Generate N stealthy points in [0, 1)^D.
233
+
234
+ Parameters
235
+ ----------
236
+ N : number of output points.
237
+ D : spatial dimension (2–4, or ≥5 falls back to bruteforce).
238
+ N_ITER : pipeline iterations — more is better but slower.
239
+ verbose : 0 = silent, 1 = live progress.
240
+ """
241
+ logger = ProgressLogger(verbose)
242
+
243
+ if bruteforce or N <= 3_000 or D >= 5:
244
+ reason = (
245
+ f"D={D} ≥ 5" if D >= 5
246
+ else f"N={N} ≤ 3 000" if N <= 3_000
247
+ else "bruteforce flag"
248
+ )
249
+ ctx = logger.enter_level(N, D, 0, reason)
250
+ blue = _build_bruteforce(N, D, ctx)
251
+ out = np.array(blue())
252
+ logger.exit_level()
253
+ return out
254
+
255
+ preset = _PRESETS[D]
256
+ return _run_pipeline(
257
+ N=N,
258
+ D=D,
259
+ N_ITER=N_ITER,
260
+ logger=logger,
261
+ x=None,
262
+ S=preset["S"],
263
+ expension_factor=preset["expension_factor"],
264
+ LR_spatial=preset["LR_spatial"],
265
+ LR_spectral=preset["LR_spectral"],
266
+ spatial_radius=preset["spatial_radius"],
267
+ spectral_radius=preset["spectral_radius"],
268
+ N_PER_STEP=10,
269
+ _is_root=False,
270
+ _is_leaf=True,
271
+ )
@@ -0,0 +1,118 @@
1
+ """
2
+ Visualisation helpers
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+
10
+ from .math_utils import structure_factor as _structure_factor
11
+
12
+
13
+ def plot(
14
+ points: np.ndarray,
15
+ max_scatter: int = 30_000,
16
+ ax: plt.Axes | None = None,
17
+ **scatter_kw,
18
+ ) -> plt.Figure:
19
+ """
20
+ Scatter plot of a 2-D or 3-D point set.
21
+
22
+ For large point sets the view is automatically zoomed so that at most
23
+ *max_scatter* points are displayed.
24
+
25
+ Parameters
26
+ ----------
27
+ points : array-like
28
+ Point coordinates, shape (N, D) with D ∈ {2, 3}.
29
+ Higher-dimensional arrays are silently projected onto the first 3 axes.
30
+ max_scatter : int
31
+ Maximum number of points to draw. Excess points are cropped by
32
+ zooming into the lower-left corner of the domain.
33
+ ax : matplotlib Axes | None
34
+ Existing axes to draw into. When *None* a new figure is created.
35
+ **scatter_kw
36
+ Extra keyword arguments forwarded to ``ax.scatter``.
37
+
38
+ Returns
39
+ -------
40
+ fig : matplotlib.figure.Figure
41
+ """
42
+ pts = np.asarray(points).reshape(-1, np.asarray(points).shape[-1])
43
+ D = min(pts.shape[-1], 3)
44
+ pts = pts[:, :D]
45
+
46
+ if len(pts) > max_scatter:
47
+ zoom = (max_scatter / len(pts)) ** (1.0 / D)
48
+ pts = pts[(pts <= zoom).all(axis=1)]
49
+
50
+ kw = dict(s=0.4, color="black")
51
+ kw.update(scatter_kw)
52
+
53
+ if ax is None:
54
+ fig = plt.figure(figsize=(8, 8))
55
+ if D == 2:
56
+ ax = fig.add_subplot(111)
57
+ else:
58
+ ax = fig.add_subplot(111, projection="3d")
59
+ else:
60
+ fig = ax.get_figure()
61
+
62
+ if D == 2:
63
+ ax.scatter(pts[:, 0], pts[:, 1], **kw)
64
+ else:
65
+ ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], **kw)
66
+
67
+ ax.set_axis_off()
68
+ plt.tight_layout()
69
+ plt.show()
70
+ return fig
71
+
72
+
73
+ def plot_structure_factor(
74
+ points: np.ndarray,
75
+ bins: int = 100,
76
+ resolution: float = 30.0,
77
+ ax: plt.Axes | None = None,
78
+ **plot_kw,
79
+ ) -> plt.Figure:
80
+ """
81
+ Log-log plot of the radial structure factor S(k).
82
+
83
+ Parameters
84
+ ----------
85
+ points : (N, D) array
86
+ Point coordinates in [0, 1)^D.
87
+ bins : int
88
+ Number of radial bins for the structure-factor estimate.
89
+ resolution : float
90
+ Random wave-vector density (vectors per bin) for the estimate.
91
+ ax : matplotlib Axes | None
92
+ Existing axes to draw into. When *None* a new figure is created.
93
+ **plot_kw
94
+ Extra keyword arguments forwarded to ``ax.loglog``.
95
+
96
+ Returns
97
+ -------
98
+ fig : matplotlib.figure.Figure
99
+ """
100
+ pts = np.asarray(points).reshape(-1, np.asarray(points).shape[-1])
101
+ k, S = _structure_factor(pts, nbins=bins, resolution=resolution)
102
+
103
+ kw = dict(marker="o", markersize=2, linewidth=1)
104
+ kw.update(plot_kw)
105
+
106
+ if ax is None:
107
+ fig, ax = plt.subplots(figsize=(7, 5))
108
+ else:
109
+ fig = ax.get_figure()
110
+
111
+ ax.loglog(k, S, **kw)
112
+ ax.set_xlabel("k")
113
+ ax.set_ylabel("S(k)")
114
+ ax.set_title("Structure factor (log-log)")
115
+ ax.grid(True, which="both", alpha=0.4)
116
+ plt.tight_layout()
117
+ plt.show()
118
+ return fig