invrs-opt 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +2 -2
- invrs_opt/base.py +0 -22
- invrs_opt/lbfgsb/lbfgsb.py +136 -45
- invrs_opt/lbfgsb/transform.py +0 -1
- {invrs_opt-0.1.3.dist-info → invrs_opt-0.2.0.dist-info}/METADATA +3 -3
- invrs_opt-0.2.0.dist-info/RECORD +11 -0
- {invrs_opt-0.1.3.dist-info → invrs_opt-0.2.0.dist-info}/WHEEL +1 -1
- invrs_opt-0.1.3.dist-info/RECORD +0 -11
- {invrs_opt-0.1.3.dist-info → invrs_opt-0.2.0.dist-info}/LICENSE +0 -0
- {invrs_opt-0.1.3.dist-info → invrs_opt-0.2.0.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
__version__ = "v0.
|
6
|
+
__version__ = "v0.2.0"
|
7
7
|
__author__ = "Martin F. Schubert <mfschubert@gmail.com>"
|
8
8
|
|
9
|
-
from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
|
10
9
|
from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
|
10
|
+
from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
|
invrs_opt/base.py
CHANGED
@@ -6,8 +6,6 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
6
6
|
import dataclasses
|
7
7
|
from typing import Any, Protocol
|
8
8
|
|
9
|
-
from totypes import json_utils
|
10
|
-
|
11
9
|
PyTree = Any
|
12
10
|
|
13
11
|
|
@@ -46,23 +44,3 @@ class Optimizer:
|
|
46
44
|
init: InitFn
|
47
45
|
params: ParamsFn
|
48
46
|
update: UpdateFn
|
49
|
-
|
50
|
-
|
51
|
-
# Additional custom types and prefixes used for serializing optimizer state.
|
52
|
-
CUSTOM_TYPES_AND_PREFIXES = ()
|
53
|
-
|
54
|
-
|
55
|
-
def serialize(tree: PyTree) -> str:
|
56
|
-
"""Serializes a pytree into a string."""
|
57
|
-
return json_utils.json_from_pytree(
|
58
|
-
tree,
|
59
|
-
extra_custom_types_and_prefixes=CUSTOM_TYPES_AND_PREFIXES,
|
60
|
-
)
|
61
|
-
|
62
|
-
|
63
|
-
def deserialize(serialized: str) -> PyTree:
|
64
|
-
"""Restores a pytree from a string."""
|
65
|
-
return json_utils.pytree_from_json(
|
66
|
-
serialized,
|
67
|
-
extra_custom_types_and_prefixes=CUSTOM_TYPES_AND_PREFIXES,
|
68
|
-
)
|
invrs_opt/lbfgsb/lbfgsb.py
CHANGED
@@ -10,19 +10,19 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
|
10
10
|
import jax
|
11
11
|
import jax.numpy as jnp
|
12
12
|
import numpy as onp
|
13
|
-
from jax import tree_util
|
13
|
+
from jax import flatten_util, tree_util
|
14
14
|
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
15
15
|
_lbfgsb as scipy_lbfgsb,
|
16
16
|
)
|
17
|
+
from totypes import types
|
17
18
|
|
18
|
-
from invrs_opt.lbfgsb import transform
|
19
19
|
from invrs_opt import base
|
20
|
-
from
|
20
|
+
from invrs_opt.lbfgsb import transform
|
21
21
|
|
22
22
|
NDArray = onp.ndarray[Any, Any]
|
23
23
|
PyTree = Any
|
24
24
|
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
25
|
-
LbfgsbState = Tuple[PyTree, Dict[str,
|
25
|
+
LbfgsbState = Tuple[PyTree, Dict[str, jnp.ndarray]]
|
26
26
|
|
27
27
|
|
28
28
|
# Task message prefixes for the underlying L-BFGS-B implementation.
|
@@ -187,18 +187,24 @@ def transformed_lbfgsb(
|
|
187
187
|
|
188
188
|
def init_fn(params: PyTree) -> LbfgsbState:
|
189
189
|
"""Initializes the optimization state."""
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
190
|
+
|
191
|
+
def _init_pure(params: PyTree) -> LbfgsbState:
|
192
|
+
lower_bound = types.extract_lower_bound(params)
|
193
|
+
upper_bound = types.extract_upper_bound(params)
|
194
|
+
scipy_lbfgsb_state = ScipyLbfgsbState.init(
|
195
|
+
x0=_to_numpy(params),
|
196
|
+
lower_bound=_bound_for_params(lower_bound, params),
|
197
|
+
upper_bound=_bound_for_params(upper_bound, params),
|
198
|
+
maxcor=maxcor,
|
199
|
+
line_search_max_steps=line_search_max_steps,
|
200
|
+
)
|
201
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
202
|
+
params = transform_fn(latent_params)
|
203
|
+
return params, scipy_lbfgsb_state.to_jax()
|
204
|
+
|
205
|
+
return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
|
206
|
+
_init_pure, _example_state(params, maxcor), params
|
198
207
|
)
|
199
|
-
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
200
|
-
params = transform_fn(latent_params)
|
201
|
-
return (params, dataclasses.asdict(scipy_lbfgsb_state))
|
202
208
|
|
203
209
|
def params_fn(state: LbfgsbState) -> PyTree:
|
204
210
|
"""Returns the parameters for the given `state`."""
|
@@ -213,23 +219,30 @@ def transformed_lbfgsb(
|
|
213
219
|
state: LbfgsbState,
|
214
220
|
) -> LbfgsbState:
|
215
221
|
"""Updates the state."""
|
216
|
-
del params
|
217
|
-
params, lbfgsb_state_dict = state
|
218
|
-
# Avoid in-place updates.
|
219
|
-
lbfgsb_state_dict = copy.deepcopy(lbfgsb_state_dict)
|
220
|
-
scipy_lbfgsb_state = ScipyLbfgsbState(
|
221
|
-
**lbfgsb_state_dict # type: ignore[arg-type]
|
222
|
-
)
|
223
222
|
|
224
|
-
|
225
|
-
|
226
|
-
|
223
|
+
def _update_pure(
|
224
|
+
grad: PyTree, value: float, params: PyTree, state: LbfgsbState
|
225
|
+
) -> LbfgsbState:
|
226
|
+
del params
|
227
227
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
228
|
+
params, jax_lbfgsb_state = state
|
229
|
+
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
|
230
|
+
|
231
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
232
|
+
_, vjp_fn = jax.vjp(transform_fn, latent_params)
|
233
|
+
(latent_grad,) = vjp_fn(grad)
|
234
|
+
|
235
|
+
assert onp.size(value) == 1
|
236
|
+
scipy_lbfgsb_state.update(
|
237
|
+
grad=_to_numpy(latent_grad), value=onp.asarray(value)
|
238
|
+
)
|
239
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
|
240
|
+
params = transform_fn(latent_params)
|
241
|
+
return params, scipy_lbfgsb_state.to_jax()
|
242
|
+
|
243
|
+
return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
|
244
|
+
_update_pure, state, grad, value, params, state
|
245
|
+
)
|
233
246
|
|
234
247
|
return base.Optimizer(
|
235
248
|
init=init_fn,
|
@@ -245,31 +258,25 @@ def transformed_lbfgsb(
|
|
245
258
|
|
246
259
|
def _to_numpy(params: PyTree) -> NDArray:
|
247
260
|
"""Flattens a `params` pytree into a single rank-1 numpy array."""
|
248
|
-
|
249
|
-
|
250
|
-
return x_numpy.astype(onp.float64)
|
261
|
+
x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
262
|
+
return onp.asarray(x, dtype=onp.float64)
|
251
263
|
|
252
264
|
|
253
265
|
def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
|
254
266
|
"""Restores a pytree from a flat numpy array using the structure of `params`.
|
255
267
|
|
268
|
+
Note that the returned pytree includes jax array leaves.
|
269
|
+
|
256
270
|
Args:
|
257
271
|
x_flat: The rank-1 numpy array to be restored.
|
258
272
|
params: A pytree of parameters whose structure is replicated in the restored
|
259
273
|
pytree.
|
260
274
|
|
261
275
|
Returns:
|
262
|
-
The restored pytree.
|
276
|
+
The restored pytree, with jax array leaves.
|
263
277
|
"""
|
264
|
-
|
265
|
-
|
266
|
-
x_split_flat = onp.split(x_flat, indices_or_sections)
|
267
|
-
x_split = [x.reshape(onp.shape(leaf)) for x, leaf in zip(x_split_flat, leaves)]
|
268
|
-
x_split_jax = [
|
269
|
-
jnp.asarray(x) if isinstance(leaf, jnp.ndarray) else x
|
270
|
-
for x, leaf in zip(x_split, leaves)
|
271
|
-
]
|
272
|
-
return tree_util.tree_unflatten(treedef, x_split_jax)
|
278
|
+
_, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
279
|
+
return unflatten_fn(jnp.asarray(x_flat, dtype=float))
|
273
280
|
|
274
281
|
|
275
282
|
def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
@@ -320,6 +327,8 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
|
320
327
|
|
321
328
|
bound_flat = []
|
322
329
|
for b, p in zip(bound_leaves, params_leaves):
|
330
|
+
if p is None:
|
331
|
+
continue
|
323
332
|
if b is None or onp.isscalar(b) or onp.shape(b) == ():
|
324
333
|
bound_flat += [b] * onp.size(p)
|
325
334
|
else:
|
@@ -334,6 +343,29 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
|
|
334
343
|
return bound_flat
|
335
344
|
|
336
345
|
|
346
|
+
def _example_state(params: PyTree, maxcor: int) -> PyTree:
|
347
|
+
"""Return an example state for the given `params` and `maxcor`."""
|
348
|
+
params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|
349
|
+
n = params_flat.size
|
350
|
+
float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
|
351
|
+
example_jax_lbfgsb_state = dict(
|
352
|
+
x=jnp.zeros(n, dtype=float),
|
353
|
+
_maxcor=jnp.zeros((), dtype=int),
|
354
|
+
_line_search_max_steps=jnp.zeros((), dtype=int),
|
355
|
+
_wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
|
356
|
+
_iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
|
357
|
+
_task=jnp.zeros(59, dtype=int),
|
358
|
+
_csave=jnp.zeros(59, dtype=int),
|
359
|
+
_lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
|
360
|
+
_isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
|
361
|
+
_dsave=jnp.zeros(29, dtype=float),
|
362
|
+
_lower_bound=jnp.zeros(n, dtype=float),
|
363
|
+
_upper_bound=jnp.zeros(n, dtype=float),
|
364
|
+
_bound_type=jnp.zeros(n, dtype=int),
|
365
|
+
)
|
366
|
+
return float_params, example_jax_lbfgsb_state
|
367
|
+
|
368
|
+
|
337
369
|
# ------------------------------------------------------------------------------
|
338
370
|
# Wrapper for scipy's L-BFGS-B implementation.
|
339
371
|
# ------------------------------------------------------------------------------
|
@@ -398,6 +430,44 @@ class ScipyLbfgsbState:
|
|
398
430
|
_validate_array_dtype(self._upper_bound, onp.float64)
|
399
431
|
_validate_array_dtype(self._bound_type, int)
|
400
432
|
|
433
|
+
def to_jax(self) -> Dict[str, jnp.ndarray]:
|
434
|
+
"""Generates a dictionary of jax arrays defining the state."""
|
435
|
+
return dict(
|
436
|
+
x=jnp.asarray(self.x),
|
437
|
+
_maxcor=jnp.asarray(self._maxcor),
|
438
|
+
_line_search_max_steps=jnp.asarray(self._line_search_max_steps),
|
439
|
+
_wa=jnp.asarray(self._wa),
|
440
|
+
_iwa=jnp.asarray(self._iwa),
|
441
|
+
_task=_array_from_s60_str(self._task),
|
442
|
+
_csave=_array_from_s60_str(self._csave),
|
443
|
+
_lsave=jnp.asarray(self._lsave),
|
444
|
+
_isave=jnp.asarray(self._isave),
|
445
|
+
_dsave=jnp.asarray(self._dsave),
|
446
|
+
_lower_bound=jnp.asarray(self._lower_bound),
|
447
|
+
_upper_bound=jnp.asarray(self._upper_bound),
|
448
|
+
_bound_type=jnp.asarray(self._bound_type),
|
449
|
+
)
|
450
|
+
|
451
|
+
@classmethod
|
452
|
+
def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
|
453
|
+
"""Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
|
454
|
+
state_dict = copy.deepcopy(state_dict)
|
455
|
+
return ScipyLbfgsbState(
|
456
|
+
x=onp.asarray(state_dict["x"], dtype=onp.float64),
|
457
|
+
_maxcor=int(state_dict["_maxcor"]),
|
458
|
+
_line_search_max_steps=int(state_dict["_line_search_max_steps"]),
|
459
|
+
_wa=onp.asarray(state_dict["_wa"], onp.float64),
|
460
|
+
_iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
|
461
|
+
_task=_s60_str_from_array(state_dict["_task"]),
|
462
|
+
_csave=_s60_str_from_array(state_dict["_csave"]),
|
463
|
+
_lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
|
464
|
+
_isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
|
465
|
+
_dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
|
466
|
+
_lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
|
467
|
+
_upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
|
468
|
+
_bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
|
469
|
+
)
|
470
|
+
|
401
471
|
@classmethod
|
402
472
|
def init(
|
403
473
|
cls,
|
@@ -443,12 +513,12 @@ class ScipyLbfgsbState:
|
|
443
513
|
|
444
514
|
# See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
|
445
515
|
# function.
|
446
|
-
|
516
|
+
wa_size = _wa_size(n=n, maxcor=maxcor)
|
447
517
|
state = ScipyLbfgsbState(
|
448
518
|
x=onp.array(x0, onp.float64),
|
449
519
|
_maxcor=maxcor,
|
450
520
|
_line_search_max_steps=line_search_max_steps,
|
451
|
-
_wa=onp.zeros(
|
521
|
+
_wa=onp.zeros(wa_size, onp.float64),
|
452
522
|
_iwa=onp.zeros(3 * n, FORTRAN_INT),
|
453
523
|
_task=task,
|
454
524
|
_csave=onp.zeros(1, "S60"),
|
@@ -513,6 +583,11 @@ class ScipyLbfgsbState:
|
|
513
583
|
break
|
514
584
|
|
515
585
|
|
586
|
+
def _wa_size(n: int, maxcor: int) -> int:
|
587
|
+
"""Return the size of the `wa` attribute of lbfgsb state."""
|
588
|
+
return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
|
589
|
+
|
590
|
+
|
516
591
|
def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
|
517
592
|
"""Validates that `x` is an array with the specified `dtype`."""
|
518
593
|
if not isinstance(x, onp.ndarray):
|
@@ -537,3 +612,19 @@ def _configure_bounds(
|
|
537
612
|
onp.asarray(upper_bound_array, onp.float64),
|
538
613
|
onp.asarray(bound_type),
|
539
614
|
)
|
615
|
+
|
616
|
+
|
617
|
+
def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
|
618
|
+
"""Return a jax array for a numpy s60 string."""
|
619
|
+
assert s60_str.shape == (1,)
|
620
|
+
chars = [int(o) for o in s60_str[0]]
|
621
|
+
chars.extend([32] * (59 - len(chars)))
|
622
|
+
return jnp.asarray(chars, dtype=int)
|
623
|
+
|
624
|
+
|
625
|
+
def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
|
626
|
+
"""Return a numpy s60 string for a jax array."""
|
627
|
+
return onp.asarray(
|
628
|
+
[b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
|
629
|
+
dtype="S60",
|
630
|
+
)
|
invrs_opt/lbfgsb/transform.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
|
-
Name:
|
3
|
-
Version: 0.
|
2
|
+
Name: invrs_opt
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -47,7 +47,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
|
|
47
47
|
Requires-Dist: pytest-subtests ; extra == 'tests'
|
48
48
|
|
49
49
|
# invrs-opt - Optimization algorithms for inverse design
|
50
|
-
`v0.
|
50
|
+
`v0.2.0`
|
51
51
|
|
52
52
|
## Overview
|
53
53
|
|
@@ -0,0 +1,11 @@
|
|
1
|
+
invrs_opt/__init__.py,sha256=ZRbgvTG7O9G3ZUHOofIAqDuyNXZoD924wUWForaFZM4,309
|
2
|
+
invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
|
3
|
+
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
+
invrs_opt/lbfgsb/lbfgsb.py,sha256=POK422WPlfVZw63y5i8wkrbh18YzojW2L5nM-eue7PI,23528
|
6
|
+
invrs_opt/lbfgsb/transform.py,sha256=gYRBHUfhdkzSeBLfncdroWs3PP_o5T2X0GfDhTc82Rs,5926
|
7
|
+
invrs_opt-0.2.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
|
8
|
+
invrs_opt-0.2.0.dist-info/METADATA,sha256=vmoIcnI1tL3zvvGwpJBcPula-rIZ3HVbs7bzyvfpZls,3272
|
9
|
+
invrs_opt-0.2.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
10
|
+
invrs_opt-0.2.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
11
|
+
invrs_opt-0.2.0.dist-info/RECORD,,
|
invrs_opt-0.1.3.dist-info/RECORD
DELETED
@@ -1,11 +0,0 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=IpAs-pDwW_mo2FnbNkDpsR-XuxZW6h5TwvlHkc8kCuE,309
|
2
|
-
invrs_opt/base.py,sha256=dm5nzlO4KXFfuIfyHcTn9V1VCU6hAy1w3IA2vfzaQD8,1481
|
3
|
-
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
invrs_opt/lbfgsb/lbfgsb.py,sha256=nZqotWv9oGbY56UKSI3zcetexpwMoaDsvqwJSgXjvwc,19597
|
6
|
-
invrs_opt/lbfgsb/transform.py,sha256=TjFSeWGqlJv8uY4jtgaZ38Z5hplSX5WSQfQzN8rMV5U,5927
|
7
|
-
invrs_opt-0.1.3.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
|
8
|
-
invrs_opt-0.1.3.dist-info/METADATA,sha256=oXMNM5J_RzBpmnNB3IBvFhscJg6nQbyLWNMz-OljYHQ,3272
|
9
|
-
invrs_opt-0.1.3.dist-info/WHEEL,sha256=Xo9-1PvkuimrydujYJAjF7pCkriuXBpUPEjma1nZyJ0,92
|
10
|
-
invrs_opt-0.1.3.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
11
|
-
invrs_opt-0.1.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|