invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +7 -4
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -672
- invrs_opt-0.4.0.dist-info/LICENSE +0 -21
- invrs_opt-0.4.0.dist-info/METADATA +0 -75
- invrs_opt-0.4.0.dist-info/RECORD +0 -14
- /invrs_opt/{lbfgsb → optimizers}/__init__.py +0 -0
- {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,347 @@
|
|
1
|
+
"""Defines a wrapper for optax optimizers.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Any, Optional, Tuple
|
7
|
+
|
8
|
+
import jax
|
9
|
+
import jax.numpy as jnp
|
10
|
+
import optax # type: ignore[import-untyped]
|
11
|
+
from jax import tree_util
|
12
|
+
from totypes import types
|
13
|
+
|
14
|
+
from invrs_opt.optimizers import base
|
15
|
+
from invrs_opt.parameterization import (
|
16
|
+
base as param_base,
|
17
|
+
filter_project,
|
18
|
+
gaussian_levelset,
|
19
|
+
pixel,
|
20
|
+
)
|
21
|
+
|
22
|
+
PyTree = Any
|
23
|
+
WrappedOptaxState = Tuple[int, PyTree, PyTree, PyTree]
|
24
|
+
|
25
|
+
|
26
|
+
def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
|
27
|
+
"""Return a wrapped optax optimizer."""
|
28
|
+
return parameterized_wrapped_optax(
|
29
|
+
opt=opt, penalty=0.0, density_parameterization=None
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def density_wrapped_optax(
|
34
|
+
opt: optax.GradientTransformation,
|
35
|
+
*,
|
36
|
+
beta: float,
|
37
|
+
) -> base.Optimizer:
|
38
|
+
"""Wrapped optax optimizer with filter-project density parameterization.
|
39
|
+
|
40
|
+
In the filter-project density parameterization, the optimization variable
|
41
|
+
associated with a density array is a latent density array; the density is obtained
|
42
|
+
by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
|
43
|
+
full-width at half-maximum equal to the length scale (the mean of declared minimum
|
44
|
+
width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
|
45
|
+
operation ("projection").
|
46
|
+
|
47
|
+
Args:
|
48
|
+
opt: The optax optimizer to be wrapped.
|
49
|
+
beta: Determines the sharpness of the thresholding operation.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The wrapped optax optimizer.
|
53
|
+
"""
|
54
|
+
return parameterized_wrapped_optax(
|
55
|
+
opt=opt,
|
56
|
+
penalty=0.0,
|
57
|
+
density_parameterization=filter_project.filter_project(beta=beta),
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def levelset_wrapped_optax(
|
62
|
+
opt: optax.GradientTransformation,
|
63
|
+
*,
|
64
|
+
penalty: float,
|
65
|
+
length_scale_spacing_factor: float = (
|
66
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
|
67
|
+
),
|
68
|
+
length_scale_fwhm_factor: float = (
|
69
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
|
70
|
+
),
|
71
|
+
length_scale_constraint_factor: float = (
|
72
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
|
73
|
+
),
|
74
|
+
smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
|
75
|
+
length_scale_constraint_beta: float = (
|
76
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
|
77
|
+
),
|
78
|
+
length_scale_constraint_weight: float = (
|
79
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
|
80
|
+
),
|
81
|
+
curvature_constraint_weight: float = (
|
82
|
+
gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
|
83
|
+
),
|
84
|
+
fixed_pixel_constraint_weight: float = (
|
85
|
+
gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
|
86
|
+
),
|
87
|
+
init_optimizer: optax.GradientTransformation = (
|
88
|
+
gaussian_levelset.DEFAULT_INIT_OPTIMIZER
|
89
|
+
),
|
90
|
+
init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
|
91
|
+
) -> base.Optimizer:
|
92
|
+
"""Wrapped optax optimizer with levelset density parameterization.
|
93
|
+
|
94
|
+
In the levelset parameterization, the optimization variable associated with a
|
95
|
+
density array is an array giving the amplitudes of Gaussian radial basis functions
|
96
|
+
that represent a levelset function over the domain of the density. In the levelset
|
97
|
+
parameterization, gradients are nonzero only at the edges of features, and in
|
98
|
+
general the topology of a solution does not change during the course of
|
99
|
+
optimization.
|
100
|
+
|
101
|
+
The spacing and full-width at half-maximum of the Gaussian basis functions gives
|
102
|
+
some amount of control over length scales. In addition, constraints associated with
|
103
|
+
length scale, radius of curvature, and deviation from fixed pixels are
|
104
|
+
automatically computed and penalized with a weight given by `penalty`. In general,
|
105
|
+
this helps ensure that features in an optimized density array violate the specified
|
106
|
+
constraints to a lesser degree. The constraints are based on "Analytical level set
|
107
|
+
fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
|
108
|
+
|
109
|
+
Args:
|
110
|
+
opt: The optax optimizer to be wrapped.
|
111
|
+
penalty: The weight of the fabrication penalty, which combines length scale,
|
112
|
+
curvature, and fixed pixel constraints.
|
113
|
+
length_scale_spacing_factor: The number of levelset control points per unit of
|
114
|
+
minimum length scale (mean of density minimum width and minimum spacing).
|
115
|
+
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
116
|
+
the minimum length scale.
|
117
|
+
length_scale_constraint_factor: Multiplies the target length scale in the
|
118
|
+
levelset constraints. A value greater than 1 is pessimistic and drives the
|
119
|
+
solution to have a larger length scale (relative to smaller values).
|
120
|
+
smoothing_factor: For values greater than 1, the density is initially computed
|
121
|
+
at higher resolution and then downsampled, yielding smoother geometries.
|
122
|
+
length_scale_constraint_beta: Controls relaxation of the length scale
|
123
|
+
constraint near the zero level.
|
124
|
+
length_scale_constraint_weight: The weight of the length scale constraint in
|
125
|
+
the overall fabrication constraint peenalty.
|
126
|
+
curvature_constraint_weight: The weight of the curvature constraint.
|
127
|
+
fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
|
128
|
+
init_optimizer: The optimizer used in the initialization of the levelset
|
129
|
+
parameterization. At initialization, the latent parameters are optimized so
|
130
|
+
that the initial parameters match the binarized initial density.
|
131
|
+
init_steps: The number of optimization steps used in the initialization.
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
The wrapped optax optimizer.
|
135
|
+
"""
|
136
|
+
return parameterized_wrapped_optax(
|
137
|
+
opt=opt,
|
138
|
+
penalty=penalty,
|
139
|
+
density_parameterization=gaussian_levelset.gaussian_levelset(
|
140
|
+
length_scale_spacing_factor=length_scale_spacing_factor,
|
141
|
+
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
142
|
+
length_scale_constraint_factor=length_scale_constraint_factor,
|
143
|
+
smoothing_factor=smoothing_factor,
|
144
|
+
length_scale_constraint_beta=length_scale_constraint_beta,
|
145
|
+
length_scale_constraint_weight=length_scale_constraint_weight,
|
146
|
+
curvature_constraint_weight=curvature_constraint_weight,
|
147
|
+
fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
|
148
|
+
init_optimizer=init_optimizer,
|
149
|
+
init_steps=init_steps,
|
150
|
+
),
|
151
|
+
)
|
152
|
+
|
153
|
+
|
154
|
+
# -----------------------------------------------------------------------------
|
155
|
+
# Base parameterized wrapped optax optimizer.
|
156
|
+
# -----------------------------------------------------------------------------
|
157
|
+
|
158
|
+
|
159
|
+
def parameterized_wrapped_optax(
|
160
|
+
opt: optax.GradientTransformation,
|
161
|
+
density_parameterization: Optional[param_base.Density2DParameterization],
|
162
|
+
penalty: float,
|
163
|
+
) -> base.Optimizer:
|
164
|
+
"""Wrapped optax optimizer with specified density parameterization.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
opt: The optax `GradientTransformation` to be wrapped.
|
168
|
+
density_parameterization: The parameterization to be used, or `None`. When no
|
169
|
+
parameterization is given, the direct pixel parameterization is used for
|
170
|
+
density arrays.
|
171
|
+
penalty: The weight of the scalar penalty formed from the constraints of the
|
172
|
+
parameterization.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
The `base.Optimizer`.
|
176
|
+
"""
|
177
|
+
|
178
|
+
if density_parameterization is None:
|
179
|
+
density_parameterization = pixel.pixel()
|
180
|
+
|
181
|
+
def init_fn(params: PyTree) -> WrappedOptaxState:
|
182
|
+
"""Initializes the optimization state."""
|
183
|
+
latent_params = _init_latents(params)
|
184
|
+
_, latents = param_base.partition_density_metadata(latent_params)
|
185
|
+
return (
|
186
|
+
0, # step
|
187
|
+
_params_from_latent_params(latent_params), # params
|
188
|
+
latent_params, # latent params
|
189
|
+
opt.init(latents), # opt state
|
190
|
+
)
|
191
|
+
|
192
|
+
def params_fn(state: WrappedOptaxState) -> PyTree:
|
193
|
+
"""Returns the parameters for the given `state`."""
|
194
|
+
_, params, _, _ = state
|
195
|
+
return params
|
196
|
+
|
197
|
+
def update_fn(
|
198
|
+
*,
|
199
|
+
grad: PyTree,
|
200
|
+
value: jnp.ndarray,
|
201
|
+
params: PyTree,
|
202
|
+
state: WrappedOptaxState,
|
203
|
+
) -> WrappedOptaxState:
|
204
|
+
"""Updates the state."""
|
205
|
+
del params
|
206
|
+
|
207
|
+
step, params, latent_params, opt_state = state
|
208
|
+
metadata, latents = param_base.partition_density_metadata(latent_params)
|
209
|
+
|
210
|
+
def _params_from_latents(latents: PyTree) -> PyTree:
|
211
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
212
|
+
return _params_from_latent_params(latent_params)
|
213
|
+
|
214
|
+
def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
|
215
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
216
|
+
return _constraint_loss(latent_params)
|
217
|
+
|
218
|
+
_, vjp_fn = jax.vjp(_params_from_latents, latents)
|
219
|
+
(latents_grad,) = vjp_fn(grad)
|
220
|
+
|
221
|
+
if not (
|
222
|
+
tree_util.tree_structure(latents_grad)
|
223
|
+
== tree_util.tree_structure(latents) # type: ignore[operator]
|
224
|
+
):
|
225
|
+
raise ValueError(
|
226
|
+
f"Tree structure of `latents_grad` was different than expected, got \n"
|
227
|
+
f"{tree_util.tree_structure(latents_grad)} but expected \n"
|
228
|
+
f"{tree_util.tree_structure(latents)}."
|
229
|
+
)
|
230
|
+
|
231
|
+
constraint_loss_grad = jax.grad(_constraint_loss_latents)(latents)
|
232
|
+
latents_grad = tree_util.tree_map(
|
233
|
+
lambda a, b: a + b, latents_grad, constraint_loss_grad
|
234
|
+
)
|
235
|
+
|
236
|
+
latent_updates, opt_state = opt.update(latents_grad, opt_state, params=latents)
|
237
|
+
latent_params = _apply_updates(
|
238
|
+
params=latent_params,
|
239
|
+
updates=param_base.combine_density_metadata(metadata, latent_updates),
|
240
|
+
value=value,
|
241
|
+
step=step,
|
242
|
+
)
|
243
|
+
latent_params = _clip(latent_params)
|
244
|
+
params = _params_from_latent_params(latent_params)
|
245
|
+
return (step + 1, params, latent_params, opt_state)
|
246
|
+
|
247
|
+
# -------------------------------------------------------------------------
|
248
|
+
# Functions related to the density parameterization.
|
249
|
+
# -------------------------------------------------------------------------
|
250
|
+
|
251
|
+
def _init_latents(params: PyTree) -> PyTree:
|
252
|
+
def _leaf_init_latents(leaf: Any) -> Any:
|
253
|
+
leaf = _clip(leaf)
|
254
|
+
if not _is_density(leaf):
|
255
|
+
return leaf
|
256
|
+
return density_parameterization.from_density(leaf)
|
257
|
+
|
258
|
+
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
259
|
+
|
260
|
+
def _params_from_latent_params(params: PyTree) -> PyTree:
|
261
|
+
def _leaf_params_from_latents(leaf: Any) -> Any:
|
262
|
+
if not _is_parameterized_density(leaf):
|
263
|
+
return leaf
|
264
|
+
return density_parameterization.to_density(leaf)
|
265
|
+
|
266
|
+
return tree_util.tree_map(
|
267
|
+
_leaf_params_from_latents,
|
268
|
+
params,
|
269
|
+
is_leaf=_is_parameterized_density,
|
270
|
+
)
|
271
|
+
|
272
|
+
def _apply_updates(
|
273
|
+
params: PyTree,
|
274
|
+
updates: PyTree,
|
275
|
+
value: jnp.ndarray,
|
276
|
+
step: int,
|
277
|
+
) -> PyTree:
|
278
|
+
def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
|
279
|
+
if _is_parameterized_density(leaf):
|
280
|
+
return density_parameterization.update(
|
281
|
+
params=leaf, updates=update, value=value, step=step
|
282
|
+
)
|
283
|
+
else:
|
284
|
+
return optax.apply_updates(params=leaf, updates=update)
|
285
|
+
|
286
|
+
return tree_util.tree_map(
|
287
|
+
_leaf_apply_updates,
|
288
|
+
updates,
|
289
|
+
params,
|
290
|
+
is_leaf=_is_parameterized_density,
|
291
|
+
)
|
292
|
+
|
293
|
+
# -------------------------------------------------------------------------
|
294
|
+
# Functions related to the constraints to be minimized.
|
295
|
+
# -------------------------------------------------------------------------
|
296
|
+
|
297
|
+
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
298
|
+
def _constraint_loss_leaf(
|
299
|
+
leaf: param_base.ParameterizedDensity2DArray,
|
300
|
+
) -> jnp.ndarray:
|
301
|
+
constraints = density_parameterization.constraints(leaf)
|
302
|
+
constraints = tree_util.tree_map(
|
303
|
+
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
304
|
+
constraints,
|
305
|
+
)
|
306
|
+
return jnp.sum(jnp.asarray(constraints))
|
307
|
+
|
308
|
+
losses = [0.0] + [
|
309
|
+
_constraint_loss_leaf(p)
|
310
|
+
for p in tree_util.tree_leaves(
|
311
|
+
latent_params, is_leaf=_is_parameterized_density
|
312
|
+
)
|
313
|
+
if _is_parameterized_density(p)
|
314
|
+
]
|
315
|
+
return penalty * jnp.sum(jnp.asarray(losses))
|
316
|
+
|
317
|
+
return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
|
318
|
+
|
319
|
+
|
320
|
+
def _is_density(leaf: Any) -> Any:
|
321
|
+
"""Return `True` if `leaf` is a density array."""
|
322
|
+
return isinstance(leaf, types.Density2DArray)
|
323
|
+
|
324
|
+
|
325
|
+
def _is_parameterized_density(leaf: Any) -> Any:
|
326
|
+
"""Return `True` if `leaf` is a parameterized density array."""
|
327
|
+
return isinstance(leaf, param_base.ParameterizedDensity2DArray)
|
328
|
+
|
329
|
+
|
330
|
+
def _is_custom_type(leaf: Any) -> bool:
|
331
|
+
"""Return `True` if `leaf` is a recognized custom type."""
|
332
|
+
return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
|
333
|
+
|
334
|
+
|
335
|
+
def _clip(pytree: PyTree) -> PyTree:
|
336
|
+
"""Clips leaves on `pytree` to their bounds."""
|
337
|
+
|
338
|
+
def _clip_fn(leaf: Any) -> Any:
|
339
|
+
if not _is_custom_type(leaf):
|
340
|
+
return leaf
|
341
|
+
if leaf.lower_bound is None and leaf.upper_bound is None:
|
342
|
+
return leaf
|
343
|
+
return tree_util.tree_map(
|
344
|
+
lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
|
345
|
+
)
|
346
|
+
|
347
|
+
return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
|
File without changes
|
@@ -0,0 +1,208 @@
|
|
1
|
+
"""Base types for density parameterizations.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
from typing import Any, Optional, Protocol, Sequence, Tuple
|
8
|
+
|
9
|
+
import jax.numpy as jnp
|
10
|
+
import numpy as onp
|
11
|
+
from jax import tree_util
|
12
|
+
from totypes import json_utils, partition_utils, types
|
13
|
+
|
14
|
+
Array = jnp.ndarray | onp.ndarray[Any, Any]
|
15
|
+
PyTree = Any
|
16
|
+
|
17
|
+
|
18
|
+
@dataclasses.dataclass
|
19
|
+
class ParameterizedDensity2DArray:
|
20
|
+
"""Stores latents and metadata for a parameterized density array."""
|
21
|
+
|
22
|
+
latents: "LatentsBase"
|
23
|
+
metadata: Optional["MetadataBase"]
|
24
|
+
|
25
|
+
|
26
|
+
class LatentsBase:
|
27
|
+
"""Base class for latents of a parameterized density array."""
|
28
|
+
|
29
|
+
pass
|
30
|
+
|
31
|
+
|
32
|
+
class MetadataBase:
|
33
|
+
"""Base class for metadata of a parameterized density array."""
|
34
|
+
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
tree_util.register_dataclass(
|
39
|
+
ParameterizedDensity2DArray,
|
40
|
+
data_fields=["latents", "metadata"],
|
41
|
+
meta_fields=[],
|
42
|
+
)
|
43
|
+
json_utils.register_custom_type(ParameterizedDensity2DArray)
|
44
|
+
|
45
|
+
|
46
|
+
def partition_density_metadata(tree: PyTree) -> Tuple[PyTree, PyTree]:
|
47
|
+
"""Splits a pytree with parameterized densities into metadata from latents."""
|
48
|
+
metadata, latents = partition_utils.partition(
|
49
|
+
tree,
|
50
|
+
select_fn=lambda x: isinstance(x, MetadataBase),
|
51
|
+
is_leaf=_is_metadata_or_none,
|
52
|
+
)
|
53
|
+
return metadata, latents
|
54
|
+
|
55
|
+
|
56
|
+
def combine_density_metadata(metadata: PyTree, latents: PyTree) -> PyTree:
|
57
|
+
"""Combines pytrees containing metadata and latents."""
|
58
|
+
return partition_utils.combine(metadata, latents, is_leaf=_is_metadata_or_none)
|
59
|
+
|
60
|
+
|
61
|
+
def _is_metadata_or_none(leaf: Any) -> bool:
|
62
|
+
"""Return `True` if `leaf` is `None` or density metadata."""
|
63
|
+
return leaf is None or isinstance(leaf, MetadataBase)
|
64
|
+
|
65
|
+
|
66
|
+
@dataclasses.dataclass
|
67
|
+
class Density2DParameterization:
|
68
|
+
"""Stores `(from_density, to_density, constraints, update)` function triple."""
|
69
|
+
|
70
|
+
from_density: "FromDensityFn"
|
71
|
+
to_density: "ToDensityFn"
|
72
|
+
constraints: "ConstraintsFn"
|
73
|
+
update: "UpdateFn"
|
74
|
+
|
75
|
+
|
76
|
+
class FromDensityFn(Protocol):
|
77
|
+
"""Generate the latent representation of a density array."""
|
78
|
+
|
79
|
+
def __call__(self, density: types.Density2DArray) -> ParameterizedDensity2DArray:
|
80
|
+
...
|
81
|
+
|
82
|
+
|
83
|
+
class ToDensityFn(Protocol):
|
84
|
+
"""Generate a density from its latent representation."""
|
85
|
+
|
86
|
+
def __call__(self, params: PyTree) -> types.Density2DArray:
|
87
|
+
...
|
88
|
+
|
89
|
+
|
90
|
+
class ConstraintsFn(Protocol):
|
91
|
+
"""Compute constraints for a latent representation of a density array."""
|
92
|
+
|
93
|
+
def __call__(self, params: PyTree) -> jnp.ndarray:
|
94
|
+
...
|
95
|
+
|
96
|
+
|
97
|
+
class UpdateFn(Protocol):
|
98
|
+
"""Performs the required update of a parameterized density for the given step."""
|
99
|
+
|
100
|
+
def __call__(
|
101
|
+
self,
|
102
|
+
params: PyTree,
|
103
|
+
updates: PyTree,
|
104
|
+
value: jnp.ndarray,
|
105
|
+
step: int,
|
106
|
+
) -> PyTree:
|
107
|
+
...
|
108
|
+
|
109
|
+
|
110
|
+
@dataclasses.dataclass
|
111
|
+
class Density2DMetadata:
|
112
|
+
"""Stores the metadata of a `Density2DArray`."""
|
113
|
+
|
114
|
+
lower_bound: float
|
115
|
+
upper_bound: float
|
116
|
+
fixed_solid: Optional[Array]
|
117
|
+
fixed_void: Optional[Array]
|
118
|
+
minimum_width: int
|
119
|
+
minimum_spacing: int
|
120
|
+
periodic: Sequence[bool]
|
121
|
+
symmetries: Sequence[str]
|
122
|
+
|
123
|
+
def __post_init__(self) -> None:
|
124
|
+
self.periodic = tuple(self.periodic)
|
125
|
+
self.symmetries = tuple(self.symmetries)
|
126
|
+
|
127
|
+
@classmethod
|
128
|
+
def from_density(self, density: types.Density2DArray) -> "Density2DMetadata":
|
129
|
+
density_metadata_dict = dataclasses.asdict(density)
|
130
|
+
del density_metadata_dict["array"]
|
131
|
+
return Density2DMetadata(**density_metadata_dict)
|
132
|
+
|
133
|
+
|
134
|
+
def _flatten_density_2d_metadata(
|
135
|
+
metadata: Density2DMetadata,
|
136
|
+
) -> Tuple[
|
137
|
+
Tuple[()],
|
138
|
+
Tuple[
|
139
|
+
float,
|
140
|
+
float,
|
141
|
+
types.HashableWrapper,
|
142
|
+
types.HashableWrapper,
|
143
|
+
int,
|
144
|
+
int,
|
145
|
+
Sequence[bool],
|
146
|
+
Sequence[str],
|
147
|
+
],
|
148
|
+
]:
|
149
|
+
"""Flattens a `Density2DMetadata` into children and auxilliary data."""
|
150
|
+
return (
|
151
|
+
(),
|
152
|
+
(
|
153
|
+
metadata.lower_bound,
|
154
|
+
metadata.upper_bound,
|
155
|
+
types.HashableWrapper(metadata.fixed_solid),
|
156
|
+
types.HashableWrapper(metadata.fixed_void),
|
157
|
+
metadata.minimum_width,
|
158
|
+
metadata.minimum_spacing,
|
159
|
+
metadata.periodic,
|
160
|
+
metadata.symmetries,
|
161
|
+
),
|
162
|
+
)
|
163
|
+
|
164
|
+
|
165
|
+
def _unflatten_density_2d_metadata(
|
166
|
+
aux: Tuple[
|
167
|
+
float,
|
168
|
+
float,
|
169
|
+
types.HashableWrapper,
|
170
|
+
types.HashableWrapper,
|
171
|
+
int,
|
172
|
+
int,
|
173
|
+
Sequence[bool],
|
174
|
+
Sequence[str],
|
175
|
+
],
|
176
|
+
children: Tuple[()],
|
177
|
+
) -> Density2DMetadata:
|
178
|
+
"""Unflattens a flattened `Density2DMetadata`."""
|
179
|
+
del children
|
180
|
+
(
|
181
|
+
lower_bound,
|
182
|
+
upper_bound,
|
183
|
+
wrapped_fixed_solid,
|
184
|
+
wrapped_fixed_void,
|
185
|
+
minimum_width,
|
186
|
+
minimum_spacing,
|
187
|
+
periodic,
|
188
|
+
symmetries,
|
189
|
+
) = aux
|
190
|
+
return Density2DMetadata(
|
191
|
+
lower_bound=lower_bound,
|
192
|
+
upper_bound=upper_bound,
|
193
|
+
fixed_solid=wrapped_fixed_solid.array, # type: ignore[arg-type]
|
194
|
+
fixed_void=wrapped_fixed_void.array, # type: ignore[arg-type]
|
195
|
+
minimum_width=minimum_width,
|
196
|
+
minimum_spacing=minimum_spacing,
|
197
|
+
periodic=tuple(periodic),
|
198
|
+
symmetries=tuple(symmetries),
|
199
|
+
)
|
200
|
+
|
201
|
+
|
202
|
+
tree_util.register_pytree_node(
|
203
|
+
Density2DMetadata,
|
204
|
+
flatten_func=_flatten_density_2d_metadata,
|
205
|
+
unflatten_func=_unflatten_density_2d_metadata,
|
206
|
+
)
|
207
|
+
|
208
|
+
json_utils.register_custom_type(Density2DMetadata)
|
@@ -0,0 +1,138 @@
|
|
1
|
+
"""Defines filter-and-project density parameterization.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from jax import tree_util
|
10
|
+
from totypes import json_utils, types
|
11
|
+
|
12
|
+
from invrs_opt.parameterization import base, transforms
|
13
|
+
|
14
|
+
|
15
|
+
@dataclasses.dataclass
|
16
|
+
class FilterProjectParams(base.ParameterizedDensity2DArray):
|
17
|
+
"""Stores parameters for the filter-project parameterization."""
|
18
|
+
|
19
|
+
latents: "FilterProjectLatents"
|
20
|
+
metadata: "FilterProjectMetadata"
|
21
|
+
|
22
|
+
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class FilterProjectLatents(base.LatentsBase):
|
25
|
+
"""Stores latent parameters for the filter-project parameterization.
|
26
|
+
|
27
|
+
Attributes:s
|
28
|
+
latent_density: The latent variable from which the density is obtained.
|
29
|
+
"""
|
30
|
+
|
31
|
+
latent_density: types.Density2DArray
|
32
|
+
|
33
|
+
|
34
|
+
@dataclasses.dataclass
|
35
|
+
class FilterProjectMetadata(base.MetadataBase):
|
36
|
+
"""Stores metadata for the filter-project parameterization.
|
37
|
+
|
38
|
+
Attributes:
|
39
|
+
beta: Determines the sharpness of the thresholding operation.
|
40
|
+
"""
|
41
|
+
|
42
|
+
beta: float
|
43
|
+
|
44
|
+
|
45
|
+
tree_util.register_dataclass(
|
46
|
+
FilterProjectParams,
|
47
|
+
data_fields=["latents", "metadata"],
|
48
|
+
meta_fields=[],
|
49
|
+
)
|
50
|
+
tree_util.register_dataclass(
|
51
|
+
FilterProjectLatents,
|
52
|
+
data_fields=["latent_density"],
|
53
|
+
meta_fields=[],
|
54
|
+
)
|
55
|
+
tree_util.register_dataclass(
|
56
|
+
FilterProjectMetadata,
|
57
|
+
data_fields=[],
|
58
|
+
meta_fields=["beta"],
|
59
|
+
)
|
60
|
+
json_utils.register_custom_type(FilterProjectParams)
|
61
|
+
json_utils.register_custom_type(FilterProjectLatents)
|
62
|
+
json_utils.register_custom_type(FilterProjectMetadata)
|
63
|
+
|
64
|
+
|
65
|
+
def filter_project(beta: float) -> base.Density2DParameterization:
|
66
|
+
"""Defines a filter-project parameterization for density arrays.
|
67
|
+
|
68
|
+
The `DensityArray2D` is represented as latent density array that is transformed by,
|
69
|
+
|
70
|
+
transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
|
71
|
+
|
72
|
+
where the kernel has a full-width at half-maximum determined by the minimum width
|
73
|
+
and spacing parameters of the `DensityArray2D`.
|
74
|
+
|
75
|
+
When the density lower and upper bounds are -1 and +1, this basic expression is
|
76
|
+
Where the bounds differ, the density is scaled before the transform is applied, and
|
77
|
+
then unscaled afterwards.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
beta: Determines the sharpness of the thresholding operation.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
The `Density2DParameterization`.
|
84
|
+
"""
|
85
|
+
|
86
|
+
def from_density_fn(density: types.Density2DArray) -> FilterProjectParams:
|
87
|
+
"""Return latent parameters for the given `density`."""
|
88
|
+
array = transforms.normalized_array_from_density(density)
|
89
|
+
array = jnp.clip(array, -1, 1)
|
90
|
+
array *= jnp.tanh(beta)
|
91
|
+
latent_array = jnp.arctanh(array) / beta
|
92
|
+
latent_array = transforms.rescale_array_for_density(latent_array, density)
|
93
|
+
latent_density = density = dataclasses.replace(density, array=latent_array)
|
94
|
+
return FilterProjectParams(
|
95
|
+
latents=FilterProjectLatents(latent_density=latent_density),
|
96
|
+
metadata=FilterProjectMetadata(beta=beta),
|
97
|
+
)
|
98
|
+
|
99
|
+
def to_density_fn(params: FilterProjectParams) -> types.Density2DArray:
|
100
|
+
"""Return a density from the latent parameters."""
|
101
|
+
latent_density = params.latents.latent_density
|
102
|
+
beta = params.metadata.beta
|
103
|
+
|
104
|
+
transformed = types.symmetrize_density(latent_density)
|
105
|
+
transformed = transforms.density_gaussian_filter_and_tanh(transformed, beta)
|
106
|
+
# Scale to ensure that the full valid range of the density array is reachable.
|
107
|
+
mid_value = (transformed.lower_bound + transformed.upper_bound) / 2
|
108
|
+
transformed = tree_util.tree_map(
|
109
|
+
lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
|
110
|
+
)
|
111
|
+
return transforms.apply_fixed_pixels(transformed)
|
112
|
+
|
113
|
+
def constraints_fn(params: FilterProjectParams) -> jnp.ndarray:
|
114
|
+
"""Computes constraints associated with the params."""
|
115
|
+
del params
|
116
|
+
return jnp.asarray(0.0)
|
117
|
+
|
118
|
+
def update_fn(
|
119
|
+
params: FilterProjectParams,
|
120
|
+
updates: FilterProjectParams,
|
121
|
+
value: jnp.ndarray,
|
122
|
+
step: int,
|
123
|
+
) -> FilterProjectParams:
|
124
|
+
"""Perform updates to `params` required for the given `step`."""
|
125
|
+
del step, value
|
126
|
+
return FilterProjectParams(
|
127
|
+
latents=tree_util.tree_map(
|
128
|
+
lambda a, b: a + b, params.latents, updates.latents
|
129
|
+
),
|
130
|
+
metadata=params.metadata,
|
131
|
+
)
|
132
|
+
|
133
|
+
return base.Density2DParameterization(
|
134
|
+
to_density=to_density_fn,
|
135
|
+
from_density=from_density_fn,
|
136
|
+
constraints=constraints_fn,
|
137
|
+
update=update_fn,
|
138
|
+
)
|