invrs-opt 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
invrs_opt/__init__.py CHANGED
@@ -3,8 +3,8 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.1.4"
6
+ __version__ = "v0.2.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
- from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
10
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
+ from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
@@ -10,20 +10,19 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
10
10
  import jax
11
11
  import jax.numpy as jnp
12
12
  import numpy as onp
13
- from jax import flatten_util
14
- from jax import tree_util
13
+ from jax import flatten_util, tree_util
15
14
  from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
16
15
  _lbfgsb as scipy_lbfgsb,
17
16
  )
17
+ from totypes import types
18
18
 
19
- from invrs_opt.lbfgsb import transform
20
19
  from invrs_opt import base
21
- from totypes import types
20
+ from invrs_opt.lbfgsb import transform
22
21
 
23
22
  NDArray = onp.ndarray[Any, Any]
24
23
  PyTree = Any
25
24
  ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
26
- LbfgsbState = Tuple[PyTree, Dict[str, NDArray]]
25
+ LbfgsbState = Tuple[PyTree, Dict[str, jnp.ndarray]]
27
26
 
28
27
 
29
28
  # Task message prefixes for the underlying L-BFGS-B implementation.
@@ -188,18 +187,24 @@ def transformed_lbfgsb(
188
187
 
189
188
  def init_fn(params: PyTree) -> LbfgsbState:
190
189
  """Initializes the optimization state."""
191
- lower_bound = types.extract_lower_bound(params)
192
- upper_bound = types.extract_upper_bound(params)
193
- scipy_lbfgsb_state = ScipyLbfgsbState.init(
194
- x0=_to_numpy(params),
195
- lower_bound=_bound_for_params(lower_bound, params),
196
- upper_bound=_bound_for_params(upper_bound, params),
197
- maxcor=maxcor,
198
- line_search_max_steps=line_search_max_steps,
190
+
191
+ def _init_pure(params: PyTree) -> LbfgsbState:
192
+ lower_bound = types.extract_lower_bound(params)
193
+ upper_bound = types.extract_upper_bound(params)
194
+ scipy_lbfgsb_state = ScipyLbfgsbState.init(
195
+ x0=_to_numpy(params),
196
+ lower_bound=_bound_for_params(lower_bound, params),
197
+ upper_bound=_bound_for_params(upper_bound, params),
198
+ maxcor=maxcor,
199
+ line_search_max_steps=line_search_max_steps,
200
+ )
201
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
202
+ params = transform_fn(latent_params)
203
+ return params, scipy_lbfgsb_state.to_jax()
204
+
205
+ return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
206
+ _init_pure, _example_state(params, maxcor), params
199
207
  )
200
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
201
- params = transform_fn(latent_params)
202
- return (params, dataclasses.asdict(scipy_lbfgsb_state))
203
208
 
204
209
  def params_fn(state: LbfgsbState) -> PyTree:
205
210
  """Returns the parameters for the given `state`."""
@@ -214,23 +219,30 @@ def transformed_lbfgsb(
214
219
  state: LbfgsbState,
215
220
  ) -> LbfgsbState:
216
221
  """Updates the state."""
217
- del params
218
- params, lbfgsb_state_dict = state
219
- # Avoid in-place updates.
220
- lbfgsb_state_dict = copy.deepcopy(lbfgsb_state_dict)
221
- scipy_lbfgsb_state = ScipyLbfgsbState(
222
- **lbfgsb_state_dict # type: ignore[arg-type]
223
- )
224
222
 
225
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
226
- _, vjp_fn = jax.vjp(transform_fn, latent_params)
227
- (latent_grad,) = vjp_fn(grad)
223
+ def _update_pure(
224
+ grad: PyTree, value: float, params: PyTree, state: LbfgsbState
225
+ ) -> LbfgsbState:
226
+ del params
227
+
228
+ params, jax_lbfgsb_state = state
229
+ scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
228
230
 
229
- assert onp.size(value) == 1
230
- scipy_lbfgsb_state.update(grad=_to_numpy(latent_grad), value=onp.asarray(value))
231
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
232
- params = transform_fn(latent_params)
233
- return (params, dataclasses.asdict(scipy_lbfgsb_state))
231
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
232
+ _, vjp_fn = jax.vjp(transform_fn, latent_params)
233
+ (latent_grad,) = vjp_fn(grad)
234
+
235
+ assert onp.size(value) == 1
236
+ scipy_lbfgsb_state.update(
237
+ grad=_to_numpy(latent_grad), value=onp.asarray(value)
238
+ )
239
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
240
+ params = transform_fn(latent_params)
241
+ return params, scipy_lbfgsb_state.to_jax()
242
+
243
+ return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
244
+ _update_pure, state, grad, value, params, state
245
+ )
234
246
 
235
247
  return base.Optimizer(
236
248
  init=init_fn,
@@ -331,6 +343,29 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
331
343
  return bound_flat
332
344
 
333
345
 
346
+ def _example_state(params: PyTree, maxcor: int) -> PyTree:
347
+ """Return an example state for the given `params` and `maxcor`."""
348
+ params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
349
+ n = params_flat.size
350
+ float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
351
+ example_jax_lbfgsb_state = dict(
352
+ x=jnp.zeros(n, dtype=float),
353
+ _maxcor=jnp.zeros((), dtype=int),
354
+ _line_search_max_steps=jnp.zeros((), dtype=int),
355
+ _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
356
+ _iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
357
+ _task=jnp.zeros(59, dtype=int),
358
+ _csave=jnp.zeros(59, dtype=int),
359
+ _lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
360
+ _isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
361
+ _dsave=jnp.zeros(29, dtype=float),
362
+ _lower_bound=jnp.zeros(n, dtype=float),
363
+ _upper_bound=jnp.zeros(n, dtype=float),
364
+ _bound_type=jnp.zeros(n, dtype=int),
365
+ )
366
+ return float_params, example_jax_lbfgsb_state
367
+
368
+
334
369
  # ------------------------------------------------------------------------------
335
370
  # Wrapper for scipy's L-BFGS-B implementation.
336
371
  # ------------------------------------------------------------------------------
@@ -395,6 +430,44 @@ class ScipyLbfgsbState:
395
430
  _validate_array_dtype(self._upper_bound, onp.float64)
396
431
  _validate_array_dtype(self._bound_type, int)
397
432
 
433
+ def to_jax(self) -> Dict[str, jnp.ndarray]:
434
+ """Generates a dictionary of jax arrays defining the state."""
435
+ return dict(
436
+ x=jnp.asarray(self.x),
437
+ _maxcor=jnp.asarray(self._maxcor),
438
+ _line_search_max_steps=jnp.asarray(self._line_search_max_steps),
439
+ _wa=jnp.asarray(self._wa),
440
+ _iwa=jnp.asarray(self._iwa),
441
+ _task=_array_from_s60_str(self._task),
442
+ _csave=_array_from_s60_str(self._csave),
443
+ _lsave=jnp.asarray(self._lsave),
444
+ _isave=jnp.asarray(self._isave),
445
+ _dsave=jnp.asarray(self._dsave),
446
+ _lower_bound=jnp.asarray(self._lower_bound),
447
+ _upper_bound=jnp.asarray(self._upper_bound),
448
+ _bound_type=jnp.asarray(self._bound_type),
449
+ )
450
+
451
+ @classmethod
452
+ def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
453
+ """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
454
+ state_dict = copy.deepcopy(state_dict)
455
+ return ScipyLbfgsbState(
456
+ x=onp.asarray(state_dict["x"], dtype=onp.float64),
457
+ _maxcor=int(state_dict["_maxcor"]),
458
+ _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
459
+ _wa=onp.asarray(state_dict["_wa"], onp.float64),
460
+ _iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
461
+ _task=_s60_str_from_array(state_dict["_task"]),
462
+ _csave=_s60_str_from_array(state_dict["_csave"]),
463
+ _lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
464
+ _isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
465
+ _dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
466
+ _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
467
+ _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
468
+ _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
469
+ )
470
+
398
471
  @classmethod
399
472
  def init(
400
473
  cls,
@@ -440,12 +513,12 @@ class ScipyLbfgsbState:
440
513
 
441
514
  # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
442
515
  # function.
443
- wa_shape = 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
516
+ wa_size = _wa_size(n=n, maxcor=maxcor)
444
517
  state = ScipyLbfgsbState(
445
518
  x=onp.array(x0, onp.float64),
446
519
  _maxcor=maxcor,
447
520
  _line_search_max_steps=line_search_max_steps,
448
- _wa=onp.zeros(wa_shape, onp.float64),
521
+ _wa=onp.zeros(wa_size, onp.float64),
449
522
  _iwa=onp.zeros(3 * n, FORTRAN_INT),
450
523
  _task=task,
451
524
  _csave=onp.zeros(1, "S60"),
@@ -510,6 +583,11 @@ class ScipyLbfgsbState:
510
583
  break
511
584
 
512
585
 
586
+ def _wa_size(n: int, maxcor: int) -> int:
587
+ """Return the size of the `wa` attribute of lbfgsb state."""
588
+ return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
589
+
590
+
513
591
  def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
514
592
  """Validates that `x` is an array with the specified `dtype`."""
515
593
  if not isinstance(x, onp.ndarray):
@@ -534,3 +612,19 @@ def _configure_bounds(
534
612
  onp.asarray(upper_bound_array, onp.float64),
535
613
  onp.asarray(bound_type),
536
614
  )
615
+
616
+
617
+ def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
618
+ """Return a jax array for a numpy s60 string."""
619
+ assert s60_str.shape == (1,)
620
+ chars = [int(o) for o in s60_str[0]]
621
+ chars.extend([32] * (59 - len(chars)))
622
+ return jnp.asarray(chars, dtype=int)
623
+
624
+
625
+ def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
626
+ """Return a numpy s60 string for a jax array."""
627
+ return onp.asarray(
628
+ [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
629
+ dtype="S60",
630
+ )
@@ -8,7 +8,6 @@ from typing import Tuple, Union
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  from jax import tree_util
11
-
12
11
  from totypes import types
13
12
 
14
13
  PadMode = Union[str, Tuple[str, str]]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.1.4
3
+ Version: 0.2.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -47,7 +47,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
47
47
  Requires-Dist: pytest-subtests ; extra == 'tests'
48
48
 
49
49
  # invrs-opt - Optimization algorithms for inverse design
50
- `v0.1.4`
50
+ `v0.2.0`
51
51
 
52
52
  ## Overview
53
53
 
@@ -0,0 +1,11 @@
1
+ invrs_opt/__init__.py,sha256=ZRbgvTG7O9G3ZUHOofIAqDuyNXZoD924wUWForaFZM4,309
2
+ invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ invrs_opt/lbfgsb/lbfgsb.py,sha256=POK422WPlfVZw63y5i8wkrbh18YzojW2L5nM-eue7PI,23528
6
+ invrs_opt/lbfgsb/transform.py,sha256=gYRBHUfhdkzSeBLfncdroWs3PP_o5T2X0GfDhTc82Rs,5926
7
+ invrs_opt-0.2.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
8
+ invrs_opt-0.2.0.dist-info/METADATA,sha256=vmoIcnI1tL3zvvGwpJBcPula-rIZ3HVbs7bzyvfpZls,3272
9
+ invrs_opt-0.2.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
10
+ invrs_opt-0.2.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
11
+ invrs_opt-0.2.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- invrs_opt/__init__.py,sha256=ByNDfFD5Jc1XezSskVZH4ERm3pBqhBkcxS5V8qtKlHU,309
2
- invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- invrs_opt/lbfgsb/lbfgsb.py,sha256=tlVU6wxaSCr2FSYfEe3RUaTFT9EFyyk4i71l508vNsI,19393
6
- invrs_opt/lbfgsb/transform.py,sha256=TjFSeWGqlJv8uY4jtgaZ38Z5hplSX5WSQfQzN8rMV5U,5927
7
- invrs_opt-0.1.4.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
8
- invrs_opt-0.1.4.dist-info/METADATA,sha256=iVhKz15sEHlsssEtbCmvkOs3ij5h4XGz5I9gNyUY5D4,3272
9
- invrs_opt-0.1.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
10
- invrs_opt-0.1.4.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
11
- invrs_opt-0.1.4.dist-info/RECORD,,