invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,670 +0,0 @@
1
- """Defines a jax-style wrapper for scipy's L-BFGS-B algorithm.
2
-
3
- Copyright (c) 2023 The INVRS-IO authors.
4
- """
5
-
6
- import copy
7
- import dataclasses
8
- from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
9
-
10
- import jax
11
- import jax.numpy as jnp
12
- import numpy as onp
13
- from jax import flatten_util, tree_util
14
- from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
15
- _lbfgsb as scipy_lbfgsb,
16
- )
17
- from totypes import types
18
-
19
- from invrs_opt import base
20
- from invrs_opt.lbfgsb import transform
21
-
22
- NDArray = onp.ndarray[Any, Any]
23
- PyTree = Any
24
- ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
25
- JaxLbfgsbDict = Dict[str, jnp.ndarray]
26
- LbfgsbState = Tuple[PyTree, PyTree, JaxLbfgsbDict]
27
-
28
-
29
- # Task message prefixes for the underlying L-BFGS-B implementation.
30
- TASK_START = b"START"
31
- TASK_FG = b"FG"
32
-
33
- # Parameters which configure the state update step.
34
- UPDATE_IPRINT = -1
35
- UPDATE_PGTOL = 0.0
36
- UPDATE_FACTR = 0.0
37
-
38
- # Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
39
- MAXCOR_MAX_VALUE = 100
40
- MAXCOR_DEFAULT = 20
41
- LINE_SEARCH_MAX_STEPS_DEFAULT = 100
42
-
43
- # Maps bound scenarios to integers.
44
- BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
45
- (True, True): 0, # Both upper and lower bound are `None`.
46
- (False, True): 1, # Only upper bound is `None`.
47
- (False, False): 2, # Neither of the bounds are `None`.
48
- (True, False): 3, # Only the lower bound is `None`.
49
- }
50
-
51
- FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
52
-
53
-
54
- def lbfgsb(
55
- maxcor: int = MAXCOR_DEFAULT,
56
- line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
57
- ) -> base.Optimizer:
58
- """Return an optimizer implementing the standard L-BFGS-B algorithm.
59
-
60
- This optimizer wraps scipy's implementation of the algorithm, and provides
61
- a jax-style API to the scheme. The optimizer works with custom types such
62
- as the `BoundedArray` to constrain the optimization variable.
63
-
64
- Example usage is as follows:
65
-
66
- def fn(x):
67
- leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
68
- return jnp.sum(jnp.asarray(leaves_sum_sq))
69
-
70
- x0 = {
71
- "a": jnp.ones((3,)),
72
- "b": BoundedArray(
73
- value=-jnp.ones((2, 5)),
74
- lower_bound=-5,
75
- upper_bound=5,
76
- ),
77
- }
78
- opt = lbfgsb(maxcor=20, line_search_max_steps=100)
79
- state = opt.init(x0)
80
- for _ in range(10):
81
- x = opt.params(state)
82
- value, grad = jax.value_and_grad(fn)(x)
83
- state = opt.update(grad, value, state)
84
-
85
- While the algorithm can work with pytrees of jax arrays, numpy arrays can
86
- also be used. Thus, e.g. the optimizer can directly be used with autograd.
87
-
88
- Args:
89
- maxcor: The maximum number of variable metric corrections used to define
90
- the limited memory matrix, in the L-BFGS-B scheme.
91
- line_search_max_steps: The maximum number of steps in the line search.
92
-
93
- Returns:
94
- The `base.Optimizer`.
95
- """
96
- return transformed_lbfgsb(
97
- maxcor=maxcor,
98
- line_search_max_steps=line_search_max_steps,
99
- transform_fn=lambda x: x,
100
- initialize_latent_fn=lambda x: x,
101
- )
102
-
103
-
104
- def density_lbfgsb(
105
- beta: float,
106
- maxcor: int = MAXCOR_DEFAULT,
107
- line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
108
- ) -> base.Optimizer:
109
- """Return an L-BFGS-B optimizer with additional transforms for density arrays.
110
-
111
- Parameters that are of type `DensityArray2D` are represented as latent parameters
112
- that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
113
-
114
- transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
115
-
116
- where the kernel has a full-width at half-maximum determined by the minimum width
117
- and spacing parameters of the `DensityArray2D`. Where the bounds differ, the
118
- density is scaled before the transform is applied, and then unscaled afterwards.
119
-
120
- Args:
121
- beta: Determines the steepness of the thresholding.
122
- maxcor: The maximum number of variable metric corrections used to define
123
- the limited memory matrix, in the L-BFGS-B scheme.
124
- line_search_max_steps: The maximum number of steps in the line search.
125
-
126
- Returns:
127
- The `base.Optimizer`.
128
- """
129
-
130
- def transform_fn(tree: PyTree) -> PyTree:
131
- return tree_util.tree_map(
132
- lambda x: transform_density(x) if _is_density(x) else x,
133
- tree,
134
- is_leaf=_is_density,
135
- )
136
-
137
- def initialize_latent_fn(tree: PyTree) -> PyTree:
138
- return tree_util.tree_map(
139
- lambda x: initialize_latent_density(x) if _is_density(x) else x,
140
- tree,
141
- is_leaf=_is_density,
142
- )
143
-
144
- def transform_density(density: types.Density2DArray) -> types.Density2DArray:
145
- transformed = types.symmetrize_density(density)
146
- transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
147
- # Scale to ensure that the full valid range of the density array is reachable.
148
- mid_value = (density.lower_bound + density.upper_bound) / 2
149
- transformed = tree_util.tree_map(
150
- lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
151
- )
152
- return transform.apply_fixed_pixels(transformed)
153
-
154
- def initialize_latent_density(
155
- density: types.Density2DArray,
156
- ) -> types.Density2DArray:
157
- array = transform.normalized_array_from_density(density)
158
- array = jnp.clip(array, -1, 1)
159
- array *= jnp.tanh(beta)
160
- latent_array = jnp.arctanh(array) / beta
161
- latent_array = transform.rescale_array_for_density(latent_array, density)
162
- return dataclasses.replace(density, array=latent_array)
163
-
164
- return transformed_lbfgsb(
165
- maxcor=maxcor,
166
- line_search_max_steps=line_search_max_steps,
167
- transform_fn=transform_fn,
168
- initialize_latent_fn=initialize_latent_fn,
169
- )
170
-
171
-
172
- def transformed_lbfgsb(
173
- maxcor: int,
174
- line_search_max_steps: int,
175
- transform_fn: Callable[[PyTree], PyTree],
176
- initialize_latent_fn: Callable[[PyTree], PyTree],
177
- ) -> base.Optimizer:
178
- """Construct an latent parameter L-BFGS-B optimizer.
179
-
180
- The optimized parameters are termed latent parameters, from which the
181
- actual parameters returned by the optimizer are obtained using the
182
- `transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
183
- the identity), this is equivalent to the standard L-BFGS-B algorithm.
184
-
185
- Args:
186
- maxcor: The maximum number of variable metric corrections used to define
187
- the limited memory matrix, in the L-BFGS-B scheme.
188
- line_search_max_steps: The maximum number of steps in the line search.
189
- transform_fn: Function which transforms the internal latent parameters to
190
- the parameters returned by the optimizer.
191
- initialize_latent_fn: Function which computes the initial latent parameters
192
- given the initial parameters.
193
-
194
- Returns:
195
- The `base.Optimizer`.
196
- """
197
- if not isinstance(maxcor, int) or maxcor < 1 or maxcor > MAXCOR_MAX_VALUE:
198
- raise ValueError(
199
- f"`maxcor` must be greater than 0 and less than "
200
- f"{MAXCOR_MAX_VALUE}, but got {maxcor}"
201
- )
202
-
203
- if not isinstance(line_search_max_steps, int) or line_search_max_steps < 1:
204
- raise ValueError(
205
- f"`line_search_max_steps` must be greater than 0 but got "
206
- f"{line_search_max_steps}"
207
- )
208
-
209
- def init_fn(params: PyTree) -> LbfgsbState:
210
- """Initializes the optimization state."""
211
-
212
- def _init_pure(params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
213
- lower_bound = types.extract_lower_bound(params)
214
- upper_bound = types.extract_upper_bound(params)
215
- scipy_lbfgsb_state = ScipyLbfgsbState.init(
216
- x0=_to_numpy(params),
217
- lower_bound=_bound_for_params(lower_bound, params),
218
- upper_bound=_bound_for_params(upper_bound, params),
219
- maxcor=maxcor,
220
- line_search_max_steps=line_search_max_steps,
221
- )
222
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
223
- return latent_params, scipy_lbfgsb_state.to_jax()
224
-
225
- (
226
- latent_params,
227
- jax_lbfgsb_state,
228
- ) = jax.pure_callback( # type: ignore[attr-defined]
229
- _init_pure,
230
- _example_state(params, maxcor),
231
- initialize_latent_fn(params),
232
- )
233
- return transform_fn(latent_params), latent_params, jax_lbfgsb_state
234
-
235
- def params_fn(state: LbfgsbState) -> PyTree:
236
- """Returns the parameters for the given `state`."""
237
- params, _, _ = state
238
- return params
239
-
240
- def update_fn(
241
- *,
242
- grad: PyTree,
243
- value: float,
244
- params: PyTree,
245
- state: LbfgsbState,
246
- ) -> LbfgsbState:
247
- """Updates the state."""
248
- del params
249
-
250
- def _update_pure(
251
- flat_latent_grad: PyTree,
252
- value: jnp.ndarray,
253
- jax_lbfgsb_state: JaxLbfgsbDict,
254
- ) -> Tuple[PyTree, JaxLbfgsbDict]:
255
- assert onp.size(value) == 1
256
- scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
257
- scipy_lbfgsb_state.update(
258
- grad=onp.asarray(flat_latent_grad, dtype=onp.float64),
259
- value=onp.asarray(value, dtype=onp.float64),
260
- )
261
- flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
262
- return flat_latent_params, scipy_lbfgsb_state.to_jax()
263
-
264
- params, latent_params, jax_lbfgsb_state = state
265
- _, vjp_fn = jax.vjp(transform_fn, latent_params)
266
- (latent_grad,) = vjp_fn(grad)
267
- flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(latent_grad)
268
-
269
- (
270
- flat_latent_params,
271
- jax_lbfgsb_state,
272
- ) = jax.pure_callback( # type: ignore[attr-defined]
273
- _update_pure,
274
- (flat_latent_grad, jax_lbfgsb_state),
275
- flat_latent_grad,
276
- value,
277
- jax_lbfgsb_state,
278
- )
279
- latent_params = unflatten_fn(flat_latent_params)
280
- return transform_fn(latent_params), latent_params, jax_lbfgsb_state
281
-
282
- return base.Optimizer(
283
- init=init_fn,
284
- params=params_fn,
285
- update=update_fn,
286
- )
287
-
288
-
289
- # ------------------------------------------------------------------------------
290
- # Helper functions.
291
- # ------------------------------------------------------------------------------
292
-
293
-
294
- def _is_density(leaf: Any) -> Any:
295
- """Return `True` if `leaf` is a density array."""
296
- return isinstance(leaf, types.Density2DArray)
297
-
298
-
299
- def _to_numpy(params: PyTree) -> NDArray:
300
- """Flattens a `params` pytree into a single rank-1 numpy array."""
301
- x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
302
- return onp.asarray(x, dtype=onp.float64)
303
-
304
-
305
- def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
306
- """Restores a pytree from a flat numpy array using the structure of `params`.
307
-
308
- Note that the returned pytree includes jax array leaves.
309
-
310
- Args:
311
- x_flat: The rank-1 numpy array to be restored.
312
- params: A pytree of parameters whose structure is replicated in the restored
313
- pytree.
314
-
315
- Returns:
316
- The restored pytree, with jax array leaves.
317
- """
318
- _, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
319
- return unflatten_fn(jnp.asarray(x_flat, dtype=float))
320
-
321
-
322
- def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
323
- """Generates a bound vector for the `params`.
324
-
325
- The `bound` can be specified in various ways; it may be `None` or a scalar,
326
- which then applies to all arrays in `params`. It may be a pytree with
327
- structure matching that of `params`, where each leaf is either `None`, a
328
- scalar, or an array matching the shape of the corresponding leaf in `params`.
329
-
330
- The returned bound is a flat array suitable for use with `ScipyLbfgsbState`.
331
-
332
- Args:
333
- bound: The pytree of bounds.
334
- params: The pytree of parameters.
335
-
336
- Returns:
337
- The flat elementwise bound.
338
- """
339
-
340
- if bound is None or onp.isscalar(bound):
341
- bound = tree_util.tree_map(
342
- lambda _: bound,
343
- params,
344
- is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES),
345
- )
346
-
347
- bound_leaves, bound_treedef = tree_util.tree_flatten(
348
- bound, is_leaf=lambda x: x is None
349
- )
350
- params_leaves = tree_util.tree_leaves(params, is_leaf=lambda x: x is None)
351
-
352
- # `bound` should be a pytree of arrays or `None`, while `params` may
353
- # include custom pytree nodes. Convert the custom nodes into standard
354
- # types to facilitate validation that the tree structures match.
355
- params_treedef = tree_util.tree_structure(
356
- tree_util.tree_map(
357
- lambda x: 0.0,
358
- tree=params,
359
- is_leaf=lambda x: x is None or isinstance(x, types.CUSTOM_TYPES),
360
- )
361
- )
362
- if bound_treedef != params_treedef: # type: ignore[operator]
363
- raise ValueError(
364
- f"Tree structure of `bound` and `params` must match, but got "
365
- f"{bound_treedef} and {params_treedef}, respectively."
366
- )
367
-
368
- bound_flat = []
369
- for b, p in zip(bound_leaves, params_leaves):
370
- if p is None:
371
- continue
372
- if b is None or onp.isscalar(b) or onp.shape(b) == ():
373
- bound_flat += [b] * onp.size(p)
374
- else:
375
- if b.shape != p.shape:
376
- raise ValueError(
377
- f"`bound` must be `None`, a scalar, or have shape matching "
378
- f"`params`, but got shape {b.shape} when params has shape "
379
- f"{p.shape}."
380
- )
381
- bound_flat += b.flatten().tolist()
382
-
383
- return bound_flat
384
-
385
-
386
- def _example_state(params: PyTree, maxcor: int) -> PyTree:
387
- """Return an example state for the given `params` and `maxcor`."""
388
- params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
389
- n = params_flat.size
390
- float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
391
- example_jax_lbfgsb_state = dict(
392
- x=jnp.zeros(n, dtype=float),
393
- _maxcor=jnp.zeros((), dtype=int),
394
- _line_search_max_steps=jnp.zeros((), dtype=int),
395
- _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
396
- _iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
397
- _task=jnp.zeros(59, dtype=int),
398
- _csave=jnp.zeros(59, dtype=int),
399
- _lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
400
- _isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
401
- _dsave=jnp.zeros(29, dtype=float),
402
- _lower_bound=jnp.zeros(n, dtype=float),
403
- _upper_bound=jnp.zeros(n, dtype=float),
404
- _bound_type=jnp.zeros(n, dtype=int),
405
- )
406
- return float_params, example_jax_lbfgsb_state
407
-
408
-
409
- # ------------------------------------------------------------------------------
410
- # Wrapper for scipy's L-BFGS-B implementation.
411
- # ------------------------------------------------------------------------------
412
-
413
-
414
- @dataclasses.dataclass
415
- class ScipyLbfgsbState:
416
- """Stores the state of a scipy L-BFGS-B minimization.
417
-
418
- This class enables optimization with a more functional style, giving the user
419
- control over the optimization loop. Example usage is as follows:
420
-
421
- value_fn = lambda x: onp.sum(x**2)
422
- grad_fn = lambda x: 2 * x
423
-
424
- x0 = onp.asarray([0.1, 0.2, 0.3])
425
- lb = [None, -1, 0.1]
426
- ub = [None, None, None]
427
- state = ScipyLbfgsbState.init(
428
- x0=x0, lower_bound=lb, upper_bound=ub, maxcor=20
429
- )
430
-
431
- for _ in range(10):
432
- value = value_fn(state.x)
433
- grad = grad_fn(state.x)
434
- state.update(grad, value)
435
-
436
- This example converges with `state.x` equal to `(0, 0, 0.1)` and value equal
437
- to `0.01`.
438
-
439
- Attributes:
440
- x: The current solution vector.
441
- """
442
-
443
- x: NDArray
444
- # Private attributes correspond to internal variables in the `scipy.optimize.
445
- # lbfgsb._minimize_lbfgsb` function.
446
- _maxcor: int
447
- _line_search_max_steps: int
448
- _wa: NDArray
449
- _iwa: NDArray
450
- _task: NDArray
451
- _csave: NDArray
452
- _lsave: NDArray
453
- _isave: NDArray
454
- _dsave: NDArray
455
- _lower_bound: NDArray
456
- _upper_bound: NDArray
457
- _bound_type: NDArray
458
-
459
- def __post_init__(self) -> None:
460
- """Validates the datatypes for all state attributes."""
461
- _validate_array_dtype(self.x, onp.float64)
462
- _validate_array_dtype(self._wa, onp.float64)
463
- _validate_array_dtype(self._iwa, FORTRAN_INT)
464
- _validate_array_dtype(self._task, "S60")
465
- _validate_array_dtype(self._csave, "S60")
466
- _validate_array_dtype(self._lsave, FORTRAN_INT)
467
- _validate_array_dtype(self._isave, FORTRAN_INT)
468
- _validate_array_dtype(self._dsave, onp.float64)
469
- _validate_array_dtype(self._lower_bound, onp.float64)
470
- _validate_array_dtype(self._upper_bound, onp.float64)
471
- _validate_array_dtype(self._bound_type, int)
472
-
473
- def to_jax(self) -> Dict[str, jnp.ndarray]:
474
- """Generates a dictionary of jax arrays defining the state."""
475
- return dict(
476
- x=jnp.asarray(self.x),
477
- _maxcor=jnp.asarray(self._maxcor),
478
- _line_search_max_steps=jnp.asarray(self._line_search_max_steps),
479
- _wa=jnp.asarray(self._wa),
480
- _iwa=jnp.asarray(self._iwa),
481
- _task=_array_from_s60_str(self._task),
482
- _csave=_array_from_s60_str(self._csave),
483
- _lsave=jnp.asarray(self._lsave),
484
- _isave=jnp.asarray(self._isave),
485
- _dsave=jnp.asarray(self._dsave),
486
- _lower_bound=jnp.asarray(self._lower_bound),
487
- _upper_bound=jnp.asarray(self._upper_bound),
488
- _bound_type=jnp.asarray(self._bound_type),
489
- )
490
-
491
- @classmethod
492
- def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
493
- """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
494
- state_dict = copy.deepcopy(state_dict)
495
- return ScipyLbfgsbState(
496
- x=onp.asarray(state_dict["x"], dtype=onp.float64),
497
- _maxcor=int(state_dict["_maxcor"]),
498
- _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
499
- _wa=onp.asarray(state_dict["_wa"], onp.float64),
500
- _iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
501
- _task=_s60_str_from_array(state_dict["_task"]),
502
- _csave=_s60_str_from_array(state_dict["_csave"]),
503
- _lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
504
- _isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
505
- _dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
506
- _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
507
- _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
508
- _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
509
- )
510
-
511
- @classmethod
512
- def init(
513
- cls,
514
- x0: NDArray,
515
- lower_bound: ElementwiseBound,
516
- upper_bound: ElementwiseBound,
517
- maxcor: int,
518
- line_search_max_steps: int,
519
- ) -> "ScipyLbfgsbState":
520
- """Initializes the `ScipyLbfgsbState` for `x0`.
521
-
522
- Args:
523
- x0: Array giving the initial solution vector.
524
- lower_bound: Array giving the elementwise optional lower bound.
525
- upper_bound: Array giving the elementwise optional upper bound.
526
- maxcor: The maximum number of variable metric corrections used to define
527
- the limited memory matrix, in the L-BFGS-B scheme.
528
- line_search_max_steps: The maximum number of steps in the line search.
529
-
530
- Returns:
531
- The `ScipyLbfgsbState`.
532
- """
533
- x0 = onp.asarray(x0)
534
- if x0.ndim > 1:
535
- raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
536
- lower_bound = onp.asarray(lower_bound)
537
- upper_bound = onp.asarray(upper_bound)
538
- if x0.shape != lower_bound.shape or x0.shape != upper_bound.shape:
539
- raise ValueError(
540
- f"`x0`, `lower_bound`, and `upper_bound` must have matching "
541
- f"shape but got shapes {x0.shape}, {lower_bound.shape}, and "
542
- f"{upper_bound.shape}, respectively."
543
- )
544
- if maxcor < 1:
545
- raise ValueError(f"`maxcor` must be positive but got {maxcor}.")
546
-
547
- n = x0.size
548
- lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
549
- lower_bound, upper_bound
550
- )
551
- task = onp.zeros(1, "S60")
552
- task[:] = TASK_START
553
-
554
- # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
555
- # function.
556
- wa_size = _wa_size(n=n, maxcor=maxcor)
557
- state = ScipyLbfgsbState(
558
- x=onp.array(x0, onp.float64),
559
- _maxcor=maxcor,
560
- _line_search_max_steps=line_search_max_steps,
561
- _wa=onp.zeros(wa_size, onp.float64),
562
- _iwa=onp.zeros(3 * n, FORTRAN_INT),
563
- _task=task,
564
- _csave=onp.zeros(1, "S60"),
565
- _lsave=onp.zeros(4, FORTRAN_INT),
566
- _isave=onp.zeros(44, FORTRAN_INT),
567
- _dsave=onp.zeros(29, onp.float64),
568
- _lower_bound=lower_bound_array,
569
- _upper_bound=upper_bound_array,
570
- _bound_type=bound_type,
571
- )
572
- # The initial state requires an update with zero value and gradient. This
573
- # is because the initial task is "START", which does not actually require
574
- # value and gradient evaluation.
575
- state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
576
- return state
577
-
578
- def update(
579
- self,
580
- grad: NDArray,
581
- value: NDArray,
582
- ) -> None:
583
- """Performs an in-place update of the `ScipyLbfgsbState`.
584
-
585
- Args:
586
- grad: The function gradient for the current `x`.
587
- value: The scalar function value for the current `x`.
588
- """
589
- if grad.shape != self.x.shape:
590
- raise ValueError(
591
- f"`grad` must have the same shape as attribute `x`, but got shapes "
592
- f"{grad.shape} and {self.x.shape}, respectively."
593
- )
594
- if value.shape != ():
595
- raise ValueError(f"`value` must be a scalar but got shape {value.shape}.")
596
-
597
- # The `setulb` function will sometimes return with a task that does not
598
- # require a value and gradient evaluation. In this case we simply call it
599
- # again, advancing past such "dummy" steps.
600
- for _ in range(3):
601
- scipy_lbfgsb.setulb(
602
- m=self._maxcor,
603
- x=self.x,
604
- l=self._lower_bound,
605
- u=self._upper_bound,
606
- nbd=self._bound_type,
607
- f=value,
608
- g=grad,
609
- factr=UPDATE_FACTR,
610
- pgtol=UPDATE_PGTOL,
611
- wa=self._wa,
612
- iwa=self._iwa,
613
- task=self._task,
614
- iprint=UPDATE_IPRINT,
615
- csave=self._csave,
616
- lsave=self._lsave,
617
- isave=self._isave,
618
- dsave=self._dsave,
619
- maxls=self._line_search_max_steps,
620
- )
621
- task_str = self._task.tobytes()
622
- if task_str.startswith(TASK_FG):
623
- break
624
-
625
-
626
- def _wa_size(n: int, maxcor: int) -> int:
627
- """Return the size of the `wa` attribute of lbfgsb state."""
628
- return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
629
-
630
-
631
- def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
632
- """Validates that `x` is an array with the specified `dtype`."""
633
- if not isinstance(x, onp.ndarray):
634
- raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
635
- if x.dtype != dtype:
636
- raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
637
-
638
-
639
- def _configure_bounds(
640
- lower_bound: ElementwiseBound,
641
- upper_bound: ElementwiseBound,
642
- ) -> Tuple[NDArray, NDArray, NDArray]:
643
- """Configures the bounds for an L-BFGS-B optimization."""
644
- bound_type = [
645
- BOUNDS_MAP[(lower is None, upper is None)]
646
- for lower, upper in zip(lower_bound, upper_bound)
647
- ]
648
- lower_bound_array = [0.0 if x is None else x for x in lower_bound]
649
- upper_bound_array = [0.0 if x is None else x for x in upper_bound]
650
- return (
651
- onp.asarray(lower_bound_array, onp.float64),
652
- onp.asarray(upper_bound_array, onp.float64),
653
- onp.asarray(bound_type),
654
- )
655
-
656
-
657
- def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
658
- """Return a jax array for a numpy s60 string."""
659
- assert s60_str.shape == (1,)
660
- chars = [int(o) for o in s60_str[0]]
661
- chars.extend([32] * (59 - len(chars)))
662
- return jnp.asarray(chars, dtype=int)
663
-
664
-
665
- def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
666
- """Return a numpy s60 string for a jax array."""
667
- return onp.asarray(
668
- [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
669
- dtype="S60",
670
- )
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2023 The INVRS-IO authors.
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.