invrs-opt 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- invrs_opt/__init__.py +2 -2
- invrs_opt/lbfgsb/lbfgsb.py +127 -33
- invrs_opt/lbfgsb/transform.py +0 -1
- {invrs_opt-0.1.4.dist-info → invrs_opt-0.2.0.dist-info}/METADATA +2 -2
- invrs_opt-0.2.0.dist-info/RECORD +11 -0
- invrs_opt-0.1.4.dist-info/RECORD +0 -11
- {invrs_opt-0.1.4.dist-info → invrs_opt-0.2.0.dist-info}/LICENSE +0 -0
- {invrs_opt-0.1.4.dist-info → invrs_opt-0.2.0.dist-info}/WHEEL +0 -0
- {invrs_opt-0.1.4.dist-info → invrs_opt-0.2.0.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
__version__ = "v0.
|
6
|
+
__version__ = "v0.2.0"
|
7
7
|
__author__ = "Martin F. Schubert <mfschubert@gmail.com>"
|
8
8
|
|
9
|
-
from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
|
10
9
|
from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
|
10
|
+
from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
|
invrs_opt/lbfgsb/lbfgsb.py
CHANGED
@@ -10,20 +10,19 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
|
10
10
|
import jax
|
11
11
|
import jax.numpy as jnp
|
12
12
|
import numpy as onp
|
13
|
-
from jax import flatten_util
|
14
|
-
from jax import tree_util
|
13
|
+
from jax import flatten_util, tree_util
|
15
14
|
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
16
15
|
_lbfgsb as scipy_lbfgsb,
|
17
16
|
)
|
17
|
+
from totypes import types
|
18
18
|
|
19
|
-
from invrs_opt.lbfgsb import transform
|
20
19
|
from invrs_opt import base
|
21
|
-
from
|
20
|
+
from invrs_opt.lbfgsb import transform
|
22
21
|
|
23
22
|
NDArray = onp.ndarray[Any, Any]
|
24
23
|
PyTree = Any
|
25
24
|
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
26
|
-
LbfgsbState = Tuple[PyTree, Dict[str,
|
25
|
+
LbfgsbState = Tuple[PyTree, Dict[str, jnp.ndarray]]
|
27
26
|
|
28
27
|
|
29
28
|
# Task message prefixes for the underlying L-BFGS-B implementation.
|
@@ -188,18 +187,24 @@ def transformed_lbfgsb(
|
|
188
187
|
|
189
188
|
def init_fn(params: PyTree) -> LbfgsbState:
|
190
189
|
"""Initializes the optimization state."""
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
190
|
+
|
191
|
+
def _init_pure(params: PyTree) -> LbfgsbState:
|
192
|
+
lower_bound = types.extract_lower_bound(params)
|
193
|
+
upper_bound = types.extract_upper_bound(params)
|
194
|
+
scipy_lbfgsb_state = ScipyLbfgsbState.init(
|
195
|
+
x0=_to_numpy(params),
|
196
|
+
lower_bound=_bound_for_params(lower_bound, params),
|
197
|
+
upper_bound=_bound_for_params(upper_bound, params),
|
198
|
+
maxcor=maxcor,
|
199
|
+
line_search_max_steps=line_search_max_steps,
|
200
|
+
)
|
201
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
202
|
+
params = transform_fn(latent_params)
|
203
|
+
return params, scipy_lbfgsb_state.to_jax()
|
204
|
+
|
205
|
+
return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
|
206
|
+
_init_pure, _example_state(params, maxcor), params
|
199
207
|
)
|
200
|
-
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
201
|
-
params = transform_fn(latent_params)
|
202
|
-
return (params, dataclasses.asdict(scipy_lbfgsb_state))
|
203
208
|
|
204
209
|
def params_fn(state: LbfgsbState) -> PyTree:
|
205
210
|
"""Returns the parameters for the given `state`."""
|
@@ -214,23 +219,30 @@ def transformed_lbfgsb(
|
|
214
219
|
state: LbfgsbState,
|
215
220
|
) -> LbfgsbState:
|
216
221
|
"""Updates the state."""
|
217
|
-
del params
|
218
|
-
params, lbfgsb_state_dict = state
|
219
|
-
# Avoid in-place updates.
|
220
|
-
lbfgsb_state_dict = copy.deepcopy(lbfgsb_state_dict)
|
221
|
-
scipy_lbfgsb_state = ScipyLbfgsbState(
|
222
|
-
**lbfgsb_state_dict # type: ignore[arg-type]
|
223
|
-
)
|
224
222
|
|
225
|
-
|
226
|
-
|
227
|
-
|
223
|
+
def _update_pure(
|
224
|
+
grad: PyTree, value: float, params: PyTree, state: LbfgsbState
|
225
|
+
) -> LbfgsbState:
|
226
|
+
del params
|
227
|
+
|
228
|
+
params, jax_lbfgsb_state = state
|
229
|
+
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
|
228
230
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
231
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
232
|
+
_, vjp_fn = jax.vjp(transform_fn, latent_params)
|
233
|
+
(latent_grad,) = vjp_fn(grad)
|
234
|
+
|
235
|
+
assert onp.size(value) == 1
|
236
|
+
scipy_lbfgsb_state.update(
|
237
|
+
grad=_to_numpy(latent_grad), value=onp.asarray(value)
|
238
|
+
)
|
239
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
240
|
+
params = transform_fn(latent_params)
|
241
|
+
return params, scipy_lbfgsb_state.to_jax()
|
242
|
+
|
243
|
+
return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
|
244
|
+
_update_pure, state, grad, value, params, state
|
245
|
+
)
|
234
246
|
|
235
247
|
return base.Optimizer(
|
236
248
|
init=init_fn,
|
@@ -331,6 +343,29 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
|
331
343
|
return bound_flat
|
332
344
|
|
333
345
|
|
346
|
+
def _example_state(params: PyTree, maxcor: int) -> PyTree:
|
347
|
+
"""Return an example state for the given `params` and `maxcor`."""
|
348
|
+
params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
349
|
+
n = params_flat.size
|
350
|
+
float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
|
351
|
+
example_jax_lbfgsb_state = dict(
|
352
|
+
x=jnp.zeros(n, dtype=float),
|
353
|
+
_maxcor=jnp.zeros((), dtype=int),
|
354
|
+
_line_search_max_steps=jnp.zeros((), dtype=int),
|
355
|
+
_wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
|
356
|
+
_iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
|
357
|
+
_task=jnp.zeros(59, dtype=int),
|
358
|
+
_csave=jnp.zeros(59, dtype=int),
|
359
|
+
_lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
|
360
|
+
_isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
|
361
|
+
_dsave=jnp.zeros(29, dtype=float),
|
362
|
+
_lower_bound=jnp.zeros(n, dtype=float),
|
363
|
+
_upper_bound=jnp.zeros(n, dtype=float),
|
364
|
+
_bound_type=jnp.zeros(n, dtype=int),
|
365
|
+
)
|
366
|
+
return float_params, example_jax_lbfgsb_state
|
367
|
+
|
368
|
+
|
334
369
|
# ------------------------------------------------------------------------------
|
335
370
|
# Wrapper for scipy's L-BFGS-B implementation.
|
336
371
|
# ------------------------------------------------------------------------------
|
@@ -395,6 +430,44 @@ class ScipyLbfgsbState:
|
|
395
430
|
_validate_array_dtype(self._upper_bound, onp.float64)
|
396
431
|
_validate_array_dtype(self._bound_type, int)
|
397
432
|
|
433
|
+
def to_jax(self) -> Dict[str, jnp.ndarray]:
|
434
|
+
"""Generates a dictionary of jax arrays defining the state."""
|
435
|
+
return dict(
|
436
|
+
x=jnp.asarray(self.x),
|
437
|
+
_maxcor=jnp.asarray(self._maxcor),
|
438
|
+
_line_search_max_steps=jnp.asarray(self._line_search_max_steps),
|
439
|
+
_wa=jnp.asarray(self._wa),
|
440
|
+
_iwa=jnp.asarray(self._iwa),
|
441
|
+
_task=_array_from_s60_str(self._task),
|
442
|
+
_csave=_array_from_s60_str(self._csave),
|
443
|
+
_lsave=jnp.asarray(self._lsave),
|
444
|
+
_isave=jnp.asarray(self._isave),
|
445
|
+
_dsave=jnp.asarray(self._dsave),
|
446
|
+
_lower_bound=jnp.asarray(self._lower_bound),
|
447
|
+
_upper_bound=jnp.asarray(self._upper_bound),
|
448
|
+
_bound_type=jnp.asarray(self._bound_type),
|
449
|
+
)
|
450
|
+
|
451
|
+
@classmethod
|
452
|
+
def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
|
453
|
+
"""Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
|
454
|
+
state_dict = copy.deepcopy(state_dict)
|
455
|
+
return ScipyLbfgsbState(
|
456
|
+
x=onp.asarray(state_dict["x"], dtype=onp.float64),
|
457
|
+
_maxcor=int(state_dict["_maxcor"]),
|
458
|
+
_line_search_max_steps=int(state_dict["_line_search_max_steps"]),
|
459
|
+
_wa=onp.asarray(state_dict["_wa"], onp.float64),
|
460
|
+
_iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
|
461
|
+
_task=_s60_str_from_array(state_dict["_task"]),
|
462
|
+
_csave=_s60_str_from_array(state_dict["_csave"]),
|
463
|
+
_lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
|
464
|
+
_isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
|
465
|
+
_dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
|
466
|
+
_lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
|
467
|
+
_upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
|
468
|
+
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
|
469
|
+
)
|
470
|
+
|
398
471
|
@classmethod
|
399
472
|
def init(
|
400
473
|
cls,
|
@@ -440,12 +513,12 @@ class ScipyLbfgsbState:
|
|
440
513
|
|
441
514
|
# See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
|
442
515
|
# function.
|
443
|
-
|
516
|
+
wa_size = _wa_size(n=n, maxcor=maxcor)
|
444
517
|
state = ScipyLbfgsbState(
|
445
518
|
x=onp.array(x0, onp.float64),
|
446
519
|
_maxcor=maxcor,
|
447
520
|
_line_search_max_steps=line_search_max_steps,
|
448
|
-
_wa=onp.zeros(
|
521
|
+
_wa=onp.zeros(wa_size, onp.float64),
|
449
522
|
_iwa=onp.zeros(3 * n, FORTRAN_INT),
|
450
523
|
_task=task,
|
451
524
|
_csave=onp.zeros(1, "S60"),
|
@@ -510,6 +583,11 @@ class ScipyLbfgsbState:
|
|
510
583
|
break
|
511
584
|
|
512
585
|
|
586
|
+
def _wa_size(n: int, maxcor: int) -> int:
|
587
|
+
"""Return the size of the `wa` attribute of lbfgsb state."""
|
588
|
+
return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
|
589
|
+
|
590
|
+
|
513
591
|
def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
|
514
592
|
"""Validates that `x` is an array with the specified `dtype`."""
|
515
593
|
if not isinstance(x, onp.ndarray):
|
@@ -534,3 +612,19 @@ def _configure_bounds(
|
|
534
612
|
onp.asarray(upper_bound_array, onp.float64),
|
535
613
|
onp.asarray(bound_type),
|
536
614
|
)
|
615
|
+
|
616
|
+
|
617
|
+
def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
|
618
|
+
"""Return a jax array for a numpy s60 string."""
|
619
|
+
assert s60_str.shape == (1,)
|
620
|
+
chars = [int(o) for o in s60_str[0]]
|
621
|
+
chars.extend([32] * (59 - len(chars)))
|
622
|
+
return jnp.asarray(chars, dtype=int)
|
623
|
+
|
624
|
+
|
625
|
+
def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
|
626
|
+
"""Return a numpy s60 string for a jax array."""
|
627
|
+
return onp.asarray(
|
628
|
+
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
|
629
|
+
dtype="S60",
|
630
|
+
)
|
invrs_opt/lbfgsb/transform.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -47,7 +47,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
|
|
47
47
|
Requires-Dist: pytest-subtests ; extra == 'tests'
|
48
48
|
|
49
49
|
# invrs-opt - Optimization algorithms for inverse design
|
50
|
-
`v0.
|
50
|
+
`v0.2.0`
|
51
51
|
|
52
52
|
## Overview
|
53
53
|
|
@@ -0,0 +1,11 @@
|
|
1
|
+
invrs_opt/__init__.py,sha256=ZRbgvTG7O9G3ZUHOofIAqDuyNXZoD924wUWForaFZM4,309
|
2
|
+
invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
|
3
|
+
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
+
invrs_opt/lbfgsb/lbfgsb.py,sha256=POK422WPlfVZw63y5i8wkrbh18YzojW2L5nM-eue7PI,23528
|
6
|
+
invrs_opt/lbfgsb/transform.py,sha256=gYRBHUfhdkzSeBLfncdroWs3PP_o5T2X0GfDhTc82Rs,5926
|
7
|
+
invrs_opt-0.2.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
|
8
|
+
invrs_opt-0.2.0.dist-info/METADATA,sha256=vmoIcnI1tL3zvvGwpJBcPula-rIZ3HVbs7bzyvfpZls,3272
|
9
|
+
invrs_opt-0.2.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
10
|
+
invrs_opt-0.2.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
11
|
+
invrs_opt-0.2.0.dist-info/RECORD,,
|
invrs_opt-0.1.4.dist-info/RECORD
DELETED
@@ -1,11 +0,0 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=ByNDfFD5Jc1XezSskVZH4ERm3pBqhBkcxS5V8qtKlHU,309
|
2
|
-
invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
|
3
|
-
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
invrs_opt/lbfgsb/lbfgsb.py,sha256=tlVU6wxaSCr2FSYfEe3RUaTFT9EFyyk4i71l508vNsI,19393
|
6
|
-
invrs_opt/lbfgsb/transform.py,sha256=TjFSeWGqlJv8uY4jtgaZ38Z5hplSX5WSQfQzN8rMV5U,5927
|
7
|
-
invrs_opt-0.1.4.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
|
8
|
-
invrs_opt-0.1.4.dist-info/METADATA,sha256=iVhKz15sEHlsssEtbCmvkOs3ij5h4XGz5I9gNyUY5D4,3272
|
9
|
-
invrs_opt-0.1.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
10
|
-
invrs_opt-0.1.4.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
11
|
-
invrs_opt-0.1.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|