invrs-opt 0.5.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,8 +3,12 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.5.1"
6
+ __version__ = "v0.6.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
10
  from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
11
+ from invrs_opt.wrapped_optax.wrapped_optax import (
12
+ density_wrapped_optax as density_wrapped_optax,
13
+ )
14
+ from invrs_opt.wrapped_optax.wrapped_optax import wrapped_optax as wrapped_optax
invrs_opt/base.py CHANGED
@@ -6,6 +6,9 @@ Copyright (c) 2023 The INVRS-IO authors.
6
6
  import dataclasses
7
7
  from typing import Any, Protocol
8
8
 
9
+ import optax # type: ignore[import-untyped]
10
+ from totypes import json_utils
11
+
9
12
  PyTree = Any
10
13
 
11
14
 
@@ -44,3 +47,8 @@ class Optimizer:
44
47
  init: InitFn
45
48
  params: ParamsFn
46
49
  update: UpdateFn
50
+
51
+
52
+ # TODO: consider programatically registering all optax states here.
53
+ json_utils.register_custom_type(optax.EmptyState)
54
+ json_utils.register_custom_type(optax.ScaleByAdamState)
@@ -4,16 +4,15 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import json
7
- import requests
8
7
  import time
9
8
  from typing import Any, Dict, Optional
10
9
 
10
+ import requests
11
11
  from totypes import json_utils
12
12
 
13
13
  from invrs_opt import base
14
14
  from invrs_opt.experimental import labels
15
15
 
16
-
17
16
  PyTree = Any
18
17
  StateToken = str
19
18
 
@@ -16,8 +16,7 @@ from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
16
16
  )
17
17
  from totypes import types
18
18
 
19
- from invrs_opt import base
20
- from invrs_opt.lbfgsb import transform
19
+ from invrs_opt import base, transform
21
20
 
22
21
  NDArray = onp.ndarray[Any, Any]
23
22
  PyTree = Any
@@ -258,7 +257,7 @@ def transformed_lbfgsb(
258
257
  (
259
258
  latent_params,
260
259
  jax_lbfgsb_state,
261
- ) = jax.pure_callback( # type: ignore[attr-defined]
260
+ ) = jax.pure_callback(
262
261
  _init_pure,
263
262
  _example_state(params, maxcor),
264
263
  initialize_latent_fn(params),
@@ -304,7 +303,7 @@ def transformed_lbfgsb(
304
303
  (
305
304
  flat_latent_params,
306
305
  jax_lbfgsb_state,
307
- ) = jax.pure_callback( # type: ignore[attr-defined]
306
+ ) = jax.pure_callback(
308
307
  _update_pure,
309
308
  (flat_latent_grad, jax_lbfgsb_state),
310
309
  flat_latent_grad,
@@ -542,19 +541,19 @@ class ScipyLbfgsbState:
542
541
  """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
543
542
  state_dict = copy.deepcopy(state_dict)
544
543
  return ScipyLbfgsbState(
545
- x=onp.asarray(state_dict["x"], dtype=onp.float64),
544
+ x=onp.array(state_dict["x"], dtype=onp.float64),
546
545
  converged=onp.asarray(state_dict["converged"], dtype=bool),
547
546
  _maxcor=int(state_dict["_maxcor"]),
548
547
  _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
549
548
  _ftol=onp.asarray(state_dict["_ftol"], dtype=onp.float64),
550
549
  _gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
551
- _wa=onp.asarray(state_dict["_wa"], onp.float64),
552
- _iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
550
+ _wa=onp.array(state_dict["_wa"], onp.float64),
551
+ _iwa=onp.array(state_dict["_iwa"], dtype=FORTRAN_INT),
553
552
  _task=_s60_str_from_array(state_dict["_task"]),
554
553
  _csave=_s60_str_from_array(state_dict["_csave"]),
555
- _lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
556
- _isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
557
- _dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
554
+ _lsave=onp.array(state_dict["_lsave"], dtype=FORTRAN_INT),
555
+ _isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
556
+ _dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
558
557
  _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
559
558
  _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
560
559
  _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
File without changes
@@ -0,0 +1,150 @@
1
+ import dataclasses
2
+ from typing import Any, Callable, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import optax # type: ignore[import-untyped]
7
+ from jax import tree_util
8
+ from totypes import types
9
+
10
+ from invrs_opt import base, transform
11
+
12
+ PyTree = Any
13
+ WrappedOptaxState = Tuple[PyTree, PyTree, PyTree]
14
+
15
+
16
+ def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
17
+ """Return a wrapped optax optimizer."""
18
+ return transformed_wrapped_optax(
19
+ opt=opt,
20
+ transform_fn=lambda x: x,
21
+ initialize_latent_fn=lambda x: x,
22
+ )
23
+
24
+
25
+ def density_wrapped_optax(
26
+ opt: optax.GradientTransformation,
27
+ beta: float,
28
+ ) -> base.Optimizer:
29
+ """Return a wrapped optax optimizer with transforms for density arrays."""
30
+
31
+ def transform_fn(tree: PyTree) -> PyTree:
32
+ return tree_util.tree_map(
33
+ lambda x: transform_density(x) if _is_density(x) else x,
34
+ tree,
35
+ is_leaf=_is_density,
36
+ )
37
+
38
+ def initialize_latent_fn(tree: PyTree) -> PyTree:
39
+ return tree_util.tree_map(
40
+ lambda x: initialize_latent_density(x) if _is_density(x) else x,
41
+ tree,
42
+ is_leaf=_is_density,
43
+ )
44
+
45
+ def transform_density(density: types.Density2DArray) -> types.Density2DArray:
46
+ transformed = types.symmetrize_density(density)
47
+ transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
48
+ # Scale to ensure that the full valid range of the density array is reachable.
49
+ mid_value = (density.lower_bound + density.upper_bound) / 2
50
+ transformed = tree_util.tree_map(
51
+ lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
52
+ )
53
+ return transform.apply_fixed_pixels(transformed)
54
+
55
+ def initialize_latent_density(
56
+ density: types.Density2DArray,
57
+ ) -> types.Density2DArray:
58
+ array = transform.normalized_array_from_density(density)
59
+ array = jnp.clip(array, -1, 1)
60
+ array *= jnp.tanh(beta)
61
+ latent_array = jnp.arctanh(array) / beta
62
+ latent_array = transform.rescale_array_for_density(latent_array, density)
63
+ return dataclasses.replace(density, array=latent_array)
64
+
65
+ return transformed_wrapped_optax(
66
+ opt=opt,
67
+ transform_fn=transform_fn,
68
+ initialize_latent_fn=initialize_latent_fn,
69
+ )
70
+
71
+
72
+ def transformed_wrapped_optax(
73
+ opt: optax.GradientTransformation,
74
+ transform_fn: Callable[[PyTree], PyTree],
75
+ initialize_latent_fn: Callable[[PyTree], PyTree],
76
+ ) -> base.Optimizer:
77
+ """Return a wrapped optax optimizer for transformed latent parameters.
78
+
79
+ Args:
80
+ opt: The optax `GradientTransformation` to be wrapped.
81
+ transform_fn: Function which transforms the internal latent parameters to
82
+ the parameters returned by the optimizer.
83
+ initialize_latent_fn: Function which computes the initial latent parameters
84
+ given the initial parameters.
85
+
86
+ Returns:
87
+ The `base.Optimizer`.
88
+ """
89
+
90
+ def init_fn(params: PyTree) -> WrappedOptaxState:
91
+ """Initializes the optimization state."""
92
+ latent_params = initialize_latent_fn(_clip(params))
93
+ params = transform_fn(latent_params)
94
+ return params, latent_params, opt.init(latent_params)
95
+
96
+ def params_fn(state: WrappedOptaxState) -> PyTree:
97
+ """Returns the parameters for the given `state`."""
98
+ params, _, _ = state
99
+ return params
100
+
101
+ def update_fn(
102
+ *,
103
+ grad: PyTree,
104
+ value: float,
105
+ params: PyTree,
106
+ state: WrappedOptaxState,
107
+ ) -> WrappedOptaxState:
108
+ """Updates the state."""
109
+ del value
110
+
111
+ _, latent_params, opt_state = state
112
+ _, vjp_fn = jax.vjp(transform_fn, latent_params)
113
+ (latent_grad,) = vjp_fn(grad)
114
+
115
+ updates, opt_state = opt.update(latent_grad, opt_state)
116
+ latent_params = optax.apply_updates(params=latent_params, updates=updates)
117
+ latent_params = _clip(latent_params)
118
+ params = transform_fn(latent_params)
119
+ return params, latent_params, opt_state
120
+
121
+ return base.Optimizer(
122
+ init=init_fn,
123
+ params=params_fn,
124
+ update=update_fn,
125
+ )
126
+
127
+
128
+ def _is_density(leaf: Any) -> Any:
129
+ """Return `True` if `leaf` is a density array."""
130
+ return isinstance(leaf, types.Density2DArray)
131
+
132
+
133
+ def _is_custom_type(leaf: Any) -> bool:
134
+ """Return `True` if `leaf` is a recognized custom type."""
135
+ return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
136
+
137
+
138
+ def _clip(pytree: PyTree) -> PyTree:
139
+ """Clips leaves on `pytree` to their bounds."""
140
+
141
+ def _clip_fn(leaf: Any) -> Any:
142
+ if not _is_custom_type(leaf):
143
+ return leaf
144
+ if leaf.lower_bound is None and leaf.upper_bound is None:
145
+ return leaf
146
+ return tree_util.tree_map(
147
+ lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
148
+ )
149
+
150
+ return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.5.1
3
+ Version: 0.6.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -34,6 +34,7 @@ Requires-Dist: jax
34
34
  Requires-Dist: jaxlib
35
35
  Requires-Dist: numpy
36
36
  Requires-Dist: requests
37
+ Requires-Dist: optax
37
38
  Requires-Dist: scipy
38
39
  Requires-Dist: totypes
39
40
  Requires-Dist: types-requests
@@ -49,7 +50,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
49
50
  Requires-Dist: pytest-subtests ; extra == 'tests'
50
51
 
51
52
  # invrs-opt - Optimization algorithms for inverse design
52
- `v0.5.1`
53
+ `v0.6.0`
53
54
 
54
55
  ## Overview
55
56
 
@@ -0,0 +1,16 @@
1
+ invrs_opt/__init__.py,sha256=35pvMpeqvJgU4DizUO5hTzeE9j93prbB9RMtcaoFYwg,496
2
+ invrs_opt/base.py,sha256=FdQyPTlWGo03YztI3K2_QBN6Q-v0PeXv6XCyXu_uh_4,1160
3
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
5
+ invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ invrs_opt/experimental/client.py,sha256=MqC_TguT9IGrG7WW54vwz6QQMylKkbCjHxFPIG9vQMA,4841
7
+ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
8
+ invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ invrs_opt/lbfgsb/lbfgsb.py,sha256=pfrqCaOMco-eHUQe2q03hbla9D2TYqmMB-07jK4-5Ik,27792
10
+ invrs_opt/wrapped_optax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ invrs_opt/wrapped_optax/wrapped_optax.py,sha256=-ke0MNCb2EB0ntlj5IHIHrvybOVF4m24DM6JI4_Ktcc,4974
12
+ invrs_opt-0.6.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
13
+ invrs_opt-0.6.0.dist-info/METADATA,sha256=V4hpzjEovC2CU0IgUBMAgKzPfgqGvLlutQ9X5-FOuIk,3347
14
+ invrs_opt-0.6.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
15
+ invrs_opt-0.6.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
16
+ invrs_opt-0.6.0.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- invrs_opt/__init__.py,sha256=rrTPxvCPLpsNopphWwz4MNzs0309YJ2FqnoSoym8MjM,309
2
- invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- invrs_opt/experimental/client.py,sha256=td5o_YqqbcSypDrWCVrHGSoF8UxEdOLtKU0z9Dth9lA,4842
6
- invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
7
- invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- invrs_opt/lbfgsb/lbfgsb.py,sha256=DOfPVrvHYq7KBrHx3ibRO_ik4PO2D0QQtCwoubVJfVU,27892
9
- invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
10
- invrs_opt-0.5.1.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
11
- invrs_opt-0.5.1.dist-info/METADATA,sha256=xBu95bCkaOqZzteNeNHUta09LrlQ3oHZcTBVSd6j9Zk,3326
12
- invrs_opt-0.5.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- invrs_opt-0.5.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
14
- invrs_opt-0.5.1.dist-info/RECORD,,
File without changes