invrs-opt 0.9.4__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.9.4"
6
+ __version__ = "v0.10.1"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -132,7 +132,11 @@ def optimizer_client(
132
132
  response = json.loads(get_response.text)
133
133
  return json_utils.pytree_from_json(response[labels.PARAMS])
134
134
 
135
- return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
135
+ return base.Optimizer(
136
+ init=init_fn,
137
+ update=update_fn, # type: ignore[arg-type]
138
+ params=params_fn,
139
+ )
136
140
 
137
141
 
138
142
  # -----------------------------------------------------------------------------
@@ -7,6 +7,7 @@ import dataclasses
7
7
  import inspect
8
8
  from typing import Any, Protocol
9
9
 
10
+ import jax.numpy as jnp
10
11
  import optax # type: ignore[import-untyped]
11
12
  from totypes import json_utils
12
13
 
@@ -34,7 +35,7 @@ class UpdateFn(Protocol):
34
35
  self,
35
36
  *,
36
37
  grad: PyTree,
37
- value: float,
38
+ value: jnp.ndarray,
38
39
  params: PyTree,
39
40
  state: PyTree,
40
41
  ) -> PyTree:
@@ -317,7 +317,10 @@ def parameterized_lbfgsb(
317
317
  latent_params = _init_latents(params)
318
318
  metadata, latents = param_base.partition_density_metadata(latent_params)
319
319
  latents, jax_lbfgsb_state = jax.pure_callback(
320
- _init_state_pure, _example_state(latents, maxcor), latents
320
+ _init_state_pure,
321
+ _example_state(latents, maxcor),
322
+ latents,
323
+ vmap_method="sequential",
321
324
  )
322
325
  latent_params = param_base.combine_density_metadata(metadata, latents)
323
326
  return (
@@ -335,7 +338,7 @@ def parameterized_lbfgsb(
335
338
  def update_fn(
336
339
  *,
337
340
  grad: PyTree,
338
- value: float,
341
+ value: jnp.ndarray,
339
342
  params: PyTree,
340
343
  state: LbfgsbState,
341
344
  ) -> LbfgsbState:
@@ -346,15 +349,18 @@ def parameterized_lbfgsb(
346
349
  flat_latent_grad: PyTree,
347
350
  value: jnp.ndarray,
348
351
  jax_lbfgsb_state: JaxLbfgsbDict,
349
- ) -> Tuple[PyTree, NumpyLbfgsbDict]:
352
+ ) -> Tuple[NDArray, NumpyLbfgsbDict]:
350
353
  assert onp.size(value) == 1
351
354
  scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
355
+ flat_latent_params = scipy_lbfgsb_state.x.copy()
352
356
  scipy_lbfgsb_state.update(
353
357
  grad=onp.array(flat_latent_grad, dtype=onp.float64),
354
358
  value=onp.array(value, dtype=onp.float64),
355
359
  )
356
- flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
357
- return flat_latent_params, scipy_lbfgsb_state.to_dict()
360
+ updated_flat_latent_params = scipy_lbfgsb_state.x
361
+ flat_latent_updates: NDArray
362
+ flat_latent_updates = updated_flat_latent_params - flat_latent_params
363
+ return flat_latent_updates, scipy_lbfgsb_state.to_dict()
358
364
 
359
365
  step, _, latent_params, jax_lbfgsb_state = state
360
366
  metadata, latents = param_base.partition_density_metadata(latent_params)
@@ -395,16 +401,22 @@ def parameterized_lbfgsb(
395
401
  latents_grad
396
402
  ) # type: ignore[no-untyped-call]
397
403
 
398
- flat_latents, jax_lbfgsb_state = jax.pure_callback(
404
+ flat_latent_updates, jax_lbfgsb_state = jax.pure_callback(
399
405
  _update_pure,
400
406
  (flat_latents_grad, jax_lbfgsb_state),
401
407
  flat_latents_grad,
402
408
  value,
403
409
  jax_lbfgsb_state,
410
+ vmap_method="sequential",
404
411
  )
405
- latents = unflatten_fn(flat_latents)
406
- latent_params = param_base.combine_density_metadata(metadata, latents)
407
- latent_params = _update_parameterized_densities(latent_params, step)
412
+ latent_updates = unflatten_fn(flat_latent_updates)
413
+ latent_params = _apply_updates(
414
+ params=latent_params,
415
+ updates=param_base.combine_density_metadata(metadata, latent_updates),
416
+ value=value,
417
+ step=step,
418
+ )
419
+ latent_params = _clip(latent_params)
408
420
  params = _params_from_latent_params(latent_params)
409
421
  return step + 1, params, latent_params, jax_lbfgsb_state
410
422
 
@@ -433,15 +445,24 @@ def parameterized_lbfgsb(
433
445
  is_leaf=_is_parameterized_density,
434
446
  )
435
447
 
436
- def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
437
- def _update_leaf(leaf: Any) -> Any:
438
- if not _is_parameterized_density(leaf):
439
- return leaf
440
- return density_parameterization.update(leaf, step)
448
+ def _apply_updates(
449
+ params: PyTree,
450
+ updates: PyTree,
451
+ value: jnp.ndarray,
452
+ step: int,
453
+ ) -> PyTree:
454
+ def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
455
+ if _is_parameterized_density(leaf):
456
+ return density_parameterization.update(
457
+ params=leaf, updates=update, value=value, step=step
458
+ )
459
+ else:
460
+ return optax.apply_updates(params=leaf, updates=update)
441
461
 
442
462
  return tree_util.tree_map(
443
- _update_leaf,
444
- latent_params,
463
+ _leaf_apply_updates,
464
+ updates,
465
+ params,
445
466
  is_leaf=_is_parameterized_density,
446
467
  )
447
468
 
@@ -197,12 +197,12 @@ def parameterized_wrapped_optax(
197
197
  def update_fn(
198
198
  *,
199
199
  grad: PyTree,
200
- value: float,
200
+ value: jnp.ndarray,
201
201
  params: PyTree,
202
202
  state: WrappedOptaxState,
203
203
  ) -> WrappedOptaxState:
204
204
  """Updates the state."""
205
- del value, params
205
+ del params
206
206
 
207
207
  step, params, latent_params, opt_state = state
208
208
  metadata, latents = param_base.partition_density_metadata(latent_params)
@@ -233,12 +233,14 @@ def parameterized_wrapped_optax(
233
233
  lambda a, b: a + b, latents_grad, constraint_loss_grad
234
234
  )
235
235
 
236
- updates, opt_state = opt.update(latents_grad, state=opt_state, params=latents)
237
- latents = optax.apply_updates(params=latents, updates=updates)
238
-
239
- latent_params = param_base.combine_density_metadata(metadata, latents)
236
+ latent_updates, opt_state = opt.update(latents_grad, opt_state, params=latents)
237
+ latent_params = _apply_updates(
238
+ params=latent_params,
239
+ updates=param_base.combine_density_metadata(metadata, latent_updates),
240
+ value=value,
241
+ step=step,
242
+ )
240
243
  latent_params = _clip(latent_params)
241
- latent_params = _update_parameterized_densities(latent_params, step + 1)
242
244
  params = _params_from_latent_params(latent_params)
243
245
  return (step + 1, params, latent_params, opt_state)
244
246
 
@@ -267,15 +269,24 @@ def parameterized_wrapped_optax(
267
269
  is_leaf=_is_parameterized_density,
268
270
  )
269
271
 
270
- def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
271
- def _update_leaf(leaf: Any) -> Any:
272
- if not _is_parameterized_density(leaf):
273
- return leaf
274
- return density_parameterization.update(leaf, step)
272
+ def _apply_updates(
273
+ params: PyTree,
274
+ updates: PyTree,
275
+ value: jnp.ndarray,
276
+ step: int,
277
+ ) -> PyTree:
278
+ def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
279
+ if _is_parameterized_density(leaf):
280
+ return density_parameterization.update(
281
+ params=leaf, updates=update, value=value, step=step
282
+ )
283
+ else:
284
+ return optax.apply_updates(params=leaf, updates=update)
275
285
 
276
286
  return tree_util.tree_map(
277
- _update_leaf,
278
- latent_params,
287
+ _leaf_apply_updates,
288
+ updates,
289
+ params,
279
290
  is_leaf=_is_parameterized_density,
280
291
  )
281
292
 
@@ -97,7 +97,13 @@ class ConstraintsFn(Protocol):
97
97
  class UpdateFn(Protocol):
98
98
  """Performs the required update of a parameterized density for the given step."""
99
99
 
100
- def __call__(self, params: PyTree, step: int) -> PyTree:
100
+ def __call__(
101
+ self,
102
+ params: PyTree,
103
+ updates: PyTree,
104
+ value: jnp.ndarray,
105
+ step: int,
106
+ ) -> PyTree:
101
107
  ...
102
108
 
103
109
 
@@ -115,10 +115,20 @@ def filter_project(beta: float) -> base.Density2DParameterization:
115
115
  del params
116
116
  return jnp.asarray(0.0)
117
117
 
118
- def update_fn(params: FilterProjectParams, step: int) -> FilterProjectParams:
118
+ def update_fn(
119
+ params: FilterProjectParams,
120
+ updates: FilterProjectParams,
121
+ value: jnp.ndarray,
122
+ step: int,
123
+ ) -> FilterProjectParams:
119
124
  """Perform updates to `params` required for the given `step`."""
120
- del step
121
- return params
125
+ del step, value
126
+ return FilterProjectParams(
127
+ latents=tree_util.tree_map(
128
+ lambda a, b: a + b, params.latents, updates.latents
129
+ ),
130
+ metadata=params.metadata,
131
+ )
122
132
 
123
133
  return base.Density2DParameterization(
124
134
  to_density=to_density_fn,
@@ -229,10 +229,20 @@ def gaussian_levelset(
229
229
  pad_pixels=pad_pixels,
230
230
  )
231
231
 
232
- def update_fn(params: GaussianLevelsetParams, step: int) -> GaussianLevelsetParams:
232
+ def update_fn(
233
+ params: GaussianLevelsetParams,
234
+ updates: GaussianLevelsetParams,
235
+ value: jnp.ndarray,
236
+ step: int,
237
+ ) -> GaussianLevelsetParams:
233
238
  """Perform updates to `params` required for the given `step`."""
234
- del step
235
- return params
239
+ del step, value
240
+ return GaussianLevelsetParams(
241
+ latents=tree_util.tree_map(
242
+ lambda a, b: a + b, params.latents, updates.latents
243
+ ),
244
+ metadata=params.metadata,
245
+ )
236
246
 
237
247
  return base.Density2DParameterization(
238
248
  to_density=to_density_fn,
@@ -509,11 +519,13 @@ def _levelset_constraints(
509
519
  )
510
520
 
511
521
  d = minimum_length_scale * length_scale_constraint_factor
512
- length_scale_constraint = (
513
- jnp.abs(phi_vv) / (jnp.pi / d * jnp.abs(phi) + beta * phi_v) - jnp.pi / d
514
- )
522
+ denom = jnp.pi / d * jnp.abs(phi) + beta * phi_v
523
+ denom_safe = jnp.where(jnp.isclose(phi_vv, 0.0), 1.0, denom)
524
+ length_scale_constraint = jnp.abs(phi_vv) / denom_safe - jnp.pi / d
525
+
526
+ curvature_denom_safe = jnp.where(jnp.isclose(phi_v, 0.0), 1.0, phi)
515
527
  curvature_constraint = (
516
- jnp.abs(inverse_radius * jnp.arctan(phi_v / phi)) - jnp.pi / d
528
+ jnp.abs(inverse_radius * jnp.arctan(phi_v / curvature_denom_safe)) - jnp.pi / d
517
529
  )
518
530
 
519
531
  # Downsample so that constraints shape matches the density shape.
@@ -52,9 +52,20 @@ def pixel() -> base.Density2DParameterization:
52
52
  del params
53
53
  return jnp.asarray(0.0)
54
54
 
55
- def update_fn(params: PixelParams, step: int) -> PixelParams:
56
- del step
57
- return params
55
+ def update_fn(
56
+ params: PixelParams,
57
+ updates: PixelParams,
58
+ value: jnp.ndarray,
59
+ step: int,
60
+ ) -> PixelParams:
61
+ """Perform updates to `params` required for the given `step`."""
62
+ del step, value
63
+ return PixelParams(
64
+ latents=tree_util.tree_map(
65
+ lambda a, b: a + b, params.latents, updates.latents
66
+ ),
67
+ metadata=params.metadata,
68
+ )
58
69
 
59
70
  return base.Density2DParameterization(
60
71
  from_density=from_density_fn,
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.9.4
3
+ Version: 0.10.1
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
7
- License: GNU LESSER GENERAL PUBLIC LICENSE
7
+ License: GNU LESSER GENERAL PUBLIC LICENSE
8
8
  Version 2.1, February 1999
9
9
 
10
10
  Copyright (C) 1991, 1999 Free Software Foundation, Inc.
@@ -513,7 +513,7 @@ Keywords: topology,optimization,jax,inverse design
513
513
  Requires-Python: >=3.7
514
514
  Description-Content-Type: text/markdown
515
515
  License-File: LICENSE
516
- Requires-Dist: jax<=0.4.35
516
+ Requires-Dist: jax>=0.4.35
517
517
  Requires-Dist: jaxlib
518
518
  Requires-Dist: numpy
519
519
  Requires-Dist: requests
@@ -521,19 +521,19 @@ Requires-Dist: optax
521
521
  Requires-Dist: scipy
522
522
  Requires-Dist: totypes
523
523
  Requires-Dist: types-requests
524
- Provides-Extra: dev
525
- Requires-Dist: bump-my-version; extra == "dev"
526
- Requires-Dist: darglint; extra == "dev"
527
- Requires-Dist: mypy; extra == "dev"
528
- Requires-Dist: pre-commit; extra == "dev"
529
524
  Provides-Extra: tests
530
525
  Requires-Dist: parameterized; extra == "tests"
531
526
  Requires-Dist: pytest; extra == "tests"
532
527
  Requires-Dist: pytest-cov; extra == "tests"
533
528
  Requires-Dist: pytest-subtests; extra == "tests"
529
+ Provides-Extra: dev
530
+ Requires-Dist: bump-my-version; extra == "dev"
531
+ Requires-Dist: darglint; extra == "dev"
532
+ Requires-Dist: mypy; extra == "dev"
533
+ Requires-Dist: pre-commit; extra == "dev"
534
534
 
535
535
  # invrs-opt - Optimization algorithms for inverse design
536
- `v0.9.4`
536
+ `v0.10.1`
537
537
 
538
538
  ## Overview
539
539
 
@@ -0,0 +1,20 @@
1
+ invrs_opt/__init__.py,sha256=bd6NWb7l6z0v6kQif3vfX3xgxdlZrp9L9EFt_Ds20UU,586
2
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/experimental/client.py,sha256=tbtH13FrA65XmTZfTO71CxJ78jeAEj3Zf85R-MTwbiU,4909
5
+ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
+ invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ invrs_opt/optimizers/base.py,sha256=uFfkN2LwWzAtwh6ktwWNy2iHNOY-sW3JzI46iSFkgok,1306
8
+ invrs_opt/optimizers/lbfgsb.py,sha256=kIluuCL1TjkNrdiwqsHEzqz7PEnPa0txpGa3rAw1GXU,36983
9
+ invrs_opt/optimizers/wrapped_optax.py,sha256=781-8v_TlHsGaQF9Se9_iOEvtOLOr-BesTLudYarSlg,13685
10
+ invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ invrs_opt/parameterization/base.py,sha256=jSwrEO86lGkYQG5gWsHvcIMWpZnnbdiKpn--2qaU02g,5362
12
+ invrs_opt/parameterization/filter_project.py,sha256=XL3HTEBLrF-q_75TjhOWLNdfUOSEEjKcoM7Qj844QpQ,4590
13
+ invrs_opt/parameterization/gaussian_levelset.py,sha256=-6foekLTFoZDtMKuoMEvdxMJt0_zTxrKNJo0Vn-Rv80,26073
14
+ invrs_opt/parameterization/pixel.py,sha256=YWkyBhfYtzI8cQ-M90PAZqRAbabwVaUh0UiYIGegQHI,1955
15
+ invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
+ invrs_opt-0.10.1.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ invrs_opt-0.10.1.dist-info/METADATA,sha256=gCJPUnv5b3erC7R4Fb-Hlhbo9uAjYK0fgTYL0qFg7Jg,32661
18
+ invrs_opt-0.10.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
19
+ invrs_opt-0.10.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
+ invrs_opt-0.10.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,20 +0,0 @@
1
- invrs_opt/__init__.py,sha256=bVIs0NxPxNuRUOypBIE68qx-SA1lFCmqmo5cqvRpxmU,585
2
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
5
- invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
- invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
8
- invrs_opt/optimizers/lbfgsb.py,sha256=9DrmyCj4Ny04NFVRUTz3pJypbw_j5Gw4wpKfe0WKEv4,36336
9
- invrs_opt/optimizers/wrapped_optax.py,sha256=VXdCteT2kumqhP81l3p6QiEqwBffoUuJ3UjrAyX5ToA,13468
10
- invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- invrs_opt/parameterization/base.py,sha256=BObzbz6efT2nBjib0_5BSdkCmFi2f0mcZ9VJYpDzO6Q,5278
12
- invrs_opt/parameterization/filter_project.py,sha256=7Jb8JVENmBTdx3-XmI-VRm4aMjxg_Wtin8tMKKKxWvQ,4309
13
- invrs_opt/parameterization/gaussian_levelset.py,sha256=Ka4hW_OLxUaIPHQsyIOlryG7i1mC-LTIgiQQhCPwHPk,25626
14
- invrs_opt/parameterization/pixel.py,sha256=AwC4GBNNOysdICvYHv_D2tZdqJmYiRzOUZNq_-R9Z70,1617
15
- invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
- invrs_opt-0.9.4.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- invrs_opt-0.9.4.dist-info/METADATA,sha256=SOt6aECrbTQNI0QuznoPrWKzUdMeahaST0G4Q7sccxE,32641
18
- invrs_opt-0.9.4.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
19
- invrs_opt-0.9.4.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
- invrs_opt-0.9.4.dist-info/RECORD,,