invrs-opt 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,8 +3,8 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.1.4"
6
+ __version__ = "v0.2.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
- from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
10
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
+ from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
@@ -10,20 +10,19 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
10
10
  import jax
11
11
  import jax.numpy as jnp
12
12
  import numpy as onp
13
- from jax import flatten_util
14
- from jax import tree_util
13
+ from jax import flatten_util, tree_util
15
14
  from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
16
15
  _lbfgsb as scipy_lbfgsb,
17
16
  )
17
+ from totypes import types
18
18
 
19
- from invrs_opt.lbfgsb import transform
20
19
  from invrs_opt import base
21
- from totypes import types
20
+ from invrs_opt.lbfgsb import transform
22
21
 
23
22
  NDArray = onp.ndarray[Any, Any]
24
23
  PyTree = Any
25
24
  ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
26
- LbfgsbState = Tuple[PyTree, Dict[str, NDArray]]
25
+ LbfgsbState = Tuple[PyTree, Dict[str, jnp.ndarray]]
27
26
 
28
27
 
29
28
  # Task message prefixes for the underlying L-BFGS-B implementation.
@@ -188,18 +187,24 @@ def transformed_lbfgsb(
188
187
 
189
188
  def init_fn(params: PyTree) -> LbfgsbState:
190
189
  """Initializes the optimization state."""
191
- lower_bound = types.extract_lower_bound(params)
192
- upper_bound = types.extract_upper_bound(params)
193
- scipy_lbfgsb_state = ScipyLbfgsbState.init(
194
- x0=_to_numpy(params),
195
- lower_bound=_bound_for_params(lower_bound, params),
196
- upper_bound=_bound_for_params(upper_bound, params),
197
- maxcor=maxcor,
198
- line_search_max_steps=line_search_max_steps,
190
+
191
+ def _init_pure(params: PyTree) -> LbfgsbState:
192
+ lower_bound = types.extract_lower_bound(params)
193
+ upper_bound = types.extract_upper_bound(params)
194
+ scipy_lbfgsb_state = ScipyLbfgsbState.init(
195
+ x0=_to_numpy(params),
196
+ lower_bound=_bound_for_params(lower_bound, params),
197
+ upper_bound=_bound_for_params(upper_bound, params),
198
+ maxcor=maxcor,
199
+ line_search_max_steps=line_search_max_steps,
200
+ )
201
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
202
+ params = transform_fn(latent_params)
203
+ return params, scipy_lbfgsb_state.to_jax()
204
+
205
+ return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
206
+ _init_pure, _example_state(params, maxcor), params
199
207
  )
200
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
201
- params = transform_fn(latent_params)
202
- return (params, dataclasses.asdict(scipy_lbfgsb_state))
203
208
 
204
209
  def params_fn(state: LbfgsbState) -> PyTree:
205
210
  """Returns the parameters for the given `state`."""
@@ -214,23 +219,30 @@ def transformed_lbfgsb(
214
219
  state: LbfgsbState,
215
220
  ) -> LbfgsbState:
216
221
  """Updates the state."""
217
- del params
218
- params, lbfgsb_state_dict = state
219
- # Avoid in-place updates.
220
- lbfgsb_state_dict = copy.deepcopy(lbfgsb_state_dict)
221
- scipy_lbfgsb_state = ScipyLbfgsbState(
222
- **lbfgsb_state_dict # type: ignore[arg-type]
223
- )
224
222
 
225
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
226
- _, vjp_fn = jax.vjp(transform_fn, latent_params)
227
- (latent_grad,) = vjp_fn(grad)
223
+ def _update_pure(
224
+ grad: PyTree, value: float, params: PyTree, state: LbfgsbState
225
+ ) -> LbfgsbState:
226
+ del params
227
+
228
+ params, jax_lbfgsb_state = state
229
+ scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
228
230
 
229
- assert onp.size(value) == 1
230
- scipy_lbfgsb_state.update(grad=_to_numpy(latent_grad), value=onp.asarray(value))
231
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
232
- params = transform_fn(latent_params)
233
- return (params, dataclasses.asdict(scipy_lbfgsb_state))
231
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
232
+ _, vjp_fn = jax.vjp(transform_fn, latent_params)
233
+ (latent_grad,) = vjp_fn(grad)
234
+
235
+ assert onp.size(value) == 1
236
+ scipy_lbfgsb_state.update(
237
+ grad=_to_numpy(latent_grad), value=onp.asarray(value)
238
+ )
239
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
240
+ params = transform_fn(latent_params)
241
+ return params, scipy_lbfgsb_state.to_jax()
242
+
243
+ return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
244
+ _update_pure, state, grad, value, params, state
245
+ )
234
246
 
235
247
  return base.Optimizer(
236
248
  init=init_fn,
@@ -331,6 +343,29 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
331
343
  return bound_flat
332
344
 
333
345
 
346
+ def _example_state(params: PyTree, maxcor: int) -> PyTree:
347
+ """Return an example state for the given `params` and `maxcor`."""
348
+ params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
349
+ n = params_flat.size
350
+ float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
351
+ example_jax_lbfgsb_state = dict(
352
+ x=jnp.zeros(n, dtype=float),
353
+ _maxcor=jnp.zeros((), dtype=int),
354
+ _line_search_max_steps=jnp.zeros((), dtype=int),
355
+ _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
356
+ _iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
357
+ _task=jnp.zeros(59, dtype=int),
358
+ _csave=jnp.zeros(59, dtype=int),
359
+ _lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
360
+ _isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
361
+ _dsave=jnp.zeros(29, dtype=float),
362
+ _lower_bound=jnp.zeros(n, dtype=float),
363
+ _upper_bound=jnp.zeros(n, dtype=float),
364
+ _bound_type=jnp.zeros(n, dtype=int),
365
+ )
366
+ return float_params, example_jax_lbfgsb_state
367
+
368
+
334
369
  # ------------------------------------------------------------------------------
335
370
  # Wrapper for scipy's L-BFGS-B implementation.
336
371
  # ------------------------------------------------------------------------------
@@ -395,6 +430,44 @@ class ScipyLbfgsbState:
395
430
  _validate_array_dtype(self._upper_bound, onp.float64)
396
431
  _validate_array_dtype(self._bound_type, int)
397
432
 
433
+ def to_jax(self) -> Dict[str, jnp.ndarray]:
434
+ """Generates a dictionary of jax arrays defining the state."""
435
+ return dict(
436
+ x=jnp.asarray(self.x),
437
+ _maxcor=jnp.asarray(self._maxcor),
438
+ _line_search_max_steps=jnp.asarray(self._line_search_max_steps),
439
+ _wa=jnp.asarray(self._wa),
440
+ _iwa=jnp.asarray(self._iwa),
441
+ _task=_array_from_s60_str(self._task),
442
+ _csave=_array_from_s60_str(self._csave),
443
+ _lsave=jnp.asarray(self._lsave),
444
+ _isave=jnp.asarray(self._isave),
445
+ _dsave=jnp.asarray(self._dsave),
446
+ _lower_bound=jnp.asarray(self._lower_bound),
447
+ _upper_bound=jnp.asarray(self._upper_bound),
448
+ _bound_type=jnp.asarray(self._bound_type),
449
+ )
450
+
451
+ @classmethod
452
+ def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
453
+ """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
454
+ state_dict = copy.deepcopy(state_dict)
455
+ return ScipyLbfgsbState(
456
+ x=onp.asarray(state_dict["x"], dtype=onp.float64),
457
+ _maxcor=int(state_dict["_maxcor"]),
458
+ _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
459
+ _wa=onp.asarray(state_dict["_wa"], onp.float64),
460
+ _iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
461
+ _task=_s60_str_from_array(state_dict["_task"]),
462
+ _csave=_s60_str_from_array(state_dict["_csave"]),
463
+ _lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
464
+ _isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
465
+ _dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
466
+ _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
467
+ _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
468
+ _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
469
+ )
470
+
398
471
  @classmethod
399
472
  def init(
400
473
  cls,
@@ -440,12 +513,12 @@ class ScipyLbfgsbState:
440
513
 
441
514
  # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
442
515
  # function.
443
- wa_shape = 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
516
+ wa_size = _wa_size(n=n, maxcor=maxcor)
444
517
  state = ScipyLbfgsbState(
445
518
  x=onp.array(x0, onp.float64),
446
519
  _maxcor=maxcor,
447
520
  _line_search_max_steps=line_search_max_steps,
448
- _wa=onp.zeros(wa_shape, onp.float64),
521
+ _wa=onp.zeros(wa_size, onp.float64),
449
522
  _iwa=onp.zeros(3 * n, FORTRAN_INT),
450
523
  _task=task,
451
524
  _csave=onp.zeros(1, "S60"),
@@ -510,6 +583,11 @@ class ScipyLbfgsbState:
510
583
  break
511
584
 
512
585
 
586
+ def _wa_size(n: int, maxcor: int) -> int:
587
+ """Return the size of the `wa` attribute of lbfgsb state."""
588
+ return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
589
+
590
+
513
591
  def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
514
592
  """Validates that `x` is an array with the specified `dtype`."""
515
593
  if not isinstance(x, onp.ndarray):
@@ -534,3 +612,19 @@ def _configure_bounds(
534
612
  onp.asarray(upper_bound_array, onp.float64),
535
613
  onp.asarray(bound_type),
536
614
  )
615
+
616
+
617
+ def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
618
+ """Return a jax array for a numpy s60 string."""
619
+ assert s60_str.shape == (1,)
620
+ chars = [int(o) for o in s60_str[0]]
621
+ chars.extend([32] * (59 - len(chars)))
622
+ return jnp.asarray(chars, dtype=int)
623
+
624
+
625
+ def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
626
+ """Return a numpy s60 string for a jax array."""
627
+ return onp.asarray(
628
+ [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
629
+ dtype="S60",
630
+ )
@@ -8,7 +8,6 @@ from typing import Tuple, Union
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  from jax import tree_util
11
-
12
11
  from totypes import types
13
12
 
14
13
  PadMode = Union[str, Tuple[str, str]]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.1.4
3
+ Version: 0.2.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -47,7 +47,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
47
47
  Requires-Dist: pytest-subtests ; extra == 'tests'
48
48
 
49
49
  # invrs-opt - Optimization algorithms for inverse design
50
- `v0.1.4`
50
+ `v0.2.0`
51
51
 
52
52
  ## Overview
53
53
 
@@ -0,0 +1,11 @@
1
+ invrs_opt/__init__.py,sha256=ZRbgvTG7O9G3ZUHOofIAqDuyNXZoD924wUWForaFZM4,309
2
+ invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ invrs_opt/lbfgsb/lbfgsb.py,sha256=POK422WPlfVZw63y5i8wkrbh18YzojW2L5nM-eue7PI,23528
6
+ invrs_opt/lbfgsb/transform.py,sha256=gYRBHUfhdkzSeBLfncdroWs3PP_o5T2X0GfDhTc82Rs,5926
7
+ invrs_opt-0.2.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
8
+ invrs_opt-0.2.0.dist-info/METADATA,sha256=vmoIcnI1tL3zvvGwpJBcPula-rIZ3HVbs7bzyvfpZls,3272
9
+ invrs_opt-0.2.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
10
+ invrs_opt-0.2.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
11
+ invrs_opt-0.2.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- invrs_opt/__init__.py,sha256=ByNDfFD5Jc1XezSskVZH4ERm3pBqhBkcxS5V8qtKlHU,309
2
- invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- invrs_opt/lbfgsb/lbfgsb.py,sha256=tlVU6wxaSCr2FSYfEe3RUaTFT9EFyyk4i71l508vNsI,19393
6
- invrs_opt/lbfgsb/transform.py,sha256=TjFSeWGqlJv8uY4jtgaZ38Z5hplSX5WSQfQzN8rMV5U,5927
7
- invrs_opt-0.1.4.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
8
- invrs_opt-0.1.4.dist-info/METADATA,sha256=iVhKz15sEHlsssEtbCmvkOs3ij5h4XGz5I9gNyUY5D4,3272
9
- invrs_opt-0.1.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
10
- invrs_opt-0.1.4.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
11
- invrs_opt-0.1.4.dist-info/RECORD,,