invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +7 -4
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -672
- invrs_opt-0.4.0.dist-info/LICENSE +0 -21
- invrs_opt-0.4.0.dist-info/METADATA +0 -75
- invrs_opt-0.4.0.dist-info/RECORD +0 -14
- /invrs_opt/{lbfgsb → optimizers}/__init__.py +0 -0
- {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
invrs_opt/lbfgsb/lbfgsb.py
DELETED
@@ -1,672 +0,0 @@
|
|
1
|
-
"""Defines a jax-style wrapper for scipy's L-BFGS-B algorithm.
|
2
|
-
|
3
|
-
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
-
"""
|
5
|
-
|
6
|
-
import copy
|
7
|
-
import dataclasses
|
8
|
-
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
9
|
-
|
10
|
-
import jax
|
11
|
-
import jax.numpy as jnp
|
12
|
-
import numpy as onp
|
13
|
-
from jax import flatten_util, tree_util
|
14
|
-
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
15
|
-
_lbfgsb as scipy_lbfgsb,
|
16
|
-
)
|
17
|
-
from totypes import types
|
18
|
-
|
19
|
-
from invrs_opt import base
|
20
|
-
from invrs_opt.lbfgsb import transform
|
21
|
-
|
22
|
-
NDArray = onp.ndarray[Any, Any]
|
23
|
-
PyTree = Any
|
24
|
-
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
25
|
-
JaxLbfgsbDict = Dict[str, jnp.ndarray]
|
26
|
-
LbfgsbState = Tuple[PyTree, PyTree, JaxLbfgsbDict]
|
27
|
-
|
28
|
-
|
29
|
-
# Task message prefixes for the underlying L-BFGS-B implementation.
|
30
|
-
TASK_START = b"START"
|
31
|
-
TASK_FG = b"FG"
|
32
|
-
|
33
|
-
# Parameters which configure the state update step.
|
34
|
-
UPDATE_IPRINT = -1
|
35
|
-
UPDATE_PGTOL = 0.0
|
36
|
-
UPDATE_FACTR = 0.0
|
37
|
-
|
38
|
-
# Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
|
39
|
-
MAXCOR_MAX_VALUE = 100
|
40
|
-
MAXCOR_DEFAULT = 20
|
41
|
-
LINE_SEARCH_MAX_STEPS_DEFAULT = 100
|
42
|
-
|
43
|
-
# Maps bound scenarios to integers.
|
44
|
-
BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
|
45
|
-
(True, True): 0, # Both upper and lower bound are `None`.
|
46
|
-
(False, True): 1, # Only upper bound is `None`.
|
47
|
-
(False, False): 2, # Neither of the bounds are `None`.
|
48
|
-
(True, False): 3, # Only the lower bound is `None`.
|
49
|
-
}
|
50
|
-
|
51
|
-
FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
|
52
|
-
|
53
|
-
|
54
|
-
def lbfgsb(
|
55
|
-
maxcor: int = MAXCOR_DEFAULT,
|
56
|
-
line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
|
57
|
-
) -> base.Optimizer:
|
58
|
-
"""Return an optimizer implementing the standard L-BFGS-B algorithm.
|
59
|
-
|
60
|
-
This optimizer wraps scipy's implementation of the algorithm, and provides
|
61
|
-
a jax-style API to the scheme. The optimizer works with custom types such
|
62
|
-
as the `BoundedArray` to constrain the optimization variable.
|
63
|
-
|
64
|
-
Example usage is as follows:
|
65
|
-
|
66
|
-
def fn(x):
|
67
|
-
leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
|
68
|
-
return jnp.sum(jnp.asarray(leaves_sum_sq))
|
69
|
-
|
70
|
-
x0 = {
|
71
|
-
"a": jnp.ones((3,)),
|
72
|
-
"b": BoundedArray(
|
73
|
-
value=-jnp.ones((2, 5)),
|
74
|
-
lower_bound=-5,
|
75
|
-
upper_bound=5,
|
76
|
-
),
|
77
|
-
}
|
78
|
-
opt = lbfgsb(maxcor=20, line_search_max_steps=100)
|
79
|
-
state = opt.init(x0)
|
80
|
-
for _ in range(10):
|
81
|
-
x = opt.params(state)
|
82
|
-
value, grad = jax.value_and_grad(fn)(x)
|
83
|
-
state = opt.update(grad, value, state)
|
84
|
-
|
85
|
-
While the algorithm can work with pytrees of jax arrays, numpy arrays can
|
86
|
-
also be used. Thus, e.g. the optimizer can directly be used with autograd.
|
87
|
-
|
88
|
-
Args:
|
89
|
-
maxcor: The maximum number of variable metric corrections used to define
|
90
|
-
the limited memory matrix, in the L-BFGS-B scheme.
|
91
|
-
line_search_max_steps: The maximum number of steps in the line search.
|
92
|
-
|
93
|
-
Returns:
|
94
|
-
The `base.Optimizer`.
|
95
|
-
"""
|
96
|
-
return transformed_lbfgsb(
|
97
|
-
maxcor=maxcor,
|
98
|
-
line_search_max_steps=line_search_max_steps,
|
99
|
-
transform_fn=lambda x: x,
|
100
|
-
initialize_latent_fn=lambda x: x,
|
101
|
-
)
|
102
|
-
|
103
|
-
|
104
|
-
def density_lbfgsb(
|
105
|
-
beta: float,
|
106
|
-
maxcor: int = MAXCOR_DEFAULT,
|
107
|
-
line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
|
108
|
-
) -> base.Optimizer:
|
109
|
-
"""Return an L-BFGS-B optimizer with additional transforms for density arrays.
|
110
|
-
|
111
|
-
Parameters that are of type `DensityArray2D` are represented as latent parameters
|
112
|
-
that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
|
113
|
-
|
114
|
-
transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
|
115
|
-
|
116
|
-
where the kernel has a full-width at half-maximum determined by the minimum width
|
117
|
-
and spacing parameters of the `DensityArray2D`. Where the bounds differ, the
|
118
|
-
density is scaled before the transform is applied, and then unscaled afterwards.
|
119
|
-
|
120
|
-
Args:
|
121
|
-
beta: Determines the steepness of the thresholding.
|
122
|
-
maxcor: The maximum number of variable metric corrections used to define
|
123
|
-
the limited memory matrix, in the L-BFGS-B scheme.
|
124
|
-
line_search_max_steps: The maximum number of steps in the line search.
|
125
|
-
|
126
|
-
Returns:
|
127
|
-
The `base.Optimizer`.
|
128
|
-
"""
|
129
|
-
|
130
|
-
def transform_fn(tree: PyTree) -> PyTree:
|
131
|
-
return tree_util.tree_map(
|
132
|
-
lambda x: transform_density(x) if _is_density(x) else x,
|
133
|
-
tree,
|
134
|
-
is_leaf=_is_density,
|
135
|
-
)
|
136
|
-
|
137
|
-
def initialize_latent_fn(tree: PyTree) -> PyTree:
|
138
|
-
return tree_util.tree_map(
|
139
|
-
lambda x: initialize_latent_density(x) if _is_density(x) else x,
|
140
|
-
tree,
|
141
|
-
is_leaf=_is_density,
|
142
|
-
)
|
143
|
-
|
144
|
-
def transform_density(density: types.Density2DArray) -> types.Density2DArray:
|
145
|
-
transformed = types.symmetrize_density(density)
|
146
|
-
transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
|
147
|
-
# Scale to ensure that the full valid range of the density array is reachable.
|
148
|
-
mid_value = (density.lower_bound + density.upper_bound) / 2
|
149
|
-
transformed = tree_util.tree_map(
|
150
|
-
lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
|
151
|
-
)
|
152
|
-
return transform.apply_fixed_pixels(transformed)
|
153
|
-
|
154
|
-
def initialize_latent_density(
|
155
|
-
density: types.Density2DArray,
|
156
|
-
) -> types.Density2DArray:
|
157
|
-
array = transform.normalized_array_from_density(density)
|
158
|
-
array = jnp.clip(array, -1, 1)
|
159
|
-
array *= jnp.tanh(beta)
|
160
|
-
latent_array = jnp.arctanh(array) / beta
|
161
|
-
latent_array = transform.rescale_array_for_density(latent_array, density)
|
162
|
-
return dataclasses.replace(density, array=latent_array)
|
163
|
-
|
164
|
-
return transformed_lbfgsb(
|
165
|
-
maxcor=maxcor,
|
166
|
-
line_search_max_steps=line_search_max_steps,
|
167
|
-
transform_fn=transform_fn,
|
168
|
-
initialize_latent_fn=initialize_latent_fn,
|
169
|
-
)
|
170
|
-
|
171
|
-
|
172
|
-
def transformed_lbfgsb(
|
173
|
-
maxcor: int,
|
174
|
-
line_search_max_steps: int,
|
175
|
-
transform_fn: Callable[[PyTree], PyTree],
|
176
|
-
initialize_latent_fn: Callable[[PyTree], PyTree],
|
177
|
-
) -> base.Optimizer:
|
178
|
-
"""Construct an latent parameter L-BFGS-B optimizer.
|
179
|
-
|
180
|
-
The optimized parameters are termed latent parameters, from which the
|
181
|
-
actual parameters returned by the optimizer are obtained using the
|
182
|
-
`transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
|
183
|
-
the identity), this is equivalent to the standard L-BFGS-B algorithm.
|
184
|
-
|
185
|
-
Args:
|
186
|
-
maxcor: The maximum number of variable metric corrections used to define
|
187
|
-
the limited memory matrix, in the L-BFGS-B scheme.
|
188
|
-
line_search_max_steps: The maximum number of steps in the line search.
|
189
|
-
transform_fn: Function which transforms the internal latent parameters to
|
190
|
-
the parameters returned by the optimizer.
|
191
|
-
initialize_latent_fn: Function which computes the initial latent parameters
|
192
|
-
given the initial parameters.
|
193
|
-
|
194
|
-
Returns:
|
195
|
-
The `base.Optimizer`.
|
196
|
-
"""
|
197
|
-
if not isinstance(maxcor, int) or maxcor < 1 or maxcor > MAXCOR_MAX_VALUE:
|
198
|
-
raise ValueError(
|
199
|
-
f"`maxcor` must be greater than 0 and less than "
|
200
|
-
f"{MAXCOR_MAX_VALUE}, but got {maxcor}"
|
201
|
-
)
|
202
|
-
|
203
|
-
if not isinstance(line_search_max_steps, int) or line_search_max_steps < 1:
|
204
|
-
raise ValueError(
|
205
|
-
f"`line_search_max_steps` must be greater than 0 but got "
|
206
|
-
f"{line_search_max_steps}"
|
207
|
-
)
|
208
|
-
|
209
|
-
def init_fn(params: PyTree) -> LbfgsbState:
|
210
|
-
"""Initializes the optimization state."""
|
211
|
-
|
212
|
-
def _init_pure(params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
|
213
|
-
lower_bound = types.extract_lower_bound(params)
|
214
|
-
upper_bound = types.extract_upper_bound(params)
|
215
|
-
scipy_lbfgsb_state = ScipyLbfgsbState.init(
|
216
|
-
x0=_to_numpy(params),
|
217
|
-
lower_bound=_bound_for_params(lower_bound, params),
|
218
|
-
upper_bound=_bound_for_params(upper_bound, params),
|
219
|
-
maxcor=maxcor,
|
220
|
-
line_search_max_steps=line_search_max_steps,
|
221
|
-
)
|
222
|
-
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
223
|
-
return latent_params, scipy_lbfgsb_state.to_jax()
|
224
|
-
|
225
|
-
(
|
226
|
-
latent_params,
|
227
|
-
jax_lbfgsb_state,
|
228
|
-
) = jax.pure_callback( # type: ignore[attr-defined]
|
229
|
-
_init_pure,
|
230
|
-
_example_state(params, maxcor),
|
231
|
-
initialize_latent_fn(params),
|
232
|
-
)
|
233
|
-
return transform_fn(latent_params), latent_params, jax_lbfgsb_state
|
234
|
-
|
235
|
-
def params_fn(state: LbfgsbState) -> PyTree:
|
236
|
-
"""Returns the parameters for the given `state`."""
|
237
|
-
params, _, _ = state
|
238
|
-
return params
|
239
|
-
|
240
|
-
def update_fn(
|
241
|
-
*,
|
242
|
-
grad: PyTree,
|
243
|
-
value: float,
|
244
|
-
params: PyTree,
|
245
|
-
state: LbfgsbState,
|
246
|
-
) -> LbfgsbState:
|
247
|
-
"""Updates the state."""
|
248
|
-
del params
|
249
|
-
|
250
|
-
def _update_pure(
|
251
|
-
flat_latent_grad: PyTree,
|
252
|
-
value: jnp.ndarray,
|
253
|
-
jax_lbfgsb_state: JaxLbfgsbDict,
|
254
|
-
) -> Tuple[PyTree, JaxLbfgsbDict]:
|
255
|
-
assert onp.size(value) == 1
|
256
|
-
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
|
257
|
-
scipy_lbfgsb_state.update(
|
258
|
-
grad=onp.asarray(flat_latent_grad, dtype=onp.float64),
|
259
|
-
value=onp.asarray(value, dtype=onp.float64),
|
260
|
-
)
|
261
|
-
flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
|
262
|
-
return flat_latent_params, scipy_lbfgsb_state.to_jax()
|
263
|
-
|
264
|
-
_, latent_params, jax_lbfgsb_state = state
|
265
|
-
_, vjp_fn = jax.vjp(transform_fn, latent_params)
|
266
|
-
(latent_grad,) = vjp_fn(grad)
|
267
|
-
flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
|
268
|
-
latent_grad
|
269
|
-
) # type: ignore[no-untyped-call]
|
270
|
-
|
271
|
-
(
|
272
|
-
flat_latent_params,
|
273
|
-
jax_lbfgsb_state,
|
274
|
-
) = jax.pure_callback( # type: ignore[attr-defined]
|
275
|
-
_update_pure,
|
276
|
-
(flat_latent_grad, jax_lbfgsb_state),
|
277
|
-
flat_latent_grad,
|
278
|
-
value,
|
279
|
-
jax_lbfgsb_state,
|
280
|
-
)
|
281
|
-
latent_params = unflatten_fn(flat_latent_params)
|
282
|
-
return transform_fn(latent_params), latent_params, jax_lbfgsb_state
|
283
|
-
|
284
|
-
return base.Optimizer(
|
285
|
-
init=init_fn,
|
286
|
-
params=params_fn,
|
287
|
-
update=update_fn,
|
288
|
-
)
|
289
|
-
|
290
|
-
|
291
|
-
# ------------------------------------------------------------------------------
|
292
|
-
# Helper functions.
|
293
|
-
# ------------------------------------------------------------------------------
|
294
|
-
|
295
|
-
|
296
|
-
def _is_density(leaf: Any) -> Any:
|
297
|
-
"""Return `True` if `leaf` is a density array."""
|
298
|
-
return isinstance(leaf, types.Density2DArray)
|
299
|
-
|
300
|
-
|
301
|
-
def _to_numpy(params: PyTree) -> NDArray:
|
302
|
-
"""Flattens a `params` pytree into a single rank-1 numpy array."""
|
303
|
-
x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
304
|
-
return onp.asarray(x, dtype=onp.float64)
|
305
|
-
|
306
|
-
|
307
|
-
def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
|
308
|
-
"""Restores a pytree from a flat numpy array using the structure of `params`.
|
309
|
-
|
310
|
-
Note that the returned pytree includes jax array leaves.
|
311
|
-
|
312
|
-
Args:
|
313
|
-
x_flat: The rank-1 numpy array to be restored.
|
314
|
-
params: A pytree of parameters whose structure is replicated in the restored
|
315
|
-
pytree.
|
316
|
-
|
317
|
-
Returns:
|
318
|
-
The restored pytree, with jax array leaves.
|
319
|
-
"""
|
320
|
-
_, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
321
|
-
return unflatten_fn(jnp.asarray(x_flat, dtype=float))
|
322
|
-
|
323
|
-
|
324
|
-
def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
325
|
-
"""Generates a bound vector for the `params`.
|
326
|
-
|
327
|
-
The `bound` can be specified in various ways; it may be `None` or a scalar,
|
328
|
-
which then applies to all arrays in `params`. It may be a pytree with
|
329
|
-
structure matching that of `params`, where each leaf is either `None`, a
|
330
|
-
scalar, or an array matching the shape of the corresponding leaf in `params`.
|
331
|
-
|
332
|
-
The returned bound is a flat array suitable for use with `ScipyLbfgsbState`.
|
333
|
-
|
334
|
-
Args:
|
335
|
-
bound: The pytree of bounds.
|
336
|
-
params: The pytree of parameters.
|
337
|
-
|
338
|
-
Returns:
|
339
|
-
The flat elementwise bound.
|
340
|
-
"""
|
341
|
-
|
342
|
-
if bound is None or onp.isscalar(bound):
|
343
|
-
bound = tree_util.tree_map(
|
344
|
-
lambda _: bound,
|
345
|
-
params,
|
346
|
-
is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES),
|
347
|
-
)
|
348
|
-
|
349
|
-
bound_leaves, bound_treedef = tree_util.tree_flatten(
|
350
|
-
bound, is_leaf=lambda x: x is None
|
351
|
-
)
|
352
|
-
params_leaves = tree_util.tree_leaves(params, is_leaf=lambda x: x is None)
|
353
|
-
|
354
|
-
# `bound` should be a pytree of arrays or `None`, while `params` may
|
355
|
-
# include custom pytree nodes. Convert the custom nodes into standard
|
356
|
-
# types to facilitate validation that the tree structures match.
|
357
|
-
params_treedef = tree_util.tree_structure(
|
358
|
-
tree_util.tree_map(
|
359
|
-
lambda x: 0.0,
|
360
|
-
tree=params,
|
361
|
-
is_leaf=lambda x: x is None or isinstance(x, types.CUSTOM_TYPES),
|
362
|
-
)
|
363
|
-
)
|
364
|
-
if bound_treedef != params_treedef: # type: ignore[operator]
|
365
|
-
raise ValueError(
|
366
|
-
f"Tree structure of `bound` and `params` must match, but got "
|
367
|
-
f"{bound_treedef} and {params_treedef}, respectively."
|
368
|
-
)
|
369
|
-
|
370
|
-
bound_flat = []
|
371
|
-
for b, p in zip(bound_leaves, params_leaves):
|
372
|
-
if p is None:
|
373
|
-
continue
|
374
|
-
if b is None or onp.isscalar(b) or onp.shape(b) == ():
|
375
|
-
bound_flat += [b] * onp.size(p)
|
376
|
-
else:
|
377
|
-
if b.shape != p.shape:
|
378
|
-
raise ValueError(
|
379
|
-
f"`bound` must be `None`, a scalar, or have shape matching "
|
380
|
-
f"`params`, but got shape {b.shape} when params has shape "
|
381
|
-
f"{p.shape}."
|
382
|
-
)
|
383
|
-
bound_flat += b.flatten().tolist()
|
384
|
-
|
385
|
-
return bound_flat
|
386
|
-
|
387
|
-
|
388
|
-
def _example_state(params: PyTree, maxcor: int) -> PyTree:
|
389
|
-
"""Return an example state for the given `params` and `maxcor`."""
|
390
|
-
params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
391
|
-
n = params_flat.size
|
392
|
-
float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
|
393
|
-
example_jax_lbfgsb_state = dict(
|
394
|
-
x=jnp.zeros(n, dtype=float),
|
395
|
-
_maxcor=jnp.zeros((), dtype=int),
|
396
|
-
_line_search_max_steps=jnp.zeros((), dtype=int),
|
397
|
-
_wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
|
398
|
-
_iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
|
399
|
-
_task=jnp.zeros(59, dtype=int),
|
400
|
-
_csave=jnp.zeros(59, dtype=int),
|
401
|
-
_lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
|
402
|
-
_isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
|
403
|
-
_dsave=jnp.zeros(29, dtype=float),
|
404
|
-
_lower_bound=jnp.zeros(n, dtype=float),
|
405
|
-
_upper_bound=jnp.zeros(n, dtype=float),
|
406
|
-
_bound_type=jnp.zeros(n, dtype=int),
|
407
|
-
)
|
408
|
-
return float_params, example_jax_lbfgsb_state
|
409
|
-
|
410
|
-
|
411
|
-
# ------------------------------------------------------------------------------
|
412
|
-
# Wrapper for scipy's L-BFGS-B implementation.
|
413
|
-
# ------------------------------------------------------------------------------
|
414
|
-
|
415
|
-
|
416
|
-
@dataclasses.dataclass
|
417
|
-
class ScipyLbfgsbState:
|
418
|
-
"""Stores the state of a scipy L-BFGS-B minimization.
|
419
|
-
|
420
|
-
This class enables optimization with a more functional style, giving the user
|
421
|
-
control over the optimization loop. Example usage is as follows:
|
422
|
-
|
423
|
-
value_fn = lambda x: onp.sum(x**2)
|
424
|
-
grad_fn = lambda x: 2 * x
|
425
|
-
|
426
|
-
x0 = onp.asarray([0.1, 0.2, 0.3])
|
427
|
-
lb = [None, -1, 0.1]
|
428
|
-
ub = [None, None, None]
|
429
|
-
state = ScipyLbfgsbState.init(
|
430
|
-
x0=x0, lower_bound=lb, upper_bound=ub, maxcor=20
|
431
|
-
)
|
432
|
-
|
433
|
-
for _ in range(10):
|
434
|
-
value = value_fn(state.x)
|
435
|
-
grad = grad_fn(state.x)
|
436
|
-
state.update(grad, value)
|
437
|
-
|
438
|
-
This example converges with `state.x` equal to `(0, 0, 0.1)` and value equal
|
439
|
-
to `0.01`.
|
440
|
-
|
441
|
-
Attributes:
|
442
|
-
x: The current solution vector.
|
443
|
-
"""
|
444
|
-
|
445
|
-
x: NDArray
|
446
|
-
# Private attributes correspond to internal variables in the `scipy.optimize.
|
447
|
-
# lbfgsb._minimize_lbfgsb` function.
|
448
|
-
_maxcor: int
|
449
|
-
_line_search_max_steps: int
|
450
|
-
_wa: NDArray
|
451
|
-
_iwa: NDArray
|
452
|
-
_task: NDArray
|
453
|
-
_csave: NDArray
|
454
|
-
_lsave: NDArray
|
455
|
-
_isave: NDArray
|
456
|
-
_dsave: NDArray
|
457
|
-
_lower_bound: NDArray
|
458
|
-
_upper_bound: NDArray
|
459
|
-
_bound_type: NDArray
|
460
|
-
|
461
|
-
def __post_init__(self) -> None:
|
462
|
-
"""Validates the datatypes for all state attributes."""
|
463
|
-
_validate_array_dtype(self.x, onp.float64)
|
464
|
-
_validate_array_dtype(self._wa, onp.float64)
|
465
|
-
_validate_array_dtype(self._iwa, FORTRAN_INT)
|
466
|
-
_validate_array_dtype(self._task, "S60")
|
467
|
-
_validate_array_dtype(self._csave, "S60")
|
468
|
-
_validate_array_dtype(self._lsave, FORTRAN_INT)
|
469
|
-
_validate_array_dtype(self._isave, FORTRAN_INT)
|
470
|
-
_validate_array_dtype(self._dsave, onp.float64)
|
471
|
-
_validate_array_dtype(self._lower_bound, onp.float64)
|
472
|
-
_validate_array_dtype(self._upper_bound, onp.float64)
|
473
|
-
_validate_array_dtype(self._bound_type, int)
|
474
|
-
|
475
|
-
def to_jax(self) -> Dict[str, jnp.ndarray]:
|
476
|
-
"""Generates a dictionary of jax arrays defining the state."""
|
477
|
-
return dict(
|
478
|
-
x=jnp.asarray(self.x),
|
479
|
-
_maxcor=jnp.asarray(self._maxcor),
|
480
|
-
_line_search_max_steps=jnp.asarray(self._line_search_max_steps),
|
481
|
-
_wa=jnp.asarray(self._wa),
|
482
|
-
_iwa=jnp.asarray(self._iwa),
|
483
|
-
_task=_array_from_s60_str(self._task),
|
484
|
-
_csave=_array_from_s60_str(self._csave),
|
485
|
-
_lsave=jnp.asarray(self._lsave),
|
486
|
-
_isave=jnp.asarray(self._isave),
|
487
|
-
_dsave=jnp.asarray(self._dsave),
|
488
|
-
_lower_bound=jnp.asarray(self._lower_bound),
|
489
|
-
_upper_bound=jnp.asarray(self._upper_bound),
|
490
|
-
_bound_type=jnp.asarray(self._bound_type),
|
491
|
-
)
|
492
|
-
|
493
|
-
@classmethod
|
494
|
-
def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
|
495
|
-
"""Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
|
496
|
-
state_dict = copy.deepcopy(state_dict)
|
497
|
-
return ScipyLbfgsbState(
|
498
|
-
x=onp.asarray(state_dict["x"], dtype=onp.float64),
|
499
|
-
_maxcor=int(state_dict["_maxcor"]),
|
500
|
-
_line_search_max_steps=int(state_dict["_line_search_max_steps"]),
|
501
|
-
_wa=onp.asarray(state_dict["_wa"], onp.float64),
|
502
|
-
_iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
|
503
|
-
_task=_s60_str_from_array(state_dict["_task"]),
|
504
|
-
_csave=_s60_str_from_array(state_dict["_csave"]),
|
505
|
-
_lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
|
506
|
-
_isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
|
507
|
-
_dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
|
508
|
-
_lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
|
509
|
-
_upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
|
510
|
-
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
|
511
|
-
)
|
512
|
-
|
513
|
-
@classmethod
|
514
|
-
def init(
|
515
|
-
cls,
|
516
|
-
x0: NDArray,
|
517
|
-
lower_bound: ElementwiseBound,
|
518
|
-
upper_bound: ElementwiseBound,
|
519
|
-
maxcor: int,
|
520
|
-
line_search_max_steps: int,
|
521
|
-
) -> "ScipyLbfgsbState":
|
522
|
-
"""Initializes the `ScipyLbfgsbState` for `x0`.
|
523
|
-
|
524
|
-
Args:
|
525
|
-
x0: Array giving the initial solution vector.
|
526
|
-
lower_bound: Array giving the elementwise optional lower bound.
|
527
|
-
upper_bound: Array giving the elementwise optional upper bound.
|
528
|
-
maxcor: The maximum number of variable metric corrections used to define
|
529
|
-
the limited memory matrix, in the L-BFGS-B scheme.
|
530
|
-
line_search_max_steps: The maximum number of steps in the line search.
|
531
|
-
|
532
|
-
Returns:
|
533
|
-
The `ScipyLbfgsbState`.
|
534
|
-
"""
|
535
|
-
x0 = onp.asarray(x0)
|
536
|
-
if x0.ndim > 1:
|
537
|
-
raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
|
538
|
-
lower_bound = onp.asarray(lower_bound)
|
539
|
-
upper_bound = onp.asarray(upper_bound)
|
540
|
-
if x0.shape != lower_bound.shape or x0.shape != upper_bound.shape:
|
541
|
-
raise ValueError(
|
542
|
-
f"`x0`, `lower_bound`, and `upper_bound` must have matching "
|
543
|
-
f"shape but got shapes {x0.shape}, {lower_bound.shape}, and "
|
544
|
-
f"{upper_bound.shape}, respectively."
|
545
|
-
)
|
546
|
-
if maxcor < 1:
|
547
|
-
raise ValueError(f"`maxcor` must be positive but got {maxcor}.")
|
548
|
-
|
549
|
-
n = x0.size
|
550
|
-
lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
|
551
|
-
lower_bound, upper_bound
|
552
|
-
)
|
553
|
-
task = onp.zeros(1, "S60")
|
554
|
-
task[:] = TASK_START
|
555
|
-
|
556
|
-
# See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
|
557
|
-
# function.
|
558
|
-
wa_size = _wa_size(n=n, maxcor=maxcor)
|
559
|
-
state = ScipyLbfgsbState(
|
560
|
-
x=onp.array(x0, onp.float64),
|
561
|
-
_maxcor=maxcor,
|
562
|
-
_line_search_max_steps=line_search_max_steps,
|
563
|
-
_wa=onp.zeros(wa_size, onp.float64),
|
564
|
-
_iwa=onp.zeros(3 * n, FORTRAN_INT),
|
565
|
-
_task=task,
|
566
|
-
_csave=onp.zeros(1, "S60"),
|
567
|
-
_lsave=onp.zeros(4, FORTRAN_INT),
|
568
|
-
_isave=onp.zeros(44, FORTRAN_INT),
|
569
|
-
_dsave=onp.zeros(29, onp.float64),
|
570
|
-
_lower_bound=lower_bound_array,
|
571
|
-
_upper_bound=upper_bound_array,
|
572
|
-
_bound_type=bound_type,
|
573
|
-
)
|
574
|
-
# The initial state requires an update with zero value and gradient. This
|
575
|
-
# is because the initial task is "START", which does not actually require
|
576
|
-
# value and gradient evaluation.
|
577
|
-
state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
|
578
|
-
return state
|
579
|
-
|
580
|
-
def update(
|
581
|
-
self,
|
582
|
-
grad: NDArray,
|
583
|
-
value: NDArray,
|
584
|
-
) -> None:
|
585
|
-
"""Performs an in-place update of the `ScipyLbfgsbState`.
|
586
|
-
|
587
|
-
Args:
|
588
|
-
grad: The function gradient for the current `x`.
|
589
|
-
value: The scalar function value for the current `x`.
|
590
|
-
"""
|
591
|
-
if grad.shape != self.x.shape:
|
592
|
-
raise ValueError(
|
593
|
-
f"`grad` must have the same shape as attribute `x`, but got shapes "
|
594
|
-
f"{grad.shape} and {self.x.shape}, respectively."
|
595
|
-
)
|
596
|
-
if value.shape != ():
|
597
|
-
raise ValueError(f"`value` must be a scalar but got shape {value.shape}.")
|
598
|
-
|
599
|
-
# The `setulb` function will sometimes return with a task that does not
|
600
|
-
# require a value and gradient evaluation. In this case we simply call it
|
601
|
-
# again, advancing past such "dummy" steps.
|
602
|
-
for _ in range(3):
|
603
|
-
scipy_lbfgsb.setulb(
|
604
|
-
m=self._maxcor,
|
605
|
-
x=self.x,
|
606
|
-
l=self._lower_bound,
|
607
|
-
u=self._upper_bound,
|
608
|
-
nbd=self._bound_type,
|
609
|
-
f=value,
|
610
|
-
g=grad,
|
611
|
-
factr=UPDATE_FACTR,
|
612
|
-
pgtol=UPDATE_PGTOL,
|
613
|
-
wa=self._wa,
|
614
|
-
iwa=self._iwa,
|
615
|
-
task=self._task,
|
616
|
-
iprint=UPDATE_IPRINT,
|
617
|
-
csave=self._csave,
|
618
|
-
lsave=self._lsave,
|
619
|
-
isave=self._isave,
|
620
|
-
dsave=self._dsave,
|
621
|
-
maxls=self._line_search_max_steps,
|
622
|
-
)
|
623
|
-
task_str = self._task.tobytes()
|
624
|
-
if task_str.startswith(TASK_FG):
|
625
|
-
break
|
626
|
-
|
627
|
-
|
628
|
-
def _wa_size(n: int, maxcor: int) -> int:
|
629
|
-
"""Return the size of the `wa` attribute of lbfgsb state."""
|
630
|
-
return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
|
631
|
-
|
632
|
-
|
633
|
-
def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
|
634
|
-
"""Validates that `x` is an array with the specified `dtype`."""
|
635
|
-
if not isinstance(x, onp.ndarray):
|
636
|
-
raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
|
637
|
-
if x.dtype != dtype:
|
638
|
-
raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
|
639
|
-
|
640
|
-
|
641
|
-
def _configure_bounds(
|
642
|
-
lower_bound: ElementwiseBound,
|
643
|
-
upper_bound: ElementwiseBound,
|
644
|
-
) -> Tuple[NDArray, NDArray, NDArray]:
|
645
|
-
"""Configures the bounds for an L-BFGS-B optimization."""
|
646
|
-
bound_type = [
|
647
|
-
BOUNDS_MAP[(lower is None, upper is None)]
|
648
|
-
for lower, upper in zip(lower_bound, upper_bound)
|
649
|
-
]
|
650
|
-
lower_bound_array = [0.0 if x is None else x for x in lower_bound]
|
651
|
-
upper_bound_array = [0.0 if x is None else x for x in upper_bound]
|
652
|
-
return (
|
653
|
-
onp.asarray(lower_bound_array, onp.float64),
|
654
|
-
onp.asarray(upper_bound_array, onp.float64),
|
655
|
-
onp.asarray(bound_type),
|
656
|
-
)
|
657
|
-
|
658
|
-
|
659
|
-
def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
|
660
|
-
"""Return a jax array for a numpy s60 string."""
|
661
|
-
assert s60_str.shape == (1,)
|
662
|
-
chars = [int(o) for o in s60_str[0]]
|
663
|
-
chars.extend([32] * (59 - len(chars)))
|
664
|
-
return jnp.asarray(chars, dtype=int)
|
665
|
-
|
666
|
-
|
667
|
-
def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
|
668
|
-
"""Return a numpy s60 string for a jax array."""
|
669
|
-
return onp.asarray(
|
670
|
-
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
|
671
|
-
dtype="S60",
|
672
|
-
)
|
@@ -1,21 +0,0 @@
|
|
1
|
-
MIT License
|
2
|
-
|
3
|
-
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
-
|
5
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
-
of this software and associated documentation files (the "Software"), to deal
|
7
|
-
in the Software without restriction, including without limitation the rights
|
8
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
-
copies of the Software, and to permit persons to whom the Software is
|
10
|
-
furnished to do so, subject to the following conditions:
|
11
|
-
|
12
|
-
The above copyright notice and this permission notice shall be included in all
|
13
|
-
copies or substantial portions of the Software.
|
14
|
-
|
15
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
-
SOFTWARE.
|