invrs-opt 0.5.2__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,643 @@
1
+ """Defines Gaussian radial basis function levelset parameterization.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+ from typing import Any, Tuple
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import optax # type: ignore[import-untyped]
12
+ from jax import tree_util
13
+ from totypes import json_utils, symmetry, types
14
+
15
+ from invrs_opt.parameterization import base, transforms
16
+
17
+ PyTree = Any
18
+
19
+ DEFAULT_LENGTH_SCALE_SPACING_FACTOR: float = 2.0
20
+ DEFAULT_LENGTH_SCALE_FWHM_FACTOR: float = 1.0
21
+ DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR: float = 1.15
22
+ DEFAULT_SMOOTHING_FACTOR: int = 2
23
+ DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA: float = 0.333
24
+ DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT: float = 1.0
25
+ DEFAULT_CURVATURE_CONSTRAINT_WEIGHT: float = 2.0
26
+ DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT: float = 10.0
27
+ DEFAULT_INIT_STEPS: int = 50
28
+ DEFAULT_INIT_OPTIMIZER: optax.GradientTransformation = optax.adam(1e-1)
29
+
30
+
31
+ @dataclasses.dataclass
32
+ class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase):
33
+ """Parameters of a density represented by a Gaussian levelset.
34
+
35
+ Attributes:
36
+ amplitude: Array giving the amplitude of the Gaussian basis function at
37
+ levelset control points.
38
+ length_scale_spacing_factor: The number of levelset control points per unit of
39
+ minimum length scale (mean of density minimum width and minimum spacing).
40
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
41
+ the minimum length scale.
42
+ smoothing_factor: For values greater than 1, the density is initially computed
43
+ at higher resolution and then downsampled, yielding smoother geometries.
44
+ density_shape: Shape of the density array obtained from the parameters.
45
+ density_metadata: Metadata for the density array obtained from the parameters.
46
+ """
47
+
48
+ amplitude: jnp.ndarray
49
+ length_scale_spacing_factor: float
50
+ length_scale_fwhm_factor: float
51
+ smoothing_factor: int
52
+ density_shape: Tuple[int, ...]
53
+ density_metadata: base.Density2DMetadata
54
+
55
+ def __post_init__(self) -> None:
56
+ self.density_shape = tuple(self.density_shape)
57
+
58
+ def example_density(self) -> types.Density2DArray:
59
+ """Returns an example density with appropriate shape and metadata."""
60
+ with jax.ensure_compile_time_eval():
61
+ return types.Density2DArray(
62
+ array=jnp.zeros(self.density_shape),
63
+ **dataclasses.asdict(self.density_metadata),
64
+ )
65
+
66
+
67
+ _GaussianLevelsetParamsAux = Tuple[
68
+ float, float, int, Tuple[int, ...], tree_util.PyTreeDef
69
+ ]
70
+
71
+
72
+ def _flatten_gaussian_levelset_params(
73
+ params: GaussianLevelsetParams,
74
+ ) -> Tuple[Tuple[jnp.ndarray], _GaussianLevelsetParamsAux]:
75
+ _, flat_metadata = tree_util.tree_flatten(params.density_metadata)
76
+ return (
77
+ (params.amplitude,),
78
+ (
79
+ params.length_scale_spacing_factor,
80
+ params.length_scale_fwhm_factor,
81
+ params.smoothing_factor,
82
+ params.density_shape,
83
+ flat_metadata,
84
+ ),
85
+ )
86
+
87
+
88
+ def _unflatten_gaussian_levelset_params(
89
+ aux: _GaussianLevelsetParamsAux,
90
+ children: Tuple[jnp.ndarray],
91
+ ) -> GaussianLevelsetParams:
92
+ (amplitude,) = children
93
+ (
94
+ length_scale_spacing_factor,
95
+ length_scale_fwhm_factor,
96
+ smoothing_factor,
97
+ density_shape,
98
+ flat_metadata,
99
+ ) = aux
100
+
101
+ density_metadata = tree_util.tree_unflatten(flat_metadata, ())
102
+ return GaussianLevelsetParams(
103
+ amplitude=amplitude,
104
+ length_scale_spacing_factor=length_scale_spacing_factor,
105
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
106
+ smoothing_factor=smoothing_factor,
107
+ density_shape=tuple(density_shape),
108
+ density_metadata=density_metadata,
109
+ )
110
+
111
+
112
+ tree_util.register_pytree_node(
113
+ GaussianLevelsetParams,
114
+ flatten_func=_flatten_gaussian_levelset_params,
115
+ unflatten_func=_unflatten_gaussian_levelset_params,
116
+ )
117
+
118
+
119
+ json_utils.register_custom_type(GaussianLevelsetParams)
120
+
121
+
122
+ def gaussian_levelset(
123
+ *,
124
+ length_scale_spacing_factor: float = DEFAULT_LENGTH_SCALE_SPACING_FACTOR,
125
+ length_scale_fwhm_factor: float = DEFAULT_LENGTH_SCALE_FWHM_FACTOR,
126
+ length_scale_constraint_factor: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR,
127
+ smoothing_factor: int = DEFAULT_SMOOTHING_FACTOR,
128
+ length_scale_constraint_beta: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA,
129
+ length_scale_constraint_weight: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT,
130
+ curvature_constraint_weight: float = DEFAULT_CURVATURE_CONSTRAINT_WEIGHT,
131
+ fixed_pixel_constraint_weight: float = DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT,
132
+ init_optimizer: optax.GradientTransformation = DEFAULT_INIT_OPTIMIZER,
133
+ init_steps: int = DEFAULT_INIT_STEPS,
134
+ ) -> base.Density2DParameterization:
135
+ """Defines a levelset parameterization with Gaussian radial basis functions.
136
+
137
+ Args:
138
+ length_scale_spacing_factor: The number of levelset control points per unit of
139
+ minimum length scale (mean of density minimum width and minimum spacing).
140
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
141
+ the minimum length scale.
142
+ length_scale_constraint_factor: Multiplies the target length scale in the
143
+ levelset constraints. A value greater than 1 is pessimistic and drives the
144
+ solution to have a larger length scale (relative to smaller values).
145
+ smoothing_factor: For values greater than 1, the density is initially computed
146
+ at higher resolution and then downsampled, yielding smoother geometries.
147
+ length_scale_constraint_beta: Controls relaxation of the length scale
148
+ constraint near the zero level.
149
+ length_scale_constraint_weight: The weight of the length scale constraint in
150
+ the overall fabrication constraint peenalty.
151
+ curvature_constraint_weight: The weight of the curvature constraint.
152
+ fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
153
+ init_optimizer: The optimizer used in the initialization of the levelset
154
+ parameterization. At initialization, the latent parameters are optimized so
155
+ that the initial parameters match the binarized initial density.
156
+ init_steps: The number of optimization steps used in the initialization.
157
+
158
+ Returns:
159
+ The `Density2DParameterization`.
160
+ """
161
+
162
+ def from_density_fn(density: types.Density2DArray) -> GaussianLevelsetParams:
163
+ """Return level set parameters for the given `density`."""
164
+ length_scale = (density.minimum_width + density.minimum_spacing) / 2
165
+ spacing_factor = length_scale_spacing_factor / length_scale
166
+ shape = density.shape[:-2] + (
167
+ int(jnp.ceil(density.shape[-2] * spacing_factor)),
168
+ int(jnp.ceil(density.shape[-1] * spacing_factor)),
169
+ )
170
+
171
+ mid_value = 0.5 * (density.lower_bound + density.upper_bound)
172
+ value_range = density.upper_bound - density.lower_bound
173
+ target_array = transforms.apply_fixed_pixels(density).array
174
+ target_array = (
175
+ jnp.sign(target_array - mid_value) * 0.5 * value_range + mid_value
176
+ )
177
+
178
+ # Generate the initial amplitude array.
179
+ amplitude = density.array - (density.upper_bound + density.lower_bound) / 2
180
+ amplitude = jnp.sign(amplitude)
181
+ amplitude = transforms.resample(amplitude, shape)
182
+
183
+ # If the density is not periodic, ensure there are level set control points
184
+ # beyond the edge of the density array.
185
+ pad_width = ((0, 0),) * (amplitude.ndim - 2)
186
+ pad_width += ((0, 0),) if density.periodic[0] else ((1, 1),)
187
+ pad_width += ((0, 0),) if density.periodic[1] else ((1, 1),)
188
+ amplitude = jnp.pad(amplitude, pad_width, mode="edge")
189
+
190
+ density_metadata_dict = dataclasses.asdict(density)
191
+ del density_metadata_dict["array"]
192
+ density_metadata = base.Density2DMetadata(**density_metadata_dict)
193
+ params = GaussianLevelsetParams(
194
+ amplitude=amplitude,
195
+ length_scale_spacing_factor=length_scale_spacing_factor,
196
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
197
+ smoothing_factor=smoothing_factor,
198
+ density_shape=density.shape,
199
+ density_metadata=density_metadata,
200
+ )
201
+
202
+ def step_fn(
203
+ _: int,
204
+ params_and_state: Tuple[PyTree, PyTree],
205
+ ) -> Tuple[PyTree, PyTree]:
206
+ def loss_fn(params: GaussianLevelsetParams) -> jnp.ndarray:
207
+ density_from_params = to_density_fn(params, mask_gradient=False)
208
+ return jnp.mean((density_from_params.array - target_array) ** 2)
209
+
210
+ params, state = params_and_state
211
+ grad = jax.grad(loss_fn)(params)
212
+ updates, state = init_optimizer.update(grad, params=params, state=state)
213
+ params = optax.apply_updates(params, updates)
214
+ return params, state
215
+
216
+ state = init_optimizer.init(params)
217
+ params, state = jax.lax.fori_loop(
218
+ 0, init_steps, body_fun=step_fn, init_val=(params, state)
219
+ )
220
+
221
+ maxval = jnp.amax(jnp.abs(params.amplitude), axis=(-2, -1), keepdims=True)
222
+ return dataclasses.replace(params, amplitude=params.amplitude / maxval)
223
+
224
+ def to_density_fn(
225
+ params: GaussianLevelsetParams,
226
+ mask_gradient: bool = True,
227
+ ) -> types.Density2DArray:
228
+ """Return a density from the latent parameters."""
229
+ array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=0)
230
+
231
+ example_density = params.example_density()
232
+ lb = example_density.lower_bound
233
+ ub = example_density.upper_bound
234
+ array = lb + array * (ub - lb)
235
+ assert array.shape == example_density.shape
236
+ return dataclasses.replace(example_density, array=array)
237
+
238
+ def constraints_fn(
239
+ params: GaussianLevelsetParams,
240
+ mask_gradient: bool = True,
241
+ pad_pixels: int = 2,
242
+ ) -> jnp.ndarray:
243
+ """Computes constraints associated with the params."""
244
+ length_scale_constraint, curvature_constraint = _levelset_constraints(
245
+ params,
246
+ beta=length_scale_constraint_beta,
247
+ length_scale_constraint_factor=length_scale_constraint_factor,
248
+ pad_pixels=pad_pixels,
249
+ )
250
+ fixed_pixel_constraint = _fixed_pixel_constraint(
251
+ params,
252
+ mask_gradient=mask_gradient,
253
+ pad_pixels=pad_pixels,
254
+ )
255
+
256
+ constraints = jnp.stack(
257
+ [
258
+ length_scale_constraint * length_scale_constraint_weight,
259
+ curvature_constraint * curvature_constraint_weight,
260
+ fixed_pixel_constraint * fixed_pixel_constraint_weight,
261
+ ],
262
+ axis=-1,
263
+ )
264
+
265
+ # Normalize constraints to make them (somewhat) resolution-independent.
266
+ example_density = params.example_density()
267
+ length_scale = 0.5 * (
268
+ example_density.minimum_spacing + example_density.minimum_width
269
+ )
270
+ return constraints / length_scale**2
271
+
272
+ return base.Density2DParameterization(
273
+ to_density=to_density_fn,
274
+ from_density=from_density_fn,
275
+ constraints=constraints_fn,
276
+ )
277
+
278
+
279
+ # -----------------------------------------------------------------------------
280
+ # Functions to obtain arrays from the levelset parameterization.
281
+ # -----------------------------------------------------------------------------
282
+
283
+
284
+ def _to_array(
285
+ params: GaussianLevelsetParams,
286
+ mask_gradient: bool,
287
+ pad_pixels: int,
288
+ ) -> jnp.ndarray:
289
+ """Return an array from the parameters.
290
+
291
+ The array has a value of `1` where the levelset array is positive, and a value
292
+ of `-1` elsewhere. The final density array can be obtained by rescaling this array
293
+ to have the appropriate upper and lower bounds.
294
+
295
+ Args:
296
+ params: The parameters from which the density is obtained.
297
+ mask_gradient: If `True`, the gradient is masked so that it is nonzero only at
298
+ the borders of features.
299
+ pad_pixels: A non-negative integer giving the additional pixels to be included
300
+ beyond the boundaries of the parameterized density.
301
+
302
+ Returns:
303
+ The array.
304
+ """
305
+ example_density = params.example_density()
306
+ periodic: Tuple[bool, bool] = example_density.periodic
307
+ phi = _phi_from_params(
308
+ params=params,
309
+ pad_pixels=pad_pixels,
310
+ )
311
+ array = _levelset_threshold(
312
+ phi=phi,
313
+ periodic=periodic,
314
+ mask_gradient=mask_gradient,
315
+ )
316
+ return _downsample_spatial_dims(array, params.smoothing_factor)
317
+
318
+
319
+ def _phi_from_params(
320
+ params: GaussianLevelsetParams,
321
+ pad_pixels: int,
322
+ ) -> jnp.ndarray:
323
+ """Return the levelset function for the given `params`.
324
+
325
+ Args:
326
+ params: The parameters from which the density is obtained.
327
+ pad_pixels: A non-negative integer giving the additional pixels to be included
328
+ beyond the boundaries of the parameterized density.
329
+
330
+ Returns:
331
+ The levelset array `phi`.
332
+ """
333
+ with jax.ensure_compile_time_eval():
334
+ example_density = params.example_density()
335
+ length_scale = 0.5 * (
336
+ example_density.minimum_width + example_density.minimum_spacing
337
+ )
338
+ fwhm = length_scale * params.length_scale_fwhm_factor
339
+ sigma = fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
340
+
341
+ highres_i = (
342
+ 0.5
343
+ + jnp.arange(
344
+ params.smoothing_factor * (-pad_pixels),
345
+ params.smoothing_factor * (pad_pixels + example_density.shape[-2]),
346
+ )
347
+ ) / params.smoothing_factor
348
+ highres_j = (
349
+ 0.5
350
+ + jnp.arange(
351
+ params.smoothing_factor * (-pad_pixels),
352
+ params.smoothing_factor * (pad_pixels + example_density.shape[-1]),
353
+ )
354
+ ) / params.smoothing_factor
355
+
356
+ # Coordinates for the control points of the Gaussian radial basis functions.
357
+ levelset_i, levelset_j = _control_point_coords(
358
+ density_shape=params.density_shape[-2:], # type: ignore[arg-type]
359
+ levelset_shape=params.amplitude.shape[-2:], # type: ignore[arg-type]
360
+ periodic=example_density.periodic,
361
+ )
362
+
363
+ # Handle periodicity by replicating control points over a 3x3 supercell.
364
+ if example_density.periodic[0]:
365
+ levelset_i = jnp.concatenate(
366
+ [
367
+ levelset_i - example_density.shape[-2],
368
+ levelset_i,
369
+ levelset_i + example_density.shape[-2],
370
+ ],
371
+ axis=-2,
372
+ )
373
+ levelset_j = jnp.concatenate([levelset_j] * 3, axis=-2)
374
+ if example_density.periodic[1]:
375
+ levelset_i = jnp.concatenate([levelset_i] * 3, axis=-1)
376
+ levelset_j = jnp.concatenate(
377
+ [
378
+ levelset_j - example_density.shape[-1],
379
+ levelset_j,
380
+ levelset_j + example_density.shape[-1],
381
+ ],
382
+ axis=-1,
383
+ )
384
+
385
+ levelset_i = levelset_i.flatten()
386
+ levelset_j = levelset_j.flatten()
387
+
388
+ amplitude = params.amplitude
389
+ if example_density.periodic[0]:
390
+ amplitude = jnp.concat([amplitude] * 3, axis=-2)
391
+ if example_density.periodic[1]:
392
+ amplitude = jnp.concat([amplitude] * 3, axis=-1)
393
+
394
+ amplitude = amplitude.reshape(amplitude.shape[:-2] + (1, -1))
395
+
396
+ # Use a scan operation to compute the array; this lowers memory consumption.
397
+ def scan_fn(_: Tuple[()], i: jnp.ndarray) -> Tuple[Tuple[()], jnp.ndarray]:
398
+ distance_sq = (i - levelset_i) ** 2 + (
399
+ highres_j[:, jnp.newaxis] - levelset_j
400
+ ) ** 2
401
+ basis = jnp.exp(-distance_sq / sigma**2)
402
+ return (), jnp.sum(basis * amplitude, axis=-1)
403
+
404
+ _, array = jax.lax.scan(scan_fn, (), xs=highres_i)
405
+ array = jnp.moveaxis(array, 0, -2)
406
+
407
+ assert array.shape[-2] % params.smoothing_factor == 0
408
+ assert array.shape[-1] % params.smoothing_factor == 0
409
+ array = symmetry.symmetrize(array, tuple(example_density.symmetries))
410
+ return array
411
+
412
+
413
+ # -----------------------------------------------------------------------------
414
+ # Functions to compute constraints.
415
+ # -----------------------------------------------------------------------------
416
+
417
+
418
+ def _fixed_pixel_constraint(
419
+ params: GaussianLevelsetParams,
420
+ mask_gradient: bool,
421
+ pad_pixels: int,
422
+ ) -> jnp.ndarray:
423
+ """Return the fixed pixel constraint array.
424
+
425
+ The fixed pixel constraint array is nonzero at locations where the density obtained
426
+ from `params` differs from fixed pixels.
427
+
428
+ Args:
429
+ params: The parameters from which the density is obtained.
430
+ mask_gradient: If `True`, the gradient is masked so that it is nonzero only at
431
+ the borders of features.
432
+ pad_pixels: The number of pixels added at borders. Values greater than zero
433
+ help to ensure that sharp features at the borders are avoided.
434
+
435
+ Returns:
436
+ The constraints array.
437
+ """
438
+ array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=pad_pixels)
439
+
440
+ example_density = params.example_density()
441
+ fixed_solid = jnp.zeros(example_density.shape, dtype=bool)
442
+ fixed_void = jnp.zeros(example_density.shape, dtype=bool)
443
+ if example_density.fixed_solid is not None:
444
+ fixed_solid = jnp.asarray(example_density.fixed_solid)
445
+ if example_density.fixed_void is not None:
446
+ fixed_void = jnp.asarray(example_density.fixed_void)
447
+
448
+ pad_width = ((0, 0),) * (fixed_solid.ndim - 2) + (
449
+ (pad_pixels, pad_pixels),
450
+ (pad_pixels, pad_pixels),
451
+ )
452
+ fixed_solid = jnp.pad(fixed_solid, pad_width, mode="edge")
453
+ fixed_void = jnp.pad(fixed_void, pad_width, mode="edge")
454
+ fixed = fixed_solid | fixed_void
455
+ target = jnp.where(fixed_solid, 1, 0)
456
+
457
+ return jnp.where(fixed, jnp.abs(array - target), 0.0)
458
+
459
+
460
+ def _levelset_constraints(
461
+ params: GaussianLevelsetParams,
462
+ beta: float,
463
+ length_scale_constraint_factor: float,
464
+ pad_pixels: int,
465
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
466
+ """Compute constraints for minimum width, spacing, and radius of curvature.
467
+
468
+ The constraints are based on "Analytical level set fabrication constraints for
469
+ inverse design," by D. Vercruysse et al. (2019). Constraints are satisfied when
470
+ they are non-positive.
471
+
472
+ https://www.nature.com/articles/s41598-019-45026-0
473
+
474
+ Args:
475
+ params: The parameters of the Gaussian levelset.
476
+ beta: Parameter which relaxes the constraint near the zero-plane.
477
+ length_scale_constraint_factor: Multiplies the target length scale in the
478
+ levelset constraints. A value greater than 1 is pessimistic and drives the
479
+ solution to have a larger length scale (relative to smaller values).
480
+ pad_pixels: A non-negative integer giving the additional pixels to be included
481
+ beyond the boundaries of the parameterized density.
482
+
483
+ Returns:
484
+ The minimum length scale and minimum curvature constraint arrays.
485
+ """
486
+ example_density = params.example_density()
487
+ minimum_length_scale = 0.5 * (
488
+ example_density.minimum_width + example_density.minimum_spacing
489
+ )
490
+
491
+ phi, phi_v, phi_vv, inverse_radius = _phi_derivatives_and_inverse_radius(
492
+ params,
493
+ pad_pixels=pad_pixels,
494
+ )
495
+
496
+ d = minimum_length_scale * length_scale_constraint_factor
497
+ length_scale_constraint = (
498
+ jnp.abs(phi_vv) / (jnp.pi / d * jnp.abs(phi) + beta * phi_v) - jnp.pi / d
499
+ )
500
+ curvature_constraint = (
501
+ jnp.abs(inverse_radius * jnp.arctan(phi_v / phi)) - jnp.pi / d
502
+ )
503
+
504
+ # Downsample so that constraints shape matches the density shape.
505
+ return (
506
+ _downsample_spatial_dims(length_scale_constraint, params.smoothing_factor),
507
+ _downsample_spatial_dims(curvature_constraint, params.smoothing_factor),
508
+ )
509
+
510
+
511
+ def _phi_derivatives_and_inverse_radius(
512
+ params: GaussianLevelsetParams,
513
+ pad_pixels: int,
514
+ ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
515
+ """Compute the levelset function and its first and second derivatives."""
516
+
517
+ phi = _phi_from_params(
518
+ params=params,
519
+ pad_pixels=pad_pixels,
520
+ )
521
+
522
+ d = 1 / params.smoothing_factor
523
+ phi_x, phi_y = jnp.gradient(phi, d, axis=(-2, -1))
524
+ phi_xx, phi_yx = jnp.gradient(phi_x, d, axis=(-2, -1))
525
+ phi_xy, phi_yy = jnp.gradient(phi_y, d, axis=(-2, -1))
526
+
527
+ phi_v = _sqrt_safe(phi_x**2 + phi_y**2)
528
+
529
+ # Compute "safe" versions of `phi_v` and its square, which are used to
530
+ # normalize quantities below. These are equal to 1 anywhere `phi_v` is
531
+ # close to zero, and take their usual values elsewhere.
532
+ phi_v_near_zero = jnp.isclose(phi_v, 0.0)
533
+ phi_v_squared_safe = jnp.where(phi_v_near_zero, 1.0, phi_v**2)
534
+ phi_v_safe = jnp.where(phi_v_near_zero, 1.0, phi_v)
535
+
536
+ weight_xx = phi_x**2 / phi_v_squared_safe
537
+ weight_yy = phi_y**2 / phi_v_squared_safe
538
+ weight_xy = (phi_x * phi_y) / phi_v_squared_safe
539
+ phi_vv = weight_xx * phi_xx + weight_xy * (phi_xy + phi_yx) + weight_yy * phi_yy
540
+
541
+ qx: jnp.ndarray = jnp.gradient( # type: ignore[assignment]
542
+ phi_x / jnp.abs(phi_v_safe), d, axis=-2
543
+ )
544
+ qy: jnp.ndarray = jnp.gradient( # type: ignore[assignment]
545
+ phi_y / jnp.abs(phi_v_safe), d, axis=-1
546
+ )
547
+ inverse_radius = qx + qy
548
+
549
+ return phi, phi_v, phi_vv, inverse_radius
550
+
551
+
552
+ # -----------------------------------------------------------------------------
553
+ # Helper functions.
554
+ # -----------------------------------------------------------------------------
555
+
556
+
557
+ def _control_point_coords(
558
+ density_shape: Tuple[int, int],
559
+ levelset_shape: Tuple[int, int],
560
+ periodic: Tuple[bool, bool],
561
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
562
+ """Returns the control point coordinates."""
563
+ # If the levelset is periodic along any axis, the first and last control
564
+ # points along that axis lie outside the bounds of the density.
565
+ offset_i = 0.5 if periodic[0] else -0.5
566
+ offset_j = 0.5 if periodic[1] else -0.5
567
+ range_i = levelset_shape[-2] - (0 if periodic[0] else 2)
568
+ range_j = levelset_shape[-1] - (0 if periodic[1] else 2)
569
+
570
+ factor_i = density_shape[-2] / range_i
571
+ factor_j = density_shape[-1] / range_j
572
+ levelset_i, levelset_j = jnp.meshgrid(
573
+ (offset_i + jnp.arange(levelset_shape[-2])) * factor_i,
574
+ (offset_j + jnp.arange(levelset_shape[-1])) * factor_j,
575
+ indexing="ij",
576
+ )
577
+ return levelset_i, levelset_j
578
+
579
+
580
+ def _sqrt_safe(x: jnp.ndarray) -> jnp.ndarray:
581
+ """Compute square root while avoiding `nan` gradients near zero."""
582
+ x_near_zero = jnp.isclose(x, 0.0)
583
+ x_safe = jnp.where(x_near_zero, 1, x)
584
+ return jnp.where(x_near_zero, 0.0, jnp.sqrt(x_safe))
585
+
586
+
587
+ def _levelset_threshold(
588
+ phi: jnp.ndarray,
589
+ periodic: Tuple[bool, bool],
590
+ mask_gradient: bool,
591
+ ) -> jnp.ndarray:
592
+ """Thresholds a level set function `phi`."""
593
+ if mask_gradient:
594
+ interface = _interface_pixels(phi, periodic)
595
+ phi = jnp.where(interface, phi, jax.lax.stop_gradient(phi))
596
+ thresholded = (phi > 0).astype(float) + (phi - jax.lax.stop_gradient(phi))
597
+ return thresholded
598
+
599
+
600
+ def _interface_pixels(phi: jnp.ndarray, periodic: Tuple[bool, bool]) -> jnp.ndarray:
601
+ """Identifies interface pixels of a level set function `phi`."""
602
+ batch_shape = phi.shape[:-2]
603
+ phi = phi.reshape((-1,) + phi.shape[-2:])
604
+
605
+ pad_mode = (
606
+ "wrap" if periodic[0] else "edge",
607
+ "wrap" if periodic[1] else "edge",
608
+ )
609
+ pad_width = ((1, 1), (1, 1))
610
+
611
+ kernel = jnp.asarray([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float)
612
+
613
+ solid = phi > 0
614
+ void = ~solid
615
+
616
+ solid_padded = transforms.pad2d(solid, pad_width, pad_mode)
617
+ num_solid_adjacent = transforms.conv(
618
+ x=solid_padded[:, jnp.newaxis, :, :].astype(float),
619
+ kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
620
+ padding="VALID",
621
+ )
622
+ num_solid_adjacent = jnp.squeeze(num_solid_adjacent, axis=1)
623
+
624
+ void_padded = transforms.pad2d(void, pad_width, pad_mode)
625
+ num_void_adjacent = transforms.conv(
626
+ x=void_padded[:, jnp.newaxis, :, :].astype(float),
627
+ kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
628
+ padding="VALID",
629
+ )
630
+ num_void_adjacent = jnp.squeeze(num_void_adjacent, axis=1)
631
+
632
+ interface = solid & (num_void_adjacent > 0) | void & (num_solid_adjacent > 0)
633
+
634
+ return interface.reshape(batch_shape + interface.shape[-2:])
635
+
636
+
637
+ def _downsample_spatial_dims(x: jnp.ndarray, downsample_factor: int) -> jnp.ndarray:
638
+ """Downsamples the two trailing axes of `x` by `downsample_factor`."""
639
+ shape = x.shape[:-2] + (
640
+ x.shape[-2] // downsample_factor,
641
+ x.shape[-1] // downsample_factor,
642
+ )
643
+ return transforms.box_downsample(x, shape)
@@ -0,0 +1,45 @@
1
+ """Defines the direct pixel parameterization for density arrays.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+
8
+ import jax.numpy as jnp
9
+ from jax import tree_util
10
+ from totypes import json_utils, types
11
+
12
+ from invrs_opt.parameterization import base
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class PixelParams(base.ParameterizedDensity2DArrayBase):
17
+ """Stores latent parameters of the direct pixel parameterization."""
18
+
19
+ density: types.Density2DArray
20
+
21
+
22
+ tree_util.register_dataclass(PixelParams, data_fields=["density"], meta_fields=[])
23
+
24
+
25
+ json_utils.register_custom_type(PixelParams)
26
+
27
+
28
+ def pixel() -> base.Density2DParameterization:
29
+ """Return the direct pixel parameterization."""
30
+
31
+ def from_density_fn(density: types.Density2DArray) -> PixelParams:
32
+ return PixelParams(density=density)
33
+
34
+ def to_density_fn(params: PixelParams) -> types.Density2DArray:
35
+ return params.density
36
+
37
+ def constraints_fn(params: PixelParams) -> jnp.ndarray:
38
+ del params
39
+ return jnp.asarray(0.0)
40
+
41
+ return base.Density2DParameterization(
42
+ from_density=from_density_fn,
43
+ to_density=to_density_fn,
44
+ constraints=constraints_fn,
45
+ )