invrs-opt 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,8 +3,8 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.1.3"
6
+ __version__ = "v0.2.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
- from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
10
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
+ from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
invrs_opt/base.py CHANGED
@@ -6,8 +6,6 @@ Copyright (c) 2023 The INVRS-IO authors.
6
6
  import dataclasses
7
7
  from typing import Any, Protocol
8
8
 
9
- from totypes import json_utils
10
-
11
9
  PyTree = Any
12
10
 
13
11
 
@@ -46,23 +44,3 @@ class Optimizer:
46
44
  init: InitFn
47
45
  params: ParamsFn
48
46
  update: UpdateFn
49
-
50
-
51
- # Additional custom types and prefixes used for serializing optimizer state.
52
- CUSTOM_TYPES_AND_PREFIXES = ()
53
-
54
-
55
- def serialize(tree: PyTree) -> str:
56
- """Serializes a pytree into a string."""
57
- return json_utils.json_from_pytree(
58
- tree,
59
- extra_custom_types_and_prefixes=CUSTOM_TYPES_AND_PREFIXES,
60
- )
61
-
62
-
63
- def deserialize(serialized: str) -> PyTree:
64
- """Restores a pytree from a string."""
65
- return json_utils.pytree_from_json(
66
- serialized,
67
- extra_custom_types_and_prefixes=CUSTOM_TYPES_AND_PREFIXES,
68
- )
@@ -10,19 +10,19 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
10
10
  import jax
11
11
  import jax.numpy as jnp
12
12
  import numpy as onp
13
- from jax import tree_util
13
+ from jax import flatten_util, tree_util
14
14
  from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
15
15
  _lbfgsb as scipy_lbfgsb,
16
16
  )
17
+ from totypes import types
17
18
 
18
- from invrs_opt.lbfgsb import transform
19
19
  from invrs_opt import base
20
- from totypes import types
20
+ from invrs_opt.lbfgsb import transform
21
21
 
22
22
  NDArray = onp.ndarray[Any, Any]
23
23
  PyTree = Any
24
24
  ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
25
- LbfgsbState = Tuple[PyTree, Dict[str, NDArray]]
25
+ LbfgsbState = Tuple[PyTree, Dict[str, jnp.ndarray]]
26
26
 
27
27
 
28
28
  # Task message prefixes for the underlying L-BFGS-B implementation.
@@ -187,18 +187,24 @@ def transformed_lbfgsb(
187
187
 
188
188
  def init_fn(params: PyTree) -> LbfgsbState:
189
189
  """Initializes the optimization state."""
190
- lower_bound = types.extract_lower_bound(params)
191
- upper_bound = types.extract_upper_bound(params)
192
- scipy_lbfgsb_state = ScipyLbfgsbState.init(
193
- x0=_to_numpy(params),
194
- lower_bound=_bound_for_params(lower_bound, params),
195
- upper_bound=_bound_for_params(upper_bound, params),
196
- maxcor=maxcor,
197
- line_search_max_steps=line_search_max_steps,
190
+
191
+ def _init_pure(params: PyTree) -> LbfgsbState:
192
+ lower_bound = types.extract_lower_bound(params)
193
+ upper_bound = types.extract_upper_bound(params)
194
+ scipy_lbfgsb_state = ScipyLbfgsbState.init(
195
+ x0=_to_numpy(params),
196
+ lower_bound=_bound_for_params(lower_bound, params),
197
+ upper_bound=_bound_for_params(upper_bound, params),
198
+ maxcor=maxcor,
199
+ line_search_max_steps=line_search_max_steps,
200
+ )
201
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
202
+ params = transform_fn(latent_params)
203
+ return params, scipy_lbfgsb_state.to_jax()
204
+
205
+ return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
206
+ _init_pure, _example_state(params, maxcor), params
198
207
  )
199
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
200
- params = transform_fn(latent_params)
201
- return (params, dataclasses.asdict(scipy_lbfgsb_state))
202
208
 
203
209
  def params_fn(state: LbfgsbState) -> PyTree:
204
210
  """Returns the parameters for the given `state`."""
@@ -213,23 +219,30 @@ def transformed_lbfgsb(
213
219
  state: LbfgsbState,
214
220
  ) -> LbfgsbState:
215
221
  """Updates the state."""
216
- del params
217
- params, lbfgsb_state_dict = state
218
- # Avoid in-place updates.
219
- lbfgsb_state_dict = copy.deepcopy(lbfgsb_state_dict)
220
- scipy_lbfgsb_state = ScipyLbfgsbState(
221
- **lbfgsb_state_dict # type: ignore[arg-type]
222
- )
223
222
 
224
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
225
- _, vjp_fn = jax.vjp(transform_fn, latent_params)
226
- (latent_grad,) = vjp_fn(grad)
223
+ def _update_pure(
224
+ grad: PyTree, value: float, params: PyTree, state: LbfgsbState
225
+ ) -> LbfgsbState:
226
+ del params
227
227
 
228
- assert onp.size(value) == 1
229
- scipy_lbfgsb_state.update(grad=_to_numpy(latent_grad), value=onp.asarray(value))
230
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
231
- params = transform_fn(latent_params)
232
- return (params, dataclasses.asdict(scipy_lbfgsb_state))
228
+ params, jax_lbfgsb_state = state
229
+ scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
230
+
231
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
232
+ _, vjp_fn = jax.vjp(transform_fn, latent_params)
233
+ (latent_grad,) = vjp_fn(grad)
234
+
235
+ assert onp.size(value) == 1
236
+ scipy_lbfgsb_state.update(
237
+ grad=_to_numpy(latent_grad), value=onp.asarray(value)
238
+ )
239
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
240
+ params = transform_fn(latent_params)
241
+ return params, scipy_lbfgsb_state.to_jax()
242
+
243
+ return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
244
+ _update_pure, state, grad, value, params, state
245
+ )
233
246
 
234
247
  return base.Optimizer(
235
248
  init=init_fn,
@@ -245,31 +258,25 @@ def transformed_lbfgsb(
245
258
 
246
259
  def _to_numpy(params: PyTree) -> NDArray:
247
260
  """Flattens a `params` pytree into a single rank-1 numpy array."""
248
- leaves = tree_util.tree_leaves(params)
249
- x_numpy: NDArray = onp.concatenate([onp.asarray(leaf).flatten() for leaf in leaves])
250
- return x_numpy.astype(onp.float64)
261
+ x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
262
+ return onp.asarray(x, dtype=onp.float64)
251
263
 
252
264
 
253
265
  def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
254
266
  """Restores a pytree from a flat numpy array using the structure of `params`.
255
267
 
268
+ Note that the returned pytree includes jax array leaves.
269
+
256
270
  Args:
257
271
  x_flat: The rank-1 numpy array to be restored.
258
272
  params: A pytree of parameters whose structure is replicated in the restored
259
273
  pytree.
260
274
 
261
275
  Returns:
262
- The restored pytree.
276
+ The restored pytree, with jax array leaves.
263
277
  """
264
- leaves, treedef = tree_util.tree_flatten(params)
265
- indices_or_sections = onp.cumsum([onp.size(leaf) for leaf in leaves])
266
- x_split_flat = onp.split(x_flat, indices_or_sections)
267
- x_split = [x.reshape(onp.shape(leaf)) for x, leaf in zip(x_split_flat, leaves)]
268
- x_split_jax = [
269
- jnp.asarray(x) if isinstance(leaf, jnp.ndarray) else x
270
- for x, leaf in zip(x_split, leaves)
271
- ]
272
- return tree_util.tree_unflatten(treedef, x_split_jax)
278
+ _, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
279
+ return unflatten_fn(jnp.asarray(x_flat, dtype=float))
273
280
 
274
281
 
275
282
  def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
@@ -320,6 +327,8 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
320
327
 
321
328
  bound_flat = []
322
329
  for b, p in zip(bound_leaves, params_leaves):
330
+ if p is None:
331
+ continue
323
332
  if b is None or onp.isscalar(b) or onp.shape(b) == ():
324
333
  bound_flat += [b] * onp.size(p)
325
334
  else:
@@ -334,6 +343,29 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
334
343
  return bound_flat
335
344
 
336
345
 
346
+ def _example_state(params: PyTree, maxcor: int) -> PyTree:
347
+ """Return an example state for the given `params` and `maxcor`."""
348
+ params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
349
+ n = params_flat.size
350
+ float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
351
+ example_jax_lbfgsb_state = dict(
352
+ x=jnp.zeros(n, dtype=float),
353
+ _maxcor=jnp.zeros((), dtype=int),
354
+ _line_search_max_steps=jnp.zeros((), dtype=int),
355
+ _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
356
+ _iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
357
+ _task=jnp.zeros(59, dtype=int),
358
+ _csave=jnp.zeros(59, dtype=int),
359
+ _lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
360
+ _isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
361
+ _dsave=jnp.zeros(29, dtype=float),
362
+ _lower_bound=jnp.zeros(n, dtype=float),
363
+ _upper_bound=jnp.zeros(n, dtype=float),
364
+ _bound_type=jnp.zeros(n, dtype=int),
365
+ )
366
+ return float_params, example_jax_lbfgsb_state
367
+
368
+
337
369
  # ------------------------------------------------------------------------------
338
370
  # Wrapper for scipy's L-BFGS-B implementation.
339
371
  # ------------------------------------------------------------------------------
@@ -398,6 +430,44 @@ class ScipyLbfgsbState:
398
430
  _validate_array_dtype(self._upper_bound, onp.float64)
399
431
  _validate_array_dtype(self._bound_type, int)
400
432
 
433
+ def to_jax(self) -> Dict[str, jnp.ndarray]:
434
+ """Generates a dictionary of jax arrays defining the state."""
435
+ return dict(
436
+ x=jnp.asarray(self.x),
437
+ _maxcor=jnp.asarray(self._maxcor),
438
+ _line_search_max_steps=jnp.asarray(self._line_search_max_steps),
439
+ _wa=jnp.asarray(self._wa),
440
+ _iwa=jnp.asarray(self._iwa),
441
+ _task=_array_from_s60_str(self._task),
442
+ _csave=_array_from_s60_str(self._csave),
443
+ _lsave=jnp.asarray(self._lsave),
444
+ _isave=jnp.asarray(self._isave),
445
+ _dsave=jnp.asarray(self._dsave),
446
+ _lower_bound=jnp.asarray(self._lower_bound),
447
+ _upper_bound=jnp.asarray(self._upper_bound),
448
+ _bound_type=jnp.asarray(self._bound_type),
449
+ )
450
+
451
+ @classmethod
452
+ def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
453
+ """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
454
+ state_dict = copy.deepcopy(state_dict)
455
+ return ScipyLbfgsbState(
456
+ x=onp.asarray(state_dict["x"], dtype=onp.float64),
457
+ _maxcor=int(state_dict["_maxcor"]),
458
+ _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
459
+ _wa=onp.asarray(state_dict["_wa"], onp.float64),
460
+ _iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
461
+ _task=_s60_str_from_array(state_dict["_task"]),
462
+ _csave=_s60_str_from_array(state_dict["_csave"]),
463
+ _lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
464
+ _isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
465
+ _dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
466
+ _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
467
+ _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
468
+ _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
469
+ )
470
+
401
471
  @classmethod
402
472
  def init(
403
473
  cls,
@@ -443,12 +513,12 @@ class ScipyLbfgsbState:
443
513
 
444
514
  # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
445
515
  # function.
446
- wa_shape = 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
516
+ wa_size = _wa_size(n=n, maxcor=maxcor)
447
517
  state = ScipyLbfgsbState(
448
518
  x=onp.array(x0, onp.float64),
449
519
  _maxcor=maxcor,
450
520
  _line_search_max_steps=line_search_max_steps,
451
- _wa=onp.zeros(wa_shape, onp.float64),
521
+ _wa=onp.zeros(wa_size, onp.float64),
452
522
  _iwa=onp.zeros(3 * n, FORTRAN_INT),
453
523
  _task=task,
454
524
  _csave=onp.zeros(1, "S60"),
@@ -513,6 +583,11 @@ class ScipyLbfgsbState:
513
583
  break
514
584
 
515
585
 
586
+ def _wa_size(n: int, maxcor: int) -> int:
587
+ """Return the size of the `wa` attribute of lbfgsb state."""
588
+ return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
589
+
590
+
516
591
  def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
517
592
  """Validates that `x` is an array with the specified `dtype`."""
518
593
  if not isinstance(x, onp.ndarray):
@@ -537,3 +612,19 @@ def _configure_bounds(
537
612
  onp.asarray(upper_bound_array, onp.float64),
538
613
  onp.asarray(bound_type),
539
614
  )
615
+
616
+
617
+ def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
618
+ """Return a jax array for a numpy s60 string."""
619
+ assert s60_str.shape == (1,)
620
+ chars = [int(o) for o in s60_str[0]]
621
+ chars.extend([32] * (59 - len(chars)))
622
+ return jnp.asarray(chars, dtype=int)
623
+
624
+
625
+ def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
626
+ """Return a numpy s60 string for a jax array."""
627
+ return onp.asarray(
628
+ [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
629
+ dtype="S60",
630
+ )
@@ -8,7 +8,6 @@ from typing import Tuple, Union
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  from jax import tree_util
11
-
12
11
  from totypes import types
13
12
 
14
13
  PadMode = Union[str, Tuple[str, str]]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
- Name: invrs-opt
3
- Version: 0.1.3
2
+ Name: invrs_opt
3
+ Version: 0.2.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -47,7 +47,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
47
47
  Requires-Dist: pytest-subtests ; extra == 'tests'
48
48
 
49
49
  # invrs-opt - Optimization algorithms for inverse design
50
- `v0.1.3`
50
+ `v0.2.0`
51
51
 
52
52
  ## Overview
53
53
 
@@ -0,0 +1,11 @@
1
+ invrs_opt/__init__.py,sha256=ZRbgvTG7O9G3ZUHOofIAqDuyNXZoD924wUWForaFZM4,309
2
+ invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ invrs_opt/lbfgsb/lbfgsb.py,sha256=POK422WPlfVZw63y5i8wkrbh18YzojW2L5nM-eue7PI,23528
6
+ invrs_opt/lbfgsb/transform.py,sha256=gYRBHUfhdkzSeBLfncdroWs3PP_o5T2X0GfDhTc82Rs,5926
7
+ invrs_opt-0.2.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
8
+ invrs_opt-0.2.0.dist-info/METADATA,sha256=vmoIcnI1tL3zvvGwpJBcPula-rIZ3HVbs7bzyvfpZls,3272
9
+ invrs_opt-0.2.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
10
+ invrs_opt-0.2.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
11
+ invrs_opt-0.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.3)
2
+ Generator: bdist_wheel (0.42.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,11 +0,0 @@
1
- invrs_opt/__init__.py,sha256=IpAs-pDwW_mo2FnbNkDpsR-XuxZW6h5TwvlHkc8kCuE,309
2
- invrs_opt/base.py,sha256=dm5nzlO4KXFfuIfyHcTn9V1VCU6hAy1w3IA2vfzaQD8,1481
3
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- invrs_opt/lbfgsb/lbfgsb.py,sha256=nZqotWv9oGbY56UKSI3zcetexpwMoaDsvqwJSgXjvwc,19597
6
- invrs_opt/lbfgsb/transform.py,sha256=TjFSeWGqlJv8uY4jtgaZ38Z5hplSX5WSQfQzN8rMV5U,5927
7
- invrs_opt-0.1.3.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
8
- invrs_opt-0.1.3.dist-info/METADATA,sha256=oXMNM5J_RzBpmnNB3IBvFhscJg6nQbyLWNMz-OljYHQ,3272
9
- invrs_opt-0.1.3.dist-info/WHEEL,sha256=Xo9-1PvkuimrydujYJAjF7pCkriuXBpUPEjma1nZyJ0,92
10
- invrs_opt-0.1.3.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
11
- invrs_opt-0.1.3.dist-info/RECORD,,