invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +155 -0
- invrs_opt/experimental/labels.py +23 -0
- invrs_opt/optimizers/__init__.py +0 -0
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -670
- invrs_opt-0.3.2.dist-info/LICENSE +0 -21
- invrs_opt-0.3.2.dist-info/METADATA +0 -73
- invrs_opt-0.3.2.dist-info/RECORD +0 -11
- /invrs_opt/{lbfgsb → experimental}/__init__.py +0 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,671 @@
|
|
1
|
+
"""Defines Gaussian radial basis function levelset parameterization.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
from typing import Any, Tuple
|
8
|
+
|
9
|
+
import jax
|
10
|
+
import jax.numpy as jnp
|
11
|
+
import optax # type: ignore[import-untyped]
|
12
|
+
from jax import tree_util
|
13
|
+
from totypes import json_utils, symmetry, types
|
14
|
+
|
15
|
+
from invrs_opt.parameterization import base, transforms
|
16
|
+
|
17
|
+
PyTree = Any
|
18
|
+
|
19
|
+
DEFAULT_LENGTH_SCALE_SPACING_FACTOR: float = 2.0
|
20
|
+
DEFAULT_LENGTH_SCALE_FWHM_FACTOR: float = 1.0
|
21
|
+
DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR: float = 1.15
|
22
|
+
DEFAULT_SMOOTHING_FACTOR: int = 2
|
23
|
+
DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA: float = 0.333
|
24
|
+
DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT: float = 1.0
|
25
|
+
DEFAULT_CURVATURE_CONSTRAINT_WEIGHT: float = 2.0
|
26
|
+
DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT: float = 10.0
|
27
|
+
DEFAULT_INIT_STEPS: int = 50
|
28
|
+
DEFAULT_INIT_OPTIMIZER: optax.GradientTransformation = optax.adam(1e-1)
|
29
|
+
|
30
|
+
|
31
|
+
@dataclasses.dataclass
|
32
|
+
class GaussianLevelsetParams(base.ParameterizedDensity2DArray):
|
33
|
+
"""Stores parameters for the Gaussian levelset parameterization."""
|
34
|
+
|
35
|
+
latents: "GaussianLevelsetLatents"
|
36
|
+
metadata: "GaussianLevelsetMetadata"
|
37
|
+
|
38
|
+
|
39
|
+
@dataclasses.dataclass
|
40
|
+
class GaussianLevelsetLatents(base.LatentsBase):
|
41
|
+
"""Stores latent parameters for the Gaussian levelset parameterization.
|
42
|
+
|
43
|
+
Attributes:
|
44
|
+
amplitude: Array giving the amplitude of the Gaussian basis function at
|
45
|
+
levelset control points.
|
46
|
+
"""
|
47
|
+
|
48
|
+
amplitude: jnp.ndarray
|
49
|
+
|
50
|
+
|
51
|
+
@dataclasses.dataclass
|
52
|
+
class GaussianLevelsetMetadata(base.MetadataBase):
|
53
|
+
"""Stores metadata for the Gaussian levelset parameterization.
|
54
|
+
|
55
|
+
Attributes:
|
56
|
+
length_scale_spacing_factor: The number of levelset control points per unit of
|
57
|
+
minimum length scale (mean of density minimum width and minimum spacing).
|
58
|
+
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
59
|
+
the minimum length scale.
|
60
|
+
smoothing_factor: For values greater than 1, the density is initially computed
|
61
|
+
at higher resolution and then downsampled, yielding smoother geometries.
|
62
|
+
density_shape: Shape of the density array obtained from the parameters.
|
63
|
+
density_metadata: Metadata for the density array obtained from the parameters.
|
64
|
+
"""
|
65
|
+
|
66
|
+
length_scale_spacing_factor: float
|
67
|
+
length_scale_fwhm_factor: float
|
68
|
+
smoothing_factor: int
|
69
|
+
density_shape: Tuple[int, ...]
|
70
|
+
density_metadata: base.Density2DMetadata
|
71
|
+
|
72
|
+
def __post_init__(self) -> None:
|
73
|
+
self.density_shape = tuple(self.density_shape)
|
74
|
+
|
75
|
+
|
76
|
+
tree_util.register_dataclass(
|
77
|
+
GaussianLevelsetParams,
|
78
|
+
data_fields=["latents", "metadata"],
|
79
|
+
meta_fields=[],
|
80
|
+
)
|
81
|
+
tree_util.register_dataclass(
|
82
|
+
GaussianLevelsetLatents,
|
83
|
+
data_fields=["amplitude"],
|
84
|
+
meta_fields=[],
|
85
|
+
)
|
86
|
+
tree_util.register_dataclass(
|
87
|
+
GaussianLevelsetMetadata,
|
88
|
+
data_fields=[
|
89
|
+
"length_scale_spacing_factor",
|
90
|
+
"length_scale_fwhm_factor",
|
91
|
+
"density_metadata",
|
92
|
+
],
|
93
|
+
meta_fields=["density_shape", "smoothing_factor"],
|
94
|
+
)
|
95
|
+
json_utils.register_custom_type(GaussianLevelsetParams)
|
96
|
+
json_utils.register_custom_type(GaussianLevelsetLatents)
|
97
|
+
json_utils.register_custom_type(GaussianLevelsetMetadata)
|
98
|
+
|
99
|
+
|
100
|
+
def gaussian_levelset(
|
101
|
+
*,
|
102
|
+
length_scale_spacing_factor: float = DEFAULT_LENGTH_SCALE_SPACING_FACTOR,
|
103
|
+
length_scale_fwhm_factor: float = DEFAULT_LENGTH_SCALE_FWHM_FACTOR,
|
104
|
+
length_scale_constraint_factor: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR,
|
105
|
+
smoothing_factor: int = DEFAULT_SMOOTHING_FACTOR,
|
106
|
+
length_scale_constraint_beta: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA,
|
107
|
+
length_scale_constraint_weight: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT,
|
108
|
+
curvature_constraint_weight: float = DEFAULT_CURVATURE_CONSTRAINT_WEIGHT,
|
109
|
+
fixed_pixel_constraint_weight: float = DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT,
|
110
|
+
init_optimizer: optax.GradientTransformation = DEFAULT_INIT_OPTIMIZER,
|
111
|
+
init_steps: int = DEFAULT_INIT_STEPS,
|
112
|
+
) -> base.Density2DParameterization:
|
113
|
+
"""Defines a levelset parameterization with Gaussian radial basis functions.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
length_scale_spacing_factor: The number of levelset control points per unit of
|
117
|
+
minimum length scale (mean of density minimum width and minimum spacing).
|
118
|
+
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
119
|
+
the minimum length scale.
|
120
|
+
length_scale_constraint_factor: Multiplies the target length scale in the
|
121
|
+
levelset constraints. A value greater than 1 is pessimistic and drives the
|
122
|
+
solution to have a larger length scale (relative to smaller values).
|
123
|
+
smoothing_factor: For values greater than 1, the density is initially computed
|
124
|
+
at higher resolution and then downsampled, yielding smoother geometries.
|
125
|
+
length_scale_constraint_beta: Controls relaxation of the length scale
|
126
|
+
constraint near the zero level.
|
127
|
+
length_scale_constraint_weight: The weight of the length scale constraint in
|
128
|
+
the overall fabrication constraint peenalty.
|
129
|
+
curvature_constraint_weight: The weight of the curvature constraint.
|
130
|
+
fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
|
131
|
+
init_optimizer: The optimizer used in the initialization of the levelset
|
132
|
+
parameterization. At initialization, the latent parameters are optimized so
|
133
|
+
that the initial parameters match the binarized initial density.
|
134
|
+
init_steps: The number of optimization steps used in the initialization.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
The `Density2DParameterization`.
|
138
|
+
"""
|
139
|
+
|
140
|
+
def from_density_fn(density: types.Density2DArray) -> GaussianLevelsetParams:
|
141
|
+
"""Return level set parameters for the given `density`."""
|
142
|
+
length_scale = (density.minimum_width + density.minimum_spacing) / 2
|
143
|
+
spacing_factor = length_scale_spacing_factor / length_scale
|
144
|
+
shape = density.shape[:-2] + (
|
145
|
+
int(jnp.ceil(density.shape[-2] * spacing_factor)),
|
146
|
+
int(jnp.ceil(density.shape[-1] * spacing_factor)),
|
147
|
+
)
|
148
|
+
|
149
|
+
mid_value = 0.5 * (density.lower_bound + density.upper_bound)
|
150
|
+
value_range = density.upper_bound - density.lower_bound
|
151
|
+
target_array = transforms.apply_fixed_pixels(density).array
|
152
|
+
target_array = (
|
153
|
+
jnp.sign(target_array - mid_value) * 0.5 * value_range + mid_value
|
154
|
+
)
|
155
|
+
|
156
|
+
# Generate the initial amplitude array.
|
157
|
+
amplitude = density.array - (density.upper_bound + density.lower_bound) / 2
|
158
|
+
amplitude = jnp.sign(amplitude)
|
159
|
+
amplitude = transforms.resample(amplitude, shape)
|
160
|
+
|
161
|
+
# If the density is not periodic, ensure there are level set control points
|
162
|
+
# beyond the edge of the density array.
|
163
|
+
pad_width = ((0, 0),) * (amplitude.ndim - 2)
|
164
|
+
pad_width += ((0, 0),) if density.periodic[0] else ((1, 1),)
|
165
|
+
pad_width += ((0, 0),) if density.periodic[1] else ((1, 1),)
|
166
|
+
amplitude = jnp.pad(amplitude, pad_width, mode="edge")
|
167
|
+
|
168
|
+
latents = GaussianLevelsetLatents(amplitude=amplitude)
|
169
|
+
metadata = GaussianLevelsetMetadata(
|
170
|
+
length_scale_spacing_factor=length_scale_spacing_factor,
|
171
|
+
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
172
|
+
smoothing_factor=smoothing_factor,
|
173
|
+
density_shape=density.shape,
|
174
|
+
density_metadata=base.Density2DMetadata.from_density(density),
|
175
|
+
)
|
176
|
+
|
177
|
+
def step_fn(
|
178
|
+
_: int,
|
179
|
+
params_and_state: Tuple[PyTree, PyTree],
|
180
|
+
) -> Tuple[PyTree, PyTree]:
|
181
|
+
def loss_fn(latents: GaussianLevelsetLatents) -> jnp.ndarray:
|
182
|
+
params = GaussianLevelsetParams(latents, metadata=metadata)
|
183
|
+
density_from_params = to_density_fn(params, mask_gradient=False)
|
184
|
+
return jnp.mean((density_from_params.array - target_array) ** 2)
|
185
|
+
|
186
|
+
params, state = params_and_state
|
187
|
+
grad = jax.grad(loss_fn)(params)
|
188
|
+
updates, state = init_optimizer.update(grad, params=params, state=state)
|
189
|
+
params = optax.apply_updates(params, updates)
|
190
|
+
return params, state
|
191
|
+
|
192
|
+
state = init_optimizer.init(latents)
|
193
|
+
latents, _ = jax.lax.fori_loop(
|
194
|
+
0, init_steps, body_fun=step_fn, init_val=(latents, state)
|
195
|
+
)
|
196
|
+
|
197
|
+
maxval = jnp.amax(jnp.abs(latents.amplitude), axis=(-2, -1), keepdims=True)
|
198
|
+
latents = dataclasses.replace(latents, amplitude=latents.amplitude / maxval)
|
199
|
+
return GaussianLevelsetParams(latents=latents, metadata=metadata)
|
200
|
+
|
201
|
+
def to_density_fn(
|
202
|
+
params: GaussianLevelsetParams,
|
203
|
+
mask_gradient: bool = True,
|
204
|
+
) -> types.Density2DArray:
|
205
|
+
"""Return a density from the latent parameters."""
|
206
|
+
array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=0)
|
207
|
+
|
208
|
+
example_density = _example_density(params)
|
209
|
+
lb = example_density.lower_bound
|
210
|
+
ub = example_density.upper_bound
|
211
|
+
array = lb + array * (ub - lb)
|
212
|
+
assert array.shape == example_density.shape
|
213
|
+
return dataclasses.replace(example_density, array=array)
|
214
|
+
|
215
|
+
def constraints_fn(
|
216
|
+
params: GaussianLevelsetParams,
|
217
|
+
mask_gradient: bool = True,
|
218
|
+
pad_pixels: int = 2,
|
219
|
+
) -> jnp.ndarray:
|
220
|
+
"""Computes constraints associated with the params."""
|
221
|
+
return analytical_constraints(
|
222
|
+
params=params,
|
223
|
+
length_scale_constraint_factor=length_scale_constraint_factor,
|
224
|
+
length_scale_constraint_beta=length_scale_constraint_beta,
|
225
|
+
length_scale_constraint_weight=length_scale_constraint_weight,
|
226
|
+
curvature_constraint_weight=curvature_constraint_weight,
|
227
|
+
fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
|
228
|
+
mask_gradient=mask_gradient,
|
229
|
+
pad_pixels=pad_pixels,
|
230
|
+
)
|
231
|
+
|
232
|
+
def update_fn(
|
233
|
+
params: GaussianLevelsetParams,
|
234
|
+
updates: GaussianLevelsetParams,
|
235
|
+
value: jnp.ndarray,
|
236
|
+
step: int,
|
237
|
+
) -> GaussianLevelsetParams:
|
238
|
+
"""Perform updates to `params` required for the given `step`."""
|
239
|
+
del step, value
|
240
|
+
return GaussianLevelsetParams(
|
241
|
+
latents=tree_util.tree_map(
|
242
|
+
lambda a, b: a + b, params.latents, updates.latents
|
243
|
+
),
|
244
|
+
metadata=params.metadata,
|
245
|
+
)
|
246
|
+
|
247
|
+
return base.Density2DParameterization(
|
248
|
+
to_density=to_density_fn,
|
249
|
+
from_density=from_density_fn,
|
250
|
+
constraints=constraints_fn,
|
251
|
+
update=update_fn,
|
252
|
+
)
|
253
|
+
|
254
|
+
|
255
|
+
# -----------------------------------------------------------------------------
|
256
|
+
# Functions to obtain arrays from the levelset parameterization.
|
257
|
+
# -----------------------------------------------------------------------------
|
258
|
+
|
259
|
+
|
260
|
+
def _example_density(params: GaussianLevelsetParams) -> types.Density2DArray:
|
261
|
+
"""Returns an example density with appropriate shape and metadata."""
|
262
|
+
with jax.ensure_compile_time_eval():
|
263
|
+
return types.Density2DArray(
|
264
|
+
array=jnp.zeros(params.metadata.density_shape),
|
265
|
+
**dataclasses.asdict(params.metadata.density_metadata),
|
266
|
+
)
|
267
|
+
|
268
|
+
|
269
|
+
def _to_array(
|
270
|
+
params: GaussianLevelsetParams,
|
271
|
+
mask_gradient: bool,
|
272
|
+
pad_pixels: int,
|
273
|
+
) -> jnp.ndarray:
|
274
|
+
"""Return an array from the parameters.
|
275
|
+
|
276
|
+
The array has a value of `1` where the levelset array is positive, and a value
|
277
|
+
of `-1` elsewhere. The final density array can be obtained by rescaling this array
|
278
|
+
to have the appropriate upper and lower bounds.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
params: The parameters from which the density is obtained.
|
282
|
+
mask_gradient: If `True`, the gradient is masked so that it is nonzero only at
|
283
|
+
the borders of features.
|
284
|
+
pad_pixels: A non-negative integer giving the additional pixels to be included
|
285
|
+
beyond the boundaries of the parameterized density.
|
286
|
+
|
287
|
+
Returns:
|
288
|
+
The array.
|
289
|
+
"""
|
290
|
+
example_density = _example_density(params)
|
291
|
+
periodic: Tuple[bool, bool] = example_density.periodic
|
292
|
+
phi = _phi_from_params(params=params, pad_pixels=pad_pixels)
|
293
|
+
array = _levelset_threshold(phi=phi, periodic=periodic, mask_gradient=mask_gradient)
|
294
|
+
return _downsample_spatial_dims(array, params.metadata.smoothing_factor)
|
295
|
+
|
296
|
+
|
297
|
+
def _phi_from_params(
|
298
|
+
params: GaussianLevelsetParams,
|
299
|
+
pad_pixels: int,
|
300
|
+
) -> jnp.ndarray:
|
301
|
+
"""Return the levelset function for the given `params`.
|
302
|
+
|
303
|
+
Args:
|
304
|
+
params: The parameters from which the density is obtained.
|
305
|
+
pad_pixels: A non-negative integer giving the additional pixels to be included
|
306
|
+
beyond the boundaries of the parameterized density.
|
307
|
+
|
308
|
+
Returns:
|
309
|
+
The levelset array `phi`.
|
310
|
+
"""
|
311
|
+
with jax.ensure_compile_time_eval():
|
312
|
+
example_density = _example_density(params)
|
313
|
+
length_scale = 0.5 * (
|
314
|
+
example_density.minimum_width + example_density.minimum_spacing
|
315
|
+
)
|
316
|
+
fwhm = length_scale * params.metadata.length_scale_fwhm_factor
|
317
|
+
sigma = fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
|
318
|
+
|
319
|
+
s_factor = params.metadata.smoothing_factor
|
320
|
+
highres_i = (
|
321
|
+
0.5
|
322
|
+
+ jnp.arange(
|
323
|
+
s_factor * (-pad_pixels),
|
324
|
+
s_factor * (pad_pixels + example_density.shape[-2]),
|
325
|
+
)
|
326
|
+
) / s_factor
|
327
|
+
highres_j = (
|
328
|
+
0.5
|
329
|
+
+ jnp.arange(
|
330
|
+
s_factor * (-pad_pixels),
|
331
|
+
s_factor * (pad_pixels + example_density.shape[-1]),
|
332
|
+
)
|
333
|
+
) / s_factor
|
334
|
+
|
335
|
+
# Coordinates for the control points of the Gaussian radial basis functions.
|
336
|
+
levelset_i, levelset_j = _control_point_coords(
|
337
|
+
density_shape=params.metadata.density_shape[-2:], # type: ignore[arg-type]
|
338
|
+
levelset_shape=(
|
339
|
+
params.latents.amplitude.shape[-2:] # type: ignore[arg-type]
|
340
|
+
),
|
341
|
+
periodic=example_density.periodic,
|
342
|
+
)
|
343
|
+
|
344
|
+
# Handle periodicity by replicating control points over a 3x3 supercell.
|
345
|
+
if example_density.periodic[0]:
|
346
|
+
levelset_i = jnp.concatenate(
|
347
|
+
[
|
348
|
+
levelset_i - example_density.shape[-2],
|
349
|
+
levelset_i,
|
350
|
+
levelset_i + example_density.shape[-2],
|
351
|
+
],
|
352
|
+
axis=-2,
|
353
|
+
)
|
354
|
+
levelset_j = jnp.concatenate([levelset_j] * 3, axis=-2)
|
355
|
+
if example_density.periodic[1]:
|
356
|
+
levelset_i = jnp.concatenate([levelset_i] * 3, axis=-1)
|
357
|
+
levelset_j = jnp.concatenate(
|
358
|
+
[
|
359
|
+
levelset_j - example_density.shape[-1],
|
360
|
+
levelset_j,
|
361
|
+
levelset_j + example_density.shape[-1],
|
362
|
+
],
|
363
|
+
axis=-1,
|
364
|
+
)
|
365
|
+
|
366
|
+
levelset_i = levelset_i.flatten()
|
367
|
+
levelset_j = levelset_j.flatten()
|
368
|
+
|
369
|
+
amplitude = params.latents.amplitude
|
370
|
+
if example_density.periodic[0]:
|
371
|
+
amplitude = jnp.concat([amplitude] * 3, axis=-2)
|
372
|
+
if example_density.periodic[1]:
|
373
|
+
amplitude = jnp.concat([amplitude] * 3, axis=-1)
|
374
|
+
|
375
|
+
amplitude = amplitude.reshape(amplitude.shape[:-2] + (1, -1))
|
376
|
+
|
377
|
+
# Use a scan operation to compute the array; this lowers memory consumption.
|
378
|
+
def scan_fn(_: Tuple[()], i: jnp.ndarray) -> Tuple[Tuple[()], jnp.ndarray]:
|
379
|
+
distance_sq = (i - levelset_i) ** 2 + (
|
380
|
+
highres_j[:, jnp.newaxis] - levelset_j
|
381
|
+
) ** 2
|
382
|
+
basis = jnp.exp(-distance_sq / sigma**2)
|
383
|
+
return (), jnp.sum(basis * amplitude, axis=-1)
|
384
|
+
|
385
|
+
_, array = jax.lax.scan(scan_fn, (), xs=highres_i)
|
386
|
+
array = jnp.moveaxis(array, 0, -2)
|
387
|
+
|
388
|
+
assert array.shape[-2] % s_factor == 0
|
389
|
+
assert array.shape[-1] % s_factor == 0
|
390
|
+
array = symmetry.symmetrize(array, tuple(example_density.symmetries))
|
391
|
+
return array
|
392
|
+
|
393
|
+
|
394
|
+
# -----------------------------------------------------------------------------
|
395
|
+
# Functions to compute constraints.
|
396
|
+
# -----------------------------------------------------------------------------
|
397
|
+
|
398
|
+
|
399
|
+
def analytical_constraints(
|
400
|
+
params: GaussianLevelsetParams,
|
401
|
+
length_scale_constraint_factor: float,
|
402
|
+
length_scale_constraint_beta: float,
|
403
|
+
length_scale_constraint_weight: float,
|
404
|
+
curvature_constraint_weight: float,
|
405
|
+
fixed_pixel_constraint_weight: float,
|
406
|
+
mask_gradient: bool,
|
407
|
+
pad_pixels: int,
|
408
|
+
) -> jnp.ndarray:
|
409
|
+
"""Computes analytical levelset constraints associated with the params."""
|
410
|
+
length_scale_constraint, curvature_constraint = _levelset_constraints(
|
411
|
+
params,
|
412
|
+
beta=length_scale_constraint_beta,
|
413
|
+
length_scale_constraint_factor=length_scale_constraint_factor,
|
414
|
+
pad_pixels=pad_pixels,
|
415
|
+
)
|
416
|
+
fixed_pixel_constraint = _fixed_pixel_constraint(
|
417
|
+
params,
|
418
|
+
mask_gradient=mask_gradient,
|
419
|
+
pad_pixels=pad_pixels,
|
420
|
+
)
|
421
|
+
|
422
|
+
constraints = jnp.stack(
|
423
|
+
[
|
424
|
+
length_scale_constraint * length_scale_constraint_weight,
|
425
|
+
curvature_constraint * curvature_constraint_weight,
|
426
|
+
fixed_pixel_constraint * fixed_pixel_constraint_weight,
|
427
|
+
],
|
428
|
+
axis=-1,
|
429
|
+
)
|
430
|
+
|
431
|
+
# Normalize constraints to make them (somewhat) resolution-independent.
|
432
|
+
example_density = _example_density(params)
|
433
|
+
length_scale = 0.5 * (
|
434
|
+
example_density.minimum_spacing + example_density.minimum_width
|
435
|
+
)
|
436
|
+
return constraints / length_scale**2
|
437
|
+
|
438
|
+
|
439
|
+
def _fixed_pixel_constraint(
|
440
|
+
params: GaussianLevelsetParams,
|
441
|
+
mask_gradient: bool,
|
442
|
+
pad_pixels: int,
|
443
|
+
) -> jnp.ndarray:
|
444
|
+
"""Return the fixed pixel constraint array.
|
445
|
+
|
446
|
+
The fixed pixel constraint array is nonzero at locations where the density obtained
|
447
|
+
from `params` differs from fixed pixels.
|
448
|
+
|
449
|
+
Args:
|
450
|
+
params: The parameters from which the density is obtained.
|
451
|
+
mask_gradient: If `True`, the gradient is masked so that it is nonzero only at
|
452
|
+
the borders of features.
|
453
|
+
pad_pixels: The number of pixels added at borders. Values greater than zero
|
454
|
+
help to ensure that sharp features at the borders are avoided.
|
455
|
+
|
456
|
+
Returns:
|
457
|
+
The constraints array.
|
458
|
+
"""
|
459
|
+
array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=pad_pixels)
|
460
|
+
|
461
|
+
example_density = _example_density(params)
|
462
|
+
fixed_solid = jnp.zeros(example_density.shape[-2:], dtype=bool)
|
463
|
+
fixed_void = jnp.zeros(example_density.shape[-2:], dtype=bool)
|
464
|
+
if example_density.fixed_solid is not None:
|
465
|
+
fixed_solid = jnp.asarray(example_density.fixed_solid)
|
466
|
+
if example_density.fixed_void is not None:
|
467
|
+
fixed_void = jnp.asarray(example_density.fixed_void)
|
468
|
+
|
469
|
+
pad_width_solid = ((0, 0),) * (fixed_solid.ndim - 2) + (
|
470
|
+
(pad_pixels, pad_pixels),
|
471
|
+
(pad_pixels, pad_pixels),
|
472
|
+
)
|
473
|
+
pad_width_void = ((0, 0),) * (fixed_void.ndim - 2) + (
|
474
|
+
(pad_pixels, pad_pixels),
|
475
|
+
(pad_pixels, pad_pixels),
|
476
|
+
)
|
477
|
+
fixed_solid = jnp.pad(fixed_solid, pad_width_solid, mode="edge")
|
478
|
+
fixed_void = jnp.pad(fixed_void, pad_width_void, mode="edge")
|
479
|
+
fixed = fixed_solid | fixed_void
|
480
|
+
target = jnp.where(fixed_solid, 1, 0)
|
481
|
+
|
482
|
+
return jnp.where(fixed, jnp.abs(array - target), 0.0)
|
483
|
+
|
484
|
+
|
485
|
+
def _levelset_constraints(
|
486
|
+
params: GaussianLevelsetParams,
|
487
|
+
beta: float,
|
488
|
+
length_scale_constraint_factor: float,
|
489
|
+
pad_pixels: int,
|
490
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
491
|
+
"""Compute constraints for minimum width, spacing, and radius of curvature.
|
492
|
+
|
493
|
+
The constraints are based on "Analytical level set fabrication constraints for
|
494
|
+
inverse design," by D. Vercruysse et al. (2019). Constraints are satisfied when
|
495
|
+
they are non-positive.
|
496
|
+
|
497
|
+
https://www.nature.com/articles/s41598-019-45026-0
|
498
|
+
|
499
|
+
Args:
|
500
|
+
params: The parameters of the Gaussian levelset.
|
501
|
+
beta: Parameter which relaxes the constraint near the zero-plane.
|
502
|
+
length_scale_constraint_factor: Multiplies the target length scale in the
|
503
|
+
levelset constraints. A value greater than 1 is pessimistic and drives the
|
504
|
+
solution to have a larger length scale (relative to smaller values).
|
505
|
+
pad_pixels: A non-negative integer giving the additional pixels to be included
|
506
|
+
beyond the boundaries of the parameterized density.
|
507
|
+
|
508
|
+
Returns:
|
509
|
+
The minimum length scale and minimum curvature constraint arrays.s
|
510
|
+
"""
|
511
|
+
example_density = _example_density(params)
|
512
|
+
minimum_length_scale = 0.5 * (
|
513
|
+
example_density.minimum_width + example_density.minimum_spacing
|
514
|
+
)
|
515
|
+
|
516
|
+
phi, phi_v, phi_vv, inverse_radius = _phi_derivatives_and_inverse_radius(
|
517
|
+
params,
|
518
|
+
pad_pixels=pad_pixels,
|
519
|
+
)
|
520
|
+
|
521
|
+
d = minimum_length_scale * length_scale_constraint_factor
|
522
|
+
denom = jnp.pi / d * jnp.abs(phi) + beta * phi_v
|
523
|
+
denom_safe = jnp.where(jnp.isclose(phi_vv, 0.0), 1.0, denom)
|
524
|
+
length_scale_constraint = jnp.abs(phi_vv) / denom_safe - jnp.pi / d
|
525
|
+
|
526
|
+
curvature_denom_safe = jnp.where(jnp.isclose(phi_v, 0.0), 1.0, phi)
|
527
|
+
curvature_constraint = (
|
528
|
+
jnp.abs(inverse_radius * jnp.arctan(phi_v / curvature_denom_safe)) - jnp.pi / d
|
529
|
+
)
|
530
|
+
|
531
|
+
# Downsample so that constraints shape matches the density shape.
|
532
|
+
factor = params.metadata.smoothing_factor
|
533
|
+
return (
|
534
|
+
_downsample_spatial_dims(length_scale_constraint, factor),
|
535
|
+
_downsample_spatial_dims(curvature_constraint, factor),
|
536
|
+
)
|
537
|
+
|
538
|
+
|
539
|
+
def _phi_derivatives_and_inverse_radius(
|
540
|
+
params: GaussianLevelsetParams,
|
541
|
+
pad_pixels: int,
|
542
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
543
|
+
"""Compute the levelset function and its first and second derivatives."""
|
544
|
+
|
545
|
+
phi = _phi_from_params(
|
546
|
+
params=params,
|
547
|
+
pad_pixels=pad_pixels,
|
548
|
+
)
|
549
|
+
|
550
|
+
d = 1 / params.metadata.smoothing_factor
|
551
|
+
phi_x, phi_y = jnp.gradient(phi, d, axis=(-2, -1))
|
552
|
+
phi_xx, phi_yx = jnp.gradient(phi_x, d, axis=(-2, -1))
|
553
|
+
phi_xy, phi_yy = jnp.gradient(phi_y, d, axis=(-2, -1))
|
554
|
+
|
555
|
+
phi_v = _sqrt_safe(phi_x**2 + phi_y**2)
|
556
|
+
|
557
|
+
# Compute "safe" versions of `phi_v` and its square, which are used to
|
558
|
+
# normalize quantities below. These are equal to 1 anywhere `phi_v` is
|
559
|
+
# close to zero, and take their usual values elsewhere.
|
560
|
+
phi_v_near_zero = jnp.isclose(phi_v, 0.0)
|
561
|
+
phi_v_squared_safe = jnp.where(phi_v_near_zero, 1.0, phi_v**2)
|
562
|
+
phi_v_safe = jnp.where(phi_v_near_zero, 1.0, phi_v)
|
563
|
+
|
564
|
+
weight_xx = phi_x**2 / phi_v_squared_safe
|
565
|
+
weight_yy = phi_y**2 / phi_v_squared_safe
|
566
|
+
weight_xy = (phi_x * phi_y) / phi_v_squared_safe
|
567
|
+
phi_vv = weight_xx * phi_xx + weight_xy * (phi_xy + phi_yx) + weight_yy * phi_yy
|
568
|
+
|
569
|
+
qx: jnp.ndarray = jnp.gradient( # type: ignore[assignment]
|
570
|
+
phi_x / jnp.abs(phi_v_safe), d, axis=-2
|
571
|
+
)
|
572
|
+
qy: jnp.ndarray = jnp.gradient( # type: ignore[assignment]
|
573
|
+
phi_y / jnp.abs(phi_v_safe), d, axis=-1
|
574
|
+
)
|
575
|
+
inverse_radius = qx + qy
|
576
|
+
|
577
|
+
return phi, phi_v, phi_vv, inverse_radius
|
578
|
+
|
579
|
+
|
580
|
+
# -----------------------------------------------------------------------------
|
581
|
+
# Helper functions.
|
582
|
+
# -----------------------------------------------------------------------------
|
583
|
+
|
584
|
+
|
585
|
+
def _control_point_coords(
|
586
|
+
density_shape: Tuple[int, int],
|
587
|
+
levelset_shape: Tuple[int, int],
|
588
|
+
periodic: Tuple[bool, bool],
|
589
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
590
|
+
"""Returns the control point coordinates."""
|
591
|
+
# If the levelset is periodic along any axis, the first and last control
|
592
|
+
# points along that axis lie outside the bounds of the density.
|
593
|
+
offset_i = 0.5 if periodic[0] else -0.5
|
594
|
+
offset_j = 0.5 if periodic[1] else -0.5
|
595
|
+
range_i = levelset_shape[-2] - (0 if periodic[0] else 2)
|
596
|
+
range_j = levelset_shape[-1] - (0 if periodic[1] else 2)
|
597
|
+
|
598
|
+
factor_i = density_shape[-2] / range_i
|
599
|
+
factor_j = density_shape[-1] / range_j
|
600
|
+
levelset_i, levelset_j = jnp.meshgrid(
|
601
|
+
(offset_i + jnp.arange(levelset_shape[-2])) * factor_i,
|
602
|
+
(offset_j + jnp.arange(levelset_shape[-1])) * factor_j,
|
603
|
+
indexing="ij",
|
604
|
+
)
|
605
|
+
return levelset_i, levelset_j
|
606
|
+
|
607
|
+
|
608
|
+
def _sqrt_safe(x: jnp.ndarray) -> jnp.ndarray:
|
609
|
+
"""Compute square root while avoiding `nan` gradients near zero."""
|
610
|
+
x_near_zero = jnp.isclose(x, 0.0)
|
611
|
+
x_safe = jnp.where(x_near_zero, 1, x)
|
612
|
+
return jnp.where(x_near_zero, 0.0, jnp.sqrt(x_safe))
|
613
|
+
|
614
|
+
|
615
|
+
def _levelset_threshold(
|
616
|
+
phi: jnp.ndarray,
|
617
|
+
periodic: Tuple[bool, bool],
|
618
|
+
mask_gradient: bool,
|
619
|
+
) -> jnp.ndarray:
|
620
|
+
"""Thresholds a level set function `phi`."""
|
621
|
+
if mask_gradient:
|
622
|
+
interface = _interface_pixels(phi, periodic)
|
623
|
+
phi = jnp.where(interface, phi, jax.lax.stop_gradient(phi))
|
624
|
+
thresholded = (phi > 0).astype(float) + (phi - jax.lax.stop_gradient(phi))
|
625
|
+
return thresholded
|
626
|
+
|
627
|
+
|
628
|
+
def _interface_pixels(phi: jnp.ndarray, periodic: Tuple[bool, bool]) -> jnp.ndarray:
|
629
|
+
"""Identifies interface pixels of a level set function `phi`."""
|
630
|
+
batch_shape = phi.shape[:-2]
|
631
|
+
phi = phi.reshape((-1,) + phi.shape[-2:])
|
632
|
+
|
633
|
+
pad_mode = (
|
634
|
+
"wrap" if periodic[0] else "edge",
|
635
|
+
"wrap" if periodic[1] else "edge",
|
636
|
+
)
|
637
|
+
pad_width = ((1, 1), (1, 1))
|
638
|
+
|
639
|
+
kernel = jnp.asarray([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float)
|
640
|
+
|
641
|
+
solid = phi > 0
|
642
|
+
void = ~solid
|
643
|
+
|
644
|
+
solid_padded = transforms.pad2d(solid, pad_width, pad_mode)
|
645
|
+
num_solid_adjacent = transforms.conv(
|
646
|
+
x=solid_padded[:, jnp.newaxis, :, :].astype(float),
|
647
|
+
kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
|
648
|
+
padding="VALID",
|
649
|
+
)
|
650
|
+
num_solid_adjacent = jnp.squeeze(num_solid_adjacent, axis=1)
|
651
|
+
|
652
|
+
void_padded = transforms.pad2d(void, pad_width, pad_mode)
|
653
|
+
num_void_adjacent = transforms.conv(
|
654
|
+
x=void_padded[:, jnp.newaxis, :, :].astype(float),
|
655
|
+
kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
|
656
|
+
padding="VALID",
|
657
|
+
)
|
658
|
+
num_void_adjacent = jnp.squeeze(num_void_adjacent, axis=1)
|
659
|
+
|
660
|
+
interface = solid & (num_void_adjacent > 0) | void & (num_solid_adjacent > 0)
|
661
|
+
|
662
|
+
return interface.reshape(batch_shape + interface.shape[-2:])
|
663
|
+
|
664
|
+
|
665
|
+
def _downsample_spatial_dims(x: jnp.ndarray, downsample_factor: int) -> jnp.ndarray:
|
666
|
+
"""Downsamples the two trailing axes of `x` by `downsample_factor`."""
|
667
|
+
shape = x.shape[:-2] + (
|
668
|
+
x.shape[-2] // downsample_factor,
|
669
|
+
x.shape[-1] // downsample_factor,
|
670
|
+
)
|
671
|
+
return transforms.box_downsample(x, shape)
|