invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,672 +0,0 @@
1
- """Defines a jax-style wrapper for scipy's L-BFGS-B algorithm.
2
-
3
- Copyright (c) 2023 The INVRS-IO authors.
4
- """
5
-
6
- import copy
7
- import dataclasses
8
- from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
9
-
10
- import jax
11
- import jax.numpy as jnp
12
- import numpy as onp
13
- from jax import flatten_util, tree_util
14
- from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
15
- _lbfgsb as scipy_lbfgsb,
16
- )
17
- from totypes import types
18
-
19
- from invrs_opt import base
20
- from invrs_opt.lbfgsb import transform
21
-
22
- NDArray = onp.ndarray[Any, Any]
23
- PyTree = Any
24
- ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
25
- JaxLbfgsbDict = Dict[str, jnp.ndarray]
26
- LbfgsbState = Tuple[PyTree, PyTree, JaxLbfgsbDict]
27
-
28
-
29
- # Task message prefixes for the underlying L-BFGS-B implementation.
30
- TASK_START = b"START"
31
- TASK_FG = b"FG"
32
-
33
- # Parameters which configure the state update step.
34
- UPDATE_IPRINT = -1
35
- UPDATE_PGTOL = 0.0
36
- UPDATE_FACTR = 0.0
37
-
38
- # Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
39
- MAXCOR_MAX_VALUE = 100
40
- MAXCOR_DEFAULT = 20
41
- LINE_SEARCH_MAX_STEPS_DEFAULT = 100
42
-
43
- # Maps bound scenarios to integers.
44
- BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
45
- (True, True): 0, # Both upper and lower bound are `None`.
46
- (False, True): 1, # Only upper bound is `None`.
47
- (False, False): 2, # Neither of the bounds are `None`.
48
- (True, False): 3, # Only the lower bound is `None`.
49
- }
50
-
51
- FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
52
-
53
-
54
- def lbfgsb(
55
- maxcor: int = MAXCOR_DEFAULT,
56
- line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
57
- ) -> base.Optimizer:
58
- """Return an optimizer implementing the standard L-BFGS-B algorithm.
59
-
60
- This optimizer wraps scipy's implementation of the algorithm, and provides
61
- a jax-style API to the scheme. The optimizer works with custom types such
62
- as the `BoundedArray` to constrain the optimization variable.
63
-
64
- Example usage is as follows:
65
-
66
- def fn(x):
67
- leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
68
- return jnp.sum(jnp.asarray(leaves_sum_sq))
69
-
70
- x0 = {
71
- "a": jnp.ones((3,)),
72
- "b": BoundedArray(
73
- value=-jnp.ones((2, 5)),
74
- lower_bound=-5,
75
- upper_bound=5,
76
- ),
77
- }
78
- opt = lbfgsb(maxcor=20, line_search_max_steps=100)
79
- state = opt.init(x0)
80
- for _ in range(10):
81
- x = opt.params(state)
82
- value, grad = jax.value_and_grad(fn)(x)
83
- state = opt.update(grad, value, state)
84
-
85
- While the algorithm can work with pytrees of jax arrays, numpy arrays can
86
- also be used. Thus, e.g. the optimizer can directly be used with autograd.
87
-
88
- Args:
89
- maxcor: The maximum number of variable metric corrections used to define
90
- the limited memory matrix, in the L-BFGS-B scheme.
91
- line_search_max_steps: The maximum number of steps in the line search.
92
-
93
- Returns:
94
- The `base.Optimizer`.
95
- """
96
- return transformed_lbfgsb(
97
- maxcor=maxcor,
98
- line_search_max_steps=line_search_max_steps,
99
- transform_fn=lambda x: x,
100
- initialize_latent_fn=lambda x: x,
101
- )
102
-
103
-
104
- def density_lbfgsb(
105
- beta: float,
106
- maxcor: int = MAXCOR_DEFAULT,
107
- line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
108
- ) -> base.Optimizer:
109
- """Return an L-BFGS-B optimizer with additional transforms for density arrays.
110
-
111
- Parameters that are of type `DensityArray2D` are represented as latent parameters
112
- that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
113
-
114
- transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
115
-
116
- where the kernel has a full-width at half-maximum determined by the minimum width
117
- and spacing parameters of the `DensityArray2D`. Where the bounds differ, the
118
- density is scaled before the transform is applied, and then unscaled afterwards.
119
-
120
- Args:
121
- beta: Determines the steepness of the thresholding.
122
- maxcor: The maximum number of variable metric corrections used to define
123
- the limited memory matrix, in the L-BFGS-B scheme.
124
- line_search_max_steps: The maximum number of steps in the line search.
125
-
126
- Returns:
127
- The `base.Optimizer`.
128
- """
129
-
130
- def transform_fn(tree: PyTree) -> PyTree:
131
- return tree_util.tree_map(
132
- lambda x: transform_density(x) if _is_density(x) else x,
133
- tree,
134
- is_leaf=_is_density,
135
- )
136
-
137
- def initialize_latent_fn(tree: PyTree) -> PyTree:
138
- return tree_util.tree_map(
139
- lambda x: initialize_latent_density(x) if _is_density(x) else x,
140
- tree,
141
- is_leaf=_is_density,
142
- )
143
-
144
- def transform_density(density: types.Density2DArray) -> types.Density2DArray:
145
- transformed = types.symmetrize_density(density)
146
- transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
147
- # Scale to ensure that the full valid range of the density array is reachable.
148
- mid_value = (density.lower_bound + density.upper_bound) / 2
149
- transformed = tree_util.tree_map(
150
- lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
151
- )
152
- return transform.apply_fixed_pixels(transformed)
153
-
154
- def initialize_latent_density(
155
- density: types.Density2DArray,
156
- ) -> types.Density2DArray:
157
- array = transform.normalized_array_from_density(density)
158
- array = jnp.clip(array, -1, 1)
159
- array *= jnp.tanh(beta)
160
- latent_array = jnp.arctanh(array) / beta
161
- latent_array = transform.rescale_array_for_density(latent_array, density)
162
- return dataclasses.replace(density, array=latent_array)
163
-
164
- return transformed_lbfgsb(
165
- maxcor=maxcor,
166
- line_search_max_steps=line_search_max_steps,
167
- transform_fn=transform_fn,
168
- initialize_latent_fn=initialize_latent_fn,
169
- )
170
-
171
-
172
- def transformed_lbfgsb(
173
- maxcor: int,
174
- line_search_max_steps: int,
175
- transform_fn: Callable[[PyTree], PyTree],
176
- initialize_latent_fn: Callable[[PyTree], PyTree],
177
- ) -> base.Optimizer:
178
- """Construct an latent parameter L-BFGS-B optimizer.
179
-
180
- The optimized parameters are termed latent parameters, from which the
181
- actual parameters returned by the optimizer are obtained using the
182
- `transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
183
- the identity), this is equivalent to the standard L-BFGS-B algorithm.
184
-
185
- Args:
186
- maxcor: The maximum number of variable metric corrections used to define
187
- the limited memory matrix, in the L-BFGS-B scheme.
188
- line_search_max_steps: The maximum number of steps in the line search.
189
- transform_fn: Function which transforms the internal latent parameters to
190
- the parameters returned by the optimizer.
191
- initialize_latent_fn: Function which computes the initial latent parameters
192
- given the initial parameters.
193
-
194
- Returns:
195
- The `base.Optimizer`.
196
- """
197
- if not isinstance(maxcor, int) or maxcor < 1 or maxcor > MAXCOR_MAX_VALUE:
198
- raise ValueError(
199
- f"`maxcor` must be greater than 0 and less than "
200
- f"{MAXCOR_MAX_VALUE}, but got {maxcor}"
201
- )
202
-
203
- if not isinstance(line_search_max_steps, int) or line_search_max_steps < 1:
204
- raise ValueError(
205
- f"`line_search_max_steps` must be greater than 0 but got "
206
- f"{line_search_max_steps}"
207
- )
208
-
209
- def init_fn(params: PyTree) -> LbfgsbState:
210
- """Initializes the optimization state."""
211
-
212
- def _init_pure(params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
213
- lower_bound = types.extract_lower_bound(params)
214
- upper_bound = types.extract_upper_bound(params)
215
- scipy_lbfgsb_state = ScipyLbfgsbState.init(
216
- x0=_to_numpy(params),
217
- lower_bound=_bound_for_params(lower_bound, params),
218
- upper_bound=_bound_for_params(upper_bound, params),
219
- maxcor=maxcor,
220
- line_search_max_steps=line_search_max_steps,
221
- )
222
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
223
- return latent_params, scipy_lbfgsb_state.to_jax()
224
-
225
- (
226
- latent_params,
227
- jax_lbfgsb_state,
228
- ) = jax.pure_callback( # type: ignore[attr-defined]
229
- _init_pure,
230
- _example_state(params, maxcor),
231
- initialize_latent_fn(params),
232
- )
233
- return transform_fn(latent_params), latent_params, jax_lbfgsb_state
234
-
235
- def params_fn(state: LbfgsbState) -> PyTree:
236
- """Returns the parameters for the given `state`."""
237
- params, _, _ = state
238
- return params
239
-
240
- def update_fn(
241
- *,
242
- grad: PyTree,
243
- value: float,
244
- params: PyTree,
245
- state: LbfgsbState,
246
- ) -> LbfgsbState:
247
- """Updates the state."""
248
- del params
249
-
250
- def _update_pure(
251
- flat_latent_grad: PyTree,
252
- value: jnp.ndarray,
253
- jax_lbfgsb_state: JaxLbfgsbDict,
254
- ) -> Tuple[PyTree, JaxLbfgsbDict]:
255
- assert onp.size(value) == 1
256
- scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
257
- scipy_lbfgsb_state.update(
258
- grad=onp.asarray(flat_latent_grad, dtype=onp.float64),
259
- value=onp.asarray(value, dtype=onp.float64),
260
- )
261
- flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
262
- return flat_latent_params, scipy_lbfgsb_state.to_jax()
263
-
264
- _, latent_params, jax_lbfgsb_state = state
265
- _, vjp_fn = jax.vjp(transform_fn, latent_params)
266
- (latent_grad,) = vjp_fn(grad)
267
- flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
268
- latent_grad
269
- ) # type: ignore[no-untyped-call]
270
-
271
- (
272
- flat_latent_params,
273
- jax_lbfgsb_state,
274
- ) = jax.pure_callback( # type: ignore[attr-defined]
275
- _update_pure,
276
- (flat_latent_grad, jax_lbfgsb_state),
277
- flat_latent_grad,
278
- value,
279
- jax_lbfgsb_state,
280
- )
281
- latent_params = unflatten_fn(flat_latent_params)
282
- return transform_fn(latent_params), latent_params, jax_lbfgsb_state
283
-
284
- return base.Optimizer(
285
- init=init_fn,
286
- params=params_fn,
287
- update=update_fn,
288
- )
289
-
290
-
291
- # ------------------------------------------------------------------------------
292
- # Helper functions.
293
- # ------------------------------------------------------------------------------
294
-
295
-
296
- def _is_density(leaf: Any) -> Any:
297
- """Return `True` if `leaf` is a density array."""
298
- return isinstance(leaf, types.Density2DArray)
299
-
300
-
301
- def _to_numpy(params: PyTree) -> NDArray:
302
- """Flattens a `params` pytree into a single rank-1 numpy array."""
303
- x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
304
- return onp.asarray(x, dtype=onp.float64)
305
-
306
-
307
- def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
308
- """Restores a pytree from a flat numpy array using the structure of `params`.
309
-
310
- Note that the returned pytree includes jax array leaves.
311
-
312
- Args:
313
- x_flat: The rank-1 numpy array to be restored.
314
- params: A pytree of parameters whose structure is replicated in the restored
315
- pytree.
316
-
317
- Returns:
318
- The restored pytree, with jax array leaves.
319
- """
320
- _, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
321
- return unflatten_fn(jnp.asarray(x_flat, dtype=float))
322
-
323
-
324
- def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
325
- """Generates a bound vector for the `params`.
326
-
327
- The `bound` can be specified in various ways; it may be `None` or a scalar,
328
- which then applies to all arrays in `params`. It may be a pytree with
329
- structure matching that of `params`, where each leaf is either `None`, a
330
- scalar, or an array matching the shape of the corresponding leaf in `params`.
331
-
332
- The returned bound is a flat array suitable for use with `ScipyLbfgsbState`.
333
-
334
- Args:
335
- bound: The pytree of bounds.
336
- params: The pytree of parameters.
337
-
338
- Returns:
339
- The flat elementwise bound.
340
- """
341
-
342
- if bound is None or onp.isscalar(bound):
343
- bound = tree_util.tree_map(
344
- lambda _: bound,
345
- params,
346
- is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES),
347
- )
348
-
349
- bound_leaves, bound_treedef = tree_util.tree_flatten(
350
- bound, is_leaf=lambda x: x is None
351
- )
352
- params_leaves = tree_util.tree_leaves(params, is_leaf=lambda x: x is None)
353
-
354
- # `bound` should be a pytree of arrays or `None`, while `params` may
355
- # include custom pytree nodes. Convert the custom nodes into standard
356
- # types to facilitate validation that the tree structures match.
357
- params_treedef = tree_util.tree_structure(
358
- tree_util.tree_map(
359
- lambda x: 0.0,
360
- tree=params,
361
- is_leaf=lambda x: x is None or isinstance(x, types.CUSTOM_TYPES),
362
- )
363
- )
364
- if bound_treedef != params_treedef: # type: ignore[operator]
365
- raise ValueError(
366
- f"Tree structure of `bound` and `params` must match, but got "
367
- f"{bound_treedef} and {params_treedef}, respectively."
368
- )
369
-
370
- bound_flat = []
371
- for b, p in zip(bound_leaves, params_leaves):
372
- if p is None:
373
- continue
374
- if b is None or onp.isscalar(b) or onp.shape(b) == ():
375
- bound_flat += [b] * onp.size(p)
376
- else:
377
- if b.shape != p.shape:
378
- raise ValueError(
379
- f"`bound` must be `None`, a scalar, or have shape matching "
380
- f"`params`, but got shape {b.shape} when params has shape "
381
- f"{p.shape}."
382
- )
383
- bound_flat += b.flatten().tolist()
384
-
385
- return bound_flat
386
-
387
-
388
- def _example_state(params: PyTree, maxcor: int) -> PyTree:
389
- """Return an example state for the given `params` and `maxcor`."""
390
- params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
391
- n = params_flat.size
392
- float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
393
- example_jax_lbfgsb_state = dict(
394
- x=jnp.zeros(n, dtype=float),
395
- _maxcor=jnp.zeros((), dtype=int),
396
- _line_search_max_steps=jnp.zeros((), dtype=int),
397
- _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
398
- _iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
399
- _task=jnp.zeros(59, dtype=int),
400
- _csave=jnp.zeros(59, dtype=int),
401
- _lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
402
- _isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
403
- _dsave=jnp.zeros(29, dtype=float),
404
- _lower_bound=jnp.zeros(n, dtype=float),
405
- _upper_bound=jnp.zeros(n, dtype=float),
406
- _bound_type=jnp.zeros(n, dtype=int),
407
- )
408
- return float_params, example_jax_lbfgsb_state
409
-
410
-
411
- # ------------------------------------------------------------------------------
412
- # Wrapper for scipy's L-BFGS-B implementation.
413
- # ------------------------------------------------------------------------------
414
-
415
-
416
- @dataclasses.dataclass
417
- class ScipyLbfgsbState:
418
- """Stores the state of a scipy L-BFGS-B minimization.
419
-
420
- This class enables optimization with a more functional style, giving the user
421
- control over the optimization loop. Example usage is as follows:
422
-
423
- value_fn = lambda x: onp.sum(x**2)
424
- grad_fn = lambda x: 2 * x
425
-
426
- x0 = onp.asarray([0.1, 0.2, 0.3])
427
- lb = [None, -1, 0.1]
428
- ub = [None, None, None]
429
- state = ScipyLbfgsbState.init(
430
- x0=x0, lower_bound=lb, upper_bound=ub, maxcor=20
431
- )
432
-
433
- for _ in range(10):
434
- value = value_fn(state.x)
435
- grad = grad_fn(state.x)
436
- state.update(grad, value)
437
-
438
- This example converges with `state.x` equal to `(0, 0, 0.1)` and value equal
439
- to `0.01`.
440
-
441
- Attributes:
442
- x: The current solution vector.
443
- """
444
-
445
- x: NDArray
446
- # Private attributes correspond to internal variables in the `scipy.optimize.
447
- # lbfgsb._minimize_lbfgsb` function.
448
- _maxcor: int
449
- _line_search_max_steps: int
450
- _wa: NDArray
451
- _iwa: NDArray
452
- _task: NDArray
453
- _csave: NDArray
454
- _lsave: NDArray
455
- _isave: NDArray
456
- _dsave: NDArray
457
- _lower_bound: NDArray
458
- _upper_bound: NDArray
459
- _bound_type: NDArray
460
-
461
- def __post_init__(self) -> None:
462
- """Validates the datatypes for all state attributes."""
463
- _validate_array_dtype(self.x, onp.float64)
464
- _validate_array_dtype(self._wa, onp.float64)
465
- _validate_array_dtype(self._iwa, FORTRAN_INT)
466
- _validate_array_dtype(self._task, "S60")
467
- _validate_array_dtype(self._csave, "S60")
468
- _validate_array_dtype(self._lsave, FORTRAN_INT)
469
- _validate_array_dtype(self._isave, FORTRAN_INT)
470
- _validate_array_dtype(self._dsave, onp.float64)
471
- _validate_array_dtype(self._lower_bound, onp.float64)
472
- _validate_array_dtype(self._upper_bound, onp.float64)
473
- _validate_array_dtype(self._bound_type, int)
474
-
475
- def to_jax(self) -> Dict[str, jnp.ndarray]:
476
- """Generates a dictionary of jax arrays defining the state."""
477
- return dict(
478
- x=jnp.asarray(self.x),
479
- _maxcor=jnp.asarray(self._maxcor),
480
- _line_search_max_steps=jnp.asarray(self._line_search_max_steps),
481
- _wa=jnp.asarray(self._wa),
482
- _iwa=jnp.asarray(self._iwa),
483
- _task=_array_from_s60_str(self._task),
484
- _csave=_array_from_s60_str(self._csave),
485
- _lsave=jnp.asarray(self._lsave),
486
- _isave=jnp.asarray(self._isave),
487
- _dsave=jnp.asarray(self._dsave),
488
- _lower_bound=jnp.asarray(self._lower_bound),
489
- _upper_bound=jnp.asarray(self._upper_bound),
490
- _bound_type=jnp.asarray(self._bound_type),
491
- )
492
-
493
- @classmethod
494
- def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
495
- """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
496
- state_dict = copy.deepcopy(state_dict)
497
- return ScipyLbfgsbState(
498
- x=onp.asarray(state_dict["x"], dtype=onp.float64),
499
- _maxcor=int(state_dict["_maxcor"]),
500
- _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
501
- _wa=onp.asarray(state_dict["_wa"], onp.float64),
502
- _iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
503
- _task=_s60_str_from_array(state_dict["_task"]),
504
- _csave=_s60_str_from_array(state_dict["_csave"]),
505
- _lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
506
- _isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
507
- _dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
508
- _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
509
- _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
510
- _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
511
- )
512
-
513
- @classmethod
514
- def init(
515
- cls,
516
- x0: NDArray,
517
- lower_bound: ElementwiseBound,
518
- upper_bound: ElementwiseBound,
519
- maxcor: int,
520
- line_search_max_steps: int,
521
- ) -> "ScipyLbfgsbState":
522
- """Initializes the `ScipyLbfgsbState` for `x0`.
523
-
524
- Args:
525
- x0: Array giving the initial solution vector.
526
- lower_bound: Array giving the elementwise optional lower bound.
527
- upper_bound: Array giving the elementwise optional upper bound.
528
- maxcor: The maximum number of variable metric corrections used to define
529
- the limited memory matrix, in the L-BFGS-B scheme.
530
- line_search_max_steps: The maximum number of steps in the line search.
531
-
532
- Returns:
533
- The `ScipyLbfgsbState`.
534
- """
535
- x0 = onp.asarray(x0)
536
- if x0.ndim > 1:
537
- raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
538
- lower_bound = onp.asarray(lower_bound)
539
- upper_bound = onp.asarray(upper_bound)
540
- if x0.shape != lower_bound.shape or x0.shape != upper_bound.shape:
541
- raise ValueError(
542
- f"`x0`, `lower_bound`, and `upper_bound` must have matching "
543
- f"shape but got shapes {x0.shape}, {lower_bound.shape}, and "
544
- f"{upper_bound.shape}, respectively."
545
- )
546
- if maxcor < 1:
547
- raise ValueError(f"`maxcor` must be positive but got {maxcor}.")
548
-
549
- n = x0.size
550
- lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
551
- lower_bound, upper_bound
552
- )
553
- task = onp.zeros(1, "S60")
554
- task[:] = TASK_START
555
-
556
- # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
557
- # function.
558
- wa_size = _wa_size(n=n, maxcor=maxcor)
559
- state = ScipyLbfgsbState(
560
- x=onp.array(x0, onp.float64),
561
- _maxcor=maxcor,
562
- _line_search_max_steps=line_search_max_steps,
563
- _wa=onp.zeros(wa_size, onp.float64),
564
- _iwa=onp.zeros(3 * n, FORTRAN_INT),
565
- _task=task,
566
- _csave=onp.zeros(1, "S60"),
567
- _lsave=onp.zeros(4, FORTRAN_INT),
568
- _isave=onp.zeros(44, FORTRAN_INT),
569
- _dsave=onp.zeros(29, onp.float64),
570
- _lower_bound=lower_bound_array,
571
- _upper_bound=upper_bound_array,
572
- _bound_type=bound_type,
573
- )
574
- # The initial state requires an update with zero value and gradient. This
575
- # is because the initial task is "START", which does not actually require
576
- # value and gradient evaluation.
577
- state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
578
- return state
579
-
580
- def update(
581
- self,
582
- grad: NDArray,
583
- value: NDArray,
584
- ) -> None:
585
- """Performs an in-place update of the `ScipyLbfgsbState`.
586
-
587
- Args:
588
- grad: The function gradient for the current `x`.
589
- value: The scalar function value for the current `x`.
590
- """
591
- if grad.shape != self.x.shape:
592
- raise ValueError(
593
- f"`grad` must have the same shape as attribute `x`, but got shapes "
594
- f"{grad.shape} and {self.x.shape}, respectively."
595
- )
596
- if value.shape != ():
597
- raise ValueError(f"`value` must be a scalar but got shape {value.shape}.")
598
-
599
- # The `setulb` function will sometimes return with a task that does not
600
- # require a value and gradient evaluation. In this case we simply call it
601
- # again, advancing past such "dummy" steps.
602
- for _ in range(3):
603
- scipy_lbfgsb.setulb(
604
- m=self._maxcor,
605
- x=self.x,
606
- l=self._lower_bound,
607
- u=self._upper_bound,
608
- nbd=self._bound_type,
609
- f=value,
610
- g=grad,
611
- factr=UPDATE_FACTR,
612
- pgtol=UPDATE_PGTOL,
613
- wa=self._wa,
614
- iwa=self._iwa,
615
- task=self._task,
616
- iprint=UPDATE_IPRINT,
617
- csave=self._csave,
618
- lsave=self._lsave,
619
- isave=self._isave,
620
- dsave=self._dsave,
621
- maxls=self._line_search_max_steps,
622
- )
623
- task_str = self._task.tobytes()
624
- if task_str.startswith(TASK_FG):
625
- break
626
-
627
-
628
- def _wa_size(n: int, maxcor: int) -> int:
629
- """Return the size of the `wa` attribute of lbfgsb state."""
630
- return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
631
-
632
-
633
- def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
634
- """Validates that `x` is an array with the specified `dtype`."""
635
- if not isinstance(x, onp.ndarray):
636
- raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
637
- if x.dtype != dtype:
638
- raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
639
-
640
-
641
- def _configure_bounds(
642
- lower_bound: ElementwiseBound,
643
- upper_bound: ElementwiseBound,
644
- ) -> Tuple[NDArray, NDArray, NDArray]:
645
- """Configures the bounds for an L-BFGS-B optimization."""
646
- bound_type = [
647
- BOUNDS_MAP[(lower is None, upper is None)]
648
- for lower, upper in zip(lower_bound, upper_bound)
649
- ]
650
- lower_bound_array = [0.0 if x is None else x for x in lower_bound]
651
- upper_bound_array = [0.0 if x is None else x for x in upper_bound]
652
- return (
653
- onp.asarray(lower_bound_array, onp.float64),
654
- onp.asarray(upper_bound_array, onp.float64),
655
- onp.asarray(bound_type),
656
- )
657
-
658
-
659
- def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
660
- """Return a jax array for a numpy s60 string."""
661
- assert s60_str.shape == (1,)
662
- chars = [int(o) for o in s60_str[0]]
663
- chars.extend([32] * (59 - len(chars)))
664
- return jnp.asarray(chars, dtype=int)
665
-
666
-
667
- def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
668
- """Return a numpy s60 string for a jax array."""
669
- return onp.asarray(
670
- [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
671
- dtype="S60",
672
- )
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2023 The INVRS-IO authors.
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.