invrs-opt 0.10.4__py3-none-any.whl → 0.10.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.10.4"
6
+ __version__ = "v0.10.5"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -56,8 +56,6 @@ BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
56
56
  (True, False): 3, # Only the lower bound is `None`.
57
57
  }
58
58
 
59
- FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
60
-
61
59
  if version.Version(jax.__version__) > version.Version("0.4.31"):
62
60
  callback_sequential = functools.partial(jax.pure_callback, vmap_method="sequential")
63
61
  else:
@@ -638,19 +636,19 @@ def _example_state(params: PyTree, maxcor: int) -> PyTree:
638
636
  x=jnp.zeros(n, dtype=float),
639
637
  converged=jnp.asarray(False),
640
638
  _maxcor=jnp.zeros((), dtype=int),
641
- _line_search_max_steps=jnp.zeros((), dtype=int),
639
+ _lower_bound=jnp.zeros(n, dtype=float),
640
+ _upper_bound=jnp.zeros(n, dtype=float),
641
+ _bound_type=jnp.zeros(n, dtype=jnp.int32),
642
642
  _ftol=jnp.zeros((), dtype=float),
643
643
  _gtol=jnp.zeros((), dtype=float),
644
644
  _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
645
- _iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
646
- _task=jnp.zeros(59, dtype=int),
647
- _csave=jnp.zeros(59, dtype=int),
648
- _lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
649
- _isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
645
+ _iwa=jnp.ones(n * 3, dtype=jnp.int32),
646
+ _task=jnp.zeros(2, dtype=jnp.int32),
647
+ _lsave=jnp.zeros(4, dtype=jnp.int32),
648
+ _isave=jnp.zeros(44, dtype=jnp.int32),
650
649
  _dsave=jnp.zeros(29, dtype=float),
651
- _lower_bound=jnp.zeros(n, dtype=float),
652
- _upper_bound=jnp.zeros(n, dtype=float),
653
- _bound_type=jnp.zeros(n, dtype=int),
650
+ _ln_task=jnp.zeros(2, dtype=jnp.int32),
651
+ _line_search_max_steps=jnp.zeros((), dtype=int),
654
652
  )
655
653
  return float_params, example_jax_lbfgsb_state
656
654
 
@@ -691,36 +689,37 @@ class ScipyLbfgsbState:
691
689
 
692
690
  x: NDArray
693
691
  converged: NDArray
694
- # Private attributes correspond to internal variables in the `scipy.optimize.
695
- # lbfgsb._minimize_lbfgsb` function.
692
+ # Private attributes correspond to internal variables in the
693
+ # `scipy.optimize.lbfgsb._minimize_lbfgsb` function.
696
694
  _maxcor: int
697
- _line_search_max_steps: int
695
+ _lower_bound: NDArray
696
+ _upper_bound: NDArray
697
+ _bound_type: NDArray
698
698
  _ftol: NDArray
699
699
  _gtol: NDArray
700
700
  _wa: NDArray
701
701
  _iwa: NDArray
702
702
  _task: NDArray
703
- _csave: NDArray
704
703
  _lsave: NDArray
705
704
  _isave: NDArray
706
705
  _dsave: NDArray
707
- _lower_bound: NDArray
708
- _upper_bound: NDArray
709
- _bound_type: NDArray
706
+ _line_search_max_steps: int
707
+ _ln_task: NDArray
710
708
 
711
709
  def __post_init__(self) -> None:
712
710
  """Validates the datatypes for all state attributes."""
713
- _validate_array_dtype(self.x, onp.float64)
714
- _validate_array_dtype(self._wa, onp.float64)
715
- _validate_array_dtype(self._iwa, FORTRAN_INT)
716
- _validate_array_dtype(self._task, "S60")
717
- _validate_array_dtype(self._csave, "S60")
718
- _validate_array_dtype(self._lsave, FORTRAN_INT)
719
- _validate_array_dtype(self._isave, FORTRAN_INT)
720
- _validate_array_dtype(self._dsave, onp.float64)
721
- _validate_array_dtype(self._lower_bound, onp.float64)
722
- _validate_array_dtype(self._upper_bound, onp.float64)
723
- _validate_array_dtype(self._bound_type, int)
711
+ _validate_array_dtype("x", self.x, onp.float64)
712
+ _validate_array_dtype("_lower_bound", self._lower_bound, onp.float64)
713
+ _validate_array_dtype("_upper_bound", self._upper_bound, onp.float64)
714
+ _validate_array_dtype("_ftol", self._ftol, onp.float64)
715
+ _validate_array_dtype("_gtol", self._gtol, onp.float64)
716
+ _validate_array_dtype("_wa", self._wa, onp.float64)
717
+ _validate_array_dtype("_iwa", self._iwa, onp.int32)
718
+ _validate_array_dtype("_task", self._task, onp.int32)
719
+ _validate_array_dtype("_lsave", self._lsave, onp.int32)
720
+ _validate_array_dtype("_isave", self._isave, onp.int32)
721
+ _validate_array_dtype("_dsave", self._dsave, onp.float64)
722
+ _validate_array_dtype("_ln_task", self._ln_task, onp.int32)
724
723
 
725
724
  def to_dict(self) -> NumpyLbfgsbDict:
726
725
  """Generates a dictionary of jax arrays defining the state."""
@@ -728,19 +727,19 @@ class ScipyLbfgsbState:
728
727
  x=onp.asarray(self.x),
729
728
  converged=onp.asarray(self.converged),
730
729
  _maxcor=onp.asarray(self._maxcor),
731
- _line_search_max_steps=onp.asarray(self._line_search_max_steps),
730
+ _lower_bound=onp.asarray(self._lower_bound),
731
+ _upper_bound=onp.asarray(self._upper_bound),
732
+ _bound_type=onp.asarray(self._bound_type),
732
733
  _ftol=onp.asarray(self._ftol),
733
734
  _gtol=onp.asarray(self._gtol),
734
735
  _wa=onp.asarray(self._wa),
735
736
  _iwa=onp.asarray(self._iwa),
736
- _task=_array_from_s60_str(self._task),
737
- _csave=_array_from_s60_str(self._csave),
737
+ _task=onp.asarray(self._task),
738
738
  _lsave=onp.asarray(self._lsave),
739
739
  _isave=onp.asarray(self._isave),
740
740
  _dsave=onp.asarray(self._dsave),
741
- _lower_bound=onp.asarray(self._lower_bound),
742
- _upper_bound=onp.asarray(self._upper_bound),
743
- _bound_type=onp.asarray(self._bound_type),
741
+ _line_search_max_steps=onp.asarray(self._line_search_max_steps),
742
+ _ln_task=onp.asarray(self._ln_task),
744
743
  )
745
744
 
746
745
  @classmethod
@@ -750,19 +749,19 @@ class ScipyLbfgsbState:
750
749
  x=onp.array(state_dict["x"], dtype=onp.float64),
751
750
  converged=onp.asarray(state_dict["converged"], dtype=bool),
752
751
  _maxcor=int(state_dict["_maxcor"]),
753
- _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
752
+ _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
753
+ _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
754
+ _bound_type=onp.asarray(state_dict["_bound_type"], dtype=onp.int32),
754
755
  _ftol=onp.asarray(state_dict["_ftol"], dtype=onp.float64),
755
756
  _gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
756
757
  _wa=onp.array(state_dict["_wa"], onp.float64),
757
- _iwa=onp.array(state_dict["_iwa"], dtype=FORTRAN_INT),
758
- _task=_s60_str_from_array(onp.asarray(state_dict["_task"])),
759
- _csave=_s60_str_from_array(onp.asarray(state_dict["_csave"])),
760
- _lsave=onp.array(state_dict["_lsave"], dtype=FORTRAN_INT),
761
- _isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
758
+ _iwa=onp.array(state_dict["_iwa"], dtype=onp.int32),
759
+ _task=onp.asarray(state_dict["_task"], dtype=onp.int32),
760
+ _lsave=onp.array(state_dict["_lsave"], dtype=onp.int32),
761
+ _isave=onp.array(state_dict["_isave"], dtype=onp.int32),
762
762
  _dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
763
- _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
764
- _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
765
- _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
763
+ _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
764
+ _ln_task=onp.asarray(state_dict["_ln_task"], onp.int32),
766
765
  )
767
766
 
768
767
  @classmethod
@@ -792,7 +791,6 @@ class ScipyLbfgsbState:
792
791
  Returns:
793
792
  The `ScipyLbfgsbState`.
794
793
  """
795
- x0 = onp.asarray(x0)
796
794
  if x0.ndim > 1:
797
795
  raise ValueError(f"`x0` must be rank-1 but got shape {x0.shape}.")
798
796
  lower_bound = onp.asarray(lower_bound)
@@ -810,8 +808,6 @@ class ScipyLbfgsbState:
810
808
  lower_bound_array, upper_bound_array, bound_type = _configure_bounds(
811
809
  lower_bound, upper_bound
812
810
  )
813
- task = onp.zeros(1, "S60")
814
- task[:] = TASK_START
815
811
 
816
812
  # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
817
813
  # function.
@@ -820,19 +816,19 @@ class ScipyLbfgsbState:
820
816
  x=onp.array(x0, onp.float64),
821
817
  converged=onp.asarray(False),
822
818
  _maxcor=maxcor,
823
- _line_search_max_steps=line_search_max_steps,
819
+ _lower_bound=lower_bound_array,
820
+ _upper_bound=upper_bound_array,
821
+ _bound_type=bound_type,
824
822
  _ftol=onp.asarray(ftol, onp.float64),
825
823
  _gtol=onp.asarray(gtol, onp.float64),
826
824
  _wa=onp.zeros(wa_size, onp.float64),
827
- _iwa=onp.zeros(3 * n, FORTRAN_INT),
828
- _task=task,
829
- _csave=onp.zeros(1, "S60"),
830
- _lsave=onp.zeros(4, FORTRAN_INT),
831
- _isave=onp.zeros(44, FORTRAN_INT),
825
+ _iwa=onp.zeros(3 * n, onp.int32),
826
+ _task=onp.zeros(2, onp.int32),
827
+ _lsave=onp.zeros(4, onp.int32),
828
+ _isave=onp.zeros(44, onp.int32),
832
829
  _dsave=onp.zeros(29, onp.float64),
833
- _lower_bound=lower_bound_array,
834
- _upper_bound=upper_bound_array,
835
- _bound_type=bound_type,
830
+ _line_search_max_steps=line_search_max_steps,
831
+ _ln_task=onp.zeros(2, onp.int32),
836
832
  )
837
833
  # The initial state requires an update with zero value and gradient. This
838
834
  # is because the initial task is "START", which does not actually require
@@ -840,11 +836,7 @@ class ScipyLbfgsbState:
840
836
  state.update(onp.zeros(x0.shape, onp.float64), onp.zeros((), onp.float64))
841
837
  return state
842
838
 
843
- def update(
844
- self,
845
- grad: NDArray,
846
- value: NDArray,
847
- ) -> None:
839
+ def update(self, grad: NDArray, value: NDArray) -> None:
848
840
  """Performs an in-place update of the `ScipyLbfgsbState` if not converged.
849
841
 
850
842
  Args:
@@ -866,29 +858,27 @@ class ScipyLbfgsbState:
866
858
  # again, advancing past such "dummy" steps.
867
859
  for _ in range(3):
868
860
  scipy_lbfgsb.setulb(
869
- m=self._maxcor,
870
- x=self.x,
871
- l=self._lower_bound,
872
- u=self._upper_bound,
873
- nbd=self._bound_type,
874
- f=value,
875
- g=grad,
876
- factr=self._ftol / onp.finfo(float).eps,
877
- pgtol=self._gtol,
878
- wa=self._wa,
879
- iwa=self._iwa,
880
- task=self._task,
881
- iprint=UPDATE_IPRINT,
882
- csave=self._csave,
883
- lsave=self._lsave,
884
- isave=self._isave,
885
- dsave=self._dsave,
886
- maxls=self._line_search_max_steps,
861
+ self._maxcor, # m
862
+ self.x, # x
863
+ self._lower_bound, # low_bnd
864
+ self._upper_bound, # upper_bnd
865
+ self._bound_type, # nbnd
866
+ value, # f
867
+ grad, # g
868
+ self._ftol / onp.finfo(float).eps, # factr
869
+ self._gtol, # pgtol
870
+ self._wa, # wa
871
+ self._iwa, # iwa
872
+ self._task, # task
873
+ self._lsave, # lsave
874
+ self._isave, # isave
875
+ self._dsave, # dsave
876
+ self._line_search_max_steps, # maxls
877
+ self._ln_task, # ln_task
887
878
  )
888
- task_str = self._task.tobytes()
889
- if task_str.startswith(TASK_CONVERGED):
879
+ if self._task[0] == 4:
890
880
  self.converged = onp.asarray(True)
891
- if task_str.startswith(TASK_FG):
881
+ if self._task[0] == 3:
892
882
  break
893
883
 
894
884
 
@@ -897,12 +887,12 @@ def _wa_size(n: int, maxcor: int) -> int:
897
887
  return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
898
888
 
899
889
 
900
- def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
890
+ def _validate_array_dtype(name: str, x: NDArray, dtype: type) -> None:
901
891
  """Validates that `x` is an array with the specified `dtype`."""
902
892
  if not isinstance(x, onp.ndarray):
903
- raise ValueError(f"`x` must be an `onp.ndarray` but got {type(x)}")
893
+ raise ValueError(f"`{name}` must be an `onp.ndarray` but got {type(x)}")
904
894
  if x.dtype != dtype:
905
- raise ValueError(f"`x` must have dtype {dtype} but got {x.dtype}")
895
+ raise ValueError(f"`{name}` must have dtype {dtype} but got {x.dtype}")
906
896
 
907
897
 
908
898
  def _configure_bounds(
@@ -919,21 +909,5 @@ def _configure_bounds(
919
909
  return (
920
910
  onp.asarray(lower_bound_array, onp.float64),
921
911
  onp.asarray(upper_bound_array, onp.float64),
922
- onp.asarray(bound_type),
923
- )
924
-
925
-
926
- def _array_from_s60_str(s60_str: NDArray) -> NDArray:
927
- """Return a jax array for a numpy s60 string."""
928
- assert s60_str.shape == (1,)
929
- chars = [int(o) for o in s60_str[0]]
930
- chars.extend([32] * (59 - len(chars)))
931
- return onp.asarray(chars, dtype=int)
932
-
933
-
934
- def _s60_str_from_array(array: NDArray) -> NDArray:
935
- """Return a numpy s60 string for a jax array."""
936
- return onp.asarray(
937
- [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
938
- dtype="S60",
912
+ onp.asarray(bound_type, onp.int32),
939
913
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: invrs_opt
3
- Version: 0.10.4
3
+ Version: 0.10.5
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -518,7 +518,7 @@ Requires-Dist: jaxlib
518
518
  Requires-Dist: numpy
519
519
  Requires-Dist: requests
520
520
  Requires-Dist: optax
521
- Requires-Dist: scipy<1.15.0
521
+ Requires-Dist: scipy>=1.15.0
522
522
  Requires-Dist: totypes
523
523
  Requires-Dist: types-requests
524
524
  Provides-Extra: tests
@@ -1,11 +1,11 @@
1
- invrs_opt/__init__.py,sha256=os3uHllRlWdmF4rPgnDPipz7kG6NCdNaECH7DvBlUpA,586
1
+ invrs_opt/__init__.py,sha256=b2AnOoyTVD9z6CZEjHc3_S0un9BDAXTz8vv8yWtw0AE,586
2
2
  invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  invrs_opt/experimental/client.py,sha256=tbtH13FrA65XmTZfTO71CxJ78jeAEj3Zf85R-MTwbiU,4909
5
5
  invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
6
  invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  invrs_opt/optimizers/base.py,sha256=uFfkN2LwWzAtwh6ktwWNy2iHNOY-sW3JzI46iSFkgok,1306
8
- invrs_opt/optimizers/lbfgsb.py,sha256=eOg74ulC_OgVWPXdhZ_Lte80peJsA4DnTlrClkJNaA0,37200
8
+ invrs_opt/optimizers/lbfgsb.py,sha256=WP6ouVtLaXSwJBh7CSzWR7rnRdHZuSmOr57TKF4UxMg,36659
9
9
  invrs_opt/optimizers/wrapped_optax.py,sha256=781-8v_TlHsGaQF9Se9_iOEvtOLOr-BesTLudYarSlg,13685
10
10
  invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  invrs_opt/parameterization/base.py,sha256=jSwrEO86lGkYQG5gWsHvcIMWpZnnbdiKpn--2qaU02g,5362
@@ -13,8 +13,8 @@ invrs_opt/parameterization/filter_project.py,sha256=XL3HTEBLrF-q_75TjhOWLNdfUOSE
13
13
  invrs_opt/parameterization/gaussian_levelset.py,sha256=-6foekLTFoZDtMKuoMEvdxMJt0_zTxrKNJo0Vn-Rv80,26073
14
14
  invrs_opt/parameterization/pixel.py,sha256=YWkyBhfYtzI8cQ-M90PAZqRAbabwVaUh0UiYIGegQHI,1955
15
15
  invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
- invrs_opt-0.10.4.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- invrs_opt-0.10.4.dist-info/METADATA,sha256=FbUpVxnhVW3U2bAMTO21ekJFRY9AWYJsUhXYq9YAMmE,32815
18
- invrs_opt-0.10.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
19
- invrs_opt-0.10.4.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
- invrs_opt-0.10.4.dist-info/RECORD,,
16
+ invrs_opt-0.10.5.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ invrs_opt-0.10.5.dist-info/METADATA,sha256=JV6RdTh63uMDr2ae3PedC4BIxz54G8twEBGd38hKlZg,32816
18
+ invrs_opt-0.10.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
19
+ invrs_opt-0.10.5.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
+ invrs_opt-0.10.5.dist-info/RECORD,,