invrs-opt 0.10.4__py3-none-any.whl → 0.10.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- invrs_opt/__init__.py +1 -1
- invrs_opt/optimizers/lbfgsb.py +76 -102
- {invrs_opt-0.10.4.dist-info → invrs_opt-0.10.5.dist-info}/METADATA +2 -2
- {invrs_opt-0.10.4.dist-info → invrs_opt-0.10.5.dist-info}/RECORD +7 -7
- {invrs_opt-0.10.4.dist-info → invrs_opt-0.10.5.dist-info}/LICENSE +0 -0
- {invrs_opt-0.10.4.dist-info → invrs_opt-0.10.5.dist-info}/WHEEL +0 -0
- {invrs_opt-0.10.4.dist-info → invrs_opt-0.10.5.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
invrs_opt/optimizers/lbfgsb.py
CHANGED
@@ -56,8 +56,6 @@ BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
|
|
56
56
|
(True, False): 3, # Only the lower bound is `None`.
|
57
57
|
}
|
58
58
|
|
59
|
-
FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
|
60
|
-
|
61
59
|
if version.Version(jax.__version__) > version.Version("0.4.31"):
|
62
60
|
callback_sequential = functools.partial(jax.pure_callback, vmap_method="sequential")
|
63
61
|
else:
|
@@ -638,19 +636,19 @@ def _example_state(params: PyTree, maxcor: int) -> PyTree:
|
|
638
636
|
x=jnp.zeros(n, dtype=float),
|
639
637
|
converged=jnp.asarray(False),
|
640
638
|
_maxcor=jnp.zeros((), dtype=int),
|
641
|
-
|
639
|
+
_lower_bound=jnp.zeros(n, dtype=float),
|
640
|
+
_upper_bound=jnp.zeros(n, dtype=float),
|
641
|
+
_bound_type=jnp.zeros(n, dtype=jnp.int32),
|
642
642
|
_ftol=jnp.zeros((), dtype=float),
|
643
643
|
_gtol=jnp.zeros((), dtype=float),
|
644
644
|
_wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
|
645
|
-
_iwa=jnp.ones(n * 3, dtype=jnp.int32),
|
646
|
-
_task=jnp.zeros(
|
647
|
-
|
648
|
-
|
649
|
-
_isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
|
645
|
+
_iwa=jnp.ones(n * 3, dtype=jnp.int32),
|
646
|
+
_task=jnp.zeros(2, dtype=jnp.int32),
|
647
|
+
_lsave=jnp.zeros(4, dtype=jnp.int32),
|
648
|
+
_isave=jnp.zeros(44, dtype=jnp.int32),
|
650
649
|
_dsave=jnp.zeros(29, dtype=float),
|
651
|
-
|
652
|
-
|
653
|
-
_bound_type=jnp.zeros(n, dtype=int),
|
650
|
+
_ln_task=jnp.zeros(2, dtype=jnp.int32),
|
651
|
+
_line_search_max_steps=jnp.zeros((), dtype=int),
|
654
652
|
)
|
655
653
|
return float_params, example_jax_lbfgsb_state
|
656
654
|
|
@@ -691,36 +689,37 @@ class ScipyLbfgsbState:
|
|
691
689
|
|
692
690
|
x: NDArray
|
693
691
|
converged: NDArray
|
694
|
-
# Private attributes correspond to internal variables in the
|
695
|
-
# lbfgsb._minimize_lbfgsb` function.
|
692
|
+
# Private attributes correspond to internal variables in the
|
693
|
+
# `scipy.optimize.lbfgsb._minimize_lbfgsb` function.
|
696
694
|
_maxcor: int
|
697
|
-
|
695
|
+
_lower_bound: NDArray
|
696
|
+
_upper_bound: NDArray
|
697
|
+
_bound_type: NDArray
|
698
698
|
_ftol: NDArray
|
699
699
|
_gtol: NDArray
|
700
700
|
_wa: NDArray
|
701
701
|
_iwa: NDArray
|
702
702
|
_task: NDArray
|
703
|
-
_csave: NDArray
|
704
703
|
_lsave: NDArray
|
705
704
|
_isave: NDArray
|
706
705
|
_dsave: NDArray
|
707
|
-
|
708
|
-
|
709
|
-
_bound_type: NDArray
|
706
|
+
_line_search_max_steps: int
|
707
|
+
_ln_task: NDArray
|
710
708
|
|
711
709
|
def __post_init__(self) -> None:
|
712
710
|
"""Validates the datatypes for all state attributes."""
|
713
|
-
_validate_array_dtype(self.x, onp.float64)
|
714
|
-
_validate_array_dtype(self.
|
715
|
-
_validate_array_dtype(self.
|
716
|
-
_validate_array_dtype(self.
|
717
|
-
_validate_array_dtype(self.
|
718
|
-
_validate_array_dtype(self.
|
719
|
-
_validate_array_dtype(self.
|
720
|
-
_validate_array_dtype(self.
|
721
|
-
_validate_array_dtype(self.
|
722
|
-
_validate_array_dtype(self.
|
723
|
-
_validate_array_dtype(self.
|
711
|
+
_validate_array_dtype("x", self.x, onp.float64)
|
712
|
+
_validate_array_dtype("_lower_bound", self._lower_bound, onp.float64)
|
713
|
+
_validate_array_dtype("_upper_bound", self._upper_bound, onp.float64)
|
714
|
+
_validate_array_dtype("_ftol", self._ftol, onp.float64)
|
715
|
+
_validate_array_dtype("_gtol", self._gtol, onp.float64)
|
716
|
+
_validate_array_dtype("_wa", self._wa, onp.float64)
|
717
|
+
_validate_array_dtype("_iwa", self._iwa, onp.int32)
|
718
|
+
_validate_array_dtype("_task", self._task, onp.int32)
|
719
|
+
_validate_array_dtype("_lsave", self._lsave, onp.int32)
|
720
|
+
_validate_array_dtype("_isave", self._isave, onp.int32)
|
721
|
+
_validate_array_dtype("_dsave", self._dsave, onp.float64)
|
722
|
+
_validate_array_dtype("_ln_task", self._ln_task, onp.int32)
|
724
723
|
|
725
724
|
def to_dict(self) -> NumpyLbfgsbDict:
|
726
725
|
"""Generates a dictionary of jax arrays defining the state."""
|
@@ -728,19 +727,19 @@ class ScipyLbfgsbState:
|
|
728
727
|
x=onp.asarray(self.x),
|
729
728
|
converged=onp.asarray(self.converged),
|
730
729
|
_maxcor=onp.asarray(self._maxcor),
|
731
|
-
|
730
|
+
_lower_bound=onp.asarray(self._lower_bound),
|
731
|
+
_upper_bound=onp.asarray(self._upper_bound),
|
732
|
+
_bound_type=onp.asarray(self._bound_type),
|
732
733
|
_ftol=onp.asarray(self._ftol),
|
733
734
|
_gtol=onp.asarray(self._gtol),
|
734
735
|
_wa=onp.asarray(self._wa),
|
735
736
|
_iwa=onp.asarray(self._iwa),
|
736
|
-
_task=
|
737
|
-
_csave=_array_from_s60_str(self._csave),
|
737
|
+
_task=onp.asarray(self._task),
|
738
738
|
_lsave=onp.asarray(self._lsave),
|
739
739
|
_isave=onp.asarray(self._isave),
|
740
740
|
_dsave=onp.asarray(self._dsave),
|
741
|
-
|
742
|
-
|
743
|
-
_bound_type=onp.asarray(self._bound_type),
|
741
|
+
_line_search_max_steps=onp.asarray(self._line_search_max_steps),
|
742
|
+
_ln_task=onp.asarray(self._ln_task),
|
744
743
|
)
|
745
744
|
|
746
745
|
@classmethod
|
@@ -750,19 +749,19 @@ class ScipyLbfgsbState:
|
|
750
749
|
x=onp.array(state_dict["x"], dtype=onp.float64),
|
751
750
|
converged=onp.asarray(state_dict["converged"], dtype=bool),
|
752
751
|
_maxcor=int(state_dict["_maxcor"]),
|
753
|
-
|
752
|
+
_lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
|
753
|
+
_upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
|
754
|
+
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=onp.int32),
|
754
755
|
_ftol=onp.asarray(state_dict["_ftol"], dtype=onp.float64),
|
755
756
|
_gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
|
756
757
|
_wa=onp.array(state_dict["_wa"], onp.float64),
|
757
|
-
_iwa=onp.array(state_dict["_iwa"], dtype=
|
758
|
-
_task=
|
759
|
-
|
760
|
-
|
761
|
-
_isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
|
758
|
+
_iwa=onp.array(state_dict["_iwa"], dtype=onp.int32),
|
759
|
+
_task=onp.asarray(state_dict["_task"], dtype=onp.int32),
|
760
|
+
_lsave=onp.array(state_dict["_lsave"], dtype=onp.int32),
|
761
|
+
_isave=onp.array(state_dict["_isave"], dtype=onp.int32),
|
762
762
|
_dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
|
763
|
-
|
764
|
-
|
765
|
-
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
|
763
|
+
_line_search_max_steps=int(state_dict["_line_search_max_steps"]),
|
764
|
+
_ln_task=onp.asarray(state_dict["_ln_task"], onp.int32),
|
766
765
|
)
|
767
766
|
|
768
767
|
@classmethod
|
@@ -792,7 +791,6 @@ class ScipyLbfgsbState:
|
|
792
791
|
Returns:
|
793
792
|
The `ScipyLbfgsbState`.
|
794
793
|
"""
|
795
|
-
x0 = onp.asarray(x0)
|
796
794
|
if x0.ndim > 1:
|
797
795
|
raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
|
798
796
|
lower_bound = onp.asarray(lower_bound)
|
@@ -810,8 +808,6 @@ class ScipyLbfgsbState:
|
|
810
808
|
lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
|
811
809
|
lower_bound, upper_bound
|
812
810
|
)
|
813
|
-
task = onp.zeros(1, "S60")
|
814
|
-
task[:] = TASK_START
|
815
811
|
|
816
812
|
# See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
|
817
813
|
# function.
|
@@ -820,19 +816,19 @@ class ScipyLbfgsbState:
|
|
820
816
|
x=onp.array(x0, onp.float64),
|
821
817
|
converged=onp.asarray(False),
|
822
818
|
_maxcor=maxcor,
|
823
|
-
|
819
|
+
_lower_bound=lower_bound_array,
|
820
|
+
_upper_bound=upper_bound_array,
|
821
|
+
_bound_type=bound_type,
|
824
822
|
_ftol=onp.asarray(ftol, onp.float64),
|
825
823
|
_gtol=onp.asarray(gtol, onp.float64),
|
826
824
|
_wa=onp.zeros(wa_size, onp.float64),
|
827
|
-
_iwa=onp.zeros(3 * n,
|
828
|
-
_task=
|
829
|
-
|
830
|
-
|
831
|
-
_isave=onp.zeros(44, FORTRAN_INT),
|
825
|
+
_iwa=onp.zeros(3 * n, onp.int32),
|
826
|
+
_task=onp.zeros(2, onp.int32),
|
827
|
+
_lsave=onp.zeros(4, onp.int32),
|
828
|
+
_isave=onp.zeros(44, onp.int32),
|
832
829
|
_dsave=onp.zeros(29, onp.float64),
|
833
|
-
|
834
|
-
|
835
|
-
_bound_type=bound_type,
|
830
|
+
_line_search_max_steps=line_search_max_steps,
|
831
|
+
_ln_task=onp.zeros(2, onp.int32),
|
836
832
|
)
|
837
833
|
# The initial state requires an update with zero value and gradient. This
|
838
834
|
# is because the initial task is "START", which does not actually require
|
@@ -840,11 +836,7 @@ class ScipyLbfgsbState:
|
|
840
836
|
state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
|
841
837
|
return state
|
842
838
|
|
843
|
-
def update(
|
844
|
-
self,
|
845
|
-
grad: NDArray,
|
846
|
-
value: NDArray,
|
847
|
-
) -> None:
|
839
|
+
def update(self, grad: NDArray, value: NDArray) -> None:
|
848
840
|
"""Performs an in-place update of the `ScipyLbfgsbState` if not converged.
|
849
841
|
|
850
842
|
Args:
|
@@ -866,29 +858,27 @@ class ScipyLbfgsbState:
|
|
866
858
|
# again, advancing past such "dummy" steps.
|
867
859
|
for _ in range(3):
|
868
860
|
scipy_lbfgsb.setulb(
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
maxls=self._line_search_max_steps,
|
861
|
+
self._maxcor, # m
|
862
|
+
self.x, # x
|
863
|
+
self._lower_bound, # low_bnd
|
864
|
+
self._upper_bound, # upper_bnd
|
865
|
+
self._bound_type, # nbnd
|
866
|
+
value, # f
|
867
|
+
grad, # g
|
868
|
+
self._ftol / onp.finfo(float).eps, # factr
|
869
|
+
self._gtol, # pgtol
|
870
|
+
self._wa, # wa
|
871
|
+
self._iwa, # iwa
|
872
|
+
self._task, # task
|
873
|
+
self._lsave, # lsave
|
874
|
+
self._isave, # isave
|
875
|
+
self._dsave, # dsave
|
876
|
+
self._line_search_max_steps, # maxls
|
877
|
+
self._ln_task, # ln_task
|
887
878
|
)
|
888
|
-
|
889
|
-
if task_str.startswith(TASK_CONVERGED):
|
879
|
+
if self._task[0] == 4:
|
890
880
|
self.converged = onp.asarray(True)
|
891
|
-
if
|
881
|
+
if self._task[0] == 3:
|
892
882
|
break
|
893
883
|
|
894
884
|
|
@@ -897,12 +887,12 @@ def _wa_size(n: int, maxcor: int) -> int:
|
|
897
887
|
return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
|
898
888
|
|
899
889
|
|
900
|
-
def _validate_array_dtype(x: NDArray, dtype:
|
890
|
+
def _validate_array_dtype(name: str, x: NDArray, dtype: type) -> None:
|
901
891
|
"""Validates that `x` is an array with the specified `dtype`."""
|
902
892
|
if not isinstance(x, onp.ndarray):
|
903
|
-
raise ValueError(f"`
|
893
|
+
raise ValueError(f"`{name}` must be an `onp.ndarray` but got {type(x)}")
|
904
894
|
if x.dtype != dtype:
|
905
|
-
raise ValueError(f"`
|
895
|
+
raise ValueError(f"`{name}` must have dtype {dtype} but got {x.dtype}")
|
906
896
|
|
907
897
|
|
908
898
|
def _configure_bounds(
|
@@ -919,21 +909,5 @@ def _configure_bounds(
|
|
919
909
|
return (
|
920
910
|
onp.asarray(lower_bound_array, onp.float64),
|
921
911
|
onp.asarray(upper_bound_array, onp.float64),
|
922
|
-
onp.asarray(bound_type),
|
923
|
-
)
|
924
|
-
|
925
|
-
|
926
|
-
def _array_from_s60_str(s60_str: NDArray) -> NDArray:
|
927
|
-
"""Return a jax array for a numpy s60 string."""
|
928
|
-
assert s60_str.shape == (1,)
|
929
|
-
chars = [int(o) for o in s60_str[0]]
|
930
|
-
chars.extend([32] * (59 - len(chars)))
|
931
|
-
return onp.asarray(chars, dtype=int)
|
932
|
-
|
933
|
-
|
934
|
-
def _s60_str_from_array(array: NDArray) -> NDArray:
|
935
|
-
"""Return a numpy s60 string for a jax array."""
|
936
|
-
return onp.asarray(
|
937
|
-
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
|
938
|
-
dtype="S60",
|
912
|
+
onp.asarray(bound_type, onp.int32),
|
939
913
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.5
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -518,7 +518,7 @@ Requires-Dist: jaxlib
|
|
518
518
|
Requires-Dist: numpy
|
519
519
|
Requires-Dist: requests
|
520
520
|
Requires-Dist: optax
|
521
|
-
Requires-Dist: scipy
|
521
|
+
Requires-Dist: scipy>=1.15.0
|
522
522
|
Requires-Dist: totypes
|
523
523
|
Requires-Dist: types-requests
|
524
524
|
Provides-Extra: tests
|
@@ -1,11 +1,11 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=
|
1
|
+
invrs_opt/__init__.py,sha256=b2AnOoyTVD9z6CZEjHc3_S0un9BDAXTz8vv8yWtw0AE,586
|
2
2
|
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
invrs_opt/experimental/client.py,sha256=tbtH13FrA65XmTZfTO71CxJ78jeAEj3Zf85R-MTwbiU,4909
|
5
5
|
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
6
6
|
invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
invrs_opt/optimizers/base.py,sha256=uFfkN2LwWzAtwh6ktwWNy2iHNOY-sW3JzI46iSFkgok,1306
|
8
|
-
invrs_opt/optimizers/lbfgsb.py,sha256=
|
8
|
+
invrs_opt/optimizers/lbfgsb.py,sha256=WP6ouVtLaXSwJBh7CSzWR7rnRdHZuSmOr57TKF4UxMg,36659
|
9
9
|
invrs_opt/optimizers/wrapped_optax.py,sha256=781-8v_TlHsGaQF9Se9_iOEvtOLOr-BesTLudYarSlg,13685
|
10
10
|
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
invrs_opt/parameterization/base.py,sha256=jSwrEO86lGkYQG5gWsHvcIMWpZnnbdiKpn--2qaU02g,5362
|
@@ -13,8 +13,8 @@ invrs_opt/parameterization/filter_project.py,sha256=XL3HTEBLrF-q_75TjhOWLNdfUOSE
|
|
13
13
|
invrs_opt/parameterization/gaussian_levelset.py,sha256=-6foekLTFoZDtMKuoMEvdxMJt0_zTxrKNJo0Vn-Rv80,26073
|
14
14
|
invrs_opt/parameterization/pixel.py,sha256=YWkyBhfYtzI8cQ-M90PAZqRAbabwVaUh0UiYIGegQHI,1955
|
15
15
|
invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
|
16
|
-
invrs_opt-0.10.
|
17
|
-
invrs_opt-0.10.
|
18
|
-
invrs_opt-0.10.
|
19
|
-
invrs_opt-0.10.
|
20
|
-
invrs_opt-0.10.
|
16
|
+
invrs_opt-0.10.5.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
17
|
+
invrs_opt-0.10.5.dist-info/METADATA,sha256=JV6RdTh63uMDr2ae3PedC4BIxz54G8twEBGd38hKlZg,32816
|
18
|
+
invrs_opt-0.10.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
19
|
+
invrs_opt-0.10.5.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
20
|
+
invrs_opt-0.10.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|