invrs-opt 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,300 @@
1
+ """Defines a wrapper for optax optimizers.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ from typing import Any, Optional, Tuple
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import optax # type: ignore[import-untyped]
11
+ from jax import tree_util
12
+ from totypes import types
13
+
14
+ from invrs_opt.optimizers import base
15
+ from invrs_opt.parameterization import (
16
+ base as parameterization_base,
17
+ filter_project,
18
+ gaussian_levelset,
19
+ pixel,
20
+ )
21
+
22
+ PyTree = Any
23
+ WrappedOptaxState = Tuple[PyTree, PyTree, PyTree]
24
+
25
+
26
+ def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
27
+ """Return a wrapped optax optimizer."""
28
+ return parameterized_wrapped_optax(
29
+ opt=opt, penalty=0.0, density_parameterization=None
30
+ )
31
+
32
+
33
+ def density_wrapped_optax(
34
+ opt: optax.GradientTransformation,
35
+ *,
36
+ beta: float,
37
+ ) -> base.Optimizer:
38
+ """Wrapped optax optimizer with filter-project density parameterization.
39
+
40
+ In the filter-project density parameterization, the optimization variable
41
+ associated with a density array is a latent density array; the density is obtained
42
+ by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
43
+ full-width at half-maximum equal to the length scale (the mean of declared minimum
44
+ width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
45
+ operation ("projection").
46
+
47
+ Args:
48
+ opt: The optax optimizer to be wrapped.
49
+ beta: Determines the sharpness of the thresholding operation.
50
+
51
+ Returns:
52
+ The wrapped optax optimizer.
53
+ """
54
+ return parameterized_wrapped_optax(
55
+ opt=opt,
56
+ penalty=0.0,
57
+ density_parameterization=filter_project.filter_project(beta=beta),
58
+ )
59
+
60
+
61
+ def levelset_wrapped_optax(
62
+ opt: optax.GradientTransformation,
63
+ *,
64
+ penalty: float,
65
+ length_scale_spacing_factor: float = (
66
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
67
+ ),
68
+ length_scale_fwhm_factor: float = (
69
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
70
+ ),
71
+ length_scale_constraint_factor: float = (
72
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
73
+ ),
74
+ smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
75
+ length_scale_constraint_beta: float = (
76
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
77
+ ),
78
+ length_scale_constraint_weight: float = (
79
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
80
+ ),
81
+ curvature_constraint_weight: float = (
82
+ gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
83
+ ),
84
+ fixed_pixel_constraint_weight: float = (
85
+ gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
86
+ ),
87
+ init_optimizer: optax.GradientTransformation = (
88
+ gaussian_levelset.DEFAULT_INIT_OPTIMIZER
89
+ ),
90
+ init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
91
+ ) -> base.Optimizer:
92
+ """Wrapped optax optimizer with levelset density parameterization.
93
+
94
+ In the levelset parameterization, the optimization variable associated with a
95
+ density array is an array giving the amplitudes of Gaussian radial basis functions
96
+ that represent a levelset function over the domain of the density. In the levelset
97
+ parameterization, gradients are nonzero only at the edges of features, and in
98
+ general the topology of a solution does not change during the course of
99
+ optimization.
100
+
101
+ The spacing and full-width at half-maximum of the Gaussian basis functions gives
102
+ some amount of control over length scales. In addition, constraints associated with
103
+ length scale, radius of curvature, and deviation from fixed pixels are
104
+ automatically computed and penalized with a weight given by `penalty`. In general,
105
+ this helps ensure that features in an optimized density array violate the specified
106
+ constraints to a lesser degree. The constraints are based on "Analytical level set
107
+ fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
108
+
109
+ Args:
110
+ opt: The optax optimizer to be wrapped.
111
+ penalty: The weight of the fabrication penalty, which combines length scale,
112
+ curvature, and fixed pixel constraints.
113
+ length_scale_spacing_factor: The number of levelset control points per unit of
114
+ minimum length scale (mean of density minimum width and minimum spacing).
115
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
116
+ the minimum length scale.
117
+ length_scale_constraint_factor: Multiplies the target length scale in the
118
+ levelset constraints. A value greater than 1 is pessimistic and drives the
119
+ solution to have a larger length scale (relative to smaller values).
120
+ smoothing_factor: For values greater than 1, the density is initially computed
121
+ at higher resolution and then downsampled, yielding smoother geometries.
122
+ length_scale_constraint_beta: Controls relaxation of the length scale
123
+ constraint near the zero level.
124
+ length_scale_constraint_weight: The weight of the length scale constraint in
125
+ the overall fabrication constraint peenalty.
126
+ curvature_constraint_weight: The weight of the curvature constraint.
127
+ fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
128
+ init_optimizer: The optimizer used in the initialization of the levelset
129
+ parameterization. At initialization, the latent parameters are optimized so
130
+ that the initial parameters match the binarized initial density.
131
+ init_steps: The number of optimization steps used in the initialization.
132
+
133
+ Returns:
134
+ The wrapped optax optimizer.
135
+ """
136
+ return parameterized_wrapped_optax(
137
+ opt=opt,
138
+ penalty=penalty,
139
+ density_parameterization=gaussian_levelset.gaussian_levelset(
140
+ length_scale_spacing_factor=length_scale_spacing_factor,
141
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
142
+ length_scale_constraint_factor=length_scale_constraint_factor,
143
+ smoothing_factor=smoothing_factor,
144
+ length_scale_constraint_beta=length_scale_constraint_beta,
145
+ length_scale_constraint_weight=length_scale_constraint_weight,
146
+ curvature_constraint_weight=curvature_constraint_weight,
147
+ fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
148
+ init_optimizer=init_optimizer,
149
+ init_steps=init_steps,
150
+ ),
151
+ )
152
+
153
+
154
+ # -----------------------------------------------------------------------------
155
+ # Base parameterized wrapped optax optimizer.
156
+ # -----------------------------------------------------------------------------
157
+
158
+
159
+ def parameterized_wrapped_optax(
160
+ opt: optax.GradientTransformation,
161
+ density_parameterization: Optional[parameterization_base.Density2DParameterization],
162
+ penalty: float,
163
+ ) -> base.Optimizer:
164
+ """Wrapped optax optimizer with specified density parameterization.
165
+
166
+ Args:
167
+ opt: The optax `GradientTransformation` to be wrapped.
168
+ density_parameterization: The parameterization to be used, or `None`. When no
169
+ parameterization is given, the direct pixel parameterization is used for
170
+ density arrays.
171
+ penalty: The weight of the scalar penalty formed from the constraints of the
172
+ parameterization.
173
+
174
+ Returns:
175
+ The `base.Optimizer`.
176
+ """
177
+
178
+ if density_parameterization is None:
179
+ density_parameterization = pixel.pixel()
180
+
181
+ def _init_latents(params: PyTree) -> PyTree:
182
+ def _leaf_init_latents(leaf: Any) -> Any:
183
+ leaf = _clip(leaf)
184
+ if not _is_density(leaf):
185
+ return leaf
186
+ return density_parameterization.from_density(leaf)
187
+
188
+ return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
189
+
190
+ def _params_from_latents(params: PyTree) -> PyTree:
191
+ def _leaf_params_from_latents(leaf: Any) -> Any:
192
+ if not _is_parameterized_density(leaf):
193
+ return leaf
194
+ return density_parameterization.to_density(leaf)
195
+
196
+ return tree_util.tree_map(
197
+ _leaf_params_from_latents,
198
+ params,
199
+ is_leaf=_is_parameterized_density,
200
+ )
201
+
202
+ def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
203
+ def _constraint_loss_leaf(
204
+ params: parameterization_base.ParameterizedDensity2DArrayBase,
205
+ ) -> jnp.ndarray:
206
+ constraints = density_parameterization.constraints(params)
207
+ constraints = tree_util.tree_map(
208
+ lambda x: jnp.sum(jnp.maximum(x, 0.0)),
209
+ constraints,
210
+ )
211
+ return jnp.sum(jnp.asarray(constraints))
212
+
213
+ losses = [0.0] + [
214
+ _constraint_loss_leaf(p)
215
+ for p in tree_util.tree_leaves(
216
+ latent_params, is_leaf=_is_parameterized_density
217
+ )
218
+ if _is_parameterized_density(p)
219
+ ]
220
+ return penalty * jnp.sum(jnp.asarray(losses))
221
+
222
+ def init_fn(params: PyTree) -> WrappedOptaxState:
223
+ """Initializes the optimization state."""
224
+ latent_params = _init_latents(params)
225
+ params = _params_from_latents(latent_params)
226
+ return params, latent_params, opt.init(latent_params)
227
+
228
+ def params_fn(state: WrappedOptaxState) -> PyTree:
229
+ """Returns the parameters for the given `state`."""
230
+ params, _, _ = state
231
+ return params
232
+
233
+ def update_fn(
234
+ *,
235
+ grad: PyTree,
236
+ value: float,
237
+ params: PyTree,
238
+ state: WrappedOptaxState,
239
+ ) -> WrappedOptaxState:
240
+ """Updates the state."""
241
+ del value, params
242
+
243
+ _, latent_params, opt_state = state
244
+ _, vjp_fn = jax.vjp(_params_from_latents, latent_params)
245
+ (latent_grad,) = vjp_fn(grad)
246
+
247
+ if not (
248
+ tree_util.tree_structure(latent_grad)
249
+ == tree_util.tree_structure(latent_params) # type: ignore[operator]
250
+ ):
251
+ raise ValueError(
252
+ f"Tree structure of `latent_grad` was different than expected, got \n"
253
+ f"{tree_util.tree_structure(latent_grad)} but expected \n"
254
+ f"{tree_util.tree_structure(latent_params)}."
255
+ )
256
+
257
+ constraint_loss_grad = jax.grad(_constraint_loss)(latent_params)
258
+ latent_grad = tree_util.tree_map(
259
+ lambda a, b: a + b, latent_grad, constraint_loss_grad
260
+ )
261
+
262
+ updates, opt_state = opt.update(
263
+ updates=latent_grad, state=opt_state, params=latent_params
264
+ )
265
+ latent_params = optax.apply_updates(params=latent_params, updates=updates)
266
+ latent_params = _clip(latent_params)
267
+ params = _params_from_latents(latent_params)
268
+ return params, latent_params, opt_state
269
+
270
+ return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
271
+
272
+
273
+ def _is_density(leaf: Any) -> Any:
274
+ """Return `True` if `leaf` is a density array."""
275
+ return isinstance(leaf, types.Density2DArray)
276
+
277
+
278
+ def _is_parameterized_density(leaf: Any) -> Any:
279
+ """Return `True` if `leaf` is a parameterized density array."""
280
+ return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase)
281
+
282
+
283
+ def _is_custom_type(leaf: Any) -> bool:
284
+ """Return `True` if `leaf` is a recognized custom type."""
285
+ return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
286
+
287
+
288
+ def _clip(pytree: PyTree) -> PyTree:
289
+ """Clips leaves on `pytree` to their bounds."""
290
+
291
+ def _clip_fn(leaf: Any) -> Any:
292
+ if not _is_custom_type(leaf):
293
+ return leaf
294
+ if leaf.lower_bound is None and leaf.upper_bound is None:
295
+ return leaf
296
+ return tree_util.tree_map(
297
+ lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
298
+ )
299
+
300
+ return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
@@ -0,0 +1,148 @@
1
+ """Base types for density parameterizations.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+ from typing import Any, Optional, Protocol, Sequence, Tuple
8
+
9
+ import jax.numpy as jnp
10
+ import numpy as onp
11
+ from jax import tree_util
12
+ from totypes import json_utils, types
13
+
14
+ Array = jnp.ndarray | onp.ndarray[Any, Any]
15
+ PyTree = Any
16
+
17
+
18
+ class ParameterizedDensity2DArrayBase:
19
+ """Base class for parameterized density arrays."""
20
+
21
+ pass
22
+
23
+
24
+ class FromDensityFn(Protocol):
25
+ """Generate the latent representation of a density array."""
26
+
27
+ def __call__(
28
+ self, density: types.Density2DArray
29
+ ) -> ParameterizedDensity2DArrayBase:
30
+ ...
31
+
32
+
33
+ class ToDensityFn(Protocol):
34
+ """Generate a density from its latent representation."""
35
+
36
+ def __call__(self, params: PyTree) -> types.Density2DArray:
37
+ ...
38
+
39
+
40
+ class ConstraintsFn(Protocol):
41
+ """Compute constraints for a latent representation of a density array."""
42
+
43
+ def __call__(self, params: PyTree) -> jnp.ndarray:
44
+ ...
45
+
46
+
47
+ @dataclasses.dataclass
48
+ class Density2DParameterization:
49
+ """Stores `(from_density, to_density, constraints)` function triple."""
50
+
51
+ from_density: FromDensityFn
52
+ to_density: ToDensityFn
53
+ constraints: ConstraintsFn
54
+
55
+
56
+ @dataclasses.dataclass
57
+ class Density2DMetadata:
58
+ """Stores the metadata of a `Density2DArray`."""
59
+
60
+ lower_bound: float
61
+ upper_bound: float
62
+ fixed_solid: Optional[Array]
63
+ fixed_void: Optional[Array]
64
+ minimum_width: int
65
+ minimum_spacing: int
66
+ periodic: Sequence[bool]
67
+ symmetries: Sequence[str]
68
+
69
+ def __post_init__(self) -> None:
70
+ self.periodic = tuple(self.periodic)
71
+ self.symmetries = tuple(self.symmetries)
72
+
73
+
74
+ def _flatten_density_2d_metadata(
75
+ metadata: Density2DMetadata,
76
+ ) -> Tuple[
77
+ Tuple[()],
78
+ Tuple[
79
+ float,
80
+ float,
81
+ types.HashableWrapper,
82
+ types.HashableWrapper,
83
+ int,
84
+ int,
85
+ Sequence[bool],
86
+ Sequence[str],
87
+ ],
88
+ ]:
89
+ """Flattens a `Density2DMetadata` into children and auxilliary data."""
90
+ return (
91
+ (),
92
+ (
93
+ metadata.lower_bound,
94
+ metadata.upper_bound,
95
+ types.HashableWrapper(metadata.fixed_solid),
96
+ types.HashableWrapper(metadata.fixed_void),
97
+ metadata.minimum_width,
98
+ metadata.minimum_spacing,
99
+ metadata.periodic,
100
+ metadata.symmetries,
101
+ ),
102
+ )
103
+
104
+
105
+ def _unflatten_density_2d_metadata(
106
+ aux: Tuple[
107
+ float,
108
+ float,
109
+ types.HashableWrapper,
110
+ types.HashableWrapper,
111
+ int,
112
+ int,
113
+ Sequence[bool],
114
+ Sequence[str],
115
+ ],
116
+ children: Tuple[()],
117
+ ) -> Density2DMetadata:
118
+ """Unflattens a flattened `Density2DMetadata`."""
119
+ del children
120
+ (
121
+ lower_bound,
122
+ upper_bound,
123
+ wrapped_fixed_solid,
124
+ wrapped_fixed_void,
125
+ minimum_width,
126
+ minimum_spacing,
127
+ periodic,
128
+ symmetries,
129
+ ) = aux
130
+ return Density2DMetadata(
131
+ lower_bound=lower_bound,
132
+ upper_bound=upper_bound,
133
+ fixed_solid=wrapped_fixed_solid.array, # type: ignore[arg-type]
134
+ fixed_void=wrapped_fixed_void.array, # type: ignore[arg-type]
135
+ minimum_width=minimum_width,
136
+ minimum_spacing=minimum_spacing,
137
+ periodic=tuple(periodic),
138
+ symmetries=tuple(symmetries),
139
+ )
140
+
141
+
142
+ tree_util.register_pytree_node(
143
+ Density2DMetadata,
144
+ flatten_func=_flatten_density_2d_metadata,
145
+ unflatten_func=_unflatten_density_2d_metadata,
146
+ )
147
+
148
+ json_utils.register_custom_type(Density2DMetadata)
@@ -0,0 +1,92 @@
1
+ """Defines filter-and-project density parameterization.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+
8
+ import jax.numpy as jnp
9
+ from jax import tree_util
10
+ from totypes import json_utils, types
11
+
12
+ from invrs_opt.parameterization import base, transforms
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class FilterAndProjectParams(base.ParameterizedDensity2DArrayBase):
17
+ """Stores the latent parameters of the pixel parameterization.
18
+
19
+ Attributes:
20
+ latent_density: The latent variable from which the density is obtained.
21
+ beta: Determines the sharpness of the thresholding operation.
22
+ """
23
+
24
+ latent_density: types.Density2DArray
25
+ beta: float
26
+
27
+
28
+ tree_util.register_dataclass(
29
+ FilterAndProjectParams,
30
+ data_fields=["latent_density"],
31
+ meta_fields=["beta"],
32
+ )
33
+
34
+ json_utils.register_custom_type(FilterAndProjectParams)
35
+
36
+
37
+ def filter_project(beta: float) -> base.Density2DParameterization:
38
+ """Defines a filter-project parameterization for density arrays.
39
+
40
+ The `DensityArray2D` is represented as latent density array that is transformed by,
41
+
42
+ transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
43
+
44
+ where the kernel has a full-width at half-maximum determined by the minimum width
45
+ and spacing parameters of the `DensityArray2D`.
46
+
47
+ When the density lower and upper bounds are -1 and +1, this basic expression is
48
+ Where the bounds differ, the density is scaled before the transform is applied, and
49
+ then unscaled afterwards.
50
+
51
+ Args:
52
+ beta: Determines the sharpness of the thresholding operation.
53
+
54
+ Returns:
55
+ The `Density2DParameterization`.
56
+ """
57
+
58
+ def from_density_fn(density: types.Density2DArray) -> FilterAndProjectParams:
59
+ """Return latent parameters for the given `density`."""
60
+ array = transforms.normalized_array_from_density(density)
61
+ array = jnp.clip(array, -1, 1)
62
+ array *= jnp.tanh(beta)
63
+ latent_array = jnp.arctanh(array) / beta
64
+ latent_array = transforms.rescale_array_for_density(latent_array, density)
65
+ return FilterAndProjectParams(
66
+ latent_density=dataclasses.replace(density, array=latent_array),
67
+ beta=beta,
68
+ )
69
+
70
+ def to_density_fn(params: FilterAndProjectParams) -> types.Density2DArray:
71
+ """Return a density from the latent parameters."""
72
+ transformed = types.symmetrize_density(params.latent_density)
73
+ transformed = transforms.density_gaussian_filter_and_tanh(
74
+ transformed, beta=params.beta
75
+ )
76
+ # Scale to ensure that the full valid range of the density array is reachable.
77
+ mid_value = (transformed.lower_bound + transformed.upper_bound) / 2
78
+ transformed = tree_util.tree_map(
79
+ lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
80
+ )
81
+ return transforms.apply_fixed_pixels(transformed)
82
+
83
+ def constraints_fn(params: FilterAndProjectParams) -> jnp.ndarray:
84
+ """Computes constraints associated with the params."""
85
+ del params
86
+ return jnp.asarray(0.0)
87
+
88
+ return base.Density2DParameterization(
89
+ to_density=to_density_fn,
90
+ from_density=from_density_fn,
91
+ constraints=constraints_fn,
92
+ )