invrs-opt 0.9.2__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +1 -1
- invrs_opt/optimizers/lbfgsb.py +26 -27
- {invrs_opt-0.9.2.dist-info → invrs_opt-0.9.3.dist-info}/METADATA +3 -3
- {invrs_opt-0.9.2.dist-info → invrs_opt-0.9.3.dist-info}/RECORD +7 -7
- {invrs_opt-0.9.2.dist-info → invrs_opt-0.9.3.dist-info}/WHEEL +1 -1
- {invrs_opt-0.9.2.dist-info → invrs_opt-0.9.3.dist-info}/LICENSE +0 -0
- {invrs_opt-0.9.2.dist-info → invrs_opt-0.9.3.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
invrs_opt/optimizers/lbfgsb.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
import copy
|
7
6
|
import dataclasses
|
8
7
|
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
9
8
|
|
@@ -28,6 +27,7 @@ from invrs_opt.parameterization import (
|
|
28
27
|
NDArray = onp.ndarray[Any, Any]
|
29
28
|
PyTree = Any
|
30
29
|
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
30
|
+
NumpyLbfgsbDict = Dict[str, NDArray]
|
31
31
|
JaxLbfgsbDict = Dict[str, jnp.ndarray]
|
32
32
|
LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]
|
33
33
|
|
@@ -299,7 +299,7 @@ def parameterized_lbfgsb(
|
|
299
299
|
def init_fn(params: PyTree) -> LbfgsbState:
|
300
300
|
"""Initializes the optimization state."""
|
301
301
|
|
302
|
-
def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree,
|
302
|
+
def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, NumpyLbfgsbDict]:
|
303
303
|
lower_bound = types.extract_lower_bound(latent_params)
|
304
304
|
upper_bound = types.extract_upper_bound(latent_params)
|
305
305
|
scipy_lbfgsb_state = ScipyLbfgsbState.init(
|
@@ -312,7 +312,7 @@ def parameterized_lbfgsb(
|
|
312
312
|
gtol=gtol,
|
313
313
|
)
|
314
314
|
latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_params)
|
315
|
-
return latent_params, scipy_lbfgsb_state.
|
315
|
+
return latent_params, scipy_lbfgsb_state.to_dict()
|
316
316
|
|
317
317
|
latent_params = _init_latents(params)
|
318
318
|
metadata, latents = param_base.partition_density_metadata(latent_params)
|
@@ -346,7 +346,7 @@ def parameterized_lbfgsb(
|
|
346
346
|
flat_latent_grad: PyTree,
|
347
347
|
value: jnp.ndarray,
|
348
348
|
jax_lbfgsb_state: JaxLbfgsbDict,
|
349
|
-
) -> Tuple[PyTree,
|
349
|
+
) -> Tuple[PyTree, NumpyLbfgsbDict]:
|
350
350
|
assert onp.size(value) == 1
|
351
351
|
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
|
352
352
|
scipy_lbfgsb_state.update(
|
@@ -354,7 +354,7 @@ def parameterized_lbfgsb(
|
|
354
354
|
value=onp.array(value, dtype=onp.float64),
|
355
355
|
)
|
356
356
|
flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
|
357
|
-
return flat_latent_params, scipy_lbfgsb_state.
|
357
|
+
return flat_latent_params, scipy_lbfgsb_state.to_dict()
|
358
358
|
|
359
359
|
step, _, latent_params, jax_lbfgsb_state = state
|
360
360
|
metadata, latents = param_base.partition_density_metadata(latent_params)
|
@@ -696,31 +696,30 @@ class ScipyLbfgsbState:
|
|
696
696
|
_validate_array_dtype(self._upper_bound, onp.float64)
|
697
697
|
_validate_array_dtype(self._bound_type, int)
|
698
698
|
|
699
|
-
def
|
699
|
+
def to_dict(self) -> NumpyLbfgsbDict:
|
700
700
|
"""Generates a dictionary of jax arrays defining the state."""
|
701
701
|
return dict(
|
702
|
-
x=
|
703
|
-
converged=
|
704
|
-
_maxcor=
|
705
|
-
_line_search_max_steps=
|
706
|
-
_ftol=
|
707
|
-
_gtol=
|
708
|
-
_wa=
|
709
|
-
_iwa=
|
702
|
+
x=onp.asarray(self.x),
|
703
|
+
converged=onp.asarray(self.converged),
|
704
|
+
_maxcor=onp.asarray(self._maxcor),
|
705
|
+
_line_search_max_steps=onp.asarray(self._line_search_max_steps),
|
706
|
+
_ftol=onp.asarray(self._ftol),
|
707
|
+
_gtol=onp.asarray(self._gtol),
|
708
|
+
_wa=onp.asarray(self._wa),
|
709
|
+
_iwa=onp.asarray(self._iwa),
|
710
710
|
_task=_array_from_s60_str(self._task),
|
711
711
|
_csave=_array_from_s60_str(self._csave),
|
712
|
-
_lsave=
|
713
|
-
_isave=
|
714
|
-
_dsave=
|
715
|
-
_lower_bound=
|
716
|
-
_upper_bound=
|
717
|
-
_bound_type=
|
712
|
+
_lsave=onp.asarray(self._lsave),
|
713
|
+
_isave=onp.asarray(self._isave),
|
714
|
+
_dsave=onp.asarray(self._dsave),
|
715
|
+
_lower_bound=onp.asarray(self._lower_bound),
|
716
|
+
_upper_bound=onp.asarray(self._upper_bound),
|
717
|
+
_bound_type=onp.asarray(self._bound_type),
|
718
718
|
)
|
719
719
|
|
720
720
|
@classmethod
|
721
|
-
def from_jax(cls, state_dict:
|
721
|
+
def from_jax(cls, state_dict: JaxLbfgsbDict) -> "ScipyLbfgsbState":
|
722
722
|
"""Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
|
723
|
-
state_dict = copy.deepcopy(state_dict)
|
724
723
|
return ScipyLbfgsbState(
|
725
724
|
x=onp.array(state_dict["x"], dtype=onp.float64),
|
726
725
|
converged=onp.asarray(state_dict["converged"], dtype=bool),
|
@@ -730,8 +729,8 @@ class ScipyLbfgsbState:
|
|
730
729
|
_gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
|
731
730
|
_wa=onp.array(state_dict["_wa"], onp.float64),
|
732
731
|
_iwa=onp.array(state_dict["_iwa"], dtype=FORTRAN_INT),
|
733
|
-
_task=_s60_str_from_array(state_dict["_task"]),
|
734
|
-
_csave=_s60_str_from_array(state_dict["_csave"]),
|
732
|
+
_task=_s60_str_from_array(onp.asarray(state_dict["_task"])),
|
733
|
+
_csave=_s60_str_from_array(onp.asarray(state_dict["_csave"])),
|
735
734
|
_lsave=onp.array(state_dict["_lsave"], dtype=FORTRAN_INT),
|
736
735
|
_isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
|
737
736
|
_dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
|
@@ -898,15 +897,15 @@ def _configure_bounds(
|
|
898
897
|
)
|
899
898
|
|
900
899
|
|
901
|
-
def _array_from_s60_str(s60_str: NDArray) ->
|
900
|
+
def _array_from_s60_str(s60_str: NDArray) -> NDArray:
|
902
901
|
"""Return a jax array for a numpy s60 string."""
|
903
902
|
assert s60_str.shape == (1,)
|
904
903
|
chars = [int(o) for o in s60_str[0]]
|
905
904
|
chars.extend([32] * (59 - len(chars)))
|
906
|
-
return
|
905
|
+
return onp.asarray(chars, dtype=int)
|
907
906
|
|
908
907
|
|
909
|
-
def _s60_str_from_array(array:
|
908
|
+
def _s60_str_from_array(array: NDArray) -> NDArray:
|
910
909
|
"""Return a numpy s60 string for a jax array."""
|
911
910
|
return onp.asarray(
|
912
911
|
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.9.
|
3
|
+
Version: 0.9.3
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -513,7 +513,7 @@ Keywords: topology,optimization,jax,inverse design
|
|
513
513
|
Requires-Python: >=3.7
|
514
514
|
Description-Content-Type: text/markdown
|
515
515
|
License-File: LICENSE
|
516
|
-
Requires-Dist: jax
|
516
|
+
Requires-Dist: jax<=0.4.35
|
517
517
|
Requires-Dist: jaxlib
|
518
518
|
Requires-Dist: numpy
|
519
519
|
Requires-Dist: requests
|
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
|
|
533
533
|
Requires-Dist: pytest-subtests; extra == "tests"
|
534
534
|
|
535
535
|
# invrs-opt - Optimization algorithms for inverse design
|
536
|
-
`v0.9.
|
536
|
+
`v0.9.3`
|
537
537
|
|
538
538
|
## Overview
|
539
539
|
|
@@ -1,11 +1,11 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=
|
1
|
+
invrs_opt/__init__.py,sha256=tCzzgk3fRw4_HT8YLg6RlXIrvv19H11TGJngEGL5Iyk,585
|
2
2
|
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
|
5
5
|
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
6
6
|
invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
|
8
|
-
invrs_opt/optimizers/lbfgsb.py,sha256=
|
8
|
+
invrs_opt/optimizers/lbfgsb.py,sha256=9DrmyCj4Ny04NFVRUTz3pJypbw_j5Gw4wpKfe0WKEv4,36336
|
9
9
|
invrs_opt/optimizers/wrapped_optax.py,sha256=VXdCteT2kumqhP81l3p6QiEqwBffoUuJ3UjrAyX5ToA,13468
|
10
10
|
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
invrs_opt/parameterization/base.py,sha256=BObzbz6efT2nBjib0_5BSdkCmFi2f0mcZ9VJYpDzO6Q,5278
|
@@ -13,8 +13,8 @@ invrs_opt/parameterization/filter_project.py,sha256=7Jb8JVENmBTdx3-XmI-VRm4aMjxg
|
|
13
13
|
invrs_opt/parameterization/gaussian_levelset.py,sha256=Uagx7k69SWmass0YirD5JN8O4QDbwwKTBBjRfkIXvv8,24793
|
14
14
|
invrs_opt/parameterization/pixel.py,sha256=AwC4GBNNOysdICvYHv_D2tZdqJmYiRzOUZNq_-R9Z70,1617
|
15
15
|
invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
|
16
|
-
invrs_opt-0.9.
|
17
|
-
invrs_opt-0.9.
|
18
|
-
invrs_opt-0.9.
|
19
|
-
invrs_opt-0.9.
|
20
|
-
invrs_opt-0.9.
|
16
|
+
invrs_opt-0.9.3.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
17
|
+
invrs_opt-0.9.3.dist-info/METADATA,sha256=eMtr2aqWLeT2y6dOzEyETbNWbbeTzh2iCcfGu4G2cuQ,32641
|
18
|
+
invrs_opt-0.9.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
19
|
+
invrs_opt-0.9.3.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
20
|
+
invrs_opt-0.9.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|