invrs-opt 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +1 -1
- invrs_opt/optimizers/wrapped_optax.py +84 -61
- {invrs_opt-0.8.0.dist-info → invrs_opt-0.8.1.dist-info}/METADATA +2 -2
- {invrs_opt-0.8.0.dist-info → invrs_opt-0.8.1.dist-info}/RECORD +7 -7
- {invrs_opt-0.8.0.dist-info → invrs_opt-0.8.1.dist-info}/LICENSE +0 -0
- {invrs_opt-0.8.0.dist-info → invrs_opt-0.8.1.dist-info}/WHEEL +0 -0
- {invrs_opt-0.8.0.dist-info → invrs_opt-0.8.1.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -178,64 +178,15 @@ def parameterized_wrapped_optax(
|
|
178
178
|
if density_parameterization is None:
|
179
179
|
density_parameterization = pixel.pixel()
|
180
180
|
|
181
|
-
def _init_latents(params: PyTree) -> PyTree:
|
182
|
-
def _leaf_init_latents(leaf: Any) -> Any:
|
183
|
-
leaf = _clip(leaf)
|
184
|
-
if not _is_density(leaf):
|
185
|
-
return leaf
|
186
|
-
return density_parameterization.from_density(leaf)
|
187
|
-
|
188
|
-
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
189
|
-
|
190
|
-
def _params_from_latents(params: PyTree) -> PyTree:
|
191
|
-
def _leaf_params_from_latents(leaf: Any) -> Any:
|
192
|
-
if not _is_parameterized_density(leaf):
|
193
|
-
return leaf
|
194
|
-
return density_parameterization.to_density(leaf)
|
195
|
-
|
196
|
-
return tree_util.tree_map(
|
197
|
-
_leaf_params_from_latents,
|
198
|
-
params,
|
199
|
-
is_leaf=_is_parameterized_density,
|
200
|
-
)
|
201
|
-
|
202
|
-
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
203
|
-
def _constraint_loss_leaf(
|
204
|
-
params: parameterization_base.ParameterizedDensity2DArrayBase,
|
205
|
-
) -> jnp.ndarray:
|
206
|
-
constraints = density_parameterization.constraints(params)
|
207
|
-
constraints = tree_util.tree_map(
|
208
|
-
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
209
|
-
constraints,
|
210
|
-
)
|
211
|
-
return jnp.sum(jnp.asarray(constraints))
|
212
|
-
|
213
|
-
losses = [0.0] + [
|
214
|
-
_constraint_loss_leaf(p)
|
215
|
-
for p in tree_util.tree_leaves(
|
216
|
-
latent_params, is_leaf=_is_parameterized_density
|
217
|
-
)
|
218
|
-
if _is_parameterized_density(p)
|
219
|
-
]
|
220
|
-
return penalty * jnp.sum(jnp.asarray(losses))
|
221
|
-
|
222
|
-
def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
|
223
|
-
def _update_leaf(leaf: Any) -> Any:
|
224
|
-
if not _is_parameterized_density(leaf):
|
225
|
-
return leaf
|
226
|
-
return density_parameterization.update(leaf, step)
|
227
|
-
|
228
|
-
return tree_util.tree_map(
|
229
|
-
_update_leaf,
|
230
|
-
latent_params,
|
231
|
-
is_leaf=_is_parameterized_density,
|
232
|
-
)
|
233
|
-
|
234
181
|
def init_fn(params: PyTree) -> WrappedOptaxState:
|
235
182
|
"""Initializes the optimization state."""
|
236
183
|
latent_params = _init_latents(params)
|
237
|
-
|
238
|
-
|
184
|
+
return (
|
185
|
+
0, # step
|
186
|
+
_params_from_latents(latent_params), # params
|
187
|
+
latent_params, # latent params
|
188
|
+
opt.init(tree_util.tree_leaves(latent_params)), # opt state
|
189
|
+
)
|
239
190
|
|
240
191
|
def params_fn(state: WrappedOptaxState) -> PyTree:
|
241
192
|
"""Returns the parameters for the given `state`."""
|
@@ -252,7 +203,8 @@ def parameterized_wrapped_optax(
|
|
252
203
|
"""Updates the state."""
|
253
204
|
del value, params
|
254
205
|
|
255
|
-
step,
|
206
|
+
step, params, latent_params, opt_state = state
|
207
|
+
|
256
208
|
_, vjp_fn = jax.vjp(_params_from_latents, latent_params)
|
257
209
|
(latent_grad,) = vjp_fn(grad)
|
258
210
|
|
@@ -271,14 +223,85 @@ def parameterized_wrapped_optax(
|
|
271
223
|
lambda a, b: a + b, latent_grad, constraint_loss_grad
|
272
224
|
)
|
273
225
|
|
274
|
-
|
275
|
-
updates=latent_grad,
|
226
|
+
updates_leaves, opt_state = opt.update(
|
227
|
+
updates=tree_util.tree_leaves(latent_grad),
|
228
|
+
state=opt_state,
|
229
|
+
params=tree_util.tree_leaves(latent_params),
|
230
|
+
)
|
231
|
+
latent_params_leaves = optax.apply_updates(
|
232
|
+
params=tree_util.tree_leaves(latent_params),
|
233
|
+
updates=updates_leaves,
|
276
234
|
)
|
277
|
-
latent_params =
|
235
|
+
latent_params = tree_util.tree_unflatten(
|
236
|
+
treedef=tree_util.tree_structure(latent_params),
|
237
|
+
leaves=latent_params_leaves,
|
238
|
+
)
|
239
|
+
|
278
240
|
latent_params = _clip(latent_params)
|
279
|
-
latent_params = _update_parameterized_densities(latent_params, step)
|
241
|
+
latent_params = _update_parameterized_densities(latent_params, step + 1)
|
280
242
|
params = _params_from_latents(latent_params)
|
281
|
-
return step + 1, params, latent_params, opt_state
|
243
|
+
return (step + 1, params, latent_params, opt_state)
|
244
|
+
|
245
|
+
# -------------------------------------------------------------------------
|
246
|
+
# Functions related to the density parameterization.
|
247
|
+
# -------------------------------------------------------------------------
|
248
|
+
|
249
|
+
def _init_latents(params: PyTree) -> PyTree:
|
250
|
+
def _leaf_init_latents(leaf: Any) -> Any:
|
251
|
+
leaf = _clip(leaf)
|
252
|
+
if not _is_density(leaf):
|
253
|
+
return leaf
|
254
|
+
return density_parameterization.from_density(leaf)
|
255
|
+
|
256
|
+
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
257
|
+
|
258
|
+
def _params_from_latents(params: PyTree) -> PyTree:
|
259
|
+
def _leaf_params_from_latents(leaf: Any) -> Any:
|
260
|
+
if not _is_parameterized_density(leaf):
|
261
|
+
return leaf
|
262
|
+
return density_parameterization.to_density(leaf)
|
263
|
+
|
264
|
+
return tree_util.tree_map(
|
265
|
+
_leaf_params_from_latents,
|
266
|
+
params,
|
267
|
+
is_leaf=_is_parameterized_density,
|
268
|
+
)
|
269
|
+
|
270
|
+
def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
|
271
|
+
def _update_leaf(leaf: Any) -> Any:
|
272
|
+
if not _is_parameterized_density(leaf):
|
273
|
+
return leaf
|
274
|
+
return density_parameterization.update(leaf, step)
|
275
|
+
|
276
|
+
return tree_util.tree_map(
|
277
|
+
_update_leaf,
|
278
|
+
latent_params,
|
279
|
+
is_leaf=_is_parameterized_density,
|
280
|
+
)
|
281
|
+
|
282
|
+
# -------------------------------------------------------------------------
|
283
|
+
# Functions related to the constraints to be minimized.
|
284
|
+
# -------------------------------------------------------------------------
|
285
|
+
|
286
|
+
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
287
|
+
def _constraint_loss_leaf(
|
288
|
+
params: parameterization_base.ParameterizedDensity2DArrayBase,
|
289
|
+
) -> jnp.ndarray:
|
290
|
+
constraints = density_parameterization.constraints(params)
|
291
|
+
constraints = tree_util.tree_map(
|
292
|
+
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
293
|
+
constraints,
|
294
|
+
)
|
295
|
+
return jnp.sum(jnp.asarray(constraints))
|
296
|
+
|
297
|
+
losses = [0.0] + [
|
298
|
+
_constraint_loss_leaf(p)
|
299
|
+
for p in tree_util.tree_leaves(
|
300
|
+
latent_params, is_leaf=_is_parameterized_density
|
301
|
+
)
|
302
|
+
if _is_parameterized_density(p)
|
303
|
+
]
|
304
|
+
return penalty * jnp.sum(jnp.asarray(losses))
|
282
305
|
|
283
306
|
return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
|
284
307
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.8.
|
3
|
+
Version: 0.8.1
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
|
|
533
533
|
Requires-Dist: pytest-subtests; extra == "tests"
|
534
534
|
|
535
535
|
# invrs-opt - Optimization algorithms for inverse design
|
536
|
-
`v0.8.
|
536
|
+
`v0.8.1`
|
537
537
|
|
538
538
|
## Overview
|
539
539
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=
|
1
|
+
invrs_opt/__init__.py,sha256=tjbabqQKp5y1I5eqDa7onsy7E58P9nTk_Ifn12oMfZM,585
|
2
2
|
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
|
@@ -6,15 +6,15 @@ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8
|
|
6
6
|
invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
|
8
8
|
invrs_opt/optimizers/lbfgsb.py,sha256=2NcyFllUM5CrjOZLEBocYAlCrkWt1fiRGJrxg05waog,35149
|
9
|
-
invrs_opt/optimizers/wrapped_optax.py,sha256=
|
9
|
+
invrs_opt/optimizers/wrapped_optax.py,sha256=c4iSsXHF2cJZaZhdxLpyRy26q5IxBTo4-6sY8IOLT38,13259
|
10
10
|
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
invrs_opt/parameterization/base.py,sha256=k2VXZorMnkmaIWWY_G2CvS0NtxYojfvKlPccDGZyGNA,3779
|
12
12
|
invrs_opt/parameterization/filter_project.py,sha256=D2w8xrg34V8ysIbbr1RPvegM5WoLdz8QKUCGi80ieOI,3466
|
13
13
|
invrs_opt/parameterization/gaussian_levelset.py,sha256=DllvpkBpVxuDFQFu941f8v_Xh2D0m9Eh-JxBH5afNHU,25221
|
14
14
|
invrs_opt/parameterization/pixel.py,sha256=yWdGpf0x6Om_-7CYOcHTgNK9UF-XGAPM7GbOBNijnBw,1305
|
15
15
|
invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
|
16
|
-
invrs_opt-0.8.
|
17
|
-
invrs_opt-0.8.
|
18
|
-
invrs_opt-0.8.
|
19
|
-
invrs_opt-0.8.
|
20
|
-
invrs_opt-0.8.
|
16
|
+
invrs_opt-0.8.1.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
17
|
+
invrs_opt-0.8.1.dist-info/METADATA,sha256=4qLsuDLp3pNn5C1TTPnF0qwWps7MeN_Uc3rKsBLbPrk,32633
|
18
|
+
invrs_opt-0.8.1.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
19
|
+
invrs_opt-0.8.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
20
|
+
invrs_opt-0.8.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|