invrs-opt 0.5.2__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,8 +3,19 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.5.2"
6
+ __version__ = "v0.7.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
- from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
- from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
9
+ from invrs_opt import parameterization as parameterization
10
+
11
+ from invrs_opt.optimizers.lbfgsb import (
12
+ density_lbfgsb as density_lbfgsb,
13
+ lbfgsb as lbfgsb,
14
+ levelset_lbfgsb as levelset_lbfgsb,
15
+ )
16
+
17
+ from invrs_opt.optimizers.wrapped_optax import (
18
+ density_wrapped_optax as density_wrapped_optax,
19
+ levelset_wrapped_optax as levelset_wrapped_optax,
20
+ wrapped_optax as wrapped_optax,
21
+ )
@@ -4,15 +4,14 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import json
7
- import requests
8
7
  import time
9
8
  from typing import Any, Dict, Optional
10
9
 
10
+ import requests
11
11
  from totypes import json_utils
12
12
 
13
- from invrs_opt import base
14
13
  from invrs_opt.experimental import labels
15
-
14
+ from invrs_opt.optimizers import base
16
15
 
17
16
  PyTree = Any
18
17
  StateToken = str
@@ -4,8 +4,12 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import dataclasses
7
+ import inspect
7
8
  from typing import Any, Protocol
8
9
 
10
+ import optax # type: ignore[import-untyped]
11
+ from totypes import json_utils
12
+
9
13
  PyTree = Any
10
14
 
11
15
 
@@ -44,3 +48,9 @@ class Optimizer:
44
48
  init: InitFn
45
49
  params: ParamsFn
46
50
  update: UpdateFn
51
+
52
+
53
+ # Register all optax state types for serialization.
54
+ for name, obj in inspect.getmembers(optax):
55
+ if name.endswith("State") and isinstance(obj, type):
56
+ json_utils.register_custom_type(obj)
@@ -5,19 +5,25 @@ Copyright (c) 2023 The INVRS-IO authors.
5
5
 
6
6
  import copy
7
7
  import dataclasses
8
- from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
8
+ from typing import Any, Dict, Optional, Sequence, Tuple, Union
9
9
 
10
10
  import jax
11
11
  import jax.numpy as jnp
12
12
  import numpy as onp
13
+ import optax # type: ignore[import-untyped]
13
14
  from jax import flatten_util, tree_util
14
15
  from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
15
16
  _lbfgsb as scipy_lbfgsb,
16
17
  )
17
18
  from totypes import types
18
19
 
19
- from invrs_opt import base
20
- from invrs_opt.lbfgsb import transform
20
+ from invrs_opt.optimizers import base
21
+ from invrs_opt.parameterization import (
22
+ base as parameterization_base,
23
+ filter_project,
24
+ gaussian_levelset,
25
+ pixel,
26
+ )
21
27
 
22
28
  NDArray = onp.ndarray[Any, Any]
23
29
  PyTree = Any
@@ -35,10 +41,10 @@ UPDATE_IPRINT = -1
35
41
 
36
42
  # Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
37
43
  MAXCOR_MAX_VALUE = 100
38
- MAXCOR_DEFAULT = 20
39
- LINE_SEARCH_MAX_STEPS_DEFAULT = 100
40
- FTOL_DEFAULT = 0.0
41
- GTOL_DEFAULT = 0.0
44
+ DEFAULT_MAXCOR = 20
45
+ DEFAULT_LINE_SEARCH_MAX_STEPS = 100
46
+ DEFAULT_FTOL = 0.0
47
+ DEFAULT_GTOL = 0.0
42
48
 
43
49
  # Maps bound scenarios to integers.
44
50
  BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
@@ -52,175 +58,225 @@ FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
52
58
 
53
59
 
54
60
  def lbfgsb(
55
- maxcor: int = MAXCOR_DEFAULT,
56
- line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
57
- ftol: float = FTOL_DEFAULT,
58
- gtol: float = GTOL_DEFAULT,
61
+ *,
62
+ maxcor: int = DEFAULT_MAXCOR,
63
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
64
+ ftol: float = DEFAULT_FTOL,
65
+ gtol: float = DEFAULT_GTOL,
59
66
  ) -> base.Optimizer:
60
- """Return an optimizer implementing the standard L-BFGS-B algorithm.
67
+ """Optimizer implementing the standard L-BFGS-B algorithm.
61
68
 
62
- This optimizer wraps scipy's implementation of the algorithm, and provides
63
- a jax-style API to the scheme. The optimizer works with custom types such
64
- as the `BoundedArray` to constrain the optimization variable.
69
+ The standard L-BFGS-B algorithm uses the direct pixel parameterization for density
70
+ arrays, which simply enforces that values are between the declared upper and lower
71
+ bounds of the density.
65
72
 
66
- Example usage is as follows:
67
-
68
- def fn(x):
69
- leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
70
- return jnp.sum(jnp.asarray(leaves_sum_sq))
71
-
72
- x0 = {
73
- "a": jnp.ones((3,)),
74
- "b": BoundedArray(
75
- value=-jnp.ones((2, 5)),
76
- lower_bound=-5,
77
- upper_bound=5,
78
- ),
79
- }
80
- opt = lbfgsb(maxcor=20, line_search_max_steps=100)
81
- state = opt.init(x0)
82
- for _ in range(10):
83
- x = opt.params(state)
84
- value, grad = jax.value_and_grad(fn)(x)
85
- state = opt.update(grad, value, state)
86
-
87
- While the algorithm can work with pytrees of jax arrays, numpy arrays can
88
- also be used. Thus, e.g. the optimizer can directly be used with autograd.
89
-
90
- When the optimization has converged (according to `ftol` or `gtol` criteria), the
91
- optimizer simply returns the parameters which obtained the converged result. The
73
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
74
+ the optimizer `params` function will simply return the optimal parameters. The
92
75
  convergence can be queried by `is_converged(state)`.
93
76
 
94
77
  Args:
95
- maxcor: The maximum number of variable metric corrections used to define
96
- the limited memory matrix, in the L-BFGS-B scheme.
78
+ maxcor: The maximum number of variable metric corrections used to define the
79
+ limited memory matrix, in the L-BFGS-B scheme.
97
80
  line_search_max_steps: The maximum number of steps in the line search.
98
- ftol: Tolerance for stopping criteria based on function values. See scipy
99
- documentation for details.
100
- gtol: Tolerance for stopping criteria based on gradient.
81
+ ftol: Convergence criteria based on function values. See scipy documentation
82
+ for details.
83
+ gtol: Convergence criteria based on gradient.
101
84
 
102
85
  Returns:
103
- The `base.Optimizer`.
86
+ The `Optimizer` implementing the L-BFGS-B optimizer.
104
87
  """
105
- return transformed_lbfgsb(
88
+ return parameterized_lbfgsb(
89
+ density_parameterization=None,
90
+ penalty=0.0,
106
91
  maxcor=maxcor,
107
92
  line_search_max_steps=line_search_max_steps,
108
93
  ftol=ftol,
109
94
  gtol=gtol,
110
- transform_fn=lambda x: x,
111
- initialize_latent_fn=lambda x: x,
112
95
  )
113
96
 
114
97
 
115
98
  def density_lbfgsb(
99
+ *,
116
100
  beta: float,
117
- maxcor: int = MAXCOR_DEFAULT,
118
- line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
119
- ftol: float = FTOL_DEFAULT,
120
- gtol: float = GTOL_DEFAULT,
101
+ maxcor: int = DEFAULT_MAXCOR,
102
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
103
+ ftol: float = DEFAULT_FTOL,
104
+ gtol: float = DEFAULT_GTOL,
121
105
  ) -> base.Optimizer:
122
- """Return an L-BFGS-B optimizer with additional transforms for density arrays.
123
-
124
- Parameters that are of type `DensityArray2D` are represented as latent parameters
125
- that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
106
+ """Optimizer using L-BFGS-B algorithm with filter-project density parameterization.
126
107
 
127
- transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
108
+ In the filter-project density parameterization, the optimization variable
109
+ associated with a density array is a latent density array; the density is obtained
110
+ by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
111
+ full-width at half-maximum equal to the length scale (the mean of declared minimum
112
+ width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
113
+ operation ("projection").
128
114
 
129
- where the kernel has a full-width at half-maximum determined by the minimum width
130
- and spacing parameters of the `DensityArray2D`. Where the bounds differ, the
131
- density is scaled before the transform is applied, and then unscaled afterwards.
132
-
133
- When the optimization has converged (according to `ftol` or `gtol` criteria), the
134
- optimizer simply returns the parameters which obtained the converged result. The
115
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
116
+ the optimizer `params` function will simply return the optimal parameters. The
135
117
  convergence can be queried by `is_converged(state)`.
136
118
 
137
119
  Args:
138
- beta: Determines the steepness of the thresholding.
139
- maxcor: The maximum number of variable metric corrections used to define
140
- the limited memory matrix, in the L-BFGS-B scheme.
120
+ beta: Determines the sharpness of the thresholding operation.
121
+ maxcor: The maximum number of variable metric corrections used to define the
122
+ limited memory matrix, in the L-BFGS-B scheme.
141
123
  line_search_max_steps: The maximum number of steps in the line search.
142
- ftol: Tolerance for stopping criteria based on function values. See scipy
143
- documentation for details.
144
- gtol: Tolerance for stopping criteria based on gradient.
124
+ ftol: Convergence criteria based on function values. See scipy documentation
125
+ for details.
126
+ gtol: Convergence criteria based on gradient.
145
127
 
146
128
  Returns:
147
- The `base.Optimizer`.
129
+ The `Optimizer` implementing the L-BFGS-B optimizer.
148
130
  """
131
+ return parameterized_lbfgsb(
132
+ density_parameterization=filter_project.filter_project(beta=beta),
133
+ penalty=0.0,
134
+ maxcor=maxcor,
135
+ line_search_max_steps=line_search_max_steps,
136
+ ftol=ftol,
137
+ gtol=gtol,
138
+ )
149
139
 
150
- def transform_fn(tree: PyTree) -> PyTree:
151
- return tree_util.tree_map(
152
- lambda x: transform_density(x) if _is_density(x) else x,
153
- tree,
154
- is_leaf=_is_density,
155
- )
156
140
 
157
- def initialize_latent_fn(tree: PyTree) -> PyTree:
158
- return tree_util.tree_map(
159
- lambda x: initialize_latent_density(x) if _is_density(x) else x,
160
- tree,
161
- is_leaf=_is_density,
162
- )
141
+ def levelset_lbfgsb(
142
+ *,
143
+ penalty: float,
144
+ length_scale_spacing_factor: float = (
145
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
146
+ ),
147
+ length_scale_fwhm_factor: float = (
148
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
149
+ ),
150
+ length_scale_constraint_factor: float = (
151
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
152
+ ),
153
+ smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
154
+ length_scale_constraint_beta: float = (
155
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
156
+ ),
157
+ length_scale_constraint_weight: float = (
158
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
159
+ ),
160
+ curvature_constraint_weight: float = (
161
+ gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
162
+ ),
163
+ fixed_pixel_constraint_weight: float = (
164
+ gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
165
+ ),
166
+ init_optimizer: optax.GradientTransformation = (
167
+ gaussian_levelset.DEFAULT_INIT_OPTIMIZER
168
+ ),
169
+ init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
170
+ maxcor: int = DEFAULT_MAXCOR,
171
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
172
+ ftol: float = DEFAULT_FTOL,
173
+ gtol: float = DEFAULT_GTOL,
174
+ ) -> base.Optimizer:
175
+ """Optimizer using L-BFGS-B algorithm with levelset density parameterization.
176
+
177
+ In the levelset parameterization, the optimization variable associated with a
178
+ density array is an array giving the amplitudes of Gaussian radial basis functions
179
+ that represent a levelset function over the domain of the density. In the levelset
180
+ parameterization, gradients are nonzero only at the edges of features, and in
181
+ general the topology of a solution does not change during the course of
182
+ optimization.
183
+
184
+ The spacing and full-width at half-maximum of the Gaussian basis functions gives
185
+ some amount of control over length scales. In addition, constraints associated with
186
+ length scale, radius of curvature, and deviation from fixed pixels are
187
+ automatically computed and penalized with a weight given by `penalty`. In general,
188
+ this helps ensure that features in an optimized density array violate the specified
189
+ constraints to a lesser degree. The constraints are based on "Analytical level set
190
+ fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
191
+
192
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
193
+ the optimizer `params` function will simply return the optimal parameters. The
194
+ convergence can be queried by `is_converged(state)`.
163
195
 
164
- def transform_density(density: types.Density2DArray) -> types.Density2DArray:
165
- transformed = types.symmetrize_density(density)
166
- transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
167
- # Scale to ensure that the full valid range of the density array is reachable.
168
- mid_value = (density.lower_bound + density.upper_bound) / 2
169
- transformed = tree_util.tree_map(
170
- lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
171
- )
172
- return transform.apply_fixed_pixels(transformed)
173
-
174
- def initialize_latent_density(
175
- density: types.Density2DArray,
176
- ) -> types.Density2DArray:
177
- array = transform.normalized_array_from_density(density)
178
- array = jnp.clip(array, -1, 1)
179
- array *= jnp.tanh(beta)
180
- latent_array = jnp.arctanh(array) / beta
181
- latent_array = transform.rescale_array_for_density(latent_array, density)
182
- return dataclasses.replace(density, array=latent_array)
183
-
184
- return transformed_lbfgsb(
196
+ Args:
197
+ penalty: The weight of the fabrication penalty, which combines length scale,
198
+ curvature, and fixed pixel constraints.
199
+ length_scale_spacing_factor: The number of levelset control points per unit of
200
+ minimum length scale (mean of density minimum width and minimum spacing).
201
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
202
+ the minimum length scale.
203
+ length_scale_constraint_factor: Multiplies the target length scale in the
204
+ levelset constraints. A value greater than 1 is pessimistic and drives the
205
+ solution to have a larger length scale (relative to smaller values).
206
+ smoothing_factor: For values greater than 1, the density is initially computed
207
+ at higher resolution and then downsampled, yielding smoother geometries.
208
+ length_scale_constraint_beta: Controls relaxation of the length scale
209
+ constraint near the zero level.
210
+ length_scale_constraint_weight: The weight of the length scale constraint in
211
+ the overall fabrication constraint peenalty.
212
+ curvature_constraint_weight: The weight of the curvature constraint.
213
+ fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
214
+ init_optimizer: The optimizer used in the initialization of the levelset
215
+ parameterization. At initialization, the latent parameters are optimized so
216
+ that the initial parameters match the binarized initial density.
217
+ init_steps: The number of optimization steps used in the initialization.
218
+ maxcor: The maximum number of variable metric corrections used to define the
219
+ limited memory matrix, in the L-BFGS-B scheme.
220
+ line_search_max_steps: The maximum number of steps in the line search.
221
+ ftol: Convergence criteria based on function values. See scipy documentation
222
+ for details.
223
+ gtol: Convergence criteria based on gradient.
224
+
225
+ Returns:
226
+ The `Optimizer` implementing the L-BFGS-B optimizer.
227
+ """
228
+ return parameterized_lbfgsb(
229
+ density_parameterization=gaussian_levelset.gaussian_levelset(
230
+ length_scale_spacing_factor=length_scale_spacing_factor,
231
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
232
+ length_scale_constraint_factor=length_scale_constraint_factor,
233
+ smoothing_factor=smoothing_factor,
234
+ length_scale_constraint_beta=length_scale_constraint_beta,
235
+ length_scale_constraint_weight=length_scale_constraint_weight,
236
+ curvature_constraint_weight=curvature_constraint_weight,
237
+ fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
238
+ init_optimizer=init_optimizer,
239
+ init_steps=init_steps,
240
+ ),
241
+ penalty=penalty,
185
242
  maxcor=maxcor,
186
243
  line_search_max_steps=line_search_max_steps,
187
244
  ftol=ftol,
188
245
  gtol=gtol,
189
- transform_fn=transform_fn,
190
- initialize_latent_fn=initialize_latent_fn,
191
246
  )
192
247
 
193
248
 
194
- def transformed_lbfgsb(
195
- maxcor: int,
196
- line_search_max_steps: int,
197
- ftol: float,
198
- gtol: float,
199
- transform_fn: Callable[[PyTree], PyTree],
200
- initialize_latent_fn: Callable[[PyTree], PyTree],
201
- ) -> base.Optimizer:
202
- """Construct an latent parameter L-BFGS-B optimizer.
249
+ # -----------------------------------------------------------------------------
250
+ # Base parameterized L-BFGS-B optimizer.
251
+ # -----------------------------------------------------------------------------
203
252
 
204
- The optimized parameters are termed latent parameters, from which the
205
- actual parameters returned by the optimizer are obtained using the
206
- `transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
207
- the identity), this is equivalent to the standard L-BFGS-B algorithm.
208
253
 
209
- When the optimization has converged (according to `ftol` or `gtol` criteria), the
210
- optimizer simply returns the parameters which obtained the converged result. The
211
- convergence can be queried by `is_converged(state)`.
254
+ def parameterized_lbfgsb(
255
+ density_parameterization: Optional[parameterization_base.Density2DParameterization],
256
+ penalty: float,
257
+ maxcor: int = DEFAULT_MAXCOR,
258
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
259
+ ftol: float = DEFAULT_FTOL,
260
+ gtol: float = DEFAULT_GTOL,
261
+ ) -> base.Optimizer:
262
+ """Optimizer using L-BFGS-B optimizer with specified density parameterization.
263
+
264
+ This optimizer wraps scipy's implementation of the algorithm, and provides
265
+ a jax-style API to the scheme. The optimizer works with custom types such
266
+ as the `BoundedArray` to constrain the optimization variable.
212
267
 
213
268
  Args:
214
- maxcor: The maximum number of variable metric corrections used to define
215
- the limited memory matrix, in the L-BFGS-B scheme.
269
+ density_parameterization: The parameterization to be used, or `None`. When no
270
+ parameterization is given, the direct pixel parameterization is used for
271
+ density arrays.
272
+ penalty: The weight of the scalar penalty formed from the constraints of the
273
+ parameterization.
274
+ maxcor: The maximum number of variable metric corrections used to define the
275
+ limited memory matrix, in the L-BFGS-B scheme.
216
276
  line_search_max_steps: The maximum number of steps in the line search.
217
- ftol: Tolerance for stopping criteria based on function values. See scipy
218
- documentation for details.
219
- gtol: Tolerance for stopping criteria based on gradient.
220
- transform_fn: Function which transforms the internal latent parameters to
221
- the parameters returned by the optimizer.
222
- initialize_latent_fn: Function which computes the initial latent parameters
223
- given the initial parameters.
277
+ ftol: Convergence criteria based on function values. See scipy documentation
278
+ for details.
279
+ gtol: Convergence criteria based on gradient.
224
280
 
225
281
  Returns:
226
282
  The `base.Optimizer`.
@@ -237,33 +293,73 @@ def transformed_lbfgsb(
237
293
  f"{line_search_max_steps}"
238
294
  )
239
295
 
296
+ if density_parameterization is None:
297
+ density_parameterization = pixel.pixel()
298
+
299
+ def _init_latents(params: PyTree) -> PyTree:
300
+ def _leaf_init_latents(leaf: Any) -> Any:
301
+ leaf = _clip(leaf)
302
+ if not _is_density(leaf) or density_parameterization is None:
303
+ return leaf
304
+ return density_parameterization.from_density(leaf)
305
+
306
+ return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
307
+
308
+ def _params_from_latents(latent_params: PyTree) -> PyTree:
309
+ def _leaf_params_from_latents(leaf: Any) -> Any:
310
+ if not _is_parameterized_density(leaf) or density_parameterization is None:
311
+ return leaf
312
+ return density_parameterization.to_density(leaf)
313
+
314
+ return tree_util.tree_map(
315
+ _leaf_params_from_latents,
316
+ latent_params,
317
+ is_leaf=_is_parameterized_density,
318
+ )
319
+
320
+ def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
321
+ def _constraint_loss_leaf(
322
+ params: parameterization_base.ParameterizedDensity2DArrayBase,
323
+ ) -> jnp.ndarray:
324
+ constraints = density_parameterization.constraints(params)
325
+ constraints = tree_util.tree_map(
326
+ lambda x: jnp.sum(jnp.maximum(x, 0.0)),
327
+ constraints,
328
+ )
329
+ return jnp.sum(jnp.asarray(constraints))
330
+
331
+ losses = [0.0] + [
332
+ _constraint_loss_leaf(p)
333
+ for p in tree_util.tree_leaves(
334
+ latent_params, is_leaf=_is_parameterized_density
335
+ )
336
+ if _is_parameterized_density(p)
337
+ ]
338
+ return penalty * jnp.sum(jnp.asarray(losses))
339
+
240
340
  def init_fn(params: PyTree) -> LbfgsbState:
241
341
  """Initializes the optimization state."""
242
342
 
243
- def _init_pure(params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
244
- lower_bound = types.extract_lower_bound(params)
245
- upper_bound = types.extract_upper_bound(params)
343
+ def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
344
+ lower_bound = types.extract_lower_bound(latent_params)
345
+ upper_bound = types.extract_upper_bound(latent_params)
246
346
  scipy_lbfgsb_state = ScipyLbfgsbState.init(
247
- x0=_to_numpy(params),
248
- lower_bound=_bound_for_params(lower_bound, params),
249
- upper_bound=_bound_for_params(upper_bound, params),
347
+ x0=_to_numpy(latent_params),
348
+ lower_bound=_bound_for_params(lower_bound, latent_params),
349
+ upper_bound=_bound_for_params(upper_bound, latent_params),
250
350
  maxcor=maxcor,
251
351
  line_search_max_steps=line_search_max_steps,
252
352
  ftol=ftol,
253
353
  gtol=gtol,
254
354
  )
255
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
355
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_params)
256
356
  return latent_params, scipy_lbfgsb_state.to_jax()
257
357
 
258
- (
259
- latent_params,
260
- jax_lbfgsb_state,
261
- ) = jax.pure_callback(
262
- _init_pure,
263
- _example_state(params, maxcor),
264
- initialize_latent_fn(params),
358
+ latent_params = _init_latents(params)
359
+ latent_params, jax_lbfgsb_state = jax.pure_callback(
360
+ _init_state_pure, _example_state(latent_params, maxcor), latent_params
265
361
  )
266
- return transform_fn(latent_params), latent_params, jax_lbfgsb_state
362
+ return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
267
363
 
268
364
  def params_fn(state: LbfgsbState) -> PyTree:
269
365
  """Returns the parameters for the given `state`."""
@@ -295,16 +391,35 @@ def transformed_lbfgsb(
295
391
  return flat_latent_params, scipy_lbfgsb_state.to_jax()
296
392
 
297
393
  _, latent_params, jax_lbfgsb_state = state
298
- _, vjp_fn = jax.vjp(transform_fn, latent_params)
394
+ _, vjp_fn = jax.vjp(_params_from_latents, latent_params)
299
395
  (latent_grad,) = vjp_fn(grad)
396
+
397
+ if not (
398
+ tree_util.tree_structure(latent_grad)
399
+ == tree_util.tree_structure(latent_params) # type: ignore[operator]
400
+ ):
401
+ raise ValueError(
402
+ f"Tree structure of `latent_grad` was different than expected, got \n"
403
+ f"{tree_util.tree_structure(latent_grad)} but expected \n"
404
+ f"{tree_util.tree_structure(latent_params)}."
405
+ )
406
+
407
+ (
408
+ constraint_loss_value,
409
+ constraint_loss_grad,
410
+ ) = jax.value_and_grad(
411
+ _constraint_loss
412
+ )(latent_params)
413
+ value += constraint_loss_value
414
+ latent_grad = tree_util.tree_map(
415
+ lambda a, b: a + b, latent_grad, constraint_loss_grad
416
+ )
417
+
300
418
  flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
301
419
  latent_grad
302
420
  ) # type: ignore[no-untyped-call]
303
421
 
304
- (
305
- flat_latent_params,
306
- jax_lbfgsb_state,
307
- ) = jax.pure_callback(
422
+ flat_latent_params, jax_lbfgsb_state = jax.pure_callback(
308
423
  _update_pure,
309
424
  (flat_latent_grad, jax_lbfgsb_state),
310
425
  flat_latent_grad,
@@ -312,7 +427,7 @@ def transformed_lbfgsb(
312
427
  jax_lbfgsb_state,
313
428
  )
314
429
  latent_params = unflatten_fn(flat_latent_params)
315
- return transform_fn(latent_params), latent_params, jax_lbfgsb_state
430
+ return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
316
431
 
317
432
  return base.Optimizer(
318
433
  init=init_fn,
@@ -336,6 +451,31 @@ def _is_density(leaf: Any) -> Any:
336
451
  return isinstance(leaf, types.Density2DArray)
337
452
 
338
453
 
454
+ def _is_parameterized_density(leaf: Any) -> Any:
455
+ """Return `True` if `leaf` is a parameterized density array."""
456
+ return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase)
457
+
458
+
459
+ def _is_custom_type(leaf: Any) -> bool:
460
+ """Return `True` if `leaf` is a recognized custom type."""
461
+ return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
462
+
463
+
464
+ def _clip(pytree: PyTree) -> PyTree:
465
+ """Clips leaves on `pytree` to their bounds."""
466
+
467
+ def _clip_fn(leaf: Any) -> Any:
468
+ if not _is_custom_type(leaf):
469
+ return leaf
470
+ if leaf.lower_bound is None and leaf.upper_bound is None:
471
+ return leaf
472
+ return tree_util.tree_map(
473
+ lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
474
+ )
475
+
476
+ return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
477
+
478
+
339
479
  def _to_numpy(params: PyTree) -> NDArray:
340
480
  """Flattens a `params` pytree into a single rank-1 numpy array."""
341
481
  x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]