invrs-opt 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.8.0"
6
+ __version__ = "v0.8.1"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -178,64 +178,15 @@ def parameterized_wrapped_optax(
178
178
  if density_parameterization is None:
179
179
  density_parameterization = pixel.pixel()
180
180
 
181
- def _init_latents(params: PyTree) -> PyTree:
182
- def _leaf_init_latents(leaf: Any) -> Any:
183
- leaf = _clip(leaf)
184
- if not _is_density(leaf):
185
- return leaf
186
- return density_parameterization.from_density(leaf)
187
-
188
- return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
189
-
190
- def _params_from_latents(params: PyTree) -> PyTree:
191
- def _leaf_params_from_latents(leaf: Any) -> Any:
192
- if not _is_parameterized_density(leaf):
193
- return leaf
194
- return density_parameterization.to_density(leaf)
195
-
196
- return tree_util.tree_map(
197
- _leaf_params_from_latents,
198
- params,
199
- is_leaf=_is_parameterized_density,
200
- )
201
-
202
- def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
203
- def _constraint_loss_leaf(
204
- params: parameterization_base.ParameterizedDensity2DArrayBase,
205
- ) -> jnp.ndarray:
206
- constraints = density_parameterization.constraints(params)
207
- constraints = tree_util.tree_map(
208
- lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
209
- constraints,
210
- )
211
- return jnp.sum(jnp.asarray(constraints))
212
-
213
- losses = [0.0] + [
214
- _constraint_loss_leaf(p)
215
- for p in tree_util.tree_leaves(
216
- latent_params, is_leaf=_is_parameterized_density
217
- )
218
- if _is_parameterized_density(p)
219
- ]
220
- return penalty * jnp.sum(jnp.asarray(losses))
221
-
222
- def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
223
- def _update_leaf(leaf: Any) -> Any:
224
- if not _is_parameterized_density(leaf):
225
- return leaf
226
- return density_parameterization.update(leaf, step)
227
-
228
- return tree_util.tree_map(
229
- _update_leaf,
230
- latent_params,
231
- is_leaf=_is_parameterized_density,
232
- )
233
-
234
181
  def init_fn(params: PyTree) -> WrappedOptaxState:
235
182
  """Initializes the optimization state."""
236
183
  latent_params = _init_latents(params)
237
- params = _params_from_latents(latent_params)
238
- return 0, params, latent_params, opt.init(latent_params)
184
+ return (
185
+ 0, # step
186
+ _params_from_latents(latent_params), # params
187
+ latent_params, # latent params
188
+ opt.init(tree_util.tree_leaves(latent_params)), # opt state
189
+ )
239
190
 
240
191
  def params_fn(state: WrappedOptaxState) -> PyTree:
241
192
  """Returns the parameters for the given `state`."""
@@ -252,7 +203,8 @@ def parameterized_wrapped_optax(
252
203
  """Updates the state."""
253
204
  del value, params
254
205
 
255
- step, _, latent_params, opt_state = state
206
+ step, params, latent_params, opt_state = state
207
+
256
208
  _, vjp_fn = jax.vjp(_params_from_latents, latent_params)
257
209
  (latent_grad,) = vjp_fn(grad)
258
210
 
@@ -271,14 +223,85 @@ def parameterized_wrapped_optax(
271
223
  lambda a, b: a + b, latent_grad, constraint_loss_grad
272
224
  )
273
225
 
274
- updates, opt_state = opt.update(
275
- updates=latent_grad, state=opt_state, params=latent_params
226
+ updates_leaves, opt_state = opt.update(
227
+ updates=tree_util.tree_leaves(latent_grad),
228
+ state=opt_state,
229
+ params=tree_util.tree_leaves(latent_params),
230
+ )
231
+ latent_params_leaves = optax.apply_updates(
232
+ params=tree_util.tree_leaves(latent_params),
233
+ updates=updates_leaves,
276
234
  )
277
- latent_params = optax.apply_updates(params=latent_params, updates=updates)
235
+ latent_params = tree_util.tree_unflatten(
236
+ treedef=tree_util.tree_structure(latent_params),
237
+ leaves=latent_params_leaves,
238
+ )
239
+
278
240
  latent_params = _clip(latent_params)
279
- latent_params = _update_parameterized_densities(latent_params, step)
241
+ latent_params = _update_parameterized_densities(latent_params, step + 1)
280
242
  params = _params_from_latents(latent_params)
281
- return step + 1, params, latent_params, opt_state
243
+ return (step + 1, params, latent_params, opt_state)
244
+
245
+ # -------------------------------------------------------------------------
246
+ # Functions related to the density parameterization.
247
+ # -------------------------------------------------------------------------
248
+
249
+ def _init_latents(params: PyTree) -> PyTree:
250
+ def _leaf_init_latents(leaf: Any) -> Any:
251
+ leaf = _clip(leaf)
252
+ if not _is_density(leaf):
253
+ return leaf
254
+ return density_parameterization.from_density(leaf)
255
+
256
+ return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
257
+
258
+ def _params_from_latents(params: PyTree) -> PyTree:
259
+ def _leaf_params_from_latents(leaf: Any) -> Any:
260
+ if not _is_parameterized_density(leaf):
261
+ return leaf
262
+ return density_parameterization.to_density(leaf)
263
+
264
+ return tree_util.tree_map(
265
+ _leaf_params_from_latents,
266
+ params,
267
+ is_leaf=_is_parameterized_density,
268
+ )
269
+
270
+ def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
271
+ def _update_leaf(leaf: Any) -> Any:
272
+ if not _is_parameterized_density(leaf):
273
+ return leaf
274
+ return density_parameterization.update(leaf, step)
275
+
276
+ return tree_util.tree_map(
277
+ _update_leaf,
278
+ latent_params,
279
+ is_leaf=_is_parameterized_density,
280
+ )
281
+
282
+ # -------------------------------------------------------------------------
283
+ # Functions related to the constraints to be minimized.
284
+ # -------------------------------------------------------------------------
285
+
286
+ def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
287
+ def _constraint_loss_leaf(
288
+ params: parameterization_base.ParameterizedDensity2DArrayBase,
289
+ ) -> jnp.ndarray:
290
+ constraints = density_parameterization.constraints(params)
291
+ constraints = tree_util.tree_map(
292
+ lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
293
+ constraints,
294
+ )
295
+ return jnp.sum(jnp.asarray(constraints))
296
+
297
+ losses = [0.0] + [
298
+ _constraint_loss_leaf(p)
299
+ for p in tree_util.tree_leaves(
300
+ latent_params, is_leaf=_is_parameterized_density
301
+ )
302
+ if _is_parameterized_density(p)
303
+ ]
304
+ return penalty * jnp.sum(jnp.asarray(losses))
282
305
 
283
306
  return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
284
307
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.8.0
3
+ Version: 0.8.1
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
533
533
  Requires-Dist: pytest-subtests; extra == "tests"
534
534
 
535
535
  # invrs-opt - Optimization algorithms for inverse design
536
- `v0.8.0`
536
+ `v0.8.1`
537
537
 
538
538
  ## Overview
539
539
 
@@ -1,4 +1,4 @@
1
- invrs_opt/__init__.py,sha256=kTrg48iZu7i5OlH0Nqtfh_wBn3be9u2eZgJBRqUr6uQ,585
1
+ invrs_opt/__init__.py,sha256=tjbabqQKp5y1I5eqDa7onsy7E58P9nTk_Ifn12oMfZM,585
2
2
  invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
@@ -6,15 +6,15 @@ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8
6
6
  invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
8
8
  invrs_opt/optimizers/lbfgsb.py,sha256=2NcyFllUM5CrjOZLEBocYAlCrkWt1fiRGJrxg05waog,35149
9
- invrs_opt/optimizers/wrapped_optax.py,sha256=zPC2j_KkfF2RqeorxB38ovuUsg1SNwdxwhjA7gvOMC4,12387
9
+ invrs_opt/optimizers/wrapped_optax.py,sha256=c4iSsXHF2cJZaZhdxLpyRy26q5IxBTo4-6sY8IOLT38,13259
10
10
  invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  invrs_opt/parameterization/base.py,sha256=k2VXZorMnkmaIWWY_G2CvS0NtxYojfvKlPccDGZyGNA,3779
12
12
  invrs_opt/parameterization/filter_project.py,sha256=D2w8xrg34V8ysIbbr1RPvegM5WoLdz8QKUCGi80ieOI,3466
13
13
  invrs_opt/parameterization/gaussian_levelset.py,sha256=DllvpkBpVxuDFQFu941f8v_Xh2D0m9Eh-JxBH5afNHU,25221
14
14
  invrs_opt/parameterization/pixel.py,sha256=yWdGpf0x6Om_-7CYOcHTgNK9UF-XGAPM7GbOBNijnBw,1305
15
15
  invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
- invrs_opt-0.8.0.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- invrs_opt-0.8.0.dist-info/METADATA,sha256=WSzpZwByiBqOZfiwSGvaVaFa5TUAPj4b-mr2qZRGLq0,32633
18
- invrs_opt-0.8.0.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
19
- invrs_opt-0.8.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
- invrs_opt-0.8.0.dist-info/RECORD,,
16
+ invrs_opt-0.8.1.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ invrs_opt-0.8.1.dist-info/METADATA,sha256=4qLsuDLp3pNn5C1TTPnF0qwWps7MeN_Uc3rKsBLbPrk,32633
18
+ invrs_opt-0.8.1.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
19
+ invrs_opt-0.8.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
+ invrs_opt-0.8.1.dist-info/RECORD,,