invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,671 @@
1
+ """Defines Gaussian radial basis function levelset parameterization.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+ from typing import Any, Tuple
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import optax # type: ignore[import-untyped]
12
+ from jax import tree_util
13
+ from totypes import json_utils, symmetry, types
14
+
15
+ from invrs_opt.parameterization import base, transforms
16
+
17
+ PyTree = Any
18
+
19
+ DEFAULT_LENGTH_SCALE_SPACING_FACTOR: float = 2.0
20
+ DEFAULT_LENGTH_SCALE_FWHM_FACTOR: float = 1.0
21
+ DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR: float = 1.15
22
+ DEFAULT_SMOOTHING_FACTOR: int = 2
23
+ DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA: float = 0.333
24
+ DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT: float = 1.0
25
+ DEFAULT_CURVATURE_CONSTRAINT_WEIGHT: float = 2.0
26
+ DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT: float = 10.0
27
+ DEFAULT_INIT_STEPS: int = 50
28
+ DEFAULT_INIT_OPTIMIZER: optax.GradientTransformation = optax.adam(1e-1)
29
+
30
+
31
+ @dataclasses.dataclass
32
+ class GaussianLevelsetParams(base.ParameterizedDensity2DArray):
33
+ """Stores parameters for the Gaussian levelset parameterization."""
34
+
35
+ latents: "GaussianLevelsetLatents"
36
+ metadata: "GaussianLevelsetMetadata"
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class GaussianLevelsetLatents(base.LatentsBase):
41
+ """Stores latent parameters for the Gaussian levelset parameterization.
42
+
43
+ Attributes:
44
+ amplitude: Array giving the amplitude of the Gaussian basis function at
45
+ levelset control points.
46
+ """
47
+
48
+ amplitude: jnp.ndarray
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class GaussianLevelsetMetadata(base.MetadataBase):
53
+ """Stores metadata for the Gaussian levelset parameterization.
54
+
55
+ Attributes:
56
+ length_scale_spacing_factor: The number of levelset control points per unit of
57
+ minimum length scale (mean of density minimum width and minimum spacing).
58
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
59
+ the minimum length scale.
60
+ smoothing_factor: For values greater than 1, the density is initially computed
61
+ at higher resolution and then downsampled, yielding smoother geometries.
62
+ density_shape: Shape of the density array obtained from the parameters.
63
+ density_metadata: Metadata for the density array obtained from the parameters.
64
+ """
65
+
66
+ length_scale_spacing_factor: float
67
+ length_scale_fwhm_factor: float
68
+ smoothing_factor: int
69
+ density_shape: Tuple[int, ...]
70
+ density_metadata: base.Density2DMetadata
71
+
72
+ def __post_init__(self) -> None:
73
+ self.density_shape = tuple(self.density_shape)
74
+
75
+
76
+ tree_util.register_dataclass(
77
+ GaussianLevelsetParams,
78
+ data_fields=["latents", "metadata"],
79
+ meta_fields=[],
80
+ )
81
+ tree_util.register_dataclass(
82
+ GaussianLevelsetLatents,
83
+ data_fields=["amplitude"],
84
+ meta_fields=[],
85
+ )
86
+ tree_util.register_dataclass(
87
+ GaussianLevelsetMetadata,
88
+ data_fields=[
89
+ "length_scale_spacing_factor",
90
+ "length_scale_fwhm_factor",
91
+ "density_metadata",
92
+ ],
93
+ meta_fields=["density_shape", "smoothing_factor"],
94
+ )
95
+ json_utils.register_custom_type(GaussianLevelsetParams)
96
+ json_utils.register_custom_type(GaussianLevelsetLatents)
97
+ json_utils.register_custom_type(GaussianLevelsetMetadata)
98
+
99
+
100
+ def gaussian_levelset(
101
+ *,
102
+ length_scale_spacing_factor: float = DEFAULT_LENGTH_SCALE_SPACING_FACTOR,
103
+ length_scale_fwhm_factor: float = DEFAULT_LENGTH_SCALE_FWHM_FACTOR,
104
+ length_scale_constraint_factor: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR,
105
+ smoothing_factor: int = DEFAULT_SMOOTHING_FACTOR,
106
+ length_scale_constraint_beta: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA,
107
+ length_scale_constraint_weight: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT,
108
+ curvature_constraint_weight: float = DEFAULT_CURVATURE_CONSTRAINT_WEIGHT,
109
+ fixed_pixel_constraint_weight: float = DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT,
110
+ init_optimizer: optax.GradientTransformation = DEFAULT_INIT_OPTIMIZER,
111
+ init_steps: int = DEFAULT_INIT_STEPS,
112
+ ) -> base.Density2DParameterization:
113
+ """Defines a levelset parameterization with Gaussian radial basis functions.
114
+
115
+ Args:
116
+ length_scale_spacing_factor: The number of levelset control points per unit of
117
+ minimum length scale (mean of density minimum width and minimum spacing).
118
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
119
+ the minimum length scale.
120
+ length_scale_constraint_factor: Multiplies the target length scale in the
121
+ levelset constraints. A value greater than 1 is pessimistic and drives the
122
+ solution to have a larger length scale (relative to smaller values).
123
+ smoothing_factor: For values greater than 1, the density is initially computed
124
+ at higher resolution and then downsampled, yielding smoother geometries.
125
+ length_scale_constraint_beta: Controls relaxation of the length scale
126
+ constraint near the zero level.
127
+ length_scale_constraint_weight: The weight of the length scale constraint in
128
+ the overall fabrication constraint peenalty.
129
+ curvature_constraint_weight: The weight of the curvature constraint.
130
+ fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
131
+ init_optimizer: The optimizer used in the initialization of the levelset
132
+ parameterization. At initialization, the latent parameters are optimized so
133
+ that the initial parameters match the binarized initial density.
134
+ init_steps: The number of optimization steps used in the initialization.
135
+
136
+ Returns:
137
+ The `Density2DParameterization`.
138
+ """
139
+
140
+ def from_density_fn(density: types.Density2DArray) -> GaussianLevelsetParams:
141
+ """Return level set parameters for the given `density`."""
142
+ length_scale = (density.minimum_width + density.minimum_spacing) / 2
143
+ spacing_factor = length_scale_spacing_factor / length_scale
144
+ shape = density.shape[:-2] + (
145
+ int(jnp.ceil(density.shape[-2] * spacing_factor)),
146
+ int(jnp.ceil(density.shape[-1] * spacing_factor)),
147
+ )
148
+
149
+ mid_value = 0.5 * (density.lower_bound + density.upper_bound)
150
+ value_range = density.upper_bound - density.lower_bound
151
+ target_array = transforms.apply_fixed_pixels(density).array
152
+ target_array = (
153
+ jnp.sign(target_array - mid_value) * 0.5 * value_range + mid_value
154
+ )
155
+
156
+ # Generate the initial amplitude array.
157
+ amplitude = density.array - (density.upper_bound + density.lower_bound) / 2
158
+ amplitude = jnp.sign(amplitude)
159
+ amplitude = transforms.resample(amplitude, shape)
160
+
161
+ # If the density is not periodic, ensure there are level set control points
162
+ # beyond the edge of the density array.
163
+ pad_width = ((0, 0),) * (amplitude.ndim - 2)
164
+ pad_width += ((0, 0),) if density.periodic[0] else ((1, 1),)
165
+ pad_width += ((0, 0),) if density.periodic[1] else ((1, 1),)
166
+ amplitude = jnp.pad(amplitude, pad_width, mode="edge")
167
+
168
+ latents = GaussianLevelsetLatents(amplitude=amplitude)
169
+ metadata = GaussianLevelsetMetadata(
170
+ length_scale_spacing_factor=length_scale_spacing_factor,
171
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
172
+ smoothing_factor=smoothing_factor,
173
+ density_shape=density.shape,
174
+ density_metadata=base.Density2DMetadata.from_density(density),
175
+ )
176
+
177
+ def step_fn(
178
+ _: int,
179
+ params_and_state: Tuple[PyTree, PyTree],
180
+ ) -> Tuple[PyTree, PyTree]:
181
+ def loss_fn(latents: GaussianLevelsetLatents) -> jnp.ndarray:
182
+ params = GaussianLevelsetParams(latents, metadata=metadata)
183
+ density_from_params = to_density_fn(params, mask_gradient=False)
184
+ return jnp.mean((density_from_params.array - target_array) ** 2)
185
+
186
+ params, state = params_and_state
187
+ grad = jax.grad(loss_fn)(params)
188
+ updates, state = init_optimizer.update(grad, params=params, state=state)
189
+ params = optax.apply_updates(params, updates)
190
+ return params, state
191
+
192
+ state = init_optimizer.init(latents)
193
+ latents, _ = jax.lax.fori_loop(
194
+ 0, init_steps, body_fun=step_fn, init_val=(latents, state)
195
+ )
196
+
197
+ maxval = jnp.amax(jnp.abs(latents.amplitude), axis=(-2, -1), keepdims=True)
198
+ latents = dataclasses.replace(latents, amplitude=latents.amplitude / maxval)
199
+ return GaussianLevelsetParams(latents=latents, metadata=metadata)
200
+
201
+ def to_density_fn(
202
+ params: GaussianLevelsetParams,
203
+ mask_gradient: bool = True,
204
+ ) -> types.Density2DArray:
205
+ """Return a density from the latent parameters."""
206
+ array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=0)
207
+
208
+ example_density = _example_density(params)
209
+ lb = example_density.lower_bound
210
+ ub = example_density.upper_bound
211
+ array = lb + array * (ub - lb)
212
+ assert array.shape == example_density.shape
213
+ return dataclasses.replace(example_density, array=array)
214
+
215
+ def constraints_fn(
216
+ params: GaussianLevelsetParams,
217
+ mask_gradient: bool = True,
218
+ pad_pixels: int = 2,
219
+ ) -> jnp.ndarray:
220
+ """Computes constraints associated with the params."""
221
+ return analytical_constraints(
222
+ params=params,
223
+ length_scale_constraint_factor=length_scale_constraint_factor,
224
+ length_scale_constraint_beta=length_scale_constraint_beta,
225
+ length_scale_constraint_weight=length_scale_constraint_weight,
226
+ curvature_constraint_weight=curvature_constraint_weight,
227
+ fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
228
+ mask_gradient=mask_gradient,
229
+ pad_pixels=pad_pixels,
230
+ )
231
+
232
+ def update_fn(
233
+ params: GaussianLevelsetParams,
234
+ updates: GaussianLevelsetParams,
235
+ value: jnp.ndarray,
236
+ step: int,
237
+ ) -> GaussianLevelsetParams:
238
+ """Perform updates to `params` required for the given `step`."""
239
+ del step, value
240
+ return GaussianLevelsetParams(
241
+ latents=tree_util.tree_map(
242
+ lambda a, b: a + b, params.latents, updates.latents
243
+ ),
244
+ metadata=params.metadata,
245
+ )
246
+
247
+ return base.Density2DParameterization(
248
+ to_density=to_density_fn,
249
+ from_density=from_density_fn,
250
+ constraints=constraints_fn,
251
+ update=update_fn,
252
+ )
253
+
254
+
255
+ # -----------------------------------------------------------------------------
256
+ # Functions to obtain arrays from the levelset parameterization.
257
+ # -----------------------------------------------------------------------------
258
+
259
+
260
+ def _example_density(params: GaussianLevelsetParams) -> types.Density2DArray:
261
+ """Returns an example density with appropriate shape and metadata."""
262
+ with jax.ensure_compile_time_eval():
263
+ return types.Density2DArray(
264
+ array=jnp.zeros(params.metadata.density_shape),
265
+ **dataclasses.asdict(params.metadata.density_metadata),
266
+ )
267
+
268
+
269
+ def _to_array(
270
+ params: GaussianLevelsetParams,
271
+ mask_gradient: bool,
272
+ pad_pixels: int,
273
+ ) -> jnp.ndarray:
274
+ """Return an array from the parameters.
275
+
276
+ The array has a value of `1` where the levelset array is positive, and a value
277
+ of `-1` elsewhere. The final density array can be obtained by rescaling this array
278
+ to have the appropriate upper and lower bounds.
279
+
280
+ Args:
281
+ params: The parameters from which the density is obtained.
282
+ mask_gradient: If `True`, the gradient is masked so that it is nonzero only at
283
+ the borders of features.
284
+ pad_pixels: A non-negative integer giving the additional pixels to be included
285
+ beyond the boundaries of the parameterized density.
286
+
287
+ Returns:
288
+ The array.
289
+ """
290
+ example_density = _example_density(params)
291
+ periodic: Tuple[bool, bool] = example_density.periodic
292
+ phi = _phi_from_params(params=params, pad_pixels=pad_pixels)
293
+ array = _levelset_threshold(phi=phi, periodic=periodic, mask_gradient=mask_gradient)
294
+ return _downsample_spatial_dims(array, params.metadata.smoothing_factor)
295
+
296
+
297
+ def _phi_from_params(
298
+ params: GaussianLevelsetParams,
299
+ pad_pixels: int,
300
+ ) -> jnp.ndarray:
301
+ """Return the levelset function for the given `params`.
302
+
303
+ Args:
304
+ params: The parameters from which the density is obtained.
305
+ pad_pixels: A non-negative integer giving the additional pixels to be included
306
+ beyond the boundaries of the parameterized density.
307
+
308
+ Returns:
309
+ The levelset array `phi`.
310
+ """
311
+ with jax.ensure_compile_time_eval():
312
+ example_density = _example_density(params)
313
+ length_scale = 0.5 * (
314
+ example_density.minimum_width + example_density.minimum_spacing
315
+ )
316
+ fwhm = length_scale * params.metadata.length_scale_fwhm_factor
317
+ sigma = fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
318
+
319
+ s_factor = params.metadata.smoothing_factor
320
+ highres_i = (
321
+ 0.5
322
+ + jnp.arange(
323
+ s_factor * (-pad_pixels),
324
+ s_factor * (pad_pixels + example_density.shape[-2]),
325
+ )
326
+ ) / s_factor
327
+ highres_j = (
328
+ 0.5
329
+ + jnp.arange(
330
+ s_factor * (-pad_pixels),
331
+ s_factor * (pad_pixels + example_density.shape[-1]),
332
+ )
333
+ ) / s_factor
334
+
335
+ # Coordinates for the control points of the Gaussian radial basis functions.
336
+ levelset_i, levelset_j = _control_point_coords(
337
+ density_shape=params.metadata.density_shape[-2:], # type: ignore[arg-type]
338
+ levelset_shape=(
339
+ params.latents.amplitude.shape[-2:] # type: ignore[arg-type]
340
+ ),
341
+ periodic=example_density.periodic,
342
+ )
343
+
344
+ # Handle periodicity by replicating control points over a 3x3 supercell.
345
+ if example_density.periodic[0]:
346
+ levelset_i = jnp.concatenate(
347
+ [
348
+ levelset_i - example_density.shape[-2],
349
+ levelset_i,
350
+ levelset_i + example_density.shape[-2],
351
+ ],
352
+ axis=-2,
353
+ )
354
+ levelset_j = jnp.concatenate([levelset_j] * 3, axis=-2)
355
+ if example_density.periodic[1]:
356
+ levelset_i = jnp.concatenate([levelset_i] * 3, axis=-1)
357
+ levelset_j = jnp.concatenate(
358
+ [
359
+ levelset_j - example_density.shape[-1],
360
+ levelset_j,
361
+ levelset_j + example_density.shape[-1],
362
+ ],
363
+ axis=-1,
364
+ )
365
+
366
+ levelset_i = levelset_i.flatten()
367
+ levelset_j = levelset_j.flatten()
368
+
369
+ amplitude = params.latents.amplitude
370
+ if example_density.periodic[0]:
371
+ amplitude = jnp.concat([amplitude] * 3, axis=-2)
372
+ if example_density.periodic[1]:
373
+ amplitude = jnp.concat([amplitude] * 3, axis=-1)
374
+
375
+ amplitude = amplitude.reshape(amplitude.shape[:-2] + (1, -1))
376
+
377
+ # Use a scan operation to compute the array; this lowers memory consumption.
378
+ def scan_fn(_: Tuple[()], i: jnp.ndarray) -> Tuple[Tuple[()], jnp.ndarray]:
379
+ distance_sq = (i - levelset_i) ** 2 + (
380
+ highres_j[:, jnp.newaxis] - levelset_j
381
+ ) ** 2
382
+ basis = jnp.exp(-distance_sq / sigma**2)
383
+ return (), jnp.sum(basis * amplitude, axis=-1)
384
+
385
+ _, array = jax.lax.scan(scan_fn, (), xs=highres_i)
386
+ array = jnp.moveaxis(array, 0, -2)
387
+
388
+ assert array.shape[-2] % s_factor == 0
389
+ assert array.shape[-1] % s_factor == 0
390
+ array = symmetry.symmetrize(array, tuple(example_density.symmetries))
391
+ return array
392
+
393
+
394
+ # -----------------------------------------------------------------------------
395
+ # Functions to compute constraints.
396
+ # -----------------------------------------------------------------------------
397
+
398
+
399
+ def analytical_constraints(
400
+ params: GaussianLevelsetParams,
401
+ length_scale_constraint_factor: float,
402
+ length_scale_constraint_beta: float,
403
+ length_scale_constraint_weight: float,
404
+ curvature_constraint_weight: float,
405
+ fixed_pixel_constraint_weight: float,
406
+ mask_gradient: bool,
407
+ pad_pixels: int,
408
+ ) -> jnp.ndarray:
409
+ """Computes analytical levelset constraints associated with the params."""
410
+ length_scale_constraint, curvature_constraint = _levelset_constraints(
411
+ params,
412
+ beta=length_scale_constraint_beta,
413
+ length_scale_constraint_factor=length_scale_constraint_factor,
414
+ pad_pixels=pad_pixels,
415
+ )
416
+ fixed_pixel_constraint = _fixed_pixel_constraint(
417
+ params,
418
+ mask_gradient=mask_gradient,
419
+ pad_pixels=pad_pixels,
420
+ )
421
+
422
+ constraints = jnp.stack(
423
+ [
424
+ length_scale_constraint * length_scale_constraint_weight,
425
+ curvature_constraint * curvature_constraint_weight,
426
+ fixed_pixel_constraint * fixed_pixel_constraint_weight,
427
+ ],
428
+ axis=-1,
429
+ )
430
+
431
+ # Normalize constraints to make them (somewhat) resolution-independent.
432
+ example_density = _example_density(params)
433
+ length_scale = 0.5 * (
434
+ example_density.minimum_spacing + example_density.minimum_width
435
+ )
436
+ return constraints / length_scale**2
437
+
438
+
439
+ def _fixed_pixel_constraint(
440
+ params: GaussianLevelsetParams,
441
+ mask_gradient: bool,
442
+ pad_pixels: int,
443
+ ) -> jnp.ndarray:
444
+ """Return the fixed pixel constraint array.
445
+
446
+ The fixed pixel constraint array is nonzero at locations where the density obtained
447
+ from `params` differs from fixed pixels.
448
+
449
+ Args:
450
+ params: The parameters from which the density is obtained.
451
+ mask_gradient: If `True`, the gradient is masked so that it is nonzero only at
452
+ the borders of features.
453
+ pad_pixels: The number of pixels added at borders. Values greater than zero
454
+ help to ensure that sharp features at the borders are avoided.
455
+
456
+ Returns:
457
+ The constraints array.
458
+ """
459
+ array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=pad_pixels)
460
+
461
+ example_density = _example_density(params)
462
+ fixed_solid = jnp.zeros(example_density.shape[-2:], dtype=bool)
463
+ fixed_void = jnp.zeros(example_density.shape[-2:], dtype=bool)
464
+ if example_density.fixed_solid is not None:
465
+ fixed_solid = jnp.asarray(example_density.fixed_solid)
466
+ if example_density.fixed_void is not None:
467
+ fixed_void = jnp.asarray(example_density.fixed_void)
468
+
469
+ pad_width_solid = ((0, 0),) * (fixed_solid.ndim - 2) + (
470
+ (pad_pixels, pad_pixels),
471
+ (pad_pixels, pad_pixels),
472
+ )
473
+ pad_width_void = ((0, 0),) * (fixed_void.ndim - 2) + (
474
+ (pad_pixels, pad_pixels),
475
+ (pad_pixels, pad_pixels),
476
+ )
477
+ fixed_solid = jnp.pad(fixed_solid, pad_width_solid, mode="edge")
478
+ fixed_void = jnp.pad(fixed_void, pad_width_void, mode="edge")
479
+ fixed = fixed_solid | fixed_void
480
+ target = jnp.where(fixed_solid, 1, 0)
481
+
482
+ return jnp.where(fixed, jnp.abs(array - target), 0.0)
483
+
484
+
485
+ def _levelset_constraints(
486
+ params: GaussianLevelsetParams,
487
+ beta: float,
488
+ length_scale_constraint_factor: float,
489
+ pad_pixels: int,
490
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
491
+ """Compute constraints for minimum width, spacing, and radius of curvature.
492
+
493
+ The constraints are based on "Analytical level set fabrication constraints for
494
+ inverse design," by D. Vercruysse et al. (2019). Constraints are satisfied when
495
+ they are non-positive.
496
+
497
+ https://www.nature.com/articles/s41598-019-45026-0
498
+
499
+ Args:
500
+ params: The parameters of the Gaussian levelset.
501
+ beta: Parameter which relaxes the constraint near the zero-plane.
502
+ length_scale_constraint_factor: Multiplies the target length scale in the
503
+ levelset constraints. A value greater than 1 is pessimistic and drives the
504
+ solution to have a larger length scale (relative to smaller values).
505
+ pad_pixels: A non-negative integer giving the additional pixels to be included
506
+ beyond the boundaries of the parameterized density.
507
+
508
+ Returns:
509
+ The minimum length scale and minimum curvature constraint arrays.s
510
+ """
511
+ example_density = _example_density(params)
512
+ minimum_length_scale = 0.5 * (
513
+ example_density.minimum_width + example_density.minimum_spacing
514
+ )
515
+
516
+ phi, phi_v, phi_vv, inverse_radius = _phi_derivatives_and_inverse_radius(
517
+ params,
518
+ pad_pixels=pad_pixels,
519
+ )
520
+
521
+ d = minimum_length_scale * length_scale_constraint_factor
522
+ denom = jnp.pi / d * jnp.abs(phi) + beta * phi_v
523
+ denom_safe = jnp.where(jnp.isclose(phi_vv, 0.0), 1.0, denom)
524
+ length_scale_constraint = jnp.abs(phi_vv) / denom_safe - jnp.pi / d
525
+
526
+ curvature_denom_safe = jnp.where(jnp.isclose(phi_v, 0.0), 1.0, phi)
527
+ curvature_constraint = (
528
+ jnp.abs(inverse_radius * jnp.arctan(phi_v / curvature_denom_safe)) - jnp.pi / d
529
+ )
530
+
531
+ # Downsample so that constraints shape matches the density shape.
532
+ factor = params.metadata.smoothing_factor
533
+ return (
534
+ _downsample_spatial_dims(length_scale_constraint, factor),
535
+ _downsample_spatial_dims(curvature_constraint, factor),
536
+ )
537
+
538
+
539
+ def _phi_derivatives_and_inverse_radius(
540
+ params: GaussianLevelsetParams,
541
+ pad_pixels: int,
542
+ ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
543
+ """Compute the levelset function and its first and second derivatives."""
544
+
545
+ phi = _phi_from_params(
546
+ params=params,
547
+ pad_pixels=pad_pixels,
548
+ )
549
+
550
+ d = 1 / params.metadata.smoothing_factor
551
+ phi_x, phi_y = jnp.gradient(phi, d, axis=(-2, -1))
552
+ phi_xx, phi_yx = jnp.gradient(phi_x, d, axis=(-2, -1))
553
+ phi_xy, phi_yy = jnp.gradient(phi_y, d, axis=(-2, -1))
554
+
555
+ phi_v = _sqrt_safe(phi_x**2 + phi_y**2)
556
+
557
+ # Compute "safe" versions of `phi_v` and its square, which are used to
558
+ # normalize quantities below. These are equal to 1 anywhere `phi_v` is
559
+ # close to zero, and take their usual values elsewhere.
560
+ phi_v_near_zero = jnp.isclose(phi_v, 0.0)
561
+ phi_v_squared_safe = jnp.where(phi_v_near_zero, 1.0, phi_v**2)
562
+ phi_v_safe = jnp.where(phi_v_near_zero, 1.0, phi_v)
563
+
564
+ weight_xx = phi_x**2 / phi_v_squared_safe
565
+ weight_yy = phi_y**2 / phi_v_squared_safe
566
+ weight_xy = (phi_x * phi_y) / phi_v_squared_safe
567
+ phi_vv = weight_xx * phi_xx + weight_xy * (phi_xy + phi_yx) + weight_yy * phi_yy
568
+
569
+ qx: jnp.ndarray = jnp.gradient( # type: ignore[assignment]
570
+ phi_x / jnp.abs(phi_v_safe), d, axis=-2
571
+ )
572
+ qy: jnp.ndarray = jnp.gradient( # type: ignore[assignment]
573
+ phi_y / jnp.abs(phi_v_safe), d, axis=-1
574
+ )
575
+ inverse_radius = qx + qy
576
+
577
+ return phi, phi_v, phi_vv, inverse_radius
578
+
579
+
580
+ # -----------------------------------------------------------------------------
581
+ # Helper functions.
582
+ # -----------------------------------------------------------------------------
583
+
584
+
585
+ def _control_point_coords(
586
+ density_shape: Tuple[int, int],
587
+ levelset_shape: Tuple[int, int],
588
+ periodic: Tuple[bool, bool],
589
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
590
+ """Returns the control point coordinates."""
591
+ # If the levelset is periodic along any axis, the first and last control
592
+ # points along that axis lie outside the bounds of the density.
593
+ offset_i = 0.5 if periodic[0] else -0.5
594
+ offset_j = 0.5 if periodic[1] else -0.5
595
+ range_i = levelset_shape[-2] - (0 if periodic[0] else 2)
596
+ range_j = levelset_shape[-1] - (0 if periodic[1] else 2)
597
+
598
+ factor_i = density_shape[-2] / range_i
599
+ factor_j = density_shape[-1] / range_j
600
+ levelset_i, levelset_j = jnp.meshgrid(
601
+ (offset_i + jnp.arange(levelset_shape[-2])) * factor_i,
602
+ (offset_j + jnp.arange(levelset_shape[-1])) * factor_j,
603
+ indexing="ij",
604
+ )
605
+ return levelset_i, levelset_j
606
+
607
+
608
+ def _sqrt_safe(x: jnp.ndarray) -> jnp.ndarray:
609
+ """Compute square root while avoiding `nan` gradients near zero."""
610
+ x_near_zero = jnp.isclose(x, 0.0)
611
+ x_safe = jnp.where(x_near_zero, 1, x)
612
+ return jnp.where(x_near_zero, 0.0, jnp.sqrt(x_safe))
613
+
614
+
615
+ def _levelset_threshold(
616
+ phi: jnp.ndarray,
617
+ periodic: Tuple[bool, bool],
618
+ mask_gradient: bool,
619
+ ) -> jnp.ndarray:
620
+ """Thresholds a level set function `phi`."""
621
+ if mask_gradient:
622
+ interface = _interface_pixels(phi, periodic)
623
+ phi = jnp.where(interface, phi, jax.lax.stop_gradient(phi))
624
+ thresholded = (phi > 0).astype(float) + (phi - jax.lax.stop_gradient(phi))
625
+ return thresholded
626
+
627
+
628
+ def _interface_pixels(phi: jnp.ndarray, periodic: Tuple[bool, bool]) -> jnp.ndarray:
629
+ """Identifies interface pixels of a level set function `phi`."""
630
+ batch_shape = phi.shape[:-2]
631
+ phi = phi.reshape((-1,) + phi.shape[-2:])
632
+
633
+ pad_mode = (
634
+ "wrap" if periodic[0] else "edge",
635
+ "wrap" if periodic[1] else "edge",
636
+ )
637
+ pad_width = ((1, 1), (1, 1))
638
+
639
+ kernel = jnp.asarray([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float)
640
+
641
+ solid = phi > 0
642
+ void = ~solid
643
+
644
+ solid_padded = transforms.pad2d(solid, pad_width, pad_mode)
645
+ num_solid_adjacent = transforms.conv(
646
+ x=solid_padded[:, jnp.newaxis, :, :].astype(float),
647
+ kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
648
+ padding="VALID",
649
+ )
650
+ num_solid_adjacent = jnp.squeeze(num_solid_adjacent, axis=1)
651
+
652
+ void_padded = transforms.pad2d(void, pad_width, pad_mode)
653
+ num_void_adjacent = transforms.conv(
654
+ x=void_padded[:, jnp.newaxis, :, :].astype(float),
655
+ kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
656
+ padding="VALID",
657
+ )
658
+ num_void_adjacent = jnp.squeeze(num_void_adjacent, axis=1)
659
+
660
+ interface = solid & (num_void_adjacent > 0) | void & (num_solid_adjacent > 0)
661
+
662
+ return interface.reshape(batch_shape + interface.shape[-2:])
663
+
664
+
665
+ def _downsample_spatial_dims(x: jnp.ndarray, downsample_factor: int) -> jnp.ndarray:
666
+ """Downsamples the two trailing axes of `x` by `downsample_factor`."""
667
+ shape = x.shape[:-2] + (
668
+ x.shape[-2] // downsample_factor,
669
+ x.shape[-1] // downsample_factor,
670
+ )
671
+ return transforms.box_downsample(x, shape)