invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +155 -0
- invrs_opt/experimental/labels.py +23 -0
- invrs_opt/optimizers/__init__.py +0 -0
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -670
- invrs_opt-0.3.2.dist-info/LICENSE +0 -21
- invrs_opt-0.3.2.dist-info/METADATA +0 -73
- invrs_opt-0.3.2.dist-info/RECORD +0 -11
- /invrs_opt/{lbfgsb → experimental}/__init__.py +0 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,347 @@
|
|
1
|
+
"""Defines a wrapper for optax optimizers.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Any, Optional, Tuple
|
7
|
+
|
8
|
+
import jax
|
9
|
+
import jax.numpy as jnp
|
10
|
+
import optax # type: ignore[import-untyped]
|
11
|
+
from jax import tree_util
|
12
|
+
from totypes import types
|
13
|
+
|
14
|
+
from invrs_opt.optimizers import base
|
15
|
+
from invrs_opt.parameterization import (
|
16
|
+
base as param_base,
|
17
|
+
filter_project,
|
18
|
+
gaussian_levelset,
|
19
|
+
pixel,
|
20
|
+
)
|
21
|
+
|
22
|
+
PyTree = Any
|
23
|
+
WrappedOptaxState = Tuple[int, PyTree, PyTree, PyTree]
|
24
|
+
|
25
|
+
|
26
|
+
def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
|
27
|
+
"""Return a wrapped optax optimizer."""
|
28
|
+
return parameterized_wrapped_optax(
|
29
|
+
opt=opt, penalty=0.0, density_parameterization=None
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def density_wrapped_optax(
|
34
|
+
opt: optax.GradientTransformation,
|
35
|
+
*,
|
36
|
+
beta: float,
|
37
|
+
) -> base.Optimizer:
|
38
|
+
"""Wrapped optax optimizer with filter-project density parameterization.
|
39
|
+
|
40
|
+
In the filter-project density parameterization, the optimization variable
|
41
|
+
associated with a density array is a latent density array; the density is obtained
|
42
|
+
by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
|
43
|
+
full-width at half-maximum equal to the length scale (the mean of declared minimum
|
44
|
+
width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
|
45
|
+
operation ("projection").
|
46
|
+
|
47
|
+
Args:
|
48
|
+
opt: The optax optimizer to be wrapped.
|
49
|
+
beta: Determines the sharpness of the thresholding operation.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The wrapped optax optimizer.
|
53
|
+
"""
|
54
|
+
return parameterized_wrapped_optax(
|
55
|
+
opt=opt,
|
56
|
+
penalty=0.0,
|
57
|
+
density_parameterization=filter_project.filter_project(beta=beta),
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def levelset_wrapped_optax(
|
62
|
+
opt: optax.GradientTransformation,
|
63
|
+
*,
|
64
|
+
penalty: float,
|
65
|
+
length_scale_spacing_factor: float = (
|
66
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
|
67
|
+
),
|
68
|
+
length_scale_fwhm_factor: float = (
|
69
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
|
70
|
+
),
|
71
|
+
length_scale_constraint_factor: float = (
|
72
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
|
73
|
+
),
|
74
|
+
smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
|
75
|
+
length_scale_constraint_beta: float = (
|
76
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
|
77
|
+
),
|
78
|
+
length_scale_constraint_weight: float = (
|
79
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
|
80
|
+
),
|
81
|
+
curvature_constraint_weight: float = (
|
82
|
+
gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
|
83
|
+
),
|
84
|
+
fixed_pixel_constraint_weight: float = (
|
85
|
+
gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
|
86
|
+
),
|
87
|
+
init_optimizer: optax.GradientTransformation = (
|
88
|
+
gaussian_levelset.DEFAULT_INIT_OPTIMIZER
|
89
|
+
),
|
90
|
+
init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
|
91
|
+
) -> base.Optimizer:
|
92
|
+
"""Wrapped optax optimizer with levelset density parameterization.
|
93
|
+
|
94
|
+
In the levelset parameterization, the optimization variable associated with a
|
95
|
+
density array is an array giving the amplitudes of Gaussian radial basis functions
|
96
|
+
that represent a levelset function over the domain of the density. In the levelset
|
97
|
+
parameterization, gradients are nonzero only at the edges of features, and in
|
98
|
+
general the topology of a solution does not change during the course of
|
99
|
+
optimization.
|
100
|
+
|
101
|
+
The spacing and full-width at half-maximum of the Gaussian basis functions gives
|
102
|
+
some amount of control over length scales. In addition, constraints associated with
|
103
|
+
length scale, radius of curvature, and deviation from fixed pixels are
|
104
|
+
automatically computed and penalized with a weight given by `penalty`. In general,
|
105
|
+
this helps ensure that features in an optimized density array violate the specified
|
106
|
+
constraints to a lesser degree. The constraints are based on "Analytical level set
|
107
|
+
fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
|
108
|
+
|
109
|
+
Args:
|
110
|
+
opt: The optax optimizer to be wrapped.
|
111
|
+
penalty: The weight of the fabrication penalty, which combines length scale,
|
112
|
+
curvature, and fixed pixel constraints.
|
113
|
+
length_scale_spacing_factor: The number of levelset control points per unit of
|
114
|
+
minimum length scale (mean of density minimum width and minimum spacing).
|
115
|
+
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
116
|
+
the minimum length scale.
|
117
|
+
length_scale_constraint_factor: Multiplies the target length scale in the
|
118
|
+
levelset constraints. A value greater than 1 is pessimistic and drives the
|
119
|
+
solution to have a larger length scale (relative to smaller values).
|
120
|
+
smoothing_factor: For values greater than 1, the density is initially computed
|
121
|
+
at higher resolution and then downsampled, yielding smoother geometries.
|
122
|
+
length_scale_constraint_beta: Controls relaxation of the length scale
|
123
|
+
constraint near the zero level.
|
124
|
+
length_scale_constraint_weight: The weight of the length scale constraint in
|
125
|
+
the overall fabrication constraint peenalty.
|
126
|
+
curvature_constraint_weight: The weight of the curvature constraint.
|
127
|
+
fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
|
128
|
+
init_optimizer: The optimizer used in the initialization of the levelset
|
129
|
+
parameterization. At initialization, the latent parameters are optimized so
|
130
|
+
that the initial parameters match the binarized initial density.
|
131
|
+
init_steps: The number of optimization steps used in the initialization.
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
The wrapped optax optimizer.
|
135
|
+
"""
|
136
|
+
return parameterized_wrapped_optax(
|
137
|
+
opt=opt,
|
138
|
+
penalty=penalty,
|
139
|
+
density_parameterization=gaussian_levelset.gaussian_levelset(
|
140
|
+
length_scale_spacing_factor=length_scale_spacing_factor,
|
141
|
+
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
142
|
+
length_scale_constraint_factor=length_scale_constraint_factor,
|
143
|
+
smoothing_factor=smoothing_factor,
|
144
|
+
length_scale_constraint_beta=length_scale_constraint_beta,
|
145
|
+
length_scale_constraint_weight=length_scale_constraint_weight,
|
146
|
+
curvature_constraint_weight=curvature_constraint_weight,
|
147
|
+
fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
|
148
|
+
init_optimizer=init_optimizer,
|
149
|
+
init_steps=init_steps,
|
150
|
+
),
|
151
|
+
)
|
152
|
+
|
153
|
+
|
154
|
+
# -----------------------------------------------------------------------------
|
155
|
+
# Base parameterized wrapped optax optimizer.
|
156
|
+
# -----------------------------------------------------------------------------
|
157
|
+
|
158
|
+
|
159
|
+
def parameterized_wrapped_optax(
|
160
|
+
opt: optax.GradientTransformation,
|
161
|
+
density_parameterization: Optional[param_base.Density2DParameterization],
|
162
|
+
penalty: float,
|
163
|
+
) -> base.Optimizer:
|
164
|
+
"""Wrapped optax optimizer with specified density parameterization.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
opt: The optax `GradientTransformation` to be wrapped.
|
168
|
+
density_parameterization: The parameterization to be used, or `None`. When no
|
169
|
+
parameterization is given, the direct pixel parameterization is used for
|
170
|
+
density arrays.
|
171
|
+
penalty: The weight of the scalar penalty formed from the constraints of the
|
172
|
+
parameterization.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
The `base.Optimizer`.
|
176
|
+
"""
|
177
|
+
|
178
|
+
if density_parameterization is None:
|
179
|
+
density_parameterization = pixel.pixel()
|
180
|
+
|
181
|
+
def init_fn(params: PyTree) -> WrappedOptaxState:
|
182
|
+
"""Initializes the optimization state."""
|
183
|
+
latent_params = _init_latents(params)
|
184
|
+
_, latents = param_base.partition_density_metadata(latent_params)
|
185
|
+
return (
|
186
|
+
0, # step
|
187
|
+
_params_from_latent_params(latent_params), # params
|
188
|
+
latent_params, # latent params
|
189
|
+
opt.init(latents), # opt state
|
190
|
+
)
|
191
|
+
|
192
|
+
def params_fn(state: WrappedOptaxState) -> PyTree:
|
193
|
+
"""Returns the parameters for the given `state`."""
|
194
|
+
_, params, _, _ = state
|
195
|
+
return params
|
196
|
+
|
197
|
+
def update_fn(
|
198
|
+
*,
|
199
|
+
grad: PyTree,
|
200
|
+
value: jnp.ndarray,
|
201
|
+
params: PyTree,
|
202
|
+
state: WrappedOptaxState,
|
203
|
+
) -> WrappedOptaxState:
|
204
|
+
"""Updates the state."""
|
205
|
+
del params
|
206
|
+
|
207
|
+
step, params, latent_params, opt_state = state
|
208
|
+
metadata, latents = param_base.partition_density_metadata(latent_params)
|
209
|
+
|
210
|
+
def _params_from_latents(latents: PyTree) -> PyTree:
|
211
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
212
|
+
return _params_from_latent_params(latent_params)
|
213
|
+
|
214
|
+
def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
|
215
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
216
|
+
return _constraint_loss(latent_params)
|
217
|
+
|
218
|
+
_, vjp_fn = jax.vjp(_params_from_latents, latents)
|
219
|
+
(latents_grad,) = vjp_fn(grad)
|
220
|
+
|
221
|
+
if not (
|
222
|
+
tree_util.tree_structure(latents_grad)
|
223
|
+
== tree_util.tree_structure(latents) # type: ignore[operator]
|
224
|
+
):
|
225
|
+
raise ValueError(
|
226
|
+
f"Tree structure of `latents_grad` was different than expected, got \n"
|
227
|
+
f"{tree_util.tree_structure(latents_grad)} but expected \n"
|
228
|
+
f"{tree_util.tree_structure(latents)}."
|
229
|
+
)
|
230
|
+
|
231
|
+
constraint_loss_grad = jax.grad(_constraint_loss_latents)(latents)
|
232
|
+
latents_grad = tree_util.tree_map(
|
233
|
+
lambda a, b: a + b, latents_grad, constraint_loss_grad
|
234
|
+
)
|
235
|
+
|
236
|
+
latent_updates, opt_state = opt.update(latents_grad, opt_state, params=latents)
|
237
|
+
latent_params = _apply_updates(
|
238
|
+
params=latent_params,
|
239
|
+
updates=param_base.combine_density_metadata(metadata, latent_updates),
|
240
|
+
value=value,
|
241
|
+
step=step,
|
242
|
+
)
|
243
|
+
latent_params = _clip(latent_params)
|
244
|
+
params = _params_from_latent_params(latent_params)
|
245
|
+
return (step + 1, params, latent_params, opt_state)
|
246
|
+
|
247
|
+
# -------------------------------------------------------------------------
|
248
|
+
# Functions related to the density parameterization.
|
249
|
+
# -------------------------------------------------------------------------
|
250
|
+
|
251
|
+
def _init_latents(params: PyTree) -> PyTree:
|
252
|
+
def _leaf_init_latents(leaf: Any) -> Any:
|
253
|
+
leaf = _clip(leaf)
|
254
|
+
if not _is_density(leaf):
|
255
|
+
return leaf
|
256
|
+
return density_parameterization.from_density(leaf)
|
257
|
+
|
258
|
+
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
259
|
+
|
260
|
+
def _params_from_latent_params(params: PyTree) -> PyTree:
|
261
|
+
def _leaf_params_from_latents(leaf: Any) -> Any:
|
262
|
+
if not _is_parameterized_density(leaf):
|
263
|
+
return leaf
|
264
|
+
return density_parameterization.to_density(leaf)
|
265
|
+
|
266
|
+
return tree_util.tree_map(
|
267
|
+
_leaf_params_from_latents,
|
268
|
+
params,
|
269
|
+
is_leaf=_is_parameterized_density,
|
270
|
+
)
|
271
|
+
|
272
|
+
def _apply_updates(
|
273
|
+
params: PyTree,
|
274
|
+
updates: PyTree,
|
275
|
+
value: jnp.ndarray,
|
276
|
+
step: int,
|
277
|
+
) -> PyTree:
|
278
|
+
def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
|
279
|
+
if _is_parameterized_density(leaf):
|
280
|
+
return density_parameterization.update(
|
281
|
+
params=leaf, updates=update, value=value, step=step
|
282
|
+
)
|
283
|
+
else:
|
284
|
+
return optax.apply_updates(params=leaf, updates=update)
|
285
|
+
|
286
|
+
return tree_util.tree_map(
|
287
|
+
_leaf_apply_updates,
|
288
|
+
updates,
|
289
|
+
params,
|
290
|
+
is_leaf=_is_parameterized_density,
|
291
|
+
)
|
292
|
+
|
293
|
+
# -------------------------------------------------------------------------
|
294
|
+
# Functions related to the constraints to be minimized.
|
295
|
+
# -------------------------------------------------------------------------
|
296
|
+
|
297
|
+
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
298
|
+
def _constraint_loss_leaf(
|
299
|
+
leaf: param_base.ParameterizedDensity2DArray,
|
300
|
+
) -> jnp.ndarray:
|
301
|
+
constraints = density_parameterization.constraints(leaf)
|
302
|
+
constraints = tree_util.tree_map(
|
303
|
+
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
304
|
+
constraints,
|
305
|
+
)
|
306
|
+
return jnp.sum(jnp.asarray(constraints))
|
307
|
+
|
308
|
+
losses = [0.0] + [
|
309
|
+
_constraint_loss_leaf(p)
|
310
|
+
for p in tree_util.tree_leaves(
|
311
|
+
latent_params, is_leaf=_is_parameterized_density
|
312
|
+
)
|
313
|
+
if _is_parameterized_density(p)
|
314
|
+
]
|
315
|
+
return penalty * jnp.sum(jnp.asarray(losses))
|
316
|
+
|
317
|
+
return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
|
318
|
+
|
319
|
+
|
320
|
+
def _is_density(leaf: Any) -> Any:
|
321
|
+
"""Return `True` if `leaf` is a density array."""
|
322
|
+
return isinstance(leaf, types.Density2DArray)
|
323
|
+
|
324
|
+
|
325
|
+
def _is_parameterized_density(leaf: Any) -> Any:
|
326
|
+
"""Return `True` if `leaf` is a parameterized density array."""
|
327
|
+
return isinstance(leaf, param_base.ParameterizedDensity2DArray)
|
328
|
+
|
329
|
+
|
330
|
+
def _is_custom_type(leaf: Any) -> bool:
|
331
|
+
"""Return `True` if `leaf` is a recognized custom type."""
|
332
|
+
return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
|
333
|
+
|
334
|
+
|
335
|
+
def _clip(pytree: PyTree) -> PyTree:
|
336
|
+
"""Clips leaves on `pytree` to their bounds."""
|
337
|
+
|
338
|
+
def _clip_fn(leaf: Any) -> Any:
|
339
|
+
if not _is_custom_type(leaf):
|
340
|
+
return leaf
|
341
|
+
if leaf.lower_bound is None and leaf.upper_bound is None:
|
342
|
+
return leaf
|
343
|
+
return tree_util.tree_map(
|
344
|
+
lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
|
345
|
+
)
|
346
|
+
|
347
|
+
return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
|
File without changes
|
@@ -0,0 +1,208 @@
|
|
1
|
+
"""Base types for density parameterizations.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
from typing import Any, Optional, Protocol, Sequence, Tuple
|
8
|
+
|
9
|
+
import jax.numpy as jnp
|
10
|
+
import numpy as onp
|
11
|
+
from jax import tree_util
|
12
|
+
from totypes import json_utils, partition_utils, types
|
13
|
+
|
14
|
+
Array = jnp.ndarray | onp.ndarray[Any, Any]
|
15
|
+
PyTree = Any
|
16
|
+
|
17
|
+
|
18
|
+
@dataclasses.dataclass
|
19
|
+
class ParameterizedDensity2DArray:
|
20
|
+
"""Stores latents and metadata for a parameterized density array."""
|
21
|
+
|
22
|
+
latents: "LatentsBase"
|
23
|
+
metadata: Optional["MetadataBase"]
|
24
|
+
|
25
|
+
|
26
|
+
class LatentsBase:
|
27
|
+
"""Base class for latents of a parameterized density array."""
|
28
|
+
|
29
|
+
pass
|
30
|
+
|
31
|
+
|
32
|
+
class MetadataBase:
|
33
|
+
"""Base class for metadata of a parameterized density array."""
|
34
|
+
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
tree_util.register_dataclass(
|
39
|
+
ParameterizedDensity2DArray,
|
40
|
+
data_fields=["latents", "metadata"],
|
41
|
+
meta_fields=[],
|
42
|
+
)
|
43
|
+
json_utils.register_custom_type(ParameterizedDensity2DArray)
|
44
|
+
|
45
|
+
|
46
|
+
def partition_density_metadata(tree: PyTree) -> Tuple[PyTree, PyTree]:
|
47
|
+
"""Splits a pytree with parameterized densities into metadata from latents."""
|
48
|
+
metadata, latents = partition_utils.partition(
|
49
|
+
tree,
|
50
|
+
select_fn=lambda x: isinstance(x, MetadataBase),
|
51
|
+
is_leaf=_is_metadata_or_none,
|
52
|
+
)
|
53
|
+
return metadata, latents
|
54
|
+
|
55
|
+
|
56
|
+
def combine_density_metadata(metadata: PyTree, latents: PyTree) -> PyTree:
|
57
|
+
"""Combines pytrees containing metadata and latents."""
|
58
|
+
return partition_utils.combine(metadata, latents, is_leaf=_is_metadata_or_none)
|
59
|
+
|
60
|
+
|
61
|
+
def _is_metadata_or_none(leaf: Any) -> bool:
|
62
|
+
"""Return `True` if `leaf` is `None` or density metadata."""
|
63
|
+
return leaf is None or isinstance(leaf, MetadataBase)
|
64
|
+
|
65
|
+
|
66
|
+
@dataclasses.dataclass
|
67
|
+
class Density2DParameterization:
|
68
|
+
"""Stores `(from_density, to_density, constraints, update)` function triple."""
|
69
|
+
|
70
|
+
from_density: "FromDensityFn"
|
71
|
+
to_density: "ToDensityFn"
|
72
|
+
constraints: "ConstraintsFn"
|
73
|
+
update: "UpdateFn"
|
74
|
+
|
75
|
+
|
76
|
+
class FromDensityFn(Protocol):
|
77
|
+
"""Generate the latent representation of a density array."""
|
78
|
+
|
79
|
+
def __call__(self, density: types.Density2DArray) -> ParameterizedDensity2DArray:
|
80
|
+
...
|
81
|
+
|
82
|
+
|
83
|
+
class ToDensityFn(Protocol):
|
84
|
+
"""Generate a density from its latent representation."""
|
85
|
+
|
86
|
+
def __call__(self, params: PyTree) -> types.Density2DArray:
|
87
|
+
...
|
88
|
+
|
89
|
+
|
90
|
+
class ConstraintsFn(Protocol):
|
91
|
+
"""Compute constraints for a latent representation of a density array."""
|
92
|
+
|
93
|
+
def __call__(self, params: PyTree) -> jnp.ndarray:
|
94
|
+
...
|
95
|
+
|
96
|
+
|
97
|
+
class UpdateFn(Protocol):
|
98
|
+
"""Performs the required update of a parameterized density for the given step."""
|
99
|
+
|
100
|
+
def __call__(
|
101
|
+
self,
|
102
|
+
params: PyTree,
|
103
|
+
updates: PyTree,
|
104
|
+
value: jnp.ndarray,
|
105
|
+
step: int,
|
106
|
+
) -> PyTree:
|
107
|
+
...
|
108
|
+
|
109
|
+
|
110
|
+
@dataclasses.dataclass
|
111
|
+
class Density2DMetadata:
|
112
|
+
"""Stores the metadata of a `Density2DArray`."""
|
113
|
+
|
114
|
+
lower_bound: float
|
115
|
+
upper_bound: float
|
116
|
+
fixed_solid: Optional[Array]
|
117
|
+
fixed_void: Optional[Array]
|
118
|
+
minimum_width: int
|
119
|
+
minimum_spacing: int
|
120
|
+
periodic: Sequence[bool]
|
121
|
+
symmetries: Sequence[str]
|
122
|
+
|
123
|
+
def __post_init__(self) -> None:
|
124
|
+
self.periodic = tuple(self.periodic)
|
125
|
+
self.symmetries = tuple(self.symmetries)
|
126
|
+
|
127
|
+
@classmethod
|
128
|
+
def from_density(self, density: types.Density2DArray) -> "Density2DMetadata":
|
129
|
+
density_metadata_dict = dataclasses.asdict(density)
|
130
|
+
del density_metadata_dict["array"]
|
131
|
+
return Density2DMetadata(**density_metadata_dict)
|
132
|
+
|
133
|
+
|
134
|
+
def _flatten_density_2d_metadata(
|
135
|
+
metadata: Density2DMetadata,
|
136
|
+
) -> Tuple[
|
137
|
+
Tuple[()],
|
138
|
+
Tuple[
|
139
|
+
float,
|
140
|
+
float,
|
141
|
+
types.HashableWrapper,
|
142
|
+
types.HashableWrapper,
|
143
|
+
int,
|
144
|
+
int,
|
145
|
+
Sequence[bool],
|
146
|
+
Sequence[str],
|
147
|
+
],
|
148
|
+
]:
|
149
|
+
"""Flattens a `Density2DMetadata` into children and auxilliary data."""
|
150
|
+
return (
|
151
|
+
(),
|
152
|
+
(
|
153
|
+
metadata.lower_bound,
|
154
|
+
metadata.upper_bound,
|
155
|
+
types.HashableWrapper(metadata.fixed_solid),
|
156
|
+
types.HashableWrapper(metadata.fixed_void),
|
157
|
+
metadata.minimum_width,
|
158
|
+
metadata.minimum_spacing,
|
159
|
+
metadata.periodic,
|
160
|
+
metadata.symmetries,
|
161
|
+
),
|
162
|
+
)
|
163
|
+
|
164
|
+
|
165
|
+
def _unflatten_density_2d_metadata(
|
166
|
+
aux: Tuple[
|
167
|
+
float,
|
168
|
+
float,
|
169
|
+
types.HashableWrapper,
|
170
|
+
types.HashableWrapper,
|
171
|
+
int,
|
172
|
+
int,
|
173
|
+
Sequence[bool],
|
174
|
+
Sequence[str],
|
175
|
+
],
|
176
|
+
children: Tuple[()],
|
177
|
+
) -> Density2DMetadata:
|
178
|
+
"""Unflattens a flattened `Density2DMetadata`."""
|
179
|
+
del children
|
180
|
+
(
|
181
|
+
lower_bound,
|
182
|
+
upper_bound,
|
183
|
+
wrapped_fixed_solid,
|
184
|
+
wrapped_fixed_void,
|
185
|
+
minimum_width,
|
186
|
+
minimum_spacing,
|
187
|
+
periodic,
|
188
|
+
symmetries,
|
189
|
+
) = aux
|
190
|
+
return Density2DMetadata(
|
191
|
+
lower_bound=lower_bound,
|
192
|
+
upper_bound=upper_bound,
|
193
|
+
fixed_solid=wrapped_fixed_solid.array, # type: ignore[arg-type]
|
194
|
+
fixed_void=wrapped_fixed_void.array, # type: ignore[arg-type]
|
195
|
+
minimum_width=minimum_width,
|
196
|
+
minimum_spacing=minimum_spacing,
|
197
|
+
periodic=tuple(periodic),
|
198
|
+
symmetries=tuple(symmetries),
|
199
|
+
)
|
200
|
+
|
201
|
+
|
202
|
+
tree_util.register_pytree_node(
|
203
|
+
Density2DMetadata,
|
204
|
+
flatten_func=_flatten_density_2d_metadata,
|
205
|
+
unflatten_func=_unflatten_density_2d_metadata,
|
206
|
+
)
|
207
|
+
|
208
|
+
json_utils.register_custom_type(Density2DMetadata)
|
@@ -0,0 +1,138 @@
|
|
1
|
+
"""Defines filter-and-project density parameterization.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from jax import tree_util
|
10
|
+
from totypes import json_utils, types
|
11
|
+
|
12
|
+
from invrs_opt.parameterization import base, transforms
|
13
|
+
|
14
|
+
|
15
|
+
@dataclasses.dataclass
|
16
|
+
class FilterProjectParams(base.ParameterizedDensity2DArray):
|
17
|
+
"""Stores parameters for the filter-project parameterization."""
|
18
|
+
|
19
|
+
latents: "FilterProjectLatents"
|
20
|
+
metadata: "FilterProjectMetadata"
|
21
|
+
|
22
|
+
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class FilterProjectLatents(base.LatentsBase):
|
25
|
+
"""Stores latent parameters for the filter-project parameterization.
|
26
|
+
|
27
|
+
Attributes:s
|
28
|
+
latent_density: The latent variable from which the density is obtained.
|
29
|
+
"""
|
30
|
+
|
31
|
+
latent_density: types.Density2DArray
|
32
|
+
|
33
|
+
|
34
|
+
@dataclasses.dataclass
|
35
|
+
class FilterProjectMetadata(base.MetadataBase):
|
36
|
+
"""Stores metadata for the filter-project parameterization.
|
37
|
+
|
38
|
+
Attributes:
|
39
|
+
beta: Determines the sharpness of the thresholding operation.
|
40
|
+
"""
|
41
|
+
|
42
|
+
beta: float
|
43
|
+
|
44
|
+
|
45
|
+
tree_util.register_dataclass(
|
46
|
+
FilterProjectParams,
|
47
|
+
data_fields=["latents", "metadata"],
|
48
|
+
meta_fields=[],
|
49
|
+
)
|
50
|
+
tree_util.register_dataclass(
|
51
|
+
FilterProjectLatents,
|
52
|
+
data_fields=["latent_density"],
|
53
|
+
meta_fields=[],
|
54
|
+
)
|
55
|
+
tree_util.register_dataclass(
|
56
|
+
FilterProjectMetadata,
|
57
|
+
data_fields=[],
|
58
|
+
meta_fields=["beta"],
|
59
|
+
)
|
60
|
+
json_utils.register_custom_type(FilterProjectParams)
|
61
|
+
json_utils.register_custom_type(FilterProjectLatents)
|
62
|
+
json_utils.register_custom_type(FilterProjectMetadata)
|
63
|
+
|
64
|
+
|
65
|
+
def filter_project(beta: float) -> base.Density2DParameterization:
|
66
|
+
"""Defines a filter-project parameterization for density arrays.
|
67
|
+
|
68
|
+
The `DensityArray2D` is represented as latent density array that is transformed by,
|
69
|
+
|
70
|
+
transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
|
71
|
+
|
72
|
+
where the kernel has a full-width at half-maximum determined by the minimum width
|
73
|
+
and spacing parameters of the `DensityArray2D`.
|
74
|
+
|
75
|
+
When the density lower and upper bounds are -1 and +1, this basic expression is
|
76
|
+
Where the bounds differ, the density is scaled before the transform is applied, and
|
77
|
+
then unscaled afterwards.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
beta: Determines the sharpness of the thresholding operation.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
The `Density2DParameterization`.
|
84
|
+
"""
|
85
|
+
|
86
|
+
def from_density_fn(density: types.Density2DArray) -> FilterProjectParams:
|
87
|
+
"""Return latent parameters for the given `density`."""
|
88
|
+
array = transforms.normalized_array_from_density(density)
|
89
|
+
array = jnp.clip(array, -1, 1)
|
90
|
+
array *= jnp.tanh(beta)
|
91
|
+
latent_array = jnp.arctanh(array) / beta
|
92
|
+
latent_array = transforms.rescale_array_for_density(latent_array, density)
|
93
|
+
latent_density = density = dataclasses.replace(density, array=latent_array)
|
94
|
+
return FilterProjectParams(
|
95
|
+
latents=FilterProjectLatents(latent_density=latent_density),
|
96
|
+
metadata=FilterProjectMetadata(beta=beta),
|
97
|
+
)
|
98
|
+
|
99
|
+
def to_density_fn(params: FilterProjectParams) -> types.Density2DArray:
|
100
|
+
"""Return a density from the latent parameters."""
|
101
|
+
latent_density = params.latents.latent_density
|
102
|
+
beta = params.metadata.beta
|
103
|
+
|
104
|
+
transformed = types.symmetrize_density(latent_density)
|
105
|
+
transformed = transforms.density_gaussian_filter_and_tanh(transformed, beta)
|
106
|
+
# Scale to ensure that the full valid range of the density array is reachable.
|
107
|
+
mid_value = (transformed.lower_bound + transformed.upper_bound) / 2
|
108
|
+
transformed = tree_util.tree_map(
|
109
|
+
lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
|
110
|
+
)
|
111
|
+
return transforms.apply_fixed_pixels(transformed)
|
112
|
+
|
113
|
+
def constraints_fn(params: FilterProjectParams) -> jnp.ndarray:
|
114
|
+
"""Computes constraints associated with the params."""
|
115
|
+
del params
|
116
|
+
return jnp.asarray(0.0)
|
117
|
+
|
118
|
+
def update_fn(
|
119
|
+
params: FilterProjectParams,
|
120
|
+
updates: FilterProjectParams,
|
121
|
+
value: jnp.ndarray,
|
122
|
+
step: int,
|
123
|
+
) -> FilterProjectParams:
|
124
|
+
"""Perform updates to `params` required for the given `step`."""
|
125
|
+
del step, value
|
126
|
+
return FilterProjectParams(
|
127
|
+
latents=tree_util.tree_map(
|
128
|
+
lambda a, b: a + b, params.latents, updates.latents
|
129
|
+
),
|
130
|
+
metadata=params.metadata,
|
131
|
+
)
|
132
|
+
|
133
|
+
return base.Density2DParameterization(
|
134
|
+
to_density=to_density_fn,
|
135
|
+
from_density=from_density_fn,
|
136
|
+
constraints=constraints_fn,
|
137
|
+
update=update_fn,
|
138
|
+
)
|