invrs-opt 0.9.2__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.9.2"
6
+ __version__ = "v0.9.3"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -3,7 +3,6 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- import copy
7
6
  import dataclasses
8
7
  from typing import Any, Dict, Optional, Sequence, Tuple, Union
9
8
 
@@ -28,6 +27,7 @@ from invrs_opt.parameterization import (
28
27
  NDArray = onp.ndarray[Any, Any]
29
28
  PyTree = Any
30
29
  ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
30
+ NumpyLbfgsbDict = Dict[str, NDArray]
31
31
  JaxLbfgsbDict = Dict[str, jnp.ndarray]
32
32
  LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]
33
33
 
@@ -299,7 +299,7 @@ def parameterized_lbfgsb(
299
299
  def init_fn(params: PyTree) -> LbfgsbState:
300
300
  """Initializes the optimization state."""
301
301
 
302
- def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
302
+ def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, NumpyLbfgsbDict]:
303
303
  lower_bound = types.extract_lower_bound(latent_params)
304
304
  upper_bound = types.extract_upper_bound(latent_params)
305
305
  scipy_lbfgsb_state = ScipyLbfgsbState.init(
@@ -312,7 +312,7 @@ def parameterized_lbfgsb(
312
312
  gtol=gtol,
313
313
  )
314
314
  latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_params)
315
- return latent_params, scipy_lbfgsb_state.to_jax()
315
+ return latent_params, scipy_lbfgsb_state.to_dict()
316
316
 
317
317
  latent_params = _init_latents(params)
318
318
  metadata, latents = param_base.partition_density_metadata(latent_params)
@@ -346,7 +346,7 @@ def parameterized_lbfgsb(
346
346
  flat_latent_grad: PyTree,
347
347
  value: jnp.ndarray,
348
348
  jax_lbfgsb_state: JaxLbfgsbDict,
349
- ) -> Tuple[PyTree, JaxLbfgsbDict]:
349
+ ) -> Tuple[PyTree, NumpyLbfgsbDict]:
350
350
  assert onp.size(value) == 1
351
351
  scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
352
352
  scipy_lbfgsb_state.update(
@@ -354,7 +354,7 @@ def parameterized_lbfgsb(
354
354
  value=onp.array(value, dtype=onp.float64),
355
355
  )
356
356
  flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
357
- return flat_latent_params, scipy_lbfgsb_state.to_jax()
357
+ return flat_latent_params, scipy_lbfgsb_state.to_dict()
358
358
 
359
359
  step, _, latent_params, jax_lbfgsb_state = state
360
360
  metadata, latents = param_base.partition_density_metadata(latent_params)
@@ -696,31 +696,30 @@ class ScipyLbfgsbState:
696
696
  _validate_array_dtype(self._upper_bound, onp.float64)
697
697
  _validate_array_dtype(self._bound_type, int)
698
698
 
699
- def to_jax(self) -> Dict[str, jnp.ndarray]:
699
+ def to_dict(self) -> NumpyLbfgsbDict:
700
700
  """Generates a dictionary of jax arrays defining the state."""
701
701
  return dict(
702
- x=jnp.asarray(self.x),
703
- converged=jnp.asarray(self.converged),
704
- _maxcor=jnp.asarray(self._maxcor),
705
- _line_search_max_steps=jnp.asarray(self._line_search_max_steps),
706
- _ftol=jnp.asarray(self._ftol),
707
- _gtol=jnp.asarray(self._gtol),
708
- _wa=jnp.asarray(self._wa),
709
- _iwa=jnp.asarray(self._iwa),
702
+ x=onp.asarray(self.x),
703
+ converged=onp.asarray(self.converged),
704
+ _maxcor=onp.asarray(self._maxcor),
705
+ _line_search_max_steps=onp.asarray(self._line_search_max_steps),
706
+ _ftol=onp.asarray(self._ftol),
707
+ _gtol=onp.asarray(self._gtol),
708
+ _wa=onp.asarray(self._wa),
709
+ _iwa=onp.asarray(self._iwa),
710
710
  _task=_array_from_s60_str(self._task),
711
711
  _csave=_array_from_s60_str(self._csave),
712
- _lsave=jnp.asarray(self._lsave),
713
- _isave=jnp.asarray(self._isave),
714
- _dsave=jnp.asarray(self._dsave),
715
- _lower_bound=jnp.asarray(self._lower_bound),
716
- _upper_bound=jnp.asarray(self._upper_bound),
717
- _bound_type=jnp.asarray(self._bound_type),
712
+ _lsave=onp.asarray(self._lsave),
713
+ _isave=onp.asarray(self._isave),
714
+ _dsave=onp.asarray(self._dsave),
715
+ _lower_bound=onp.asarray(self._lower_bound),
716
+ _upper_bound=onp.asarray(self._upper_bound),
717
+ _bound_type=onp.asarray(self._bound_type),
718
718
  )
719
719
 
720
720
  @classmethod
721
- def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
721
+ def from_jax(cls, state_dict: JaxLbfgsbDict) -> "ScipyLbfgsbState":
722
722
  """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
723
- state_dict = copy.deepcopy(state_dict)
724
723
  return ScipyLbfgsbState(
725
724
  x=onp.array(state_dict["x"], dtype=onp.float64),
726
725
  converged=onp.asarray(state_dict["converged"], dtype=bool),
@@ -730,8 +729,8 @@ class ScipyLbfgsbState:
730
729
  _gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
731
730
  _wa=onp.array(state_dict["_wa"], onp.float64),
732
731
  _iwa=onp.array(state_dict["_iwa"], dtype=FORTRAN_INT),
733
- _task=_s60_str_from_array(state_dict["_task"]),
734
- _csave=_s60_str_from_array(state_dict["_csave"]),
732
+ _task=_s60_str_from_array(onp.asarray(state_dict["_task"])),
733
+ _csave=_s60_str_from_array(onp.asarray(state_dict["_csave"])),
735
734
  _lsave=onp.array(state_dict["_lsave"], dtype=FORTRAN_INT),
736
735
  _isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
737
736
  _dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
@@ -898,15 +897,15 @@ def _configure_bounds(
898
897
  )
899
898
 
900
899
 
901
- def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
900
+ def _array_from_s60_str(s60_str: NDArray) -> NDArray:
902
901
  """Return a jax array for a numpy s60 string."""
903
902
  assert s60_str.shape == (1,)
904
903
  chars = [int(o) for o in s60_str[0]]
905
904
  chars.extend([32] * (59 - len(chars)))
906
- return jnp.asarray(chars, dtype=int)
905
+ return onp.asarray(chars, dtype=int)
907
906
 
908
907
 
909
- def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
908
+ def _s60_str_from_array(array: NDArray) -> NDArray:
910
909
  """Return a numpy s60 string for a jax array."""
911
910
  return onp.asarray(
912
911
  [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.9.2
3
+ Version: 0.9.3
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -513,7 +513,7 @@ Keywords: topology,optimization,jax,inverse design
513
513
  Requires-Python: >=3.7
514
514
  Description-Content-Type: text/markdown
515
515
  License-File: LICENSE
516
- Requires-Dist: jax
516
+ Requires-Dist: jax<=0.4.35
517
517
  Requires-Dist: jaxlib
518
518
  Requires-Dist: numpy
519
519
  Requires-Dist: requests
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
533
533
  Requires-Dist: pytest-subtests; extra == "tests"
534
534
 
535
535
  # invrs-opt - Optimization algorithms for inverse design
536
- `v0.9.2`
536
+ `v0.9.3`
537
537
 
538
538
  ## Overview
539
539
 
@@ -1,11 +1,11 @@
1
- invrs_opt/__init__.py,sha256=KoyFNAkv9psGOf-zbhFORZI4io0O1AmGuMyCitqeLVg,585
1
+ invrs_opt/__init__.py,sha256=tCzzgk3fRw4_HT8YLg6RlXIrvv19H11TGJngEGL5Iyk,585
2
2
  invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
5
5
  invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
6
  invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
8
- invrs_opt/optimizers/lbfgsb.py,sha256=8BPiEAqececL-zLnqrgN0CogGDkAd1tyAGndUB-kahc,36349
8
+ invrs_opt/optimizers/lbfgsb.py,sha256=9DrmyCj4Ny04NFVRUTz3pJypbw_j5Gw4wpKfe0WKEv4,36336
9
9
  invrs_opt/optimizers/wrapped_optax.py,sha256=VXdCteT2kumqhP81l3p6QiEqwBffoUuJ3UjrAyX5ToA,13468
10
10
  invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  invrs_opt/parameterization/base.py,sha256=BObzbz6efT2nBjib0_5BSdkCmFi2f0mcZ9VJYpDzO6Q,5278
@@ -13,8 +13,8 @@ invrs_opt/parameterization/filter_project.py,sha256=7Jb8JVENmBTdx3-XmI-VRm4aMjxg
13
13
  invrs_opt/parameterization/gaussian_levelset.py,sha256=Uagx7k69SWmass0YirD5JN8O4QDbwwKTBBjRfkIXvv8,24793
14
14
  invrs_opt/parameterization/pixel.py,sha256=AwC4GBNNOysdICvYHv_D2tZdqJmYiRzOUZNq_-R9Z70,1617
15
15
  invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
- invrs_opt-0.9.2.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- invrs_opt-0.9.2.dist-info/METADATA,sha256=vpWyCF9sTpxzWBqce-AESwaNmeEWBimlWPFHSIniJFk,32633
18
- invrs_opt-0.9.2.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
19
- invrs_opt-0.9.2.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
- invrs_opt-0.9.2.dist-info/RECORD,,
16
+ invrs_opt-0.9.3.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ invrs_opt-0.9.3.dist-info/METADATA,sha256=eMtr2aqWLeT2y6dOzEyETbNWbbeTzh2iCcfGu4G2cuQ,32641
18
+ invrs_opt-0.9.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
19
+ invrs_opt-0.9.3.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
+ invrs_opt-0.9.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.1.2)
2
+ Generator: setuptools (75.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5