invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +155 -0
- invrs_opt/experimental/labels.py +23 -0
- invrs_opt/optimizers/__init__.py +0 -0
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -670
- invrs_opt-0.3.2.dist-info/LICENSE +0 -21
- invrs_opt-0.3.2.dist-info/METADATA +0 -73
- invrs_opt-0.3.2.dist-info/RECORD +0 -11
- /invrs_opt/{lbfgsb → experimental}/__init__.py +0 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,671 @@
|
|
1
|
+
"""Defines Gaussian radial basis function levelset parameterization.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
from typing import Any, Tuple
|
8
|
+
|
9
|
+
import jax
|
10
|
+
import jax.numpy as jnp
|
11
|
+
import optax # type: ignore[import-untyped]
|
12
|
+
from jax import tree_util
|
13
|
+
from totypes import json_utils, symmetry, types
|
14
|
+
|
15
|
+
from invrs_opt.parameterization import base, transforms
|
16
|
+
|
17
|
+
PyTree = Any
|
18
|
+
|
19
|
+
DEFAULT_LENGTH_SCALE_SPACING_FACTOR: float = 2.0
|
20
|
+
DEFAULT_LENGTH_SCALE_FWHM_FACTOR: float = 1.0
|
21
|
+
DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR: float = 1.15
|
22
|
+
DEFAULT_SMOOTHING_FACTOR: int = 2
|
23
|
+
DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA: float = 0.333
|
24
|
+
DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT: float = 1.0
|
25
|
+
DEFAULT_CURVATURE_CONSTRAINT_WEIGHT: float = 2.0
|
26
|
+
DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT: float = 10.0
|
27
|
+
DEFAULT_INIT_STEPS: int = 50
|
28
|
+
DEFAULT_INIT_OPTIMIZER: optax.GradientTransformation = optax.adam(1e-1)
|
29
|
+
|
30
|
+
|
31
|
+
@dataclasses.dataclass
|
32
|
+
class GaussianLevelsetParams(base.ParameterizedDensity2DArray):
|
33
|
+
"""Stores parameters for the Gaussian levelset parameterization."""
|
34
|
+
|
35
|
+
latents: "GaussianLevelsetLatents"
|
36
|
+
metadata: "GaussianLevelsetMetadata"
|
37
|
+
|
38
|
+
|
39
|
+
@dataclasses.dataclass
|
40
|
+
class GaussianLevelsetLatents(base.LatentsBase):
|
41
|
+
"""Stores latent parameters for the Gaussian levelset parameterization.
|
42
|
+
|
43
|
+
Attributes:
|
44
|
+
amplitude: Array giving the amplitude of the Gaussian basis function at
|
45
|
+
levelset control points.
|
46
|
+
"""
|
47
|
+
|
48
|
+
amplitude: jnp.ndarray
|
49
|
+
|
50
|
+
|
51
|
+
@dataclasses.dataclass
|
52
|
+
class GaussianLevelsetMetadata(base.MetadataBase):
|
53
|
+
"""Stores metadata for the Gaussian levelset parameterization.
|
54
|
+
|
55
|
+
Attributes:
|
56
|
+
length_scale_spacing_factor: The number of levelset control points per unit of
|
57
|
+
minimum length scale (mean of density minimum width and minimum spacing).
|
58
|
+
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
59
|
+
the minimum length scale.
|
60
|
+
smoothing_factor: For values greater than 1, the density is initially computed
|
61
|
+
at higher resolution and then downsampled, yielding smoother geometries.
|
62
|
+
density_shape: Shape of the density array obtained from the parameters.
|
63
|
+
density_metadata: Metadata for the density array obtained from the parameters.
|
64
|
+
"""
|
65
|
+
|
66
|
+
length_scale_spacing_factor: float
|
67
|
+
length_scale_fwhm_factor: float
|
68
|
+
smoothing_factor: int
|
69
|
+
density_shape: Tuple[int, ...]
|
70
|
+
density_metadata: base.Density2DMetadata
|
71
|
+
|
72
|
+
def __post_init__(self) -> None:
|
73
|
+
self.density_shape = tuple(self.density_shape)
|
74
|
+
|
75
|
+
|
76
|
+
tree_util.register_dataclass(
|
77
|
+
GaussianLevelsetParams,
|
78
|
+
data_fields=["latents", "metadata"],
|
79
|
+
meta_fields=[],
|
80
|
+
)
|
81
|
+
tree_util.register_dataclass(
|
82
|
+
GaussianLevelsetLatents,
|
83
|
+
data_fields=["amplitude"],
|
84
|
+
meta_fields=[],
|
85
|
+
)
|
86
|
+
tree_util.register_dataclass(
|
87
|
+
GaussianLevelsetMetadata,
|
88
|
+
data_fields=[
|
89
|
+
"length_scale_spacing_factor",
|
90
|
+
"length_scale_fwhm_factor",
|
91
|
+
"density_metadata",
|
92
|
+
],
|
93
|
+
meta_fields=["density_shape", "smoothing_factor"],
|
94
|
+
)
|
95
|
+
json_utils.register_custom_type(GaussianLevelsetParams)
|
96
|
+
json_utils.register_custom_type(GaussianLevelsetLatents)
|
97
|
+
json_utils.register_custom_type(GaussianLevelsetMetadata)
|
98
|
+
|
99
|
+
|
100
|
+
def gaussian_levelset(
|
101
|
+
*,
|
102
|
+
length_scale_spacing_factor: float = DEFAULT_LENGTH_SCALE_SPACING_FACTOR,
|
103
|
+
length_scale_fwhm_factor: float = DEFAULT_LENGTH_SCALE_FWHM_FACTOR,
|
104
|
+
length_scale_constraint_factor: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR,
|
105
|
+
smoothing_factor: int = DEFAULT_SMOOTHING_FACTOR,
|
106
|
+
length_scale_constraint_beta: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA,
|
107
|
+
length_scale_constraint_weight: float = DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT,
|
108
|
+
curvature_constraint_weight: float = DEFAULT_CURVATURE_CONSTRAINT_WEIGHT,
|
109
|
+
fixed_pixel_constraint_weight: float = DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT,
|
110
|
+
init_optimizer: optax.GradientTransformation = DEFAULT_INIT_OPTIMIZER,
|
111
|
+
init_steps: int = DEFAULT_INIT_STEPS,
|
112
|
+
) -> base.Density2DParameterization:
|
113
|
+
"""Defines a levelset parameterization with Gaussian radial basis functions.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
length_scale_spacing_factor: The number of levelset control points per unit of
|
117
|
+
minimum length scale (mean of density minimum width and minimum spacing).
|
118
|
+
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
119
|
+
the minimum length scale.
|
120
|
+
length_scale_constraint_factor: Multiplies the target length scale in the
|
121
|
+
levelset constraints. A value greater than 1 is pessimistic and drives the
|
122
|
+
solution to have a larger length scale (relative to smaller values).
|
123
|
+
smoothing_factor: For values greater than 1, the density is initially computed
|
124
|
+
at higher resolution and then downsampled, yielding smoother geometries.
|
125
|
+
length_scale_constraint_beta: Controls relaxation of the length scale
|
126
|
+
constraint near the zero level.
|
127
|
+
length_scale_constraint_weight: The weight of the length scale constraint in
|
128
|
+
the overall fabrication constraint peenalty.
|
129
|
+
curvature_constraint_weight: The weight of the curvature constraint.
|
130
|
+
fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
|
131
|
+
init_optimizer: The optimizer used in the initialization of the levelset
|
132
|
+
parameterization. At initialization, the latent parameters are optimized so
|
133
|
+
that the initial parameters match the binarized initial density.
|
134
|
+
init_steps: The number of optimization steps used in the initialization.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
The `Density2DParameterization`.
|
138
|
+
"""
|
139
|
+
|
140
|
+
def from_density_fn(density: types.Density2DArray) -> GaussianLevelsetParams:
|
141
|
+
"""Return level set parameters for the given `density`."""
|
142
|
+
length_scale = (density.minimum_width + density.minimum_spacing) / 2
|
143
|
+
spacing_factor = length_scale_spacing_factor / length_scale
|
144
|
+
shape = density.shape[:-2] + (
|
145
|
+
int(jnp.ceil(density.shape[-2] * spacing_factor)),
|
146
|
+
int(jnp.ceil(density.shape[-1] * spacing_factor)),
|
147
|
+
)
|
148
|
+
|
149
|
+
mid_value = 0.5 * (density.lower_bound + density.upper_bound)
|
150
|
+
value_range = density.upper_bound - density.lower_bound
|
151
|
+
target_array = transforms.apply_fixed_pixels(density).array
|
152
|
+
target_array = (
|
153
|
+
jnp.sign(target_array - mid_value) * 0.5 * value_range + mid_value
|
154
|
+
)
|
155
|
+
|
156
|
+
# Generate the initial amplitude array.
|
157
|
+
amplitude = density.array - (density.upper_bound + density.lower_bound) / 2
|
158
|
+
amplitude = jnp.sign(amplitude)
|
159
|
+
amplitude = transforms.resample(amplitude, shape)
|
160
|
+
|
161
|
+
# If the density is not periodic, ensure there are level set control points
|
162
|
+
# beyond the edge of the density array.
|
163
|
+
pad_width = ((0, 0),) * (amplitude.ndim - 2)
|
164
|
+
pad_width += ((0, 0),) if density.periodic[0] else ((1, 1),)
|
165
|
+
pad_width += ((0, 0),) if density.periodic[1] else ((1, 1),)
|
166
|
+
amplitude = jnp.pad(amplitude, pad_width, mode="edge")
|
167
|
+
|
168
|
+
latents = GaussianLevelsetLatents(amplitude=amplitude)
|
169
|
+
metadata = GaussianLevelsetMetadata(
|
170
|
+
length_scale_spacing_factor=length_scale_spacing_factor,
|
171
|
+
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
172
|
+
smoothing_factor=smoothing_factor,
|
173
|
+
density_shape=density.shape,
|
174
|
+
density_metadata=base.Density2DMetadata.from_density(density),
|
175
|
+
)
|
176
|
+
|
177
|
+
def step_fn(
|
178
|
+
_: int,
|
179
|
+
params_and_state: Tuple[PyTree, PyTree],
|
180
|
+
) -> Tuple[PyTree, PyTree]:
|
181
|
+
def loss_fn(latents: GaussianLevelsetLatents) -> jnp.ndarray:
|
182
|
+
params = GaussianLevelsetParams(latents, metadata=metadata)
|
183
|
+
density_from_params = to_density_fn(params, mask_gradient=False)
|
184
|
+
return jnp.mean((density_from_params.array - target_array) ** 2)
|
185
|
+
|
186
|
+
params, state = params_and_state
|
187
|
+
grad = jax.grad(loss_fn)(params)
|
188
|
+
updates, state = init_optimizer.update(grad, params=params, state=state)
|
189
|
+
params = optax.apply_updates(params, updates)
|
190
|
+
return params, state
|
191
|
+
|
192
|
+
state = init_optimizer.init(latents)
|
193
|
+
latents, _ = jax.lax.fori_loop(
|
194
|
+
0, init_steps, body_fun=step_fn, init_val=(latents, state)
|
195
|
+
)
|
196
|
+
|
197
|
+
maxval = jnp.amax(jnp.abs(latents.amplitude), axis=(-2, -1), keepdims=True)
|
198
|
+
latents = dataclasses.replace(latents, amplitude=latents.amplitude / maxval)
|
199
|
+
return GaussianLevelsetParams(latents=latents, metadata=metadata)
|
200
|
+
|
201
|
+
def to_density_fn(
|
202
|
+
params: GaussianLevelsetParams,
|
203
|
+
mask_gradient: bool = True,
|
204
|
+
) -> types.Density2DArray:
|
205
|
+
"""Return a density from the latent parameters."""
|
206
|
+
array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=0)
|
207
|
+
|
208
|
+
example_density = _example_density(params)
|
209
|
+
lb = example_density.lower_bound
|
210
|
+
ub = example_density.upper_bound
|
211
|
+
array = lb + array * (ub - lb)
|
212
|
+
assert array.shape == example_density.shape
|
213
|
+
return dataclasses.replace(example_density, array=array)
|
214
|
+
|
215
|
+
def constraints_fn(
|
216
|
+
params: GaussianLevelsetParams,
|
217
|
+
mask_gradient: bool = True,
|
218
|
+
pad_pixels: int = 2,
|
219
|
+
) -> jnp.ndarray:
|
220
|
+
"""Computes constraints associated with the params."""
|
221
|
+
return analytical_constraints(
|
222
|
+
params=params,
|
223
|
+
length_scale_constraint_factor=length_scale_constraint_factor,
|
224
|
+
length_scale_constraint_beta=length_scale_constraint_beta,
|
225
|
+
length_scale_constraint_weight=length_scale_constraint_weight,
|
226
|
+
curvature_constraint_weight=curvature_constraint_weight,
|
227
|
+
fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
|
228
|
+
mask_gradient=mask_gradient,
|
229
|
+
pad_pixels=pad_pixels,
|
230
|
+
)
|
231
|
+
|
232
|
+
def update_fn(
|
233
|
+
params: GaussianLevelsetParams,
|
234
|
+
updates: GaussianLevelsetParams,
|
235
|
+
value: jnp.ndarray,
|
236
|
+
step: int,
|
237
|
+
) -> GaussianLevelsetParams:
|
238
|
+
"""Perform updates to `params` required for the given `step`."""
|
239
|
+
del step, value
|
240
|
+
return GaussianLevelsetParams(
|
241
|
+
latents=tree_util.tree_map(
|
242
|
+
lambda a, b: a + b, params.latents, updates.latents
|
243
|
+
),
|
244
|
+
metadata=params.metadata,
|
245
|
+
)
|
246
|
+
|
247
|
+
return base.Density2DParameterization(
|
248
|
+
to_density=to_density_fn,
|
249
|
+
from_density=from_density_fn,
|
250
|
+
constraints=constraints_fn,
|
251
|
+
update=update_fn,
|
252
|
+
)
|
253
|
+
|
254
|
+
|
255
|
+
# -----------------------------------------------------------------------------
|
256
|
+
# Functions to obtain arrays from the levelset parameterization.
|
257
|
+
# -----------------------------------------------------------------------------
|
258
|
+
|
259
|
+
|
260
|
+
def _example_density(params: GaussianLevelsetParams) -> types.Density2DArray:
|
261
|
+
"""Returns an example density with appropriate shape and metadata."""
|
262
|
+
with jax.ensure_compile_time_eval():
|
263
|
+
return types.Density2DArray(
|
264
|
+
array=jnp.zeros(params.metadata.density_shape),
|
265
|
+
**dataclasses.asdict(params.metadata.density_metadata),
|
266
|
+
)
|
267
|
+
|
268
|
+
|
269
|
+
def _to_array(
|
270
|
+
params: GaussianLevelsetParams,
|
271
|
+
mask_gradient: bool,
|
272
|
+
pad_pixels: int,
|
273
|
+
) -> jnp.ndarray:
|
274
|
+
"""Return an array from the parameters.
|
275
|
+
|
276
|
+
The array has a value of `1` where the levelset array is positive, and a value
|
277
|
+
of `-1` elsewhere. The final density array can be obtained by rescaling this array
|
278
|
+
to have the appropriate upper and lower bounds.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
params: The parameters from which the density is obtained.
|
282
|
+
mask_gradient: If `True`, the gradient is masked so that it is nonzero only at
|
283
|
+
the borders of features.
|
284
|
+
pad_pixels: A non-negative integer giving the additional pixels to be included
|
285
|
+
beyond the boundaries of the parameterized density.
|
286
|
+
|
287
|
+
Returns:
|
288
|
+
The array.
|
289
|
+
"""
|
290
|
+
example_density = _example_density(params)
|
291
|
+
periodic: Tuple[bool, bool] = example_density.periodic
|
292
|
+
phi = _phi_from_params(params=params, pad_pixels=pad_pixels)
|
293
|
+
array = _levelset_threshold(phi=phi, periodic=periodic, mask_gradient=mask_gradient)
|
294
|
+
return _downsample_spatial_dims(array, params.metadata.smoothing_factor)
|
295
|
+
|
296
|
+
|
297
|
+
def _phi_from_params(
|
298
|
+
params: GaussianLevelsetParams,
|
299
|
+
pad_pixels: int,
|
300
|
+
) -> jnp.ndarray:
|
301
|
+
"""Return the levelset function for the given `params`.
|
302
|
+
|
303
|
+
Args:
|
304
|
+
params: The parameters from which the density is obtained.
|
305
|
+
pad_pixels: A non-negative integer giving the additional pixels to be included
|
306
|
+
beyond the boundaries of the parameterized density.
|
307
|
+
|
308
|
+
Returns:
|
309
|
+
The levelset array `phi`.
|
310
|
+
"""
|
311
|
+
with jax.ensure_compile_time_eval():
|
312
|
+
example_density = _example_density(params)
|
313
|
+
length_scale = 0.5 * (
|
314
|
+
example_density.minimum_width + example_density.minimum_spacing
|
315
|
+
)
|
316
|
+
fwhm = length_scale * params.metadata.length_scale_fwhm_factor
|
317
|
+
sigma = fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
|
318
|
+
|
319
|
+
s_factor = params.metadata.smoothing_factor
|
320
|
+
highres_i = (
|
321
|
+
0.5
|
322
|
+
+ jnp.arange(
|
323
|
+
s_factor * (-pad_pixels),
|
324
|
+
s_factor * (pad_pixels + example_density.shape[-2]),
|
325
|
+
)
|
326
|
+
) / s_factor
|
327
|
+
highres_j = (
|
328
|
+
0.5
|
329
|
+
+ jnp.arange(
|
330
|
+
s_factor * (-pad_pixels),
|
331
|
+
s_factor * (pad_pixels + example_density.shape[-1]),
|
332
|
+
)
|
333
|
+
) / s_factor
|
334
|
+
|
335
|
+
# Coordinates for the control points of the Gaussian radial basis functions.
|
336
|
+
levelset_i, levelset_j = _control_point_coords(
|
337
|
+
density_shape=params.metadata.density_shape[-2:], # type: ignore[arg-type]
|
338
|
+
levelset_shape=(
|
339
|
+
params.latents.amplitude.shape[-2:] # type: ignore[arg-type]
|
340
|
+
),
|
341
|
+
periodic=example_density.periodic,
|
342
|
+
)
|
343
|
+
|
344
|
+
# Handle periodicity by replicating control points over a 3x3 supercell.
|
345
|
+
if example_density.periodic[0]:
|
346
|
+
levelset_i = jnp.concatenate(
|
347
|
+
[
|
348
|
+
levelset_i - example_density.shape[-2],
|
349
|
+
levelset_i,
|
350
|
+
levelset_i + example_density.shape[-2],
|
351
|
+
],
|
352
|
+
axis=-2,
|
353
|
+
)
|
354
|
+
levelset_j = jnp.concatenate([levelset_j] * 3, axis=-2)
|
355
|
+
if example_density.periodic[1]:
|
356
|
+
levelset_i = jnp.concatenate([levelset_i] * 3, axis=-1)
|
357
|
+
levelset_j = jnp.concatenate(
|
358
|
+
[
|
359
|
+
levelset_j - example_density.shape[-1],
|
360
|
+
levelset_j,
|
361
|
+
levelset_j + example_density.shape[-1],
|
362
|
+
],
|
363
|
+
axis=-1,
|
364
|
+
)
|
365
|
+
|
366
|
+
levelset_i = levelset_i.flatten()
|
367
|
+
levelset_j = levelset_j.flatten()
|
368
|
+
|
369
|
+
amplitude = params.latents.amplitude
|
370
|
+
if example_density.periodic[0]:
|
371
|
+
amplitude = jnp.concat([amplitude] * 3, axis=-2)
|
372
|
+
if example_density.periodic[1]:
|
373
|
+
amplitude = jnp.concat([amplitude] * 3, axis=-1)
|
374
|
+
|
375
|
+
amplitude = amplitude.reshape(amplitude.shape[:-2] + (1, -1))
|
376
|
+
|
377
|
+
# Use a scan operation to compute the array; this lowers memory consumption.
|
378
|
+
def scan_fn(_: Tuple[()], i: jnp.ndarray) -> Tuple[Tuple[()], jnp.ndarray]:
|
379
|
+
distance_sq = (i - levelset_i) ** 2 + (
|
380
|
+
highres_j[:, jnp.newaxis] - levelset_j
|
381
|
+
) ** 2
|
382
|
+
basis = jnp.exp(-distance_sq / sigma**2)
|
383
|
+
return (), jnp.sum(basis * amplitude, axis=-1)
|
384
|
+
|
385
|
+
_, array = jax.lax.scan(scan_fn, (), xs=highres_i)
|
386
|
+
array = jnp.moveaxis(array, 0, -2)
|
387
|
+
|
388
|
+
assert array.shape[-2] % s_factor == 0
|
389
|
+
assert array.shape[-1] % s_factor == 0
|
390
|
+
array = symmetry.symmetrize(array, tuple(example_density.symmetries))
|
391
|
+
return array
|
392
|
+
|
393
|
+
|
394
|
+
# -----------------------------------------------------------------------------
|
395
|
+
# Functions to compute constraints.
|
396
|
+
# -----------------------------------------------------------------------------
|
397
|
+
|
398
|
+
|
399
|
+
def analytical_constraints(
|
400
|
+
params: GaussianLevelsetParams,
|
401
|
+
length_scale_constraint_factor: float,
|
402
|
+
length_scale_constraint_beta: float,
|
403
|
+
length_scale_constraint_weight: float,
|
404
|
+
curvature_constraint_weight: float,
|
405
|
+
fixed_pixel_constraint_weight: float,
|
406
|
+
mask_gradient: bool,
|
407
|
+
pad_pixels: int,
|
408
|
+
) -> jnp.ndarray:
|
409
|
+
"""Computes analytical levelset constraints associated with the params."""
|
410
|
+
length_scale_constraint, curvature_constraint = _levelset_constraints(
|
411
|
+
params,
|
412
|
+
beta=length_scale_constraint_beta,
|
413
|
+
length_scale_constraint_factor=length_scale_constraint_factor,
|
414
|
+
pad_pixels=pad_pixels,
|
415
|
+
)
|
416
|
+
fixed_pixel_constraint = _fixed_pixel_constraint(
|
417
|
+
params,
|
418
|
+
mask_gradient=mask_gradient,
|
419
|
+
pad_pixels=pad_pixels,
|
420
|
+
)
|
421
|
+
|
422
|
+
constraints = jnp.stack(
|
423
|
+
[
|
424
|
+
length_scale_constraint * length_scale_constraint_weight,
|
425
|
+
curvature_constraint * curvature_constraint_weight,
|
426
|
+
fixed_pixel_constraint * fixed_pixel_constraint_weight,
|
427
|
+
],
|
428
|
+
axis=-1,
|
429
|
+
)
|
430
|
+
|
431
|
+
# Normalize constraints to make them (somewhat) resolution-independent.
|
432
|
+
example_density = _example_density(params)
|
433
|
+
length_scale = 0.5 * (
|
434
|
+
example_density.minimum_spacing + example_density.minimum_width
|
435
|
+
)
|
436
|
+
return constraints / length_scale**2
|
437
|
+
|
438
|
+
|
439
|
+
def _fixed_pixel_constraint(
|
440
|
+
params: GaussianLevelsetParams,
|
441
|
+
mask_gradient: bool,
|
442
|
+
pad_pixels: int,
|
443
|
+
) -> jnp.ndarray:
|
444
|
+
"""Return the fixed pixel constraint array.
|
445
|
+
|
446
|
+
The fixed pixel constraint array is nonzero at locations where the density obtained
|
447
|
+
from `params` differs from fixed pixels.
|
448
|
+
|
449
|
+
Args:
|
450
|
+
params: The parameters from which the density is obtained.
|
451
|
+
mask_gradient: If `True`, the gradient is masked so that it is nonzero only at
|
452
|
+
the borders of features.
|
453
|
+
pad_pixels: The number of pixels added at borders. Values greater than zero
|
454
|
+
help to ensure that sharp features at the borders are avoided.
|
455
|
+
|
456
|
+
Returns:
|
457
|
+
The constraints array.
|
458
|
+
"""
|
459
|
+
array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=pad_pixels)
|
460
|
+
|
461
|
+
example_density = _example_density(params)
|
462
|
+
fixed_solid = jnp.zeros(example_density.shape[-2:], dtype=bool)
|
463
|
+
fixed_void = jnp.zeros(example_density.shape[-2:], dtype=bool)
|
464
|
+
if example_density.fixed_solid is not None:
|
465
|
+
fixed_solid = jnp.asarray(example_density.fixed_solid)
|
466
|
+
if example_density.fixed_void is not None:
|
467
|
+
fixed_void = jnp.asarray(example_density.fixed_void)
|
468
|
+
|
469
|
+
pad_width_solid = ((0, 0),) * (fixed_solid.ndim - 2) + (
|
470
|
+
(pad_pixels, pad_pixels),
|
471
|
+
(pad_pixels, pad_pixels),
|
472
|
+
)
|
473
|
+
pad_width_void = ((0, 0),) * (fixed_void.ndim - 2) + (
|
474
|
+
(pad_pixels, pad_pixels),
|
475
|
+
(pad_pixels, pad_pixels),
|
476
|
+
)
|
477
|
+
fixed_solid = jnp.pad(fixed_solid, pad_width_solid, mode="edge")
|
478
|
+
fixed_void = jnp.pad(fixed_void, pad_width_void, mode="edge")
|
479
|
+
fixed = fixed_solid | fixed_void
|
480
|
+
target = jnp.where(fixed_solid, 1, 0)
|
481
|
+
|
482
|
+
return jnp.where(fixed, jnp.abs(array - target), 0.0)
|
483
|
+
|
484
|
+
|
485
|
+
def _levelset_constraints(
|
486
|
+
params: GaussianLevelsetParams,
|
487
|
+
beta: float,
|
488
|
+
length_scale_constraint_factor: float,
|
489
|
+
pad_pixels: int,
|
490
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
491
|
+
"""Compute constraints for minimum width, spacing, and radius of curvature.
|
492
|
+
|
493
|
+
The constraints are based on "Analytical level set fabrication constraints for
|
494
|
+
inverse design," by D. Vercruysse et al. (2019). Constraints are satisfied when
|
495
|
+
they are non-positive.
|
496
|
+
|
497
|
+
https://www.nature.com/articles/s41598-019-45026-0
|
498
|
+
|
499
|
+
Args:
|
500
|
+
params: The parameters of the Gaussian levelset.
|
501
|
+
beta: Parameter which relaxes the constraint near the zero-plane.
|
502
|
+
length_scale_constraint_factor: Multiplies the target length scale in the
|
503
|
+
levelset constraints. A value greater than 1 is pessimistic and drives the
|
504
|
+
solution to have a larger length scale (relative to smaller values).
|
505
|
+
pad_pixels: A non-negative integer giving the additional pixels to be included
|
506
|
+
beyond the boundaries of the parameterized density.
|
507
|
+
|
508
|
+
Returns:
|
509
|
+
The minimum length scale and minimum curvature constraint arrays.s
|
510
|
+
"""
|
511
|
+
example_density = _example_density(params)
|
512
|
+
minimum_length_scale = 0.5 * (
|
513
|
+
example_density.minimum_width + example_density.minimum_spacing
|
514
|
+
)
|
515
|
+
|
516
|
+
phi, phi_v, phi_vv, inverse_radius = _phi_derivatives_and_inverse_radius(
|
517
|
+
params,
|
518
|
+
pad_pixels=pad_pixels,
|
519
|
+
)
|
520
|
+
|
521
|
+
d = minimum_length_scale * length_scale_constraint_factor
|
522
|
+
denom = jnp.pi / d * jnp.abs(phi) + beta * phi_v
|
523
|
+
denom_safe = jnp.where(jnp.isclose(phi_vv, 0.0), 1.0, denom)
|
524
|
+
length_scale_constraint = jnp.abs(phi_vv) / denom_safe - jnp.pi / d
|
525
|
+
|
526
|
+
curvature_denom_safe = jnp.where(jnp.isclose(phi_v, 0.0), 1.0, phi)
|
527
|
+
curvature_constraint = (
|
528
|
+
jnp.abs(inverse_radius * jnp.arctan(phi_v / curvature_denom_safe)) - jnp.pi / d
|
529
|
+
)
|
530
|
+
|
531
|
+
# Downsample so that constraints shape matches the density shape.
|
532
|
+
factor = params.metadata.smoothing_factor
|
533
|
+
return (
|
534
|
+
_downsample_spatial_dims(length_scale_constraint, factor),
|
535
|
+
_downsample_spatial_dims(curvature_constraint, factor),
|
536
|
+
)
|
537
|
+
|
538
|
+
|
539
|
+
def _phi_derivatives_and_inverse_radius(
|
540
|
+
params: GaussianLevelsetParams,
|
541
|
+
pad_pixels: int,
|
542
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
543
|
+
"""Compute the levelset function and its first and second derivatives."""
|
544
|
+
|
545
|
+
phi = _phi_from_params(
|
546
|
+
params=params,
|
547
|
+
pad_pixels=pad_pixels,
|
548
|
+
)
|
549
|
+
|
550
|
+
d = 1 / params.metadata.smoothing_factor
|
551
|
+
phi_x, phi_y = jnp.gradient(phi, d, axis=(-2, -1))
|
552
|
+
phi_xx, phi_yx = jnp.gradient(phi_x, d, axis=(-2, -1))
|
553
|
+
phi_xy, phi_yy = jnp.gradient(phi_y, d, axis=(-2, -1))
|
554
|
+
|
555
|
+
phi_v = _sqrt_safe(phi_x**2 + phi_y**2)
|
556
|
+
|
557
|
+
# Compute "safe" versions of `phi_v` and its square, which are used to
|
558
|
+
# normalize quantities below. These are equal to 1 anywhere `phi_v` is
|
559
|
+
# close to zero, and take their usual values elsewhere.
|
560
|
+
phi_v_near_zero = jnp.isclose(phi_v, 0.0)
|
561
|
+
phi_v_squared_safe = jnp.where(phi_v_near_zero, 1.0, phi_v**2)
|
562
|
+
phi_v_safe = jnp.where(phi_v_near_zero, 1.0, phi_v)
|
563
|
+
|
564
|
+
weight_xx = phi_x**2 / phi_v_squared_safe
|
565
|
+
weight_yy = phi_y**2 / phi_v_squared_safe
|
566
|
+
weight_xy = (phi_x * phi_y) / phi_v_squared_safe
|
567
|
+
phi_vv = weight_xx * phi_xx + weight_xy * (phi_xy + phi_yx) + weight_yy * phi_yy
|
568
|
+
|
569
|
+
qx: jnp.ndarray = jnp.gradient( # type: ignore[assignment]
|
570
|
+
phi_x / jnp.abs(phi_v_safe), d, axis=-2
|
571
|
+
)
|
572
|
+
qy: jnp.ndarray = jnp.gradient( # type: ignore[assignment]
|
573
|
+
phi_y / jnp.abs(phi_v_safe), d, axis=-1
|
574
|
+
)
|
575
|
+
inverse_radius = qx + qy
|
576
|
+
|
577
|
+
return phi, phi_v, phi_vv, inverse_radius
|
578
|
+
|
579
|
+
|
580
|
+
# -----------------------------------------------------------------------------
|
581
|
+
# Helper functions.
|
582
|
+
# -----------------------------------------------------------------------------
|
583
|
+
|
584
|
+
|
585
|
+
def _control_point_coords(
|
586
|
+
density_shape: Tuple[int, int],
|
587
|
+
levelset_shape: Tuple[int, int],
|
588
|
+
periodic: Tuple[bool, bool],
|
589
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
590
|
+
"""Returns the control point coordinates."""
|
591
|
+
# If the levelset is periodic along any axis, the first and last control
|
592
|
+
# points along that axis lie outside the bounds of the density.
|
593
|
+
offset_i = 0.5 if periodic[0] else -0.5
|
594
|
+
offset_j = 0.5 if periodic[1] else -0.5
|
595
|
+
range_i = levelset_shape[-2] - (0 if periodic[0] else 2)
|
596
|
+
range_j = levelset_shape[-1] - (0 if periodic[1] else 2)
|
597
|
+
|
598
|
+
factor_i = density_shape[-2] / range_i
|
599
|
+
factor_j = density_shape[-1] / range_j
|
600
|
+
levelset_i, levelset_j = jnp.meshgrid(
|
601
|
+
(offset_i + jnp.arange(levelset_shape[-2])) * factor_i,
|
602
|
+
(offset_j + jnp.arange(levelset_shape[-1])) * factor_j,
|
603
|
+
indexing="ij",
|
604
|
+
)
|
605
|
+
return levelset_i, levelset_j
|
606
|
+
|
607
|
+
|
608
|
+
def _sqrt_safe(x: jnp.ndarray) -> jnp.ndarray:
|
609
|
+
"""Compute square root while avoiding `nan` gradients near zero."""
|
610
|
+
x_near_zero = jnp.isclose(x, 0.0)
|
611
|
+
x_safe = jnp.where(x_near_zero, 1, x)
|
612
|
+
return jnp.where(x_near_zero, 0.0, jnp.sqrt(x_safe))
|
613
|
+
|
614
|
+
|
615
|
+
def _levelset_threshold(
|
616
|
+
phi: jnp.ndarray,
|
617
|
+
periodic: Tuple[bool, bool],
|
618
|
+
mask_gradient: bool,
|
619
|
+
) -> jnp.ndarray:
|
620
|
+
"""Thresholds a level set function `phi`."""
|
621
|
+
if mask_gradient:
|
622
|
+
interface = _interface_pixels(phi, periodic)
|
623
|
+
phi = jnp.where(interface, phi, jax.lax.stop_gradient(phi))
|
624
|
+
thresholded = (phi > 0).astype(float) + (phi - jax.lax.stop_gradient(phi))
|
625
|
+
return thresholded
|
626
|
+
|
627
|
+
|
628
|
+
def _interface_pixels(phi: jnp.ndarray, periodic: Tuple[bool, bool]) -> jnp.ndarray:
|
629
|
+
"""Identifies interface pixels of a level set function `phi`."""
|
630
|
+
batch_shape = phi.shape[:-2]
|
631
|
+
phi = phi.reshape((-1,) + phi.shape[-2:])
|
632
|
+
|
633
|
+
pad_mode = (
|
634
|
+
"wrap" if periodic[0] else "edge",
|
635
|
+
"wrap" if periodic[1] else "edge",
|
636
|
+
)
|
637
|
+
pad_width = ((1, 1), (1, 1))
|
638
|
+
|
639
|
+
kernel = jnp.asarray([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float)
|
640
|
+
|
641
|
+
solid = phi > 0
|
642
|
+
void = ~solid
|
643
|
+
|
644
|
+
solid_padded = transforms.pad2d(solid, pad_width, pad_mode)
|
645
|
+
num_solid_adjacent = transforms.conv(
|
646
|
+
x=solid_padded[:, jnp.newaxis, :, :].astype(float),
|
647
|
+
kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
|
648
|
+
padding="VALID",
|
649
|
+
)
|
650
|
+
num_solid_adjacent = jnp.squeeze(num_solid_adjacent, axis=1)
|
651
|
+
|
652
|
+
void_padded = transforms.pad2d(void, pad_width, pad_mode)
|
653
|
+
num_void_adjacent = transforms.conv(
|
654
|
+
x=void_padded[:, jnp.newaxis, :, :].astype(float),
|
655
|
+
kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
|
656
|
+
padding="VALID",
|
657
|
+
)
|
658
|
+
num_void_adjacent = jnp.squeeze(num_void_adjacent, axis=1)
|
659
|
+
|
660
|
+
interface = solid & (num_void_adjacent > 0) | void & (num_solid_adjacent > 0)
|
661
|
+
|
662
|
+
return interface.reshape(batch_shape + interface.shape[-2:])
|
663
|
+
|
664
|
+
|
665
|
+
def _downsample_spatial_dims(x: jnp.ndarray, downsample_factor: int) -> jnp.ndarray:
|
666
|
+
"""Downsamples the two trailing axes of `x` by `downsample_factor`."""
|
667
|
+
shape = x.shape[:-2] + (
|
668
|
+
x.shape[-2] // downsample_factor,
|
669
|
+
x.shape[-1] // downsample_factor,
|
670
|
+
)
|
671
|
+
return transforms.box_downsample(x, shape)
|