invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,939 @@
1
+ """Defines a jax-style wrapper for scipy's L-BFGS-B algorithm.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+ import functools
8
+ from packaging import version
9
+ from typing import Any, Dict, Optional, Sequence, Tuple, Union
10
+
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import numpy as onp
14
+ import optax # type: ignore[import-untyped]
15
+ from jax import flatten_util, tree_util
16
+ from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
17
+ _lbfgsb as scipy_lbfgsb,
18
+ )
19
+ from totypes import types
20
+
21
+ from invrs_opt.optimizers import base
22
+ from invrs_opt.parameterization import (
23
+ base as param_base,
24
+ filter_project,
25
+ gaussian_levelset,
26
+ pixel,
27
+ )
28
+
29
+ NDArray = onp.ndarray[Any, Any]
30
+ PyTree = Any
31
+ ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
32
+ NumpyLbfgsbDict = Dict[str, NDArray]
33
+ JaxLbfgsbDict = Dict[str, jnp.ndarray]
34
+ LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]
35
+
36
+
37
+ # Task message prefixes for the underlying L-BFGS-B implementation.
38
+ TASK_START = b"START"
39
+ TASK_FG = b"FG"
40
+ TASK_CONVERGED = b"CONVERGENCE"
41
+
42
+ UPDATE_IPRINT = -1
43
+
44
+ # Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
45
+ MAXCOR_MAX_VALUE = 100
46
+ DEFAULT_MAXCOR = 20
47
+ DEFAULT_LINE_SEARCH_MAX_STEPS = 100
48
+ DEFAULT_FTOL = 0.0
49
+ DEFAULT_GTOL = 0.0
50
+
51
+ # Maps bound scenarios to integers.
52
+ BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
53
+ (True, True): 0, # Both upper and lower bound are `None`.
54
+ (False, True): 1, # Only upper bound is `None`.
55
+ (False, False): 2, # Neither of the bounds are `None`.
56
+ (True, False): 3, # Only the lower bound is `None`.
57
+ }
58
+
59
+ FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
60
+
61
+ if version.Version(jax.__version__) > version.Version("0.4.31"):
62
+ callback_sequential = functools.partial(jax.pure_callback, vmap_method="sequential")
63
+ else:
64
+ callback_sequential = functools.partial(jax.pure_callback, vectorized=False)
65
+
66
+
67
+ def lbfgsb(
68
+ *,
69
+ maxcor: int = DEFAULT_MAXCOR,
70
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
71
+ ftol: float = DEFAULT_FTOL,
72
+ gtol: float = DEFAULT_GTOL,
73
+ ) -> base.Optimizer:
74
+ """Optimizer implementing the standard L-BFGS-B algorithm.
75
+
76
+ The standard L-BFGS-B algorithm uses the direct pixel parameterization for density
77
+ arrays, which simply enforces that values are between the declared upper and lower
78
+ bounds of the density.
79
+
80
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
81
+ the optimizer `params` function will simply return the optimal parameters. The
82
+ convergence can be queried by `is_converged(state)`.
83
+
84
+ Args:
85
+ maxcor: The maximum number of variable metric corrections used to define the
86
+ limited memory matrix, in the L-BFGS-B scheme.
87
+ line_search_max_steps: The maximum number of steps in the line search.
88
+ ftol: Convergence criteria based on function values. See scipy documentation
89
+ for details.
90
+ gtol: Convergence criteria based on gradient.
91
+
92
+ Returns:
93
+ The `Optimizer` implementing the L-BFGS-B optimizer.
94
+ """
95
+ return parameterized_lbfgsb(
96
+ density_parameterization=None,
97
+ penalty=0.0,
98
+ maxcor=maxcor,
99
+ line_search_max_steps=line_search_max_steps,
100
+ ftol=ftol,
101
+ gtol=gtol,
102
+ )
103
+
104
+
105
+ def density_lbfgsb(
106
+ *,
107
+ beta: float,
108
+ maxcor: int = DEFAULT_MAXCOR,
109
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
110
+ ftol: float = DEFAULT_FTOL,
111
+ gtol: float = DEFAULT_GTOL,
112
+ ) -> base.Optimizer:
113
+ """Optimizer using L-BFGS-B algorithm with filter-project density parameterization.
114
+
115
+ In the filter-project density parameterization, the optimization variable
116
+ associated with a density array is a latent density array; the density is obtained
117
+ by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
118
+ full-width at half-maximum equal to the length scale (the mean of declared minimum
119
+ width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
120
+ operation ("projection").
121
+
122
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
123
+ the optimizer `params` function will simply return the optimal parameters. The
124
+ convergence can be queried by `is_converged(state)`.
125
+
126
+ Args:
127
+ beta: Determines the sharpness of the thresholding operation.
128
+ maxcor: The maximum number of variable metric corrections used to define the
129
+ limited memory matrix, in the L-BFGS-B scheme.
130
+ line_search_max_steps: The maximum number of steps in the line search.
131
+ ftol: Convergence criteria based on function values. See scipy documentation
132
+ for details.
133
+ gtol: Convergence criteria based on gradient.
134
+
135
+ Returns:
136
+ The `Optimizer` implementing the L-BFGS-B optimizer.
137
+ """
138
+ return parameterized_lbfgsb(
139
+ density_parameterization=filter_project.filter_project(beta=beta),
140
+ penalty=0.0,
141
+ maxcor=maxcor,
142
+ line_search_max_steps=line_search_max_steps,
143
+ ftol=ftol,
144
+ gtol=gtol,
145
+ )
146
+
147
+
148
+ def levelset_lbfgsb(
149
+ *,
150
+ penalty: float,
151
+ length_scale_spacing_factor: float = (
152
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
153
+ ),
154
+ length_scale_fwhm_factor: float = (
155
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
156
+ ),
157
+ length_scale_constraint_factor: float = (
158
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
159
+ ),
160
+ smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
161
+ length_scale_constraint_beta: float = (
162
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
163
+ ),
164
+ length_scale_constraint_weight: float = (
165
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
166
+ ),
167
+ curvature_constraint_weight: float = (
168
+ gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
169
+ ),
170
+ fixed_pixel_constraint_weight: float = (
171
+ gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
172
+ ),
173
+ init_optimizer: optax.GradientTransformation = (
174
+ gaussian_levelset.DEFAULT_INIT_OPTIMIZER
175
+ ),
176
+ init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
177
+ maxcor: int = DEFAULT_MAXCOR,
178
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
179
+ ftol: float = DEFAULT_FTOL,
180
+ gtol: float = DEFAULT_GTOL,
181
+ ) -> base.Optimizer:
182
+ """Optimizer using L-BFGS-B algorithm with levelset density parameterization.
183
+
184
+ In the levelset parameterization, the optimization variable associated with a
185
+ density array is an array giving the amplitudes of Gaussian radial basis functions
186
+ that represent a levelset function over the domain of the density. In the levelset
187
+ parameterization, gradients are nonzero only at the edges of features, and in
188
+ general the topology of a solution does not change during the course of
189
+ optimization.
190
+
191
+ The spacing and full-width at half-maximum of the Gaussian basis functions gives
192
+ some amount of control over length scales. In addition, constraints associated with
193
+ length scale, radius of curvature, and deviation from fixed pixels are
194
+ automatically computed and penalized with a weight given by `penalty`. In general,
195
+ this helps ensure that features in an optimized density array violate the specified
196
+ constraints to a lesser degree. The constraints are based on "Analytical level set
197
+ fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
198
+
199
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
200
+ the optimizer `params` function will simply return the optimal parameters. The
201
+ convergence can be queried by `is_converged(state)`.
202
+
203
+ Args:
204
+ penalty: The weight of the fabrication penalty, which combines length scale,
205
+ curvature, and fixed pixel constraints.
206
+ length_scale_spacing_factor: The number of levelset control points per unit of
207
+ minimum length scale (mean of density minimum width and minimum spacing).
208
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
209
+ the minimum length scale.
210
+ length_scale_constraint_factor: Multiplies the target length scale in the
211
+ levelset constraints. A value greater than 1 is pessimistic and drives the
212
+ solution to have a larger length scale (relative to smaller values).
213
+ smoothing_factor: For values greater than 1, the density is initially computed
214
+ at higher resolution and then downsampled, yielding smoother geometries.
215
+ length_scale_constraint_beta: Controls relaxation of the length scale
216
+ constraint near the zero level.
217
+ length_scale_constraint_weight: The weight of the length scale constraint in
218
+ the overall fabrication constraint peenalty.
219
+ curvature_constraint_weight: The weight of the curvature constraint.
220
+ fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
221
+ init_optimizer: The optimizer used in the initialization of the levelset
222
+ parameterization. At initialization, the latent parameters are optimized so
223
+ that the initial parameters match the binarized initial density.
224
+ init_steps: The number of optimization steps used in the initialization.
225
+ maxcor: The maximum number of variable metric corrections used to define the
226
+ limited memory matrix, in the L-BFGS-B scheme.
227
+ line_search_max_steps: The maximum number of steps in the line search.
228
+ ftol: Convergence criteria based on function values. See scipy documentation
229
+ for details.
230
+ gtol: Convergence criteria based on gradient.
231
+
232
+ Returns:
233
+ The `Optimizer` implementing the L-BFGS-B optimizer.
234
+ """
235
+ return parameterized_lbfgsb(
236
+ density_parameterization=gaussian_levelset.gaussian_levelset(
237
+ length_scale_spacing_factor=length_scale_spacing_factor,
238
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
239
+ length_scale_constraint_factor=length_scale_constraint_factor,
240
+ smoothing_factor=smoothing_factor,
241
+ length_scale_constraint_beta=length_scale_constraint_beta,
242
+ length_scale_constraint_weight=length_scale_constraint_weight,
243
+ curvature_constraint_weight=curvature_constraint_weight,
244
+ fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
245
+ init_optimizer=init_optimizer,
246
+ init_steps=init_steps,
247
+ ),
248
+ penalty=penalty,
249
+ maxcor=maxcor,
250
+ line_search_max_steps=line_search_max_steps,
251
+ ftol=ftol,
252
+ gtol=gtol,
253
+ )
254
+
255
+
256
+ # -----------------------------------------------------------------------------
257
+ # Base parameterized L-BFGS-B optimizer.
258
+ # -----------------------------------------------------------------------------
259
+
260
+
261
+ def parameterized_lbfgsb(
262
+ density_parameterization: Optional[param_base.Density2DParameterization],
263
+ penalty: float,
264
+ maxcor: int = DEFAULT_MAXCOR,
265
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
266
+ ftol: float = DEFAULT_FTOL,
267
+ gtol: float = DEFAULT_GTOL,
268
+ ) -> base.Optimizer:
269
+ """Optimizer using L-BFGS-B optimizer with specified density parameterization.
270
+
271
+ This optimizer wraps scipy's implementation of the algorithm, and provides
272
+ a jax-style API to the scheme. The optimizer works with custom types such
273
+ as the `BoundedArray` to constrain the optimization variable.
274
+
275
+ Args:
276
+ density_parameterization: The parameterization to be used, or `None`. When no
277
+ parameterization is given, the direct pixel parameterization is used for
278
+ density arrays.
279
+ penalty: The weight of the scalar penalty formed from the constraints of the
280
+ parameterization.
281
+ maxcor: The maximum number of variable metric corrections used to define the
282
+ limited memory matrix, in the L-BFGS-B scheme.
283
+ line_search_max_steps: The maximum number of steps in the line search.
284
+ ftol: Convergence criteria based on function values. See scipy documentation
285
+ for details.
286
+ gtol: Convergence criteria based on gradient.
287
+
288
+ Returns:
289
+ The `base.Optimizer`.
290
+ """
291
+ if not isinstance(maxcor, int) or maxcor < 1 or maxcor > MAXCOR_MAX_VALUE:
292
+ raise ValueError(
293
+ f"`maxcor` must be greater than 0 and less than "
294
+ f"{MAXCOR_MAX_VALUE}, but got {maxcor}"
295
+ )
296
+
297
+ if not isinstance(line_search_max_steps, int) or line_search_max_steps < 1:
298
+ raise ValueError(
299
+ f"`line_search_max_steps` must be greater than 0 but got "
300
+ f"{line_search_max_steps}"
301
+ )
302
+
303
+ if density_parameterization is None:
304
+ density_parameterization = pixel.pixel()
305
+
306
+ def init_fn(params: PyTree) -> LbfgsbState:
307
+ """Initializes the optimization state."""
308
+
309
+ def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, NumpyLbfgsbDict]:
310
+ lower_bound = types.extract_lower_bound(latent_params)
311
+ upper_bound = types.extract_upper_bound(latent_params)
312
+ scipy_lbfgsb_state = ScipyLbfgsbState.init(
313
+ x0=_to_numpy(latent_params),
314
+ lower_bound=_bound_for_params(lower_bound, latent_params),
315
+ upper_bound=_bound_for_params(upper_bound, latent_params),
316
+ maxcor=maxcor,
317
+ line_search_max_steps=line_search_max_steps,
318
+ ftol=ftol,
319
+ gtol=gtol,
320
+ )
321
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_params)
322
+ return latent_params, scipy_lbfgsb_state.to_dict()
323
+
324
+ latent_params = _init_latents(params)
325
+ metadata, latents = param_base.partition_density_metadata(latent_params)
326
+ latents, jax_lbfgsb_state = callback_sequential(
327
+ _init_state_pure,
328
+ _example_state(latents, maxcor),
329
+ latents,
330
+ )
331
+ latent_params = param_base.combine_density_metadata(metadata, latents)
332
+ return (
333
+ 0, # step
334
+ _params_from_latent_params(latent_params), # params
335
+ latent_params, # latent params
336
+ jax_lbfgsb_state, # opt state
337
+ )
338
+
339
+ def params_fn(state: LbfgsbState) -> PyTree:
340
+ """Returns the parameters for the given `state`."""
341
+ _, params, _, _ = state
342
+ return params
343
+
344
+ def update_fn(
345
+ *,
346
+ grad: PyTree,
347
+ value: jnp.ndarray,
348
+ params: PyTree,
349
+ state: LbfgsbState,
350
+ ) -> LbfgsbState:
351
+ """Updates the state."""
352
+ del params
353
+
354
+ def _update_pure(
355
+ flat_latent_grad: PyTree,
356
+ value: jnp.ndarray,
357
+ jax_lbfgsb_state: JaxLbfgsbDict,
358
+ ) -> Tuple[NDArray, NumpyLbfgsbDict]:
359
+ assert onp.size(value) == 1
360
+ scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
361
+ flat_latent_params = scipy_lbfgsb_state.x.copy()
362
+ scipy_lbfgsb_state.update(
363
+ grad=onp.array(flat_latent_grad, dtype=onp.float64),
364
+ value=onp.array(value, dtype=onp.float64),
365
+ )
366
+ updated_flat_latent_params = scipy_lbfgsb_state.x
367
+ flat_latent_updates: NDArray
368
+ flat_latent_updates = updated_flat_latent_params - flat_latent_params
369
+ return flat_latent_updates, scipy_lbfgsb_state.to_dict()
370
+
371
+ step, _, latent_params, jax_lbfgsb_state = state
372
+ metadata, latents = param_base.partition_density_metadata(latent_params)
373
+
374
+ def _params_from_latents(latents: PyTree) -> PyTree:
375
+ latent_params = param_base.combine_density_metadata(metadata, latents)
376
+ return _params_from_latent_params(latent_params)
377
+
378
+ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
379
+ latent_params = param_base.combine_density_metadata(metadata, latents)
380
+ return _constraint_loss(latent_params)
381
+
382
+ _, vjp_fn = jax.vjp(_params_from_latents, latents)
383
+ (latents_grad,) = vjp_fn(grad)
384
+
385
+ if not (
386
+ tree_util.tree_structure(latents_grad)
387
+ == tree_util.tree_structure(latents) # type: ignore[operator]
388
+ ):
389
+ raise ValueError(
390
+ f"Tree structure of `latents_grad` was different than expected, got \n"
391
+ f"{tree_util.tree_structure(latents_grad)} but expected \n"
392
+ f"{tree_util.tree_structure(latents)}."
393
+ )
394
+
395
+ (
396
+ constraint_loss_value,
397
+ constraint_loss_grad,
398
+ ) = jax.value_and_grad(
399
+ _constraint_loss_latents
400
+ )(latents)
401
+ value += constraint_loss_value
402
+ latents_grad = tree_util.tree_map(
403
+ lambda a, b: a + b, latents_grad, constraint_loss_grad
404
+ )
405
+
406
+ flat_latents_grad, unflatten_fn = flatten_util.ravel_pytree(
407
+ latents_grad
408
+ ) # type: ignore[no-untyped-call]
409
+
410
+ flat_latent_updates, jax_lbfgsb_state = callback_sequential(
411
+ _update_pure,
412
+ (flat_latents_grad, jax_lbfgsb_state),
413
+ flat_latents_grad,
414
+ value,
415
+ jax_lbfgsb_state,
416
+ )
417
+ latent_updates = unflatten_fn(flat_latent_updates)
418
+ latent_params = _apply_updates(
419
+ params=latent_params,
420
+ updates=param_base.combine_density_metadata(metadata, latent_updates),
421
+ value=value,
422
+ step=step,
423
+ )
424
+ latent_params = _clip(latent_params)
425
+ params = _params_from_latent_params(latent_params)
426
+ return step + 1, params, latent_params, jax_lbfgsb_state
427
+
428
+ # -------------------------------------------------------------------------
429
+ # Functions related to the density parameterization.
430
+ # -------------------------------------------------------------------------
431
+
432
+ def _init_latents(params: PyTree) -> PyTree:
433
+ def _leaf_init_latents(leaf: Any) -> Any:
434
+ leaf = _clip(leaf)
435
+ if not _is_density(leaf) or density_parameterization is None:
436
+ return leaf
437
+ return density_parameterization.from_density(leaf)
438
+
439
+ return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
440
+
441
+ def _params_from_latent_params(latent_params: PyTree) -> PyTree:
442
+ def _leaf_params_from_latents(leaf: Any) -> Any:
443
+ if not _is_parameterized_density(leaf) or density_parameterization is None:
444
+ return leaf
445
+ return density_parameterization.to_density(leaf)
446
+
447
+ return tree_util.tree_map(
448
+ _leaf_params_from_latents,
449
+ latent_params,
450
+ is_leaf=_is_parameterized_density,
451
+ )
452
+
453
+ def _apply_updates(
454
+ params: PyTree,
455
+ updates: PyTree,
456
+ value: jnp.ndarray,
457
+ step: int,
458
+ ) -> PyTree:
459
+ def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
460
+ if _is_parameterized_density(leaf):
461
+ return density_parameterization.update(
462
+ params=leaf, updates=update, value=value, step=step
463
+ )
464
+ else:
465
+ return optax.apply_updates(params=leaf, updates=update)
466
+
467
+ return tree_util.tree_map(
468
+ _leaf_apply_updates,
469
+ updates,
470
+ params,
471
+ is_leaf=_is_parameterized_density,
472
+ )
473
+
474
+ # -------------------------------------------------------------------------
475
+ # Functions related to the constraints to be minimized.
476
+ # -------------------------------------------------------------------------
477
+
478
+ def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
479
+ def _constraint_loss_leaf(
480
+ leaf: param_base.ParameterizedDensity2DArray,
481
+ ) -> jnp.ndarray:
482
+ constraints = density_parameterization.constraints(leaf)
483
+ constraints = tree_util.tree_map(
484
+ lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
485
+ constraints,
486
+ )
487
+ return jnp.sum(jnp.asarray(constraints))
488
+
489
+ losses = [0.0] + [
490
+ _constraint_loss_leaf(p)
491
+ for p in tree_util.tree_leaves(
492
+ latent_params, is_leaf=_is_parameterized_density
493
+ )
494
+ if _is_parameterized_density(p)
495
+ ]
496
+ return penalty * jnp.sum(jnp.asarray(losses))
497
+
498
+ return base.Optimizer(
499
+ init=init_fn,
500
+ params=params_fn,
501
+ update=update_fn,
502
+ )
503
+
504
+
505
+ def is_converged(state: LbfgsbState) -> jnp.ndarray:
506
+ """Returns `True` if the optimization has converged."""
507
+ return state[3]["converged"]
508
+
509
+
510
+ # ------------------------------------------------------------------------------
511
+ # Helper functions.
512
+ # ------------------------------------------------------------------------------
513
+
514
+
515
+ def _is_density(leaf: Any) -> Any:
516
+ """Return `True` if `leaf` is a density array."""
517
+ return isinstance(leaf, types.Density2DArray)
518
+
519
+
520
+ def _is_parameterized_density(leaf: Any) -> Any:
521
+ """Return `True` if `leaf` is a parameterized density array."""
522
+ return isinstance(leaf, param_base.ParameterizedDensity2DArray)
523
+
524
+
525
+ def _is_custom_type(leaf: Any) -> bool:
526
+ """Return `True` if `leaf` is a recognized custom type."""
527
+ return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
528
+
529
+
530
+ def _clip(pytree: PyTree) -> PyTree:
531
+ """Clips leaves on `pytree` to their bounds."""
532
+
533
+ def _clip_fn(leaf: Any) -> Any:
534
+ if not _is_custom_type(leaf):
535
+ return leaf
536
+ if leaf.lower_bound is None and leaf.upper_bound is None:
537
+ return leaf
538
+ return tree_util.tree_map(
539
+ lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
540
+ )
541
+
542
+ return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
543
+
544
+
545
+ def _to_numpy(params: PyTree) -> NDArray:
546
+ """Flattens a `params` pytree into a single rank-1 numpy array."""
547
+ x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
548
+ return onp.asarray(x, dtype=onp.float64)
549
+
550
+
551
+ def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
552
+ """Restores a pytree from a flat numpy array using the structure of `params`.
553
+
554
+ Note that the returned pytree includes jax array leaves.
555
+
556
+ Args:
557
+ x_flat: The rank-1 numpy array to be restored.
558
+ params: A pytree of parameters whose structure is replicated in the restored
559
+ pytree.
560
+
561
+ Returns:
562
+ The restored pytree, with jax array leaves.
563
+ """
564
+ _, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
565
+ return unflatten_fn(jnp.asarray(x_flat, dtype=float))
566
+
567
+
568
+ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
569
+ """Generates a bound vector for the `params`.
570
+
571
+ The `bound` can be specified in various ways; it may be `None` or a scalar,
572
+ which then applies to all arrays in `params`. It may be a pytree with
573
+ structure matching that of `params`, where each leaf is either `None`, a
574
+ scalar, or an array matching the shape of the corresponding leaf in `params`.
575
+
576
+ The returned bound is a flat array suitable for use with `ScipyLbfgsbState`.
577
+
578
+ Args:
579
+ bound: The pytree of bounds.
580
+ params: The pytree of parameters.
581
+
582
+ Returns:
583
+ The flat elementwise bound.
584
+ """
585
+
586
+ if bound is None or onp.isscalar(bound):
587
+ bound = tree_util.tree_map(
588
+ lambda _: bound,
589
+ params,
590
+ is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES),
591
+ )
592
+
593
+ bound_leaves, bound_treedef = tree_util.tree_flatten(
594
+ bound, is_leaf=lambda x: x is None
595
+ )
596
+ params_leaves = tree_util.tree_leaves(params, is_leaf=lambda x: x is None)
597
+
598
+ # `bound` should be a pytree of arrays or `None`, while `params` may
599
+ # include custom pytree nodes. Convert the custom nodes into standard
600
+ # types to facilitate validation that the tree structures match.
601
+ params_treedef = tree_util.tree_structure(
602
+ tree_util.tree_map(
603
+ lambda x: 0.0,
604
+ tree=params,
605
+ is_leaf=lambda x: x is None or isinstance(x, types.CUSTOM_TYPES),
606
+ )
607
+ )
608
+ if bound_treedef != params_treedef: # type: ignore[operator]
609
+ raise ValueError(
610
+ f"Tree structure of `bound` and `params` must match, but got "
611
+ f"{bound_treedef} and {params_treedef}, respectively."
612
+ )
613
+
614
+ bound_flat = []
615
+ for b, p in zip(bound_leaves, params_leaves):
616
+ if p is None:
617
+ continue
618
+ if b is None or onp.isscalar(b) or onp.shape(b) == ():
619
+ bound_flat += [b] * onp.size(p)
620
+ else:
621
+ if b.shape != p.shape:
622
+ raise ValueError(
623
+ f"`bound` must be `None`, a scalar, or have shape matching "
624
+ f"`params`, but got shape {b.shape} when params has shape "
625
+ f"{p.shape}."
626
+ )
627
+ bound_flat += b.flatten().tolist()
628
+
629
+ return bound_flat
630
+
631
+
632
+ def _example_state(params: PyTree, maxcor: int) -> PyTree:
633
+ """Return an example state for the given `params` and `maxcor`."""
634
+ params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
635
+ n = params_flat.size
636
+ float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
637
+ example_jax_lbfgsb_state = dict(
638
+ x=jnp.zeros(n, dtype=float),
639
+ converged=jnp.asarray(False),
640
+ _maxcor=jnp.zeros((), dtype=int),
641
+ _line_search_max_steps=jnp.zeros((), dtype=int),
642
+ _ftol=jnp.zeros((), dtype=float),
643
+ _gtol=jnp.zeros((), dtype=float),
644
+ _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
645
+ _iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
646
+ _task=jnp.zeros(59, dtype=int),
647
+ _csave=jnp.zeros(59, dtype=int),
648
+ _lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
649
+ _isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
650
+ _dsave=jnp.zeros(29, dtype=float),
651
+ _lower_bound=jnp.zeros(n, dtype=float),
652
+ _upper_bound=jnp.zeros(n, dtype=float),
653
+ _bound_type=jnp.zeros(n, dtype=int),
654
+ )
655
+ return float_params, example_jax_lbfgsb_state
656
+
657
+
658
+ # ------------------------------------------------------------------------------
659
+ # Wrapper for scipy's L-BFGS-B implementation.
660
+ # ------------------------------------------------------------------------------
661
+
662
+
663
+ @dataclasses.dataclass
664
+ class ScipyLbfgsbState:
665
+ """Stores the state of a scipy L-BFGS-B minimization.
666
+
667
+ This class enables optimization with a more functional style, giving the user
668
+ control over the optimization loop. Example usage is as follows:
669
+
670
+ value_fn = lambda x: onp.sum(x**2)
671
+ grad_fn = lambda x: 2 * x
672
+
673
+ x0 = onp.asarray([0.1, 0.2, 0.3])
674
+ lb = [None, -1, 0.1]
675
+ ub = [None, None, None]
676
+ state = ScipyLbfgsbState.init(
677
+ x0=x0, lower_bound=lb, upper_bound=ub, maxcor=20
678
+ )
679
+
680
+ for _ in range(10):
681
+ value = value_fn(state.x)
682
+ grad = grad_fn(state.x)
683
+ state.update(grad, value)
684
+
685
+ This example converges with `state.x` equal to `(0, 0, 0.1)` and value equal
686
+ to `0.01`.
687
+
688
+ Attributes:
689
+ x: The current solution vector.
690
+ """
691
+
692
+ x: NDArray
693
+ converged: NDArray
694
+ # Private attributes correspond to internal variables in the `scipy.optimize.
695
+ # lbfgsb._minimize_lbfgsb` function.
696
+ _maxcor: int
697
+ _line_search_max_steps: int
698
+ _ftol: NDArray
699
+ _gtol: NDArray
700
+ _wa: NDArray
701
+ _iwa: NDArray
702
+ _task: NDArray
703
+ _csave: NDArray
704
+ _lsave: NDArray
705
+ _isave: NDArray
706
+ _dsave: NDArray
707
+ _lower_bound: NDArray
708
+ _upper_bound: NDArray
709
+ _bound_type: NDArray
710
+
711
+ def __post_init__(self) -> None:
712
+ """Validates the datatypes for all state attributes."""
713
+ _validate_array_dtype(self.x, onp.float64)
714
+ _validate_array_dtype(self._wa, onp.float64)
715
+ _validate_array_dtype(self._iwa, FORTRAN_INT)
716
+ _validate_array_dtype(self._task, "S60")
717
+ _validate_array_dtype(self._csave, "S60")
718
+ _validate_array_dtype(self._lsave, FORTRAN_INT)
719
+ _validate_array_dtype(self._isave, FORTRAN_INT)
720
+ _validate_array_dtype(self._dsave, onp.float64)
721
+ _validate_array_dtype(self._lower_bound, onp.float64)
722
+ _validate_array_dtype(self._upper_bound, onp.float64)
723
+ _validate_array_dtype(self._bound_type, int)
724
+
725
+ def to_dict(self) -> NumpyLbfgsbDict:
726
+ """Generates a dictionary of jax arrays defining the state."""
727
+ return dict(
728
+ x=onp.asarray(self.x),
729
+ converged=onp.asarray(self.converged),
730
+ _maxcor=onp.asarray(self._maxcor),
731
+ _line_search_max_steps=onp.asarray(self._line_search_max_steps),
732
+ _ftol=onp.asarray(self._ftol),
733
+ _gtol=onp.asarray(self._gtol),
734
+ _wa=onp.asarray(self._wa),
735
+ _iwa=onp.asarray(self._iwa),
736
+ _task=_array_from_s60_str(self._task),
737
+ _csave=_array_from_s60_str(self._csave),
738
+ _lsave=onp.asarray(self._lsave),
739
+ _isave=onp.asarray(self._isave),
740
+ _dsave=onp.asarray(self._dsave),
741
+ _lower_bound=onp.asarray(self._lower_bound),
742
+ _upper_bound=onp.asarray(self._upper_bound),
743
+ _bound_type=onp.asarray(self._bound_type),
744
+ )
745
+
746
+ @classmethod
747
+ def from_jax(cls, state_dict: JaxLbfgsbDict) -> "ScipyLbfgsbState":
748
+ """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
749
+ return ScipyLbfgsbState(
750
+ x=onp.array(state_dict["x"], dtype=onp.float64),
751
+ converged=onp.asarray(state_dict["converged"], dtype=bool),
752
+ _maxcor=int(state_dict["_maxcor"]),
753
+ _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
754
+ _ftol=onp.asarray(state_dict["_ftol"], dtype=onp.float64),
755
+ _gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
756
+ _wa=onp.array(state_dict["_wa"], onp.float64),
757
+ _iwa=onp.array(state_dict["_iwa"], dtype=FORTRAN_INT),
758
+ _task=_s60_str_from_array(onp.asarray(state_dict["_task"])),
759
+ _csave=_s60_str_from_array(onp.asarray(state_dict["_csave"])),
760
+ _lsave=onp.array(state_dict["_lsave"], dtype=FORTRAN_INT),
761
+ _isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
762
+ _dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
763
+ _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
764
+ _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
765
+ _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
766
+ )
767
+
768
+ @classmethod
769
+ def init(
770
+ cls,
771
+ x0: NDArray,
772
+ lower_bound: ElementwiseBound,
773
+ upper_bound: ElementwiseBound,
774
+ maxcor: int,
775
+ line_search_max_steps: int,
776
+ ftol: float,
777
+ gtol: float,
778
+ ) -> "ScipyLbfgsbState":
779
+ """Initializes the `ScipyLbfgsbState` for `x0`.
780
+
781
+ Args:
782
+ x0: Array giving the initial solution vector.
783
+ lower_bound: Array giving the elementwise optional lower bound.
784
+ upper_bound: Array giving the elementwise optional upper bound.
785
+ maxcor: The maximum number of variable metric corrections used to define
786
+ the limited memory matrix, in the L-BFGS-B scheme.
787
+ line_search_max_steps: The maximum number of steps in the line search.
788
+ ftol: Tolerance for stopping criteria based on function values. See scipy
789
+ documentation for details.
790
+ gtol: Tolerance for stopping criteria based on gradient.
791
+
792
+ Returns:
793
+ The `ScipyLbfgsbState`.
794
+ """
795
+ x0 = onp.asarray(x0)
796
+ if x0.ndim > 1:
797
+ raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
798
+ lower_bound = onp.asarray(lower_bound)
799
+ upper_bound = onp.asarray(upper_bound)
800
+ if x0.shape != lower_bound.shape or x0.shape != upper_bound.shape:
801
+ raise ValueError(
802
+ f"`x0`, `lower_bound`, and `upper_bound` must have matching "
803
+ f"shape but got shapes {x0.shape}, {lower_bound.shape}, and "
804
+ f"{upper_bound.shape}, respectively."
805
+ )
806
+ if maxcor < 1:
807
+ raise ValueError(f"`maxcor` must be positive but got {maxcor}.")
808
+
809
+ n = x0.size
810
+ lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
811
+ lower_bound, upper_bound
812
+ )
813
+ task = onp.zeros(1, "S60")
814
+ task[:] = TASK_START
815
+
816
+ # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
817
+ # function.
818
+ wa_size = _wa_size(n=n, maxcor=maxcor)
819
+ state = ScipyLbfgsbState(
820
+ x=onp.array(x0, onp.float64),
821
+ converged=onp.asarray(False),
822
+ _maxcor=maxcor,
823
+ _line_search_max_steps=line_search_max_steps,
824
+ _ftol=onp.asarray(ftol, onp.float64),
825
+ _gtol=onp.asarray(gtol, onp.float64),
826
+ _wa=onp.zeros(wa_size, onp.float64),
827
+ _iwa=onp.zeros(3 * n, FORTRAN_INT),
828
+ _task=task,
829
+ _csave=onp.zeros(1, "S60"),
830
+ _lsave=onp.zeros(4, FORTRAN_INT),
831
+ _isave=onp.zeros(44, FORTRAN_INT),
832
+ _dsave=onp.zeros(29, onp.float64),
833
+ _lower_bound=lower_bound_array,
834
+ _upper_bound=upper_bound_array,
835
+ _bound_type=bound_type,
836
+ )
837
+ # The initial state requires an update with zero value and gradient. This
838
+ # is because the initial task is "START", which does not actually require
839
+ # value and gradient evaluation.
840
+ state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
841
+ return state
842
+
843
+ def update(
844
+ self,
845
+ grad: NDArray,
846
+ value: NDArray,
847
+ ) -> None:
848
+ """Performs an in-place update of the `ScipyLbfgsbState` if not converged.
849
+
850
+ Args:
851
+ grad: The function gradient for the current `x`.
852
+ value: The scalar function value for the current `x`.
853
+ """
854
+ if self.converged:
855
+ return
856
+ if grad.shape != self.x.shape:
857
+ raise ValueError(
858
+ f"`grad` must have the same shape as attribute `x`, but got shapes "
859
+ f"{grad.shape} and {self.x.shape}, respectively."
860
+ )
861
+ if value.shape != ():
862
+ raise ValueError(f"`value` must be a scalar but got shape {value.shape}.")
863
+
864
+ # The `setulb` function will sometimes return with a task that does not
865
+ # require a value and gradient evaluation. In this case we simply call it
866
+ # again, advancing past such "dummy" steps.
867
+ for _ in range(3):
868
+ scipy_lbfgsb.setulb(
869
+ m=self._maxcor,
870
+ x=self.x,
871
+ l=self._lower_bound,
872
+ u=self._upper_bound,
873
+ nbd=self._bound_type,
874
+ f=value,
875
+ g=grad,
876
+ factr=self._ftol / onp.finfo(float).eps,
877
+ pgtol=self._gtol,
878
+ wa=self._wa,
879
+ iwa=self._iwa,
880
+ task=self._task,
881
+ iprint=UPDATE_IPRINT,
882
+ csave=self._csave,
883
+ lsave=self._lsave,
884
+ isave=self._isave,
885
+ dsave=self._dsave,
886
+ maxls=self._line_search_max_steps,
887
+ )
888
+ task_str = self._task.tobytes()
889
+ if task_str.startswith(TASK_CONVERGED):
890
+ self.converged = onp.asarray(True)
891
+ if task_str.startswith(TASK_FG):
892
+ break
893
+
894
+
895
+ def _wa_size(n: int, maxcor: int) -> int:
896
+ """Return the size of the `wa` attribute of lbfgsb state."""
897
+ return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
898
+
899
+
900
+ def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
901
+ """Validates that `x` is an array with the specified `dtype`."""
902
+ if not isinstance(x, onp.ndarray):
903
+ raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
904
+ if x.dtype != dtype:
905
+ raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
906
+
907
+
908
+ def _configure_bounds(
909
+ lower_bound: ElementwiseBound,
910
+ upper_bound: ElementwiseBound,
911
+ ) -> Tuple[NDArray, NDArray, NDArray]:
912
+ """Configures the bounds for an L-BFGS-B optimization."""
913
+ bound_type = [
914
+ BOUNDS_MAP[(lower is None, upper is None)]
915
+ for lower, upper in zip(lower_bound, upper_bound)
916
+ ]
917
+ lower_bound_array = [0.0 if x is None else x for x in lower_bound]
918
+ upper_bound_array = [0.0 if x is None else x for x in upper_bound]
919
+ return (
920
+ onp.asarray(lower_bound_array, onp.float64),
921
+ onp.asarray(upper_bound_array, onp.float64),
922
+ onp.asarray(bound_type),
923
+ )
924
+
925
+
926
+ def _array_from_s60_str(s60_str: NDArray) -> NDArray:
927
+ """Return a jax array for a numpy s60 string."""
928
+ assert s60_str.shape == (1,)
929
+ chars = [int(o) for o in s60_str[0]]
930
+ chars.extend([32] * (59 - len(chars)))
931
+ return onp.asarray(chars, dtype=int)
932
+
933
+
934
+ def _s60_str_from_array(array: NDArray) -> NDArray:
935
+ """Return a numpy s60 string for a jax array."""
936
+ return onp.asarray(
937
+ [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
938
+ dtype="S60",
939
+ )