invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,347 @@
1
+ """Defines a wrapper for optax optimizers.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ from typing import Any, Optional, Tuple
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import optax # type: ignore[import-untyped]
11
+ from jax import tree_util
12
+ from totypes import types
13
+
14
+ from invrs_opt.optimizers import base
15
+ from invrs_opt.parameterization import (
16
+ base as param_base,
17
+ filter_project,
18
+ gaussian_levelset,
19
+ pixel,
20
+ )
21
+
22
+ PyTree = Any
23
+ WrappedOptaxState = Tuple[int, PyTree, PyTree, PyTree]
24
+
25
+
26
+ def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
27
+ """Return a wrapped optax optimizer."""
28
+ return parameterized_wrapped_optax(
29
+ opt=opt, penalty=0.0, density_parameterization=None
30
+ )
31
+
32
+
33
+ def density_wrapped_optax(
34
+ opt: optax.GradientTransformation,
35
+ *,
36
+ beta: float,
37
+ ) -> base.Optimizer:
38
+ """Wrapped optax optimizer with filter-project density parameterization.
39
+
40
+ In the filter-project density parameterization, the optimization variable
41
+ associated with a density array is a latent density array; the density is obtained
42
+ by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
43
+ full-width at half-maximum equal to the length scale (the mean of declared minimum
44
+ width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
45
+ operation ("projection").
46
+
47
+ Args:
48
+ opt: The optax optimizer to be wrapped.
49
+ beta: Determines the sharpness of the thresholding operation.
50
+
51
+ Returns:
52
+ The wrapped optax optimizer.
53
+ """
54
+ return parameterized_wrapped_optax(
55
+ opt=opt,
56
+ penalty=0.0,
57
+ density_parameterization=filter_project.filter_project(beta=beta),
58
+ )
59
+
60
+
61
+ def levelset_wrapped_optax(
62
+ opt: optax.GradientTransformation,
63
+ *,
64
+ penalty: float,
65
+ length_scale_spacing_factor: float = (
66
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
67
+ ),
68
+ length_scale_fwhm_factor: float = (
69
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
70
+ ),
71
+ length_scale_constraint_factor: float = (
72
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
73
+ ),
74
+ smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
75
+ length_scale_constraint_beta: float = (
76
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
77
+ ),
78
+ length_scale_constraint_weight: float = (
79
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
80
+ ),
81
+ curvature_constraint_weight: float = (
82
+ gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
83
+ ),
84
+ fixed_pixel_constraint_weight: float = (
85
+ gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
86
+ ),
87
+ init_optimizer: optax.GradientTransformation = (
88
+ gaussian_levelset.DEFAULT_INIT_OPTIMIZER
89
+ ),
90
+ init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
91
+ ) -> base.Optimizer:
92
+ """Wrapped optax optimizer with levelset density parameterization.
93
+
94
+ In the levelset parameterization, the optimization variable associated with a
95
+ density array is an array giving the amplitudes of Gaussian radial basis functions
96
+ that represent a levelset function over the domain of the density. In the levelset
97
+ parameterization, gradients are nonzero only at the edges of features, and in
98
+ general the topology of a solution does not change during the course of
99
+ optimization.
100
+
101
+ The spacing and full-width at half-maximum of the Gaussian basis functions gives
102
+ some amount of control over length scales. In addition, constraints associated with
103
+ length scale, radius of curvature, and deviation from fixed pixels are
104
+ automatically computed and penalized with a weight given by `penalty`. In general,
105
+ this helps ensure that features in an optimized density array violate the specified
106
+ constraints to a lesser degree. The constraints are based on "Analytical level set
107
+ fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
108
+
109
+ Args:
110
+ opt: The optax optimizer to be wrapped.
111
+ penalty: The weight of the fabrication penalty, which combines length scale,
112
+ curvature, and fixed pixel constraints.
113
+ length_scale_spacing_factor: The number of levelset control points per unit of
114
+ minimum length scale (mean of density minimum width and minimum spacing).
115
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
116
+ the minimum length scale.
117
+ length_scale_constraint_factor: Multiplies the target length scale in the
118
+ levelset constraints. A value greater than 1 is pessimistic and drives the
119
+ solution to have a larger length scale (relative to smaller values).
120
+ smoothing_factor: For values greater than 1, the density is initially computed
121
+ at higher resolution and then downsampled, yielding smoother geometries.
122
+ length_scale_constraint_beta: Controls relaxation of the length scale
123
+ constraint near the zero level.
124
+ length_scale_constraint_weight: The weight of the length scale constraint in
125
+ the overall fabrication constraint peenalty.
126
+ curvature_constraint_weight: The weight of the curvature constraint.
127
+ fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
128
+ init_optimizer: The optimizer used in the initialization of the levelset
129
+ parameterization. At initialization, the latent parameters are optimized so
130
+ that the initial parameters match the binarized initial density.
131
+ init_steps: The number of optimization steps used in the initialization.
132
+
133
+ Returns:
134
+ The wrapped optax optimizer.
135
+ """
136
+ return parameterized_wrapped_optax(
137
+ opt=opt,
138
+ penalty=penalty,
139
+ density_parameterization=gaussian_levelset.gaussian_levelset(
140
+ length_scale_spacing_factor=length_scale_spacing_factor,
141
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
142
+ length_scale_constraint_factor=length_scale_constraint_factor,
143
+ smoothing_factor=smoothing_factor,
144
+ length_scale_constraint_beta=length_scale_constraint_beta,
145
+ length_scale_constraint_weight=length_scale_constraint_weight,
146
+ curvature_constraint_weight=curvature_constraint_weight,
147
+ fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
148
+ init_optimizer=init_optimizer,
149
+ init_steps=init_steps,
150
+ ),
151
+ )
152
+
153
+
154
+ # -----------------------------------------------------------------------------
155
+ # Base parameterized wrapped optax optimizer.
156
+ # -----------------------------------------------------------------------------
157
+
158
+
159
+ def parameterized_wrapped_optax(
160
+ opt: optax.GradientTransformation,
161
+ density_parameterization: Optional[param_base.Density2DParameterization],
162
+ penalty: float,
163
+ ) -> base.Optimizer:
164
+ """Wrapped optax optimizer with specified density parameterization.
165
+
166
+ Args:
167
+ opt: The optax `GradientTransformation` to be wrapped.
168
+ density_parameterization: The parameterization to be used, or `None`. When no
169
+ parameterization is given, the direct pixel parameterization is used for
170
+ density arrays.
171
+ penalty: The weight of the scalar penalty formed from the constraints of the
172
+ parameterization.
173
+
174
+ Returns:
175
+ The `base.Optimizer`.
176
+ """
177
+
178
+ if density_parameterization is None:
179
+ density_parameterization = pixel.pixel()
180
+
181
+ def init_fn(params: PyTree) -> WrappedOptaxState:
182
+ """Initializes the optimization state."""
183
+ latent_params = _init_latents(params)
184
+ _, latents = param_base.partition_density_metadata(latent_params)
185
+ return (
186
+ 0, # step
187
+ _params_from_latent_params(latent_params), # params
188
+ latent_params, # latent params
189
+ opt.init(latents), # opt state
190
+ )
191
+
192
+ def params_fn(state: WrappedOptaxState) -> PyTree:
193
+ """Returns the parameters for the given `state`."""
194
+ _, params, _, _ = state
195
+ return params
196
+
197
+ def update_fn(
198
+ *,
199
+ grad: PyTree,
200
+ value: jnp.ndarray,
201
+ params: PyTree,
202
+ state: WrappedOptaxState,
203
+ ) -> WrappedOptaxState:
204
+ """Updates the state."""
205
+ del params
206
+
207
+ step, params, latent_params, opt_state = state
208
+ metadata, latents = param_base.partition_density_metadata(latent_params)
209
+
210
+ def _params_from_latents(latents: PyTree) -> PyTree:
211
+ latent_params = param_base.combine_density_metadata(metadata, latents)
212
+ return _params_from_latent_params(latent_params)
213
+
214
+ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
215
+ latent_params = param_base.combine_density_metadata(metadata, latents)
216
+ return _constraint_loss(latent_params)
217
+
218
+ _, vjp_fn = jax.vjp(_params_from_latents, latents)
219
+ (latents_grad,) = vjp_fn(grad)
220
+
221
+ if not (
222
+ tree_util.tree_structure(latents_grad)
223
+ == tree_util.tree_structure(latents) # type: ignore[operator]
224
+ ):
225
+ raise ValueError(
226
+ f"Tree structure of `latents_grad` was different than expected, got \n"
227
+ f"{tree_util.tree_structure(latents_grad)} but expected \n"
228
+ f"{tree_util.tree_structure(latents)}."
229
+ )
230
+
231
+ constraint_loss_grad = jax.grad(_constraint_loss_latents)(latents)
232
+ latents_grad = tree_util.tree_map(
233
+ lambda a, b: a + b, latents_grad, constraint_loss_grad
234
+ )
235
+
236
+ latent_updates, opt_state = opt.update(latents_grad, opt_state, params=latents)
237
+ latent_params = _apply_updates(
238
+ params=latent_params,
239
+ updates=param_base.combine_density_metadata(metadata, latent_updates),
240
+ value=value,
241
+ step=step,
242
+ )
243
+ latent_params = _clip(latent_params)
244
+ params = _params_from_latent_params(latent_params)
245
+ return (step + 1, params, latent_params, opt_state)
246
+
247
+ # -------------------------------------------------------------------------
248
+ # Functions related to the density parameterization.
249
+ # -------------------------------------------------------------------------
250
+
251
+ def _init_latents(params: PyTree) -> PyTree:
252
+ def _leaf_init_latents(leaf: Any) -> Any:
253
+ leaf = _clip(leaf)
254
+ if not _is_density(leaf):
255
+ return leaf
256
+ return density_parameterization.from_density(leaf)
257
+
258
+ return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
259
+
260
+ def _params_from_latent_params(params: PyTree) -> PyTree:
261
+ def _leaf_params_from_latents(leaf: Any) -> Any:
262
+ if not _is_parameterized_density(leaf):
263
+ return leaf
264
+ return density_parameterization.to_density(leaf)
265
+
266
+ return tree_util.tree_map(
267
+ _leaf_params_from_latents,
268
+ params,
269
+ is_leaf=_is_parameterized_density,
270
+ )
271
+
272
+ def _apply_updates(
273
+ params: PyTree,
274
+ updates: PyTree,
275
+ value: jnp.ndarray,
276
+ step: int,
277
+ ) -> PyTree:
278
+ def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
279
+ if _is_parameterized_density(leaf):
280
+ return density_parameterization.update(
281
+ params=leaf, updates=update, value=value, step=step
282
+ )
283
+ else:
284
+ return optax.apply_updates(params=leaf, updates=update)
285
+
286
+ return tree_util.tree_map(
287
+ _leaf_apply_updates,
288
+ updates,
289
+ params,
290
+ is_leaf=_is_parameterized_density,
291
+ )
292
+
293
+ # -------------------------------------------------------------------------
294
+ # Functions related to the constraints to be minimized.
295
+ # -------------------------------------------------------------------------
296
+
297
+ def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
298
+ def _constraint_loss_leaf(
299
+ leaf: param_base.ParameterizedDensity2DArray,
300
+ ) -> jnp.ndarray:
301
+ constraints = density_parameterization.constraints(leaf)
302
+ constraints = tree_util.tree_map(
303
+ lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
304
+ constraints,
305
+ )
306
+ return jnp.sum(jnp.asarray(constraints))
307
+
308
+ losses = [0.0] + [
309
+ _constraint_loss_leaf(p)
310
+ for p in tree_util.tree_leaves(
311
+ latent_params, is_leaf=_is_parameterized_density
312
+ )
313
+ if _is_parameterized_density(p)
314
+ ]
315
+ return penalty * jnp.sum(jnp.asarray(losses))
316
+
317
+ return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
318
+
319
+
320
+ def _is_density(leaf: Any) -> Any:
321
+ """Return `True` if `leaf` is a density array."""
322
+ return isinstance(leaf, types.Density2DArray)
323
+
324
+
325
+ def _is_parameterized_density(leaf: Any) -> Any:
326
+ """Return `True` if `leaf` is a parameterized density array."""
327
+ return isinstance(leaf, param_base.ParameterizedDensity2DArray)
328
+
329
+
330
+ def _is_custom_type(leaf: Any) -> bool:
331
+ """Return `True` if `leaf` is a recognized custom type."""
332
+ return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
333
+
334
+
335
+ def _clip(pytree: PyTree) -> PyTree:
336
+ """Clips leaves on `pytree` to their bounds."""
337
+
338
+ def _clip_fn(leaf: Any) -> Any:
339
+ if not _is_custom_type(leaf):
340
+ return leaf
341
+ if leaf.lower_bound is None and leaf.upper_bound is None:
342
+ return leaf
343
+ return tree_util.tree_map(
344
+ lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
345
+ )
346
+
347
+ return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
File without changes
@@ -0,0 +1,208 @@
1
+ """Base types for density parameterizations.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+ from typing import Any, Optional, Protocol, Sequence, Tuple
8
+
9
+ import jax.numpy as jnp
10
+ import numpy as onp
11
+ from jax import tree_util
12
+ from totypes import json_utils, partition_utils, types
13
+
14
+ Array = jnp.ndarray | onp.ndarray[Any, Any]
15
+ PyTree = Any
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class ParameterizedDensity2DArray:
20
+ """Stores latents and metadata for a parameterized density array."""
21
+
22
+ latents: "LatentsBase"
23
+ metadata: Optional["MetadataBase"]
24
+
25
+
26
+ class LatentsBase:
27
+ """Base class for latents of a parameterized density array."""
28
+
29
+ pass
30
+
31
+
32
+ class MetadataBase:
33
+ """Base class for metadata of a parameterized density array."""
34
+
35
+ pass
36
+
37
+
38
+ tree_util.register_dataclass(
39
+ ParameterizedDensity2DArray,
40
+ data_fields=["latents", "metadata"],
41
+ meta_fields=[],
42
+ )
43
+ json_utils.register_custom_type(ParameterizedDensity2DArray)
44
+
45
+
46
+ def partition_density_metadata(tree: PyTree) -> Tuple[PyTree, PyTree]:
47
+ """Splits a pytree with parameterized densities into metadata from latents."""
48
+ metadata, latents = partition_utils.partition(
49
+ tree,
50
+ select_fn=lambda x: isinstance(x, MetadataBase),
51
+ is_leaf=_is_metadata_or_none,
52
+ )
53
+ return metadata, latents
54
+
55
+
56
+ def combine_density_metadata(metadata: PyTree, latents: PyTree) -> PyTree:
57
+ """Combines pytrees containing metadata and latents."""
58
+ return partition_utils.combine(metadata, latents, is_leaf=_is_metadata_or_none)
59
+
60
+
61
+ def _is_metadata_or_none(leaf: Any) -> bool:
62
+ """Return `True` if `leaf` is `None` or density metadata."""
63
+ return leaf is None or isinstance(leaf, MetadataBase)
64
+
65
+
66
+ @dataclasses.dataclass
67
+ class Density2DParameterization:
68
+ """Stores `(from_density, to_density, constraints, update)` function triple."""
69
+
70
+ from_density: "FromDensityFn"
71
+ to_density: "ToDensityFn"
72
+ constraints: "ConstraintsFn"
73
+ update: "UpdateFn"
74
+
75
+
76
+ class FromDensityFn(Protocol):
77
+ """Generate the latent representation of a density array."""
78
+
79
+ def __call__(self, density: types.Density2DArray) -> ParameterizedDensity2DArray:
80
+ ...
81
+
82
+
83
+ class ToDensityFn(Protocol):
84
+ """Generate a density from its latent representation."""
85
+
86
+ def __call__(self, params: PyTree) -> types.Density2DArray:
87
+ ...
88
+
89
+
90
+ class ConstraintsFn(Protocol):
91
+ """Compute constraints for a latent representation of a density array."""
92
+
93
+ def __call__(self, params: PyTree) -> jnp.ndarray:
94
+ ...
95
+
96
+
97
+ class UpdateFn(Protocol):
98
+ """Performs the required update of a parameterized density for the given step."""
99
+
100
+ def __call__(
101
+ self,
102
+ params: PyTree,
103
+ updates: PyTree,
104
+ value: jnp.ndarray,
105
+ step: int,
106
+ ) -> PyTree:
107
+ ...
108
+
109
+
110
+ @dataclasses.dataclass
111
+ class Density2DMetadata:
112
+ """Stores the metadata of a `Density2DArray`."""
113
+
114
+ lower_bound: float
115
+ upper_bound: float
116
+ fixed_solid: Optional[Array]
117
+ fixed_void: Optional[Array]
118
+ minimum_width: int
119
+ minimum_spacing: int
120
+ periodic: Sequence[bool]
121
+ symmetries: Sequence[str]
122
+
123
+ def __post_init__(self) -> None:
124
+ self.periodic = tuple(self.periodic)
125
+ self.symmetries = tuple(self.symmetries)
126
+
127
+ @classmethod
128
+ def from_density(self, density: types.Density2DArray) -> "Density2DMetadata":
129
+ density_metadata_dict = dataclasses.asdict(density)
130
+ del density_metadata_dict["array"]
131
+ return Density2DMetadata(**density_metadata_dict)
132
+
133
+
134
+ def _flatten_density_2d_metadata(
135
+ metadata: Density2DMetadata,
136
+ ) -> Tuple[
137
+ Tuple[()],
138
+ Tuple[
139
+ float,
140
+ float,
141
+ types.HashableWrapper,
142
+ types.HashableWrapper,
143
+ int,
144
+ int,
145
+ Sequence[bool],
146
+ Sequence[str],
147
+ ],
148
+ ]:
149
+ """Flattens a `Density2DMetadata` into children and auxilliary data."""
150
+ return (
151
+ (),
152
+ (
153
+ metadata.lower_bound,
154
+ metadata.upper_bound,
155
+ types.HashableWrapper(metadata.fixed_solid),
156
+ types.HashableWrapper(metadata.fixed_void),
157
+ metadata.minimum_width,
158
+ metadata.minimum_spacing,
159
+ metadata.periodic,
160
+ metadata.symmetries,
161
+ ),
162
+ )
163
+
164
+
165
+ def _unflatten_density_2d_metadata(
166
+ aux: Tuple[
167
+ float,
168
+ float,
169
+ types.HashableWrapper,
170
+ types.HashableWrapper,
171
+ int,
172
+ int,
173
+ Sequence[bool],
174
+ Sequence[str],
175
+ ],
176
+ children: Tuple[()],
177
+ ) -> Density2DMetadata:
178
+ """Unflattens a flattened `Density2DMetadata`."""
179
+ del children
180
+ (
181
+ lower_bound,
182
+ upper_bound,
183
+ wrapped_fixed_solid,
184
+ wrapped_fixed_void,
185
+ minimum_width,
186
+ minimum_spacing,
187
+ periodic,
188
+ symmetries,
189
+ ) = aux
190
+ return Density2DMetadata(
191
+ lower_bound=lower_bound,
192
+ upper_bound=upper_bound,
193
+ fixed_solid=wrapped_fixed_solid.array, # type: ignore[arg-type]
194
+ fixed_void=wrapped_fixed_void.array, # type: ignore[arg-type]
195
+ minimum_width=minimum_width,
196
+ minimum_spacing=minimum_spacing,
197
+ periodic=tuple(periodic),
198
+ symmetries=tuple(symmetries),
199
+ )
200
+
201
+
202
+ tree_util.register_pytree_node(
203
+ Density2DMetadata,
204
+ flatten_func=_flatten_density_2d_metadata,
205
+ unflatten_func=_unflatten_density_2d_metadata,
206
+ )
207
+
208
+ json_utils.register_custom_type(Density2DMetadata)
@@ -0,0 +1,138 @@
1
+ """Defines filter-and-project density parameterization.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+
8
+ import jax.numpy as jnp
9
+ from jax import tree_util
10
+ from totypes import json_utils, types
11
+
12
+ from invrs_opt.parameterization import base, transforms
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class FilterProjectParams(base.ParameterizedDensity2DArray):
17
+ """Stores parameters for the filter-project parameterization."""
18
+
19
+ latents: "FilterProjectLatents"
20
+ metadata: "FilterProjectMetadata"
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class FilterProjectLatents(base.LatentsBase):
25
+ """Stores latent parameters for the filter-project parameterization.
26
+
27
+ Attributes:s
28
+ latent_density: The latent variable from which the density is obtained.
29
+ """
30
+
31
+ latent_density: types.Density2DArray
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class FilterProjectMetadata(base.MetadataBase):
36
+ """Stores metadata for the filter-project parameterization.
37
+
38
+ Attributes:
39
+ beta: Determines the sharpness of the thresholding operation.
40
+ """
41
+
42
+ beta: float
43
+
44
+
45
+ tree_util.register_dataclass(
46
+ FilterProjectParams,
47
+ data_fields=["latents", "metadata"],
48
+ meta_fields=[],
49
+ )
50
+ tree_util.register_dataclass(
51
+ FilterProjectLatents,
52
+ data_fields=["latent_density"],
53
+ meta_fields=[],
54
+ )
55
+ tree_util.register_dataclass(
56
+ FilterProjectMetadata,
57
+ data_fields=[],
58
+ meta_fields=["beta"],
59
+ )
60
+ json_utils.register_custom_type(FilterProjectParams)
61
+ json_utils.register_custom_type(FilterProjectLatents)
62
+ json_utils.register_custom_type(FilterProjectMetadata)
63
+
64
+
65
+ def filter_project(beta: float) -> base.Density2DParameterization:
66
+ """Defines a filter-project parameterization for density arrays.
67
+
68
+ The `DensityArray2D` is represented as latent density array that is transformed by,
69
+
70
+ transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
71
+
72
+ where the kernel has a full-width at half-maximum determined by the minimum width
73
+ and spacing parameters of the `DensityArray2D`.
74
+
75
+ When the density lower and upper bounds are -1 and +1, this basic expression is
76
+ Where the bounds differ, the density is scaled before the transform is applied, and
77
+ then unscaled afterwards.
78
+
79
+ Args:
80
+ beta: Determines the sharpness of the thresholding operation.
81
+
82
+ Returns:
83
+ The `Density2DParameterization`.
84
+ """
85
+
86
+ def from_density_fn(density: types.Density2DArray) -> FilterProjectParams:
87
+ """Return latent parameters for the given `density`."""
88
+ array = transforms.normalized_array_from_density(density)
89
+ array = jnp.clip(array, -1, 1)
90
+ array *= jnp.tanh(beta)
91
+ latent_array = jnp.arctanh(array) / beta
92
+ latent_array = transforms.rescale_array_for_density(latent_array, density)
93
+ latent_density = density = dataclasses.replace(density, array=latent_array)
94
+ return FilterProjectParams(
95
+ latents=FilterProjectLatents(latent_density=latent_density),
96
+ metadata=FilterProjectMetadata(beta=beta),
97
+ )
98
+
99
+ def to_density_fn(params: FilterProjectParams) -> types.Density2DArray:
100
+ """Return a density from the latent parameters."""
101
+ latent_density = params.latents.latent_density
102
+ beta = params.metadata.beta
103
+
104
+ transformed = types.symmetrize_density(latent_density)
105
+ transformed = transforms.density_gaussian_filter_and_tanh(transformed, beta)
106
+ # Scale to ensure that the full valid range of the density array is reachable.
107
+ mid_value = (transformed.lower_bound + transformed.upper_bound) / 2
108
+ transformed = tree_util.tree_map(
109
+ lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
110
+ )
111
+ return transforms.apply_fixed_pixels(transformed)
112
+
113
+ def constraints_fn(params: FilterProjectParams) -> jnp.ndarray:
114
+ """Computes constraints associated with the params."""
115
+ del params
116
+ return jnp.asarray(0.0)
117
+
118
+ def update_fn(
119
+ params: FilterProjectParams,
120
+ updates: FilterProjectParams,
121
+ value: jnp.ndarray,
122
+ step: int,
123
+ ) -> FilterProjectParams:
124
+ """Perform updates to `params` required for the given `step`."""
125
+ del step, value
126
+ return FilterProjectParams(
127
+ latents=tree_util.tree_map(
128
+ lambda a, b: a + b, params.latents, updates.latents
129
+ ),
130
+ metadata=params.metadata,
131
+ )
132
+
133
+ return base.Density2DParameterization(
134
+ to_density=to_density_fn,
135
+ from_density=from_density_fn,
136
+ constraints=constraints_fn,
137
+ update=update_fn,
138
+ )