invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +155 -0
- invrs_opt/experimental/labels.py +23 -0
- invrs_opt/optimizers/__init__.py +0 -0
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -670
- invrs_opt-0.3.2.dist-info/LICENSE +0 -21
- invrs_opt-0.3.2.dist-info/METADATA +0 -73
- invrs_opt-0.3.2.dist-info/RECORD +0 -11
- /invrs_opt/{lbfgsb → experimental}/__init__.py +0 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,939 @@
|
|
1
|
+
"""Defines a jax-style wrapper for scipy's L-BFGS-B algorithm.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import dataclasses
|
7
|
+
import functools
|
8
|
+
from packaging import version
|
9
|
+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
10
|
+
|
11
|
+
import jax
|
12
|
+
import jax.numpy as jnp
|
13
|
+
import numpy as onp
|
14
|
+
import optax # type: ignore[import-untyped]
|
15
|
+
from jax import flatten_util, tree_util
|
16
|
+
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
17
|
+
_lbfgsb as scipy_lbfgsb,
|
18
|
+
)
|
19
|
+
from totypes import types
|
20
|
+
|
21
|
+
from invrs_opt.optimizers import base
|
22
|
+
from invrs_opt.parameterization import (
|
23
|
+
base as param_base,
|
24
|
+
filter_project,
|
25
|
+
gaussian_levelset,
|
26
|
+
pixel,
|
27
|
+
)
|
28
|
+
|
29
|
+
NDArray = onp.ndarray[Any, Any]
|
30
|
+
PyTree = Any
|
31
|
+
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
32
|
+
NumpyLbfgsbDict = Dict[str, NDArray]
|
33
|
+
JaxLbfgsbDict = Dict[str, jnp.ndarray]
|
34
|
+
LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]
|
35
|
+
|
36
|
+
|
37
|
+
# Task message prefixes for the underlying L-BFGS-B implementation.
|
38
|
+
TASK_START = b"START"
|
39
|
+
TASK_FG = b"FG"
|
40
|
+
TASK_CONVERGED = b"CONVERGENCE"
|
41
|
+
|
42
|
+
UPDATE_IPRINT = -1
|
43
|
+
|
44
|
+
# Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
|
45
|
+
MAXCOR_MAX_VALUE = 100
|
46
|
+
DEFAULT_MAXCOR = 20
|
47
|
+
DEFAULT_LINE_SEARCH_MAX_STEPS = 100
|
48
|
+
DEFAULT_FTOL = 0.0
|
49
|
+
DEFAULT_GTOL = 0.0
|
50
|
+
|
51
|
+
# Maps bound scenarios to integers.
|
52
|
+
BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
|
53
|
+
(True, True): 0, # Both upper and lower bound are `None`.
|
54
|
+
(False, True): 1, # Only upper bound is `None`.
|
55
|
+
(False, False): 2, # Neither of the bounds are `None`.
|
56
|
+
(True, False): 3, # Only the lower bound is `None`.
|
57
|
+
}
|
58
|
+
|
59
|
+
FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
|
60
|
+
|
61
|
+
if version.Version(jax.__version__) > version.Version("0.4.31"):
|
62
|
+
callback_sequential = functools.partial(jax.pure_callback, vmap_method="sequential")
|
63
|
+
else:
|
64
|
+
callback_sequential = functools.partial(jax.pure_callback, vectorized=False)
|
65
|
+
|
66
|
+
|
67
|
+
def lbfgsb(
|
68
|
+
*,
|
69
|
+
maxcor: int = DEFAULT_MAXCOR,
|
70
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
71
|
+
ftol: float = DEFAULT_FTOL,
|
72
|
+
gtol: float = DEFAULT_GTOL,
|
73
|
+
) -> base.Optimizer:
|
74
|
+
"""Optimizer implementing the standard L-BFGS-B algorithm.
|
75
|
+
|
76
|
+
The standard L-BFGS-B algorithm uses the direct pixel parameterization for density
|
77
|
+
arrays, which simply enforces that values are between the declared upper and lower
|
78
|
+
bounds of the density.
|
79
|
+
|
80
|
+
When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
|
81
|
+
the optimizer `params` function will simply return the optimal parameters. The
|
82
|
+
convergence can be queried by `is_converged(state)`.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
86
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
87
|
+
line_search_max_steps: The maximum number of steps in the line search.
|
88
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
89
|
+
for details.
|
90
|
+
gtol: Convergence criteria based on gradient.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
The `Optimizer` implementing the L-BFGS-B optimizer.
|
94
|
+
"""
|
95
|
+
return parameterized_lbfgsb(
|
96
|
+
density_parameterization=None,
|
97
|
+
penalty=0.0,
|
98
|
+
maxcor=maxcor,
|
99
|
+
line_search_max_steps=line_search_max_steps,
|
100
|
+
ftol=ftol,
|
101
|
+
gtol=gtol,
|
102
|
+
)
|
103
|
+
|
104
|
+
|
105
|
+
def density_lbfgsb(
|
106
|
+
*,
|
107
|
+
beta: float,
|
108
|
+
maxcor: int = DEFAULT_MAXCOR,
|
109
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
110
|
+
ftol: float = DEFAULT_FTOL,
|
111
|
+
gtol: float = DEFAULT_GTOL,
|
112
|
+
) -> base.Optimizer:
|
113
|
+
"""Optimizer using L-BFGS-B algorithm with filter-project density parameterization.
|
114
|
+
|
115
|
+
In the filter-project density parameterization, the optimization variable
|
116
|
+
associated with a density array is a latent density array; the density is obtained
|
117
|
+
by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
|
118
|
+
full-width at half-maximum equal to the length scale (the mean of declared minimum
|
119
|
+
width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
|
120
|
+
operation ("projection").
|
121
|
+
|
122
|
+
When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
|
123
|
+
the optimizer `params` function will simply return the optimal parameters. The
|
124
|
+
convergence can be queried by `is_converged(state)`.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
beta: Determines the sharpness of the thresholding operation.
|
128
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
129
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
130
|
+
line_search_max_steps: The maximum number of steps in the line search.
|
131
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
132
|
+
for details.
|
133
|
+
gtol: Convergence criteria based on gradient.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
The `Optimizer` implementing the L-BFGS-B optimizer.
|
137
|
+
"""
|
138
|
+
return parameterized_lbfgsb(
|
139
|
+
density_parameterization=filter_project.filter_project(beta=beta),
|
140
|
+
penalty=0.0,
|
141
|
+
maxcor=maxcor,
|
142
|
+
line_search_max_steps=line_search_max_steps,
|
143
|
+
ftol=ftol,
|
144
|
+
gtol=gtol,
|
145
|
+
)
|
146
|
+
|
147
|
+
|
148
|
+
def levelset_lbfgsb(
|
149
|
+
*,
|
150
|
+
penalty: float,
|
151
|
+
length_scale_spacing_factor: float = (
|
152
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
|
153
|
+
),
|
154
|
+
length_scale_fwhm_factor: float = (
|
155
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
|
156
|
+
),
|
157
|
+
length_scale_constraint_factor: float = (
|
158
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
|
159
|
+
),
|
160
|
+
smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
|
161
|
+
length_scale_constraint_beta: float = (
|
162
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
|
163
|
+
),
|
164
|
+
length_scale_constraint_weight: float = (
|
165
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
|
166
|
+
),
|
167
|
+
curvature_constraint_weight: float = (
|
168
|
+
gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
|
169
|
+
),
|
170
|
+
fixed_pixel_constraint_weight: float = (
|
171
|
+
gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
|
172
|
+
),
|
173
|
+
init_optimizer: optax.GradientTransformation = (
|
174
|
+
gaussian_levelset.DEFAULT_INIT_OPTIMIZER
|
175
|
+
),
|
176
|
+
init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
|
177
|
+
maxcor: int = DEFAULT_MAXCOR,
|
178
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
179
|
+
ftol: float = DEFAULT_FTOL,
|
180
|
+
gtol: float = DEFAULT_GTOL,
|
181
|
+
) -> base.Optimizer:
|
182
|
+
"""Optimizer using L-BFGS-B algorithm with levelset density parameterization.
|
183
|
+
|
184
|
+
In the levelset parameterization, the optimization variable associated with a
|
185
|
+
density array is an array giving the amplitudes of Gaussian radial basis functions
|
186
|
+
that represent a levelset function over the domain of the density. In the levelset
|
187
|
+
parameterization, gradients are nonzero only at the edges of features, and in
|
188
|
+
general the topology of a solution does not change during the course of
|
189
|
+
optimization.
|
190
|
+
|
191
|
+
The spacing and full-width at half-maximum of the Gaussian basis functions gives
|
192
|
+
some amount of control over length scales. In addition, constraints associated with
|
193
|
+
length scale, radius of curvature, and deviation from fixed pixels are
|
194
|
+
automatically computed and penalized with a weight given by `penalty`. In general,
|
195
|
+
this helps ensure that features in an optimized density array violate the specified
|
196
|
+
constraints to a lesser degree. The constraints are based on "Analytical level set
|
197
|
+
fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
|
198
|
+
|
199
|
+
When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
|
200
|
+
the optimizer `params` function will simply return the optimal parameters. The
|
201
|
+
convergence can be queried by `is_converged(state)`.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
penalty: The weight of the fabrication penalty, which combines length scale,
|
205
|
+
curvature, and fixed pixel constraints.
|
206
|
+
length_scale_spacing_factor: The number of levelset control points per unit of
|
207
|
+
minimum length scale (mean of density minimum width and minimum spacing).
|
208
|
+
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
209
|
+
the minimum length scale.
|
210
|
+
length_scale_constraint_factor: Multiplies the target length scale in the
|
211
|
+
levelset constraints. A value greater than 1 is pessimistic and drives the
|
212
|
+
solution to have a larger length scale (relative to smaller values).
|
213
|
+
smoothing_factor: For values greater than 1, the density is initially computed
|
214
|
+
at higher resolution and then downsampled, yielding smoother geometries.
|
215
|
+
length_scale_constraint_beta: Controls relaxation of the length scale
|
216
|
+
constraint near the zero level.
|
217
|
+
length_scale_constraint_weight: The weight of the length scale constraint in
|
218
|
+
the overall fabrication constraint peenalty.
|
219
|
+
curvature_constraint_weight: The weight of the curvature constraint.
|
220
|
+
fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
|
221
|
+
init_optimizer: The optimizer used in the initialization of the levelset
|
222
|
+
parameterization. At initialization, the latent parameters are optimized so
|
223
|
+
that the initial parameters match the binarized initial density.
|
224
|
+
init_steps: The number of optimization steps used in the initialization.
|
225
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
226
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
227
|
+
line_search_max_steps: The maximum number of steps in the line search.
|
228
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
229
|
+
for details.
|
230
|
+
gtol: Convergence criteria based on gradient.
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
The `Optimizer` implementing the L-BFGS-B optimizer.
|
234
|
+
"""
|
235
|
+
return parameterized_lbfgsb(
|
236
|
+
density_parameterization=gaussian_levelset.gaussian_levelset(
|
237
|
+
length_scale_spacing_factor=length_scale_spacing_factor,
|
238
|
+
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
239
|
+
length_scale_constraint_factor=length_scale_constraint_factor,
|
240
|
+
smoothing_factor=smoothing_factor,
|
241
|
+
length_scale_constraint_beta=length_scale_constraint_beta,
|
242
|
+
length_scale_constraint_weight=length_scale_constraint_weight,
|
243
|
+
curvature_constraint_weight=curvature_constraint_weight,
|
244
|
+
fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
|
245
|
+
init_optimizer=init_optimizer,
|
246
|
+
init_steps=init_steps,
|
247
|
+
),
|
248
|
+
penalty=penalty,
|
249
|
+
maxcor=maxcor,
|
250
|
+
line_search_max_steps=line_search_max_steps,
|
251
|
+
ftol=ftol,
|
252
|
+
gtol=gtol,
|
253
|
+
)
|
254
|
+
|
255
|
+
|
256
|
+
# -----------------------------------------------------------------------------
|
257
|
+
# Base parameterized L-BFGS-B optimizer.
|
258
|
+
# -----------------------------------------------------------------------------
|
259
|
+
|
260
|
+
|
261
|
+
def parameterized_lbfgsb(
|
262
|
+
density_parameterization: Optional[param_base.Density2DParameterization],
|
263
|
+
penalty: float,
|
264
|
+
maxcor: int = DEFAULT_MAXCOR,
|
265
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
266
|
+
ftol: float = DEFAULT_FTOL,
|
267
|
+
gtol: float = DEFAULT_GTOL,
|
268
|
+
) -> base.Optimizer:
|
269
|
+
"""Optimizer using L-BFGS-B optimizer with specified density parameterization.
|
270
|
+
|
271
|
+
This optimizer wraps scipy's implementation of the algorithm, and provides
|
272
|
+
a jax-style API to the scheme. The optimizer works with custom types such
|
273
|
+
as the `BoundedArray` to constrain the optimization variable.
|
274
|
+
|
275
|
+
Args:
|
276
|
+
density_parameterization: The parameterization to be used, or `None`. When no
|
277
|
+
parameterization is given, the direct pixel parameterization is used for
|
278
|
+
density arrays.
|
279
|
+
penalty: The weight of the scalar penalty formed from the constraints of the
|
280
|
+
parameterization.
|
281
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
282
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
283
|
+
line_search_max_steps: The maximum number of steps in the line search.
|
284
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
285
|
+
for details.
|
286
|
+
gtol: Convergence criteria based on gradient.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
The `base.Optimizer`.
|
290
|
+
"""
|
291
|
+
if not isinstance(maxcor, int) or maxcor < 1 or maxcor > MAXCOR_MAX_VALUE:
|
292
|
+
raise ValueError(
|
293
|
+
f"`maxcor` must be greater than 0 and less than "
|
294
|
+
f"{MAXCOR_MAX_VALUE}, but got {maxcor}"
|
295
|
+
)
|
296
|
+
|
297
|
+
if not isinstance(line_search_max_steps, int) or line_search_max_steps < 1:
|
298
|
+
raise ValueError(
|
299
|
+
f"`line_search_max_steps` must be greater than 0 but got "
|
300
|
+
f"{line_search_max_steps}"
|
301
|
+
)
|
302
|
+
|
303
|
+
if density_parameterization is None:
|
304
|
+
density_parameterization = pixel.pixel()
|
305
|
+
|
306
|
+
def init_fn(params: PyTree) -> LbfgsbState:
|
307
|
+
"""Initializes the optimization state."""
|
308
|
+
|
309
|
+
def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, NumpyLbfgsbDict]:
|
310
|
+
lower_bound = types.extract_lower_bound(latent_params)
|
311
|
+
upper_bound = types.extract_upper_bound(latent_params)
|
312
|
+
scipy_lbfgsb_state = ScipyLbfgsbState.init(
|
313
|
+
x0=_to_numpy(latent_params),
|
314
|
+
lower_bound=_bound_for_params(lower_bound, latent_params),
|
315
|
+
upper_bound=_bound_for_params(upper_bound, latent_params),
|
316
|
+
maxcor=maxcor,
|
317
|
+
line_search_max_steps=line_search_max_steps,
|
318
|
+
ftol=ftol,
|
319
|
+
gtol=gtol,
|
320
|
+
)
|
321
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_params)
|
322
|
+
return latent_params, scipy_lbfgsb_state.to_dict()
|
323
|
+
|
324
|
+
latent_params = _init_latents(params)
|
325
|
+
metadata, latents = param_base.partition_density_metadata(latent_params)
|
326
|
+
latents, jax_lbfgsb_state = callback_sequential(
|
327
|
+
_init_state_pure,
|
328
|
+
_example_state(latents, maxcor),
|
329
|
+
latents,
|
330
|
+
)
|
331
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
332
|
+
return (
|
333
|
+
0, # step
|
334
|
+
_params_from_latent_params(latent_params), # params
|
335
|
+
latent_params, # latent params
|
336
|
+
jax_lbfgsb_state, # opt state
|
337
|
+
)
|
338
|
+
|
339
|
+
def params_fn(state: LbfgsbState) -> PyTree:
|
340
|
+
"""Returns the parameters for the given `state`."""
|
341
|
+
_, params, _, _ = state
|
342
|
+
return params
|
343
|
+
|
344
|
+
def update_fn(
|
345
|
+
*,
|
346
|
+
grad: PyTree,
|
347
|
+
value: jnp.ndarray,
|
348
|
+
params: PyTree,
|
349
|
+
state: LbfgsbState,
|
350
|
+
) -> LbfgsbState:
|
351
|
+
"""Updates the state."""
|
352
|
+
del params
|
353
|
+
|
354
|
+
def _update_pure(
|
355
|
+
flat_latent_grad: PyTree,
|
356
|
+
value: jnp.ndarray,
|
357
|
+
jax_lbfgsb_state: JaxLbfgsbDict,
|
358
|
+
) -> Tuple[NDArray, NumpyLbfgsbDict]:
|
359
|
+
assert onp.size(value) == 1
|
360
|
+
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
|
361
|
+
flat_latent_params = scipy_lbfgsb_state.x.copy()
|
362
|
+
scipy_lbfgsb_state.update(
|
363
|
+
grad=onp.array(flat_latent_grad, dtype=onp.float64),
|
364
|
+
value=onp.array(value, dtype=onp.float64),
|
365
|
+
)
|
366
|
+
updated_flat_latent_params = scipy_lbfgsb_state.x
|
367
|
+
flat_latent_updates: NDArray
|
368
|
+
flat_latent_updates = updated_flat_latent_params - flat_latent_params
|
369
|
+
return flat_latent_updates, scipy_lbfgsb_state.to_dict()
|
370
|
+
|
371
|
+
step, _, latent_params, jax_lbfgsb_state = state
|
372
|
+
metadata, latents = param_base.partition_density_metadata(latent_params)
|
373
|
+
|
374
|
+
def _params_from_latents(latents: PyTree) -> PyTree:
|
375
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
376
|
+
return _params_from_latent_params(latent_params)
|
377
|
+
|
378
|
+
def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
|
379
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
380
|
+
return _constraint_loss(latent_params)
|
381
|
+
|
382
|
+
_, vjp_fn = jax.vjp(_params_from_latents, latents)
|
383
|
+
(latents_grad,) = vjp_fn(grad)
|
384
|
+
|
385
|
+
if not (
|
386
|
+
tree_util.tree_structure(latents_grad)
|
387
|
+
== tree_util.tree_structure(latents) # type: ignore[operator]
|
388
|
+
):
|
389
|
+
raise ValueError(
|
390
|
+
f"Tree structure of `latents_grad` was different than expected, got \n"
|
391
|
+
f"{tree_util.tree_structure(latents_grad)} but expected \n"
|
392
|
+
f"{tree_util.tree_structure(latents)}."
|
393
|
+
)
|
394
|
+
|
395
|
+
(
|
396
|
+
constraint_loss_value,
|
397
|
+
constraint_loss_grad,
|
398
|
+
) = jax.value_and_grad(
|
399
|
+
_constraint_loss_latents
|
400
|
+
)(latents)
|
401
|
+
value += constraint_loss_value
|
402
|
+
latents_grad = tree_util.tree_map(
|
403
|
+
lambda a, b: a + b, latents_grad, constraint_loss_grad
|
404
|
+
)
|
405
|
+
|
406
|
+
flat_latents_grad, unflatten_fn = flatten_util.ravel_pytree(
|
407
|
+
latents_grad
|
408
|
+
) # type: ignore[no-untyped-call]
|
409
|
+
|
410
|
+
flat_latent_updates, jax_lbfgsb_state = callback_sequential(
|
411
|
+
_update_pure,
|
412
|
+
(flat_latents_grad, jax_lbfgsb_state),
|
413
|
+
flat_latents_grad,
|
414
|
+
value,
|
415
|
+
jax_lbfgsb_state,
|
416
|
+
)
|
417
|
+
latent_updates = unflatten_fn(flat_latent_updates)
|
418
|
+
latent_params = _apply_updates(
|
419
|
+
params=latent_params,
|
420
|
+
updates=param_base.combine_density_metadata(metadata, latent_updates),
|
421
|
+
value=value,
|
422
|
+
step=step,
|
423
|
+
)
|
424
|
+
latent_params = _clip(latent_params)
|
425
|
+
params = _params_from_latent_params(latent_params)
|
426
|
+
return step + 1, params, latent_params, jax_lbfgsb_state
|
427
|
+
|
428
|
+
# -------------------------------------------------------------------------
|
429
|
+
# Functions related to the density parameterization.
|
430
|
+
# -------------------------------------------------------------------------
|
431
|
+
|
432
|
+
def _init_latents(params: PyTree) -> PyTree:
|
433
|
+
def _leaf_init_latents(leaf: Any) -> Any:
|
434
|
+
leaf = _clip(leaf)
|
435
|
+
if not _is_density(leaf) or density_parameterization is None:
|
436
|
+
return leaf
|
437
|
+
return density_parameterization.from_density(leaf)
|
438
|
+
|
439
|
+
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
440
|
+
|
441
|
+
def _params_from_latent_params(latent_params: PyTree) -> PyTree:
|
442
|
+
def _leaf_params_from_latents(leaf: Any) -> Any:
|
443
|
+
if not _is_parameterized_density(leaf) or density_parameterization is None:
|
444
|
+
return leaf
|
445
|
+
return density_parameterization.to_density(leaf)
|
446
|
+
|
447
|
+
return tree_util.tree_map(
|
448
|
+
_leaf_params_from_latents,
|
449
|
+
latent_params,
|
450
|
+
is_leaf=_is_parameterized_density,
|
451
|
+
)
|
452
|
+
|
453
|
+
def _apply_updates(
|
454
|
+
params: PyTree,
|
455
|
+
updates: PyTree,
|
456
|
+
value: jnp.ndarray,
|
457
|
+
step: int,
|
458
|
+
) -> PyTree:
|
459
|
+
def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
|
460
|
+
if _is_parameterized_density(leaf):
|
461
|
+
return density_parameterization.update(
|
462
|
+
params=leaf, updates=update, value=value, step=step
|
463
|
+
)
|
464
|
+
else:
|
465
|
+
return optax.apply_updates(params=leaf, updates=update)
|
466
|
+
|
467
|
+
return tree_util.tree_map(
|
468
|
+
_leaf_apply_updates,
|
469
|
+
updates,
|
470
|
+
params,
|
471
|
+
is_leaf=_is_parameterized_density,
|
472
|
+
)
|
473
|
+
|
474
|
+
# -------------------------------------------------------------------------
|
475
|
+
# Functions related to the constraints to be minimized.
|
476
|
+
# -------------------------------------------------------------------------
|
477
|
+
|
478
|
+
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
479
|
+
def _constraint_loss_leaf(
|
480
|
+
leaf: param_base.ParameterizedDensity2DArray,
|
481
|
+
) -> jnp.ndarray:
|
482
|
+
constraints = density_parameterization.constraints(leaf)
|
483
|
+
constraints = tree_util.tree_map(
|
484
|
+
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
485
|
+
constraints,
|
486
|
+
)
|
487
|
+
return jnp.sum(jnp.asarray(constraints))
|
488
|
+
|
489
|
+
losses = [0.0] + [
|
490
|
+
_constraint_loss_leaf(p)
|
491
|
+
for p in tree_util.tree_leaves(
|
492
|
+
latent_params, is_leaf=_is_parameterized_density
|
493
|
+
)
|
494
|
+
if _is_parameterized_density(p)
|
495
|
+
]
|
496
|
+
return penalty * jnp.sum(jnp.asarray(losses))
|
497
|
+
|
498
|
+
return base.Optimizer(
|
499
|
+
init=init_fn,
|
500
|
+
params=params_fn,
|
501
|
+
update=update_fn,
|
502
|
+
)
|
503
|
+
|
504
|
+
|
505
|
+
def is_converged(state: LbfgsbState) -> jnp.ndarray:
|
506
|
+
"""Returns `True` if the optimization has converged."""
|
507
|
+
return state[3]["converged"]
|
508
|
+
|
509
|
+
|
510
|
+
# ------------------------------------------------------------------------------
|
511
|
+
# Helper functions.
|
512
|
+
# ------------------------------------------------------------------------------
|
513
|
+
|
514
|
+
|
515
|
+
def _is_density(leaf: Any) -> Any:
|
516
|
+
"""Return `True` if `leaf` is a density array."""
|
517
|
+
return isinstance(leaf, types.Density2DArray)
|
518
|
+
|
519
|
+
|
520
|
+
def _is_parameterized_density(leaf: Any) -> Any:
|
521
|
+
"""Return `True` if `leaf` is a parameterized density array."""
|
522
|
+
return isinstance(leaf, param_base.ParameterizedDensity2DArray)
|
523
|
+
|
524
|
+
|
525
|
+
def _is_custom_type(leaf: Any) -> bool:
|
526
|
+
"""Return `True` if `leaf` is a recognized custom type."""
|
527
|
+
return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
|
528
|
+
|
529
|
+
|
530
|
+
def _clip(pytree: PyTree) -> PyTree:
|
531
|
+
"""Clips leaves on `pytree` to their bounds."""
|
532
|
+
|
533
|
+
def _clip_fn(leaf: Any) -> Any:
|
534
|
+
if not _is_custom_type(leaf):
|
535
|
+
return leaf
|
536
|
+
if leaf.lower_bound is None and leaf.upper_bound is None:
|
537
|
+
return leaf
|
538
|
+
return tree_util.tree_map(
|
539
|
+
lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
|
540
|
+
)
|
541
|
+
|
542
|
+
return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
|
543
|
+
|
544
|
+
|
545
|
+
def _to_numpy(params: PyTree) -> NDArray:
|
546
|
+
"""Flattens a `params` pytree into a single rank-1 numpy array."""
|
547
|
+
x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
548
|
+
return onp.asarray(x, dtype=onp.float64)
|
549
|
+
|
550
|
+
|
551
|
+
def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
|
552
|
+
"""Restores a pytree from a flat numpy array using the structure of `params`.
|
553
|
+
|
554
|
+
Note that the returned pytree includes jax array leaves.
|
555
|
+
|
556
|
+
Args:
|
557
|
+
x_flat: The rank-1 numpy array to be restored.
|
558
|
+
params: A pytree of parameters whose structure is replicated in the restored
|
559
|
+
pytree.
|
560
|
+
|
561
|
+
Returns:
|
562
|
+
The restored pytree, with jax array leaves.
|
563
|
+
"""
|
564
|
+
_, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
565
|
+
return unflatten_fn(jnp.asarray(x_flat, dtype=float))
|
566
|
+
|
567
|
+
|
568
|
+
def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
569
|
+
"""Generates a bound vector for the `params`.
|
570
|
+
|
571
|
+
The `bound` can be specified in various ways; it may be `None` or a scalar,
|
572
|
+
which then applies to all arrays in `params`. It may be a pytree with
|
573
|
+
structure matching that of `params`, where each leaf is either `None`, a
|
574
|
+
scalar, or an array matching the shape of the corresponding leaf in `params`.
|
575
|
+
|
576
|
+
The returned bound is a flat array suitable for use with `ScipyLbfgsbState`.
|
577
|
+
|
578
|
+
Args:
|
579
|
+
bound: The pytree of bounds.
|
580
|
+
params: The pytree of parameters.
|
581
|
+
|
582
|
+
Returns:
|
583
|
+
The flat elementwise bound.
|
584
|
+
"""
|
585
|
+
|
586
|
+
if bound is None or onp.isscalar(bound):
|
587
|
+
bound = tree_util.tree_map(
|
588
|
+
lambda _: bound,
|
589
|
+
params,
|
590
|
+
is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES),
|
591
|
+
)
|
592
|
+
|
593
|
+
bound_leaves, bound_treedef = tree_util.tree_flatten(
|
594
|
+
bound, is_leaf=lambda x: x is None
|
595
|
+
)
|
596
|
+
params_leaves = tree_util.tree_leaves(params, is_leaf=lambda x: x is None)
|
597
|
+
|
598
|
+
# `bound` should be a pytree of arrays or `None`, while `params` may
|
599
|
+
# include custom pytree nodes. Convert the custom nodes into standard
|
600
|
+
# types to facilitate validation that the tree structures match.
|
601
|
+
params_treedef = tree_util.tree_structure(
|
602
|
+
tree_util.tree_map(
|
603
|
+
lambda x: 0.0,
|
604
|
+
tree=params,
|
605
|
+
is_leaf=lambda x: x is None or isinstance(x, types.CUSTOM_TYPES),
|
606
|
+
)
|
607
|
+
)
|
608
|
+
if bound_treedef != params_treedef: # type: ignore[operator]
|
609
|
+
raise ValueError(
|
610
|
+
f"Tree structure of `bound` and `params` must match, but got "
|
611
|
+
f"{bound_treedef} and {params_treedef}, respectively."
|
612
|
+
)
|
613
|
+
|
614
|
+
bound_flat = []
|
615
|
+
for b, p in zip(bound_leaves, params_leaves):
|
616
|
+
if p is None:
|
617
|
+
continue
|
618
|
+
if b is None or onp.isscalar(b) or onp.shape(b) == ():
|
619
|
+
bound_flat += [b] * onp.size(p)
|
620
|
+
else:
|
621
|
+
if b.shape != p.shape:
|
622
|
+
raise ValueError(
|
623
|
+
f"`bound` must be `None`, a scalar, or have shape matching "
|
624
|
+
f"`params`, but got shape {b.shape} when params has shape "
|
625
|
+
f"{p.shape}."
|
626
|
+
)
|
627
|
+
bound_flat += b.flatten().tolist()
|
628
|
+
|
629
|
+
return bound_flat
|
630
|
+
|
631
|
+
|
632
|
+
def _example_state(params: PyTree, maxcor: int) -> PyTree:
|
633
|
+
"""Return an example state for the given `params` and `maxcor`."""
|
634
|
+
params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
635
|
+
n = params_flat.size
|
636
|
+
float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
|
637
|
+
example_jax_lbfgsb_state = dict(
|
638
|
+
x=jnp.zeros(n, dtype=float),
|
639
|
+
converged=jnp.asarray(False),
|
640
|
+
_maxcor=jnp.zeros((), dtype=int),
|
641
|
+
_line_search_max_steps=jnp.zeros((), dtype=int),
|
642
|
+
_ftol=jnp.zeros((), dtype=float),
|
643
|
+
_gtol=jnp.zeros((), dtype=float),
|
644
|
+
_wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
|
645
|
+
_iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
|
646
|
+
_task=jnp.zeros(59, dtype=int),
|
647
|
+
_csave=jnp.zeros(59, dtype=int),
|
648
|
+
_lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
|
649
|
+
_isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
|
650
|
+
_dsave=jnp.zeros(29, dtype=float),
|
651
|
+
_lower_bound=jnp.zeros(n, dtype=float),
|
652
|
+
_upper_bound=jnp.zeros(n, dtype=float),
|
653
|
+
_bound_type=jnp.zeros(n, dtype=int),
|
654
|
+
)
|
655
|
+
return float_params, example_jax_lbfgsb_state
|
656
|
+
|
657
|
+
|
658
|
+
# ------------------------------------------------------------------------------
|
659
|
+
# Wrapper for scipy's L-BFGS-B implementation.
|
660
|
+
# ------------------------------------------------------------------------------
|
661
|
+
|
662
|
+
|
663
|
+
@dataclasses.dataclass
|
664
|
+
class ScipyLbfgsbState:
|
665
|
+
"""Stores the state of a scipy L-BFGS-B minimization.
|
666
|
+
|
667
|
+
This class enables optimization with a more functional style, giving the user
|
668
|
+
control over the optimization loop. Example usage is as follows:
|
669
|
+
|
670
|
+
value_fn = lambda x: onp.sum(x**2)
|
671
|
+
grad_fn = lambda x: 2 * x
|
672
|
+
|
673
|
+
x0 = onp.asarray([0.1, 0.2, 0.3])
|
674
|
+
lb = [None, -1, 0.1]
|
675
|
+
ub = [None, None, None]
|
676
|
+
state = ScipyLbfgsbState.init(
|
677
|
+
x0=x0, lower_bound=lb, upper_bound=ub, maxcor=20
|
678
|
+
)
|
679
|
+
|
680
|
+
for _ in range(10):
|
681
|
+
value = value_fn(state.x)
|
682
|
+
grad = grad_fn(state.x)
|
683
|
+
state.update(grad, value)
|
684
|
+
|
685
|
+
This example converges with `state.x` equal to `(0, 0, 0.1)` and value equal
|
686
|
+
to `0.01`.
|
687
|
+
|
688
|
+
Attributes:
|
689
|
+
x: The current solution vector.
|
690
|
+
"""
|
691
|
+
|
692
|
+
x: NDArray
|
693
|
+
converged: NDArray
|
694
|
+
# Private attributes correspond to internal variables in the `scipy.optimize.
|
695
|
+
# lbfgsb._minimize_lbfgsb` function.
|
696
|
+
_maxcor: int
|
697
|
+
_line_search_max_steps: int
|
698
|
+
_ftol: NDArray
|
699
|
+
_gtol: NDArray
|
700
|
+
_wa: NDArray
|
701
|
+
_iwa: NDArray
|
702
|
+
_task: NDArray
|
703
|
+
_csave: NDArray
|
704
|
+
_lsave: NDArray
|
705
|
+
_isave: NDArray
|
706
|
+
_dsave: NDArray
|
707
|
+
_lower_bound: NDArray
|
708
|
+
_upper_bound: NDArray
|
709
|
+
_bound_type: NDArray
|
710
|
+
|
711
|
+
def __post_init__(self) -> None:
|
712
|
+
"""Validates the datatypes for all state attributes."""
|
713
|
+
_validate_array_dtype(self.x, onp.float64)
|
714
|
+
_validate_array_dtype(self._wa, onp.float64)
|
715
|
+
_validate_array_dtype(self._iwa, FORTRAN_INT)
|
716
|
+
_validate_array_dtype(self._task, "S60")
|
717
|
+
_validate_array_dtype(self._csave, "S60")
|
718
|
+
_validate_array_dtype(self._lsave, FORTRAN_INT)
|
719
|
+
_validate_array_dtype(self._isave, FORTRAN_INT)
|
720
|
+
_validate_array_dtype(self._dsave, onp.float64)
|
721
|
+
_validate_array_dtype(self._lower_bound, onp.float64)
|
722
|
+
_validate_array_dtype(self._upper_bound, onp.float64)
|
723
|
+
_validate_array_dtype(self._bound_type, int)
|
724
|
+
|
725
|
+
def to_dict(self) -> NumpyLbfgsbDict:
|
726
|
+
"""Generates a dictionary of jax arrays defining the state."""
|
727
|
+
return dict(
|
728
|
+
x=onp.asarray(self.x),
|
729
|
+
converged=onp.asarray(self.converged),
|
730
|
+
_maxcor=onp.asarray(self._maxcor),
|
731
|
+
_line_search_max_steps=onp.asarray(self._line_search_max_steps),
|
732
|
+
_ftol=onp.asarray(self._ftol),
|
733
|
+
_gtol=onp.asarray(self._gtol),
|
734
|
+
_wa=onp.asarray(self._wa),
|
735
|
+
_iwa=onp.asarray(self._iwa),
|
736
|
+
_task=_array_from_s60_str(self._task),
|
737
|
+
_csave=_array_from_s60_str(self._csave),
|
738
|
+
_lsave=onp.asarray(self._lsave),
|
739
|
+
_isave=onp.asarray(self._isave),
|
740
|
+
_dsave=onp.asarray(self._dsave),
|
741
|
+
_lower_bound=onp.asarray(self._lower_bound),
|
742
|
+
_upper_bound=onp.asarray(self._upper_bound),
|
743
|
+
_bound_type=onp.asarray(self._bound_type),
|
744
|
+
)
|
745
|
+
|
746
|
+
@classmethod
|
747
|
+
def from_jax(cls, state_dict: JaxLbfgsbDict) -> "ScipyLbfgsbState":
|
748
|
+
"""Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
|
749
|
+
return ScipyLbfgsbState(
|
750
|
+
x=onp.array(state_dict["x"], dtype=onp.float64),
|
751
|
+
converged=onp.asarray(state_dict["converged"], dtype=bool),
|
752
|
+
_maxcor=int(state_dict["_maxcor"]),
|
753
|
+
_line_search_max_steps=int(state_dict["_line_search_max_steps"]),
|
754
|
+
_ftol=onp.asarray(state_dict["_ftol"], dtype=onp.float64),
|
755
|
+
_gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
|
756
|
+
_wa=onp.array(state_dict["_wa"], onp.float64),
|
757
|
+
_iwa=onp.array(state_dict["_iwa"], dtype=FORTRAN_INT),
|
758
|
+
_task=_s60_str_from_array(onp.asarray(state_dict["_task"])),
|
759
|
+
_csave=_s60_str_from_array(onp.asarray(state_dict["_csave"])),
|
760
|
+
_lsave=onp.array(state_dict["_lsave"], dtype=FORTRAN_INT),
|
761
|
+
_isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
|
762
|
+
_dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
|
763
|
+
_lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
|
764
|
+
_upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
|
765
|
+
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
|
766
|
+
)
|
767
|
+
|
768
|
+
@classmethod
|
769
|
+
def init(
|
770
|
+
cls,
|
771
|
+
x0: NDArray,
|
772
|
+
lower_bound: ElementwiseBound,
|
773
|
+
upper_bound: ElementwiseBound,
|
774
|
+
maxcor: int,
|
775
|
+
line_search_max_steps: int,
|
776
|
+
ftol: float,
|
777
|
+
gtol: float,
|
778
|
+
) -> "ScipyLbfgsbState":
|
779
|
+
"""Initializes the `ScipyLbfgsbState` for `x0`.
|
780
|
+
|
781
|
+
Args:
|
782
|
+
x0: Array giving the initial solution vector.
|
783
|
+
lower_bound: Array giving the elementwise optional lower bound.
|
784
|
+
upper_bound: Array giving the elementwise optional upper bound.
|
785
|
+
maxcor: The maximum number of variable metric corrections used to define
|
786
|
+
the limited memory matrix, in the L-BFGS-B scheme.
|
787
|
+
line_search_max_steps: The maximum number of steps in the line search.
|
788
|
+
ftol: Tolerance for stopping criteria based on function values. See scipy
|
789
|
+
documentation for details.
|
790
|
+
gtol: Tolerance for stopping criteria based on gradient.
|
791
|
+
|
792
|
+
Returns:
|
793
|
+
The `ScipyLbfgsbState`.
|
794
|
+
"""
|
795
|
+
x0 = onp.asarray(x0)
|
796
|
+
if x0.ndim > 1:
|
797
|
+
raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
|
798
|
+
lower_bound = onp.asarray(lower_bound)
|
799
|
+
upper_bound = onp.asarray(upper_bound)
|
800
|
+
if x0.shape != lower_bound.shape or x0.shape != upper_bound.shape:
|
801
|
+
raise ValueError(
|
802
|
+
f"`x0`, `lower_bound`, and `upper_bound` must have matching "
|
803
|
+
f"shape but got shapes {x0.shape}, {lower_bound.shape}, and "
|
804
|
+
f"{upper_bound.shape}, respectively."
|
805
|
+
)
|
806
|
+
if maxcor < 1:
|
807
|
+
raise ValueError(f"`maxcor` must be positive but got {maxcor}.")
|
808
|
+
|
809
|
+
n = x0.size
|
810
|
+
lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
|
811
|
+
lower_bound, upper_bound
|
812
|
+
)
|
813
|
+
task = onp.zeros(1, "S60")
|
814
|
+
task[:] = TASK_START
|
815
|
+
|
816
|
+
# See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
|
817
|
+
# function.
|
818
|
+
wa_size = _wa_size(n=n, maxcor=maxcor)
|
819
|
+
state = ScipyLbfgsbState(
|
820
|
+
x=onp.array(x0, onp.float64),
|
821
|
+
converged=onp.asarray(False),
|
822
|
+
_maxcor=maxcor,
|
823
|
+
_line_search_max_steps=line_search_max_steps,
|
824
|
+
_ftol=onp.asarray(ftol, onp.float64),
|
825
|
+
_gtol=onp.asarray(gtol, onp.float64),
|
826
|
+
_wa=onp.zeros(wa_size, onp.float64),
|
827
|
+
_iwa=onp.zeros(3 * n, FORTRAN_INT),
|
828
|
+
_task=task,
|
829
|
+
_csave=onp.zeros(1, "S60"),
|
830
|
+
_lsave=onp.zeros(4, FORTRAN_INT),
|
831
|
+
_isave=onp.zeros(44, FORTRAN_INT),
|
832
|
+
_dsave=onp.zeros(29, onp.float64),
|
833
|
+
_lower_bound=lower_bound_array,
|
834
|
+
_upper_bound=upper_bound_array,
|
835
|
+
_bound_type=bound_type,
|
836
|
+
)
|
837
|
+
# The initial state requires an update with zero value and gradient. This
|
838
|
+
# is because the initial task is "START", which does not actually require
|
839
|
+
# value and gradient evaluation.
|
840
|
+
state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
|
841
|
+
return state
|
842
|
+
|
843
|
+
def update(
|
844
|
+
self,
|
845
|
+
grad: NDArray,
|
846
|
+
value: NDArray,
|
847
|
+
) -> None:
|
848
|
+
"""Performs an in-place update of the `ScipyLbfgsbState` if not converged.
|
849
|
+
|
850
|
+
Args:
|
851
|
+
grad: The function gradient for the current `x`.
|
852
|
+
value: The scalar function value for the current `x`.
|
853
|
+
"""
|
854
|
+
if self.converged:
|
855
|
+
return
|
856
|
+
if grad.shape != self.x.shape:
|
857
|
+
raise ValueError(
|
858
|
+
f"`grad` must have the same shape as attribute `x`, but got shapes "
|
859
|
+
f"{grad.shape} and {self.x.shape}, respectively."
|
860
|
+
)
|
861
|
+
if value.shape != ():
|
862
|
+
raise ValueError(f"`value` must be a scalar but got shape {value.shape}.")
|
863
|
+
|
864
|
+
# The `setulb` function will sometimes return with a task that does not
|
865
|
+
# require a value and gradient evaluation. In this case we simply call it
|
866
|
+
# again, advancing past such "dummy" steps.
|
867
|
+
for _ in range(3):
|
868
|
+
scipy_lbfgsb.setulb(
|
869
|
+
m=self._maxcor,
|
870
|
+
x=self.x,
|
871
|
+
l=self._lower_bound,
|
872
|
+
u=self._upper_bound,
|
873
|
+
nbd=self._bound_type,
|
874
|
+
f=value,
|
875
|
+
g=grad,
|
876
|
+
factr=self._ftol / onp.finfo(float).eps,
|
877
|
+
pgtol=self._gtol,
|
878
|
+
wa=self._wa,
|
879
|
+
iwa=self._iwa,
|
880
|
+
task=self._task,
|
881
|
+
iprint=UPDATE_IPRINT,
|
882
|
+
csave=self._csave,
|
883
|
+
lsave=self._lsave,
|
884
|
+
isave=self._isave,
|
885
|
+
dsave=self._dsave,
|
886
|
+
maxls=self._line_search_max_steps,
|
887
|
+
)
|
888
|
+
task_str = self._task.tobytes()
|
889
|
+
if task_str.startswith(TASK_CONVERGED):
|
890
|
+
self.converged = onp.asarray(True)
|
891
|
+
if task_str.startswith(TASK_FG):
|
892
|
+
break
|
893
|
+
|
894
|
+
|
895
|
+
def _wa_size(n: int, maxcor: int) -> int:
|
896
|
+
"""Return the size of the `wa` attribute of lbfgsb state."""
|
897
|
+
return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
|
898
|
+
|
899
|
+
|
900
|
+
def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
|
901
|
+
"""Validates that `x` is an array with the specified `dtype`."""
|
902
|
+
if not isinstance(x, onp.ndarray):
|
903
|
+
raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
|
904
|
+
if x.dtype != dtype:
|
905
|
+
raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
|
906
|
+
|
907
|
+
|
908
|
+
def _configure_bounds(
|
909
|
+
lower_bound: ElementwiseBound,
|
910
|
+
upper_bound: ElementwiseBound,
|
911
|
+
) -> Tuple[NDArray, NDArray, NDArray]:
|
912
|
+
"""Configures the bounds for an L-BFGS-B optimization."""
|
913
|
+
bound_type = [
|
914
|
+
BOUNDS_MAP[(lower is None, upper is None)]
|
915
|
+
for lower, upper in zip(lower_bound, upper_bound)
|
916
|
+
]
|
917
|
+
lower_bound_array = [0.0 if x is None else x for x in lower_bound]
|
918
|
+
upper_bound_array = [0.0 if x is None else x for x in upper_bound]
|
919
|
+
return (
|
920
|
+
onp.asarray(lower_bound_array, onp.float64),
|
921
|
+
onp.asarray(upper_bound_array, onp.float64),
|
922
|
+
onp.asarray(bound_type),
|
923
|
+
)
|
924
|
+
|
925
|
+
|
926
|
+
def _array_from_s60_str(s60_str: NDArray) -> NDArray:
|
927
|
+
"""Return a jax array for a numpy s60 string."""
|
928
|
+
assert s60_str.shape == (1,)
|
929
|
+
chars = [int(o) for o in s60_str[0]]
|
930
|
+
chars.extend([32] * (59 - len(chars)))
|
931
|
+
return onp.asarray(chars, dtype=int)
|
932
|
+
|
933
|
+
|
934
|
+
def _s60_str_from_array(array: NDArray) -> NDArray:
|
935
|
+
"""Return a numpy s60 string for a jax array."""
|
936
|
+
return onp.asarray(
|
937
|
+
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
|
938
|
+
dtype="S60",
|
939
|
+
)
|