invrs-opt 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +2 -2
- invrs_opt/lbfgsb/lbfgsb.py +127 -33
- invrs_opt/lbfgsb/transform.py +0 -1
- {invrs_opt-0.1.4.dist-info → invrs_opt-0.2.0.dist-info}/METADATA +2 -2
- invrs_opt-0.2.0.dist-info/RECORD +11 -0
- invrs_opt-0.1.4.dist-info/RECORD +0 -11
- {invrs_opt-0.1.4.dist-info → invrs_opt-0.2.0.dist-info}/LICENSE +0 -0
- {invrs_opt-0.1.4.dist-info → invrs_opt-0.2.0.dist-info}/WHEEL +0 -0
- {invrs_opt-0.1.4.dist-info → invrs_opt-0.2.0.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
__version__ = "v0.
|
6
|
+
__version__ = "v0.2.0"
|
7
7
|
__author__ = "Martin F. Schubert <mfschubert@gmail.com>"
|
8
8
|
|
9
|
-
from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
|
10
9
|
from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
|
10
|
+
from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
|
invrs_opt/lbfgsb/lbfgsb.py
CHANGED
@@ -10,20 +10,19 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
|
10
10
|
import jax
|
11
11
|
import jax.numpy as jnp
|
12
12
|
import numpy as onp
|
13
|
-
from jax import flatten_util
|
14
|
-
from jax import tree_util
|
13
|
+
from jax import flatten_util, tree_util
|
15
14
|
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
16
15
|
_lbfgsb as scipy_lbfgsb,
|
17
16
|
)
|
17
|
+
from totypes import types
|
18
18
|
|
19
|
-
from invrs_opt.lbfgsb import transform
|
20
19
|
from invrs_opt import base
|
21
|
-
from
|
20
|
+
from invrs_opt.lbfgsb import transform
|
22
21
|
|
23
22
|
NDArray = onp.ndarray[Any, Any]
|
24
23
|
PyTree = Any
|
25
24
|
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
26
|
-
LbfgsbState = Tuple[PyTree, Dict[str,
|
25
|
+
LbfgsbState = Tuple[PyTree, Dict[str, jnp.ndarray]]
|
27
26
|
|
28
27
|
|
29
28
|
# Task message prefixes for the underlying L-BFGS-B implementation.
|
@@ -188,18 +187,24 @@ def transformed_lbfgsb(
|
|
188
187
|
|
189
188
|
def init_fn(params: PyTree) -> LbfgsbState:
|
190
189
|
"""Initializes the optimization state."""
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
190
|
+
|
191
|
+
def _init_pure(params: PyTree) -> LbfgsbState:
|
192
|
+
lower_bound = types.extract_lower_bound(params)
|
193
|
+
upper_bound = types.extract_upper_bound(params)
|
194
|
+
scipy_lbfgsb_state = ScipyLbfgsbState.init(
|
195
|
+
x0=_to_numpy(params),
|
196
|
+
lower_bound=_bound_for_params(lower_bound, params),
|
197
|
+
upper_bound=_bound_for_params(upper_bound, params),
|
198
|
+
maxcor=maxcor,
|
199
|
+
line_search_max_steps=line_search_max_steps,
|
200
|
+
)
|
201
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
202
|
+
params = transform_fn(latent_params)
|
203
|
+
return params, scipy_lbfgsb_state.to_jax()
|
204
|
+
|
205
|
+
return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
|
206
|
+
_init_pure, _example_state(params, maxcor), params
|
199
207
|
)
|
200
|
-
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
201
|
-
params = transform_fn(latent_params)
|
202
|
-
return (params, dataclasses.asdict(scipy_lbfgsb_state))
|
203
208
|
|
204
209
|
def params_fn(state: LbfgsbState) -> PyTree:
|
205
210
|
"""Returns the parameters for the given `state`."""
|
@@ -214,23 +219,30 @@ def transformed_lbfgsb(
|
|
214
219
|
state: LbfgsbState,
|
215
220
|
) -> LbfgsbState:
|
216
221
|
"""Updates the state."""
|
217
|
-
del params
|
218
|
-
params, lbfgsb_state_dict = state
|
219
|
-
# Avoid in-place updates.
|
220
|
-
lbfgsb_state_dict = copy.deepcopy(lbfgsb_state_dict)
|
221
|
-
scipy_lbfgsb_state = ScipyLbfgsbState(
|
222
|
-
**lbfgsb_state_dict # type: ignore[arg-type]
|
223
|
-
)
|
224
222
|
|
225
|
-
|
226
|
-
|
227
|
-
|
223
|
+
def _update_pure(
|
224
|
+
grad: PyTree, value: float, params: PyTree, state: LbfgsbState
|
225
|
+
) -> LbfgsbState:
|
226
|
+
del params
|
227
|
+
|
228
|
+
params, jax_lbfgsb_state = state
|
229
|
+
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
|
228
230
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
231
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
232
|
+
_, vjp_fn = jax.vjp(transform_fn, latent_params)
|
233
|
+
(latent_grad,) = vjp_fn(grad)
|
234
|
+
|
235
|
+
assert onp.size(value) == 1
|
236
|
+
scipy_lbfgsb_state.update(
|
237
|
+
grad=_to_numpy(latent_grad), value=onp.asarray(value)
|
238
|
+
)
|
239
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
240
|
+
params = transform_fn(latent_params)
|
241
|
+
return params, scipy_lbfgsb_state.to_jax()
|
242
|
+
|
243
|
+
return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
|
244
|
+
_update_pure, state, grad, value, params, state
|
245
|
+
)
|
234
246
|
|
235
247
|
return base.Optimizer(
|
236
248
|
init=init_fn,
|
@@ -331,6 +343,29 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
|
331
343
|
return bound_flat
|
332
344
|
|
333
345
|
|
346
|
+
def _example_state(params: PyTree, maxcor: int) -> PyTree:
|
347
|
+
"""Return an example state for the given `params` and `maxcor`."""
|
348
|
+
params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
349
|
+
n = params_flat.size
|
350
|
+
float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
|
351
|
+
example_jax_lbfgsb_state = dict(
|
352
|
+
x=jnp.zeros(n, dtype=float),
|
353
|
+
_maxcor=jnp.zeros((), dtype=int),
|
354
|
+
_line_search_max_steps=jnp.zeros((), dtype=int),
|
355
|
+
_wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
|
356
|
+
_iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
|
357
|
+
_task=jnp.zeros(59, dtype=int),
|
358
|
+
_csave=jnp.zeros(59, dtype=int),
|
359
|
+
_lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
|
360
|
+
_isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
|
361
|
+
_dsave=jnp.zeros(29, dtype=float),
|
362
|
+
_lower_bound=jnp.zeros(n, dtype=float),
|
363
|
+
_upper_bound=jnp.zeros(n, dtype=float),
|
364
|
+
_bound_type=jnp.zeros(n, dtype=int),
|
365
|
+
)
|
366
|
+
return float_params, example_jax_lbfgsb_state
|
367
|
+
|
368
|
+
|
334
369
|
# ------------------------------------------------------------------------------
|
335
370
|
# Wrapper for scipy's L-BFGS-B implementation.
|
336
371
|
# ------------------------------------------------------------------------------
|
@@ -395,6 +430,44 @@ class ScipyLbfgsbState:
|
|
395
430
|
_validate_array_dtype(self._upper_bound, onp.float64)
|
396
431
|
_validate_array_dtype(self._bound_type, int)
|
397
432
|
|
433
|
+
def to_jax(self) -> Dict[str, jnp.ndarray]:
|
434
|
+
"""Generates a dictionary of jax arrays defining the state."""
|
435
|
+
return dict(
|
436
|
+
x=jnp.asarray(self.x),
|
437
|
+
_maxcor=jnp.asarray(self._maxcor),
|
438
|
+
_line_search_max_steps=jnp.asarray(self._line_search_max_steps),
|
439
|
+
_wa=jnp.asarray(self._wa),
|
440
|
+
_iwa=jnp.asarray(self._iwa),
|
441
|
+
_task=_array_from_s60_str(self._task),
|
442
|
+
_csave=_array_from_s60_str(self._csave),
|
443
|
+
_lsave=jnp.asarray(self._lsave),
|
444
|
+
_isave=jnp.asarray(self._isave),
|
445
|
+
_dsave=jnp.asarray(self._dsave),
|
446
|
+
_lower_bound=jnp.asarray(self._lower_bound),
|
447
|
+
_upper_bound=jnp.asarray(self._upper_bound),
|
448
|
+
_bound_type=jnp.asarray(self._bound_type),
|
449
|
+
)
|
450
|
+
|
451
|
+
@classmethod
|
452
|
+
def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
|
453
|
+
"""Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
|
454
|
+
state_dict = copy.deepcopy(state_dict)
|
455
|
+
return ScipyLbfgsbState(
|
456
|
+
x=onp.asarray(state_dict["x"], dtype=onp.float64),
|
457
|
+
_maxcor=int(state_dict["_maxcor"]),
|
458
|
+
_line_search_max_steps=int(state_dict["_line_search_max_steps"]),
|
459
|
+
_wa=onp.asarray(state_dict["_wa"], onp.float64),
|
460
|
+
_iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
|
461
|
+
_task=_s60_str_from_array(state_dict["_task"]),
|
462
|
+
_csave=_s60_str_from_array(state_dict["_csave"]),
|
463
|
+
_lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
|
464
|
+
_isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
|
465
|
+
_dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
|
466
|
+
_lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
|
467
|
+
_upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
|
468
|
+
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
|
469
|
+
)
|
470
|
+
|
398
471
|
@classmethod
|
399
472
|
def init(
|
400
473
|
cls,
|
@@ -440,12 +513,12 @@ class ScipyLbfgsbState:
|
|
440
513
|
|
441
514
|
# See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
|
442
515
|
# function.
|
443
|
-
|
516
|
+
wa_size = _wa_size(n=n, maxcor=maxcor)
|
444
517
|
state = ScipyLbfgsbState(
|
445
518
|
x=onp.array(x0, onp.float64),
|
446
519
|
_maxcor=maxcor,
|
447
520
|
_line_search_max_steps=line_search_max_steps,
|
448
|
-
_wa=onp.zeros(
|
521
|
+
_wa=onp.zeros(wa_size, onp.float64),
|
449
522
|
_iwa=onp.zeros(3 * n, FORTRAN_INT),
|
450
523
|
_task=task,
|
451
524
|
_csave=onp.zeros(1, "S60"),
|
@@ -510,6 +583,11 @@ class ScipyLbfgsbState:
|
|
510
583
|
break
|
511
584
|
|
512
585
|
|
586
|
+
def _wa_size(n: int, maxcor: int) -> int:
|
587
|
+
"""Return the size of the `wa` attribute of lbfgsb state."""
|
588
|
+
return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
|
589
|
+
|
590
|
+
|
513
591
|
def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
|
514
592
|
"""Validates that `x` is an array with the specified `dtype`."""
|
515
593
|
if not isinstance(x, onp.ndarray):
|
@@ -534,3 +612,19 @@ def _configure_bounds(
|
|
534
612
|
onp.asarray(upper_bound_array, onp.float64),
|
535
613
|
onp.asarray(bound_type),
|
536
614
|
)
|
615
|
+
|
616
|
+
|
617
|
+
def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
|
618
|
+
"""Return a jax array for a numpy s60 string."""
|
619
|
+
assert s60_str.shape == (1,)
|
620
|
+
chars = [int(o) for o in s60_str[0]]
|
621
|
+
chars.extend([32] * (59 - len(chars)))
|
622
|
+
return jnp.asarray(chars, dtype=int)
|
623
|
+
|
624
|
+
|
625
|
+
def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
|
626
|
+
"""Return a numpy s60 string for a jax array."""
|
627
|
+
return onp.asarray(
|
628
|
+
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
|
629
|
+
dtype="S60",
|
630
|
+
)
|
invrs_opt/lbfgsb/transform.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -47,7 +47,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
|
|
47
47
|
Requires-Dist: pytest-subtests ; extra == 'tests'
|
48
48
|
|
49
49
|
# invrs-opt - Optimization algorithms for inverse design
|
50
|
-
`v0.
|
50
|
+
`v0.2.0`
|
51
51
|
|
52
52
|
## Overview
|
53
53
|
|
@@ -0,0 +1,11 @@
|
|
1
|
+
invrs_opt/__init__.py,sha256=ZRbgvTG7O9G3ZUHOofIAqDuyNXZoD924wUWForaFZM4,309
|
2
|
+
invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
|
3
|
+
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
+
invrs_opt/lbfgsb/lbfgsb.py,sha256=POK422WPlfVZw63y5i8wkrbh18YzojW2L5nM-eue7PI,23528
|
6
|
+
invrs_opt/lbfgsb/transform.py,sha256=gYRBHUfhdkzSeBLfncdroWs3PP_o5T2X0GfDhTc82Rs,5926
|
7
|
+
invrs_opt-0.2.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
|
8
|
+
invrs_opt-0.2.0.dist-info/METADATA,sha256=vmoIcnI1tL3zvvGwpJBcPula-rIZ3HVbs7bzyvfpZls,3272
|
9
|
+
invrs_opt-0.2.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
10
|
+
invrs_opt-0.2.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
11
|
+
invrs_opt-0.2.0.dist-info/RECORD,,
|
invrs_opt-0.1.4.dist-info/RECORD
DELETED
@@ -1,11 +0,0 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=ByNDfFD5Jc1XezSskVZH4ERm3pBqhBkcxS5V8qtKlHU,309
|
2
|
-
invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
|
3
|
-
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
invrs_opt/lbfgsb/lbfgsb.py,sha256=tlVU6wxaSCr2FSYfEe3RUaTFT9EFyyk4i71l508vNsI,19393
|
6
|
-
invrs_opt/lbfgsb/transform.py,sha256=TjFSeWGqlJv8uY4jtgaZ38Z5hplSX5WSQfQzN8rMV5U,5927
|
7
|
-
invrs_opt-0.1.4.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
|
8
|
-
invrs_opt-0.1.4.dist-info/METADATA,sha256=iVhKz15sEHlsssEtbCmvkOs3ij5h4XGz5I9gNyUY5D4,3272
|
9
|
-
invrs_opt-0.1.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
10
|
-
invrs_opt-0.1.4.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
11
|
-
invrs_opt-0.1.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|