invrs-opt 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,12 +3,19 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.6.0"
6
+ __version__ = "v0.7.1"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
- from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
- from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
11
- from invrs_opt.wrapped_optax.wrapped_optax import (
9
+ from invrs_opt import parameterization as parameterization
10
+
11
+ from invrs_opt.optimizers.lbfgsb import (
12
+ density_lbfgsb as density_lbfgsb,
13
+ lbfgsb as lbfgsb,
14
+ levelset_lbfgsb as levelset_lbfgsb,
15
+ )
16
+
17
+ from invrs_opt.optimizers.wrapped_optax import (
12
18
  density_wrapped_optax as density_wrapped_optax,
19
+ levelset_wrapped_optax as levelset_wrapped_optax,
20
+ wrapped_optax as wrapped_optax,
13
21
  )
14
- from invrs_opt.wrapped_optax.wrapped_optax import wrapped_optax as wrapped_optax
@@ -10,8 +10,8 @@ from typing import Any, Dict, Optional
10
10
  import requests
11
11
  from totypes import json_utils
12
12
 
13
- from invrs_opt import base
14
13
  from invrs_opt.experimental import labels
14
+ from invrs_opt.optimizers import base
15
15
 
16
16
  PyTree = Any
17
17
  StateToken = str
@@ -4,6 +4,7 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import dataclasses
7
+ import inspect
7
8
  from typing import Any, Protocol
8
9
 
9
10
  import optax # type: ignore[import-untyped]
@@ -49,6 +50,11 @@ class Optimizer:
49
50
  update: UpdateFn
50
51
 
51
52
 
52
- # TODO: consider programatically registering all optax states here.
53
- json_utils.register_custom_type(optax.EmptyState)
54
- json_utils.register_custom_type(optax.ScaleByAdamState)
53
+ # Register all optax state types for serialization.
54
+ optax_types = {}
55
+ for name, obj in inspect.getmembers(optax):
56
+ if name.endswith("State") and isinstance(obj, type):
57
+ optax_types[obj] = True
58
+
59
+ for obj in optax_types.keys():
60
+ json_utils.register_custom_type(obj)
@@ -5,18 +5,25 @@ Copyright (c) 2023 The INVRS-IO authors.
5
5
 
6
6
  import copy
7
7
  import dataclasses
8
- from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
8
+ from typing import Any, Dict, Optional, Sequence, Tuple, Union
9
9
 
10
10
  import jax
11
11
  import jax.numpy as jnp
12
12
  import numpy as onp
13
+ import optax # type: ignore[import-untyped]
13
14
  from jax import flatten_util, tree_util
14
15
  from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
15
16
  _lbfgsb as scipy_lbfgsb,
16
17
  )
17
18
  from totypes import types
18
19
 
19
- from invrs_opt import base, transform
20
+ from invrs_opt.optimizers import base
21
+ from invrs_opt.parameterization import (
22
+ base as parameterization_base,
23
+ filter_project,
24
+ gaussian_levelset,
25
+ pixel,
26
+ )
20
27
 
21
28
  NDArray = onp.ndarray[Any, Any]
22
29
  PyTree = Any
@@ -34,10 +41,10 @@ UPDATE_IPRINT = -1
34
41
 
35
42
  # Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
36
43
  MAXCOR_MAX_VALUE = 100
37
- MAXCOR_DEFAULT = 20
38
- LINE_SEARCH_MAX_STEPS_DEFAULT = 100
39
- FTOL_DEFAULT = 0.0
40
- GTOL_DEFAULT = 0.0
44
+ DEFAULT_MAXCOR = 20
45
+ DEFAULT_LINE_SEARCH_MAX_STEPS = 100
46
+ DEFAULT_FTOL = 0.0
47
+ DEFAULT_GTOL = 0.0
41
48
 
42
49
  # Maps bound scenarios to integers.
43
50
  BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
@@ -51,175 +58,225 @@ FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
51
58
 
52
59
 
53
60
  def lbfgsb(
54
- maxcor: int = MAXCOR_DEFAULT,
55
- line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
56
- ftol: float = FTOL_DEFAULT,
57
- gtol: float = GTOL_DEFAULT,
61
+ *,
62
+ maxcor: int = DEFAULT_MAXCOR,
63
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
64
+ ftol: float = DEFAULT_FTOL,
65
+ gtol: float = DEFAULT_GTOL,
58
66
  ) -> base.Optimizer:
59
- """Return an optimizer implementing the standard L-BFGS-B algorithm.
67
+ """Optimizer implementing the standard L-BFGS-B algorithm.
60
68
 
61
- This optimizer wraps scipy's implementation of the algorithm, and provides
62
- a jax-style API to the scheme. The optimizer works with custom types such
63
- as the `BoundedArray` to constrain the optimization variable.
69
+ The standard L-BFGS-B algorithm uses the direct pixel parameterization for density
70
+ arrays, which simply enforces that values are between the declared upper and lower
71
+ bounds of the density.
64
72
 
65
- Example usage is as follows:
66
-
67
- def fn(x):
68
- leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
69
- return jnp.sum(jnp.asarray(leaves_sum_sq))
70
-
71
- x0 = {
72
- "a": jnp.ones((3,)),
73
- "b": BoundedArray(
74
- value=-jnp.ones((2, 5)),
75
- lower_bound=-5,
76
- upper_bound=5,
77
- ),
78
- }
79
- opt = lbfgsb(maxcor=20, line_search_max_steps=100)
80
- state = opt.init(x0)
81
- for _ in range(10):
82
- x = opt.params(state)
83
- value, grad = jax.value_and_grad(fn)(x)
84
- state = opt.update(grad, value, state)
85
-
86
- While the algorithm can work with pytrees of jax arrays, numpy arrays can
87
- also be used. Thus, e.g. the optimizer can directly be used with autograd.
88
-
89
- When the optimization has converged (according to `ftol` or `gtol` criteria), the
90
- optimizer simply returns the parameters which obtained the converged result. The
73
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
74
+ the optimizer `params` function will simply return the optimal parameters. The
91
75
  convergence can be queried by `is_converged(state)`.
92
76
 
93
77
  Args:
94
- maxcor: The maximum number of variable metric corrections used to define
95
- the limited memory matrix, in the L-BFGS-B scheme.
78
+ maxcor: The maximum number of variable metric corrections used to define the
79
+ limited memory matrix, in the L-BFGS-B scheme.
96
80
  line_search_max_steps: The maximum number of steps in the line search.
97
- ftol: Tolerance for stopping criteria based on function values. See scipy
98
- documentation for details.
99
- gtol: Tolerance for stopping criteria based on gradient.
81
+ ftol: Convergence criteria based on function values. See scipy documentation
82
+ for details.
83
+ gtol: Convergence criteria based on gradient.
100
84
 
101
85
  Returns:
102
- The `base.Optimizer`.
86
+ The `Optimizer` implementing the L-BFGS-B optimizer.
103
87
  """
104
- return transformed_lbfgsb(
88
+ return parameterized_lbfgsb(
89
+ density_parameterization=None,
90
+ penalty=0.0,
105
91
  maxcor=maxcor,
106
92
  line_search_max_steps=line_search_max_steps,
107
93
  ftol=ftol,
108
94
  gtol=gtol,
109
- transform_fn=lambda x: x,
110
- initialize_latent_fn=lambda x: x,
111
95
  )
112
96
 
113
97
 
114
98
  def density_lbfgsb(
99
+ *,
115
100
  beta: float,
116
- maxcor: int = MAXCOR_DEFAULT,
117
- line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
118
- ftol: float = FTOL_DEFAULT,
119
- gtol: float = GTOL_DEFAULT,
101
+ maxcor: int = DEFAULT_MAXCOR,
102
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
103
+ ftol: float = DEFAULT_FTOL,
104
+ gtol: float = DEFAULT_GTOL,
120
105
  ) -> base.Optimizer:
121
- """Return an L-BFGS-B optimizer with additional transforms for density arrays.
122
-
123
- Parameters that are of type `DensityArray2D` are represented as latent parameters
124
- that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
106
+ """Optimizer using L-BFGS-B algorithm with filter-project density parameterization.
125
107
 
126
- transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
108
+ In the filter-project density parameterization, the optimization variable
109
+ associated with a density array is a latent density array; the density is obtained
110
+ by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
111
+ full-width at half-maximum equal to the length scale (the mean of declared minimum
112
+ width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
113
+ operation ("projection").
127
114
 
128
- where the kernel has a full-width at half-maximum determined by the minimum width
129
- and spacing parameters of the `DensityArray2D`. Where the bounds differ, the
130
- density is scaled before the transform is applied, and then unscaled afterwards.
131
-
132
- When the optimization has converged (according to `ftol` or `gtol` criteria), the
133
- optimizer simply returns the parameters which obtained the converged result. The
115
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
116
+ the optimizer `params` function will simply return the optimal parameters. The
134
117
  convergence can be queried by `is_converged(state)`.
135
118
 
136
119
  Args:
137
- beta: Determines the steepness of the thresholding.
138
- maxcor: The maximum number of variable metric corrections used to define
139
- the limited memory matrix, in the L-BFGS-B scheme.
120
+ beta: Determines the sharpness of the thresholding operation.
121
+ maxcor: The maximum number of variable metric corrections used to define the
122
+ limited memory matrix, in the L-BFGS-B scheme.
140
123
  line_search_max_steps: The maximum number of steps in the line search.
141
- ftol: Tolerance for stopping criteria based on function values. See scipy
142
- documentation for details.
143
- gtol: Tolerance for stopping criteria based on gradient.
124
+ ftol: Convergence criteria based on function values. See scipy documentation
125
+ for details.
126
+ gtol: Convergence criteria based on gradient.
144
127
 
145
128
  Returns:
146
- The `base.Optimizer`.
129
+ The `Optimizer` implementing the L-BFGS-B optimizer.
147
130
  """
131
+ return parameterized_lbfgsb(
132
+ density_parameterization=filter_project.filter_project(beta=beta),
133
+ penalty=0.0,
134
+ maxcor=maxcor,
135
+ line_search_max_steps=line_search_max_steps,
136
+ ftol=ftol,
137
+ gtol=gtol,
138
+ )
148
139
 
149
- def transform_fn(tree: PyTree) -> PyTree:
150
- return tree_util.tree_map(
151
- lambda x: transform_density(x) if _is_density(x) else x,
152
- tree,
153
- is_leaf=_is_density,
154
- )
155
140
 
156
- def initialize_latent_fn(tree: PyTree) -> PyTree:
157
- return tree_util.tree_map(
158
- lambda x: initialize_latent_density(x) if _is_density(x) else x,
159
- tree,
160
- is_leaf=_is_density,
161
- )
141
+ def levelset_lbfgsb(
142
+ *,
143
+ penalty: float,
144
+ length_scale_spacing_factor: float = (
145
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
146
+ ),
147
+ length_scale_fwhm_factor: float = (
148
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
149
+ ),
150
+ length_scale_constraint_factor: float = (
151
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
152
+ ),
153
+ smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
154
+ length_scale_constraint_beta: float = (
155
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
156
+ ),
157
+ length_scale_constraint_weight: float = (
158
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
159
+ ),
160
+ curvature_constraint_weight: float = (
161
+ gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
162
+ ),
163
+ fixed_pixel_constraint_weight: float = (
164
+ gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
165
+ ),
166
+ init_optimizer: optax.GradientTransformation = (
167
+ gaussian_levelset.DEFAULT_INIT_OPTIMIZER
168
+ ),
169
+ init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
170
+ maxcor: int = DEFAULT_MAXCOR,
171
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
172
+ ftol: float = DEFAULT_FTOL,
173
+ gtol: float = DEFAULT_GTOL,
174
+ ) -> base.Optimizer:
175
+ """Optimizer using L-BFGS-B algorithm with levelset density parameterization.
176
+
177
+ In the levelset parameterization, the optimization variable associated with a
178
+ density array is an array giving the amplitudes of Gaussian radial basis functions
179
+ that represent a levelset function over the domain of the density. In the levelset
180
+ parameterization, gradients are nonzero only at the edges of features, and in
181
+ general the topology of a solution does not change during the course of
182
+ optimization.
183
+
184
+ The spacing and full-width at half-maximum of the Gaussian basis functions gives
185
+ some amount of control over length scales. In addition, constraints associated with
186
+ length scale, radius of curvature, and deviation from fixed pixels are
187
+ automatically computed and penalized with a weight given by `penalty`. In general,
188
+ this helps ensure that features in an optimized density array violate the specified
189
+ constraints to a lesser degree. The constraints are based on "Analytical level set
190
+ fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
191
+
192
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
193
+ the optimizer `params` function will simply return the optimal parameters. The
194
+ convergence can be queried by `is_converged(state)`.
162
195
 
163
- def transform_density(density: types.Density2DArray) -> types.Density2DArray:
164
- transformed = types.symmetrize_density(density)
165
- transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
166
- # Scale to ensure that the full valid range of the density array is reachable.
167
- mid_value = (density.lower_bound + density.upper_bound) / 2
168
- transformed = tree_util.tree_map(
169
- lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
170
- )
171
- return transform.apply_fixed_pixels(transformed)
172
-
173
- def initialize_latent_density(
174
- density: types.Density2DArray,
175
- ) -> types.Density2DArray:
176
- array = transform.normalized_array_from_density(density)
177
- array = jnp.clip(array, -1, 1)
178
- array *= jnp.tanh(beta)
179
- latent_array = jnp.arctanh(array) / beta
180
- latent_array = transform.rescale_array_for_density(latent_array, density)
181
- return dataclasses.replace(density, array=latent_array)
182
-
183
- return transformed_lbfgsb(
196
+ Args:
197
+ penalty: The weight of the fabrication penalty, which combines length scale,
198
+ curvature, and fixed pixel constraints.
199
+ length_scale_spacing_factor: The number of levelset control points per unit of
200
+ minimum length scale (mean of density minimum width and minimum spacing).
201
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
202
+ the minimum length scale.
203
+ length_scale_constraint_factor: Multiplies the target length scale in the
204
+ levelset constraints. A value greater than 1 is pessimistic and drives the
205
+ solution to have a larger length scale (relative to smaller values).
206
+ smoothing_factor: For values greater than 1, the density is initially computed
207
+ at higher resolution and then downsampled, yielding smoother geometries.
208
+ length_scale_constraint_beta: Controls relaxation of the length scale
209
+ constraint near the zero level.
210
+ length_scale_constraint_weight: The weight of the length scale constraint in
211
+ the overall fabrication constraint peenalty.
212
+ curvature_constraint_weight: The weight of the curvature constraint.
213
+ fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
214
+ init_optimizer: The optimizer used in the initialization of the levelset
215
+ parameterization. At initialization, the latent parameters are optimized so
216
+ that the initial parameters match the binarized initial density.
217
+ init_steps: The number of optimization steps used in the initialization.
218
+ maxcor: The maximum number of variable metric corrections used to define the
219
+ limited memory matrix, in the L-BFGS-B scheme.
220
+ line_search_max_steps: The maximum number of steps in the line search.
221
+ ftol: Convergence criteria based on function values. See scipy documentation
222
+ for details.
223
+ gtol: Convergence criteria based on gradient.
224
+
225
+ Returns:
226
+ The `Optimizer` implementing the L-BFGS-B optimizer.
227
+ """
228
+ return parameterized_lbfgsb(
229
+ density_parameterization=gaussian_levelset.gaussian_levelset(
230
+ length_scale_spacing_factor=length_scale_spacing_factor,
231
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
232
+ length_scale_constraint_factor=length_scale_constraint_factor,
233
+ smoothing_factor=smoothing_factor,
234
+ length_scale_constraint_beta=length_scale_constraint_beta,
235
+ length_scale_constraint_weight=length_scale_constraint_weight,
236
+ curvature_constraint_weight=curvature_constraint_weight,
237
+ fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
238
+ init_optimizer=init_optimizer,
239
+ init_steps=init_steps,
240
+ ),
241
+ penalty=penalty,
184
242
  maxcor=maxcor,
185
243
  line_search_max_steps=line_search_max_steps,
186
244
  ftol=ftol,
187
245
  gtol=gtol,
188
- transform_fn=transform_fn,
189
- initialize_latent_fn=initialize_latent_fn,
190
246
  )
191
247
 
192
248
 
193
- def transformed_lbfgsb(
194
- maxcor: int,
195
- line_search_max_steps: int,
196
- ftol: float,
197
- gtol: float,
198
- transform_fn: Callable[[PyTree], PyTree],
199
- initialize_latent_fn: Callable[[PyTree], PyTree],
200
- ) -> base.Optimizer:
201
- """Construct an latent parameter L-BFGS-B optimizer.
249
+ # -----------------------------------------------------------------------------
250
+ # Base parameterized L-BFGS-B optimizer.
251
+ # -----------------------------------------------------------------------------
202
252
 
203
- The optimized parameters are termed latent parameters, from which the
204
- actual parameters returned by the optimizer are obtained using the
205
- `transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
206
- the identity), this is equivalent to the standard L-BFGS-B algorithm.
207
253
 
208
- When the optimization has converged (according to `ftol` or `gtol` criteria), the
209
- optimizer simply returns the parameters which obtained the converged result. The
210
- convergence can be queried by `is_converged(state)`.
254
+ def parameterized_lbfgsb(
255
+ density_parameterization: Optional[parameterization_base.Density2DParameterization],
256
+ penalty: float,
257
+ maxcor: int = DEFAULT_MAXCOR,
258
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
259
+ ftol: float = DEFAULT_FTOL,
260
+ gtol: float = DEFAULT_GTOL,
261
+ ) -> base.Optimizer:
262
+ """Optimizer using L-BFGS-B optimizer with specified density parameterization.
263
+
264
+ This optimizer wraps scipy's implementation of the algorithm, and provides
265
+ a jax-style API to the scheme. The optimizer works with custom types such
266
+ as the `BoundedArray` to constrain the optimization variable.
211
267
 
212
268
  Args:
213
- maxcor: The maximum number of variable metric corrections used to define
214
- the limited memory matrix, in the L-BFGS-B scheme.
269
+ density_parameterization: The parameterization to be used, or `None`. When no
270
+ parameterization is given, the direct pixel parameterization is used for
271
+ density arrays.
272
+ penalty: The weight of the scalar penalty formed from the constraints of the
273
+ parameterization.
274
+ maxcor: The maximum number of variable metric corrections used to define the
275
+ limited memory matrix, in the L-BFGS-B scheme.
215
276
  line_search_max_steps: The maximum number of steps in the line search.
216
- ftol: Tolerance for stopping criteria based on function values. See scipy
217
- documentation for details.
218
- gtol: Tolerance for stopping criteria based on gradient.
219
- transform_fn: Function which transforms the internal latent parameters to
220
- the parameters returned by the optimizer.
221
- initialize_latent_fn: Function which computes the initial latent parameters
222
- given the initial parameters.
277
+ ftol: Convergence criteria based on function values. See scipy documentation
278
+ for details.
279
+ gtol: Convergence criteria based on gradient.
223
280
 
224
281
  Returns:
225
282
  The `base.Optimizer`.
@@ -236,33 +293,73 @@ def transformed_lbfgsb(
236
293
  f"{line_search_max_steps}"
237
294
  )
238
295
 
296
+ if density_parameterization is None:
297
+ density_parameterization = pixel.pixel()
298
+
299
+ def _init_latents(params: PyTree) -> PyTree:
300
+ def _leaf_init_latents(leaf: Any) -> Any:
301
+ leaf = _clip(leaf)
302
+ if not _is_density(leaf) or density_parameterization is None:
303
+ return leaf
304
+ return density_parameterization.from_density(leaf)
305
+
306
+ return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
307
+
308
+ def _params_from_latents(latent_params: PyTree) -> PyTree:
309
+ def _leaf_params_from_latents(leaf: Any) -> Any:
310
+ if not _is_parameterized_density(leaf) or density_parameterization is None:
311
+ return leaf
312
+ return density_parameterization.to_density(leaf)
313
+
314
+ return tree_util.tree_map(
315
+ _leaf_params_from_latents,
316
+ latent_params,
317
+ is_leaf=_is_parameterized_density,
318
+ )
319
+
320
+ def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
321
+ def _constraint_loss_leaf(
322
+ params: parameterization_base.ParameterizedDensity2DArrayBase,
323
+ ) -> jnp.ndarray:
324
+ constraints = density_parameterization.constraints(params)
325
+ constraints = tree_util.tree_map(
326
+ lambda x: jnp.sum(jnp.maximum(x, 0.0)),
327
+ constraints,
328
+ )
329
+ return jnp.sum(jnp.asarray(constraints))
330
+
331
+ losses = [0.0] + [
332
+ _constraint_loss_leaf(p)
333
+ for p in tree_util.tree_leaves(
334
+ latent_params, is_leaf=_is_parameterized_density
335
+ )
336
+ if _is_parameterized_density(p)
337
+ ]
338
+ return penalty * jnp.sum(jnp.asarray(losses))
339
+
239
340
  def init_fn(params: PyTree) -> LbfgsbState:
240
341
  """Initializes the optimization state."""
241
342
 
242
- def _init_pure(params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
243
- lower_bound = types.extract_lower_bound(params)
244
- upper_bound = types.extract_upper_bound(params)
343
+ def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
344
+ lower_bound = types.extract_lower_bound(latent_params)
345
+ upper_bound = types.extract_upper_bound(latent_params)
245
346
  scipy_lbfgsb_state = ScipyLbfgsbState.init(
246
- x0=_to_numpy(params),
247
- lower_bound=_bound_for_params(lower_bound, params),
248
- upper_bound=_bound_for_params(upper_bound, params),
347
+ x0=_to_numpy(latent_params),
348
+ lower_bound=_bound_for_params(lower_bound, latent_params),
349
+ upper_bound=_bound_for_params(upper_bound, latent_params),
249
350
  maxcor=maxcor,
250
351
  line_search_max_steps=line_search_max_steps,
251
352
  ftol=ftol,
252
353
  gtol=gtol,
253
354
  )
254
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
355
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_params)
255
356
  return latent_params, scipy_lbfgsb_state.to_jax()
256
357
 
257
- (
258
- latent_params,
259
- jax_lbfgsb_state,
260
- ) = jax.pure_callback(
261
- _init_pure,
262
- _example_state(params, maxcor),
263
- initialize_latent_fn(params),
358
+ latent_params = _init_latents(params)
359
+ latent_params, jax_lbfgsb_state = jax.pure_callback(
360
+ _init_state_pure, _example_state(latent_params, maxcor), latent_params
264
361
  )
265
- return transform_fn(latent_params), latent_params, jax_lbfgsb_state
362
+ return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
266
363
 
267
364
  def params_fn(state: LbfgsbState) -> PyTree:
268
365
  """Returns the parameters for the given `state`."""
@@ -294,16 +391,35 @@ def transformed_lbfgsb(
294
391
  return flat_latent_params, scipy_lbfgsb_state.to_jax()
295
392
 
296
393
  _, latent_params, jax_lbfgsb_state = state
297
- _, vjp_fn = jax.vjp(transform_fn, latent_params)
394
+ _, vjp_fn = jax.vjp(_params_from_latents, latent_params)
298
395
  (latent_grad,) = vjp_fn(grad)
396
+
397
+ if not (
398
+ tree_util.tree_structure(latent_grad)
399
+ == tree_util.tree_structure(latent_params) # type: ignore[operator]
400
+ ):
401
+ raise ValueError(
402
+ f"Tree structure of `latent_grad` was different than expected, got \n"
403
+ f"{tree_util.tree_structure(latent_grad)} but expected \n"
404
+ f"{tree_util.tree_structure(latent_params)}."
405
+ )
406
+
407
+ (
408
+ constraint_loss_value,
409
+ constraint_loss_grad,
410
+ ) = jax.value_and_grad(
411
+ _constraint_loss
412
+ )(latent_params)
413
+ value += constraint_loss_value
414
+ latent_grad = tree_util.tree_map(
415
+ lambda a, b: a + b, latent_grad, constraint_loss_grad
416
+ )
417
+
299
418
  flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
300
419
  latent_grad
301
420
  ) # type: ignore[no-untyped-call]
302
421
 
303
- (
304
- flat_latent_params,
305
- jax_lbfgsb_state,
306
- ) = jax.pure_callback(
422
+ flat_latent_params, jax_lbfgsb_state = jax.pure_callback(
307
423
  _update_pure,
308
424
  (flat_latent_grad, jax_lbfgsb_state),
309
425
  flat_latent_grad,
@@ -311,7 +427,7 @@ def transformed_lbfgsb(
311
427
  jax_lbfgsb_state,
312
428
  )
313
429
  latent_params = unflatten_fn(flat_latent_params)
314
- return transform_fn(latent_params), latent_params, jax_lbfgsb_state
430
+ return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
315
431
 
316
432
  return base.Optimizer(
317
433
  init=init_fn,
@@ -335,6 +451,31 @@ def _is_density(leaf: Any) -> Any:
335
451
  return isinstance(leaf, types.Density2DArray)
336
452
 
337
453
 
454
+ def _is_parameterized_density(leaf: Any) -> Any:
455
+ """Return `True` if `leaf` is a parameterized density array."""
456
+ return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase)
457
+
458
+
459
+ def _is_custom_type(leaf: Any) -> bool:
460
+ """Return `True` if `leaf` is a recognized custom type."""
461
+ return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
462
+
463
+
464
+ def _clip(pytree: PyTree) -> PyTree:
465
+ """Clips leaves on `pytree` to their bounds."""
466
+
467
+ def _clip_fn(leaf: Any) -> Any:
468
+ if not _is_custom_type(leaf):
469
+ return leaf
470
+ if leaf.lower_bound is None and leaf.upper_bound is None:
471
+ return leaf
472
+ return tree_util.tree_map(
473
+ lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
474
+ )
475
+
476
+ return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
477
+
478
+
338
479
  def _to_numpy(params: PyTree) -> NDArray:
339
480
  """Flattens a `params` pytree into a single rank-1 numpy array."""
340
481
  x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]