invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,671 @@
1
+ """Defines Gaussian radial basis function levelset parameterization.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+ from typing import Any, Tuple
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import optax # type: ignore[import-untyped]
12
+ from jax import tree_util
13
+ from totypes import json_utils, symmetry, types
14
+
15
+ from invrs_opt.parameterization import base, transforms
16
+
17
+ PyTree = Any
18
+
19
+ DEFAULT_LENGTH_SCALE_SPACING_FACTOR: float = 2.0
20
+ DEFAULT_LENGTH_SCALE_FWHM_FACTOR: float = 1.0
21
+ DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR: float = 1.15
22
+ DEFAULT_SMOOTHING_FACTOR: int = 2
23
+ DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA: float = 0.333
24
+ DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT: float = 1.0
25
+ DEFAULT_CURVATURE_CONSTRAINT_WEIGHT: float = 2.0
26
+ DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT: float = 10.0
27
+ DEFAULT_INIT_STEPS: int = 50
28
+ DEFAULT_INIT_OPTIMIZER: optax.GradientTransformation = optax.adam(1e-1)
29
+
30
+
31
+ @dataclasses.dataclass
32
+ class GaussianLevelsetParams(base.ParameterizedDensity2DArray):
33
+ """Stores parameters for the Gaussian levelset parameterization."""
34
+
35
+ latents: "GaussianLevelsetLatents"
36
+ metadata: "GaussianLevelsetMetadata"
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class GaussianLevelsetLatents(base.LatentsBase):
41
+ """Stores latent parameters for the Gaussian levelset parameterization.
42
+
43
+ Attributes:
44
+ amplitude: Array giving the amplitude of the Gaussian basis function at
45
+ levelset control points.
46
+ """
47
+
48
+ amplitude: jnp.ndarray
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class GaussianLevelsetMetadata(base.MetadataBase):
53
+ """Stores metadata for the Gaussian levelset parameterization.
54
+
55
+ Attributes:
56
+ length_scale_spacing_factor: The number of levelset control points per unit of
57
+ minimum length scale (mean of density minimum width and minimum spacing).
58
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
59
+ the minimum length scale.
60
+ smoothing_factor: For values greater than 1, the density is initially computed
61
+ at higher resolution and then downsampled, yielding smoother geometries.
62
+ density_shape: Shape of the density array obtained from the parameters.
63
+ density_metadata: Metadata for the density array obtained from the parameters.
64
+ """
65
+
66
+ length_scale_spacing_factor: float
67
+ length_scale_fwhm_factor: float
68
+ smoothing_factor: int
69
+ density_shape: Tuple[int, ...]
70
+ density_metadata: base.Density2DMetadata
71
+
72
+ def __post_init__(self) -> None:
73
+ self.density_shape = tuple(self.density_shape)
74
+
75
+
76
+ tree_util.register_dataclass(
77
+ GaussianLevelsetParams,
78
+ data_fields=["latents", "metadata"],
79
+ meta_fields=[],
80
+ )
81
+ tree_util.register_dataclass(
82
+ GaussianLevelsetLatents,
83
+ data_fields=["amplitude"],
84
+ meta_fields=[],
85
+ )
86
+ tree_util.register_dataclass(
87
+ GaussianLevelsetMetadata,
88
+ data_fields=[
89
+ "length_scale_spacing_factor",
90
+ "length_scale_fwhm_factor",
91
+ "density_metadata",
92
+ ],
93
+ meta_fields=["density_shape", "smoothing_factor"],
94
+ )
95
+ json_utils.register_custom_type(GaussianLevelsetParams)
96
+ json_utils.register_custom_type(GaussianLevelsetLatents)
97
+ json_utils.register_custom_type(GaussianLevelsetMetadata)
98
+
99
+
100
+ def gaussian_levelset(
101
+ *,
102
+ length_scale_spacing_factor: float = DEFAULT_LENGTH_SCALE_SPACING_FACTOR,
103
+ length_scale_fwhm_factor: float = DEFAULT_LENGTH_SCALE_FWHM_FACTOR,
104
+ length_scale_constraint_factor: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR,
105
+ smoothing_factor: int = DEFAULT_SMOOTHING_FACTOR,
106
+ length_scale_constraint_beta: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA,
107
+ length_scale_constraint_weight: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT,
108
+ curvature_constraint_weight: float = DEFAULT_CURVATURE_CONSTRAINT_WEIGHT,
109
+ fixed_pixel_constraint_weight: float = DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT,
110
+ init_optimizer: optax.GradientTransformation = DEFAULT_INIT_OPTIMIZER,
111
+ init_steps: int = DEFAULT_INIT_STEPS,
112
+ ) -> base.Density2DParameterization:
113
+ """Defines a levelset parameterization with Gaussian radial basis functions.
114
+
115
+ Args:
116
+ length_scale_spacing_factor: The number of levelset control points per unit of
117
+ minimum length scale (mean of density minimum width and minimum spacing).
118
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
119
+ the minimum length scale.
120
+ length_scale_constraint_factor: Multiplies the target length scale in the
121
+ levelset constraints. A value greater than 1 is pessimistic and drives the
122
+ solution to have a larger length scale (relative to smaller values).
123
+ smoothing_factor: For values greater than 1, the density is initially computed
124
+ at higher resolution and then downsampled, yielding smoother geometries.
125
+ length_scale_constraint_beta: Controls relaxation of the length scale
126
+ constraint near the zero level.
127
+ length_scale_constraint_weight: The weight of the length scale constraint in
128
+ the overall fabrication constraint peenalty.
129
+ curvature_constraint_weight: The weight of the curvature constraint.
130
+ fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
131
+ init_optimizer: The optimizer used in the initialization of the levelset
132
+ parameterization. At initialization, the latent parameters are optimized so
133
+ that the initial parameters match the binarized initial density.
134
+ init_steps: The number of optimization steps used in the initialization.
135
+
136
+ Returns:
137
+ The `Density2DParameterization`.
138
+ """
139
+
140
+ def from_density_fn(density: types.Density2DArray) -> GaussianLevelsetParams:
141
+ """Return level set parameters for the given `density`."""
142
+ length_scale = (density.minimum_width + density.minimum_spacing) / 2
143
+ spacing_factor = length_scale_spacing_factor / length_scale
144
+ shape = density.shape[:-2] + (
145
+ int(jnp.ceil(density.shape[-2] * spacing_factor)),
146
+ int(jnp.ceil(density.shape[-1] * spacing_factor)),
147
+ )
148
+
149
+ mid_value = 0.5 * (density.lower_bound + density.upper_bound)
150
+ value_range = density.upper_bound - density.lower_bound
151
+ target_array = transforms.apply_fixed_pixels(density).array
152
+ target_array = (
153
+ jnp.sign(target_array - mid_value) * 0.5 * value_range + mid_value
154
+ )
155
+
156
+ # Generate the initial amplitude array.
157
+ amplitude = density.array - (density.upper_bound + density.lower_bound) / 2
158
+ amplitude = jnp.sign(amplitude)
159
+ amplitude = transforms.resample(amplitude, shape)
160
+
161
+ # If the density is not periodic, ensure there are level set control points
162
+ # beyond the edge of the density array.
163
+ pad_width = ((0, 0),) * (amplitude.ndim - 2)
164
+ pad_width += ((0, 0),) if density.periodic[0] else ((1, 1),)
165
+ pad_width += ((0, 0),) if density.periodic[1] else ((1, 1),)
166
+ amplitude = jnp.pad(amplitude, pad_width, mode="edge")
167
+
168
+ latents = GaussianLevelsetLatents(amplitude=amplitude)
169
+ metadata = GaussianLevelsetMetadata(
170
+ length_scale_spacing_factor=length_scale_spacing_factor,
171
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
172
+ smoothing_factor=smoothing_factor,
173
+ density_shape=density.shape,
174
+ density_metadata=base.Density2DMetadata.from_density(density),
175
+ )
176
+
177
+ def step_fn(
178
+ _: int,
179
+ params_and_state: Tuple[PyTree, PyTree],
180
+ ) -> Tuple[PyTree, PyTree]:
181
+ def loss_fn(latents: GaussianLevelsetLatents) -> jnp.ndarray:
182
+ params = GaussianLevelsetParams(latents, metadata=metadata)
183
+ density_from_params = to_density_fn(params, mask_gradient=False)
184
+ return jnp.mean((density_from_params.array - target_array) ** 2)
185
+
186
+ params, state = params_and_state
187
+ grad = jax.grad(loss_fn)(params)
188
+ updates, state = init_optimizer.update(grad, params=params, state=state)
189
+ params = optax.apply_updates(params, updates)
190
+ return params, state
191
+
192
+ state = init_optimizer.init(latents)
193
+ latents, _ = jax.lax.fori_loop(
194
+ 0, init_steps, body_fun=step_fn, init_val=(latents, state)
195
+ )
196
+
197
+ maxval = jnp.amax(jnp.abs(latents.amplitude), axis=(-2, -1), keepdims=True)
198
+ latents = dataclasses.replace(latents, amplitude=latents.amplitude / maxval)
199
+ return GaussianLevelsetParams(latents=latents, metadata=metadata)
200
+
201
+ def to_density_fn(
202
+ params: GaussianLevelsetParams,
203
+ mask_gradient: bool = True,
204
+ ) -> types.Density2DArray:
205
+ """Return a density from the latent parameters."""
206
+ array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=0)
207
+
208
+ example_density = _example_density(params)
209
+ lb = example_density.lower_bound
210
+ ub = example_density.upper_bound
211
+ array = lb + array * (ub - lb)
212
+ assert array.shape == example_density.shape
213
+ return dataclasses.replace(example_density, array=array)
214
+
215
+ def constraints_fn(
216
+ params: GaussianLevelsetParams,
217
+ mask_gradient: bool = True,
218
+ pad_pixels: int = 2,
219
+ ) -> jnp.ndarray:
220
+ """Computes constraints associated with the params."""
221
+ return analytical_constraints(
222
+ params=params,
223
+ length_scale_constraint_factor=length_scale_constraint_factor,
224
+ length_scale_constraint_beta=length_scale_constraint_beta,
225
+ length_scale_constraint_weight=length_scale_constraint_weight,
226
+ curvature_constraint_weight=curvature_constraint_weight,
227
+ fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
228
+ mask_gradient=mask_gradient,
229
+ pad_pixels=pad_pixels,
230
+ )
231
+
232
+ def update_fn(
233
+ params: GaussianLevelsetParams,
234
+ updates: GaussianLevelsetParams,
235
+ value: jnp.ndarray,
236
+ step: int,
237
+ ) -> GaussianLevelsetParams:
238
+ """Perform updates to `params` required for the given `step`."""
239
+ del step, value
240
+ return GaussianLevelsetParams(
241
+ latents=tree_util.tree_map(
242
+ lambda a, b: a + b, params.latents, updates.latents
243
+ ),
244
+ metadata=params.metadata,
245
+ )
246
+
247
+ return base.Density2DParameterization(
248
+ to_density=to_density_fn,
249
+ from_density=from_density_fn,
250
+ constraints=constraints_fn,
251
+ update=update_fn,
252
+ )
253
+
254
+
255
+ # -----------------------------------------------------------------------------
256
+ # Functions to obtain arrays from the levelset parameterization.
257
+ # -----------------------------------------------------------------------------
258
+
259
+
260
+ def _example_density(params: GaussianLevelsetParams) -> types.Density2DArray:
261
+ """Returns an example density with appropriate shape and metadata."""
262
+ with jax.ensure_compile_time_eval():
263
+ return types.Density2DArray(
264
+ array=jnp.zeros(params.metadata.density_shape),
265
+ **dataclasses.asdict(params.metadata.density_metadata),
266
+ )
267
+
268
+
269
+ def _to_array(
270
+ params: GaussianLevelsetParams,
271
+ mask_gradient: bool,
272
+ pad_pixels: int,
273
+ ) -> jnp.ndarray:
274
+ """Return an array from the parameters.
275
+
276
+ The array has a value of `1` where the levelset array is positive, and a value
277
+ of `-1` elsewhere. The final density array can be obtained by rescaling this array
278
+ to have the appropriate upper and lower bounds.
279
+
280
+ Args:
281
+ params: The parameters from which the density is obtained.
282
+ mask_gradient: If `True`, the gradient is masked so that it is nonzero only at
283
+ the borders of features.
284
+ pad_pixels: A non-negative integer giving the additional pixels to be included
285
+ beyond the boundaries of the parameterized density.
286
+
287
+ Returns:
288
+ The array.
289
+ """
290
+ example_density = _example_density(params)
291
+ periodic: Tuple[bool, bool] = example_density.periodic
292
+ phi = _phi_from_params(params=params, pad_pixels=pad_pixels)
293
+ array = _levelset_threshold(phi=phi, periodic=periodic, mask_gradient=mask_gradient)
294
+ return _downsample_spatial_dims(array, params.metadata.smoothing_factor)
295
+
296
+
297
+ def _phi_from_params(
298
+ params: GaussianLevelsetParams,
299
+ pad_pixels: int,
300
+ ) -> jnp.ndarray:
301
+ """Return the levelset function for the given `params`.
302
+
303
+ Args:
304
+ params: The parameters from which the density is obtained.
305
+ pad_pixels: A non-negative integer giving the additional pixels to be included
306
+ beyond the boundaries of the parameterized density.
307
+
308
+ Returns:
309
+ The levelset array `phi`.
310
+ """
311
+ with jax.ensure_compile_time_eval():
312
+ example_density = _example_density(params)
313
+ length_scale = 0.5 * (
314
+ example_density.minimum_width + example_density.minimum_spacing
315
+ )
316
+ fwhm = length_scale * params.metadata.length_scale_fwhm_factor
317
+ sigma = fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
318
+
319
+ s_factor = params.metadata.smoothing_factor
320
+ highres_i = (
321
+ 0.5
322
+ + jnp.arange(
323
+ s_factor * (-pad_pixels),
324
+ s_factor * (pad_pixels + example_density.shape[-2]),
325
+ )
326
+ ) / s_factor
327
+ highres_j = (
328
+ 0.5
329
+ + jnp.arange(
330
+ s_factor * (-pad_pixels),
331
+ s_factor * (pad_pixels + example_density.shape[-1]),
332
+ )
333
+ ) / s_factor
334
+
335
+ # Coordinates for the control points of the Gaussian radial basis functions.
336
+ levelset_i, levelset_j = _control_point_coords(
337
+ density_shape=params.metadata.density_shape[-2:], # type: ignore[arg-type]
338
+ levelset_shape=(
339
+ params.latents.amplitude.shape[-2:] # type: ignore[arg-type]
340
+ ),
341
+ periodic=example_density.periodic,
342
+ )
343
+
344
+ # Handle periodicity by replicating control points over a 3x3 supercell.
345
+ if example_density.periodic[0]:
346
+ levelset_i = jnp.concatenate(
347
+ [
348
+ levelset_i - example_density.shape[-2],
349
+ levelset_i,
350
+ levelset_i + example_density.shape[-2],
351
+ ],
352
+ axis=-2,
353
+ )
354
+ levelset_j = jnp.concatenate([levelset_j] * 3, axis=-2)
355
+ if example_density.periodic[1]:
356
+ levelset_i = jnp.concatenate([levelset_i] * 3, axis=-1)
357
+ levelset_j = jnp.concatenate(
358
+ [
359
+ levelset_j - example_density.shape[-1],
360
+ levelset_j,
361
+ levelset_j + example_density.shape[-1],
362
+ ],
363
+ axis=-1,
364
+ )
365
+
366
+ levelset_i = levelset_i.flatten()
367
+ levelset_j = levelset_j.flatten()
368
+
369
+ amplitude = params.latents.amplitude
370
+ if example_density.periodic[0]:
371
+ amplitude = jnp.concat([amplitude] * 3, axis=-2)
372
+ if example_density.periodic[1]:
373
+ amplitude = jnp.concat([amplitude] * 3, axis=-1)
374
+
375
+ amplitude = amplitude.reshape(amplitude.shape[:-2] + (1, -1))
376
+
377
+ # Use a scan operation to compute the array; this lowers memory consumption.
378
+ def scan_fn(_: Tuple[()], i: jnp.ndarray) -> Tuple[Tuple[()], jnp.ndarray]:
379
+ distance_sq = (i - levelset_i) ** 2 + (
380
+ highres_j[:, jnp.newaxis] - levelset_j
381
+ ) ** 2
382
+ basis = jnp.exp(-distance_sq / sigma**2)
383
+ return (), jnp.sum(basis * amplitude, axis=-1)
384
+
385
+ _, array = jax.lax.scan(scan_fn, (), xs=highres_i)
386
+ array = jnp.moveaxis(array, 0, -2)
387
+
388
+ assert array.shape[-2] % s_factor == 0
389
+ assert array.shape[-1] % s_factor == 0
390
+ array = symmetry.symmetrize(array, tuple(example_density.symmetries))
391
+ return array
392
+
393
+
394
+ # -----------------------------------------------------------------------------
395
+ # Functions to compute constraints.
396
+ # -----------------------------------------------------------------------------
397
+
398
+
399
+ def analytical_constraints(
400
+ params: GaussianLevelsetParams,
401
+ length_scale_constraint_factor: float,
402
+ length_scale_constraint_beta: float,
403
+ length_scale_constraint_weight: float,
404
+ curvature_constraint_weight: float,
405
+ fixed_pixel_constraint_weight: float,
406
+ mask_gradient: bool,
407
+ pad_pixels: int,
408
+ ) -> jnp.ndarray:
409
+ """Computes analytical levelset constraints associated with the params."""
410
+ length_scale_constraint, curvature_constraint = _levelset_constraints(
411
+ params,
412
+ beta=length_scale_constraint_beta,
413
+ length_scale_constraint_factor=length_scale_constraint_factor,
414
+ pad_pixels=pad_pixels,
415
+ )
416
+ fixed_pixel_constraint = _fixed_pixel_constraint(
417
+ params,
418
+ mask_gradient=mask_gradient,
419
+ pad_pixels=pad_pixels,
420
+ )
421
+
422
+ constraints = jnp.stack(
423
+ [
424
+ length_scale_constraint * length_scale_constraint_weight,
425
+ curvature_constraint * curvature_constraint_weight,
426
+ fixed_pixel_constraint * fixed_pixel_constraint_weight,
427
+ ],
428
+ axis=-1,
429
+ )
430
+
431
+ # Normalize constraints to make them (somewhat) resolution-independent.
432
+ example_density = _example_density(params)
433
+ length_scale = 0.5 * (
434
+ example_density.minimum_spacing + example_density.minimum_width
435
+ )
436
+ return constraints / length_scale**2
437
+
438
+
439
+ def _fixed_pixel_constraint(
440
+ params: GaussianLevelsetParams,
441
+ mask_gradient: bool,
442
+ pad_pixels: int,
443
+ ) -> jnp.ndarray:
444
+ """Return the fixed pixel constraint array.
445
+
446
+ The fixed pixel constraint array is nonzero at locations where the density obtained
447
+ from `params` differs from fixed pixels.
448
+
449
+ Args:
450
+ params: The parameters from which the density is obtained.
451
+ mask_gradient: If `True`, the gradient is masked so that it is nonzero only at
452
+ the borders of features.
453
+ pad_pixels: The number of pixels added at borders. Values greater than zero
454
+ help to ensure that sharp features at the borders are avoided.
455
+
456
+ Returns:
457
+ The constraints array.
458
+ """
459
+ array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=pad_pixels)
460
+
461
+ example_density = _example_density(params)
462
+ fixed_solid = jnp.zeros(example_density.shape[-2:], dtype=bool)
463
+ fixed_void = jnp.zeros(example_density.shape[-2:], dtype=bool)
464
+ if example_density.fixed_solid is not None:
465
+ fixed_solid = jnp.asarray(example_density.fixed_solid)
466
+ if example_density.fixed_void is not None:
467
+ fixed_void = jnp.asarray(example_density.fixed_void)
468
+
469
+ pad_width_solid = ((0, 0),) * (fixed_solid.ndim - 2) + (
470
+ (pad_pixels, pad_pixels),
471
+ (pad_pixels, pad_pixels),
472
+ )
473
+ pad_width_void = ((0, 0),) * (fixed_void.ndim - 2) + (
474
+ (pad_pixels, pad_pixels),
475
+ (pad_pixels, pad_pixels),
476
+ )
477
+ fixed_solid = jnp.pad(fixed_solid, pad_width_solid, mode="edge")
478
+ fixed_void = jnp.pad(fixed_void, pad_width_void, mode="edge")
479
+ fixed = fixed_solid | fixed_void
480
+ target = jnp.where(fixed_solid, 1, 0)
481
+
482
+ return jnp.where(fixed, jnp.abs(array - target), 0.0)
483
+
484
+
485
+ def _levelset_constraints(
486
+ params: GaussianLevelsetParams,
487
+ beta: float,
488
+ length_scale_constraint_factor: float,
489
+ pad_pixels: int,
490
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
491
+ """Compute constraints for minimum width, spacing, and radius of curvature.
492
+
493
+ The constraints are based on "Analytical level set fabrication constraints for
494
+ inverse design," by D. Vercruysse et al. (2019). Constraints are satisfied when
495
+ they are non-positive.
496
+
497
+ https://www.nature.com/articles/s41598-019-45026-0
498
+
499
+ Args:
500
+ params: The parameters of the Gaussian levelset.
501
+ beta: Parameter which relaxes the constraint near the zero-plane.
502
+ length_scale_constraint_factor: Multiplies the target length scale in the
503
+ levelset constraints. A value greater than 1 is pessimistic and drives the
504
+ solution to have a larger length scale (relative to smaller values).
505
+ pad_pixels: A non-negative integer giving the additional pixels to be included
506
+ beyond the boundaries of the parameterized density.
507
+
508
+ Returns:
509
+ The minimum length scale and minimum curvature constraint arrays.s
510
+ """
511
+ example_density = _example_density(params)
512
+ minimum_length_scale = 0.5 * (
513
+ example_density.minimum_width + example_density.minimum_spacing
514
+ )
515
+
516
+ phi, phi_v, phi_vv, inverse_radius = _phi_derivatives_and_inverse_radius(
517
+ params,
518
+ pad_pixels=pad_pixels,
519
+ )
520
+
521
+ d = minimum_length_scale * length_scale_constraint_factor
522
+ denom = jnp.pi / d * jnp.abs(phi) + beta * phi_v
523
+ denom_safe = jnp.where(jnp.isclose(phi_vv, 0.0), 1.0, denom)
524
+ length_scale_constraint = jnp.abs(phi_vv) / denom_safe - jnp.pi / d
525
+
526
+ curvature_denom_safe = jnp.where(jnp.isclose(phi_v, 0.0), 1.0, phi)
527
+ curvature_constraint = (
528
+ jnp.abs(inverse_radius * jnp.arctan(phi_v / curvature_denom_safe)) - jnp.pi / d
529
+ )
530
+
531
+ # Downsample so that constraints shape matches the density shape.
532
+ factor = params.metadata.smoothing_factor
533
+ return (
534
+ _downsample_spatial_dims(length_scale_constraint, factor),
535
+ _downsample_spatial_dims(curvature_constraint, factor),
536
+ )
537
+
538
+
539
+ def _phi_derivatives_and_inverse_radius(
540
+ params: GaussianLevelsetParams,
541
+ pad_pixels: int,
542
+ ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
543
+ """Compute the levelset function and its first and second derivatives."""
544
+
545
+ phi = _phi_from_params(
546
+ params=params,
547
+ pad_pixels=pad_pixels,
548
+ )
549
+
550
+ d = 1 / params.metadata.smoothing_factor
551
+ phi_x, phi_y = jnp.gradient(phi, d, axis=(-2, -1))
552
+ phi_xx, phi_yx = jnp.gradient(phi_x, d, axis=(-2, -1))
553
+ phi_xy, phi_yy = jnp.gradient(phi_y, d, axis=(-2, -1))
554
+
555
+ phi_v = _sqrt_safe(phi_x**2 + phi_y**2)
556
+
557
+ # Compute "safe" versions of `phi_v` and its square, which are used to
558
+ # normalize quantities below. These are equal to 1 anywhere `phi_v` is
559
+ # close to zero, and take their usual values elsewhere.
560
+ phi_v_near_zero = jnp.isclose(phi_v, 0.0)
561
+ phi_v_squared_safe = jnp.where(phi_v_near_zero, 1.0, phi_v**2)
562
+ phi_v_safe = jnp.where(phi_v_near_zero, 1.0, phi_v)
563
+
564
+ weight_xx = phi_x**2 / phi_v_squared_safe
565
+ weight_yy = phi_y**2 / phi_v_squared_safe
566
+ weight_xy = (phi_x * phi_y) / phi_v_squared_safe
567
+ phi_vv = weight_xx * phi_xx + weight_xy * (phi_xy + phi_yx) + weight_yy * phi_yy
568
+
569
+ qx: jnp.ndarray = jnp.gradient( # type: ignore[assignment]
570
+ phi_x / jnp.abs(phi_v_safe), d, axis=-2
571
+ )
572
+ qy: jnp.ndarray = jnp.gradient( # type: ignore[assignment]
573
+ phi_y / jnp.abs(phi_v_safe), d, axis=-1
574
+ )
575
+ inverse_radius = qx + qy
576
+
577
+ return phi, phi_v, phi_vv, inverse_radius
578
+
579
+
580
+ # -----------------------------------------------------------------------------
581
+ # Helper functions.
582
+ # -----------------------------------------------------------------------------
583
+
584
+
585
+ def _control_point_coords(
586
+ density_shape: Tuple[int, int],
587
+ levelset_shape: Tuple[int, int],
588
+ periodic: Tuple[bool, bool],
589
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
590
+ """Returns the control point coordinates."""
591
+ # If the levelset is periodic along any axis, the first and last control
592
+ # points along that axis lie outside the bounds of the density.
593
+ offset_i = 0.5 if periodic[0] else -0.5
594
+ offset_j = 0.5 if periodic[1] else -0.5
595
+ range_i = levelset_shape[-2] - (0 if periodic[0] else 2)
596
+ range_j = levelset_shape[-1] - (0 if periodic[1] else 2)
597
+
598
+ factor_i = density_shape[-2] / range_i
599
+ factor_j = density_shape[-1] / range_j
600
+ levelset_i, levelset_j = jnp.meshgrid(
601
+ (offset_i + jnp.arange(levelset_shape[-2])) * factor_i,
602
+ (offset_j + jnp.arange(levelset_shape[-1])) * factor_j,
603
+ indexing="ij",
604
+ )
605
+ return levelset_i, levelset_j
606
+
607
+
608
+ def _sqrt_safe(x: jnp.ndarray) -> jnp.ndarray:
609
+ """Compute square root while avoiding `nan` gradients near zero."""
610
+ x_near_zero = jnp.isclose(x, 0.0)
611
+ x_safe = jnp.where(x_near_zero, 1, x)
612
+ return jnp.where(x_near_zero, 0.0, jnp.sqrt(x_safe))
613
+
614
+
615
+ def _levelset_threshold(
616
+ phi: jnp.ndarray,
617
+ periodic: Tuple[bool, bool],
618
+ mask_gradient: bool,
619
+ ) -> jnp.ndarray:
620
+ """Thresholds a level set function `phi`."""
621
+ if mask_gradient:
622
+ interface = _interface_pixels(phi, periodic)
623
+ phi = jnp.where(interface, phi, jax.lax.stop_gradient(phi))
624
+ thresholded = (phi > 0).astype(float) + (phi - jax.lax.stop_gradient(phi))
625
+ return thresholded
626
+
627
+
628
+ def _interface_pixels(phi: jnp.ndarray, periodic: Tuple[bool, bool]) -> jnp.ndarray:
629
+ """Identifies interface pixels of a level set function `phi`."""
630
+ batch_shape = phi.shape[:-2]
631
+ phi = phi.reshape((-1,) + phi.shape[-2:])
632
+
633
+ pad_mode = (
634
+ "wrap" if periodic[0] else "edge",
635
+ "wrap" if periodic[1] else "edge",
636
+ )
637
+ pad_width = ((1, 1), (1, 1))
638
+
639
+ kernel = jnp.asarray([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float)
640
+
641
+ solid = phi > 0
642
+ void = ~solid
643
+
644
+ solid_padded = transforms.pad2d(solid, pad_width, pad_mode)
645
+ num_solid_adjacent = transforms.conv(
646
+ x=solid_padded[:, jnp.newaxis, :, :].astype(float),
647
+ kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
648
+ padding="VALID",
649
+ )
650
+ num_solid_adjacent = jnp.squeeze(num_solid_adjacent, axis=1)
651
+
652
+ void_padded = transforms.pad2d(void, pad_width, pad_mode)
653
+ num_void_adjacent = transforms.conv(
654
+ x=void_padded[:, jnp.newaxis, :, :].astype(float),
655
+ kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
656
+ padding="VALID",
657
+ )
658
+ num_void_adjacent = jnp.squeeze(num_void_adjacent, axis=1)
659
+
660
+ interface = solid & (num_void_adjacent > 0) | void & (num_solid_adjacent > 0)
661
+
662
+ return interface.reshape(batch_shape + interface.shape[-2:])
663
+
664
+
665
+ def _downsample_spatial_dims(x: jnp.ndarray, downsample_factor: int) -> jnp.ndarray:
666
+ """Downsamples the two trailing axes of `x` by `downsample_factor`."""
667
+ shape = x.shape[:-2] + (
668
+ x.shape[-2] // downsample_factor,
669
+ x.shape[-1] // downsample_factor,
670
+ )
671
+ return transforms.box_downsample(x, shape)