invrs-opt 0.5.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,8 +3,12 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.5.2"
6
+ __version__ = "v0.6.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
10
  from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
11
+ from invrs_opt.wrapped_optax.wrapped_optax import (
12
+ density_wrapped_optax as density_wrapped_optax,
13
+ )
14
+ from invrs_opt.wrapped_optax.wrapped_optax import wrapped_optax as wrapped_optax
invrs_opt/base.py CHANGED
@@ -6,6 +6,9 @@ Copyright (c) 2023 The INVRS-IO authors.
6
6
  import dataclasses
7
7
  from typing import Any, Protocol
8
8
 
9
+ import optax # type: ignore[import-untyped]
10
+ from totypes import json_utils
11
+
9
12
  PyTree = Any
10
13
 
11
14
 
@@ -44,3 +47,8 @@ class Optimizer:
44
47
  init: InitFn
45
48
  params: ParamsFn
46
49
  update: UpdateFn
50
+
51
+
52
+ # TODO: consider programatically registering all optax states here.
53
+ json_utils.register_custom_type(optax.EmptyState)
54
+ json_utils.register_custom_type(optax.ScaleByAdamState)
@@ -4,16 +4,15 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import json
7
- import requests
8
7
  import time
9
8
  from typing import Any, Dict, Optional
10
9
 
10
+ import requests
11
11
  from totypes import json_utils
12
12
 
13
13
  from invrs_opt import base
14
14
  from invrs_opt.experimental import labels
15
15
 
16
-
17
16
  PyTree = Any
18
17
  StateToken = str
19
18
 
@@ -16,8 +16,7 @@ from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
16
16
  )
17
17
  from totypes import types
18
18
 
19
- from invrs_opt import base
20
- from invrs_opt.lbfgsb import transform
19
+ from invrs_opt import base, transform
21
20
 
22
21
  NDArray = onp.ndarray[Any, Any]
23
22
  PyTree = Any
File without changes
@@ -0,0 +1,150 @@
1
+ import dataclasses
2
+ from typing import Any, Callable, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import optax # type: ignore[import-untyped]
7
+ from jax import tree_util
8
+ from totypes import types
9
+
10
+ from invrs_opt import base, transform
11
+
12
+ PyTree = Any
13
+ WrappedOptaxState = Tuple[PyTree, PyTree, PyTree]
14
+
15
+
16
+ def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
17
+ """Return a wrapped optax optimizer."""
18
+ return transformed_wrapped_optax(
19
+ opt=opt,
20
+ transform_fn=lambda x: x,
21
+ initialize_latent_fn=lambda x: x,
22
+ )
23
+
24
+
25
+ def density_wrapped_optax(
26
+ opt: optax.GradientTransformation,
27
+ beta: float,
28
+ ) -> base.Optimizer:
29
+ """Return a wrapped optax optimizer with transforms for density arrays."""
30
+
31
+ def transform_fn(tree: PyTree) -> PyTree:
32
+ return tree_util.tree_map(
33
+ lambda x: transform_density(x) if _is_density(x) else x,
34
+ tree,
35
+ is_leaf=_is_density,
36
+ )
37
+
38
+ def initialize_latent_fn(tree: PyTree) -> PyTree:
39
+ return tree_util.tree_map(
40
+ lambda x: initialize_latent_density(x) if _is_density(x) else x,
41
+ tree,
42
+ is_leaf=_is_density,
43
+ )
44
+
45
+ def transform_density(density: types.Density2DArray) -> types.Density2DArray:
46
+ transformed = types.symmetrize_density(density)
47
+ transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
48
+ # Scale to ensure that the full valid range of the density array is reachable.
49
+ mid_value = (density.lower_bound + density.upper_bound) / 2
50
+ transformed = tree_util.tree_map(
51
+ lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
52
+ )
53
+ return transform.apply_fixed_pixels(transformed)
54
+
55
+ def initialize_latent_density(
56
+ density: types.Density2DArray,
57
+ ) -> types.Density2DArray:
58
+ array = transform.normalized_array_from_density(density)
59
+ array = jnp.clip(array, -1, 1)
60
+ array *= jnp.tanh(beta)
61
+ latent_array = jnp.arctanh(array) / beta
62
+ latent_array = transform.rescale_array_for_density(latent_array, density)
63
+ return dataclasses.replace(density, array=latent_array)
64
+
65
+ return transformed_wrapped_optax(
66
+ opt=opt,
67
+ transform_fn=transform_fn,
68
+ initialize_latent_fn=initialize_latent_fn,
69
+ )
70
+
71
+
72
+ def transformed_wrapped_optax(
73
+ opt: optax.GradientTransformation,
74
+ transform_fn: Callable[[PyTree], PyTree],
75
+ initialize_latent_fn: Callable[[PyTree], PyTree],
76
+ ) -> base.Optimizer:
77
+ """Return a wrapped optax optimizer for transformed latent parameters.
78
+
79
+ Args:
80
+ opt: The optax `GradientTransformation` to be wrapped.
81
+ transform_fn: Function which transforms the internal latent parameters to
82
+ the parameters returned by the optimizer.
83
+ initialize_latent_fn: Function which computes the initial latent parameters
84
+ given the initial parameters.
85
+
86
+ Returns:
87
+ The `base.Optimizer`.
88
+ """
89
+
90
+ def init_fn(params: PyTree) -> WrappedOptaxState:
91
+ """Initializes the optimization state."""
92
+ latent_params = initialize_latent_fn(_clip(params))
93
+ params = transform_fn(latent_params)
94
+ return params, latent_params, opt.init(latent_params)
95
+
96
+ def params_fn(state: WrappedOptaxState) -> PyTree:
97
+ """Returns the parameters for the given `state`."""
98
+ params, _, _ = state
99
+ return params
100
+
101
+ def update_fn(
102
+ *,
103
+ grad: PyTree,
104
+ value: float,
105
+ params: PyTree,
106
+ state: WrappedOptaxState,
107
+ ) -> WrappedOptaxState:
108
+ """Updates the state."""
109
+ del value
110
+
111
+ _, latent_params, opt_state = state
112
+ _, vjp_fn = jax.vjp(transform_fn, latent_params)
113
+ (latent_grad,) = vjp_fn(grad)
114
+
115
+ updates, opt_state = opt.update(latent_grad, opt_state)
116
+ latent_params = optax.apply_updates(params=latent_params, updates=updates)
117
+ latent_params = _clip(latent_params)
118
+ params = transform_fn(latent_params)
119
+ return params, latent_params, opt_state
120
+
121
+ return base.Optimizer(
122
+ init=init_fn,
123
+ params=params_fn,
124
+ update=update_fn,
125
+ )
126
+
127
+
128
+ def _is_density(leaf: Any) -> Any:
129
+ """Return `True` if `leaf` is a density array."""
130
+ return isinstance(leaf, types.Density2DArray)
131
+
132
+
133
+ def _is_custom_type(leaf: Any) -> bool:
134
+ """Return `True` if `leaf` is a recognized custom type."""
135
+ return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
136
+
137
+
138
+ def _clip(pytree: PyTree) -> PyTree:
139
+ """Clips leaves on `pytree` to their bounds."""
140
+
141
+ def _clip_fn(leaf: Any) -> Any:
142
+ if not _is_custom_type(leaf):
143
+ return leaf
144
+ if leaf.lower_bound is None and leaf.upper_bound is None:
145
+ return leaf
146
+ return tree_util.tree_map(
147
+ lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
148
+ )
149
+
150
+ return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.5.2
3
+ Version: 0.6.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -34,6 +34,7 @@ Requires-Dist: jax
34
34
  Requires-Dist: jaxlib
35
35
  Requires-Dist: numpy
36
36
  Requires-Dist: requests
37
+ Requires-Dist: optax
37
38
  Requires-Dist: scipy
38
39
  Requires-Dist: totypes
39
40
  Requires-Dist: types-requests
@@ -49,7 +50,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
49
50
  Requires-Dist: pytest-subtests ; extra == 'tests'
50
51
 
51
52
  # invrs-opt - Optimization algorithms for inverse design
52
- `v0.5.2`
53
+ `v0.6.0`
53
54
 
54
55
  ## Overview
55
56
 
@@ -0,0 +1,16 @@
1
+ invrs_opt/__init__.py,sha256=35pvMpeqvJgU4DizUO5hTzeE9j93prbB9RMtcaoFYwg,496
2
+ invrs_opt/base.py,sha256=FdQyPTlWGo03YztI3K2_QBN6Q-v0PeXv6XCyXu_uh_4,1160
3
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
5
+ invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ invrs_opt/experimental/client.py,sha256=MqC_TguT9IGrG7WW54vwz6QQMylKkbCjHxFPIG9vQMA,4841
7
+ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
8
+ invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ invrs_opt/lbfgsb/lbfgsb.py,sha256=pfrqCaOMco-eHUQe2q03hbla9D2TYqmMB-07jK4-5Ik,27792
10
+ invrs_opt/wrapped_optax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ invrs_opt/wrapped_optax/wrapped_optax.py,sha256=-ke0MNCb2EB0ntlj5IHIHrvybOVF4m24DM6JI4_Ktcc,4974
12
+ invrs_opt-0.6.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
13
+ invrs_opt-0.6.0.dist-info/METADATA,sha256=V4hpzjEovC2CU0IgUBMAgKzPfgqGvLlutQ9X5-FOuIk,3347
14
+ invrs_opt-0.6.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
15
+ invrs_opt-0.6.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
16
+ invrs_opt-0.6.0.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- invrs_opt/__init__.py,sha256=TPyc9kznbXNbM67XmXpnFIGTmGFQmYzPojKRueSZelU,309
2
- invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- invrs_opt/experimental/client.py,sha256=td5o_YqqbcSypDrWCVrHGSoF8UxEdOLtKU0z9Dth9lA,4842
6
- invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
7
- invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- invrs_opt/lbfgsb/lbfgsb.py,sha256=EKNTDTwEHcO_JeOIckMHVer5hxYcaNuqq9EZgnFnWJk,27820
9
- invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
10
- invrs_opt-0.5.2.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
11
- invrs_opt-0.5.2.dist-info/METADATA,sha256=SgIEqMR9ybcipS1NnHi-KSVTSZ3BzWKrv1ctw4jfqcE,3326
12
- invrs_opt-0.5.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- invrs_opt-0.5.2.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
14
- invrs_opt-0.5.2.dist-info/RECORD,,
File without changes