invrs-opt 0.1.3__tar.gz → 0.2.0__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/PKG-INFO +2 -2
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/README.md +1 -1
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/pyproject.toml +1 -1
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt/__init__.py +2 -2
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt/base.py +0 -22
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt/lbfgsb/lbfgsb.py +136 -45
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt/lbfgsb/transform.py +0 -1
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt.egg-info/PKG-INFO +3 -3
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/tests/test_algos.py +12 -4
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/LICENSE +0 -0
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/setup.cfg +0 -0
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt/lbfgsb/__init__.py +0 -0
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt/py.typed +0 -0
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt.egg-info/SOURCES.txt +0 -0
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt.egg-info/dependency_links.txt +0 -0
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt.egg-info/requires.txt +0 -0
- {invrs_opt-0.1.3 → invrs_opt-0.2.0}/src/invrs_opt.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -47,7 +47,7 @@ Requires-Dist: mypy; extra == "dev"
|
|
47
47
|
Requires-Dist: pre-commit; extra == "dev"
|
48
48
|
|
49
49
|
# invrs-opt - Optimization algorithms for inverse design
|
50
|
-
`v0.
|
50
|
+
`v0.2.0`
|
51
51
|
|
52
52
|
## Overview
|
53
53
|
|
@@ -3,8 +3,8 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
__version__ = "v0.
|
6
|
+
__version__ = "v0.2.0"
|
7
7
|
__author__ = "Martin F. Schubert <mfschubert@gmail.com>"
|
8
8
|
|
9
|
-
from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
|
10
9
|
from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
|
10
|
+
from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
|
@@ -6,8 +6,6 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
6
6
|
import dataclasses
|
7
7
|
from typing import Any, Protocol
|
8
8
|
|
9
|
-
from totypes import json_utils
|
10
|
-
|
11
9
|
PyTree = Any
|
12
10
|
|
13
11
|
|
@@ -46,23 +44,3 @@ class Optimizer:
|
|
46
44
|
init: InitFn
|
47
45
|
params: ParamsFn
|
48
46
|
update: UpdateFn
|
49
|
-
|
50
|
-
|
51
|
-
# Additional custom types and prefixes used for serializing optimizer state.
|
52
|
-
CUSTOM_TYPES_AND_PREFIXES = ()
|
53
|
-
|
54
|
-
|
55
|
-
def serialize(tree: PyTree) -> str:
|
56
|
-
"""Serializes a pytree into a string."""
|
57
|
-
return json_utils.json_from_pytree(
|
58
|
-
tree,
|
59
|
-
extra_custom_types_and_prefixes=CUSTOM_TYPES_AND_PREFIXES,
|
60
|
-
)
|
61
|
-
|
62
|
-
|
63
|
-
def deserialize(serialized: str) -> PyTree:
|
64
|
-
"""Restores a pytree from a string."""
|
65
|
-
return json_utils.pytree_from_json(
|
66
|
-
serialized,
|
67
|
-
extra_custom_types_and_prefixes=CUSTOM_TYPES_AND_PREFIXES,
|
68
|
-
)
|
@@ -10,19 +10,19 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
|
10
10
|
import jax
|
11
11
|
import jax.numpy as jnp
|
12
12
|
import numpy as onp
|
13
|
-
from jax import tree_util
|
13
|
+
from jax import flatten_util, tree_util
|
14
14
|
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
15
15
|
_lbfgsb as scipy_lbfgsb,
|
16
16
|
)
|
17
|
+
from totypes import types
|
17
18
|
|
18
|
-
from invrs_opt.lbfgsb import transform
|
19
19
|
from invrs_opt import base
|
20
|
-
from
|
20
|
+
from invrs_opt.lbfgsb import transform
|
21
21
|
|
22
22
|
NDArray = onp.ndarray[Any, Any]
|
23
23
|
PyTree = Any
|
24
24
|
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
25
|
-
LbfgsbState = Tuple[PyTree, Dict[str,
|
25
|
+
LbfgsbState = Tuple[PyTree, Dict[str, jnp.ndarray]]
|
26
26
|
|
27
27
|
|
28
28
|
# Task message prefixes for the underlying L-BFGS-B implementation.
|
@@ -187,18 +187,24 @@ def transformed_lbfgsb(
|
|
187
187
|
|
188
188
|
def init_fn(params: PyTree) -> LbfgsbState:
|
189
189
|
"""Initializes the optimization state."""
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
190
|
+
|
191
|
+
def _init_pure(params: PyTree) -> LbfgsbState:
|
192
|
+
lower_bound = types.extract_lower_bound(params)
|
193
|
+
upper_bound = types.extract_upper_bound(params)
|
194
|
+
scipy_lbfgsb_state = ScipyLbfgsbState.init(
|
195
|
+
x0=_to_numpy(params),
|
196
|
+
lower_bound=_bound_for_params(lower_bound, params),
|
197
|
+
upper_bound=_bound_for_params(upper_bound, params),
|
198
|
+
maxcor=maxcor,
|
199
|
+
line_search_max_steps=line_search_max_steps,
|
200
|
+
)
|
201
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
202
|
+
params = transform_fn(latent_params)
|
203
|
+
return params, scipy_lbfgsb_state.to_jax()
|
204
|
+
|
205
|
+
return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
|
206
|
+
_init_pure, _example_state(params, maxcor), params
|
198
207
|
)
|
199
|
-
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
200
|
-
params = transform_fn(latent_params)
|
201
|
-
return (params, dataclasses.asdict(scipy_lbfgsb_state))
|
202
208
|
|
203
209
|
def params_fn(state: LbfgsbState) -> PyTree:
|
204
210
|
"""Returns the parameters for the given `state`."""
|
@@ -213,23 +219,30 @@ def transformed_lbfgsb(
|
|
213
219
|
state: LbfgsbState,
|
214
220
|
) -> LbfgsbState:
|
215
221
|
"""Updates the state."""
|
216
|
-
del params
|
217
|
-
params, lbfgsb_state_dict = state
|
218
|
-
# Avoid in-place updates.
|
219
|
-
lbfgsb_state_dict = copy.deepcopy(lbfgsb_state_dict)
|
220
|
-
scipy_lbfgsb_state = ScipyLbfgsbState(
|
221
|
-
**lbfgsb_state_dict # type: ignore[arg-type]
|
222
|
-
)
|
223
222
|
|
224
|
-
|
225
|
-
|
226
|
-
|
223
|
+
def _update_pure(
|
224
|
+
grad: PyTree, value: float, params: PyTree, state: LbfgsbState
|
225
|
+
) -> LbfgsbState:
|
226
|
+
del params
|
227
227
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
228
|
+
params, jax_lbfgsb_state = state
|
229
|
+
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
|
230
|
+
|
231
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
232
|
+
_, vjp_fn = jax.vjp(transform_fn, latent_params)
|
233
|
+
(latent_grad,) = vjp_fn(grad)
|
234
|
+
|
235
|
+
assert onp.size(value) == 1
|
236
|
+
scipy_lbfgsb_state.update(
|
237
|
+
grad=_to_numpy(latent_grad), value=onp.asarray(value)
|
238
|
+
)
|
239
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
240
|
+
params = transform_fn(latent_params)
|
241
|
+
return params, scipy_lbfgsb_state.to_jax()
|
242
|
+
|
243
|
+
return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
|
244
|
+
_update_pure, state, grad, value, params, state
|
245
|
+
)
|
233
246
|
|
234
247
|
return base.Optimizer(
|
235
248
|
init=init_fn,
|
@@ -245,31 +258,25 @@ def transformed_lbfgsb(
|
|
245
258
|
|
246
259
|
def _to_numpy(params: PyTree) -> NDArray:
|
247
260
|
"""Flattens a `params` pytree into a single rank-1 numpy array."""
|
248
|
-
|
249
|
-
|
250
|
-
return x_numpy.astype(onp.float64)
|
261
|
+
x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
262
|
+
return onp.asarray(x, dtype=onp.float64)
|
251
263
|
|
252
264
|
|
253
265
|
def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
|
254
266
|
"""Restores a pytree from a flat numpy array using the structure of `params`.
|
255
267
|
|
268
|
+
Note that the returned pytree includes jax array leaves.
|
269
|
+
|
256
270
|
Args:
|
257
271
|
x_flat: The rank-1 numpy array to be restored.
|
258
272
|
params: A pytree of parameters whose structure is replicated in the restored
|
259
273
|
pytree.
|
260
274
|
|
261
275
|
Returns:
|
262
|
-
The restored pytree.
|
276
|
+
The restored pytree, with jax array leaves.
|
263
277
|
"""
|
264
|
-
|
265
|
-
|
266
|
-
x_split_flat = onp.split(x_flat, indices_or_sections)
|
267
|
-
x_split = [x.reshape(onp.shape(leaf)) for x, leaf in zip(x_split_flat, leaves)]
|
268
|
-
x_split_jax = [
|
269
|
-
jnp.asarray(x) if isinstance(leaf, jnp.ndarray) else x
|
270
|
-
for x, leaf in zip(x_split, leaves)
|
271
|
-
]
|
272
|
-
return tree_util.tree_unflatten(treedef, x_split_jax)
|
278
|
+
_, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
279
|
+
return unflatten_fn(jnp.asarray(x_flat, dtype=float))
|
273
280
|
|
274
281
|
|
275
282
|
def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
@@ -320,6 +327,8 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
|
320
327
|
|
321
328
|
bound_flat = []
|
322
329
|
for b, p in zip(bound_leaves, params_leaves):
|
330
|
+
if p is None:
|
331
|
+
continue
|
323
332
|
if b is None or onp.isscalar(b) or onp.shape(b) == ():
|
324
333
|
bound_flat += [b] * onp.size(p)
|
325
334
|
else:
|
@@ -334,6 +343,29 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
|
334
343
|
return bound_flat
|
335
344
|
|
336
345
|
|
346
|
+
def _example_state(params: PyTree, maxcor: int) -> PyTree:
|
347
|
+
"""Return an example state for the given `params` and `maxcor`."""
|
348
|
+
params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
349
|
+
n = params_flat.size
|
350
|
+
float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
|
351
|
+
example_jax_lbfgsb_state = dict(
|
352
|
+
x=jnp.zeros(n, dtype=float),
|
353
|
+
_maxcor=jnp.zeros((), dtype=int),
|
354
|
+
_line_search_max_steps=jnp.zeros((), dtype=int),
|
355
|
+
_wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
|
356
|
+
_iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
|
357
|
+
_task=jnp.zeros(59, dtype=int),
|
358
|
+
_csave=jnp.zeros(59, dtype=int),
|
359
|
+
_lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
|
360
|
+
_isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
|
361
|
+
_dsave=jnp.zeros(29, dtype=float),
|
362
|
+
_lower_bound=jnp.zeros(n, dtype=float),
|
363
|
+
_upper_bound=jnp.zeros(n, dtype=float),
|
364
|
+
_bound_type=jnp.zeros(n, dtype=int),
|
365
|
+
)
|
366
|
+
return float_params, example_jax_lbfgsb_state
|
367
|
+
|
368
|
+
|
337
369
|
# ------------------------------------------------------------------------------
|
338
370
|
# Wrapper for scipy's L-BFGS-B implementation.
|
339
371
|
# ------------------------------------------------------------------------------
|
@@ -398,6 +430,44 @@ class ScipyLbfgsbState:
|
|
398
430
|
_validate_array_dtype(self._upper_bound, onp.float64)
|
399
431
|
_validate_array_dtype(self._bound_type, int)
|
400
432
|
|
433
|
+
def to_jax(self) -> Dict[str, jnp.ndarray]:
|
434
|
+
"""Generates a dictionary of jax arrays defining the state."""
|
435
|
+
return dict(
|
436
|
+
x=jnp.asarray(self.x),
|
437
|
+
_maxcor=jnp.asarray(self._maxcor),
|
438
|
+
_line_search_max_steps=jnp.asarray(self._line_search_max_steps),
|
439
|
+
_wa=jnp.asarray(self._wa),
|
440
|
+
_iwa=jnp.asarray(self._iwa),
|
441
|
+
_task=_array_from_s60_str(self._task),
|
442
|
+
_csave=_array_from_s60_str(self._csave),
|
443
|
+
_lsave=jnp.asarray(self._lsave),
|
444
|
+
_isave=jnp.asarray(self._isave),
|
445
|
+
_dsave=jnp.asarray(self._dsave),
|
446
|
+
_lower_bound=jnp.asarray(self._lower_bound),
|
447
|
+
_upper_bound=jnp.asarray(self._upper_bound),
|
448
|
+
_bound_type=jnp.asarray(self._bound_type),
|
449
|
+
)
|
450
|
+
|
451
|
+
@classmethod
|
452
|
+
def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
|
453
|
+
"""Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
|
454
|
+
state_dict = copy.deepcopy(state_dict)
|
455
|
+
return ScipyLbfgsbState(
|
456
|
+
x=onp.asarray(state_dict["x"], dtype=onp.float64),
|
457
|
+
_maxcor=int(state_dict["_maxcor"]),
|
458
|
+
_line_search_max_steps=int(state_dict["_line_search_max_steps"]),
|
459
|
+
_wa=onp.asarray(state_dict["_wa"], onp.float64),
|
460
|
+
_iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
|
461
|
+
_task=_s60_str_from_array(state_dict["_task"]),
|
462
|
+
_csave=_s60_str_from_array(state_dict["_csave"]),
|
463
|
+
_lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
|
464
|
+
_isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
|
465
|
+
_dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
|
466
|
+
_lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
|
467
|
+
_upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
|
468
|
+
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
|
469
|
+
)
|
470
|
+
|
401
471
|
@classmethod
|
402
472
|
def init(
|
403
473
|
cls,
|
@@ -443,12 +513,12 @@ class ScipyLbfgsbState:
|
|
443
513
|
|
444
514
|
# See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
|
445
515
|
# function.
|
446
|
-
|
516
|
+
wa_size = _wa_size(n=n, maxcor=maxcor)
|
447
517
|
state = ScipyLbfgsbState(
|
448
518
|
x=onp.array(x0, onp.float64),
|
449
519
|
_maxcor=maxcor,
|
450
520
|
_line_search_max_steps=line_search_max_steps,
|
451
|
-
_wa=onp.zeros(
|
521
|
+
_wa=onp.zeros(wa_size, onp.float64),
|
452
522
|
_iwa=onp.zeros(3 * n, FORTRAN_INT),
|
453
523
|
_task=task,
|
454
524
|
_csave=onp.zeros(1, "S60"),
|
@@ -513,6 +583,11 @@ class ScipyLbfgsbState:
|
|
513
583
|
break
|
514
584
|
|
515
585
|
|
586
|
+
def _wa_size(n: int, maxcor: int) -> int:
|
587
|
+
"""Return the size of the `wa` attribute of lbfgsb state."""
|
588
|
+
return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
|
589
|
+
|
590
|
+
|
516
591
|
def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
|
517
592
|
"""Validates that `x` is an array with the specified `dtype`."""
|
518
593
|
if not isinstance(x, onp.ndarray):
|
@@ -537,3 +612,19 @@ def _configure_bounds(
|
|
537
612
|
onp.asarray(upper_bound_array, onp.float64),
|
538
613
|
onp.asarray(bound_type),
|
539
614
|
)
|
615
|
+
|
616
|
+
|
617
|
+
def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
|
618
|
+
"""Return a jax array for a numpy s60 string."""
|
619
|
+
assert s60_str.shape == (1,)
|
620
|
+
chars = [int(o) for o in s60_str[0]]
|
621
|
+
chars.extend([32] * (59 - len(chars)))
|
622
|
+
return jnp.asarray(chars, dtype=int)
|
623
|
+
|
624
|
+
|
625
|
+
def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
|
626
|
+
"""Return a numpy s60 string for a jax array."""
|
627
|
+
return onp.asarray(
|
628
|
+
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
|
629
|
+
dtype="S60",
|
630
|
+
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
|
-
Name:
|
3
|
-
Version: 0.
|
2
|
+
Name: invrs_opt
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -47,7 +47,7 @@ Requires-Dist: mypy; extra == "dev"
|
|
47
47
|
Requires-Dist: pre-commit; extra == "dev"
|
48
48
|
|
49
49
|
# invrs-opt - Optimization algorithms for inverse design
|
50
|
-
`v0.
|
50
|
+
`v0.2.0`
|
51
51
|
|
52
52
|
## Overview
|
53
53
|
|
@@ -10,9 +10,9 @@ import jax
|
|
10
10
|
import jax.numpy as jnp
|
11
11
|
import numpy as onp
|
12
12
|
import parameterized
|
13
|
+
from totypes import json_utils, symmetry, types
|
13
14
|
|
14
15
|
import invrs_opt
|
15
|
-
from totypes import symmetry, types
|
16
16
|
|
17
17
|
jax.config.update("jax_enable_x64", True)
|
18
18
|
|
@@ -153,14 +153,22 @@ def _lists_to_tuple(pytree, max_depth=10):
|
|
153
153
|
return pytree
|
154
154
|
|
155
155
|
|
156
|
+
def serialize(pytree) -> str:
|
157
|
+
return json_utils.json_from_pytree(pytree=pytree)
|
158
|
+
|
159
|
+
|
160
|
+
def deserialize(serialized):
|
161
|
+
return json_utils.pytree_from_json(serialized=serialized)
|
162
|
+
|
163
|
+
|
156
164
|
class BasicOptimizerTest(unittest.TestCase):
|
157
165
|
@parameterized.parameterized.expand(itertools.product(PARAMS, OPTIMIZERS))
|
158
166
|
def test_state_is_serializable(self, params, opt):
|
159
167
|
state = opt.init(params)
|
160
168
|
leaves, treedef = jax.tree_util.tree_flatten(state)
|
161
169
|
|
162
|
-
serialized_state =
|
163
|
-
restored_state =
|
170
|
+
serialized_state = serialize(state)
|
171
|
+
restored_state = deserialize(serialized_state)
|
164
172
|
# Serialization/deserialization unavoidably converts tuples to lists.
|
165
173
|
# Convert back to tuples to facilitate comparison.
|
166
174
|
restored_state = _lists_to_tuple(restored_state)
|
@@ -211,7 +219,7 @@ class BasicOptimizerTest(unittest.TestCase):
|
|
211
219
|
expected_grad_list.append(grad)
|
212
220
|
|
213
221
|
def serdes(x):
|
214
|
-
return
|
222
|
+
return deserialize(serialize(x))
|
215
223
|
|
216
224
|
# Optimize with serialization.
|
217
225
|
params_list = []
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|