invrs-opt 0.5.2__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +2 -3
- invrs_opt/{base.py → optimizers/base.py} +10 -0
- invrs_opt/{lbfgsb → optimizers}/lbfgsb.py +293 -153
- invrs_opt/optimizers/wrapped_optax.py +300 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +148 -0
- invrs_opt/parameterization/filter_project.py +92 -0
- invrs_opt/parameterization/gaussian_levelset.py +643 -0
- invrs_opt/parameterization/pixel.py +45 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- {invrs_opt-0.5.2.dist-info → invrs_opt-0.7.0.dist-info}/METADATA +11 -10
- invrs_opt-0.7.0.dist-info/RECORD +20 -0
- {invrs_opt-0.5.2.dist-info → invrs_opt-0.7.0.dist-info}/WHEEL +1 -1
- invrs_opt-0.5.2.dist-info/RECORD +0 -14
- /invrs_opt/{lbfgsb → optimizers}/__init__.py +0 -0
- {invrs_opt-0.5.2.dist-info → invrs_opt-0.7.0.dist-info}/LICENSE +0 -0
- {invrs_opt-0.5.2.dist-info → invrs_opt-0.7.0.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -3,8 +3,19 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
__version__ = "v0.
|
6
|
+
__version__ = "v0.7.0"
|
7
7
|
__author__ = "Martin F. Schubert <mfschubert@gmail.com>"
|
8
8
|
|
9
|
-
from invrs_opt
|
10
|
-
|
9
|
+
from invrs_opt import parameterization as parameterization
|
10
|
+
|
11
|
+
from invrs_opt.optimizers.lbfgsb import (
|
12
|
+
density_lbfgsb as density_lbfgsb,
|
13
|
+
lbfgsb as lbfgsb,
|
14
|
+
levelset_lbfgsb as levelset_lbfgsb,
|
15
|
+
)
|
16
|
+
|
17
|
+
from invrs_opt.optimizers.wrapped_optax import (
|
18
|
+
density_wrapped_optax as density_wrapped_optax,
|
19
|
+
levelset_wrapped_optax as levelset_wrapped_optax,
|
20
|
+
wrapped_optax as wrapped_optax,
|
21
|
+
)
|
invrs_opt/experimental/client.py
CHANGED
@@ -4,15 +4,14 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import json
|
7
|
-
import requests
|
8
7
|
import time
|
9
8
|
from typing import Any, Dict, Optional
|
10
9
|
|
10
|
+
import requests
|
11
11
|
from totypes import json_utils
|
12
12
|
|
13
|
-
from invrs_opt import base
|
14
13
|
from invrs_opt.experimental import labels
|
15
|
-
|
14
|
+
from invrs_opt.optimizers import base
|
16
15
|
|
17
16
|
PyTree = Any
|
18
17
|
StateToken = str
|
@@ -4,8 +4,12 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import dataclasses
|
7
|
+
import inspect
|
7
8
|
from typing import Any, Protocol
|
8
9
|
|
10
|
+
import optax # type: ignore[import-untyped]
|
11
|
+
from totypes import json_utils
|
12
|
+
|
9
13
|
PyTree = Any
|
10
14
|
|
11
15
|
|
@@ -44,3 +48,9 @@ class Optimizer:
|
|
44
48
|
init: InitFn
|
45
49
|
params: ParamsFn
|
46
50
|
update: UpdateFn
|
51
|
+
|
52
|
+
|
53
|
+
# Register all optax state types for serialization.
|
54
|
+
for name, obj in inspect.getmembers(optax):
|
55
|
+
if name.endswith("State") and isinstance(obj, type):
|
56
|
+
json_utils.register_custom_type(obj)
|
@@ -5,19 +5,25 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
5
5
|
|
6
6
|
import copy
|
7
7
|
import dataclasses
|
8
|
-
from typing import Any,
|
8
|
+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
9
9
|
|
10
10
|
import jax
|
11
11
|
import jax.numpy as jnp
|
12
12
|
import numpy as onp
|
13
|
+
import optax # type: ignore[import-untyped]
|
13
14
|
from jax import flatten_util, tree_util
|
14
15
|
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
15
16
|
_lbfgsb as scipy_lbfgsb,
|
16
17
|
)
|
17
18
|
from totypes import types
|
18
19
|
|
19
|
-
from invrs_opt import base
|
20
|
-
from invrs_opt.
|
20
|
+
from invrs_opt.optimizers import base
|
21
|
+
from invrs_opt.parameterization import (
|
22
|
+
base as parameterization_base,
|
23
|
+
filter_project,
|
24
|
+
gaussian_levelset,
|
25
|
+
pixel,
|
26
|
+
)
|
21
27
|
|
22
28
|
NDArray = onp.ndarray[Any, Any]
|
23
29
|
PyTree = Any
|
@@ -35,10 +41,10 @@ UPDATE_IPRINT = -1
|
|
35
41
|
|
36
42
|
# Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
|
37
43
|
MAXCOR_MAX_VALUE = 100
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
44
|
+
DEFAULT_MAXCOR = 20
|
45
|
+
DEFAULT_LINE_SEARCH_MAX_STEPS = 100
|
46
|
+
DEFAULT_FTOL = 0.0
|
47
|
+
DEFAULT_GTOL = 0.0
|
42
48
|
|
43
49
|
# Maps bound scenarios to integers.
|
44
50
|
BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
|
@@ -52,175 +58,225 @@ FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
|
|
52
58
|
|
53
59
|
|
54
60
|
def lbfgsb(
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
61
|
+
*,
|
62
|
+
maxcor: int = DEFAULT_MAXCOR,
|
63
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
64
|
+
ftol: float = DEFAULT_FTOL,
|
65
|
+
gtol: float = DEFAULT_GTOL,
|
59
66
|
) -> base.Optimizer:
|
60
|
-
"""
|
67
|
+
"""Optimizer implementing the standard L-BFGS-B algorithm.
|
61
68
|
|
62
|
-
|
63
|
-
|
64
|
-
|
69
|
+
The standard L-BFGS-B algorithm uses the direct pixel parameterization for density
|
70
|
+
arrays, which simply enforces that values are between the declared upper and lower
|
71
|
+
bounds of the density.
|
65
72
|
|
66
|
-
|
67
|
-
|
68
|
-
def fn(x):
|
69
|
-
leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
|
70
|
-
return jnp.sum(jnp.asarray(leaves_sum_sq))
|
71
|
-
|
72
|
-
x0 = {
|
73
|
-
"a": jnp.ones((3,)),
|
74
|
-
"b": BoundedArray(
|
75
|
-
value=-jnp.ones((2, 5)),
|
76
|
-
lower_bound=-5,
|
77
|
-
upper_bound=5,
|
78
|
-
),
|
79
|
-
}
|
80
|
-
opt = lbfgsb(maxcor=20, line_search_max_steps=100)
|
81
|
-
state = opt.init(x0)
|
82
|
-
for _ in range(10):
|
83
|
-
x = opt.params(state)
|
84
|
-
value, grad = jax.value_and_grad(fn)(x)
|
85
|
-
state = opt.update(grad, value, state)
|
86
|
-
|
87
|
-
While the algorithm can work with pytrees of jax arrays, numpy arrays can
|
88
|
-
also be used. Thus, e.g. the optimizer can directly be used with autograd.
|
89
|
-
|
90
|
-
When the optimization has converged (according to `ftol` or `gtol` criteria), the
|
91
|
-
optimizer simply returns the parameters which obtained the converged result. The
|
73
|
+
When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
|
74
|
+
the optimizer `params` function will simply return the optimal parameters. The
|
92
75
|
convergence can be queried by `is_converged(state)`.
|
93
76
|
|
94
77
|
Args:
|
95
|
-
maxcor: The maximum number of variable metric corrections used to define
|
96
|
-
|
78
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
79
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
97
80
|
line_search_max_steps: The maximum number of steps in the line search.
|
98
|
-
ftol:
|
99
|
-
|
100
|
-
gtol:
|
81
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
82
|
+
for details.
|
83
|
+
gtol: Convergence criteria based on gradient.
|
101
84
|
|
102
85
|
Returns:
|
103
|
-
The `
|
86
|
+
The `Optimizer` implementing the L-BFGS-B optimizer.
|
104
87
|
"""
|
105
|
-
return
|
88
|
+
return parameterized_lbfgsb(
|
89
|
+
density_parameterization=None,
|
90
|
+
penalty=0.0,
|
106
91
|
maxcor=maxcor,
|
107
92
|
line_search_max_steps=line_search_max_steps,
|
108
93
|
ftol=ftol,
|
109
94
|
gtol=gtol,
|
110
|
-
transform_fn=lambda x: x,
|
111
|
-
initialize_latent_fn=lambda x: x,
|
112
95
|
)
|
113
96
|
|
114
97
|
|
115
98
|
def density_lbfgsb(
|
99
|
+
*,
|
116
100
|
beta: float,
|
117
|
-
maxcor: int =
|
118
|
-
line_search_max_steps: int =
|
119
|
-
ftol: float =
|
120
|
-
gtol: float =
|
101
|
+
maxcor: int = DEFAULT_MAXCOR,
|
102
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
103
|
+
ftol: float = DEFAULT_FTOL,
|
104
|
+
gtol: float = DEFAULT_GTOL,
|
121
105
|
) -> base.Optimizer:
|
122
|
-
"""
|
123
|
-
|
124
|
-
Parameters that are of type `DensityArray2D` are represented as latent parameters
|
125
|
-
that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
|
106
|
+
"""Optimizer using L-BFGS-B algorithm with filter-project density parameterization.
|
126
107
|
|
127
|
-
|
108
|
+
In the filter-project density parameterization, the optimization variable
|
109
|
+
associated with a density array is a latent density array; the density is obtained
|
110
|
+
by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
|
111
|
+
full-width at half-maximum equal to the length scale (the mean of declared minimum
|
112
|
+
width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
|
113
|
+
operation ("projection").
|
128
114
|
|
129
|
-
|
130
|
-
|
131
|
-
density is scaled before the transform is applied, and then unscaled afterwards.
|
132
|
-
|
133
|
-
When the optimization has converged (according to `ftol` or `gtol` criteria), the
|
134
|
-
optimizer simply returns the parameters which obtained the converged result. The
|
115
|
+
When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
|
116
|
+
the optimizer `params` function will simply return the optimal parameters. The
|
135
117
|
convergence can be queried by `is_converged(state)`.
|
136
118
|
|
137
119
|
Args:
|
138
|
-
beta: Determines the
|
139
|
-
maxcor: The maximum number of variable metric corrections used to define
|
140
|
-
|
120
|
+
beta: Determines the sharpness of the thresholding operation.
|
121
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
122
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
141
123
|
line_search_max_steps: The maximum number of steps in the line search.
|
142
|
-
ftol:
|
143
|
-
|
144
|
-
gtol:
|
124
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
125
|
+
for details.
|
126
|
+
gtol: Convergence criteria based on gradient.
|
145
127
|
|
146
128
|
Returns:
|
147
|
-
The `
|
129
|
+
The `Optimizer` implementing the L-BFGS-B optimizer.
|
148
130
|
"""
|
131
|
+
return parameterized_lbfgsb(
|
132
|
+
density_parameterization=filter_project.filter_project(beta=beta),
|
133
|
+
penalty=0.0,
|
134
|
+
maxcor=maxcor,
|
135
|
+
line_search_max_steps=line_search_max_steps,
|
136
|
+
ftol=ftol,
|
137
|
+
gtol=gtol,
|
138
|
+
)
|
149
139
|
|
150
|
-
def transform_fn(tree: PyTree) -> PyTree:
|
151
|
-
return tree_util.tree_map(
|
152
|
-
lambda x: transform_density(x) if _is_density(x) else x,
|
153
|
-
tree,
|
154
|
-
is_leaf=_is_density,
|
155
|
-
)
|
156
140
|
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
141
|
+
def levelset_lbfgsb(
|
142
|
+
*,
|
143
|
+
penalty: float,
|
144
|
+
length_scale_spacing_factor: float = (
|
145
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
|
146
|
+
),
|
147
|
+
length_scale_fwhm_factor: float = (
|
148
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
|
149
|
+
),
|
150
|
+
length_scale_constraint_factor: float = (
|
151
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
|
152
|
+
),
|
153
|
+
smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
|
154
|
+
length_scale_constraint_beta: float = (
|
155
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
|
156
|
+
),
|
157
|
+
length_scale_constraint_weight: float = (
|
158
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
|
159
|
+
),
|
160
|
+
curvature_constraint_weight: float = (
|
161
|
+
gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
|
162
|
+
),
|
163
|
+
fixed_pixel_constraint_weight: float = (
|
164
|
+
gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
|
165
|
+
),
|
166
|
+
init_optimizer: optax.GradientTransformation = (
|
167
|
+
gaussian_levelset.DEFAULT_INIT_OPTIMIZER
|
168
|
+
),
|
169
|
+
init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
|
170
|
+
maxcor: int = DEFAULT_MAXCOR,
|
171
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
172
|
+
ftol: float = DEFAULT_FTOL,
|
173
|
+
gtol: float = DEFAULT_GTOL,
|
174
|
+
) -> base.Optimizer:
|
175
|
+
"""Optimizer using L-BFGS-B algorithm with levelset density parameterization.
|
176
|
+
|
177
|
+
In the levelset parameterization, the optimization variable associated with a
|
178
|
+
density array is an array giving the amplitudes of Gaussian radial basis functions
|
179
|
+
that represent a levelset function over the domain of the density. In the levelset
|
180
|
+
parameterization, gradients are nonzero only at the edges of features, and in
|
181
|
+
general the topology of a solution does not change during the course of
|
182
|
+
optimization.
|
183
|
+
|
184
|
+
The spacing and full-width at half-maximum of the Gaussian basis functions gives
|
185
|
+
some amount of control over length scales. In addition, constraints associated with
|
186
|
+
length scale, radius of curvature, and deviation from fixed pixels are
|
187
|
+
automatically computed and penalized with a weight given by `penalty`. In general,
|
188
|
+
this helps ensure that features in an optimized density array violate the specified
|
189
|
+
constraints to a lesser degree. The constraints are based on "Analytical level set
|
190
|
+
fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
|
191
|
+
|
192
|
+
When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
|
193
|
+
the optimizer `params` function will simply return the optimal parameters. The
|
194
|
+
convergence can be queried by `is_converged(state)`.
|
163
195
|
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
196
|
+
Args:
|
197
|
+
penalty: The weight of the fabrication penalty, which combines length scale,
|
198
|
+
curvature, and fixed pixel constraints.
|
199
|
+
length_scale_spacing_factor: The number of levelset control points per unit of
|
200
|
+
minimum length scale (mean of density minimum width and minimum spacing).
|
201
|
+
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
202
|
+
the minimum length scale.
|
203
|
+
length_scale_constraint_factor: Multiplies the target length scale in the
|
204
|
+
levelset constraints. A value greater than 1 is pessimistic and drives the
|
205
|
+
solution to have a larger length scale (relative to smaller values).
|
206
|
+
smoothing_factor: For values greater than 1, the density is initially computed
|
207
|
+
at higher resolution and then downsampled, yielding smoother geometries.
|
208
|
+
length_scale_constraint_beta: Controls relaxation of the length scale
|
209
|
+
constraint near the zero level.
|
210
|
+
length_scale_constraint_weight: The weight of the length scale constraint in
|
211
|
+
the overall fabrication constraint peenalty.
|
212
|
+
curvature_constraint_weight: The weight of the curvature constraint.
|
213
|
+
fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
|
214
|
+
init_optimizer: The optimizer used in the initialization of the levelset
|
215
|
+
parameterization. At initialization, the latent parameters are optimized so
|
216
|
+
that the initial parameters match the binarized initial density.
|
217
|
+
init_steps: The number of optimization steps used in the initialization.
|
218
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
219
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
220
|
+
line_search_max_steps: The maximum number of steps in the line search.
|
221
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
222
|
+
for details.
|
223
|
+
gtol: Convergence criteria based on gradient.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
The `Optimizer` implementing the L-BFGS-B optimizer.
|
227
|
+
"""
|
228
|
+
return parameterized_lbfgsb(
|
229
|
+
density_parameterization=gaussian_levelset.gaussian_levelset(
|
230
|
+
length_scale_spacing_factor=length_scale_spacing_factor,
|
231
|
+
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
232
|
+
length_scale_constraint_factor=length_scale_constraint_factor,
|
233
|
+
smoothing_factor=smoothing_factor,
|
234
|
+
length_scale_constraint_beta=length_scale_constraint_beta,
|
235
|
+
length_scale_constraint_weight=length_scale_constraint_weight,
|
236
|
+
curvature_constraint_weight=curvature_constraint_weight,
|
237
|
+
fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
|
238
|
+
init_optimizer=init_optimizer,
|
239
|
+
init_steps=init_steps,
|
240
|
+
),
|
241
|
+
penalty=penalty,
|
185
242
|
maxcor=maxcor,
|
186
243
|
line_search_max_steps=line_search_max_steps,
|
187
244
|
ftol=ftol,
|
188
245
|
gtol=gtol,
|
189
|
-
transform_fn=transform_fn,
|
190
|
-
initialize_latent_fn=initialize_latent_fn,
|
191
246
|
)
|
192
247
|
|
193
248
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
ftol: float,
|
198
|
-
gtol: float,
|
199
|
-
transform_fn: Callable[[PyTree], PyTree],
|
200
|
-
initialize_latent_fn: Callable[[PyTree], PyTree],
|
201
|
-
) -> base.Optimizer:
|
202
|
-
"""Construct an latent parameter L-BFGS-B optimizer.
|
249
|
+
# -----------------------------------------------------------------------------
|
250
|
+
# Base parameterized L-BFGS-B optimizer.
|
251
|
+
# -----------------------------------------------------------------------------
|
203
252
|
|
204
|
-
The optimized parameters are termed latent parameters, from which the
|
205
|
-
actual parameters returned by the optimizer are obtained using the
|
206
|
-
`transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
|
207
|
-
the identity), this is equivalent to the standard L-BFGS-B algorithm.
|
208
253
|
|
209
|
-
|
210
|
-
|
211
|
-
|
254
|
+
def parameterized_lbfgsb(
|
255
|
+
density_parameterization: Optional[parameterization_base.Density2DParameterization],
|
256
|
+
penalty: float,
|
257
|
+
maxcor: int = DEFAULT_MAXCOR,
|
258
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
259
|
+
ftol: float = DEFAULT_FTOL,
|
260
|
+
gtol: float = DEFAULT_GTOL,
|
261
|
+
) -> base.Optimizer:
|
262
|
+
"""Optimizer using L-BFGS-B optimizer with specified density parameterization.
|
263
|
+
|
264
|
+
This optimizer wraps scipy's implementation of the algorithm, and provides
|
265
|
+
a jax-style API to the scheme. The optimizer works with custom types such
|
266
|
+
as the `BoundedArray` to constrain the optimization variable.
|
212
267
|
|
213
268
|
Args:
|
214
|
-
|
215
|
-
the
|
269
|
+
density_parameterization: The parameterization to be used, or `None`. When no
|
270
|
+
parameterization is given, the direct pixel parameterization is used for
|
271
|
+
density arrays.
|
272
|
+
penalty: The weight of the scalar penalty formed from the constraints of the
|
273
|
+
parameterization.
|
274
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
275
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
216
276
|
line_search_max_steps: The maximum number of steps in the line search.
|
217
|
-
ftol:
|
218
|
-
|
219
|
-
gtol:
|
220
|
-
transform_fn: Function which transforms the internal latent parameters to
|
221
|
-
the parameters returned by the optimizer.
|
222
|
-
initialize_latent_fn: Function which computes the initial latent parameters
|
223
|
-
given the initial parameters.
|
277
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
278
|
+
for details.
|
279
|
+
gtol: Convergence criteria based on gradient.
|
224
280
|
|
225
281
|
Returns:
|
226
282
|
The `base.Optimizer`.
|
@@ -237,33 +293,73 @@ def transformed_lbfgsb(
|
|
237
293
|
f"{line_search_max_steps}"
|
238
294
|
)
|
239
295
|
|
296
|
+
if density_parameterization is None:
|
297
|
+
density_parameterization = pixel.pixel()
|
298
|
+
|
299
|
+
def _init_latents(params: PyTree) -> PyTree:
|
300
|
+
def _leaf_init_latents(leaf: Any) -> Any:
|
301
|
+
leaf = _clip(leaf)
|
302
|
+
if not _is_density(leaf) or density_parameterization is None:
|
303
|
+
return leaf
|
304
|
+
return density_parameterization.from_density(leaf)
|
305
|
+
|
306
|
+
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
307
|
+
|
308
|
+
def _params_from_latents(latent_params: PyTree) -> PyTree:
|
309
|
+
def _leaf_params_from_latents(leaf: Any) -> Any:
|
310
|
+
if not _is_parameterized_density(leaf) or density_parameterization is None:
|
311
|
+
return leaf
|
312
|
+
return density_parameterization.to_density(leaf)
|
313
|
+
|
314
|
+
return tree_util.tree_map(
|
315
|
+
_leaf_params_from_latents,
|
316
|
+
latent_params,
|
317
|
+
is_leaf=_is_parameterized_density,
|
318
|
+
)
|
319
|
+
|
320
|
+
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
321
|
+
def _constraint_loss_leaf(
|
322
|
+
params: parameterization_base.ParameterizedDensity2DArrayBase,
|
323
|
+
) -> jnp.ndarray:
|
324
|
+
constraints = density_parameterization.constraints(params)
|
325
|
+
constraints = tree_util.tree_map(
|
326
|
+
lambda x: jnp.sum(jnp.maximum(x, 0.0)),
|
327
|
+
constraints,
|
328
|
+
)
|
329
|
+
return jnp.sum(jnp.asarray(constraints))
|
330
|
+
|
331
|
+
losses = [0.0] + [
|
332
|
+
_constraint_loss_leaf(p)
|
333
|
+
for p in tree_util.tree_leaves(
|
334
|
+
latent_params, is_leaf=_is_parameterized_density
|
335
|
+
)
|
336
|
+
if _is_parameterized_density(p)
|
337
|
+
]
|
338
|
+
return penalty * jnp.sum(jnp.asarray(losses))
|
339
|
+
|
240
340
|
def init_fn(params: PyTree) -> LbfgsbState:
|
241
341
|
"""Initializes the optimization state."""
|
242
342
|
|
243
|
-
def
|
244
|
-
lower_bound = types.extract_lower_bound(
|
245
|
-
upper_bound = types.extract_upper_bound(
|
343
|
+
def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
|
344
|
+
lower_bound = types.extract_lower_bound(latent_params)
|
345
|
+
upper_bound = types.extract_upper_bound(latent_params)
|
246
346
|
scipy_lbfgsb_state = ScipyLbfgsbState.init(
|
247
|
-
x0=_to_numpy(
|
248
|
-
lower_bound=_bound_for_params(lower_bound,
|
249
|
-
upper_bound=_bound_for_params(upper_bound,
|
347
|
+
x0=_to_numpy(latent_params),
|
348
|
+
lower_bound=_bound_for_params(lower_bound, latent_params),
|
349
|
+
upper_bound=_bound_for_params(upper_bound, latent_params),
|
250
350
|
maxcor=maxcor,
|
251
351
|
line_search_max_steps=line_search_max_steps,
|
252
352
|
ftol=ftol,
|
253
353
|
gtol=gtol,
|
254
354
|
)
|
255
|
-
latent_params = _to_pytree(scipy_lbfgsb_state.x,
|
355
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_params)
|
256
356
|
return latent_params, scipy_lbfgsb_state.to_jax()
|
257
357
|
|
258
|
-
(
|
259
|
-
|
260
|
-
|
261
|
-
) = jax.pure_callback(
|
262
|
-
_init_pure,
|
263
|
-
_example_state(params, maxcor),
|
264
|
-
initialize_latent_fn(params),
|
358
|
+
latent_params = _init_latents(params)
|
359
|
+
latent_params, jax_lbfgsb_state = jax.pure_callback(
|
360
|
+
_init_state_pure, _example_state(latent_params, maxcor), latent_params
|
265
361
|
)
|
266
|
-
return
|
362
|
+
return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
|
267
363
|
|
268
364
|
def params_fn(state: LbfgsbState) -> PyTree:
|
269
365
|
"""Returns the parameters for the given `state`."""
|
@@ -295,16 +391,35 @@ def transformed_lbfgsb(
|
|
295
391
|
return flat_latent_params, scipy_lbfgsb_state.to_jax()
|
296
392
|
|
297
393
|
_, latent_params, jax_lbfgsb_state = state
|
298
|
-
_, vjp_fn = jax.vjp(
|
394
|
+
_, vjp_fn = jax.vjp(_params_from_latents, latent_params)
|
299
395
|
(latent_grad,) = vjp_fn(grad)
|
396
|
+
|
397
|
+
if not (
|
398
|
+
tree_util.tree_structure(latent_grad)
|
399
|
+
== tree_util.tree_structure(latent_params) # type: ignore[operator]
|
400
|
+
):
|
401
|
+
raise ValueError(
|
402
|
+
f"Tree structure of `latent_grad` was different than expected, got \n"
|
403
|
+
f"{tree_util.tree_structure(latent_grad)} but expected \n"
|
404
|
+
f"{tree_util.tree_structure(latent_params)}."
|
405
|
+
)
|
406
|
+
|
407
|
+
(
|
408
|
+
constraint_loss_value,
|
409
|
+
constraint_loss_grad,
|
410
|
+
) = jax.value_and_grad(
|
411
|
+
_constraint_loss
|
412
|
+
)(latent_params)
|
413
|
+
value += constraint_loss_value
|
414
|
+
latent_grad = tree_util.tree_map(
|
415
|
+
lambda a, b: a + b, latent_grad, constraint_loss_grad
|
416
|
+
)
|
417
|
+
|
300
418
|
flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
|
301
419
|
latent_grad
|
302
420
|
) # type: ignore[no-untyped-call]
|
303
421
|
|
304
|
-
(
|
305
|
-
flat_latent_params,
|
306
|
-
jax_lbfgsb_state,
|
307
|
-
) = jax.pure_callback(
|
422
|
+
flat_latent_params, jax_lbfgsb_state = jax.pure_callback(
|
308
423
|
_update_pure,
|
309
424
|
(flat_latent_grad, jax_lbfgsb_state),
|
310
425
|
flat_latent_grad,
|
@@ -312,7 +427,7 @@ def transformed_lbfgsb(
|
|
312
427
|
jax_lbfgsb_state,
|
313
428
|
)
|
314
429
|
latent_params = unflatten_fn(flat_latent_params)
|
315
|
-
return
|
430
|
+
return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
|
316
431
|
|
317
432
|
return base.Optimizer(
|
318
433
|
init=init_fn,
|
@@ -336,6 +451,31 @@ def _is_density(leaf: Any) -> Any:
|
|
336
451
|
return isinstance(leaf, types.Density2DArray)
|
337
452
|
|
338
453
|
|
454
|
+
def _is_parameterized_density(leaf: Any) -> Any:
|
455
|
+
"""Return `True` if `leaf` is a parameterized density array."""
|
456
|
+
return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase)
|
457
|
+
|
458
|
+
|
459
|
+
def _is_custom_type(leaf: Any) -> bool:
|
460
|
+
"""Return `True` if `leaf` is a recognized custom type."""
|
461
|
+
return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
|
462
|
+
|
463
|
+
|
464
|
+
def _clip(pytree: PyTree) -> PyTree:
|
465
|
+
"""Clips leaves on `pytree` to their bounds."""
|
466
|
+
|
467
|
+
def _clip_fn(leaf: Any) -> Any:
|
468
|
+
if not _is_custom_type(leaf):
|
469
|
+
return leaf
|
470
|
+
if leaf.lower_bound is None and leaf.upper_bound is None:
|
471
|
+
return leaf
|
472
|
+
return tree_util.tree_map(
|
473
|
+
lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
|
474
|
+
)
|
475
|
+
|
476
|
+
return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
|
477
|
+
|
478
|
+
|
339
479
|
def _to_numpy(params: PyTree) -> NDArray:
|
340
480
|
"""Flattens a `params` pytree into a single rank-1 numpy array."""
|
341
481
|
x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|