invrs-opt 0.1.3__tar.gz → 0.2.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.1.3
3
+ Version: 0.2.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -47,7 +47,7 @@ Requires-Dist: mypy; extra == "dev"
47
47
  Requires-Dist: pre-commit; extra == "dev"
48
48
 
49
49
  # invrs-opt - Optimization algorithms for inverse design
50
- `v0.1.3`
50
+ `v0.2.0`
51
51
 
52
52
  ## Overview
53
53
 
@@ -1,5 +1,5 @@
1
1
  # invrs-opt - Optimization algorithms for inverse design
2
- `v0.1.3`
2
+ `v0.2.0`
3
3
 
4
4
  ## Overview
5
5
 
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
 
3
3
  name = "invrs_opt"
4
- version = "v0.1.3"
4
+ version = "v0.2.0"
5
5
  description = "Algorithms for inverse design"
6
6
  keywords = ["topology", "optimization", "jax", "inverse design"]
7
7
  readme = "README.md"
@@ -3,8 +3,8 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.1.3"
6
+ __version__ = "v0.2.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
- from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
10
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
+ from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
@@ -6,8 +6,6 @@ Copyright (c) 2023 The INVRS-IO authors.
6
6
  import dataclasses
7
7
  from typing import Any, Protocol
8
8
 
9
- from totypes import json_utils
10
-
11
9
  PyTree = Any
12
10
 
13
11
 
@@ -46,23 +44,3 @@ class Optimizer:
46
44
  init: InitFn
47
45
  params: ParamsFn
48
46
  update: UpdateFn
49
-
50
-
51
- # Additional custom types and prefixes used for serializing optimizer state.
52
- CUSTOM_TYPES_AND_PREFIXES = ()
53
-
54
-
55
- def serialize(tree: PyTree) -> str:
56
- """Serializes a pytree into a string."""
57
- return json_utils.json_from_pytree(
58
- tree,
59
- extra_custom_types_and_prefixes=CUSTOM_TYPES_AND_PREFIXES,
60
- )
61
-
62
-
63
- def deserialize(serialized: str) -> PyTree:
64
- """Restores a pytree from a string."""
65
- return json_utils.pytree_from_json(
66
- serialized,
67
- extra_custom_types_and_prefixes=CUSTOM_TYPES_AND_PREFIXES,
68
- )
@@ -10,19 +10,19 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
10
10
  import jax
11
11
  import jax.numpy as jnp
12
12
  import numpy as onp
13
- from jax import tree_util
13
+ from jax import flatten_util, tree_util
14
14
  from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
15
15
  _lbfgsb as scipy_lbfgsb,
16
16
  )
17
+ from totypes import types
17
18
 
18
- from invrs_opt.lbfgsb import transform
19
19
  from invrs_opt import base
20
- from totypes import types
20
+ from invrs_opt.lbfgsb import transform
21
21
 
22
22
  NDArray = onp.ndarray[Any, Any]
23
23
  PyTree = Any
24
24
  ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
25
- LbfgsbState = Tuple[PyTree, Dict[str, NDArray]]
25
+ LbfgsbState = Tuple[PyTree, Dict[str, jnp.ndarray]]
26
26
 
27
27
 
28
28
  # Task message prefixes for the underlying L-BFGS-B implementation.
@@ -187,18 +187,24 @@ def transformed_lbfgsb(
187
187
 
188
188
  def init_fn(params: PyTree) -> LbfgsbState:
189
189
  """Initializes the optimization state."""
190
- lower_bound = types.extract_lower_bound(params)
191
- upper_bound = types.extract_upper_bound(params)
192
- scipy_lbfgsb_state = ScipyLbfgsbState.init(
193
- x0=_to_numpy(params),
194
- lower_bound=_bound_for_params(lower_bound, params),
195
- upper_bound=_bound_for_params(upper_bound, params),
196
- maxcor=maxcor,
197
- line_search_max_steps=line_search_max_steps,
190
+
191
+ def _init_pure(params: PyTree) -> LbfgsbState:
192
+ lower_bound = types.extract_lower_bound(params)
193
+ upper_bound = types.extract_upper_bound(params)
194
+ scipy_lbfgsb_state = ScipyLbfgsbState.init(
195
+ x0=_to_numpy(params),
196
+ lower_bound=_bound_for_params(lower_bound, params),
197
+ upper_bound=_bound_for_params(upper_bound, params),
198
+ maxcor=maxcor,
199
+ line_search_max_steps=line_search_max_steps,
200
+ )
201
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
202
+ params = transform_fn(latent_params)
203
+ return params, scipy_lbfgsb_state.to_jax()
204
+
205
+ return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
206
+ _init_pure, _example_state(params, maxcor), params
198
207
  )
199
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
200
- params = transform_fn(latent_params)
201
- return (params, dataclasses.asdict(scipy_lbfgsb_state))
202
208
 
203
209
  def params_fn(state: LbfgsbState) -> PyTree:
204
210
  """Returns the parameters for the given `state`."""
@@ -213,23 +219,30 @@ def transformed_lbfgsb(
213
219
  state: LbfgsbState,
214
220
  ) -> LbfgsbState:
215
221
  """Updates the state."""
216
- del params
217
- params, lbfgsb_state_dict = state
218
- # Avoid in-place updates.
219
- lbfgsb_state_dict = copy.deepcopy(lbfgsb_state_dict)
220
- scipy_lbfgsb_state = ScipyLbfgsbState(
221
- **lbfgsb_state_dict # type: ignore[arg-type]
222
- )
223
222
 
224
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
225
- _, vjp_fn = jax.vjp(transform_fn, latent_params)
226
- (latent_grad,) = vjp_fn(grad)
223
+ def _update_pure(
224
+ grad: PyTree, value: float, params: PyTree, state: LbfgsbState
225
+ ) -> LbfgsbState:
226
+ del params
227
227
 
228
- assert onp.size(value) == 1
229
- scipy_lbfgsb_state.update(grad=_to_numpy(latent_grad), value=onp.asarray(value))
230
- latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
231
- params = transform_fn(latent_params)
232
- return (params, dataclasses.asdict(scipy_lbfgsb_state))
228
+ params, jax_lbfgsb_state = state
229
+ scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
230
+
231
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
232
+ _, vjp_fn = jax.vjp(transform_fn, latent_params)
233
+ (latent_grad,) = vjp_fn(grad)
234
+
235
+ assert onp.size(value) == 1
236
+ scipy_lbfgsb_state.update(
237
+ grad=_to_numpy(latent_grad), value=onp.asarray(value)
238
+ )
239
+ latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
240
+ params = transform_fn(latent_params)
241
+ return params, scipy_lbfgsb_state.to_jax()
242
+
243
+ return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
244
+ _update_pure, state, grad, value, params, state
245
+ )
233
246
 
234
247
  return base.Optimizer(
235
248
  init=init_fn,
@@ -245,31 +258,25 @@ def transformed_lbfgsb(
245
258
 
246
259
  def _to_numpy(params: PyTree) -> NDArray:
247
260
  """Flattens a `params` pytree into a single rank-1 numpy array."""
248
- leaves = tree_util.tree_leaves(params)
249
- x_numpy: NDArray = onp.concatenate([onp.asarray(leaf).flatten() for leaf in leaves])
250
- return x_numpy.astype(onp.float64)
261
+ x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
262
+ return onp.asarray(x, dtype=onp.float64)
251
263
 
252
264
 
253
265
  def _to_pytree(x_flat: NDArray, params: PyTree) -> PyTree:
254
266
  """Restores a pytree from a flat numpy array using the structure of `params`.
255
267
 
268
+ Note that the returned pytree includes jax array leaves.
269
+
256
270
  Args:
257
271
  x_flat: The rank-1 numpy array to be restored.
258
272
  params: A pytree of parameters whose structure is replicated in the restored
259
273
  pytree.
260
274
 
261
275
  Returns:
262
- The restored pytree.
276
+ The restored pytree, with jax array leaves.
263
277
  """
264
- leaves, treedef = tree_util.tree_flatten(params)
265
- indices_or_sections = onp.cumsum([onp.size(leaf) for leaf in leaves])
266
- x_split_flat = onp.split(x_flat, indices_or_sections)
267
- x_split = [x.reshape(onp.shape(leaf)) for x, leaf in zip(x_split_flat, leaves)]
268
- x_split_jax = [
269
- jnp.asarray(x) if isinstance(leaf, jnp.ndarray) else x
270
- for x, leaf in zip(x_split, leaves)
271
- ]
272
- return tree_util.tree_unflatten(treedef, x_split_jax)
278
+ _, unflatten_fn = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
279
+ return unflatten_fn(jnp.asarray(x_flat, dtype=float))
273
280
 
274
281
 
275
282
  def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
@@ -320,6 +327,8 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
320
327
 
321
328
  bound_flat = []
322
329
  for b, p in zip(bound_leaves, params_leaves):
330
+ if p is None:
331
+ continue
323
332
  if b is None or onp.isscalar(b) or onp.shape(b) == ():
324
333
  bound_flat += [b] * onp.size(p)
325
334
  else:
@@ -334,6 +343,29 @@ def _bound_for_params(bound: PyTree, params: PyTree) -> ElementwiseBound:
334
343
  return bound_flat
335
344
 
336
345
 
346
+ def _example_state(params: PyTree, maxcor: int) -> PyTree:
347
+ """Return an example state for the given `params` and `maxcor`."""
348
+ params_flat, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
349
+ n = params_flat.size
350
+ float_params = tree_util.tree_map(lambda x: jnp.asarray(x, dtype=float), params)
351
+ example_jax_lbfgsb_state = dict(
352
+ x=jnp.zeros(n, dtype=float),
353
+ _maxcor=jnp.zeros((), dtype=int),
354
+ _line_search_max_steps=jnp.zeros((), dtype=int),
355
+ _wa=jnp.ones(_wa_size(n=n, maxcor=maxcor), dtype=float),
356
+ _iwa=jnp.ones(n * 3, dtype=jnp.int32), # Fortran int
357
+ _task=jnp.zeros(59, dtype=int),
358
+ _csave=jnp.zeros(59, dtype=int),
359
+ _lsave=jnp.zeros(4, dtype=jnp.int32), # Fortran int
360
+ _isave=jnp.zeros(44, dtype=jnp.int32), # Fortran int
361
+ _dsave=jnp.zeros(29, dtype=float),
362
+ _lower_bound=jnp.zeros(n, dtype=float),
363
+ _upper_bound=jnp.zeros(n, dtype=float),
364
+ _bound_type=jnp.zeros(n, dtype=int),
365
+ )
366
+ return float_params, example_jax_lbfgsb_state
367
+
368
+
337
369
  # ------------------------------------------------------------------------------
338
370
  # Wrapper for scipy's L-BFGS-B implementation.
339
371
  # ------------------------------------------------------------------------------
@@ -398,6 +430,44 @@ class ScipyLbfgsbState:
398
430
  _validate_array_dtype(self._upper_bound, onp.float64)
399
431
  _validate_array_dtype(self._bound_type, int)
400
432
 
433
+ def to_jax(self) -> Dict[str, jnp.ndarray]:
434
+ """Generates a dictionary of jax arrays defining the state."""
435
+ return dict(
436
+ x=jnp.asarray(self.x),
437
+ _maxcor=jnp.asarray(self._maxcor),
438
+ _line_search_max_steps=jnp.asarray(self._line_search_max_steps),
439
+ _wa=jnp.asarray(self._wa),
440
+ _iwa=jnp.asarray(self._iwa),
441
+ _task=_array_from_s60_str(self._task),
442
+ _csave=_array_from_s60_str(self._csave),
443
+ _lsave=jnp.asarray(self._lsave),
444
+ _isave=jnp.asarray(self._isave),
445
+ _dsave=jnp.asarray(self._dsave),
446
+ _lower_bound=jnp.asarray(self._lower_bound),
447
+ _upper_bound=jnp.asarray(self._upper_bound),
448
+ _bound_type=jnp.asarray(self._bound_type),
449
+ )
450
+
451
+ @classmethod
452
+ def from_jax(cls, state_dict: Dict[str, jnp.ndarray]) -> "ScipyLbfgsbState":
453
+ """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
454
+ state_dict = copy.deepcopy(state_dict)
455
+ return ScipyLbfgsbState(
456
+ x=onp.asarray(state_dict["x"], dtype=onp.float64),
457
+ _maxcor=int(state_dict["_maxcor"]),
458
+ _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
459
+ _wa=onp.asarray(state_dict["_wa"], onp.float64),
460
+ _iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
461
+ _task=_s60_str_from_array(state_dict["_task"]),
462
+ _csave=_s60_str_from_array(state_dict["_csave"]),
463
+ _lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
464
+ _isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
465
+ _dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
466
+ _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
467
+ _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
468
+ _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
469
+ )
470
+
401
471
  @classmethod
402
472
  def init(
403
473
  cls,
@@ -443,12 +513,12 @@ class ScipyLbfgsbState:
443
513
 
444
514
  # See initialization of internal variables in the `lbfgsb._minimize_lbfgsb`
445
515
  # function.
446
- wa_shape = 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
516
+ wa_size = _wa_size(n=n, maxcor=maxcor)
447
517
  state = ScipyLbfgsbState(
448
518
  x=onp.array(x0, onp.float64),
449
519
  _maxcor=maxcor,
450
520
  _line_search_max_steps=line_search_max_steps,
451
- _wa=onp.zeros(wa_shape, onp.float64),
521
+ _wa=onp.zeros(wa_size, onp.float64),
452
522
  _iwa=onp.zeros(3 * n, FORTRAN_INT),
453
523
  _task=task,
454
524
  _csave=onp.zeros(1, "S60"),
@@ -513,6 +583,11 @@ class ScipyLbfgsbState:
513
583
  break
514
584
 
515
585
 
586
+ def _wa_size(n: int, maxcor: int) -> int:
587
+ """Return the size of the `wa` attribute of lbfgsb state."""
588
+ return 2 * maxcor * n + 5 * n + 11 * maxcor**2 + 8 * maxcor
589
+
590
+
516
591
  def _validate_array_dtype(x: NDArray, dtype: Union[type, str]) -> None:
517
592
  """Validates that `x` is an array with the specified `dtype`."""
518
593
  if not isinstance(x, onp.ndarray):
@@ -537,3 +612,19 @@ def _configure_bounds(
537
612
  onp.asarray(upper_bound_array, onp.float64),
538
613
  onp.asarray(bound_type),
539
614
  )
615
+
616
+
617
+ def _array_from_s60_str(s60_str: NDArray) -> jnp.ndarray:
618
+ """Return a jax array for a numpy s60 string."""
619
+ assert s60_str.shape == (1,)
620
+ chars = [int(o) for o in s60_str[0]]
621
+ chars.extend([32] * (59 - len(chars)))
622
+ return jnp.asarray(chars, dtype=int)
623
+
624
+
625
+ def _s60_str_from_array(array: jnp.ndarray) -> NDArray:
626
+ """Return a numpy s60 string for a jax array."""
627
+ return onp.asarray(
628
+ [b"".join(int(i).to_bytes(length=1, byteorder="big") for i in array)],
629
+ dtype="S60",
630
+ )
@@ -8,7 +8,6 @@ from typing import Tuple, Union
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  from jax import tree_util
11
-
12
11
  from totypes import types
13
12
 
14
13
  PadMode = Union[str, Tuple[str, str]]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
- Name: invrs-opt
3
- Version: 0.1.3
2
+ Name: invrs_opt
3
+ Version: 0.2.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -47,7 +47,7 @@ Requires-Dist: mypy; extra == "dev"
47
47
  Requires-Dist: pre-commit; extra == "dev"
48
48
 
49
49
  # invrs-opt - Optimization algorithms for inverse design
50
- `v0.1.3`
50
+ `v0.2.0`
51
51
 
52
52
  ## Overview
53
53
 
@@ -10,9 +10,9 @@ import jax
10
10
  import jax.numpy as jnp
11
11
  import numpy as onp
12
12
  import parameterized
13
+ from totypes import json_utils, symmetry, types
13
14
 
14
15
  import invrs_opt
15
- from totypes import symmetry, types
16
16
 
17
17
  jax.config.update("jax_enable_x64", True)
18
18
 
@@ -153,14 +153,22 @@ def _lists_to_tuple(pytree, max_depth=10):
153
153
  return pytree
154
154
 
155
155
 
156
+ def serialize(pytree) -> str:
157
+ return json_utils.json_from_pytree(pytree=pytree)
158
+
159
+
160
+ def deserialize(serialized):
161
+ return json_utils.pytree_from_json(serialized=serialized)
162
+
163
+
156
164
  class BasicOptimizerTest(unittest.TestCase):
157
165
  @parameterized.parameterized.expand(itertools.product(PARAMS, OPTIMIZERS))
158
166
  def test_state_is_serializable(self, params, opt):
159
167
  state = opt.init(params)
160
168
  leaves, treedef = jax.tree_util.tree_flatten(state)
161
169
 
162
- serialized_state = invrs_opt.base.serialize(state)
163
- restored_state = invrs_opt.base.deserialize(serialized_state)
170
+ serialized_state = serialize(state)
171
+ restored_state = deserialize(serialized_state)
164
172
  # Serialization/deserialization unavoidably converts tuples to lists.
165
173
  # Convert back to tuples to facilitate comparison.
166
174
  restored_state = _lists_to_tuple(restored_state)
@@ -211,7 +219,7 @@ class BasicOptimizerTest(unittest.TestCase):
211
219
  expected_grad_list.append(grad)
212
220
 
213
221
  def serdes(x):
214
- return invrs_opt.base.deserialize(invrs_opt.base.serialize(x))
222
+ return deserialize(serialize(x))
215
223
 
216
224
  # Optimize with serialization.
217
225
  params_list = []
File without changes
File without changes