invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,347 @@
1
+ """Defines a wrapper for optax optimizers.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ from typing import Any, Optional, Tuple
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import optax # type: ignore[import-untyped]
11
+ from jax import tree_util
12
+ from totypes import types
13
+
14
+ from invrs_opt.optimizers import base
15
+ from invrs_opt.parameterization import (
16
+ base as param_base,
17
+ filter_project,
18
+ gaussian_levelset,
19
+ pixel,
20
+ )
21
+
22
+ PyTree = Any
23
+ WrappedOptaxState = Tuple[int, PyTree, PyTree, PyTree]
24
+
25
+
26
+ def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
27
+ """Return a wrapped optax optimizer."""
28
+ return parameterized_wrapped_optax(
29
+ opt=opt, penalty=0.0, density_parameterization=None
30
+ )
31
+
32
+
33
+ def density_wrapped_optax(
34
+ opt: optax.GradientTransformation,
35
+ *,
36
+ beta: float,
37
+ ) -> base.Optimizer:
38
+ """Wrapped optax optimizer with filter-project density parameterization.
39
+
40
+ In the filter-project density parameterization, the optimization variable
41
+ associated with a density array is a latent density array; the density is obtained
42
+ by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
43
+ full-width at half-maximum equal to the length scale (the mean of declared minimum
44
+ width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
45
+ operation ("projection").
46
+
47
+ Args:
48
+ opt: The optax optimizer to be wrapped.
49
+ beta: Determines the sharpness of the thresholding operation.
50
+
51
+ Returns:
52
+ The wrapped optax optimizer.
53
+ """
54
+ return parameterized_wrapped_optax(
55
+ opt=opt,
56
+ penalty=0.0,
57
+ density_parameterization=filter_project.filter_project(beta=beta),
58
+ )
59
+
60
+
61
+ def levelset_wrapped_optax(
62
+ opt: optax.GradientTransformation,
63
+ *,
64
+ penalty: float,
65
+ length_scale_spacing_factor: float = (
66
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
67
+ ),
68
+ length_scale_fwhm_factor: float = (
69
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
70
+ ),
71
+ length_scale_constraint_factor: float = (
72
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
73
+ ),
74
+ smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
75
+ length_scale_constraint_beta: float = (
76
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
77
+ ),
78
+ length_scale_constraint_weight: float = (
79
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
80
+ ),
81
+ curvature_constraint_weight: float = (
82
+ gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
83
+ ),
84
+ fixed_pixel_constraint_weight: float = (
85
+ gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
86
+ ),
87
+ init_optimizer: optax.GradientTransformation = (
88
+ gaussian_levelset.DEFAULT_INIT_OPTIMIZER
89
+ ),
90
+ init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
91
+ ) -> base.Optimizer:
92
+ """Wrapped optax optimizer with levelset density parameterization.
93
+
94
+ In the levelset parameterization, the optimization variable associated with a
95
+ density array is an array giving the amplitudes of Gaussian radial basis functions
96
+ that represent a levelset function over the domain of the density. In the levelset
97
+ parameterization, gradients are nonzero only at the edges of features, and in
98
+ general the topology of a solution does not change during the course of
99
+ optimization.
100
+
101
+ The spacing and full-width at half-maximum of the Gaussian basis functions gives
102
+ some amount of control over length scales. In addition, constraints associated with
103
+ length scale, radius of curvature, and deviation from fixed pixels are
104
+ automatically computed and penalized with a weight given by `penalty`. In general,
105
+ this helps ensure that features in an optimized density array violate the specified
106
+ constraints to a lesser degree. The constraints are based on "Analytical level set
107
+ fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
108
+
109
+ Args:
110
+ opt: The optax optimizer to be wrapped.
111
+ penalty: The weight of the fabrication penalty, which combines length scale,
112
+ curvature, and fixed pixel constraints.
113
+ length_scale_spacing_factor: The number of levelset control points per unit of
114
+ minimum length scale (mean of density minimum width and minimum spacing).
115
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
116
+ the minimum length scale.
117
+ length_scale_constraint_factor: Multiplies the target length scale in the
118
+ levelset constraints. A value greater than 1 is pessimistic and drives the
119
+ solution to have a larger length scale (relative to smaller values).
120
+ smoothing_factor: For values greater than 1, the density is initially computed
121
+ at higher resolution and then downsampled, yielding smoother geometries.
122
+ length_scale_constraint_beta: Controls relaxation of the length scale
123
+ constraint near the zero level.
124
+ length_scale_constraint_weight: The weight of the length scale constraint in
125
+ the overall fabrication constraint peenalty.
126
+ curvature_constraint_weight: The weight of the curvature constraint.
127
+ fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
128
+ init_optimizer: The optimizer used in the initialization of the levelset
129
+ parameterization. At initialization, the latent parameters are optimized so
130
+ that the initial parameters match the binarized initial density.
131
+ init_steps: The number of optimization steps used in the initialization.
132
+
133
+ Returns:
134
+ The wrapped optax optimizer.
135
+ """
136
+ return parameterized_wrapped_optax(
137
+ opt=opt,
138
+ penalty=penalty,
139
+ density_parameterization=gaussian_levelset.gaussian_levelset(
140
+ length_scale_spacing_factor=length_scale_spacing_factor,
141
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
142
+ length_scale_constraint_factor=length_scale_constraint_factor,
143
+ smoothing_factor=smoothing_factor,
144
+ length_scale_constraint_beta=length_scale_constraint_beta,
145
+ length_scale_constraint_weight=length_scale_constraint_weight,
146
+ curvature_constraint_weight=curvature_constraint_weight,
147
+ fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
148
+ init_optimizer=init_optimizer,
149
+ init_steps=init_steps,
150
+ ),
151
+ )
152
+
153
+
154
+ # -----------------------------------------------------------------------------
155
+ # Base parameterized wrapped optax optimizer.
156
+ # -----------------------------------------------------------------------------
157
+
158
+
159
+ def parameterized_wrapped_optax(
160
+ opt: optax.GradientTransformation,
161
+ density_parameterization: Optional[param_base.Density2DParameterization],
162
+ penalty: float,
163
+ ) -> base.Optimizer:
164
+ """Wrapped optax optimizer with specified density parameterization.
165
+
166
+ Args:
167
+ opt: The optax `GradientTransformation` to be wrapped.
168
+ density_parameterization: The parameterization to be used, or `None`. When no
169
+ parameterization is given, the direct pixel parameterization is used for
170
+ density arrays.
171
+ penalty: The weight of the scalar penalty formed from the constraints of the
172
+ parameterization.
173
+
174
+ Returns:
175
+ The `base.Optimizer`.
176
+ """
177
+
178
+ if density_parameterization is None:
179
+ density_parameterization = pixel.pixel()
180
+
181
+ def init_fn(params: PyTree) -> WrappedOptaxState:
182
+ """Initializes the optimization state."""
183
+ latent_params = _init_latents(params)
184
+ _, latents = param_base.partition_density_metadata(latent_params)
185
+ return (
186
+ 0, # step
187
+ _params_from_latent_params(latent_params), # params
188
+ latent_params, # latent params
189
+ opt.init(latents), # opt state
190
+ )
191
+
192
+ def params_fn(state: WrappedOptaxState) -> PyTree:
193
+ """Returns the parameters for the given `state`."""
194
+ _, params, _, _ = state
195
+ return params
196
+
197
+ def update_fn(
198
+ *,
199
+ grad: PyTree,
200
+ value: jnp.ndarray,
201
+ params: PyTree,
202
+ state: WrappedOptaxState,
203
+ ) -> WrappedOptaxState:
204
+ """Updates the state."""
205
+ del params
206
+
207
+ step, params, latent_params, opt_state = state
208
+ metadata, latents = param_base.partition_density_metadata(latent_params)
209
+
210
+ def _params_from_latents(latents: PyTree) -> PyTree:
211
+ latent_params = param_base.combine_density_metadata(metadata, latents)
212
+ return _params_from_latent_params(latent_params)
213
+
214
+ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
215
+ latent_params = param_base.combine_density_metadata(metadata, latents)
216
+ return _constraint_loss(latent_params)
217
+
218
+ _, vjp_fn = jax.vjp(_params_from_latents, latents)
219
+ (latents_grad,) = vjp_fn(grad)
220
+
221
+ if not (
222
+ tree_util.tree_structure(latents_grad)
223
+ == tree_util.tree_structure(latents) # type: ignore[operator]
224
+ ):
225
+ raise ValueError(
226
+ f"Tree structure of `latents_grad` was different than expected, got \n"
227
+ f"{tree_util.tree_structure(latents_grad)} but expected \n"
228
+ f"{tree_util.tree_structure(latents)}."
229
+ )
230
+
231
+ constraint_loss_grad = jax.grad(_constraint_loss_latents)(latents)
232
+ latents_grad = tree_util.tree_map(
233
+ lambda a, b: a + b, latents_grad, constraint_loss_grad
234
+ )
235
+
236
+ latent_updates, opt_state = opt.update(latents_grad, opt_state, params=latents)
237
+ latent_params = _apply_updates(
238
+ params=latent_params,
239
+ updates=param_base.combine_density_metadata(metadata, latent_updates),
240
+ value=value,
241
+ step=step,
242
+ )
243
+ latent_params = _clip(latent_params)
244
+ params = _params_from_latent_params(latent_params)
245
+ return (step + 1, params, latent_params, opt_state)
246
+
247
+ # -------------------------------------------------------------------------
248
+ # Functions related to the density parameterization.
249
+ # -------------------------------------------------------------------------
250
+
251
+ def _init_latents(params: PyTree) -> PyTree:
252
+ def _leaf_init_latents(leaf: Any) -> Any:
253
+ leaf = _clip(leaf)
254
+ if not _is_density(leaf):
255
+ return leaf
256
+ return density_parameterization.from_density(leaf)
257
+
258
+ return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
259
+
260
+ def _params_from_latent_params(params: PyTree) -> PyTree:
261
+ def _leaf_params_from_latents(leaf: Any) -> Any:
262
+ if not _is_parameterized_density(leaf):
263
+ return leaf
264
+ return density_parameterization.to_density(leaf)
265
+
266
+ return tree_util.tree_map(
267
+ _leaf_params_from_latents,
268
+ params,
269
+ is_leaf=_is_parameterized_density,
270
+ )
271
+
272
+ def _apply_updates(
273
+ params: PyTree,
274
+ updates: PyTree,
275
+ value: jnp.ndarray,
276
+ step: int,
277
+ ) -> PyTree:
278
+ def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
279
+ if _is_parameterized_density(leaf):
280
+ return density_parameterization.update(
281
+ params=leaf, updates=update, value=value, step=step
282
+ )
283
+ else:
284
+ return optax.apply_updates(params=leaf, updates=update)
285
+
286
+ return tree_util.tree_map(
287
+ _leaf_apply_updates,
288
+ updates,
289
+ params,
290
+ is_leaf=_is_parameterized_density,
291
+ )
292
+
293
+ # -------------------------------------------------------------------------
294
+ # Functions related to the constraints to be minimized.
295
+ # -------------------------------------------------------------------------
296
+
297
+ def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
298
+ def _constraint_loss_leaf(
299
+ leaf: param_base.ParameterizedDensity2DArray,
300
+ ) -> jnp.ndarray:
301
+ constraints = density_parameterization.constraints(leaf)
302
+ constraints = tree_util.tree_map(
303
+ lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
304
+ constraints,
305
+ )
306
+ return jnp.sum(jnp.asarray(constraints))
307
+
308
+ losses = [0.0] + [
309
+ _constraint_loss_leaf(p)
310
+ for p in tree_util.tree_leaves(
311
+ latent_params, is_leaf=_is_parameterized_density
312
+ )
313
+ if _is_parameterized_density(p)
314
+ ]
315
+ return penalty * jnp.sum(jnp.asarray(losses))
316
+
317
+ return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
318
+
319
+
320
+ def _is_density(leaf: Any) -> Any:
321
+ """Return `True` if `leaf` is a density array."""
322
+ return isinstance(leaf, types.Density2DArray)
323
+
324
+
325
+ def _is_parameterized_density(leaf: Any) -> Any:
326
+ """Return `True` if `leaf` is a parameterized density array."""
327
+ return isinstance(leaf, param_base.ParameterizedDensity2DArray)
328
+
329
+
330
+ def _is_custom_type(leaf: Any) -> bool:
331
+ """Return `True` if `leaf` is a recognized custom type."""
332
+ return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
333
+
334
+
335
+ def _clip(pytree: PyTree) -> PyTree:
336
+ """Clips leaves on `pytree` to their bounds."""
337
+
338
+ def _clip_fn(leaf: Any) -> Any:
339
+ if not _is_custom_type(leaf):
340
+ return leaf
341
+ if leaf.lower_bound is None and leaf.upper_bound is None:
342
+ return leaf
343
+ return tree_util.tree_map(
344
+ lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
345
+ )
346
+
347
+ return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
File without changes
@@ -0,0 +1,208 @@
1
+ """Base types for density parameterizations.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+ from typing import Any, Optional, Protocol, Sequence, Tuple
8
+
9
+ import jax.numpy as jnp
10
+ import numpy as onp
11
+ from jax import tree_util
12
+ from totypes import json_utils, partition_utils, types
13
+
14
+ Array = jnp.ndarray | onp.ndarray[Any, Any]
15
+ PyTree = Any
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class ParameterizedDensity2DArray:
20
+ """Stores latents and metadata for a parameterized density array."""
21
+
22
+ latents: "LatentsBase"
23
+ metadata: Optional["MetadataBase"]
24
+
25
+
26
+ class LatentsBase:
27
+ """Base class for latents of a parameterized density array."""
28
+
29
+ pass
30
+
31
+
32
+ class MetadataBase:
33
+ """Base class for metadata of a parameterized density array."""
34
+
35
+ pass
36
+
37
+
38
+ tree_util.register_dataclass(
39
+ ParameterizedDensity2DArray,
40
+ data_fields=["latents", "metadata"],
41
+ meta_fields=[],
42
+ )
43
+ json_utils.register_custom_type(ParameterizedDensity2DArray)
44
+
45
+
46
+ def partition_density_metadata(tree: PyTree) -> Tuple[PyTree, PyTree]:
47
+ """Splits a pytree with parameterized densities into metadata from latents."""
48
+ metadata, latents = partition_utils.partition(
49
+ tree,
50
+ select_fn=lambda x: isinstance(x, MetadataBase),
51
+ is_leaf=_is_metadata_or_none,
52
+ )
53
+ return metadata, latents
54
+
55
+
56
+ def combine_density_metadata(metadata: PyTree, latents: PyTree) -> PyTree:
57
+ """Combines pytrees containing metadata and latents."""
58
+ return partition_utils.combine(metadata, latents, is_leaf=_is_metadata_or_none)
59
+
60
+
61
+ def _is_metadata_or_none(leaf: Any) -> bool:
62
+ """Return `True` if `leaf` is `None` or density metadata."""
63
+ return leaf is None or isinstance(leaf, MetadataBase)
64
+
65
+
66
+ @dataclasses.dataclass
67
+ class Density2DParameterization:
68
+ """Stores `(from_density, to_density, constraints, update)` function triple."""
69
+
70
+ from_density: "FromDensityFn"
71
+ to_density: "ToDensityFn"
72
+ constraints: "ConstraintsFn"
73
+ update: "UpdateFn"
74
+
75
+
76
+ class FromDensityFn(Protocol):
77
+ """Generate the latent representation of a density array."""
78
+
79
+ def __call__(self, density: types.Density2DArray) -> ParameterizedDensity2DArray:
80
+ ...
81
+
82
+
83
+ class ToDensityFn(Protocol):
84
+ """Generate a density from its latent representation."""
85
+
86
+ def __call__(self, params: PyTree) -> types.Density2DArray:
87
+ ...
88
+
89
+
90
+ class ConstraintsFn(Protocol):
91
+ """Compute constraints for a latent representation of a density array."""
92
+
93
+ def __call__(self, params: PyTree) -> jnp.ndarray:
94
+ ...
95
+
96
+
97
+ class UpdateFn(Protocol):
98
+ """Performs the required update of a parameterized density for the given step."""
99
+
100
+ def __call__(
101
+ self,
102
+ params: PyTree,
103
+ updates: PyTree,
104
+ value: jnp.ndarray,
105
+ step: int,
106
+ ) -> PyTree:
107
+ ...
108
+
109
+
110
+ @dataclasses.dataclass
111
+ class Density2DMetadata:
112
+ """Stores the metadata of a `Density2DArray`."""
113
+
114
+ lower_bound: float
115
+ upper_bound: float
116
+ fixed_solid: Optional[Array]
117
+ fixed_void: Optional[Array]
118
+ minimum_width: int
119
+ minimum_spacing: int
120
+ periodic: Sequence[bool]
121
+ symmetries: Sequence[str]
122
+
123
+ def __post_init__(self) -> None:
124
+ self.periodic = tuple(self.periodic)
125
+ self.symmetries = tuple(self.symmetries)
126
+
127
+ @classmethod
128
+ def from_density(self, density: types.Density2DArray) -> "Density2DMetadata":
129
+ density_metadata_dict = dataclasses.asdict(density)
130
+ del density_metadata_dict["array"]
131
+ return Density2DMetadata(**density_metadata_dict)
132
+
133
+
134
+ def _flatten_density_2d_metadata(
135
+ metadata: Density2DMetadata,
136
+ ) -> Tuple[
137
+ Tuple[()],
138
+ Tuple[
139
+ float,
140
+ float,
141
+ types.HashableWrapper,
142
+ types.HashableWrapper,
143
+ int,
144
+ int,
145
+ Sequence[bool],
146
+ Sequence[str],
147
+ ],
148
+ ]:
149
+ """Flattens a `Density2DMetadata` into children and auxilliary data."""
150
+ return (
151
+ (),
152
+ (
153
+ metadata.lower_bound,
154
+ metadata.upper_bound,
155
+ types.HashableWrapper(metadata.fixed_solid),
156
+ types.HashableWrapper(metadata.fixed_void),
157
+ metadata.minimum_width,
158
+ metadata.minimum_spacing,
159
+ metadata.periodic,
160
+ metadata.symmetries,
161
+ ),
162
+ )
163
+
164
+
165
+ def _unflatten_density_2d_metadata(
166
+ aux: Tuple[
167
+ float,
168
+ float,
169
+ types.HashableWrapper,
170
+ types.HashableWrapper,
171
+ int,
172
+ int,
173
+ Sequence[bool],
174
+ Sequence[str],
175
+ ],
176
+ children: Tuple[()],
177
+ ) -> Density2DMetadata:
178
+ """Unflattens a flattened `Density2DMetadata`."""
179
+ del children
180
+ (
181
+ lower_bound,
182
+ upper_bound,
183
+ wrapped_fixed_solid,
184
+ wrapped_fixed_void,
185
+ minimum_width,
186
+ minimum_spacing,
187
+ periodic,
188
+ symmetries,
189
+ ) = aux
190
+ return Density2DMetadata(
191
+ lower_bound=lower_bound,
192
+ upper_bound=upper_bound,
193
+ fixed_solid=wrapped_fixed_solid.array, # type: ignore[arg-type]
194
+ fixed_void=wrapped_fixed_void.array, # type: ignore[arg-type]
195
+ minimum_width=minimum_width,
196
+ minimum_spacing=minimum_spacing,
197
+ periodic=tuple(periodic),
198
+ symmetries=tuple(symmetries),
199
+ )
200
+
201
+
202
+ tree_util.register_pytree_node(
203
+ Density2DMetadata,
204
+ flatten_func=_flatten_density_2d_metadata,
205
+ unflatten_func=_unflatten_density_2d_metadata,
206
+ )
207
+
208
+ json_utils.register_custom_type(Density2DMetadata)
@@ -0,0 +1,138 @@
1
+ """Defines filter-and-project density parameterization.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+
8
+ import jax.numpy as jnp
9
+ from jax import tree_util
10
+ from totypes import json_utils, types
11
+
12
+ from invrs_opt.parameterization import base, transforms
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class FilterProjectParams(base.ParameterizedDensity2DArray):
17
+ """Stores parameters for the filter-project parameterization."""
18
+
19
+ latents: "FilterProjectLatents"
20
+ metadata: "FilterProjectMetadata"
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class FilterProjectLatents(base.LatentsBase):
25
+ """Stores latent parameters for the filter-project parameterization.
26
+
27
+ Attributes:s
28
+ latent_density: The latent variable from which the density is obtained.
29
+ """
30
+
31
+ latent_density: types.Density2DArray
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class FilterProjectMetadata(base.MetadataBase):
36
+ """Stores metadata for the filter-project parameterization.
37
+
38
+ Attributes:
39
+ beta: Determines the sharpness of the thresholding operation.
40
+ """
41
+
42
+ beta: float
43
+
44
+
45
+ tree_util.register_dataclass(
46
+ FilterProjectParams,
47
+ data_fields=["latents", "metadata"],
48
+ meta_fields=[],
49
+ )
50
+ tree_util.register_dataclass(
51
+ FilterProjectLatents,
52
+ data_fields=["latent_density"],
53
+ meta_fields=[],
54
+ )
55
+ tree_util.register_dataclass(
56
+ FilterProjectMetadata,
57
+ data_fields=[],
58
+ meta_fields=["beta"],
59
+ )
60
+ json_utils.register_custom_type(FilterProjectParams)
61
+ json_utils.register_custom_type(FilterProjectLatents)
62
+ json_utils.register_custom_type(FilterProjectMetadata)
63
+
64
+
65
+ def filter_project(beta: float) -> base.Density2DParameterization:
66
+ """Defines a filter-project parameterization for density arrays.
67
+
68
+ The `DensityArray2D` is represented as latent density array that is transformed by,
69
+
70
+ transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
71
+
72
+ where the kernel has a full-width at half-maximum determined by the minimum width
73
+ and spacing parameters of the `DensityArray2D`.
74
+
75
+ When the density lower and upper bounds are -1 and +1, this basic expression is
76
+ Where the bounds differ, the density is scaled before the transform is applied, and
77
+ then unscaled afterwards.
78
+
79
+ Args:
80
+ beta: Determines the sharpness of the thresholding operation.
81
+
82
+ Returns:
83
+ The `Density2DParameterization`.
84
+ """
85
+
86
+ def from_density_fn(density: types.Density2DArray) -> FilterProjectParams:
87
+ """Return latent parameters for the given `density`."""
88
+ array = transforms.normalized_array_from_density(density)
89
+ array = jnp.clip(array, -1, 1)
90
+ array *= jnp.tanh(beta)
91
+ latent_array = jnp.arctanh(array) / beta
92
+ latent_array = transforms.rescale_array_for_density(latent_array, density)
93
+ latent_density = density = dataclasses.replace(density, array=latent_array)
94
+ return FilterProjectParams(
95
+ latents=FilterProjectLatents(latent_density=latent_density),
96
+ metadata=FilterProjectMetadata(beta=beta),
97
+ )
98
+
99
+ def to_density_fn(params: FilterProjectParams) -> types.Density2DArray:
100
+ """Return a density from the latent parameters."""
101
+ latent_density = params.latents.latent_density
102
+ beta = params.metadata.beta
103
+
104
+ transformed = types.symmetrize_density(latent_density)
105
+ transformed = transforms.density_gaussian_filter_and_tanh(transformed, beta)
106
+ # Scale to ensure that the full valid range of the density array is reachable.
107
+ mid_value = (transformed.lower_bound + transformed.upper_bound) / 2
108
+ transformed = tree_util.tree_map(
109
+ lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
110
+ )
111
+ return transforms.apply_fixed_pixels(transformed)
112
+
113
+ def constraints_fn(params: FilterProjectParams) -> jnp.ndarray:
114
+ """Computes constraints associated with the params."""
115
+ del params
116
+ return jnp.asarray(0.0)
117
+
118
+ def update_fn(
119
+ params: FilterProjectParams,
120
+ updates: FilterProjectParams,
121
+ value: jnp.ndarray,
122
+ step: int,
123
+ ) -> FilterProjectParams:
124
+ """Perform updates to `params` required for the given `step`."""
125
+ del step, value
126
+ return FilterProjectParams(
127
+ latents=tree_util.tree_map(
128
+ lambda a, b: a + b, params.latents, updates.latents
129
+ ),
130
+ metadata=params.metadata,
131
+ )
132
+
133
+ return base.Density2DParameterization(
134
+ to_density=to_density_fn,
135
+ from_density=from_density_fn,
136
+ constraints=constraints_fn,
137
+ update=update_fn,
138
+ )