invrs-opt 0.10.5__py3-none-any.whl → 0.10.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +1 -1
- invrs_opt/parameterization/gaussian_levelset.py +1 -38
- invrs_opt/parameterization/transforms.py +37 -0
- {invrs_opt-0.10.5.dist-info → invrs_opt-0.10.6.dist-info}/METADATA +1 -1
- {invrs_opt-0.10.5.dist-info → invrs_opt-0.10.6.dist-info}/RECORD +8 -8
- {invrs_opt-0.10.5.dist-info → invrs_opt-0.10.6.dist-info}/WHEEL +1 -1
- {invrs_opt-0.10.5.dist-info → invrs_opt-0.10.6.dist-info}/LICENSE +0 -0
- {invrs_opt-0.10.5.dist-info → invrs_opt-0.10.6.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -619,49 +619,12 @@ def _levelset_threshold(
|
|
619
619
|
) -> jnp.ndarray:
|
620
620
|
"""Thresholds a level set function `phi`."""
|
621
621
|
if mask_gradient:
|
622
|
-
interface =
|
622
|
+
interface = transforms.interface_pixels(phi, periodic)
|
623
623
|
phi = jnp.where(interface, phi, jax.lax.stop_gradient(phi))
|
624
624
|
thresholded = (phi > 0).astype(float) + (phi - jax.lax.stop_gradient(phi))
|
625
625
|
return thresholded
|
626
626
|
|
627
627
|
|
628
|
-
def _interface_pixels(phi: jnp.ndarray, periodic: Tuple[bool, bool]) -> jnp.ndarray:
|
629
|
-
"""Identifies interface pixels of a level set function `phi`."""
|
630
|
-
batch_shape = phi.shape[:-2]
|
631
|
-
phi = phi.reshape((-1,) + phi.shape[-2:])
|
632
|
-
|
633
|
-
pad_mode = (
|
634
|
-
"wrap" if periodic[0] else "edge",
|
635
|
-
"wrap" if periodic[1] else "edge",
|
636
|
-
)
|
637
|
-
pad_width = ((1, 1), (1, 1))
|
638
|
-
|
639
|
-
kernel = jnp.asarray([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float)
|
640
|
-
|
641
|
-
solid = phi > 0
|
642
|
-
void = ~solid
|
643
|
-
|
644
|
-
solid_padded = transforms.pad2d(solid, pad_width, pad_mode)
|
645
|
-
num_solid_adjacent = transforms.conv(
|
646
|
-
x=solid_padded[:, jnp.newaxis, :, :].astype(float),
|
647
|
-
kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
|
648
|
-
padding="VALID",
|
649
|
-
)
|
650
|
-
num_solid_adjacent = jnp.squeeze(num_solid_adjacent, axis=1)
|
651
|
-
|
652
|
-
void_padded = transforms.pad2d(void, pad_width, pad_mode)
|
653
|
-
num_void_adjacent = transforms.conv(
|
654
|
-
x=void_padded[:, jnp.newaxis, :, :].astype(float),
|
655
|
-
kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
|
656
|
-
padding="VALID",
|
657
|
-
)
|
658
|
-
num_void_adjacent = jnp.squeeze(num_void_adjacent, axis=1)
|
659
|
-
|
660
|
-
interface = solid & (num_void_adjacent > 0) | void & (num_solid_adjacent > 0)
|
661
|
-
|
662
|
-
return interface.reshape(batch_shape + interface.shape[-2:])
|
663
|
-
|
664
|
-
|
665
628
|
def _downsample_spatial_dims(x: jnp.ndarray, downsample_factor: int) -> jnp.ndarray:
|
666
629
|
"""Downsamples the two trailing axes of `x` by `downsample_factor`."""
|
667
630
|
shape = x.shape[:-2] + (
|
@@ -231,3 +231,40 @@ def box_downsample(x: jnp.ndarray, shape: Tuple[int, ...]) -> jnp.ndarray:
|
|
231
231
|
axes = list(range(1, 2 * x.ndim, 2))
|
232
232
|
x = x.reshape(shape)
|
233
233
|
return jnp.mean(x, axis=axes)
|
234
|
+
|
235
|
+
|
236
|
+
def interface_pixels(phi: jnp.ndarray, periodic: Tuple[bool, bool]) -> jnp.ndarray:
|
237
|
+
"""Identifies interface pixels of a level set function `phi`."""
|
238
|
+
batch_shape = phi.shape[:-2]
|
239
|
+
phi = phi.reshape((-1,) + phi.shape[-2:])
|
240
|
+
|
241
|
+
pad_mode = (
|
242
|
+
"wrap" if periodic[0] else "edge",
|
243
|
+
"wrap" if periodic[1] else "edge",
|
244
|
+
)
|
245
|
+
pad_width = ((1, 1), (1, 1))
|
246
|
+
|
247
|
+
kernel = jnp.asarray([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float)
|
248
|
+
|
249
|
+
solid = phi > 0
|
250
|
+
void = ~solid
|
251
|
+
|
252
|
+
solid_padded = pad2d(solid, pad_width, pad_mode)
|
253
|
+
num_solid_adjacent = conv(
|
254
|
+
x=solid_padded[:, jnp.newaxis, :, :].astype(float),
|
255
|
+
kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
|
256
|
+
padding="VALID",
|
257
|
+
)
|
258
|
+
num_solid_adjacent = jnp.squeeze(num_solid_adjacent, axis=1)
|
259
|
+
|
260
|
+
void_padded = pad2d(void, pad_width, pad_mode)
|
261
|
+
num_void_adjacent = conv(
|
262
|
+
x=void_padded[:, jnp.newaxis, :, :].astype(float),
|
263
|
+
kernel=kernel[jnp.newaxis, jnp.newaxis, :, :],
|
264
|
+
padding="VALID",
|
265
|
+
)
|
266
|
+
num_void_adjacent = jnp.squeeze(num_void_adjacent, axis=1)
|
267
|
+
|
268
|
+
interface = solid & (num_void_adjacent > 0) | void & (num_solid_adjacent > 0)
|
269
|
+
|
270
|
+
return interface.reshape(batch_shape + interface.shape[-2:])
|
@@ -1,4 +1,4 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=
|
1
|
+
invrs_opt/__init__.py,sha256=4uq04SwbUcX0hVAhaX0-GJ4R4tRJAAXfIzzfU-VcalA,586
|
2
2
|
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
invrs_opt/experimental/client.py,sha256=tbtH13FrA65XmTZfTO71CxJ78jeAEj3Zf85R-MTwbiU,4909
|
@@ -10,11 +10,11 @@ invrs_opt/optimizers/wrapped_optax.py,sha256=781-8v_TlHsGaQF9Se9_iOEvtOLOr-BesTL
|
|
10
10
|
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
invrs_opt/parameterization/base.py,sha256=jSwrEO86lGkYQG5gWsHvcIMWpZnnbdiKpn--2qaU02g,5362
|
12
12
|
invrs_opt/parameterization/filter_project.py,sha256=XL3HTEBLrF-q_75TjhOWLNdfUOSEEjKcoM7Qj844QpQ,4590
|
13
|
-
invrs_opt/parameterization/gaussian_levelset.py,sha256
|
13
|
+
invrs_opt/parameterization/gaussian_levelset.py,sha256=PDvjdgBzklRTCUoBpo4ZMcmXeTkn0BpZEzQj7ojtYGE,24813
|
14
14
|
invrs_opt/parameterization/pixel.py,sha256=YWkyBhfYtzI8cQ-M90PAZqRAbabwVaUh0UiYIGegQHI,1955
|
15
|
-
invrs_opt/parameterization/transforms.py,sha256=
|
16
|
-
invrs_opt-0.10.
|
17
|
-
invrs_opt-0.10.
|
18
|
-
invrs_opt-0.10.
|
19
|
-
invrs_opt-0.10.
|
20
|
-
invrs_opt-0.10.
|
15
|
+
invrs_opt/parameterization/transforms.py,sha256=mqDKuAg4wpSL9kh0oYKxtSoH0mHOQeKG1RND2fJSYaU,9441
|
16
|
+
invrs_opt-0.10.6.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
17
|
+
invrs_opt-0.10.6.dist-info/METADATA,sha256=40XUQ7i3S4nYRGUqkEVOwKrJwGtpvdwmpD9ojFpAeAM,32816
|
18
|
+
invrs_opt-0.10.6.dist-info/WHEEL,sha256=nn6H5-ilmfVryoAQl3ZQ2l8SH5imPWFpm1A5FgEuFV4,91
|
19
|
+
invrs_opt-0.10.6.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
20
|
+
invrs_opt-0.10.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|