invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +14 -3
 - invrs_opt/experimental/client.py +7 -4
 - invrs_opt/{base.py → optimizers/base.py} +16 -1
 - invrs_opt/optimizers/lbfgsb.py +939 -0
 - invrs_opt/optimizers/wrapped_optax.py +347 -0
 - invrs_opt/parameterization/__init__.py +0 -0
 - invrs_opt/parameterization/base.py +208 -0
 - invrs_opt/parameterization/filter_project.py +138 -0
 - invrs_opt/parameterization/gaussian_levelset.py +671 -0
 - invrs_opt/parameterization/pixel.py +75 -0
 - invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
 - invrs_opt-0.10.3.dist-info/LICENSE +504 -0
 - invrs_opt-0.10.3.dist-info/METADATA +560 -0
 - invrs_opt-0.10.3.dist-info/RECORD +20 -0
 - {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
 - invrs_opt/lbfgsb/lbfgsb.py +0 -672
 - invrs_opt-0.4.0.dist-info/LICENSE +0 -21
 - invrs_opt-0.4.0.dist-info/METADATA +0 -75
 - invrs_opt-0.4.0.dist-info/RECORD +0 -14
 - /invrs_opt/{lbfgsb → optimizers}/__init__.py +0 -0
 - {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
 
    
        invrs_opt/lbfgsb/lbfgsb.py
    DELETED
    
    | 
         @@ -1,672 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            """Defines a jax-style wrapper for scipy's L-BFGS-B algorithm.
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
            Copyright (c) 2023 The INVRS-IO authors.
         
     | 
| 
       4 
     | 
    
         
            -
            """
         
     | 
| 
       5 
     | 
    
         
            -
             
     | 
| 
       6 
     | 
    
         
            -
            import copy
         
     | 
| 
       7 
     | 
    
         
            -
            import dataclasses
         
     | 
| 
       8 
     | 
    
         
            -
            from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
         
     | 
| 
       9 
     | 
    
         
            -
             
     | 
| 
       10 
     | 
    
         
            -
            import jax
         
     | 
| 
       11 
     | 
    
         
            -
            import jax.numpy as jnp
         
     | 
| 
       12 
     | 
    
         
            -
            import numpy as onp
         
     | 
| 
       13 
     | 
    
         
            -
            from jax import flatten_util, tree_util
         
     | 
| 
       14 
     | 
    
         
            -
            from scipy.optimize._lbfgsb_py import (  # type: ignore[import-untyped]
         
     | 
| 
       15 
     | 
    
         
            -
                _lbfgsb as scipy_lbfgsb,
         
     | 
| 
       16 
     | 
    
         
            -
            )
         
     | 
| 
       17 
     | 
    
         
            -
            from totypes import types
         
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
            from invrs_opt import base
         
     | 
| 
       20 
     | 
    
         
            -
            from invrs_opt.lbfgsb import transform
         
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
     | 
    
         
            -
            NDArray = onp.ndarray[Any, Any]
         
     | 
| 
       23 
     | 
    
         
            -
            PyTree = Any
         
     | 
| 
       24 
     | 
    
         
            -
            ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
         
     | 
| 
       25 
     | 
    
         
            -
            JaxLbfgsbDict = Dict[str, jnp.ndarray]
         
     | 
| 
       26 
     | 
    
         
            -
            LbfgsbState = Tuple[PyTree, PyTree, JaxLbfgsbDict]
         
     | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
            # Task message prefixes for the underlying L-BFGS-B implementation.
         
     | 
| 
       30 
     | 
    
         
            -
            TASK_START = b"START"
         
     | 
| 
       31 
     | 
    
         
            -
            TASK_FG = b"FG"
         
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
            # Parameters which configure the state update step.
         
     | 
| 
       34 
     | 
    
         
            -
            UPDATE_IPRINT = -1
         
     | 
| 
       35 
     | 
    
         
            -
            UPDATE_PGTOL = 0.0
         
     | 
| 
       36 
     | 
    
         
            -
            UPDATE_FACTR = 0.0
         
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
            # Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
         
     | 
| 
       39 
     | 
    
         
            -
            MAXCOR_MAX_VALUE = 100
         
     | 
| 
       40 
     | 
    
         
            -
            MAXCOR_DEFAULT = 20
         
     | 
| 
       41 
     | 
    
         
            -
            LINE_SEARCH_MAX_STEPS_DEFAULT = 100
         
     | 
| 
       42 
     | 
    
         
            -
             
     | 
| 
       43 
     | 
    
         
            -
            # Maps bound scenarios to integers.
         
     | 
| 
       44 
     | 
    
         
            -
            BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
         
     | 
| 
       45 
     | 
    
         
            -
                (True, True): 0,  # Both upper and lower bound are `None`.
         
     | 
| 
       46 
     | 
    
         
            -
                (False, True): 1,  # Only upper bound is `None`.
         
     | 
| 
       47 
     | 
    
         
            -
                (False, False): 2,  # Neither of the bounds are `None`.
         
     | 
| 
       48 
     | 
    
         
            -
                (True, False): 3,  # Only the lower bound is `None`.
         
     | 
| 
       49 
     | 
    
         
            -
            }
         
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
       51 
     | 
    
         
            -
            FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
         
     | 
| 
       52 
     | 
    
         
            -
             
     | 
| 
       53 
     | 
    
         
            -
             
     | 
| 
       54 
     | 
    
         
            -
            def lbfgsb(
         
     | 
| 
       55 
     | 
    
         
            -
                maxcor: int = MAXCOR_DEFAULT,
         
     | 
| 
       56 
     | 
    
         
            -
                line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
         
     | 
| 
       57 
     | 
    
         
            -
            ) -> base.Optimizer:
         
     | 
| 
       58 
     | 
    
         
            -
                """Return an optimizer implementing the standard L-BFGS-B algorithm.
         
     | 
| 
       59 
     | 
    
         
            -
             
     | 
| 
       60 
     | 
    
         
            -
                This optimizer wraps scipy's implementation of the algorithm, and provides
         
     | 
| 
       61 
     | 
    
         
            -
                a jax-style API to the scheme. The optimizer works with custom types such
         
     | 
| 
       62 
     | 
    
         
            -
                as the `BoundedArray` to constrain the optimization variable.
         
     | 
| 
       63 
     | 
    
         
            -
             
     | 
| 
       64 
     | 
    
         
            -
                Example usage is as follows:
         
     | 
| 
       65 
     | 
    
         
            -
             
     | 
| 
       66 
     | 
    
         
            -
                    def fn(x):
         
     | 
| 
       67 
     | 
    
         
            -
                        leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
         
     | 
| 
       68 
     | 
    
         
            -
                        return jnp.sum(jnp.asarray(leaves_sum_sq))
         
     | 
| 
       69 
     | 
    
         
            -
             
     | 
| 
       70 
     | 
    
         
            -
                    x0 = {
         
     | 
| 
       71 
     | 
    
         
            -
                        "a": jnp.ones((3,)),
         
     | 
| 
       72 
     | 
    
         
            -
                        "b": BoundedArray(
         
     | 
| 
       73 
     | 
    
         
            -
                            value=-jnp.ones((2, 5)),
         
     | 
| 
       74 
     | 
    
         
            -
                            lower_bound=-5,
         
     | 
| 
       75 
     | 
    
         
            -
                            upper_bound=5,
         
     | 
| 
       76 
     | 
    
         
            -
                        ),
         
     | 
| 
       77 
     | 
    
         
            -
                    }
         
     | 
| 
       78 
     | 
    
         
            -
                    opt = lbfgsb(maxcor=20, line_search_max_steps=100)
         
     | 
| 
       79 
     | 
    
         
            -
                    state = opt.init(x0)
         
     | 
| 
       80 
     | 
    
         
            -
                    for _ in range(10):
         
     | 
| 
       81 
     | 
    
         
            -
                        x = opt.params(state)
         
     | 
| 
       82 
     | 
    
         
            -
                        value, grad = jax.value_and_grad(fn)(x)
         
     | 
| 
       83 
     | 
    
         
            -
                        state = opt.update(grad, value, state)
         
     | 
| 
       84 
     | 
    
         
            -
             
     | 
| 
       85 
     | 
    
         
            -
                While the algorithm can work with pytrees of jax arrays, numpy arrays can
         
     | 
| 
       86 
     | 
    
         
            -
                also be used. Thus, e.g. the optimizer can directly be used with autograd.
         
     | 
| 
       87 
     | 
    
         
            -
             
     | 
| 
       88 
     | 
    
         
            -
                Args:
         
     | 
| 
       89 
     | 
    
         
            -
                    maxcor: The maximum number of variable metric corrections used to define
         
     | 
| 
       90 
     | 
    
         
            -
                        the limited memory matrix, in the L-BFGS-B scheme.
         
     | 
| 
       91 
     | 
    
         
            -
                    line_search_max_steps: The maximum number of steps in the line search.
         
     | 
| 
       92 
     | 
    
         
            -
             
     | 
| 
       93 
     | 
    
         
            -
                Returns:
         
     | 
| 
       94 
     | 
    
         
            -
                    The `base.Optimizer`.
         
     | 
| 
       95 
     | 
    
         
            -
                """
         
     | 
| 
       96 
     | 
    
         
            -
                return transformed_lbfgsb(
         
     | 
| 
       97 
     | 
    
         
            -
                    maxcor=maxcor,
         
     | 
| 
       98 
     | 
    
         
            -
                    line_search_max_steps=line_search_max_steps,
         
     | 
| 
       99 
     | 
    
         
            -
                    transform_fn=lambda x: x,
         
     | 
| 
       100 
     | 
    
         
            -
                    initialize_latent_fn=lambda x: x,
         
     | 
| 
       101 
     | 
    
         
            -
                )
         
     | 
| 
       102 
     | 
    
         
            -
             
     | 
| 
       103 
     | 
    
         
            -
             
     | 
| 
       104 
     | 
    
         
            -
            def density_lbfgsb(
         
     | 
| 
       105 
     | 
    
         
            -
                beta: float,
         
     | 
| 
       106 
     | 
    
         
            -
                maxcor: int = MAXCOR_DEFAULT,
         
     | 
| 
       107 
     | 
    
         
            -
                line_search_max_steps: int = LINE_SEARCH_MAX_STEPS_DEFAULT,
         
     | 
| 
       108 
     | 
    
         
            -
            ) -> base.Optimizer:
         
     | 
| 
       109 
     | 
    
         
            -
                """Return an L-BFGS-B optimizer with additional transforms for density arrays.
         
     | 
| 
       110 
     | 
    
         
            -
             
     | 
| 
       111 
     | 
    
         
            -
                Parameters that are of type `DensityArray2D` are represented as latent parameters
         
     | 
| 
       112 
     | 
    
         
            -
                that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
         
     | 
| 
       113 
     | 
    
         
            -
             
     | 
| 
       114 
     | 
    
         
            -
                    transformed = tanh(beta * conv(density.array, gaussian_kernel)) / tanh(beta)
         
     | 
| 
       115 
     | 
    
         
            -
             
     | 
| 
       116 
     | 
    
         
            -
                where the kernel has a full-width at half-maximum determined by the minimum width
         
     | 
| 
       117 
     | 
    
         
            -
                and spacing parameters of the `DensityArray2D`. Where the bounds differ, the
         
     | 
| 
       118 
     | 
    
         
            -
                density is scaled before the transform is applied, and then unscaled afterwards.
         
     | 
| 
       119 
     | 
    
         
            -
             
     | 
| 
       120 
     | 
    
         
            -
                Args:
         
     | 
| 
       121 
     | 
    
         
            -
                    beta: Determines the steepness of the thresholding.
         
     | 
| 
       122 
     | 
    
         
            -
                    maxcor: The maximum number of variable metric corrections used to define
         
     | 
| 
       123 
     | 
    
         
            -
                        the limited memory matrix, in the L-BFGS-B scheme.
         
     | 
| 
       124 
     | 
    
         
            -
                    line_search_max_steps: The maximum number of steps in the line search.
         
     | 
| 
       125 
     | 
    
         
            -
             
     | 
| 
       126 
     | 
    
         
            -
                Returns:
         
     | 
| 
       127 
     | 
    
         
            -
                    The `base.Optimizer`.
         
     | 
| 
       128 
     | 
    
         
            -
                """
         
     | 
| 
       129 
     | 
    
         
            -
             
     | 
| 
       130 
     | 
    
         
            -
                def transform_fn(tree: PyTree) -> PyTree:
         
     | 
| 
       131 
     | 
    
         
            -
                    return tree_util.tree_map(
         
     | 
| 
       132 
     | 
    
         
            -
                        lambda x: transform_density(x) if _is_density(x) else x,
         
     | 
| 
       133 
     | 
    
         
            -
                        tree,
         
     | 
| 
       134 
     | 
    
         
            -
                        is_leaf=_is_density,
         
     | 
| 
       135 
     | 
    
         
            -
                    )
         
     | 
| 
       136 
     | 
    
         
            -
             
     | 
| 
       137 
     | 
    
         
            -
                def initialize_latent_fn(tree: PyTree) -> PyTree:
         
     | 
| 
       138 
     | 
    
         
            -
                    return tree_util.tree_map(
         
     | 
| 
       139 
     | 
    
         
            -
                        lambda x: initialize_latent_density(x) if _is_density(x) else x,
         
     | 
| 
       140 
     | 
    
         
            -
                        tree,
         
     | 
| 
       141 
     | 
    
         
            -
                        is_leaf=_is_density,
         
     | 
| 
       142 
     | 
    
         
            -
                    )
         
     | 
| 
       143 
     | 
    
         
            -
             
     | 
| 
       144 
     | 
    
         
            -
                def transform_density(density: types.Density2DArray) -> types.Density2DArray:
         
     | 
| 
       145 
     | 
    
         
            -
                    transformed = types.symmetrize_density(density)
         
     | 
| 
       146 
     | 
    
         
            -
                    transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
         
     | 
| 
       147 
     | 
    
         
            -
                    # Scale to ensure that the full valid range of the density array is reachable.
         
     | 
| 
       148 
     | 
    
         
            -
                    mid_value = (density.lower_bound + density.upper_bound) / 2
         
     | 
| 
       149 
     | 
    
         
            -
                    transformed = tree_util.tree_map(
         
     | 
| 
       150 
     | 
    
         
            -
                        lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
         
     | 
| 
       151 
     | 
    
         
            -
                    )
         
     | 
| 
       152 
     | 
    
         
            -
                    return transform.apply_fixed_pixels(transformed)
         
     | 
| 
       153 
     | 
    
         
            -
             
     | 
| 
       154 
     | 
    
         
            -
                def initialize_latent_density(
         
     | 
| 
       155 
     | 
    
         
            -
                    density: types.Density2DArray,
         
     | 
| 
       156 
     | 
    
         
            -
                ) -> types.Density2DArray:
         
     | 
| 
       157 
     | 
    
         
            -
                    array = transform.normalized_array_from_density(density)
         
     | 
| 
       158 
     | 
    
         
            -
                    array = jnp.clip(array, -1, 1)
         
     | 
| 
       159 
     | 
    
         
            -
                    array *= jnp.tanh(beta)
         
     | 
| 
       160 
     | 
    
         
            -
                    latent_array = jnp.arctanh(array) / beta
         
     | 
| 
       161 
     | 
    
         
            -
                    latent_array = transform.rescale_array_for_density(latent_array, density)
         
     | 
| 
       162 
     | 
    
         
            -
                    return dataclasses.replace(density, array=latent_array)
         
     | 
| 
       163 
     | 
    
         
            -
             
     | 
| 
       164 
     | 
    
         
            -
                return transformed_lbfgsb(
         
     | 
| 
       165 
     | 
    
         
            -
                    maxcor=maxcor,
         
     | 
| 
       166 
     | 
    
         
            -
                    line_search_max_steps=line_search_max_steps,
         
     | 
| 
       167 
     | 
    
         
            -
                    transform_fn=transform_fn,
         
     | 
| 
       168 
     | 
    
         
            -
                    initialize_latent_fn=initialize_latent_fn,
         
     | 
| 
       169 
     | 
    
         
            -
                )
         
     | 
| 
       170 
     | 
    
         
            -
             
     | 
| 
       171 
     | 
    
         
            -
             
     | 
| 
       172 
     | 
    
         
            -
            def transformed_lbfgsb(
         
     | 
| 
       173 
     | 
    
         
            -
                maxcor: int,
         
     | 
| 
       174 
     | 
    
         
            -
                line_search_max_steps: int,
         
     | 
| 
       175 
     | 
    
         
            -
                transform_fn: Callable[[PyTree], PyTree],
         
     | 
| 
       176 
     | 
    
         
            -
                initialize_latent_fn: Callable[[PyTree], PyTree],
         
     | 
| 
       177 
     | 
    
         
            -
            ) -> base.Optimizer:
         
     | 
| 
       178 
     | 
    
         
            -
                """Construct an latent parameter L-BFGS-B optimizer.
         
     | 
| 
       179 
     | 
    
         
            -
             
     | 
| 
       180 
     | 
    
         
            -
                The optimized parameters are termed latent parameters, from which the
         
     | 
| 
       181 
     | 
    
         
            -
                actual parameters returned by the optimizer are obtained using the
         
     | 
| 
       182 
     | 
    
         
            -
                `transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
         
     | 
| 
       183 
     | 
    
         
            -
                the identity), this is equivalent to the standard L-BFGS-B algorithm.
         
     | 
| 
       184 
     | 
    
         
            -
             
     | 
| 
       185 
     | 
    
         
            -
                Args:
         
     | 
| 
       186 
     | 
    
         
            -
                    maxcor: The maximum number of variable metric corrections used to define
         
     | 
| 
       187 
     | 
    
         
            -
                        the limited memory matrix, in the L-BFGS-B scheme.
         
     | 
| 
       188 
     | 
    
         
            -
                    line_search_max_steps: The maximum number of steps in the line search.
         
     | 
| 
       189 
     | 
    
         
            -
                    transform_fn: Function which transforms the internal latent parameters to
         
     | 
| 
       190 
     | 
    
         
            -
                        the parameters returned by the optimizer.
         
     | 
| 
       191 
     | 
    
         
            -
                    initialize_latent_fn: Function which computes the initial latent parameters
         
     | 
| 
       192 
     | 
    
         
            -
                        given the initial parameters.
         
     | 
| 
       193 
     | 
    
         
            -
             
     | 
| 
       194 
     | 
    
         
            -
                Returns:
         
     | 
| 
       195 
     | 
    
         
            -
                    The `base.Optimizer`.
         
     | 
| 
       196 
     | 
    
         
            -
                """
         
     | 
| 
       197 
     | 
    
         
            -
                if not isinstance(maxcor, int) or maxcor < 1 or maxcor > MAXCOR_MAX_VALUE:
         
     | 
| 
       198 
     | 
    
         
            -
                    raise ValueError(
         
     | 
| 
       199 
     | 
    
         
            -
                        f"`maxcor` must be greater than 0 and less than "
         
     | 
| 
       200 
     | 
    
         
            -
                        f"{MAXCOR_MAX_VALUE}, but got {maxcor}"
         
     | 
| 
       201 
     | 
    
         
            -
                    )
         
     | 
| 
       202 
     | 
    
         
            -
             
     | 
| 
       203 
     | 
    
         
            -
                if not isinstance(line_search_max_steps, int) or line_search_max_steps < 1:
         
     | 
| 
       204 
     | 
    
         
            -
                    raise ValueError(
         
     | 
| 
       205 
     | 
    
         
            -
                        f"`line_search_max_steps` must be greater than 0 but got "
         
     | 
| 
       206 
     | 
    
         
            -
                        f"{line_search_max_steps}"
         
     | 
| 
       207 
     | 
    
         
            -
                    )
         
     | 
| 
       208 
     | 
    
         
            -
             
     | 
| 
       209 
     | 
    
         
            -
                def init_fn(params: PyTree) -> LbfgsbState:
         
     | 
| 
       210 
     | 
    
         
            -
                    """Initializes the optimization state."""
         
     | 
| 
       211 
     | 
    
         
            -
             
     | 
| 
       212 
     | 
    
         
            -
                    def _init_pure(params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
         
     | 
| 
       213 
     | 
    
         
            -
                        lower_bound = types.extract_lower_bound(params)
         
     | 
| 
       214 
     | 
    
         
            -
                        upper_bound = types.extract_upper_bound(params)
         
     | 
| 
       215 
     | 
    
         
            -
                        scipy_lbfgsb_state = ScipyLbfgsbState.init(
         
     | 
| 
       216 
     | 
    
         
            -
                            x0=_to_numpy(params),
         
     | 
| 
       217 
     | 
    
         
            -
                            lower_bound=_bound_for_params(lower_bound, params),
         
     | 
| 
       218 
     | 
    
         
            -
                            upper_bound=_bound_for_params(upper_bound, params),
         
     | 
| 
       219 
     | 
    
         
            -
                            maxcor=maxcor,
         
     | 
| 
       220 
     | 
    
         
            -
                            line_search_max_steps=line_search_max_steps,
         
     | 
| 
       221 
     | 
    
         
            -
                        )
         
     | 
| 
       222 
     | 
    
         
            -
                        latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
         
     | 
| 
       223 
     | 
    
         
            -
                        return latent_params, scipy_lbfgsb_state.to_jax()
         
     | 
| 
       224 
     | 
    
         
            -
             
     | 
| 
       225 
     | 
    
         
            -
                    (
         
     | 
| 
       226 
     | 
    
         
            -
                        latent_params,
         
     | 
| 
       227 
     | 
    
         
            -
                        jax_lbfgsb_state,
         
     | 
| 
       228 
     | 
    
         
            -
                    ) = jax.pure_callback(  # type: ignore[attr-defined]
         
     | 
| 
       229 
     | 
    
         
            -
                        _init_pure,
         
     | 
| 
       230 
     | 
    
         
            -
                        _example_state(params, maxcor),
         
     | 
| 
       231 
     | 
    
         
            -
                        initialize_latent_fn(params),
         
     | 
| 
       232 
     | 
    
         
            -
                    )
         
     | 
| 
       233 
     | 
    
         
            -
                    return transform_fn(latent_params), latent_params, jax_lbfgsb_state
         
     | 
| 
       234 
     | 
    
         
            -
             
     | 
| 
       235 
     | 
    
         
            -
                def params_fn(state: LbfgsbState) -> PyTree:
         
     | 
| 
       236 
     | 
    
         
            -
                    """Returns the parameters for the given `state`."""
         
     | 
| 
       237 
     | 
    
         
            -
                    params, _, _ = state
         
     | 
| 
       238 
     | 
    
         
            -
                    return params
         
     | 
| 
       239 
     | 
    
         
            -
             
     | 
| 
       240 
     | 
    
         
            -
                def update_fn(
         
     | 
| 
       241 
     | 
    
         
            -
                    *,
         
     | 
| 
       242 
     | 
    
         
            -
                    grad: PyTree,
         
     | 
| 
       243 
     | 
    
         
            -
                    value: float,
         
     | 
| 
       244 
     | 
    
         
            -
                    params: PyTree,
         
     | 
| 
       245 
     | 
    
         
            -
                    state: LbfgsbState,
         
     | 
| 
       246 
     | 
    
         
            -
                ) -> LbfgsbState:
         
     | 
| 
       247 
     | 
    
         
            -
                    """Updates the state."""
         
     | 
| 
       248 
     | 
    
         
            -
                    del params
         
     | 
| 
       249 
     | 
    
         
            -
             
     | 
| 
       250 
     | 
    
         
            -
                    def _update_pure(
         
     | 
| 
       251 
     | 
    
         
            -
                        flat_latent_grad: PyTree,
         
     | 
| 
       252 
     | 
    
         
            -
                        value: jnp.ndarray,
         
     | 
| 
       253 
     | 
    
         
            -
                        jax_lbfgsb_state: JaxLbfgsbDict,
         
     | 
| 
       254 
     | 
    
         
            -
                    ) -> Tuple[PyTree, JaxLbfgsbDict]:
         
     | 
| 
       255 
     | 
    
         
            -
                        assert onp.size(value) == 1
         
     | 
| 
       256 
     | 
    
         
            -
                        scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
         
     | 
| 
       257 
     | 
    
         
            -
                        scipy_lbfgsb_state.update(
         
     | 
| 
       258 
     | 
    
         
            -
                            grad=onp.asarray(flat_latent_grad, dtype=onp.float64),
         
     | 
| 
       259 
     | 
    
         
            -
                            value=onp.asarray(value, dtype=onp.float64),
         
     | 
| 
       260 
     | 
    
         
            -
                        )
         
     | 
| 
       261 
     | 
    
         
            -
                        flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
         
     | 
| 
       262 
     | 
    
         
            -
                        return flat_latent_params, scipy_lbfgsb_state.to_jax()
         
     | 
| 
       263 
     | 
    
         
            -
             
     | 
| 
       264 
     | 
    
         
            -
                    _, latent_params, jax_lbfgsb_state = state
         
     | 
| 
       265 
     | 
    
         
            -
                    _, vjp_fn = jax.vjp(transform_fn, latent_params)
         
     | 
| 
       266 
     | 
    
         
            -
                    (latent_grad,) = vjp_fn(grad)
         
     | 
| 
       267 
     | 
    
         
            -
                    flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
         
     | 
| 
       268 
     | 
    
         
            -
                        latent_grad
         
     | 
| 
       269 
     | 
    
         
            -
                    )  # type: ignore[no-untyped-call]
         
     | 
| 
       270 
     | 
    
         
            -
             
     | 
| 
       271 
     | 
    
         
            -
                    (
         
     | 
| 
       272 
     | 
    
         
            -
                        flat_latent_params,
         
     | 
| 
       273 
     | 
    
         
            -
                        jax_lbfgsb_state,
         
     | 
| 
       274 
     | 
    
         
            -
                    ) = jax.pure_callback(  # type: ignore[attr-defined]
         
     | 
| 
       275 
     | 
    
         
            -
                        _update_pure,
         
     | 
| 
       276 
     | 
    
         
            -
                        (flat_latent_grad, jax_lbfgsb_state),
         
     | 
| 
       277 
     | 
    
         
            -
                        flat_latent_grad,
         
     | 
| 
       278 
     | 
    
         
            -
                        value,
         
     | 
| 
       279 
     | 
    
         
            -
                        jax_lbfgsb_state,
         
     | 
| 
       280 
     | 
    
         
            -
                    )
         
     | 
| 
       281 
     | 
    
         
            -
                    latent_params = unflatten_fn(flat_latent_params)
         
     | 
| 
       282 
     | 
    
         
            -
                    return transform_fn(latent_params), latent_params, jax_lbfgsb_state
         
     | 
| 
       283 
     | 
    
         
            -
             
     | 
| 
       284 
     | 
    
         
            -
                return base.Optimizer(
         
     | 
| 
       285 
     | 
    
         
            -
                    init=init_fn,
         
     | 
| 
       286 
     | 
    
         
            -
                    params=params_fn,
         
     | 
| 
       287 
     | 
    
         
            -
                    update=update_fn,
         
     | 
| 
       288 
     | 
    
         
            -
                )
         
     | 
| 
       289 
     | 
    
         
            -
             
     | 
| 
       290 
     | 
    
         
            -
             
     | 
| 
       291 
     | 
    
         
            -
            # ------------------------------------------------------------------------------
         
     | 
| 
       292 
     | 
    
         
            -
            # Helper functions.
         
     | 
| 
       293 
     | 
    
         
            -
            # ------------------------------------------------------------------------------
         
     | 
| 
       294 
     | 
    
         
            -
             
     | 
| 
       295 
     | 
    
         
            -
             
     | 
| 
       296 
     | 
    
         
            -
            def _is_density(leaf: Any) -> Any:
         
     | 
| 
       297 
     | 
    
         
            -
                """Return `True` if `leaf` is a density array."""
         
     | 
| 
       298 
     | 
    
         
            -
                return isinstance(leaf, types.Density2DArray)
         
     | 
| 
       299 
     | 
    
         
            -
             
     | 
| 
       300 
     | 
    
         
            -
             
     | 
| 
       301 
     | 
    
         
            -
            def _to_numpy(params: PyTree) -> NDArray:
         
     | 
| 
       302 
     | 
    
         
            -
                """Flattens a `params` pytree into a single rank-1 numpy array."""
         
     | 
| 
       303 
     | 
    
         
            -
                x, _ = flatten_util.ravel_pytree(params)  # type: ignore[no-untyped-call]
         
     | 
| 
       304 
     | 
    
         
            -
                return onp.asarray(x, dtype=onp.float64)
         
     | 
| 
       305 
     | 
    
         
            -
             
     | 
| 
       306 
     | 
    
         
            -
             
     | 
| 
       307 
     | 
    
         
            -
            def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
         
     | 
| 
       308 
     | 
    
         
            -
                """Restores a pytree from a flat numpy array using the structure of `params`.
         
     | 
| 
       309 
     | 
    
         
            -
             
     | 
| 
       310 
     | 
    
         
            -
                Note that the returned pytree includes jax array leaves.
         
     | 
| 
       311 
     | 
    
         
            -
             
     | 
| 
       312 
     | 
    
         
            -
                Args:
         
     | 
| 
       313 
     | 
    
         
            -
                    x_flat: The rank-1 numpy array to be restored.
         
     | 
| 
       314 
     | 
    
         
            -
                    params: A pytree of parameters whose structure is replicated in the restored
         
     | 
| 
       315 
     | 
    
         
            -
                        pytree.
         
     | 
| 
       316 
     | 
    
         
            -
             
     | 
| 
       317 
     | 
    
         
            -
                Returns:
         
     | 
| 
       318 
     | 
    
         
            -
                    The restored pytree, with jax array leaves.
         
     | 
| 
       319 
     | 
    
         
            -
                """
         
     | 
| 
       320 
     | 
    
         
            -
                _, unflatten_fn = flatten_util.ravel_pytree(params)  # type: ignore[no-untyped-call]
         
     | 
| 
       321 
     | 
    
         
            -
                return unflatten_fn(jnp.asarray(x_flat, dtype=float))
         
     | 
| 
       322 
     | 
    
         
            -
             
     | 
| 
       323 
     | 
    
         
            -
             
     | 
| 
       324 
     | 
    
         
            -
            def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
         
     | 
| 
       325 
     | 
    
         
            -
                """Generates a bound vector for the `params`.
         
     | 
| 
       326 
     | 
    
         
            -
             
     | 
| 
       327 
     | 
    
         
            -
                The `bound` can be specified in various ways; it may be `None` or a scalar,
         
     | 
| 
       328 
     | 
    
         
            -
                which then applies to all arrays in `params`. It may be a pytree with
         
     | 
| 
       329 
     | 
    
         
            -
                structure matching that of `params`, where each leaf is either `None`, a
         
     | 
| 
       330 
     | 
    
         
            -
                scalar, or an array matching the shape of the corresponding leaf in `params`.
         
     | 
| 
       331 
     | 
    
         
            -
             
     | 
| 
       332 
     | 
    
         
            -
                The returned bound is a flat array suitable for use with `ScipyLbfgsbState`.
         
     | 
| 
       333 
     | 
    
         
            -
             
     | 
| 
       334 
     | 
    
         
            -
                Args:
         
     | 
| 
       335 
     | 
    
         
            -
                    bound: The pytree of bounds.
         
     | 
| 
       336 
     | 
    
         
            -
                    params: The pytree of parameters.
         
     | 
| 
       337 
     | 
    
         
            -
             
     | 
| 
       338 
     | 
    
         
            -
                Returns:
         
     | 
| 
       339 
     | 
    
         
            -
                    The flat elementwise bound.
         
     | 
| 
       340 
     | 
    
         
            -
                """
         
     | 
| 
       341 
     | 
    
         
            -
             
     | 
| 
       342 
     | 
    
         
            -
                if bound is None or onp.isscalar(bound):
         
     | 
| 
       343 
     | 
    
         
            -
                    bound = tree_util.tree_map(
         
     | 
| 
       344 
     | 
    
         
            -
                        lambda _: bound,
         
     | 
| 
       345 
     | 
    
         
            -
                        params,
         
     | 
| 
       346 
     | 
    
         
            -
                        is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES),
         
     | 
| 
       347 
     | 
    
         
            -
                    )
         
     | 
| 
       348 
     | 
    
         
            -
             
     | 
| 
       349 
     | 
    
         
            -
                bound_leaves, bound_treedef = tree_util.tree_flatten(
         
     | 
| 
       350 
     | 
    
         
            -
                    bound, is_leaf=lambda x: x is None
         
     | 
| 
       351 
     | 
    
         
            -
                )
         
     | 
| 
       352 
     | 
    
         
            -
                params_leaves = tree_util.tree_leaves(params, is_leaf=lambda x: x is None)
         
     | 
| 
       353 
     | 
    
         
            -
             
     | 
| 
       354 
     | 
    
         
            -
                # `bound` should be a pytree of arrays or `None`, while `params` may
         
     | 
| 
       355 
     | 
    
         
            -
                # include custom pytree nodes. Convert the custom nodes into standard
         
     | 
| 
       356 
     | 
    
         
            -
                # types to facilitate validation that the tree structures match.
         
     | 
| 
       357 
     | 
    
         
            -
                params_treedef = tree_util.tree_structure(
         
     | 
| 
       358 
     | 
    
         
            -
                    tree_util.tree_map(
         
     | 
| 
       359 
     | 
    
         
            -
                        lambda x: 0.0,
         
     | 
| 
       360 
     | 
    
         
            -
                        tree=params,
         
     | 
| 
       361 
     | 
    
         
            -
                        is_leaf=lambda x: x is None or isinstance(x, types.CUSTOM_TYPES),
         
     | 
| 
       362 
     | 
    
         
            -
                    )
         
     | 
| 
       363 
     | 
    
         
            -
                )
         
     | 
| 
       364 
     | 
    
         
            -
                if bound_treedef != params_treedef:  # type: ignore[operator]
         
     | 
| 
       365 
     | 
    
         
            -
                    raise ValueError(
         
     | 
| 
       366 
     | 
    
         
            -
                        f"Tree structure of `bound` and `params` must match, but got "
         
     | 
| 
       367 
     | 
    
         
            -
                        f"{bound_treedef} and {params_treedef}, respectively."
         
     | 
| 
       368 
     | 
    
         
            -
                    )
         
     | 
| 
       369 
     | 
    
         
            -
             
     | 
| 
       370 
     | 
    
         
            -
                bound_flat = []
         
     | 
| 
       371 
     | 
    
         
            -
                for b, p in zip(bound_leaves, params_leaves):
         
     | 
| 
       372 
     | 
    
         
            -
                    if p is None:
         
     | 
| 
       373 
     | 
    
         
            -
                        continue
         
     | 
| 
       374 
     | 
    
         
            -
                    if b is None or onp.isscalar(b) or onp.shape(b) == ():
         
     | 
| 
       375 
     | 
    
         
            -
                        bound_flat += [b] * onp.size(p)
         
     | 
| 
       376 
     | 
    
         
            -
                    else:
         
     | 
| 
       377 
     | 
    
         
            -
                        if b.shape != p.shape:
         
     | 
| 
       378 
     | 
    
         
            -
                            raise ValueError(
         
     | 
| 
       379 
     | 
    
         
            -
                                f"`bound` must be `None`, a scalar, or have shape matching "
         
     | 
| 
       380 
     | 
    
         
            -
                                f"`params`, but got shape {b.shape} when params has shape "
         
     | 
| 
       381 
     | 
    
         
            -
                                f"{p.shape}."
         
     | 
| 
       382 
     | 
    
         
            -
                            )
         
     | 
| 
       383 
     | 
    
         
            -
                        bound_flat += b.flatten().tolist()
         
     | 
| 
       384 
     | 
    
         
            -
             
     | 
| 
       385 
     | 
    
         
            -
                return bound_flat
         
     | 
| 
       386 
     | 
    
         
            -
             
     | 
| 
       387 
     | 
    
         
            -
             
     | 
| 
       388 
     | 
    
         
            -
            def _example_state(params: PyTree, maxcor: int) -> PyTree:
         
     | 
| 
       389 
     | 
    
         
            -
                """Return an example state for the given `params` and `maxcor`."""
         
     | 
| 
       390 
     | 
    
         
            -
                params_flat, _ = flatten_util.ravel_pytree(params)  # type: ignore[no-untyped-call]
         
     | 
| 
       391 
     | 
    
         
            -
                n = params_flat.size
         
     | 
| 
       392 
     | 
    
         
            -
                float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
         
     | 
| 
       393 
     | 
    
         
            -
                example_jax_lbfgsb_state = dict(
         
     | 
| 
       394 
     | 
    
         
            -
                    x=jnp.zeros(n, dtype=float),
         
     | 
| 
       395 
     | 
    
         
            -
                    _maxcor=jnp.zeros((), dtype=int),
         
     | 
| 
       396 
     | 
    
         
            -
                    _line_search_max_steps=jnp.zeros((), dtype=int),
         
     | 
| 
       397 
     | 
    
         
            -
                    _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
         
     | 
| 
       398 
     | 
    
         
            -
                    _iwa=jnp.ones(n * 3, dtype=jnp.int32),  # Fortran int
         
     | 
| 
       399 
     | 
    
         
            -
                    _task=jnp.zeros(59, dtype=int),
         
     | 
| 
       400 
     | 
    
         
            -
                    _csave=jnp.zeros(59, dtype=int),
         
     | 
| 
       401 
     | 
    
         
            -
                    _lsave=jnp.zeros(4, dtype=jnp.int32),  # Fortran int
         
     | 
| 
       402 
     | 
    
         
            -
                    _isave=jnp.zeros(44, dtype=jnp.int32),  # Fortran int
         
     | 
| 
       403 
     | 
    
         
            -
                    _dsave=jnp.zeros(29, dtype=float),
         
     | 
| 
       404 
     | 
    
         
            -
                    _lower_bound=jnp.zeros(n, dtype=float),
         
     | 
| 
       405 
     | 
    
         
            -
                    _upper_bound=jnp.zeros(n, dtype=float),
         
     | 
| 
       406 
     | 
    
         
            -
                    _bound_type=jnp.zeros(n, dtype=int),
         
     | 
| 
       407 
     | 
    
         
            -
                )
         
     | 
| 
       408 
     | 
    
         
            -
                return float_params, example_jax_lbfgsb_state
         
     | 
| 
       409 
     | 
    
         
            -
             
     | 
| 
       410 
     | 
    
         
            -
             
     | 
| 
       411 
     | 
    
         
            -
            # ------------------------------------------------------------------------------
         
     | 
| 
       412 
     | 
    
         
            -
            # Wrapper for scipy's L-BFGS-B implementation.
         
     | 
| 
       413 
     | 
    
         
            -
            # ------------------------------------------------------------------------------
         
     | 
| 
       414 
     | 
    
         
            -
             
     | 
| 
       415 
     | 
    
         
            -
             
     | 
| 
       416 
     | 
    
         
            -
            @dataclasses.dataclass
         
     | 
| 
       417 
     | 
    
         
            -
            class ScipyLbfgsbState:
         
     | 
| 
       418 
     | 
    
         
            -
                """Stores the state of a scipy L-BFGS-B minimization.
         
     | 
| 
       419 
     | 
    
         
            -
             
     | 
| 
       420 
     | 
    
         
            -
                This class enables optimization with a more functional style, giving the user
         
     | 
| 
       421 
     | 
    
         
            -
                control over the optimization loop. Example usage is as follows:
         
     | 
| 
       422 
     | 
    
         
            -
             
     | 
| 
       423 
     | 
    
         
            -
                    value_fn = lambda x: onp.sum(x**2)
         
     | 
| 
       424 
     | 
    
         
            -
                    grad_fn = lambda x: 2 * x
         
     | 
| 
       425 
     | 
    
         
            -
             
     | 
| 
       426 
     | 
    
         
            -
                    x0 = onp.asarray([0.1, 0.2, 0.3])
         
     | 
| 
       427 
     | 
    
         
            -
                    lb = [None, -1, 0.1]
         
     | 
| 
       428 
     | 
    
         
            -
                    ub = [None, None, None]
         
     | 
| 
       429 
     | 
    
         
            -
                    state = ScipyLbfgsbState.init(
         
     | 
| 
       430 
     | 
    
         
            -
                        x0=x0, lower_bound=lb, upper_bound=ub, maxcor=20
         
     | 
| 
       431 
     | 
    
         
            -
                    )
         
     | 
| 
       432 
     | 
    
         
            -
             
     | 
| 
       433 
     | 
    
         
            -
                    for _ in range(10):
         
     | 
| 
       434 
     | 
    
         
            -
                        value = value_fn(state.x)
         
     | 
| 
       435 
     | 
    
         
            -
                        grad = grad_fn(state.x)
         
     | 
| 
       436 
     | 
    
         
            -
                        state.update(grad, value)
         
     | 
| 
       437 
     | 
    
         
            -
             
     | 
| 
       438 
     | 
    
         
            -
                This example converges with `state.x` equal to `(0, 0, 0.1)` and value equal
         
     | 
| 
       439 
     | 
    
         
            -
                to `0.01`.
         
     | 
| 
       440 
     | 
    
         
            -
             
     | 
| 
       441 
     | 
    
         
            -
                Attributes:
         
     | 
| 
       442 
     | 
    
         
            -
                    x: The current solution vector.
         
     | 
| 
       443 
     | 
    
         
            -
                """
         
     | 
| 
       444 
     | 
    
         
            -
             
     | 
| 
       445 
     | 
    
         
            -
                x: NDArray
         
     | 
| 
       446 
     | 
    
         
            -
                # Private attributes correspond to internal variables in the `scipy.optimize.
         
     | 
| 
       447 
     | 
    
         
            -
                # lbfgsb._minimize_lbfgsb` function.
         
     | 
| 
       448 
     | 
    
         
            -
                _maxcor: int
         
     | 
| 
       449 
     | 
    
         
            -
                _line_search_max_steps: int
         
     | 
| 
       450 
     | 
    
         
            -
                _wa: NDArray
         
     | 
| 
       451 
     | 
    
         
            -
                _iwa: NDArray
         
     | 
| 
       452 
     | 
    
         
            -
                _task: NDArray
         
     | 
| 
       453 
     | 
    
         
            -
                _csave: NDArray
         
     | 
| 
       454 
     | 
    
         
            -
                _lsave: NDArray
         
     | 
| 
       455 
     | 
    
         
            -
                _isave: NDArray
         
     | 
| 
       456 
     | 
    
         
            -
                _dsave: NDArray
         
     | 
| 
       457 
     | 
    
         
            -
                _lower_bound: NDArray
         
     | 
| 
       458 
     | 
    
         
            -
                _upper_bound: NDArray
         
     | 
| 
       459 
     | 
    
         
            -
                _bound_type: NDArray
         
     | 
| 
       460 
     | 
    
         
            -
             
     | 
| 
       461 
     | 
    
         
            -
                def __post_init__(self) -> None:
         
     | 
| 
       462 
     | 
    
         
            -
                    """Validates the datatypes for all state attributes."""
         
     | 
| 
       463 
     | 
    
         
            -
                    _validate_array_dtype(self.x, onp.float64)
         
     | 
| 
       464 
     | 
    
         
            -
                    _validate_array_dtype(self._wa, onp.float64)
         
     | 
| 
       465 
     | 
    
         
            -
                    _validate_array_dtype(self._iwa, FORTRAN_INT)
         
     | 
| 
       466 
     | 
    
         
            -
                    _validate_array_dtype(self._task, "S60")
         
     | 
| 
       467 
     | 
    
         
            -
                    _validate_array_dtype(self._csave, "S60")
         
     | 
| 
       468 
     | 
    
         
            -
                    _validate_array_dtype(self._lsave, FORTRAN_INT)
         
     | 
| 
       469 
     | 
    
         
            -
                    _validate_array_dtype(self._isave, FORTRAN_INT)
         
     | 
| 
       470 
     | 
    
         
            -
                    _validate_array_dtype(self._dsave, onp.float64)
         
     | 
| 
       471 
     | 
    
         
            -
                    _validate_array_dtype(self._lower_bound, onp.float64)
         
     | 
| 
       472 
     | 
    
         
            -
                    _validate_array_dtype(self._upper_bound, onp.float64)
         
     | 
| 
       473 
     | 
    
         
            -
                    _validate_array_dtype(self._bound_type, int)
         
     | 
| 
       474 
     | 
    
         
            -
             
     | 
| 
       475 
     | 
    
         
            -
                def to_jax(self) -> Dict[str, jnp.ndarray]:
         
     | 
| 
       476 
     | 
    
         
            -
                    """Generates a dictionary of jax arrays defining the state."""
         
     | 
| 
       477 
     | 
    
         
            -
                    return dict(
         
     | 
| 
       478 
     | 
    
         
            -
                        x=jnp.asarray(self.x),
         
     | 
| 
       479 
     | 
    
         
            -
                        _maxcor=jnp.asarray(self._maxcor),
         
     | 
| 
       480 
     | 
    
         
            -
                        _line_search_max_steps=jnp.asarray(self._line_search_max_steps),
         
     | 
| 
       481 
     | 
    
         
            -
                        _wa=jnp.asarray(self._wa),
         
     | 
| 
       482 
     | 
    
         
            -
                        _iwa=jnp.asarray(self._iwa),
         
     | 
| 
       483 
     | 
    
         
            -
                        _task=_array_from_s60_str(self._task),
         
     | 
| 
       484 
     | 
    
         
            -
                        _csave=_array_from_s60_str(self._csave),
         
     | 
| 
       485 
     | 
    
         
            -
                        _lsave=jnp.asarray(self._lsave),
         
     | 
| 
       486 
     | 
    
         
            -
                        _isave=jnp.asarray(self._isave),
         
     | 
| 
       487 
     | 
    
         
            -
                        _dsave=jnp.asarray(self._dsave),
         
     | 
| 
       488 
     | 
    
         
            -
                        _lower_bound=jnp.asarray(self._lower_bound),
         
     | 
| 
       489 
     | 
    
         
            -
                        _upper_bound=jnp.asarray(self._upper_bound),
         
     | 
| 
       490 
     | 
    
         
            -
                        _bound_type=jnp.asarray(self._bound_type),
         
     | 
| 
       491 
     | 
    
         
            -
                    )
         
     | 
| 
       492 
     | 
    
         
            -
             
     | 
| 
       493 
     | 
    
         
            -
                @classmethod
         
     | 
| 
       494 
     | 
    
         
            -
                def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
         
     | 
| 
       495 
     | 
    
         
            -
                    """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
         
     | 
| 
       496 
     | 
    
         
            -
                    state_dict = copy.deepcopy(state_dict)
         
     | 
| 
       497 
     | 
    
         
            -
                    return ScipyLbfgsbState(
         
     | 
| 
       498 
     | 
    
         
            -
                        x=onp.asarray(state_dict["x"], dtype=onp.float64),
         
     | 
| 
       499 
     | 
    
         
            -
                        _maxcor=int(state_dict["_maxcor"]),
         
     | 
| 
       500 
     | 
    
         
            -
                        _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
         
     | 
| 
       501 
     | 
    
         
            -
                        _wa=onp.asarray(state_dict["_wa"], onp.float64),
         
     | 
| 
       502 
     | 
    
         
            -
                        _iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
         
     | 
| 
       503 
     | 
    
         
            -
                        _task=_s60_str_from_array(state_dict["_task"]),
         
     | 
| 
       504 
     | 
    
         
            -
                        _csave=_s60_str_from_array(state_dict["_csave"]),
         
     | 
| 
       505 
     | 
    
         
            -
                        _lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
         
     | 
| 
       506 
     | 
    
         
            -
                        _isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
         
     | 
| 
       507 
     | 
    
         
            -
                        _dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
         
     | 
| 
       508 
     | 
    
         
            -
                        _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
         
     | 
| 
       509 
     | 
    
         
            -
                        _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
         
     | 
| 
       510 
     | 
    
         
            -
                        _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
         
     | 
| 
       511 
     | 
    
         
            -
                    )
         
     | 
| 
       512 
     | 
    
         
            -
             
     | 
| 
       513 
     | 
    
         
            -
                @classmethod
         
     | 
| 
       514 
     | 
    
         
            -
                def init(
         
     | 
| 
       515 
     | 
    
         
            -
                    cls,
         
     | 
| 
       516 
     | 
    
         
            -
                    x0: NDArray,
         
     | 
| 
       517 
     | 
    
         
            -
                    lower_bound: ElementwiseBound,
         
     | 
| 
       518 
     | 
    
         
            -
                    upper_bound: ElementwiseBound,
         
     | 
| 
       519 
     | 
    
         
            -
                    maxcor: int,
         
     | 
| 
       520 
     | 
    
         
            -
                    line_search_max_steps: int,
         
     | 
| 
       521 
     | 
    
         
            -
                ) -> "ScipyLbfgsbState":
         
     | 
| 
       522 
     | 
    
         
            -
                    """Initializes the `ScipyLbfgsbState` for `x0`.
         
     | 
| 
       523 
     | 
    
         
            -
             
     | 
| 
       524 
     | 
    
         
            -
                    Args:
         
     | 
| 
       525 
     | 
    
         
            -
                        x0: Array giving the initial solution vector.
         
     | 
| 
       526 
     | 
    
         
            -
                        lower_bound: Array giving the elementwise optional lower bound.
         
     | 
| 
       527 
     | 
    
         
            -
                        upper_bound: Array giving the elementwise optional upper bound.
         
     | 
| 
       528 
     | 
    
         
            -
                        maxcor: The maximum number of variable metric corrections used to define
         
     | 
| 
       529 
     | 
    
         
            -
                            the limited memory matrix, in the L-BFGS-B scheme.
         
     | 
| 
       530 
     | 
    
         
            -
                        line_search_max_steps: The maximum number of steps in the line search.
         
     | 
| 
       531 
     | 
    
         
            -
             
     | 
| 
       532 
     | 
    
         
            -
                    Returns:
         
     | 
| 
       533 
     | 
    
         
            -
                        The `ScipyLbfgsbState`.
         
     | 
| 
       534 
     | 
    
         
            -
                    """
         
     | 
| 
       535 
     | 
    
         
            -
                    x0 = onp.asarray(x0)
         
     | 
| 
       536 
     | 
    
         
            -
                    if x0.ndim > 1:
         
     | 
| 
       537 
     | 
    
         
            -
                        raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
         
     | 
| 
       538 
     | 
    
         
            -
                    lower_bound = onp.asarray(lower_bound)
         
     | 
| 
       539 
     | 
    
         
            -
                    upper_bound = onp.asarray(upper_bound)
         
     | 
| 
       540 
     | 
    
         
            -
                    if x0.shape != lower_bound.shape or x0.shape != upper_bound.shape:
         
     | 
| 
       541 
     | 
    
         
            -
                        raise ValueError(
         
     | 
| 
       542 
     | 
    
         
            -
                            f"`x0`, `lower_bound`, and `upper_bound` must have matching "
         
     | 
| 
       543 
     | 
    
         
            -
                            f"shape but got shapes {x0.shape}, {lower_bound.shape}, and "
         
     | 
| 
       544 
     | 
    
         
            -
                            f"{upper_bound.shape}, respectively."
         
     | 
| 
       545 
     | 
    
         
            -
                        )
         
     | 
| 
       546 
     | 
    
         
            -
                    if maxcor < 1:
         
     | 
| 
       547 
     | 
    
         
            -
                        raise ValueError(f"`maxcor` must be positive but got {maxcor}.")
         
     | 
| 
       548 
     | 
    
         
            -
             
     | 
| 
       549 
     | 
    
         
            -
                    n = x0.size
         
     | 
| 
       550 
     | 
    
         
            -
                    lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
         
     | 
| 
       551 
     | 
    
         
            -
                        lower_bound, upper_bound
         
     | 
| 
       552 
     | 
    
         
            -
                    )
         
     | 
| 
       553 
     | 
    
         
            -
                    task = onp.zeros(1, "S60")
         
     | 
| 
       554 
     | 
    
         
            -
                    task[:] = TASK_START
         
     | 
| 
       555 
     | 
    
         
            -
             
     | 
| 
       556 
     | 
    
         
            -
                    # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
         
     | 
| 
       557 
     | 
    
         
            -
                    # function.
         
     | 
| 
       558 
     | 
    
         
            -
                    wa_size = _wa_size(n=n, maxcor=maxcor)
         
     | 
| 
       559 
     | 
    
         
            -
                    state = ScipyLbfgsbState(
         
     | 
| 
       560 
     | 
    
         
            -
                        x=onp.array(x0, onp.float64),
         
     | 
| 
       561 
     | 
    
         
            -
                        _maxcor=maxcor,
         
     | 
| 
       562 
     | 
    
         
            -
                        _line_search_max_steps=line_search_max_steps,
         
     | 
| 
       563 
     | 
    
         
            -
                        _wa=onp.zeros(wa_size, onp.float64),
         
     | 
| 
       564 
     | 
    
         
            -
                        _iwa=onp.zeros(3 * n, FORTRAN_INT),
         
     | 
| 
       565 
     | 
    
         
            -
                        _task=task,
         
     | 
| 
       566 
     | 
    
         
            -
                        _csave=onp.zeros(1, "S60"),
         
     | 
| 
       567 
     | 
    
         
            -
                        _lsave=onp.zeros(4, FORTRAN_INT),
         
     | 
| 
       568 
     | 
    
         
            -
                        _isave=onp.zeros(44, FORTRAN_INT),
         
     | 
| 
       569 
     | 
    
         
            -
                        _dsave=onp.zeros(29, onp.float64),
         
     | 
| 
       570 
     | 
    
         
            -
                        _lower_bound=lower_bound_array,
         
     | 
| 
       571 
     | 
    
         
            -
                        _upper_bound=upper_bound_array,
         
     | 
| 
       572 
     | 
    
         
            -
                        _bound_type=bound_type,
         
     | 
| 
       573 
     | 
    
         
            -
                    )
         
     | 
| 
       574 
     | 
    
         
            -
                    # The initial state requires an update with zero value and gradient. This
         
     | 
| 
       575 
     | 
    
         
            -
                    # is because the initial task is "START", which does not actually require
         
     | 
| 
       576 
     | 
    
         
            -
                    # value and gradient evaluation.
         
     | 
| 
       577 
     | 
    
         
            -
                    state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
         
     | 
| 
       578 
     | 
    
         
            -
                    return state
         
     | 
| 
       579 
     | 
    
         
            -
             
     | 
| 
       580 
     | 
    
         
            -
                def update(
         
     | 
| 
       581 
     | 
    
         
            -
                    self,
         
     | 
| 
       582 
     | 
    
         
            -
                    grad: NDArray,
         
     | 
| 
       583 
     | 
    
         
            -
                    value: NDArray,
         
     | 
| 
       584 
     | 
    
         
            -
                ) -> None:
         
     | 
| 
       585 
     | 
    
         
            -
                    """Performs an in-place update of the `ScipyLbfgsbState`.
         
     | 
| 
       586 
     | 
    
         
            -
             
     | 
| 
       587 
     | 
    
         
            -
                    Args:
         
     | 
| 
       588 
     | 
    
         
            -
                        grad: The function gradient for the current `x`.
         
     | 
| 
       589 
     | 
    
         
            -
                        value: The scalar function value for the current `x`.
         
     | 
| 
       590 
     | 
    
         
            -
                    """
         
     | 
| 
       591 
     | 
    
         
            -
                    if grad.shape != self.x.shape:
         
     | 
| 
       592 
     | 
    
         
            -
                        raise ValueError(
         
     | 
| 
       593 
     | 
    
         
            -
                            f"`grad` must have the same shape as attribute `x`, but got shapes "
         
     | 
| 
       594 
     | 
    
         
            -
                            f"{grad.shape} and {self.x.shape}, respectively."
         
     | 
| 
       595 
     | 
    
         
            -
                        )
         
     | 
| 
       596 
     | 
    
         
            -
                    if value.shape != ():
         
     | 
| 
       597 
     | 
    
         
            -
                        raise ValueError(f"`value` must be a scalar but got shape {value.shape}.")
         
     | 
| 
       598 
     | 
    
         
            -
             
     | 
| 
       599 
     | 
    
         
            -
                    # The `setulb` function will sometimes return with a task that does not
         
     | 
| 
       600 
     | 
    
         
            -
                    # require a value and gradient evaluation. In this case we simply call it
         
     | 
| 
       601 
     | 
    
         
            -
                    # again, advancing past such "dummy" steps.
         
     | 
| 
       602 
     | 
    
         
            -
                    for _ in range(3):
         
     | 
| 
       603 
     | 
    
         
            -
                        scipy_lbfgsb.setulb(
         
     | 
| 
       604 
     | 
    
         
            -
                            m=self._maxcor,
         
     | 
| 
       605 
     | 
    
         
            -
                            x=self.x,
         
     | 
| 
       606 
     | 
    
         
            -
                            l=self._lower_bound,
         
     | 
| 
       607 
     | 
    
         
            -
                            u=self._upper_bound,
         
     | 
| 
       608 
     | 
    
         
            -
                            nbd=self._bound_type,
         
     | 
| 
       609 
     | 
    
         
            -
                            f=value,
         
     | 
| 
       610 
     | 
    
         
            -
                            g=grad,
         
     | 
| 
       611 
     | 
    
         
            -
                            factr=UPDATE_FACTR,
         
     | 
| 
       612 
     | 
    
         
            -
                            pgtol=UPDATE_PGTOL,
         
     | 
| 
       613 
     | 
    
         
            -
                            wa=self._wa,
         
     | 
| 
       614 
     | 
    
         
            -
                            iwa=self._iwa,
         
     | 
| 
       615 
     | 
    
         
            -
                            task=self._task,
         
     | 
| 
       616 
     | 
    
         
            -
                            iprint=UPDATE_IPRINT,
         
     | 
| 
       617 
     | 
    
         
            -
                            csave=self._csave,
         
     | 
| 
       618 
     | 
    
         
            -
                            lsave=self._lsave,
         
     | 
| 
       619 
     | 
    
         
            -
                            isave=self._isave,
         
     | 
| 
       620 
     | 
    
         
            -
                            dsave=self._dsave,
         
     | 
| 
       621 
     | 
    
         
            -
                            maxls=self._line_search_max_steps,
         
     | 
| 
       622 
     | 
    
         
            -
                        )
         
     | 
| 
       623 
     | 
    
         
            -
                        task_str = self._task.tobytes()
         
     | 
| 
       624 
     | 
    
         
            -
                        if task_str.startswith(TASK_FG):
         
     | 
| 
       625 
     | 
    
         
            -
                            break
         
     | 
| 
       626 
     | 
    
         
            -
             
     | 
| 
       627 
     | 
    
         
            -
             
     | 
| 
       628 
     | 
    
         
            -
            def _wa_size(n: int, maxcor: int) -> int:
         
     | 
| 
       629 
     | 
    
         
            -
                """Return the size of the `wa` attribute of lbfgsb state."""
         
     | 
| 
       630 
     | 
    
         
            -
                return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
         
     | 
| 
       631 
     | 
    
         
            -
             
     | 
| 
       632 
     | 
    
         
            -
             
     | 
| 
       633 
     | 
    
         
            -
            def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
         
     | 
| 
       634 
     | 
    
         
            -
                """Validates that `x` is an array with the specified `dtype`."""
         
     | 
| 
       635 
     | 
    
         
            -
                if not isinstance(x, onp.ndarray):
         
     | 
| 
       636 
     | 
    
         
            -
                    raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
         
     | 
| 
       637 
     | 
    
         
            -
                if x.dtype != dtype:
         
     | 
| 
       638 
     | 
    
         
            -
                    raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
         
     | 
| 
       639 
     | 
    
         
            -
             
     | 
| 
       640 
     | 
    
         
            -
             
     | 
| 
       641 
     | 
    
         
            -
            def _configure_bounds(
         
     | 
| 
       642 
     | 
    
         
            -
                lower_bound: ElementwiseBound,
         
     | 
| 
       643 
     | 
    
         
            -
                upper_bound: ElementwiseBound,
         
     | 
| 
       644 
     | 
    
         
            -
            ) -> Tuple[NDArray, NDArray, NDArray]:
         
     | 
| 
       645 
     | 
    
         
            -
                """Configures the bounds for an L-BFGS-B optimization."""
         
     | 
| 
       646 
     | 
    
         
            -
                bound_type = [
         
     | 
| 
       647 
     | 
    
         
            -
                    BOUNDS_MAP[(lower is None, upper is None)]
         
     | 
| 
       648 
     | 
    
         
            -
                    for lower, upper in zip(lower_bound, upper_bound)
         
     | 
| 
       649 
     | 
    
         
            -
                ]
         
     | 
| 
       650 
     | 
    
         
            -
                lower_bound_array = [0.0 if x is None else x for x in lower_bound]
         
     | 
| 
       651 
     | 
    
         
            -
                upper_bound_array = [0.0 if x is None else x for x in upper_bound]
         
     | 
| 
       652 
     | 
    
         
            -
                return (
         
     | 
| 
       653 
     | 
    
         
            -
                    onp.asarray(lower_bound_array, onp.float64),
         
     | 
| 
       654 
     | 
    
         
            -
                    onp.asarray(upper_bound_array, onp.float64),
         
     | 
| 
       655 
     | 
    
         
            -
                    onp.asarray(bound_type),
         
     | 
| 
       656 
     | 
    
         
            -
                )
         
     | 
| 
       657 
     | 
    
         
            -
             
     | 
| 
       658 
     | 
    
         
            -
             
     | 
| 
       659 
     | 
    
         
            -
            def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
         
     | 
| 
       660 
     | 
    
         
            -
                """Return a jax array for a numpy s60 string."""
         
     | 
| 
       661 
     | 
    
         
            -
                assert s60_str.shape == (1,)
         
     | 
| 
       662 
     | 
    
         
            -
                chars = [int(o) for o in s60_str[0]]
         
     | 
| 
       663 
     | 
    
         
            -
                chars.extend([32] * (59 - len(chars)))
         
     | 
| 
       664 
     | 
    
         
            -
                return jnp.asarray(chars, dtype=int)
         
     | 
| 
       665 
     | 
    
         
            -
             
     | 
| 
       666 
     | 
    
         
            -
             
     | 
| 
       667 
     | 
    
         
            -
            def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
         
     | 
| 
       668 
     | 
    
         
            -
                """Return a numpy s60 string for a jax array."""
         
     | 
| 
       669 
     | 
    
         
            -
                return onp.asarray(
         
     | 
| 
       670 
     | 
    
         
            -
                    [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
         
     | 
| 
       671 
     | 
    
         
            -
                    dtype="S60",
         
     | 
| 
       672 
     | 
    
         
            -
                )
         
     | 
| 
         @@ -1,21 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            MIT License
         
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
            Copyright (c) 2023 The INVRS-IO authors.
         
     | 
| 
       4 
     | 
    
         
            -
             
     | 
| 
       5 
     | 
    
         
            -
            Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 
       6 
     | 
    
         
            -
            of this software and associated documentation files (the "Software"), to deal
         
     | 
| 
       7 
     | 
    
         
            -
            in the Software without restriction, including without limitation the rights
         
     | 
| 
       8 
     | 
    
         
            -
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 
       9 
     | 
    
         
            -
            copies of the Software, and to permit persons to whom the Software is
         
     | 
| 
       10 
     | 
    
         
            -
            furnished to do so, subject to the following conditions:
         
     | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
       12 
     | 
    
         
            -
            The above copyright notice and this permission notice shall be included in all
         
     | 
| 
       13 
     | 
    
         
            -
            copies or substantial portions of the Software.
         
     | 
| 
       14 
     | 
    
         
            -
             
     | 
| 
       15 
     | 
    
         
            -
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 
       16 
     | 
    
         
            -
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 
       17 
     | 
    
         
            -
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 
       18 
     | 
    
         
            -
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 
       19 
     | 
    
         
            -
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 
       20 
     | 
    
         
            -
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 
       21 
     | 
    
         
            -
            SOFTWARE.
         
     |