invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,939 @@
1
+ """Defines a jax-style wrapper for scipy's L-BFGS-B algorithm.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+ import functools
8
+ from packaging import version
9
+ from typing import Any, Dict, Optional, Sequence, Tuple, Union
10
+
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import numpy as onp
14
+ import optax # type: ignore[import-untyped]
15
+ from jax import flatten_util, tree_util
16
+ from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
17
+ _lbfgsb as scipy_lbfgsb,
18
+ )
19
+ from totypes import types
20
+
21
+ from invrs_opt.optimizers import base
22
+ from invrs_opt.parameterization import (
23
+ base as param_base,
24
+ filter_project,
25
+ gaussian_levelset,
26
+ pixel,
27
+ )
28
+
29
+ NDArray = onp.ndarray[Any, Any]
30
+ PyTree = Any
31
+ ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
32
+ NumpyLbfgsbDict = Dict[str, NDArray]
33
+ JaxLbfgsbDict = Dict[str, jnp.ndarray]
34
+ LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]
35
+
36
+
37
+ # Task message prefixes for the underlying L-BFGS-B implementation.
38
+ TASK_START = b"START"
39
+ TASK_FG = b"FG"
40
+ TASK_CONVERGED = b"CONVERGENCE"
41
+
42
+ UPDATE_IPRINT = -1
43
+
44
+ # Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
45
+ MAXCOR_MAX_VALUE = 100
46
+ DEFAULT_MAXCOR = 20
47
+ DEFAULT_LINE_SEARCH_MAX_STEPS = 100
48
+ DEFAULT_FTOL = 0.0
49
+ DEFAULT_GTOL = 0.0
50
+
51
+ # Maps bound scenarios to integers.
52
+ BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
53
+ (True, True): 0, # Both upper and lower bound are `None`.
54
+ (False, True): 1, # Only upper bound is `None`.
55
+ (False, False): 2, # Neither of the bounds are `None`.
56
+ (True, False): 3, # Only the lower bound is `None`.
57
+ }
58
+
59
+ FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
60
+
61
+ if version.Version(jax.__version__) > version.Version("0.4.31"):
62
+ callback_sequential = functools.partial(jax.pure_callback, vmap_method="sequential")
63
+ else:
64
+ callback_sequential = functools.partial(jax.pure_callback, vectorized=False)
65
+
66
+
67
+ def lbfgsb(
68
+ *,
69
+ maxcor: int = DEFAULT_MAXCOR,
70
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
71
+ ftol: float = DEFAULT_FTOL,
72
+ gtol: float = DEFAULT_GTOL,
73
+ ) -> base.Optimizer:
74
+ """Optimizer implementing the standard L-BFGS-B algorithm.
75
+
76
+ The standard L-BFGS-B algorithm uses the direct pixel parameterization for density
77
+ arrays, which simply enforces that values are between the declared upper and lower
78
+ bounds of the density.
79
+
80
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
81
+ the optimizer `params` function will simply return the optimal parameters. The
82
+ convergence can be queried by `is_converged(state)`.
83
+
84
+ Args:
85
+ maxcor: The maximum number of variable metric corrections used to define the
86
+ limited memory matrix, in the L-BFGS-B scheme.
87
+ line_search_max_steps: The maximum number of steps in the line search.
88
+ ftol: Convergence criteria based on function values. See scipy documentation
89
+ for details.
90
+ gtol: Convergence criteria based on gradient.
91
+
92
+ Returns:
93
+ The `Optimizer` implementing the L-BFGS-B optimizer.
94
+ """
95
+ return parameterized_lbfgsb(
96
+ density_parameterization=None,
97
+ penalty=0.0,
98
+ maxcor=maxcor,
99
+ line_search_max_steps=line_search_max_steps,
100
+ ftol=ftol,
101
+ gtol=gtol,
102
+ )
103
+
104
+
105
+ def density_lbfgsb(
106
+ *,
107
+ beta: float,
108
+ maxcor: int = DEFAULT_MAXCOR,
109
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
110
+ ftol: float = DEFAULT_FTOL,
111
+ gtol: float = DEFAULT_GTOL,
112
+ ) -> base.Optimizer:
113
+ """Optimizer using L-BFGS-B algorithm with filter-project density parameterization.
114
+
115
+ In the filter-project density parameterization, the optimization variable
116
+ associated with a density array is a latent density array; the density is obtained
117
+ by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
118
+ full-width at half-maximum equal to the length scale (the mean of declared minimum
119
+ width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
120
+ operation ("projection").
121
+
122
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
123
+ the optimizer `params` function will simply return the optimal parameters. The
124
+ convergence can be queried by `is_converged(state)`.
125
+
126
+ Args:
127
+ beta: Determines the sharpness of the thresholding operation.
128
+ maxcor: The maximum number of variable metric corrections used to define the
129
+ limited memory matrix, in the L-BFGS-B scheme.
130
+ line_search_max_steps: The maximum number of steps in the line search.
131
+ ftol: Convergence criteria based on function values. See scipy documentation
132
+ for details.
133
+ gtol: Convergence criteria based on gradient.
134
+
135
+ Returns:
136
+ The `Optimizer` implementing the L-BFGS-B optimizer.
137
+ """
138
+ return parameterized_lbfgsb(
139
+ density_parameterization=filter_project.filter_project(beta=beta),
140
+ penalty=0.0,
141
+ maxcor=maxcor,
142
+ line_search_max_steps=line_search_max_steps,
143
+ ftol=ftol,
144
+ gtol=gtol,
145
+ )
146
+
147
+
148
+ def levelset_lbfgsb(
149
+ *,
150
+ penalty: float,
151
+ length_scale_spacing_factor: float = (
152
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
153
+ ),
154
+ length_scale_fwhm_factor: float = (
155
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
156
+ ),
157
+ length_scale_constraint_factor: float = (
158
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
159
+ ),
160
+ smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
161
+ length_scale_constraint_beta: float = (
162
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
163
+ ),
164
+ length_scale_constraint_weight: float = (
165
+ gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
166
+ ),
167
+ curvature_constraint_weight: float = (
168
+ gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
169
+ ),
170
+ fixed_pixel_constraint_weight: float = (
171
+ gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
172
+ ),
173
+ init_optimizer: optax.GradientTransformation = (
174
+ gaussian_levelset.DEFAULT_INIT_OPTIMIZER
175
+ ),
176
+ init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
177
+ maxcor: int = DEFAULT_MAXCOR,
178
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
179
+ ftol: float = DEFAULT_FTOL,
180
+ gtol: float = DEFAULT_GTOL,
181
+ ) -> base.Optimizer:
182
+ """Optimizer using L-BFGS-B algorithm with levelset density parameterization.
183
+
184
+ In the levelset parameterization, the optimization variable associated with a
185
+ density array is an array giving the amplitudes of Gaussian radial basis functions
186
+ that represent a levelset function over the domain of the density. In the levelset
187
+ parameterization, gradients are nonzero only at the edges of features, and in
188
+ general the topology of a solution does not change during the course of
189
+ optimization.
190
+
191
+ The spacing and full-width at half-maximum of the Gaussian basis functions gives
192
+ some amount of control over length scales. In addition, constraints associated with
193
+ length scale, radius of curvature, and deviation from fixed pixels are
194
+ automatically computed and penalized with a weight given by `penalty`. In general,
195
+ this helps ensure that features in an optimized density array violate the specified
196
+ constraints to a lesser degree. The constraints are based on "Analytical level set
197
+ fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
198
+
199
+ When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
200
+ the optimizer `params` function will simply return the optimal parameters. The
201
+ convergence can be queried by `is_converged(state)`.
202
+
203
+ Args:
204
+ penalty: The weight of the fabrication penalty, which combines length scale,
205
+ curvature, and fixed pixel constraints.
206
+ length_scale_spacing_factor: The number of levelset control points per unit of
207
+ minimum length scale (mean of density minimum width and minimum spacing).
208
+ length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
209
+ the minimum length scale.
210
+ length_scale_constraint_factor: Multiplies the target length scale in the
211
+ levelset constraints. A value greater than 1 is pessimistic and drives the
212
+ solution to have a larger length scale (relative to smaller values).
213
+ smoothing_factor: For values greater than 1, the density is initially computed
214
+ at higher resolution and then downsampled, yielding smoother geometries.
215
+ length_scale_constraint_beta: Controls relaxation of the length scale
216
+ constraint near the zero level.
217
+ length_scale_constraint_weight: The weight of the length scale constraint in
218
+ the overall fabrication constraint peenalty.
219
+ curvature_constraint_weight: The weight of the curvature constraint.
220
+ fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
221
+ init_optimizer: The optimizer used in the initialization of the levelset
222
+ parameterization. At initialization, the latent parameters are optimized so
223
+ that the initial parameters match the binarized initial density.
224
+ init_steps: The number of optimization steps used in the initialization.
225
+ maxcor: The maximum number of variable metric corrections used to define the
226
+ limited memory matrix, in the L-BFGS-B scheme.
227
+ line_search_max_steps: The maximum number of steps in the line search.
228
+ ftol: Convergence criteria based on function values. See scipy documentation
229
+ for details.
230
+ gtol: Convergence criteria based on gradient.
231
+
232
+ Returns:
233
+ The `Optimizer` implementing the L-BFGS-B optimizer.
234
+ """
235
+ return parameterized_lbfgsb(
236
+ density_parameterization=gaussian_levelset.gaussian_levelset(
237
+ length_scale_spacing_factor=length_scale_spacing_factor,
238
+ length_scale_fwhm_factor=length_scale_fwhm_factor,
239
+ length_scale_constraint_factor=length_scale_constraint_factor,
240
+ smoothing_factor=smoothing_factor,
241
+ length_scale_constraint_beta=length_scale_constraint_beta,
242
+ length_scale_constraint_weight=length_scale_constraint_weight,
243
+ curvature_constraint_weight=curvature_constraint_weight,
244
+ fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
245
+ init_optimizer=init_optimizer,
246
+ init_steps=init_steps,
247
+ ),
248
+ penalty=penalty,
249
+ maxcor=maxcor,
250
+ line_search_max_steps=line_search_max_steps,
251
+ ftol=ftol,
252
+ gtol=gtol,
253
+ )
254
+
255
+
256
+ # -----------------------------------------------------------------------------
257
+ # Base parameterized L-BFGS-B optimizer.
258
+ # -----------------------------------------------------------------------------
259
+
260
+
261
+ def parameterized_lbfgsb(
262
+ density_parameterization: Optional[param_base.Density2DParameterization],
263
+ penalty: float,
264
+ maxcor: int = DEFAULT_MAXCOR,
265
+ line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
266
+ ftol: float = DEFAULT_FTOL,
267
+ gtol: float = DEFAULT_GTOL,
268
+ ) -> base.Optimizer:
269
+ """Optimizer using L-BFGS-B optimizer with specified density parameterization.
270
+
271
+ This optimizer wraps scipy's implementation of the algorithm, and provides
272
+ a jax-style API to the scheme. The optimizer works with custom types such
273
+ as the `BoundedArray` to constrain the optimization variable.
274
+
275
+ Args:
276
+ density_parameterization: The parameterization to be used, or `None`. When no
277
+ parameterization is given, the direct pixel parameterization is used for
278
+ density arrays.
279
+ penalty: The weight of the scalar penalty formed from the constraints of the
280
+ parameterization.
281
+ maxcor: The maximum number of variable metric corrections used to define the
282
+ limited memory matrix, in the L-BFGS-B scheme.
283
+ line_search_max_steps: The maximum number of steps in the line search.
284
+ ftol: Convergence criteria based on function values. See scipy documentation
285
+ for details.
286
+ gtol: Convergence criteria based on gradient.
287
+
288
+ Returns:
289
+ The `base.Optimizer`.
290
+ """
291
+ if not isinstance(maxcor, int) or maxcor < 1 or maxcor > MAXCOR_MAX_VALUE:
292
+ raise ValueError(
293
+ f"`maxcor` must be greater than 0 and less than "
294
+ f"{MAXCOR_MAX_VALUE}, but got {maxcor}"
295
+ )
296
+
297
+ if not isinstance(line_search_max_steps, int) or line_search_max_steps < 1:
298
+ raise ValueError(
299
+ f"`line_search_max_steps` must be greater than 0 but got "
300
+ f"{line_search_max_steps}"
301
+ )
302
+
303
+ if density_parameterization is None:
304
+ density_parameterization = pixel.pixel()
305
+
306
+ def init_fn(params: PyTree) -> LbfgsbState:
307
+ """Initializes the optimization state."""
308
+
309
+ def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, NumpyLbfgsbDict]:
310
+ lower_bound = types.extract_lower_bound(latent_params)
311
+ upper_bound = types.extract_upper_bound(latent_params)
312
+ scipy_lbfgsb_state = ScipyLbfgsbState.init(
313
+ x0=_to_numpy(latent_params),
314
+ lower_bound=_bound_for_params(lower_bound, latent_params),
315
+ upper_bound=_bound_for_params(upper_bound, latent_params),
316
+ maxcor=maxcor,
317
+ line_search_max_steps=line_search_max_steps,
318
+ ftol=ftol,
319
+ gtol=gtol,
320
+ )
321
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_params)
322
+ return latent_params, scipy_lbfgsb_state.to_dict()
323
+
324
+ latent_params = _init_latents(params)
325
+ metadata, latents = param_base.partition_density_metadata(latent_params)
326
+ latents, jax_lbfgsb_state = callback_sequential(
327
+ _init_state_pure,
328
+ _example_state(latents, maxcor),
329
+ latents,
330
+ )
331
+ latent_params = param_base.combine_density_metadata(metadata, latents)
332
+ return (
333
+ 0, # step
334
+ _params_from_latent_params(latent_params), # params
335
+ latent_params, # latent params
336
+ jax_lbfgsb_state, # opt state
337
+ )
338
+
339
+ def params_fn(state: LbfgsbState) -> PyTree:
340
+ """Returns the parameters for the given `state`."""
341
+ _, params, _, _ = state
342
+ return params
343
+
344
+ def update_fn(
345
+ *,
346
+ grad: PyTree,
347
+ value: jnp.ndarray,
348
+ params: PyTree,
349
+ state: LbfgsbState,
350
+ ) -> LbfgsbState:
351
+ """Updates the state."""
352
+ del params
353
+
354
+ def _update_pure(
355
+ flat_latent_grad: PyTree,
356
+ value: jnp.ndarray,
357
+ jax_lbfgsb_state: JaxLbfgsbDict,
358
+ ) -> Tuple[NDArray, NumpyLbfgsbDict]:
359
+ assert onp.size(value) == 1
360
+ scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
361
+ flat_latent_params = scipy_lbfgsb_state.x.copy()
362
+ scipy_lbfgsb_state.update(
363
+ grad=onp.array(flat_latent_grad, dtype=onp.float64),
364
+ value=onp.array(value, dtype=onp.float64),
365
+ )
366
+ updated_flat_latent_params = scipy_lbfgsb_state.x
367
+ flat_latent_updates: NDArray
368
+ flat_latent_updates = updated_flat_latent_params - flat_latent_params
369
+ return flat_latent_updates, scipy_lbfgsb_state.to_dict()
370
+
371
+ step, _, latent_params, jax_lbfgsb_state = state
372
+ metadata, latents = param_base.partition_density_metadata(latent_params)
373
+
374
+ def _params_from_latents(latents: PyTree) -> PyTree:
375
+ latent_params = param_base.combine_density_metadata(metadata, latents)
376
+ return _params_from_latent_params(latent_params)
377
+
378
+ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
379
+ latent_params = param_base.combine_density_metadata(metadata, latents)
380
+ return _constraint_loss(latent_params)
381
+
382
+ _, vjp_fn = jax.vjp(_params_from_latents, latents)
383
+ (latents_grad,) = vjp_fn(grad)
384
+
385
+ if not (
386
+ tree_util.tree_structure(latents_grad)
387
+ == tree_util.tree_structure(latents) # type: ignore[operator]
388
+ ):
389
+ raise ValueError(
390
+ f"Tree structure of `latents_grad` was different than expected, got \n"
391
+ f"{tree_util.tree_structure(latents_grad)} but expected \n"
392
+ f"{tree_util.tree_structure(latents)}."
393
+ )
394
+
395
+ (
396
+ constraint_loss_value,
397
+ constraint_loss_grad,
398
+ ) = jax.value_and_grad(
399
+ _constraint_loss_latents
400
+ )(latents)
401
+ value += constraint_loss_value
402
+ latents_grad = tree_util.tree_map(
403
+ lambda a, b: a + b, latents_grad, constraint_loss_grad
404
+ )
405
+
406
+ flat_latents_grad, unflatten_fn = flatten_util.ravel_pytree(
407
+ latents_grad
408
+ ) # type: ignore[no-untyped-call]
409
+
410
+ flat_latent_updates, jax_lbfgsb_state = callback_sequential(
411
+ _update_pure,
412
+ (flat_latents_grad, jax_lbfgsb_state),
413
+ flat_latents_grad,
414
+ value,
415
+ jax_lbfgsb_state,
416
+ )
417
+ latent_updates = unflatten_fn(flat_latent_updates)
418
+ latent_params = _apply_updates(
419
+ params=latent_params,
420
+ updates=param_base.combine_density_metadata(metadata, latent_updates),
421
+ value=value,
422
+ step=step,
423
+ )
424
+ latent_params = _clip(latent_params)
425
+ params = _params_from_latent_params(latent_params)
426
+ return step + 1, params, latent_params, jax_lbfgsb_state
427
+
428
+ # -------------------------------------------------------------------------
429
+ # Functions related to the density parameterization.
430
+ # -------------------------------------------------------------------------
431
+
432
+ def _init_latents(params: PyTree) -> PyTree:
433
+ def _leaf_init_latents(leaf: Any) -> Any:
434
+ leaf = _clip(leaf)
435
+ if not _is_density(leaf) or density_parameterization is None:
436
+ return leaf
437
+ return density_parameterization.from_density(leaf)
438
+
439
+ return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
440
+
441
+ def _params_from_latent_params(latent_params: PyTree) -> PyTree:
442
+ def _leaf_params_from_latents(leaf: Any) -> Any:
443
+ if not _is_parameterized_density(leaf) or density_parameterization is None:
444
+ return leaf
445
+ return density_parameterization.to_density(leaf)
446
+
447
+ return tree_util.tree_map(
448
+ _leaf_params_from_latents,
449
+ latent_params,
450
+ is_leaf=_is_parameterized_density,
451
+ )
452
+
453
+ def _apply_updates(
454
+ params: PyTree,
455
+ updates: PyTree,
456
+ value: jnp.ndarray,
457
+ step: int,
458
+ ) -> PyTree:
459
+ def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
460
+ if _is_parameterized_density(leaf):
461
+ return density_parameterization.update(
462
+ params=leaf, updates=update, value=value, step=step
463
+ )
464
+ else:
465
+ return optax.apply_updates(params=leaf, updates=update)
466
+
467
+ return tree_util.tree_map(
468
+ _leaf_apply_updates,
469
+ updates,
470
+ params,
471
+ is_leaf=_is_parameterized_density,
472
+ )
473
+
474
+ # -------------------------------------------------------------------------
475
+ # Functions related to the constraints to be minimized.
476
+ # -------------------------------------------------------------------------
477
+
478
+ def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
479
+ def _constraint_loss_leaf(
480
+ leaf: param_base.ParameterizedDensity2DArray,
481
+ ) -> jnp.ndarray:
482
+ constraints = density_parameterization.constraints(leaf)
483
+ constraints = tree_util.tree_map(
484
+ lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
485
+ constraints,
486
+ )
487
+ return jnp.sum(jnp.asarray(constraints))
488
+
489
+ losses = [0.0] + [
490
+ _constraint_loss_leaf(p)
491
+ for p in tree_util.tree_leaves(
492
+ latent_params, is_leaf=_is_parameterized_density
493
+ )
494
+ if _is_parameterized_density(p)
495
+ ]
496
+ return penalty * jnp.sum(jnp.asarray(losses))
497
+
498
+ return base.Optimizer(
499
+ init=init_fn,
500
+ params=params_fn,
501
+ update=update_fn,
502
+ )
503
+
504
+
505
+ def is_converged(state: LbfgsbState) -> jnp.ndarray:
506
+ """Returns `True` if the optimization has converged."""
507
+ return state[3]["converged"]
508
+
509
+
510
+ # ------------------------------------------------------------------------------
511
+ # Helper functions.
512
+ # ------------------------------------------------------------------------------
513
+
514
+
515
+ def _is_density(leaf: Any) -> Any:
516
+ """Return `True` if `leaf` is a density array."""
517
+ return isinstance(leaf, types.Density2DArray)
518
+
519
+
520
+ def _is_parameterized_density(leaf: Any) -> Any:
521
+ """Return `True` if `leaf` is a parameterized density array."""
522
+ return isinstance(leaf, param_base.ParameterizedDensity2DArray)
523
+
524
+
525
+ def _is_custom_type(leaf: Any) -> bool:
526
+ """Return `True` if `leaf` is a recognized custom type."""
527
+ return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
528
+
529
+
530
+ def _clip(pytree: PyTree) -> PyTree:
531
+ """Clips leaves on `pytree` to their bounds."""
532
+
533
+ def _clip_fn(leaf: Any) -> Any:
534
+ if not _is_custom_type(leaf):
535
+ return leaf
536
+ if leaf.lower_bound is None and leaf.upper_bound is None:
537
+ return leaf
538
+ return tree_util.tree_map(
539
+ lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
540
+ )
541
+
542
+ return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
543
+
544
+
545
+ def _to_numpy(params: PyTree) -> NDArray:
546
+ """Flattens a `params` pytree into a single rank-1 numpy array."""
547
+ x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
548
+ return onp.asarray(x, dtype=onp.float64)
549
+
550
+
551
+ def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
552
+ """Restores a pytree from a flat numpy array using the structure of `params`.
553
+
554
+ Note that the returned pytree includes jax array leaves.
555
+
556
+ Args:
557
+ x_flat: The rank-1 numpy array to be restored.
558
+ params: A pytree of parameters whose structure is replicated in the restored
559
+ pytree.
560
+
561
+ Returns:
562
+ The restored pytree, with jax array leaves.
563
+ """
564
+ _, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
565
+ return unflatten_fn(jnp.asarray(x_flat, dtype=float))
566
+
567
+
568
+ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
569
+ """Generates a bound vector for the `params`.
570
+
571
+ The `bound` can be specified in various ways; it may be `None` or a scalar,
572
+ which then applies to all arrays in `params`. It may be a pytree with
573
+ structure matching that of `params`, where each leaf is either `None`, a
574
+ scalar, or an array matching the shape of the corresponding leaf in `params`.
575
+
576
+ The returned bound is a flat array suitable for use with `ScipyLbfgsbState`.
577
+
578
+ Args:
579
+ bound: The pytree of bounds.
580
+ params: The pytree of parameters.
581
+
582
+ Returns:
583
+ The flat elementwise bound.
584
+ """
585
+
586
+ if bound is None or onp.isscalar(bound):
587
+ bound = tree_util.tree_map(
588
+ lambda _: bound,
589
+ params,
590
+ is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES),
591
+ )
592
+
593
+ bound_leaves, bound_treedef = tree_util.tree_flatten(
594
+ bound, is_leaf=lambda x: x is None
595
+ )
596
+ params_leaves = tree_util.tree_leaves(params, is_leaf=lambda x: x is None)
597
+
598
+ # `bound` should be a pytree of arrays or `None`, while `params` may
599
+ # include custom pytree nodes. Convert the custom nodes into standard
600
+ # types to facilitate validation that the tree structures match.
601
+ params_treedef = tree_util.tree_structure(
602
+ tree_util.tree_map(
603
+ lambda x: 0.0,
604
+ tree=params,
605
+ is_leaf=lambda x: x is None or isinstance(x, types.CUSTOM_TYPES),
606
+ )
607
+ )
608
+ if bound_treedef != params_treedef: # type: ignore[operator]
609
+ raise ValueError(
610
+ f"Tree structure of `bound` and `params` must match, but got "
611
+ f"{bound_treedef} and {params_treedef}, respectively."
612
+ )
613
+
614
+ bound_flat = []
615
+ for b, p in zip(bound_leaves, params_leaves):
616
+ if p is None:
617
+ continue
618
+ if b is None or onp.isscalar(b) or onp.shape(b) == ():
619
+ bound_flat += [b] * onp.size(p)
620
+ else:
621
+ if b.shape != p.shape:
622
+ raise ValueError(
623
+ f"`bound` must be `None`, a scalar, or have shape matching "
624
+ f"`params`, but got shape {b.shape} when params has shape "
625
+ f"{p.shape}."
626
+ )
627
+ bound_flat += b.flatten().tolist()
628
+
629
+ return bound_flat
630
+
631
+
632
+ def _example_state(params: PyTree, maxcor: int) -> PyTree:
633
+ """Return an example state for the given `params` and `maxcor`."""
634
+ params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
635
+ n = params_flat.size
636
+ float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
637
+ example_jax_lbfgsb_state = dict(
638
+ x=jnp.zeros(n, dtype=float),
639
+ converged=jnp.asarray(False),
640
+ _maxcor=jnp.zeros((), dtype=int),
641
+ _line_search_max_steps=jnp.zeros((), dtype=int),
642
+ _ftol=jnp.zeros((), dtype=float),
643
+ _gtol=jnp.zeros((), dtype=float),
644
+ _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
645
+ _iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
646
+ _task=jnp.zeros(59, dtype=int),
647
+ _csave=jnp.zeros(59, dtype=int),
648
+ _lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
649
+ _isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
650
+ _dsave=jnp.zeros(29, dtype=float),
651
+ _lower_bound=jnp.zeros(n, dtype=float),
652
+ _upper_bound=jnp.zeros(n, dtype=float),
653
+ _bound_type=jnp.zeros(n, dtype=int),
654
+ )
655
+ return float_params, example_jax_lbfgsb_state
656
+
657
+
658
+ # ------------------------------------------------------------------------------
659
+ # Wrapper for scipy's L-BFGS-B implementation.
660
+ # ------------------------------------------------------------------------------
661
+
662
+
663
+ @dataclasses.dataclass
664
+ class ScipyLbfgsbState:
665
+ """Stores the state of a scipy L-BFGS-B minimization.
666
+
667
+ This class enables optimization with a more functional style, giving the user
668
+ control over the optimization loop. Example usage is as follows:
669
+
670
+ value_fn = lambda x: onp.sum(x**2)
671
+ grad_fn = lambda x: 2 * x
672
+
673
+ x0 = onp.asarray([0.1, 0.2, 0.3])
674
+ lb = [None, -1, 0.1]
675
+ ub = [None, None, None]
676
+ state = ScipyLbfgsbState.init(
677
+ x0=x0, lower_bound=lb, upper_bound=ub, maxcor=20
678
+ )
679
+
680
+ for _ in range(10):
681
+ value = value_fn(state.x)
682
+ grad = grad_fn(state.x)
683
+ state.update(grad, value)
684
+
685
+ This example converges with `state.x` equal to `(0, 0, 0.1)` and value equal
686
+ to `0.01`.
687
+
688
+ Attributes:
689
+ x: The current solution vector.
690
+ """
691
+
692
+ x: NDArray
693
+ converged: NDArray
694
+ # Private attributes correspond to internal variables in the `scipy.optimize.
695
+ # lbfgsb._minimize_lbfgsb` function.
696
+ _maxcor: int
697
+ _line_search_max_steps: int
698
+ _ftol: NDArray
699
+ _gtol: NDArray
700
+ _wa: NDArray
701
+ _iwa: NDArray
702
+ _task: NDArray
703
+ _csave: NDArray
704
+ _lsave: NDArray
705
+ _isave: NDArray
706
+ _dsave: NDArray
707
+ _lower_bound: NDArray
708
+ _upper_bound: NDArray
709
+ _bound_type: NDArray
710
+
711
+ def __post_init__(self) -> None:
712
+ """Validates the datatypes for all state attributes."""
713
+ _validate_array_dtype(self.x, onp.float64)
714
+ _validate_array_dtype(self._wa, onp.float64)
715
+ _validate_array_dtype(self._iwa, FORTRAN_INT)
716
+ _validate_array_dtype(self._task, "S60")
717
+ _validate_array_dtype(self._csave, "S60")
718
+ _validate_array_dtype(self._lsave, FORTRAN_INT)
719
+ _validate_array_dtype(self._isave, FORTRAN_INT)
720
+ _validate_array_dtype(self._dsave, onp.float64)
721
+ _validate_array_dtype(self._lower_bound, onp.float64)
722
+ _validate_array_dtype(self._upper_bound, onp.float64)
723
+ _validate_array_dtype(self._bound_type, int)
724
+
725
+ def to_dict(self) -> NumpyLbfgsbDict:
726
+ """Generates a dictionary of jax arrays defining the state."""
727
+ return dict(
728
+ x=onp.asarray(self.x),
729
+ converged=onp.asarray(self.converged),
730
+ _maxcor=onp.asarray(self._maxcor),
731
+ _line_search_max_steps=onp.asarray(self._line_search_max_steps),
732
+ _ftol=onp.asarray(self._ftol),
733
+ _gtol=onp.asarray(self._gtol),
734
+ _wa=onp.asarray(self._wa),
735
+ _iwa=onp.asarray(self._iwa),
736
+ _task=_array_from_s60_str(self._task),
737
+ _csave=_array_from_s60_str(self._csave),
738
+ _lsave=onp.asarray(self._lsave),
739
+ _isave=onp.asarray(self._isave),
740
+ _dsave=onp.asarray(self._dsave),
741
+ _lower_bound=onp.asarray(self._lower_bound),
742
+ _upper_bound=onp.asarray(self._upper_bound),
743
+ _bound_type=onp.asarray(self._bound_type),
744
+ )
745
+
746
+ @classmethod
747
+ def from_jax(cls, state_dict: JaxLbfgsbDict) -> "ScipyLbfgsbState":
748
+ """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
749
+ return ScipyLbfgsbState(
750
+ x=onp.array(state_dict["x"], dtype=onp.float64),
751
+ converged=onp.asarray(state_dict["converged"], dtype=bool),
752
+ _maxcor=int(state_dict["_maxcor"]),
753
+ _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
754
+ _ftol=onp.asarray(state_dict["_ftol"], dtype=onp.float64),
755
+ _gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
756
+ _wa=onp.array(state_dict["_wa"], onp.float64),
757
+ _iwa=onp.array(state_dict["_iwa"], dtype=FORTRAN_INT),
758
+ _task=_s60_str_from_array(onp.asarray(state_dict["_task"])),
759
+ _csave=_s60_str_from_array(onp.asarray(state_dict["_csave"])),
760
+ _lsave=onp.array(state_dict["_lsave"], dtype=FORTRAN_INT),
761
+ _isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
762
+ _dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
763
+ _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
764
+ _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
765
+ _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
766
+ )
767
+
768
+ @classmethod
769
+ def init(
770
+ cls,
771
+ x0: NDArray,
772
+ lower_bound: ElementwiseBound,
773
+ upper_bound: ElementwiseBound,
774
+ maxcor: int,
775
+ line_search_max_steps: int,
776
+ ftol: float,
777
+ gtol: float,
778
+ ) -> "ScipyLbfgsbState":
779
+ """Initializes the `ScipyLbfgsbState` for `x0`.
780
+
781
+ Args:
782
+ x0: Array giving the initial solution vector.
783
+ lower_bound: Array giving the elementwise optional lower bound.
784
+ upper_bound: Array giving the elementwise optional upper bound.
785
+ maxcor: The maximum number of variable metric corrections used to define
786
+ the limited memory matrix, in the L-BFGS-B scheme.
787
+ line_search_max_steps: The maximum number of steps in the line search.
788
+ ftol: Tolerance for stopping criteria based on function values. See scipy
789
+ documentation for details.
790
+ gtol: Tolerance for stopping criteria based on gradient.
791
+
792
+ Returns:
793
+ The `ScipyLbfgsbState`.
794
+ """
795
+ x0 = onp.asarray(x0)
796
+ if x0.ndim > 1:
797
+ raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
798
+ lower_bound = onp.asarray(lower_bound)
799
+ upper_bound = onp.asarray(upper_bound)
800
+ if x0.shape != lower_bound.shape or x0.shape != upper_bound.shape:
801
+ raise ValueError(
802
+ f"`x0`, `lower_bound`, and `upper_bound` must have matching "
803
+ f"shape but got shapes {x0.shape}, {lower_bound.shape}, and "
804
+ f"{upper_bound.shape}, respectively."
805
+ )
806
+ if maxcor < 1:
807
+ raise ValueError(f"`maxcor` must be positive but got {maxcor}.")
808
+
809
+ n = x0.size
810
+ lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
811
+ lower_bound, upper_bound
812
+ )
813
+ task = onp.zeros(1, "S60")
814
+ task[:] = TASK_START
815
+
816
+ # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
817
+ # function.
818
+ wa_size = _wa_size(n=n, maxcor=maxcor)
819
+ state = ScipyLbfgsbState(
820
+ x=onp.array(x0, onp.float64),
821
+ converged=onp.asarray(False),
822
+ _maxcor=maxcor,
823
+ _line_search_max_steps=line_search_max_steps,
824
+ _ftol=onp.asarray(ftol, onp.float64),
825
+ _gtol=onp.asarray(gtol, onp.float64),
826
+ _wa=onp.zeros(wa_size, onp.float64),
827
+ _iwa=onp.zeros(3 * n, FORTRAN_INT),
828
+ _task=task,
829
+ _csave=onp.zeros(1, "S60"),
830
+ _lsave=onp.zeros(4, FORTRAN_INT),
831
+ _isave=onp.zeros(44, FORTRAN_INT),
832
+ _dsave=onp.zeros(29, onp.float64),
833
+ _lower_bound=lower_bound_array,
834
+ _upper_bound=upper_bound_array,
835
+ _bound_type=bound_type,
836
+ )
837
+ # The initial state requires an update with zero value and gradient. This
838
+ # is because the initial task is "START", which does not actually require
839
+ # value and gradient evaluation.
840
+ state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
841
+ return state
842
+
843
+ def update(
844
+ self,
845
+ grad: NDArray,
846
+ value: NDArray,
847
+ ) -> None:
848
+ """Performs an in-place update of the `ScipyLbfgsbState` if not converged.
849
+
850
+ Args:
851
+ grad: The function gradient for the current `x`.
852
+ value: The scalar function value for the current `x`.
853
+ """
854
+ if self.converged:
855
+ return
856
+ if grad.shape != self.x.shape:
857
+ raise ValueError(
858
+ f"`grad` must have the same shape as attribute `x`, but got shapes "
859
+ f"{grad.shape} and {self.x.shape}, respectively."
860
+ )
861
+ if value.shape != ():
862
+ raise ValueError(f"`value` must be a scalar but got shape {value.shape}.")
863
+
864
+ # The `setulb` function will sometimes return with a task that does not
865
+ # require a value and gradient evaluation. In this case we simply call it
866
+ # again, advancing past such "dummy" steps.
867
+ for _ in range(3):
868
+ scipy_lbfgsb.setulb(
869
+ m=self._maxcor,
870
+ x=self.x,
871
+ l=self._lower_bound,
872
+ u=self._upper_bound,
873
+ nbd=self._bound_type,
874
+ f=value,
875
+ g=grad,
876
+ factr=self._ftol / onp.finfo(float).eps,
877
+ pgtol=self._gtol,
878
+ wa=self._wa,
879
+ iwa=self._iwa,
880
+ task=self._task,
881
+ iprint=UPDATE_IPRINT,
882
+ csave=self._csave,
883
+ lsave=self._lsave,
884
+ isave=self._isave,
885
+ dsave=self._dsave,
886
+ maxls=self._line_search_max_steps,
887
+ )
888
+ task_str = self._task.tobytes()
889
+ if task_str.startswith(TASK_CONVERGED):
890
+ self.converged = onp.asarray(True)
891
+ if task_str.startswith(TASK_FG):
892
+ break
893
+
894
+
895
+ def _wa_size(n: int, maxcor: int) -> int:
896
+ """Return the size of the `wa` attribute of lbfgsb state."""
897
+ return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
898
+
899
+
900
+ def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
901
+ """Validates that `x` is an array with the specified `dtype`."""
902
+ if not isinstance(x, onp.ndarray):
903
+ raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
904
+ if x.dtype != dtype:
905
+ raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
906
+
907
+
908
+ def _configure_bounds(
909
+ lower_bound: ElementwiseBound,
910
+ upper_bound: ElementwiseBound,
911
+ ) -> Tuple[NDArray, NDArray, NDArray]:
912
+ """Configures the bounds for an L-BFGS-B optimization."""
913
+ bound_type = [
914
+ BOUNDS_MAP[(lower is None, upper is None)]
915
+ for lower, upper in zip(lower_bound, upper_bound)
916
+ ]
917
+ lower_bound_array = [0.0 if x is None else x for x in lower_bound]
918
+ upper_bound_array = [0.0 if x is None else x for x in upper_bound]
919
+ return (
920
+ onp.asarray(lower_bound_array, onp.float64),
921
+ onp.asarray(upper_bound_array, onp.float64),
922
+ onp.asarray(bound_type),
923
+ )
924
+
925
+
926
+ def _array_from_s60_str(s60_str: NDArray) -> NDArray:
927
+ """Return a jax array for a numpy s60 string."""
928
+ assert s60_str.shape == (1,)
929
+ chars = [int(o) for o in s60_str[0]]
930
+ chars.extend([32] * (59 - len(chars)))
931
+ return onp.asarray(chars, dtype=int)
932
+
933
+
934
+ def _s60_str_from_array(array: NDArray) -> NDArray:
935
+ """Return a numpy s60 string for a jax array."""
936
+ return onp.asarray(
937
+ [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
938
+ dtype="S60",
939
+ )