invrs-opt 0.7.2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.7.2"
6
+ __version__ = "v0.8.1"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -29,7 +29,7 @@ NDArray = onp.ndarray[Any, Any]
29
29
  PyTree = Any
30
30
  ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
31
31
  JaxLbfgsbDict = Dict[str, jnp.ndarray]
32
- LbfgsbState = Tuple[PyTree, PyTree, JaxLbfgsbDict]
32
+ LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]
33
33
 
34
34
 
35
35
  # Task message prefixes for the underlying L-BFGS-B implementation.
@@ -323,7 +323,7 @@ def parameterized_lbfgsb(
323
323
  ) -> jnp.ndarray:
324
324
  constraints = density_parameterization.constraints(params)
325
325
  constraints = tree_util.tree_map(
326
- lambda x: jnp.sum(jnp.maximum(x, 0.0)),
326
+ lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
327
327
  constraints,
328
328
  )
329
329
  return jnp.sum(jnp.asarray(constraints))
@@ -337,6 +337,18 @@ def parameterized_lbfgsb(
337
337
  ]
338
338
  return penalty * jnp.sum(jnp.asarray(losses))
339
339
 
340
+ def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
341
+ def _update_leaf(leaf: Any) -> Any:
342
+ if not _is_parameterized_density(leaf):
343
+ return leaf
344
+ return density_parameterization.update(leaf, step)
345
+
346
+ return tree_util.tree_map(
347
+ _update_leaf,
348
+ latent_params,
349
+ is_leaf=_is_parameterized_density,
350
+ )
351
+
340
352
  def init_fn(params: PyTree) -> LbfgsbState:
341
353
  """Initializes the optimization state."""
342
354
 
@@ -359,11 +371,11 @@ def parameterized_lbfgsb(
359
371
  latent_params, jax_lbfgsb_state = jax.pure_callback(
360
372
  _init_state_pure, _example_state(latent_params, maxcor), latent_params
361
373
  )
362
- return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
374
+ return 0, _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
363
375
 
364
376
  def params_fn(state: LbfgsbState) -> PyTree:
365
377
  """Returns the parameters for the given `state`."""
366
- params, _, _ = state
378
+ _, params, _, _ = state
367
379
  return params
368
380
 
369
381
  def update_fn(
@@ -390,7 +402,7 @@ def parameterized_lbfgsb(
390
402
  flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
391
403
  return flat_latent_params, scipy_lbfgsb_state.to_jax()
392
404
 
393
- _, latent_params, jax_lbfgsb_state = state
405
+ step, _, latent_params, jax_lbfgsb_state = state
394
406
  _, vjp_fn = jax.vjp(_params_from_latents, latent_params)
395
407
  (latent_grad,) = vjp_fn(grad)
396
408
 
@@ -427,7 +439,9 @@ def parameterized_lbfgsb(
427
439
  jax_lbfgsb_state,
428
440
  )
429
441
  latent_params = unflatten_fn(flat_latent_params)
430
- return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
442
+ latent_params = _update_parameterized_densities(latent_params, step)
443
+ params = _params_from_latents(latent_params)
444
+ return step + 1, params, latent_params, jax_lbfgsb_state
431
445
 
432
446
  return base.Optimizer(
433
447
  init=init_fn,
@@ -438,7 +452,7 @@ def parameterized_lbfgsb(
438
452
 
439
453
  def is_converged(state: LbfgsbState) -> jnp.ndarray:
440
454
  """Returns `True` if the optimization has converged."""
441
- return state[2]["converged"]
455
+ return state[3]["converged"]
442
456
 
443
457
 
444
458
  # ------------------------------------------------------------------------------
@@ -20,7 +20,7 @@ from invrs_opt.parameterization import (
20
20
  )
21
21
 
22
22
  PyTree = Any
23
- WrappedOptaxState = Tuple[PyTree, PyTree, PyTree]
23
+ WrappedOptaxState = Tuple[int, PyTree, PyTree, PyTree]
24
24
 
25
25
 
26
26
  def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
@@ -178,56 +178,19 @@ def parameterized_wrapped_optax(
178
178
  if density_parameterization is None:
179
179
  density_parameterization = pixel.pixel()
180
180
 
181
- def _init_latents(params: PyTree) -> PyTree:
182
- def _leaf_init_latents(leaf: Any) -> Any:
183
- leaf = _clip(leaf)
184
- if not _is_density(leaf):
185
- return leaf
186
- return density_parameterization.from_density(leaf)
187
-
188
- return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
189
-
190
- def _params_from_latents(params: PyTree) -> PyTree:
191
- def _leaf_params_from_latents(leaf: Any) -> Any:
192
- if not _is_parameterized_density(leaf):
193
- return leaf
194
- return density_parameterization.to_density(leaf)
195
-
196
- return tree_util.tree_map(
197
- _leaf_params_from_latents,
198
- params,
199
- is_leaf=_is_parameterized_density,
200
- )
201
-
202
- def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
203
- def _constraint_loss_leaf(
204
- params: parameterization_base.ParameterizedDensity2DArrayBase,
205
- ) -> jnp.ndarray:
206
- constraints = density_parameterization.constraints(params)
207
- constraints = tree_util.tree_map(
208
- lambda x: jnp.sum(jnp.maximum(x, 0.0)),
209
- constraints,
210
- )
211
- return jnp.sum(jnp.asarray(constraints))
212
-
213
- losses = [0.0] + [
214
- _constraint_loss_leaf(p)
215
- for p in tree_util.tree_leaves(
216
- latent_params, is_leaf=_is_parameterized_density
217
- )
218
- if _is_parameterized_density(p)
219
- ]
220
- return penalty * jnp.sum(jnp.asarray(losses))
221
-
222
181
  def init_fn(params: PyTree) -> WrappedOptaxState:
223
182
  """Initializes the optimization state."""
224
183
  latent_params = _init_latents(params)
225
- params = _params_from_latents(latent_params)
226
- return params, latent_params, opt.init(latent_params)
184
+ return (
185
+ 0, # step
186
+ _params_from_latents(latent_params), # params
187
+ latent_params, # latent params
188
+ opt.init(tree_util.tree_leaves(latent_params)), # opt state
189
+ )
227
190
 
228
191
  def params_fn(state: WrappedOptaxState) -> PyTree:
229
192
  """Returns the parameters for the given `state`."""
230
- params, _, _ = state
193
+ _, params, _, _ = state
231
194
  return params
232
195
 
233
196
  def update_fn(
@@ -240,7 +203,8 @@ def parameterized_wrapped_optax(
240
203
  """Updates the state."""
241
204
  del value, params
242
205
 
243
- _, latent_params, opt_state = state
206
+ step, params, latent_params, opt_state = state
207
+
244
208
  _, vjp_fn = jax.vjp(_params_from_latents, latent_params)
245
209
  (latent_grad,) = vjp_fn(grad)
246
210
 
@@ -259,13 +223,85 @@ def parameterized_wrapped_optax(
259
223
  lambda a, b: a + b, latent_grad, constraint_loss_grad
260
224
  )
261
225
 
262
- updates, opt_state = opt.update(
263
- updates=latent_grad, state=opt_state, params=latent_params
226
+ updates_leaves, opt_state = opt.update(
227
+ updates=tree_util.tree_leaves(latent_grad),
228
+ state=opt_state,
229
+ params=tree_util.tree_leaves(latent_params),
230
+ )
231
+ latent_params_leaves = optax.apply_updates(
232
+ params=tree_util.tree_leaves(latent_params),
233
+ updates=updates_leaves,
234
+ )
235
+ latent_params = tree_util.tree_unflatten(
236
+ treedef=tree_util.tree_structure(latent_params),
237
+ leaves=latent_params_leaves,
264
238
  )
265
- latent_params = optax.apply_updates(params=latent_params, updates=updates)
239
+
266
240
  latent_params = _clip(latent_params)
241
+ latent_params = _update_parameterized_densities(latent_params, step + 1)
267
242
  params = _params_from_latents(latent_params)
268
- return params, latent_params, opt_state
243
+ return (step + 1, params, latent_params, opt_state)
244
+
245
+ # -------------------------------------------------------------------------
246
+ # Functions related to the density parameterization.
247
+ # -------------------------------------------------------------------------
248
+
249
+ def _init_latents(params: PyTree) -> PyTree:
250
+ def _leaf_init_latents(leaf: Any) -> Any:
251
+ leaf = _clip(leaf)
252
+ if not _is_density(leaf):
253
+ return leaf
254
+ return density_parameterization.from_density(leaf)
255
+
256
+ return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
257
+
258
+ def _params_from_latents(params: PyTree) -> PyTree:
259
+ def _leaf_params_from_latents(leaf: Any) -> Any:
260
+ if not _is_parameterized_density(leaf):
261
+ return leaf
262
+ return density_parameterization.to_density(leaf)
263
+
264
+ return tree_util.tree_map(
265
+ _leaf_params_from_latents,
266
+ params,
267
+ is_leaf=_is_parameterized_density,
268
+ )
269
+
270
+ def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
271
+ def _update_leaf(leaf: Any) -> Any:
272
+ if not _is_parameterized_density(leaf):
273
+ return leaf
274
+ return density_parameterization.update(leaf, step)
275
+
276
+ return tree_util.tree_map(
277
+ _update_leaf,
278
+ latent_params,
279
+ is_leaf=_is_parameterized_density,
280
+ )
281
+
282
+ # -------------------------------------------------------------------------
283
+ # Functions related to the constraints to be minimized.
284
+ # -------------------------------------------------------------------------
285
+
286
+ def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
287
+ def _constraint_loss_leaf(
288
+ params: parameterization_base.ParameterizedDensity2DArrayBase,
289
+ ) -> jnp.ndarray:
290
+ constraints = density_parameterization.constraints(params)
291
+ constraints = tree_util.tree_map(
292
+ lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
293
+ constraints,
294
+ )
295
+ return jnp.sum(jnp.asarray(constraints))
296
+
297
+ losses = [0.0] + [
298
+ _constraint_loss_leaf(p)
299
+ for p in tree_util.tree_leaves(
300
+ latent_params, is_leaf=_is_parameterized_density
301
+ )
302
+ if _is_parameterized_density(p)
303
+ ]
304
+ return penalty * jnp.sum(jnp.asarray(losses))
269
305
 
270
306
  return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
271
307
 
@@ -44,6 +44,13 @@ class ConstraintsFn(Protocol):
44
44
  ...
45
45
 
46
46
 
47
+ class UpdateFn(Protocol):
48
+ """Performs the required update of a parameterized density for the given step."""
49
+
50
+ def __call__(self, params: PyTree, step: int) -> PyTree:
51
+ ...
52
+
53
+
47
54
  @dataclasses.dataclass
48
55
  class Density2DParameterization:
49
56
  """Stores `(from_density, to_density, constraints)` function triple."""
@@ -51,6 +58,7 @@ class Density2DParameterization:
51
58
  from_density: FromDensityFn
52
59
  to_density: ToDensityFn
53
60
  constraints: ConstraintsFn
61
+ update: UpdateFn
54
62
 
55
63
 
56
64
  @dataclasses.dataclass
@@ -85,8 +85,14 @@ def filter_project(beta: float) -> base.Density2DParameterization:
85
85
  del params
86
86
  return jnp.asarray(0.0)
87
87
 
88
+ def update_fn(params: FilterAndProjectParams, step: int) -> FilterAndProjectParams:
89
+ """Perform updates to `params` required for the given `step`."""
90
+ del step
91
+ return params
92
+
88
93
  return base.Density2DParameterization(
89
94
  to_density=to_density_fn,
90
95
  from_density=from_density_fn,
91
96
  constraints=constraints_fn,
97
+ update=update_fn,
92
98
  )
@@ -214,7 +214,7 @@ def gaussian_levelset(
214
214
  return params, state
215
215
 
216
216
  state = init_optimizer.init(params)
217
- params, state = jax.lax.fori_loop(
217
+ params, _ = jax.lax.fori_loop(
218
218
  0, init_steps, body_fun=step_fn, init_val=(params, state)
219
219
  )
220
220
 
@@ -269,10 +269,16 @@ def gaussian_levelset(
269
269
  )
270
270
  return constraints / length_scale**2
271
271
 
272
+ def update_fn(params: GaussianLevelsetParams, step: int) -> GaussianLevelsetParams:
273
+ """Perform updates to `params` required for the given `step`."""
274
+ del step
275
+ return params
276
+
272
277
  return base.Density2DParameterization(
273
278
  to_density=to_density_fn,
274
279
  from_density=from_density_fn,
275
280
  constraints=constraints_fn,
281
+ update=update_fn,
276
282
  )
277
283
 
278
284
 
@@ -38,8 +38,13 @@ def pixel() -> base.Density2DParameterization:
38
38
  del params
39
39
  return jnp.asarray(0.0)
40
40
 
41
+ def update_fn(params: PixelParams, step: int) -> PixelParams:
42
+ del step
43
+ return params
44
+
41
45
  return base.Density2DParameterization(
42
46
  from_density=from_density_fn,
43
47
  to_density=to_density_fn,
44
48
  constraints=constraints_fn,
49
+ update=update_fn,
45
50
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.7.2
3
+ Version: 0.8.1
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
533
533
  Requires-Dist: pytest-subtests; extra == "tests"
534
534
 
535
535
  # invrs-opt - Optimization algorithms for inverse design
536
- `v0.7.2`
536
+ `v0.8.1`
537
537
 
538
538
  ## Overview
539
539
 
@@ -0,0 +1,20 @@
1
+ invrs_opt/__init__.py,sha256=tjbabqQKp5y1I5eqDa7onsy7E58P9nTk_Ifn12oMfZM,585
2
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
5
+ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
+ invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
8
+ invrs_opt/optimizers/lbfgsb.py,sha256=2NcyFllUM5CrjOZLEBocYAlCrkWt1fiRGJrxg05waog,35149
9
+ invrs_opt/optimizers/wrapped_optax.py,sha256=c4iSsXHF2cJZaZhdxLpyRy26q5IxBTo4-6sY8IOLT38,13259
10
+ invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ invrs_opt/parameterization/base.py,sha256=k2VXZorMnkmaIWWY_G2CvS0NtxYojfvKlPccDGZyGNA,3779
12
+ invrs_opt/parameterization/filter_project.py,sha256=D2w8xrg34V8ysIbbr1RPvegM5WoLdz8QKUCGi80ieOI,3466
13
+ invrs_opt/parameterization/gaussian_levelset.py,sha256=DllvpkBpVxuDFQFu941f8v_Xh2D0m9Eh-JxBH5afNHU,25221
14
+ invrs_opt/parameterization/pixel.py,sha256=yWdGpf0x6Om_-7CYOcHTgNK9UF-XGAPM7GbOBNijnBw,1305
15
+ invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
+ invrs_opt-0.8.1.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ invrs_opt-0.8.1.dist-info/METADATA,sha256=4qLsuDLp3pNn5C1TTPnF0qwWps7MeN_Uc3rKsBLbPrk,32633
18
+ invrs_opt-0.8.1.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
19
+ invrs_opt-0.8.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
+ invrs_opt-0.8.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.1.0)
2
+ Generator: setuptools (72.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,20 +0,0 @@
1
- invrs_opt/__init__.py,sha256=_hiwGthyPapJUx8-nP7JrJYEFX3mJYTsdUh_X0MyhEQ,585
2
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
5
- invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
- invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
8
- invrs_opt/optimizers/lbfgsb.py,sha256=d7i02NZZ3yYdJg7wkERDMPpdBD3GaRPhmQexGrZPz_Y,34597
9
- invrs_opt/optimizers/wrapped_optax.py,sha256=L836gwzWbwxPNWh8Y7PSgFHV4PZluSklnbh3BR5djlc,11859
10
- invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- invrs_opt/parameterization/base.py,sha256=QTnhOfMYbDchZOFzk9graryMd6rYHlyd7E2T1TtucB8,3570
12
- invrs_opt/parameterization/filter_project.py,sha256=XCPqQ2ECv7DDTLRtVGJePfnjKYB2XndI6DssSr-4MZw,3239
13
- invrs_opt/parameterization/gaussian_levelset.py,sha256=IGdQl3XMEHYNxNPSoKmUD7ZJu_fImIZTqqqmvilfmag,24998
14
- invrs_opt/parameterization/pixel.py,sha256=4qCYDUCcFPr8W94whX5YFllrXgyZbPBVIJBf8m_Dv4k,1173
15
- invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
- invrs_opt-0.7.2.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- invrs_opt-0.7.2.dist-info/METADATA,sha256=jx_MxaNkrJRUBRubRlDIytDRznP_1P9sWRHrqihaKi8,32633
18
- invrs_opt-0.7.2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
19
- invrs_opt-0.7.2.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
- invrs_opt-0.7.2.dist-info/RECORD,,