invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +7 -4
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -672
- invrs_opt-0.4.0.dist-info/LICENSE +0 -21
- invrs_opt-0.4.0.dist-info/METADATA +0 -75
- invrs_opt-0.4.0.dist-info/RECORD +0 -14
- /invrs_opt/{lbfgsb → optimizers}/__init__.py +0 -0
- {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
invrs_opt/lbfgsb/lbfgsb.py
DELETED
@@ -1,672 +0,0 @@
|
|
1
|
-
"""Defines a jax-style wrapper for scipy's L-BFGS-B algorithm.
|
2
|
-
|
3
|
-
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
-
"""
|
5
|
-
|
6
|
-
import copy
|
7
|
-
import dataclasses
|
8
|
-
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
9
|
-
|
10
|
-
import jax
|
11
|
-
import jax.numpy as jnp
|
12
|
-
import numpy as onp
|
13
|
-
from jax import flatten_util, tree_util
|
14
|
-
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
15
|
-
_lbfgsb as scipy_lbfgsb,
|
16
|
-
)
|
17
|
-
from totypes import types
|
18
|
-
|
19
|
-
from invrs_opt import base
|
20
|
-
from invrs_opt.lbfgsb import transform
|
21
|
-
|
22
|
-
NDArray = onp.ndarray[Any, Any]
|
23
|
-
PyTree = Any
|
24
|
-
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
25
|
-
JaxLbfgsbDict = Dict[str, jnp.ndarray]
|
26
|
-
LbfgsbState = Tuple[PyTree, PyTree, JaxLbfgsbDict]
|
27
|
-
|
28
|
-
|
29
|
-
# Task message prefixes for the underlying L-BFGS-B implementation.
|
30
|
-
TASK_START = b"START"
|
31
|
-
TASK_FG = b"FG"
|
32
|
-
|
33
|
-
# Parameters which configure the state update step.
|
34
|
-
UPDATE_IPRINT = -1
|
35
|
-
UPDATE_PGTOL = 0.0
|
36
|
-
UPDATE_FACTR = 0.0
|
37
|
-
|
38
|
-
# Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
|
39
|
-
MAXCOR_MAX_VALUE = 100
|
40
|
-
MAXCOR_DEFAULT = 20
|
41
|
-
LINE_SEARCH_MAX_STEPS_DEFAULT = 100
|
42
|
-
|
43
|
-
# Maps bound scenarios to integers.
|
44
|
-
BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
|
45
|
-
(True, True): 0, # Both upper and lower bound are `None`.
|
46
|
-
(False, True): 1, # Only upper bound is `None`.
|
47
|
-
(False, False): 2, # Neither of the bounds are `None`.
|
48
|
-
(True, False): 3, # Only the lower bound is `None`.
|
49
|
-
}
|
50
|
-
|
51
|
-
FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
|
52
|
-
|
53
|
-
|
54
|
-
def lbfgsb(
|
55
|
-
maxcor: int = MAXCOR_DEFAULT,
|
56
|
-
line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
|
57
|
-
) -> base.Optimizer:
|
58
|
-
"""Return an optimizer implementing the standard L-BFGS-B algorithm.
|
59
|
-
|
60
|
-
This optimizer wraps scipy's implementation of the algorithm, and provides
|
61
|
-
a jax-style API to the scheme. The optimizer works with custom types such
|
62
|
-
as the `BoundedArray` to constrain the optimization variable.
|
63
|
-
|
64
|
-
Example usage is as follows:
|
65
|
-
|
66
|
-
def fn(x):
|
67
|
-
leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
|
68
|
-
return jnp.sum(jnp.asarray(leaves_sum_sq))
|
69
|
-
|
70
|
-
x0 = {
|
71
|
-
"a": jnp.ones((3,)),
|
72
|
-
"b": BoundedArray(
|
73
|
-
value=-jnp.ones((2, 5)),
|
74
|
-
lower_bound=-5,
|
75
|
-
upper_bound=5,
|
76
|
-
),
|
77
|
-
}
|
78
|
-
opt = lbfgsb(maxcor=20, line_search_max_steps=100)
|
79
|
-
state = opt.init(x0)
|
80
|
-
for _ in range(10):
|
81
|
-
x = opt.params(state)
|
82
|
-
value, grad = jax.value_and_grad(fn)(x)
|
83
|
-
state = opt.update(grad, value, state)
|
84
|
-
|
85
|
-
While the algorithm can work with pytrees of jax arrays, numpy arrays can
|
86
|
-
also be used. Thus, e.g. the optimizer can directly be used with autograd.
|
87
|
-
|
88
|
-
Args:
|
89
|
-
maxcor: The maximum number of variable metric corrections used to define
|
90
|
-
the limited memory matrix, in the L-BFGS-B scheme.
|
91
|
-
line_search_max_steps: The maximum number of steps in the line search.
|
92
|
-
|
93
|
-
Returns:
|
94
|
-
The `base.Optimizer`.
|
95
|
-
"""
|
96
|
-
return transformed_lbfgsb(
|
97
|
-
maxcor=maxcor,
|
98
|
-
line_search_max_steps=line_search_max_steps,
|
99
|
-
transform_fn=lambda x: x,
|
100
|
-
initialize_latent_fn=lambda x: x,
|
101
|
-
)
|
102
|
-
|
103
|
-
|
104
|
-
def density_lbfgsb(
|
105
|
-
beta: float,
|
106
|
-
maxcor: int = MAXCOR_DEFAULT,
|
107
|
-
line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
|
108
|
-
) -> base.Optimizer:
|
109
|
-
"""Return an L-BFGS-B optimizer with additional transforms for density arrays.
|
110
|
-
|
111
|
-
Parameters that are of type `DensityArray2D` are represented as latent parameters
|
112
|
-
that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
|
113
|
-
|
114
|
-
transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
|
115
|
-
|
116
|
-
where the kernel has a full-width at half-maximum determined by the minimum width
|
117
|
-
and spacing parameters of the `DensityArray2D`. Where the bounds differ, the
|
118
|
-
density is scaled before the transform is applied, and then unscaled afterwards.
|
119
|
-
|
120
|
-
Args:
|
121
|
-
beta: Determines the steepness of the thresholding.
|
122
|
-
maxcor: The maximum number of variable metric corrections used to define
|
123
|
-
the limited memory matrix, in the L-BFGS-B scheme.
|
124
|
-
line_search_max_steps: The maximum number of steps in the line search.
|
125
|
-
|
126
|
-
Returns:
|
127
|
-
The `base.Optimizer`.
|
128
|
-
"""
|
129
|
-
|
130
|
-
def transform_fn(tree: PyTree) -> PyTree:
|
131
|
-
return tree_util.tree_map(
|
132
|
-
lambda x: transform_density(x) if _is_density(x) else x,
|
133
|
-
tree,
|
134
|
-
is_leaf=_is_density,
|
135
|
-
)
|
136
|
-
|
137
|
-
def initialize_latent_fn(tree: PyTree) -> PyTree:
|
138
|
-
return tree_util.tree_map(
|
139
|
-
lambda x: initialize_latent_density(x) if _is_density(x) else x,
|
140
|
-
tree,
|
141
|
-
is_leaf=_is_density,
|
142
|
-
)
|
143
|
-
|
144
|
-
def transform_density(density: types.Density2DArray) -> types.Density2DArray:
|
145
|
-
transformed = types.symmetrize_density(density)
|
146
|
-
transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
|
147
|
-
# Scale to ensure that the full valid range of the density array is reachable.
|
148
|
-
mid_value = (density.lower_bound + density.upper_bound) / 2
|
149
|
-
transformed = tree_util.tree_map(
|
150
|
-
lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
|
151
|
-
)
|
152
|
-
return transform.apply_fixed_pixels(transformed)
|
153
|
-
|
154
|
-
def initialize_latent_density(
|
155
|
-
density: types.Density2DArray,
|
156
|
-
) -> types.Density2DArray:
|
157
|
-
array = transform.normalized_array_from_density(density)
|
158
|
-
array = jnp.clip(array, -1, 1)
|
159
|
-
array *= jnp.tanh(beta)
|
160
|
-
latent_array = jnp.arctanh(array) / beta
|
161
|
-
latent_array = transform.rescale_array_for_density(latent_array, density)
|
162
|
-
return dataclasses.replace(density, array=latent_array)
|
163
|
-
|
164
|
-
return transformed_lbfgsb(
|
165
|
-
maxcor=maxcor,
|
166
|
-
line_search_max_steps=line_search_max_steps,
|
167
|
-
transform_fn=transform_fn,
|
168
|
-
initialize_latent_fn=initialize_latent_fn,
|
169
|
-
)
|
170
|
-
|
171
|
-
|
172
|
-
def transformed_lbfgsb(
|
173
|
-
maxcor: int,
|
174
|
-
line_search_max_steps: int,
|
175
|
-
transform_fn: Callable[[PyTree], PyTree],
|
176
|
-
initialize_latent_fn: Callable[[PyTree], PyTree],
|
177
|
-
) -> base.Optimizer:
|
178
|
-
"""Construct an latent parameter L-BFGS-B optimizer.
|
179
|
-
|
180
|
-
The optimized parameters are termed latent parameters, from which the
|
181
|
-
actual parameters returned by the optimizer are obtained using the
|
182
|
-
`transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
|
183
|
-
the identity), this is equivalent to the standard L-BFGS-B algorithm.
|
184
|
-
|
185
|
-
Args:
|
186
|
-
maxcor: The maximum number of variable metric corrections used to define
|
187
|
-
the limited memory matrix, in the L-BFGS-B scheme.
|
188
|
-
line_search_max_steps: The maximum number of steps in the line search.
|
189
|
-
transform_fn: Function which transforms the internal latent parameters to
|
190
|
-
the parameters returned by the optimizer.
|
191
|
-
initialize_latent_fn: Function which computes the initial latent parameters
|
192
|
-
given the initial parameters.
|
193
|
-
|
194
|
-
Returns:
|
195
|
-
The `base.Optimizer`.
|
196
|
-
"""
|
197
|
-
if not isinstance(maxcor, int) or maxcor < 1 or maxcor > MAXCOR_MAX_VALUE:
|
198
|
-
raise ValueError(
|
199
|
-
f"`maxcor` must be greater than 0 and less than "
|
200
|
-
f"{MAXCOR_MAX_VALUE}, but got {maxcor}"
|
201
|
-
)
|
202
|
-
|
203
|
-
if not isinstance(line_search_max_steps, int) or line_search_max_steps < 1:
|
204
|
-
raise ValueError(
|
205
|
-
f"`line_search_max_steps` must be greater than 0 but got "
|
206
|
-
f"{line_search_max_steps}"
|
207
|
-
)
|
208
|
-
|
209
|
-
def init_fn(params: PyTree) -> LbfgsbState:
|
210
|
-
"""Initializes the optimization state."""
|
211
|
-
|
212
|
-
def _init_pure(params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
|
213
|
-
lower_bound = types.extract_lower_bound(params)
|
214
|
-
upper_bound = types.extract_upper_bound(params)
|
215
|
-
scipy_lbfgsb_state = ScipyLbfgsbState.init(
|
216
|
-
x0=_to_numpy(params),
|
217
|
-
lower_bound=_bound_for_params(lower_bound, params),
|
218
|
-
upper_bound=_bound_for_params(upper_bound, params),
|
219
|
-
maxcor=maxcor,
|
220
|
-
line_search_max_steps=line_search_max_steps,
|
221
|
-
)
|
222
|
-
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
223
|
-
return latent_params, scipy_lbfgsb_state.to_jax()
|
224
|
-
|
225
|
-
(
|
226
|
-
latent_params,
|
227
|
-
jax_lbfgsb_state,
|
228
|
-
) = jax.pure_callback( # type: ignore[attr-defined]
|
229
|
-
_init_pure,
|
230
|
-
_example_state(params, maxcor),
|
231
|
-
initialize_latent_fn(params),
|
232
|
-
)
|
233
|
-
return transform_fn(latent_params), latent_params, jax_lbfgsb_state
|
234
|
-
|
235
|
-
def params_fn(state: LbfgsbState) -> PyTree:
|
236
|
-
"""Returns the parameters for the given `state`."""
|
237
|
-
params, _, _ = state
|
238
|
-
return params
|
239
|
-
|
240
|
-
def update_fn(
|
241
|
-
*,
|
242
|
-
grad: PyTree,
|
243
|
-
value: float,
|
244
|
-
params: PyTree,
|
245
|
-
state: LbfgsbState,
|
246
|
-
) -> LbfgsbState:
|
247
|
-
"""Updates the state."""
|
248
|
-
del params
|
249
|
-
|
250
|
-
def _update_pure(
|
251
|
-
flat_latent_grad: PyTree,
|
252
|
-
value: jnp.ndarray,
|
253
|
-
jax_lbfgsb_state: JaxLbfgsbDict,
|
254
|
-
) -> Tuple[PyTree, JaxLbfgsbDict]:
|
255
|
-
assert onp.size(value) == 1
|
256
|
-
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
|
257
|
-
scipy_lbfgsb_state.update(
|
258
|
-
grad=onp.asarray(flat_latent_grad, dtype=onp.float64),
|
259
|
-
value=onp.asarray(value, dtype=onp.float64),
|
260
|
-
)
|
261
|
-
flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
|
262
|
-
return flat_latent_params, scipy_lbfgsb_state.to_jax()
|
263
|
-
|
264
|
-
_, latent_params, jax_lbfgsb_state = state
|
265
|
-
_, vjp_fn = jax.vjp(transform_fn, latent_params)
|
266
|
-
(latent_grad,) = vjp_fn(grad)
|
267
|
-
flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
|
268
|
-
latent_grad
|
269
|
-
) # type: ignore[no-untyped-call]
|
270
|
-
|
271
|
-
(
|
272
|
-
flat_latent_params,
|
273
|
-
jax_lbfgsb_state,
|
274
|
-
) = jax.pure_callback( # type: ignore[attr-defined]
|
275
|
-
_update_pure,
|
276
|
-
(flat_latent_grad, jax_lbfgsb_state),
|
277
|
-
flat_latent_grad,
|
278
|
-
value,
|
279
|
-
jax_lbfgsb_state,
|
280
|
-
)
|
281
|
-
latent_params = unflatten_fn(flat_latent_params)
|
282
|
-
return transform_fn(latent_params), latent_params, jax_lbfgsb_state
|
283
|
-
|
284
|
-
return base.Optimizer(
|
285
|
-
init=init_fn,
|
286
|
-
params=params_fn,
|
287
|
-
update=update_fn,
|
288
|
-
)
|
289
|
-
|
290
|
-
|
291
|
-
# ------------------------------------------------------------------------------
|
292
|
-
# Helper functions.
|
293
|
-
# ------------------------------------------------------------------------------
|
294
|
-
|
295
|
-
|
296
|
-
def _is_density(leaf: Any) -> Any:
|
297
|
-
"""Return `True` if `leaf` is a density array."""
|
298
|
-
return isinstance(leaf, types.Density2DArray)
|
299
|
-
|
300
|
-
|
301
|
-
def _to_numpy(params: PyTree) -> NDArray:
|
302
|
-
"""Flattens a `params` pytree into a single rank-1 numpy array."""
|
303
|
-
x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
304
|
-
return onp.asarray(x, dtype=onp.float64)
|
305
|
-
|
306
|
-
|
307
|
-
def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
|
308
|
-
"""Restores a pytree from a flat numpy array using the structure of `params`.
|
309
|
-
|
310
|
-
Note that the returned pytree includes jax array leaves.
|
311
|
-
|
312
|
-
Args:
|
313
|
-
x_flat: The rank-1 numpy array to be restored.
|
314
|
-
params: A pytree of parameters whose structure is replicated in the restored
|
315
|
-
pytree.
|
316
|
-
|
317
|
-
Returns:
|
318
|
-
The restored pytree, with jax array leaves.
|
319
|
-
"""
|
320
|
-
_, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
321
|
-
return unflatten_fn(jnp.asarray(x_flat, dtype=float))
|
322
|
-
|
323
|
-
|
324
|
-
def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
325
|
-
"""Generates a bound vector for the `params`.
|
326
|
-
|
327
|
-
The `bound` can be specified in various ways; it may be `None` or a scalar,
|
328
|
-
which then applies to all arrays in `params`. It may be a pytree with
|
329
|
-
structure matching that of `params`, where each leaf is either `None`, a
|
330
|
-
scalar, or an array matching the shape of the corresponding leaf in `params`.
|
331
|
-
|
332
|
-
The returned bound is a flat array suitable for use with `ScipyLbfgsbState`.
|
333
|
-
|
334
|
-
Args:
|
335
|
-
bound: The pytree of bounds.
|
336
|
-
params: The pytree of parameters.
|
337
|
-
|
338
|
-
Returns:
|
339
|
-
The flat elementwise bound.
|
340
|
-
"""
|
341
|
-
|
342
|
-
if bound is None or onp.isscalar(bound):
|
343
|
-
bound = tree_util.tree_map(
|
344
|
-
lambda _: bound,
|
345
|
-
params,
|
346
|
-
is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES),
|
347
|
-
)
|
348
|
-
|
349
|
-
bound_leaves, bound_treedef = tree_util.tree_flatten(
|
350
|
-
bound, is_leaf=lambda x: x is None
|
351
|
-
)
|
352
|
-
params_leaves = tree_util.tree_leaves(params, is_leaf=lambda x: x is None)
|
353
|
-
|
354
|
-
# `bound` should be a pytree of arrays or `None`, while `params` may
|
355
|
-
# include custom pytree nodes. Convert the custom nodes into standard
|
356
|
-
# types to facilitate validation that the tree structures match.
|
357
|
-
params_treedef = tree_util.tree_structure(
|
358
|
-
tree_util.tree_map(
|
359
|
-
lambda x: 0.0,
|
360
|
-
tree=params,
|
361
|
-
is_leaf=lambda x: x is None or isinstance(x, types.CUSTOM_TYPES),
|
362
|
-
)
|
363
|
-
)
|
364
|
-
if bound_treedef != params_treedef: # type: ignore[operator]
|
365
|
-
raise ValueError(
|
366
|
-
f"Tree structure of `bound` and `params` must match, but got "
|
367
|
-
f"{bound_treedef} and {params_treedef}, respectively."
|
368
|
-
)
|
369
|
-
|
370
|
-
bound_flat = []
|
371
|
-
for b, p in zip(bound_leaves, params_leaves):
|
372
|
-
if p is None:
|
373
|
-
continue
|
374
|
-
if b is None or onp.isscalar(b) or onp.shape(b) == ():
|
375
|
-
bound_flat += [b] * onp.size(p)
|
376
|
-
else:
|
377
|
-
if b.shape != p.shape:
|
378
|
-
raise ValueError(
|
379
|
-
f"`bound` must be `None`, a scalar, or have shape matching "
|
380
|
-
f"`params`, but got shape {b.shape} when params has shape "
|
381
|
-
f"{p.shape}."
|
382
|
-
)
|
383
|
-
bound_flat += b.flatten().tolist()
|
384
|
-
|
385
|
-
return bound_flat
|
386
|
-
|
387
|
-
|
388
|
-
def _example_state(params: PyTree, maxcor: int) -> PyTree:
|
389
|
-
"""Return an example state for the given `params` and `maxcor`."""
|
390
|
-
params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
391
|
-
n = params_flat.size
|
392
|
-
float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
|
393
|
-
example_jax_lbfgsb_state = dict(
|
394
|
-
x=jnp.zeros(n, dtype=float),
|
395
|
-
_maxcor=jnp.zeros((), dtype=int),
|
396
|
-
_line_search_max_steps=jnp.zeros((), dtype=int),
|
397
|
-
_wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
|
398
|
-
_iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
|
399
|
-
_task=jnp.zeros(59, dtype=int),
|
400
|
-
_csave=jnp.zeros(59, dtype=int),
|
401
|
-
_lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
|
402
|
-
_isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
|
403
|
-
_dsave=jnp.zeros(29, dtype=float),
|
404
|
-
_lower_bound=jnp.zeros(n, dtype=float),
|
405
|
-
_upper_bound=jnp.zeros(n, dtype=float),
|
406
|
-
_bound_type=jnp.zeros(n, dtype=int),
|
407
|
-
)
|
408
|
-
return float_params, example_jax_lbfgsb_state
|
409
|
-
|
410
|
-
|
411
|
-
# ------------------------------------------------------------------------------
|
412
|
-
# Wrapper for scipy's L-BFGS-B implementation.
|
413
|
-
# ------------------------------------------------------------------------------
|
414
|
-
|
415
|
-
|
416
|
-
@dataclasses.dataclass
|
417
|
-
class ScipyLbfgsbState:
|
418
|
-
"""Stores the state of a scipy L-BFGS-B minimization.
|
419
|
-
|
420
|
-
This class enables optimization with a more functional style, giving the user
|
421
|
-
control over the optimization loop. Example usage is as follows:
|
422
|
-
|
423
|
-
value_fn = lambda x: onp.sum(x**2)
|
424
|
-
grad_fn = lambda x: 2 * x
|
425
|
-
|
426
|
-
x0 = onp.asarray([0.1, 0.2, 0.3])
|
427
|
-
lb = [None, -1, 0.1]
|
428
|
-
ub = [None, None, None]
|
429
|
-
state = ScipyLbfgsbState.init(
|
430
|
-
x0=x0, lower_bound=lb, upper_bound=ub, maxcor=20
|
431
|
-
)
|
432
|
-
|
433
|
-
for _ in range(10):
|
434
|
-
value = value_fn(state.x)
|
435
|
-
grad = grad_fn(state.x)
|
436
|
-
state.update(grad, value)
|
437
|
-
|
438
|
-
This example converges with `state.x` equal to `(0, 0, 0.1)` and value equal
|
439
|
-
to `0.01`.
|
440
|
-
|
441
|
-
Attributes:
|
442
|
-
x: The current solution vector.
|
443
|
-
"""
|
444
|
-
|
445
|
-
x: NDArray
|
446
|
-
# Private attributes correspond to internal variables in the `scipy.optimize.
|
447
|
-
# lbfgsb._minimize_lbfgsb` function.
|
448
|
-
_maxcor: int
|
449
|
-
_line_search_max_steps: int
|
450
|
-
_wa: NDArray
|
451
|
-
_iwa: NDArray
|
452
|
-
_task: NDArray
|
453
|
-
_csave: NDArray
|
454
|
-
_lsave: NDArray
|
455
|
-
_isave: NDArray
|
456
|
-
_dsave: NDArray
|
457
|
-
_lower_bound: NDArray
|
458
|
-
_upper_bound: NDArray
|
459
|
-
_bound_type: NDArray
|
460
|
-
|
461
|
-
def __post_init__(self) -> None:
|
462
|
-
"""Validates the datatypes for all state attributes."""
|
463
|
-
_validate_array_dtype(self.x, onp.float64)
|
464
|
-
_validate_array_dtype(self._wa, onp.float64)
|
465
|
-
_validate_array_dtype(self._iwa, FORTRAN_INT)
|
466
|
-
_validate_array_dtype(self._task, "S60")
|
467
|
-
_validate_array_dtype(self._csave, "S60")
|
468
|
-
_validate_array_dtype(self._lsave, FORTRAN_INT)
|
469
|
-
_validate_array_dtype(self._isave, FORTRAN_INT)
|
470
|
-
_validate_array_dtype(self._dsave, onp.float64)
|
471
|
-
_validate_array_dtype(self._lower_bound, onp.float64)
|
472
|
-
_validate_array_dtype(self._upper_bound, onp.float64)
|
473
|
-
_validate_array_dtype(self._bound_type, int)
|
474
|
-
|
475
|
-
def to_jax(self) -> Dict[str, jnp.ndarray]:
|
476
|
-
"""Generates a dictionary of jax arrays defining the state."""
|
477
|
-
return dict(
|
478
|
-
x=jnp.asarray(self.x),
|
479
|
-
_maxcor=jnp.asarray(self._maxcor),
|
480
|
-
_line_search_max_steps=jnp.asarray(self._line_search_max_steps),
|
481
|
-
_wa=jnp.asarray(self._wa),
|
482
|
-
_iwa=jnp.asarray(self._iwa),
|
483
|
-
_task=_array_from_s60_str(self._task),
|
484
|
-
_csave=_array_from_s60_str(self._csave),
|
485
|
-
_lsave=jnp.asarray(self._lsave),
|
486
|
-
_isave=jnp.asarray(self._isave),
|
487
|
-
_dsave=jnp.asarray(self._dsave),
|
488
|
-
_lower_bound=jnp.asarray(self._lower_bound),
|
489
|
-
_upper_bound=jnp.asarray(self._upper_bound),
|
490
|
-
_bound_type=jnp.asarray(self._bound_type),
|
491
|
-
)
|
492
|
-
|
493
|
-
@classmethod
|
494
|
-
def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
|
495
|
-
"""Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
|
496
|
-
state_dict = copy.deepcopy(state_dict)
|
497
|
-
return ScipyLbfgsbState(
|
498
|
-
x=onp.asarray(state_dict["x"], dtype=onp.float64),
|
499
|
-
_maxcor=int(state_dict["_maxcor"]),
|
500
|
-
_line_search_max_steps=int(state_dict["_line_search_max_steps"]),
|
501
|
-
_wa=onp.asarray(state_dict["_wa"], onp.float64),
|
502
|
-
_iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
|
503
|
-
_task=_s60_str_from_array(state_dict["_task"]),
|
504
|
-
_csave=_s60_str_from_array(state_dict["_csave"]),
|
505
|
-
_lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
|
506
|
-
_isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
|
507
|
-
_dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
|
508
|
-
_lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
|
509
|
-
_upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
|
510
|
-
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
|
511
|
-
)
|
512
|
-
|
513
|
-
@classmethod
|
514
|
-
def init(
|
515
|
-
cls,
|
516
|
-
x0: NDArray,
|
517
|
-
lower_bound: ElementwiseBound,
|
518
|
-
upper_bound: ElementwiseBound,
|
519
|
-
maxcor: int,
|
520
|
-
line_search_max_steps: int,
|
521
|
-
) -> "ScipyLbfgsbState":
|
522
|
-
"""Initializes the `ScipyLbfgsbState` for `x0`.
|
523
|
-
|
524
|
-
Args:
|
525
|
-
x0: Array giving the initial solution vector.
|
526
|
-
lower_bound: Array giving the elementwise optional lower bound.
|
527
|
-
upper_bound: Array giving the elementwise optional upper bound.
|
528
|
-
maxcor: The maximum number of variable metric corrections used to define
|
529
|
-
the limited memory matrix, in the L-BFGS-B scheme.
|
530
|
-
line_search_max_steps: The maximum number of steps in the line search.
|
531
|
-
|
532
|
-
Returns:
|
533
|
-
The `ScipyLbfgsbState`.
|
534
|
-
"""
|
535
|
-
x0 = onp.asarray(x0)
|
536
|
-
if x0.ndim > 1:
|
537
|
-
raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
|
538
|
-
lower_bound = onp.asarray(lower_bound)
|
539
|
-
upper_bound = onp.asarray(upper_bound)
|
540
|
-
if x0.shape != lower_bound.shape or x0.shape != upper_bound.shape:
|
541
|
-
raise ValueError(
|
542
|
-
f"`x0`, `lower_bound`, and `upper_bound` must have matching "
|
543
|
-
f"shape but got shapes {x0.shape}, {lower_bound.shape}, and "
|
544
|
-
f"{upper_bound.shape}, respectively."
|
545
|
-
)
|
546
|
-
if maxcor < 1:
|
547
|
-
raise ValueError(f"`maxcor` must be positive but got {maxcor}.")
|
548
|
-
|
549
|
-
n = x0.size
|
550
|
-
lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
|
551
|
-
lower_bound, upper_bound
|
552
|
-
)
|
553
|
-
task = onp.zeros(1, "S60")
|
554
|
-
task[:] = TASK_START
|
555
|
-
|
556
|
-
# See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
|
557
|
-
# function.
|
558
|
-
wa_size = _wa_size(n=n, maxcor=maxcor)
|
559
|
-
state = ScipyLbfgsbState(
|
560
|
-
x=onp.array(x0, onp.float64),
|
561
|
-
_maxcor=maxcor,
|
562
|
-
_line_search_max_steps=line_search_max_steps,
|
563
|
-
_wa=onp.zeros(wa_size, onp.float64),
|
564
|
-
_iwa=onp.zeros(3 * n, FORTRAN_INT),
|
565
|
-
_task=task,
|
566
|
-
_csave=onp.zeros(1, "S60"),
|
567
|
-
_lsave=onp.zeros(4, FORTRAN_INT),
|
568
|
-
_isave=onp.zeros(44, FORTRAN_INT),
|
569
|
-
_dsave=onp.zeros(29, onp.float64),
|
570
|
-
_lower_bound=lower_bound_array,
|
571
|
-
_upper_bound=upper_bound_array,
|
572
|
-
_bound_type=bound_type,
|
573
|
-
)
|
574
|
-
# The initial state requires an update with zero value and gradient. This
|
575
|
-
# is because the initial task is "START", which does not actually require
|
576
|
-
# value and gradient evaluation.
|
577
|
-
state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
|
578
|
-
return state
|
579
|
-
|
580
|
-
def update(
|
581
|
-
self,
|
582
|
-
grad: NDArray,
|
583
|
-
value: NDArray,
|
584
|
-
) -> None:
|
585
|
-
"""Performs an in-place update of the `ScipyLbfgsbState`.
|
586
|
-
|
587
|
-
Args:
|
588
|
-
grad: The function gradient for the current `x`.
|
589
|
-
value: The scalar function value for the current `x`.
|
590
|
-
"""
|
591
|
-
if grad.shape != self.x.shape:
|
592
|
-
raise ValueError(
|
593
|
-
f"`grad` must have the same shape as attribute `x`, but got shapes "
|
594
|
-
f"{grad.shape} and {self.x.shape}, respectively."
|
595
|
-
)
|
596
|
-
if value.shape != ():
|
597
|
-
raise ValueError(f"`value` must be a scalar but got shape {value.shape}.")
|
598
|
-
|
599
|
-
# The `setulb` function will sometimes return with a task that does not
|
600
|
-
# require a value and gradient evaluation. In this case we simply call it
|
601
|
-
# again, advancing past such "dummy" steps.
|
602
|
-
for _ in range(3):
|
603
|
-
scipy_lbfgsb.setulb(
|
604
|
-
m=self._maxcor,
|
605
|
-
x=self.x,
|
606
|
-
l=self._lower_bound,
|
607
|
-
u=self._upper_bound,
|
608
|
-
nbd=self._bound_type,
|
609
|
-
f=value,
|
610
|
-
g=grad,
|
611
|
-
factr=UPDATE_FACTR,
|
612
|
-
pgtol=UPDATE_PGTOL,
|
613
|
-
wa=self._wa,
|
614
|
-
iwa=self._iwa,
|
615
|
-
task=self._task,
|
616
|
-
iprint=UPDATE_IPRINT,
|
617
|
-
csave=self._csave,
|
618
|
-
lsave=self._lsave,
|
619
|
-
isave=self._isave,
|
620
|
-
dsave=self._dsave,
|
621
|
-
maxls=self._line_search_max_steps,
|
622
|
-
)
|
623
|
-
task_str = self._task.tobytes()
|
624
|
-
if task_str.startswith(TASK_FG):
|
625
|
-
break
|
626
|
-
|
627
|
-
|
628
|
-
def _wa_size(n: int, maxcor: int) -> int:
|
629
|
-
"""Return the size of the `wa` attribute of lbfgsb state."""
|
630
|
-
return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
|
631
|
-
|
632
|
-
|
633
|
-
def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
|
634
|
-
"""Validates that `x` is an array with the specified `dtype`."""
|
635
|
-
if not isinstance(x, onp.ndarray):
|
636
|
-
raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
|
637
|
-
if x.dtype != dtype:
|
638
|
-
raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
|
639
|
-
|
640
|
-
|
641
|
-
def _configure_bounds(
|
642
|
-
lower_bound: ElementwiseBound,
|
643
|
-
upper_bound: ElementwiseBound,
|
644
|
-
) -> Tuple[NDArray, NDArray, NDArray]:
|
645
|
-
"""Configures the bounds for an L-BFGS-B optimization."""
|
646
|
-
bound_type = [
|
647
|
-
BOUNDS_MAP[(lower is None, upper is None)]
|
648
|
-
for lower, upper in zip(lower_bound, upper_bound)
|
649
|
-
]
|
650
|
-
lower_bound_array = [0.0 if x is None else x for x in lower_bound]
|
651
|
-
upper_bound_array = [0.0 if x is None else x for x in upper_bound]
|
652
|
-
return (
|
653
|
-
onp.asarray(lower_bound_array, onp.float64),
|
654
|
-
onp.asarray(upper_bound_array, onp.float64),
|
655
|
-
onp.asarray(bound_type),
|
656
|
-
)
|
657
|
-
|
658
|
-
|
659
|
-
def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
|
660
|
-
"""Return a jax array for a numpy s60 string."""
|
661
|
-
assert s60_str.shape == (1,)
|
662
|
-
chars = [int(o) for o in s60_str[0]]
|
663
|
-
chars.extend([32] * (59 - len(chars)))
|
664
|
-
return jnp.asarray(chars, dtype=int)
|
665
|
-
|
666
|
-
|
667
|
-
def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
|
668
|
-
"""Return a numpy s60 string for a jax array."""
|
669
|
-
return onp.asarray(
|
670
|
-
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
|
671
|
-
dtype="S60",
|
672
|
-
)
|
@@ -1,21 +0,0 @@
|
|
1
|
-
MIT License
|
2
|
-
|
3
|
-
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
-
|
5
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
-
of this software and associated documentation files (the "Software"), to deal
|
7
|
-
in the Software without restriction, including without limitation the rights
|
8
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
-
copies of the Software, and to permit persons to whom the Software is
|
10
|
-
furnished to do so, subject to the following conditions:
|
11
|
-
|
12
|
-
The above copyright notice and this permission notice shall be included in all
|
13
|
-
copies or substantial portions of the Software.
|
14
|
-
|
15
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
-
SOFTWARE.
|