invrs-opt 0.10.5__tar.gz → 0.10.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/PKG-INFO +3 -2
  2. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/pyproject.toml +1 -1
  3. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/__init__.py +1 -1
  4. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/optimizers/lbfgsb.py +3 -4
  5. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/optimizers/wrapped_optax.py +3 -4
  6. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/parameterization/base.py +31 -0
  7. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/parameterization/gaussian_levelset.py +2 -39
  8. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/parameterization/transforms.py +37 -0
  9. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt.egg-info/PKG-INFO +3 -2
  10. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/LICENSE +0 -0
  11. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/README.md +0 -0
  12. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/setup.cfg +0 -0
  13. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/experimental/__init__.py +0 -0
  14. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/experimental/client.py +0 -0
  15. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/experimental/labels.py +0 -0
  16. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/optimizers/__init__.py +0 -0
  17. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/optimizers/base.py +0 -0
  18. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/parameterization/__init__.py +0 -0
  19. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/parameterization/filter_project.py +0 -0
  20. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/parameterization/pixel.py +0 -0
  21. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt/py.typed +0 -0
  22. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt.egg-info/SOURCES.txt +0 -0
  23. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt.egg-info/dependency_links.txt +0 -0
  24. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt.egg-info/requires.txt +0 -0
  25. {invrs_opt-0.10.5 → invrs_opt-0.10.7}/src/invrs_opt.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: invrs_opt
3
- Version: 0.10.5
3
+ Version: 0.10.7
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -531,6 +531,7 @@ Requires-Dist: bump-my-version; extra == "dev"
531
531
  Requires-Dist: darglint; extra == "dev"
532
532
  Requires-Dist: mypy; extra == "dev"
533
533
  Requires-Dist: pre-commit; extra == "dev"
534
+ Dynamic: license-file
534
535
 
535
536
  # invrs-opt - Optimization algorithms for inverse design
536
537
  ![Continuous integration](https://github.com/invrs-io/opt/actions/workflows/build-ci.yml/badge.svg)
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
 
3
3
  name = "invrs_opt"
4
- version = "v0.10.5"
4
+ version = "v0.10.7"
5
5
  description = "Algorithms for inverse design"
6
6
  keywords = ["topology", "optimization", "jax", "inverse design"]
7
7
  readme = "README.md"
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.10.5"
6
+ __version__ = "v0.10.7"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -380,10 +380,9 @@ def parameterized_lbfgsb(
380
380
  _, vjp_fn = jax.vjp(_params_from_latents, latents)
381
381
  (latents_grad,) = vjp_fn(grad)
382
382
 
383
- if not (
384
- tree_util.tree_structure(latents_grad)
385
- == tree_util.tree_structure(latents) # type: ignore[operator]
386
- ):
383
+ treedef = tree_util.tree_structure(latents_grad)
384
+ expected_treedef = tree_util.tree_structure(latents)
385
+ if not treedef == expected_treedef: # type: ignore[operator]
387
386
  raise ValueError(
388
387
  f"Tree structure of `latents_grad` was different than expected, got \n"
389
388
  f"{tree_util.tree_structure(latents_grad)} but expected \n"
@@ -218,10 +218,9 @@ def parameterized_wrapped_optax(
218
218
  _, vjp_fn = jax.vjp(_params_from_latents, latents)
219
219
  (latents_grad,) = vjp_fn(grad)
220
220
 
221
- if not (
222
- tree_util.tree_structure(latents_grad)
223
- == tree_util.tree_structure(latents) # type: ignore[operator]
224
- ):
221
+ treedef = tree_util.tree_structure(latents_grad)
222
+ expected_treedef = tree_util.tree_structure(latents)
223
+ if not treedef == expected_treedef: # type: ignore[operator]
225
224
  raise ValueError(
226
225
  f"Tree structure of `latents_grad` was different than expected, got \n"
227
226
  f"{tree_util.tree_structure(latents_grad)} but expected \n"
@@ -124,6 +124,22 @@ class Density2DMetadata:
124
124
  self.periodic = tuple(self.periodic)
125
125
  self.symmetries = tuple(self.symmetries)
126
126
 
127
+ def __eq__(self, other: Any) -> bool:
128
+ if not isinstance(other, Density2DMetadata):
129
+ return False
130
+ if not (
131
+ self.lower_bound == other.lower_bound
132
+ and self.upper_bound == other.upper_bound
133
+ and _arrays_equal_or_both_none(self.fixed_solid, other.fixed_solid)
134
+ and _arrays_equal_or_both_none(self.fixed_void, other.fixed_void)
135
+ and self.minimum_width == other.minimum_width
136
+ and self.minimum_spacing == other.minimum_spacing
137
+ and self.periodic == other.periodic
138
+ and self.symmetries == other.symmetries
139
+ ):
140
+ return False
141
+ return True
142
+
127
143
  @classmethod
128
144
  def from_density(self, density: types.Density2DArray) -> "Density2DMetadata":
129
145
  density_metadata_dict = dataclasses.asdict(density)
@@ -131,6 +147,21 @@ class Density2DMetadata:
131
147
  return Density2DMetadata(**density_metadata_dict)
132
148
 
133
149
 
150
+ def _arrays_equal_or_both_none(a: Optional[Array], b: Optional[Array]) -> bool:
151
+ """Return `True` if `a` and `b` are equal arrays or both `None`."""
152
+ if (a is None, b is None) not in ((True, True), (False, False)):
153
+ return False
154
+ if a is None and b is None:
155
+ return True
156
+ assert isinstance(a, onp.ndarray)
157
+ assert isinstance(b, onp.ndarray)
158
+ if a.dtype != b.dtype:
159
+ return False
160
+ if a.shape != b.shape:
161
+ return False
162
+ return bool(onp.all(a == b))
163
+
164
+
134
165
  def _flatten_density_2d_metadata(
135
166
  metadata: Density2DMetadata,
136
167
  ) -> Tuple[
@@ -388,7 +388,7 @@ def _phi_from_params(
388
388
  assert array.shape[-2] % s_factor == 0
389
389
  assert array.shape[-1] % s_factor == 0
390
390
  array = symmetry.symmetrize(array, tuple(example_density.symmetries))
391
- return array
391
+ return jnp.asarray(array)
392
392
 
393
393
 
394
394
  # -----------------------------------------------------------------------------
@@ -619,49 +619,12 @@ def _levelset_threshold(
619
619
  ) -> jnp.ndarray:
620
620
  """Thresholds a level set function `phi`."""
621
621
  if mask_gradient:
622
- interface = _interface_pixels(phi, periodic)
622
+ interface = transforms.interface_pixels(phi, periodic)
623
623
  phi = jnp.where(interface, phi, jax.lax.stop_gradient(phi))
624
624
  thresholded = (phi > 0).astype(float) + (phi - jax.lax.stop_gradient(phi))
625
625
  return thresholded
626
626
 
627
627
 
628
- def _interface_pixels(phi: jnp.ndarray, periodic: Tuple[bool, bool]) -> jnp.ndarray:
629
- """Identifies interface pixels of a level set function `phi`."""
630
- batch_shape = phi.shape[:-2]
631
- phi = phi.reshape((-1,) + phi.shape[-2:])
632
-
633
- pad_mode = (
634
- "wrap" if periodic[0] else "edge",
635
- "wrap" if periodic[1] else "edge",
636
- )
637
- pad_width = ((1, 1), (1, 1))
638
-
639
- kernel = jnp.asarray([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float)
640
-
641
- solid = phi > 0
642
- void = ~solid
643
-
644
- solid_padded = transforms.pad2d(solid, pad_width, pad_mode)
645
- num_solid_adjacent = transforms.conv(
646
- x=solid_padded[:, jnp.newaxis, :, :].astype(float),
647
- kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
648
- padding="VALID",
649
- )
650
- num_solid_adjacent = jnp.squeeze(num_solid_adjacent, axis=1)
651
-
652
- void_padded = transforms.pad2d(void, pad_width, pad_mode)
653
- num_void_adjacent = transforms.conv(
654
- x=void_padded[:, jnp.newaxis, :, :].astype(float),
655
- kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
656
- padding="VALID",
657
- )
658
- num_void_adjacent = jnp.squeeze(num_void_adjacent, axis=1)
659
-
660
- interface = solid & (num_void_adjacent > 0) | void & (num_solid_adjacent > 0)
661
-
662
- return interface.reshape(batch_shape + interface.shape[-2:])
663
-
664
-
665
628
  def _downsample_spatial_dims(x: jnp.ndarray, downsample_factor: int) -> jnp.ndarray:
666
629
  """Downsamples the two trailing axes of `x` by `downsample_factor`."""
667
630
  shape = x.shape[:-2] + (
@@ -231,3 +231,40 @@ def box_downsample(x: jnp.ndarray, shape: Tuple[int, ...]) -> jnp.ndarray:
231
231
  axes = list(range(1, 2 * x.ndim, 2))
232
232
  x = x.reshape(shape)
233
233
  return jnp.mean(x, axis=axes)
234
+
235
+
236
+ def interface_pixels(phi: jnp.ndarray, periodic: Tuple[bool, bool]) -> jnp.ndarray:
237
+ """Identifies interface pixels of a level set function `phi`."""
238
+ batch_shape = phi.shape[:-2]
239
+ phi = phi.reshape((-1,) + phi.shape[-2:])
240
+
241
+ pad_mode = (
242
+ "wrap" if periodic[0] else "edge",
243
+ "wrap" if periodic[1] else "edge",
244
+ )
245
+ pad_width = ((1, 1), (1, 1))
246
+
247
+ kernel = jnp.asarray([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float)
248
+
249
+ solid = phi > 0
250
+ void = ~solid
251
+
252
+ solid_padded = pad2d(solid, pad_width, pad_mode)
253
+ num_solid_adjacent = conv(
254
+ x=solid_padded[:, jnp.newaxis, :, :].astype(float),
255
+ kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
256
+ padding="VALID",
257
+ )
258
+ num_solid_adjacent = jnp.squeeze(num_solid_adjacent, axis=1)
259
+
260
+ void_padded = pad2d(void, pad_width, pad_mode)
261
+ num_void_adjacent = conv(
262
+ x=void_padded[:, jnp.newaxis, :, :].astype(float),
263
+ kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
264
+ padding="VALID",
265
+ )
266
+ num_void_adjacent = jnp.squeeze(num_void_adjacent, axis=1)
267
+
268
+ interface = solid & (num_void_adjacent > 0) | void & (num_solid_adjacent > 0)
269
+
270
+ return interface.reshape(batch_shape + interface.shape[-2:])
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: invrs_opt
3
- Version: 0.10.5
3
+ Version: 0.10.7
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -531,6 +531,7 @@ Requires-Dist: bump-my-version; extra == "dev"
531
531
  Requires-Dist: darglint; extra == "dev"
532
532
  Requires-Dist: mypy; extra == "dev"
533
533
  Requires-Dist: pre-commit; extra == "dev"
534
+ Dynamic: license-file
534
535
 
535
536
  # invrs-opt - Optimization algorithms for inverse design
536
537
  ![Continuous integration](https://github.com/invrs-io/opt/actions/workflows/build-ci.yml/badge.svg)
File without changes
File without changes
File without changes