invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,670 +0,0 @@
1
- """Defines a jax-style wrapper for scipy's L-BFGS-B algorithm.
2
-
3
- Copyright (c) 2023 The INVRS-IO authors.
4
- """
5
-
6
- import copy
7
- import dataclasses
8
- from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
9
-
10
- import jax
11
- import jax.numpy as jnp
12
- import numpy as onp
13
- from jax import flatten_util, tree_util
14
- from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
15
- _lbfgsb as scipy_lbfgsb,
16
- )
17
- from totypes import types
18
-
19
- from invrs_opt import base
20
- from invrs_opt.lbfgsb import transform
21
-
22
- NDArray = onp.ndarray[Any, Any]
23
- PyTree = Any
24
- ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
25
- JaxLbfgsbDict = Dict[str, jnp.ndarray]
26
- LbfgsbState = Tuple[PyTree, PyTree, JaxLbfgsbDict]
27
-
28
-
29
- # Task message prefixes for the underlying L-BFGS-B implementation.
30
- TASK_START = b"START"
31
- TASK_FG = b"FG"
32
-
33
- # Parameters which configure the state update step.
34
- UPDATE_IPRINT = -1
35
- UPDATE_PGTOL = 0.0
36
- UPDATE_FACTR = 0.0
37
-
38
- # Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
39
- MAXCOR_MAX_VALUE = 100
40
- MAXCOR_DEFAULT = 20
41
- LINE_SEARCH_MAX_STEPS_DEFAULT = 100
42
-
43
- # Maps bound scenarios to integers.
44
- BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
45
- (True, True): 0, # Both upper and lower bound are `None`.
46
- (False, True): 1, # Only upper bound is `None`.
47
- (False, False): 2, # Neither of the bounds are `None`.
48
- (True, False): 3, # Only the lower bound is `None`.
49
- }
50
-
51
- FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
52
-
53
-
54
- def lbfgsb(
55
- maxcor: int = MAXCOR_DEFAULT,
56
- line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
57
- ) -> base.Optimizer:
58
- """Return an optimizer implementing the standard L-BFGS-B algorithm.
59
-
60
- This optimizer wraps scipy's implementation of the algorithm, and provides
61
- a jax-style API to the scheme. The optimizer works with custom types such
62
- as the `BoundedArray` to constrain the optimization variable.
63
-
64
- Example usage is as follows:
65
-
66
- def fn(x):
67
- leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
68
- return jnp.sum(jnp.asarray(leaves_sum_sq))
69
-
70
- x0 = {
71
- "a": jnp.ones((3,)),
72
- "b": BoundedArray(
73
- value=-jnp.ones((2, 5)),
74
- lower_bound=-5,
75
- upper_bound=5,
76
- ),
77
- }
78
- opt = lbfgsb(maxcor=20, line_search_max_steps=100)
79
- state = opt.init(x0)
80
- for _ in range(10):
81
- x = opt.params(state)
82
- value, grad = jax.value_and_grad(fn)(x)
83
- state = opt.update(grad, value, state)
84
-
85
- While the algorithm can work with pytrees of jax arrays, numpy arrays can
86
- also be used. Thus, e.g. the optimizer can directly be used with autograd.
87
-
88
- Args:
89
- maxcor: The maximum number of variable metric corrections used to define
90
- the limited memory matrix, in the L-BFGS-B scheme.
91
- line_search_max_steps: The maximum number of steps in the line search.
92
-
93
- Returns:
94
- The `base.Optimizer`.
95
- """
96
- return transformed_lbfgsb(
97
- maxcor=maxcor,
98
- line_search_max_steps=line_search_max_steps,
99
- transform_fn=lambda x: x,
100
- initialize_latent_fn=lambda x: x,
101
- )
102
-
103
-
104
- def density_lbfgsb(
105
- beta: float,
106
- maxcor: int = MAXCOR_DEFAULT,
107
- line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
108
- ) -> base.Optimizer:
109
- """Return an L-BFGS-B optimizer with additional transforms for density arrays.
110
-
111
- Parameters that are of type `DensityArray2D` are represented as latent parameters
112
- that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
113
-
114
- transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
115
-
116
- where the kernel has a full-width at half-maximum determined by the minimum width
117
- and spacing parameters of the `DensityArray2D`. Where the bounds differ, the
118
- density is scaled before the transform is applied, and then unscaled afterwards.
119
-
120
- Args:
121
- beta: Determines the steepness of the thresholding.
122
- maxcor: The maximum number of variable metric corrections used to define
123
- the limited memory matrix, in the L-BFGS-B scheme.
124
- line_search_max_steps: The maximum number of steps in the line search.
125
-
126
- Returns:
127
- The `base.Optimizer`.
128
- """
129
-
130
- def transform_fn(tree: PyTree) -> PyTree:
131
- return tree_util.tree_map(
132
- lambda x: transform_density(x) if _is_density(x) else x,
133
- tree,
134
- is_leaf=_is_density,
135
- )
136
-
137
- def initialize_latent_fn(tree: PyTree) -> PyTree:
138
- return tree_util.tree_map(
139
- lambda x: initialize_latent_density(x) if _is_density(x) else x,
140
- tree,
141
- is_leaf=_is_density,
142
- )
143
-
144
- def transform_density(density: types.Density2DArray) -> types.Density2DArray:
145
- transformed = types.symmetrize_density(density)
146
- transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
147
- # Scale to ensure that the full valid range of the density array is reachable.
148
- mid_value = (density.lower_bound + density.upper_bound) / 2
149
- transformed = tree_util.tree_map(
150
- lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
151
- )
152
- return transform.apply_fixed_pixels(transformed)
153
-
154
- def initialize_latent_density(
155
- density: types.Density2DArray,
156
- ) -> types.Density2DArray:
157
- array = transform.normalized_array_from_density(density)
158
- array = jnp.clip(array, -1, 1)
159
- array *= jnp.tanh(beta)
160
- latent_array = jnp.arctanh(array) / beta
161
- latent_array = transform.rescale_array_for_density(latent_array, density)
162
- return dataclasses.replace(density, array=latent_array)
163
-
164
- return transformed_lbfgsb(
165
- maxcor=maxcor,
166
- line_search_max_steps=line_search_max_steps,
167
- transform_fn=transform_fn,
168
- initialize_latent_fn=initialize_latent_fn,
169
- )
170
-
171
-
172
- def transformed_lbfgsb(
173
- maxcor: int,
174
- line_search_max_steps: int,
175
- transform_fn: Callable[[PyTree], PyTree],
176
- initialize_latent_fn: Callable[[PyTree], PyTree],
177
- ) -> base.Optimizer:
178
- """Construct an latent parameter L-BFGS-B optimizer.
179
-
180
- The optimized parameters are termed latent parameters, from which the
181
- actual parameters returned by the optimizer are obtained using the
182
- `transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
183
- the identity), this is equivalent to the standard L-BFGS-B algorithm.
184
-
185
- Args:
186
- maxcor: The maximum number of variable metric corrections used to define
187
- the limited memory matrix, in the L-BFGS-B scheme.
188
- line_search_max_steps: The maximum number of steps in the line search.
189
- transform_fn: Function which transforms the internal latent parameters to
190
- the parameters returned by the optimizer.
191
- initialize_latent_fn: Function which computes the initial latent parameters
192
- given the initial parameters.
193
-
194
- Returns:
195
- The `base.Optimizer`.
196
- """
197
- if not isinstance(maxcor, int) or maxcor < 1 or maxcor > MAXCOR_MAX_VALUE:
198
- raise ValueError(
199
- f"`maxcor` must be greater than 0 and less than "
200
- f"{MAXCOR_MAX_VALUE}, but got {maxcor}"
201
- )
202
-
203
- if not isinstance(line_search_max_steps, int) or line_search_max_steps < 1:
204
- raise ValueError(
205
- f"`line_search_max_steps` must be greater than 0 but got "
206
- f"{line_search_max_steps}"
207
- )
208
-
209
- def init_fn(params: PyTree) -> LbfgsbState:
210
- """Initializes the optimization state."""
211
-
212
- def _init_pure(params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
213
- lower_bound = types.extract_lower_bound(params)
214
- upper_bound = types.extract_upper_bound(params)
215
- scipy_lbfgsb_state = ScipyLbfgsbState.init(
216
- x0=_to_numpy(params),
217
- lower_bound=_bound_for_params(lower_bound, params),
218
- upper_bound=_bound_for_params(upper_bound, params),
219
- maxcor=maxcor,
220
- line_search_max_steps=line_search_max_steps,
221
- )
222
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
223
- return latent_params, scipy_lbfgsb_state.to_jax()
224
-
225
- (
226
- latent_params,
227
- jax_lbfgsb_state,
228
- ) = jax.pure_callback( # type: ignore[attr-defined]
229
- _init_pure,
230
- _example_state(params, maxcor),
231
- initialize_latent_fn(params),
232
- )
233
- return transform_fn(latent_params), latent_params, jax_lbfgsb_state
234
-
235
- def params_fn(state: LbfgsbState) -> PyTree:
236
- """Returns the parameters for the given `state`."""
237
- params, _, _ = state
238
- return params
239
-
240
- def update_fn(
241
- *,
242
- grad: PyTree,
243
- value: float,
244
- params: PyTree,
245
- state: LbfgsbState,
246
- ) -> LbfgsbState:
247
- """Updates the state."""
248
- del params
249
-
250
- def _update_pure(
251
- flat_latent_grad: PyTree,
252
- value: jnp.ndarray,
253
- jax_lbfgsb_state: JaxLbfgsbDict,
254
- ) -> Tuple[PyTree, JaxLbfgsbDict]:
255
- assert onp.size(value) == 1
256
- scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
257
- scipy_lbfgsb_state.update(
258
- grad=onp.asarray(flat_latent_grad, dtype=onp.float64),
259
- value=onp.asarray(value, dtype=onp.float64),
260
- )
261
- flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
262
- return flat_latent_params, scipy_lbfgsb_state.to_jax()
263
-
264
- params, latent_params, jax_lbfgsb_state = state
265
- _, vjp_fn = jax.vjp(transform_fn, latent_params)
266
- (latent_grad,) = vjp_fn(grad)
267
- flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(latent_grad)
268
-
269
- (
270
- flat_latent_params,
271
- jax_lbfgsb_state,
272
- ) = jax.pure_callback( # type: ignore[attr-defined]
273
- _update_pure,
274
- (flat_latent_grad, jax_lbfgsb_state),
275
- flat_latent_grad,
276
- value,
277
- jax_lbfgsb_state,
278
- )
279
- latent_params = unflatten_fn(flat_latent_params)
280
- return transform_fn(latent_params), latent_params, jax_lbfgsb_state
281
-
282
- return base.Optimizer(
283
- init=init_fn,
284
- params=params_fn,
285
- update=update_fn,
286
- )
287
-
288
-
289
- # ------------------------------------------------------------------------------
290
- # Helper functions.
291
- # ------------------------------------------------------------------------------
292
-
293
-
294
- def _is_density(leaf: Any) -> Any:
295
- """Return `True` if `leaf` is a density array."""
296
- return isinstance(leaf, types.Density2DArray)
297
-
298
-
299
- def _to_numpy(params: PyTree) -> NDArray:
300
- """Flattens a `params` pytree into a single rank-1 numpy array."""
301
- x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
302
- return onp.asarray(x, dtype=onp.float64)
303
-
304
-
305
- def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
306
- """Restores a pytree from a flat numpy array using the structure of `params`.
307
-
308
- Note that the returned pytree includes jax array leaves.
309
-
310
- Args:
311
- x_flat: The rank-1 numpy array to be restored.
312
- params: A pytree of parameters whose structure is replicated in the restored
313
- pytree.
314
-
315
- Returns:
316
- The restored pytree, with jax array leaves.
317
- """
318
- _, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
319
- return unflatten_fn(jnp.asarray(x_flat, dtype=float))
320
-
321
-
322
- def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
323
- """Generates a bound vector for the `params`.
324
-
325
- The `bound` can be specified in various ways; it may be `None` or a scalar,
326
- which then applies to all arrays in `params`. It may be a pytree with
327
- structure matching that of `params`, where each leaf is either `None`, a
328
- scalar, or an array matching the shape of the corresponding leaf in `params`.
329
-
330
- The returned bound is a flat array suitable for use with `ScipyLbfgsbState`.
331
-
332
- Args:
333
- bound: The pytree of bounds.
334
- params: The pytree of parameters.
335
-
336
- Returns:
337
- The flat elementwise bound.
338
- """
339
-
340
- if bound is None or onp.isscalar(bound):
341
- bound = tree_util.tree_map(
342
- lambda _: bound,
343
- params,
344
- is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES),
345
- )
346
-
347
- bound_leaves, bound_treedef = tree_util.tree_flatten(
348
- bound, is_leaf=lambda x: x is None
349
- )
350
- params_leaves = tree_util.tree_leaves(params, is_leaf=lambda x: x is None)
351
-
352
- # `bound` should be a pytree of arrays or `None`, while `params` may
353
- # include custom pytree nodes. Convert the custom nodes into standard
354
- # types to facilitate validation that the tree structures match.
355
- params_treedef = tree_util.tree_structure(
356
- tree_util.tree_map(
357
- lambda x: 0.0,
358
- tree=params,
359
- is_leaf=lambda x: x is None or isinstance(x, types.CUSTOM_TYPES),
360
- )
361
- )
362
- if bound_treedef != params_treedef: # type: ignore[operator]
363
- raise ValueError(
364
- f"Tree structure of `bound` and `params` must match, but got "
365
- f"{bound_treedef} and {params_treedef}, respectively."
366
- )
367
-
368
- bound_flat = []
369
- for b, p in zip(bound_leaves, params_leaves):
370
- if p is None:
371
- continue
372
- if b is None or onp.isscalar(b) or onp.shape(b) == ():
373
- bound_flat += [b] * onp.size(p)
374
- else:
375
- if b.shape != p.shape:
376
- raise ValueError(
377
- f"`bound` must be `None`, a scalar, or have shape matching "
378
- f"`params`, but got shape {b.shape} when params has shape "
379
- f"{p.shape}."
380
- )
381
- bound_flat += b.flatten().tolist()
382
-
383
- return bound_flat
384
-
385
-
386
- def _example_state(params: PyTree, maxcor: int) -> PyTree:
387
- """Return an example state for the given `params` and `maxcor`."""
388
- params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
389
- n = params_flat.size
390
- float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
391
- example_jax_lbfgsb_state = dict(
392
- x=jnp.zeros(n, dtype=float),
393
- _maxcor=jnp.zeros((), dtype=int),
394
- _line_search_max_steps=jnp.zeros((), dtype=int),
395
- _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
396
- _iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
397
- _task=jnp.zeros(59, dtype=int),
398
- _csave=jnp.zeros(59, dtype=int),
399
- _lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
400
- _isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
401
- _dsave=jnp.zeros(29, dtype=float),
402
- _lower_bound=jnp.zeros(n, dtype=float),
403
- _upper_bound=jnp.zeros(n, dtype=float),
404
- _bound_type=jnp.zeros(n, dtype=int),
405
- )
406
- return float_params, example_jax_lbfgsb_state
407
-
408
-
409
- # ------------------------------------------------------------------------------
410
- # Wrapper for scipy's L-BFGS-B implementation.
411
- # ------------------------------------------------------------------------------
412
-
413
-
414
- @dataclasses.dataclass
415
- class ScipyLbfgsbState:
416
- """Stores the state of a scipy L-BFGS-B minimization.
417
-
418
- This class enables optimization with a more functional style, giving the user
419
- control over the optimization loop. Example usage is as follows:
420
-
421
- value_fn = lambda x: onp.sum(x**2)
422
- grad_fn = lambda x: 2 * x
423
-
424
- x0 = onp.asarray([0.1, 0.2, 0.3])
425
- lb = [None, -1, 0.1]
426
- ub = [None, None, None]
427
- state = ScipyLbfgsbState.init(
428
- x0=x0, lower_bound=lb, upper_bound=ub, maxcor=20
429
- )
430
-
431
- for _ in range(10):
432
- value = value_fn(state.x)
433
- grad = grad_fn(state.x)
434
- state.update(grad, value)
435
-
436
- This example converges with `state.x` equal to `(0, 0, 0.1)` and value equal
437
- to `0.01`.
438
-
439
- Attributes:
440
- x: The current solution vector.
441
- """
442
-
443
- x: NDArray
444
- # Private attributes correspond to internal variables in the `scipy.optimize.
445
- # lbfgsb._minimize_lbfgsb` function.
446
- _maxcor: int
447
- _line_search_max_steps: int
448
- _wa: NDArray
449
- _iwa: NDArray
450
- _task: NDArray
451
- _csave: NDArray
452
- _lsave: NDArray
453
- _isave: NDArray
454
- _dsave: NDArray
455
- _lower_bound: NDArray
456
- _upper_bound: NDArray
457
- _bound_type: NDArray
458
-
459
- def __post_init__(self) -> None:
460
- """Validates the datatypes for all state attributes."""
461
- _validate_array_dtype(self.x, onp.float64)
462
- _validate_array_dtype(self._wa, onp.float64)
463
- _validate_array_dtype(self._iwa, FORTRAN_INT)
464
- _validate_array_dtype(self._task, "S60")
465
- _validate_array_dtype(self._csave, "S60")
466
- _validate_array_dtype(self._lsave, FORTRAN_INT)
467
- _validate_array_dtype(self._isave, FORTRAN_INT)
468
- _validate_array_dtype(self._dsave, onp.float64)
469
- _validate_array_dtype(self._lower_bound, onp.float64)
470
- _validate_array_dtype(self._upper_bound, onp.float64)
471
- _validate_array_dtype(self._bound_type, int)
472
-
473
- def to_jax(self) -> Dict[str, jnp.ndarray]:
474
- """Generates a dictionary of jax arrays defining the state."""
475
- return dict(
476
- x=jnp.asarray(self.x),
477
- _maxcor=jnp.asarray(self._maxcor),
478
- _line_search_max_steps=jnp.asarray(self._line_search_max_steps),
479
- _wa=jnp.asarray(self._wa),
480
- _iwa=jnp.asarray(self._iwa),
481
- _task=_array_from_s60_str(self._task),
482
- _csave=_array_from_s60_str(self._csave),
483
- _lsave=jnp.asarray(self._lsave),
484
- _isave=jnp.asarray(self._isave),
485
- _dsave=jnp.asarray(self._dsave),
486
- _lower_bound=jnp.asarray(self._lower_bound),
487
- _upper_bound=jnp.asarray(self._upper_bound),
488
- _bound_type=jnp.asarray(self._bound_type),
489
- )
490
-
491
- @classmethod
492
- def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
493
- """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
494
- state_dict = copy.deepcopy(state_dict)
495
- return ScipyLbfgsbState(
496
- x=onp.asarray(state_dict["x"], dtype=onp.float64),
497
- _maxcor=int(state_dict["_maxcor"]),
498
- _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
499
- _wa=onp.asarray(state_dict["_wa"], onp.float64),
500
- _iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
501
- _task=_s60_str_from_array(state_dict["_task"]),
502
- _csave=_s60_str_from_array(state_dict["_csave"]),
503
- _lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
504
- _isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
505
- _dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
506
- _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
507
- _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
508
- _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
509
- )
510
-
511
- @classmethod
512
- def init(
513
- cls,
514
- x0: NDArray,
515
- lower_bound: ElementwiseBound,
516
- upper_bound: ElementwiseBound,
517
- maxcor: int,
518
- line_search_max_steps: int,
519
- ) -> "ScipyLbfgsbState":
520
- """Initializes the `ScipyLbfgsbState` for `x0`.
521
-
522
- Args:
523
- x0: Array giving the initial solution vector.
524
- lower_bound: Array giving the elementwise optional lower bound.
525
- upper_bound: Array giving the elementwise optional upper bound.
526
- maxcor: The maximum number of variable metric corrections used to define
527
- the limited memory matrix, in the L-BFGS-B scheme.
528
- line_search_max_steps: The maximum number of steps in the line search.
529
-
530
- Returns:
531
- The `ScipyLbfgsbState`.
532
- """
533
- x0 = onp.asarray(x0)
534
- if x0.ndim > 1:
535
- raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
536
- lower_bound = onp.asarray(lower_bound)
537
- upper_bound = onp.asarray(upper_bound)
538
- if x0.shape != lower_bound.shape or x0.shape != upper_bound.shape:
539
- raise ValueError(
540
- f"`x0`, `lower_bound`, and `upper_bound` must have matching "
541
- f"shape but got shapes {x0.shape}, {lower_bound.shape}, and "
542
- f"{upper_bound.shape}, respectively."
543
- )
544
- if maxcor < 1:
545
- raise ValueError(f"`maxcor` must be positive but got {maxcor}.")
546
-
547
- n = x0.size
548
- lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
549
- lower_bound, upper_bound
550
- )
551
- task = onp.zeros(1, "S60")
552
- task[:] = TASK_START
553
-
554
- # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
555
- # function.
556
- wa_size = _wa_size(n=n, maxcor=maxcor)
557
- state = ScipyLbfgsbState(
558
- x=onp.array(x0, onp.float64),
559
- _maxcor=maxcor,
560
- _line_search_max_steps=line_search_max_steps,
561
- _wa=onp.zeros(wa_size, onp.float64),
562
- _iwa=onp.zeros(3 * n, FORTRAN_INT),
563
- _task=task,
564
- _csave=onp.zeros(1, "S60"),
565
- _lsave=onp.zeros(4, FORTRAN_INT),
566
- _isave=onp.zeros(44, FORTRAN_INT),
567
- _dsave=onp.zeros(29, onp.float64),
568
- _lower_bound=lower_bound_array,
569
- _upper_bound=upper_bound_array,
570
- _bound_type=bound_type,
571
- )
572
- # The initial state requires an update with zero value and gradient. This
573
- # is because the initial task is "START", which does not actually require
574
- # value and gradient evaluation.
575
- state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
576
- return state
577
-
578
- def update(
579
- self,
580
- grad: NDArray,
581
- value: NDArray,
582
- ) -> None:
583
- """Performs an in-place update of the `ScipyLbfgsbState`.
584
-
585
- Args:
586
- grad: The function gradient for the current `x`.
587
- value: The scalar function value for the current `x`.
588
- """
589
- if grad.shape != self.x.shape:
590
- raise ValueError(
591
- f"`grad` must have the same shape as attribute `x`, but got shapes "
592
- f"{grad.shape} and {self.x.shape}, respectively."
593
- )
594
- if value.shape != ():
595
- raise ValueError(f"`value` must be a scalar but got shape {value.shape}.")
596
-
597
- # The `setulb` function will sometimes return with a task that does not
598
- # require a value and gradient evaluation. In this case we simply call it
599
- # again, advancing past such "dummy" steps.
600
- for _ in range(3):
601
- scipy_lbfgsb.setulb(
602
- m=self._maxcor,
603
- x=self.x,
604
- l=self._lower_bound,
605
- u=self._upper_bound,
606
- nbd=self._bound_type,
607
- f=value,
608
- g=grad,
609
- factr=UPDATE_FACTR,
610
- pgtol=UPDATE_PGTOL,
611
- wa=self._wa,
612
- iwa=self._iwa,
613
- task=self._task,
614
- iprint=UPDATE_IPRINT,
615
- csave=self._csave,
616
- lsave=self._lsave,
617
- isave=self._isave,
618
- dsave=self._dsave,
619
- maxls=self._line_search_max_steps,
620
- )
621
- task_str = self._task.tobytes()
622
- if task_str.startswith(TASK_FG):
623
- break
624
-
625
-
626
- def _wa_size(n: int, maxcor: int) -> int:
627
- """Return the size of the `wa` attribute of lbfgsb state."""
628
- return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
629
-
630
-
631
- def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
632
- """Validates that `x` is an array with the specified `dtype`."""
633
- if not isinstance(x, onp.ndarray):
634
- raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
635
- if x.dtype != dtype:
636
- raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
637
-
638
-
639
- def _configure_bounds(
640
- lower_bound: ElementwiseBound,
641
- upper_bound: ElementwiseBound,
642
- ) -> Tuple[NDArray, NDArray, NDArray]:
643
- """Configures the bounds for an L-BFGS-B optimization."""
644
- bound_type = [
645
- BOUNDS_MAP[(lower is None, upper is None)]
646
- for lower, upper in zip(lower_bound, upper_bound)
647
- ]
648
- lower_bound_array = [0.0 if x is None else x for x in lower_bound]
649
- upper_bound_array = [0.0 if x is None else x for x in upper_bound]
650
- return (
651
- onp.asarray(lower_bound_array, onp.float64),
652
- onp.asarray(upper_bound_array, onp.float64),
653
- onp.asarray(bound_type),
654
- )
655
-
656
-
657
- def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
658
- """Return a jax array for a numpy s60 string."""
659
- assert s60_str.shape == (1,)
660
- chars = [int(o) for o in s60_str[0]]
661
- chars.extend([32] * (59 - len(chars)))
662
- return jnp.asarray(chars, dtype=int)
663
-
664
-
665
- def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
666
- """Return a numpy s60 string for a jax array."""
667
- return onp.asarray(
668
- [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
669
- dtype="S60",
670
- )
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2023 The INVRS-IO authors.
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.