invrs-opt 0.5.1__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {invrs_opt-0.5.1/src/invrs_opt.egg-info → invrs_opt-0.6.0}/PKG-INFO +3 -2
  2. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/README.md +1 -1
  3. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/pyproject.toml +2 -1
  4. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt/__init__.py +5 -1
  5. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt/base.py +8 -0
  6. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt/experimental/client.py +1 -2
  7. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt/lbfgsb/lbfgsb.py +9 -10
  8. invrs_opt-0.6.0/src/invrs_opt/wrapped_optax/__init__.py +0 -0
  9. invrs_opt-0.6.0/src/invrs_opt/wrapped_optax/wrapped_optax.py +150 -0
  10. {invrs_opt-0.5.1 → invrs_opt-0.6.0/src/invrs_opt.egg-info}/PKG-INFO +3 -2
  11. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt.egg-info/SOURCES.txt +5 -2
  12. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt.egg-info/requires.txt +1 -0
  13. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/tests/test_algos.py +3 -0
  14. invrs_opt-0.6.0/tests/test_transform.py +362 -0
  15. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/LICENSE +0 -0
  16. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/setup.cfg +0 -0
  17. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt/experimental/__init__.py +0 -0
  18. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt/experimental/labels.py +0 -0
  19. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt/lbfgsb/__init__.py +0 -0
  20. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt/py.typed +0 -0
  21. {invrs_opt-0.5.1/src/invrs_opt/lbfgsb → invrs_opt-0.6.0/src/invrs_opt}/transform.py +0 -0
  22. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt.egg-info/dependency_links.txt +0 -0
  23. {invrs_opt-0.5.1 → invrs_opt-0.6.0}/src/invrs_opt.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.5.1
3
+ Version: 0.6.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -34,6 +34,7 @@ Requires-Dist: jax
34
34
  Requires-Dist: jaxlib
35
35
  Requires-Dist: numpy
36
36
  Requires-Dist: requests
37
+ Requires-Dist: optax
37
38
  Requires-Dist: scipy
38
39
  Requires-Dist: totypes
39
40
  Requires-Dist: types-requests
@@ -49,7 +50,7 @@ Requires-Dist: mypy; extra == "dev"
49
50
  Requires-Dist: pre-commit; extra == "dev"
50
51
 
51
52
  # invrs-opt - Optimization algorithms for inverse design
52
- `v0.5.1`
53
+ `v0.6.0`
53
54
 
54
55
  ## Overview
55
56
 
@@ -1,5 +1,5 @@
1
1
  # invrs-opt - Optimization algorithms for inverse design
2
- `v0.5.1`
2
+ `v0.6.0`
3
3
 
4
4
  ## Overview
5
5
 
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
 
3
3
  name = "invrs_opt"
4
- version = "v0.5.1"
4
+ version = "v0.6.0"
5
5
  description = "Algorithms for inverse design"
6
6
  keywords = ["topology", "optimization", "jax", "inverse design"]
7
7
  readme = "README.md"
@@ -20,6 +20,7 @@ dependencies = [
20
20
  "jaxlib",
21
21
  "numpy",
22
22
  "requests",
23
+ "optax",
23
24
  "scipy",
24
25
  "totypes",
25
26
  "types-requests",
@@ -3,8 +3,12 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.5.1"
6
+ __version__ = "v0.6.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
10
  from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
11
+ from invrs_opt.wrapped_optax.wrapped_optax import (
12
+ density_wrapped_optax as density_wrapped_optax,
13
+ )
14
+ from invrs_opt.wrapped_optax.wrapped_optax import wrapped_optax as wrapped_optax
@@ -6,6 +6,9 @@ Copyright (c) 2023 The INVRS-IO authors.
6
6
  import dataclasses
7
7
  from typing import Any, Protocol
8
8
 
9
+ import optax # type: ignore[import-untyped]
10
+ from totypes import json_utils
11
+
9
12
  PyTree = Any
10
13
 
11
14
 
@@ -44,3 +47,8 @@ class Optimizer:
44
47
  init: InitFn
45
48
  params: ParamsFn
46
49
  update: UpdateFn
50
+
51
+
52
+ # TODO: consider programatically registering all optax states here.
53
+ json_utils.register_custom_type(optax.EmptyState)
54
+ json_utils.register_custom_type(optax.ScaleByAdamState)
@@ -4,16 +4,15 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import json
7
- import requests
8
7
  import time
9
8
  from typing import Any, Dict, Optional
10
9
 
10
+ import requests
11
11
  from totypes import json_utils
12
12
 
13
13
  from invrs_opt import base
14
14
  from invrs_opt.experimental import labels
15
15
 
16
-
17
16
  PyTree = Any
18
17
  StateToken = str
19
18
 
@@ -16,8 +16,7 @@ from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
16
16
  )
17
17
  from totypes import types
18
18
 
19
- from invrs_opt import base
20
- from invrs_opt.lbfgsb import transform
19
+ from invrs_opt import base, transform
21
20
 
22
21
  NDArray = onp.ndarray[Any, Any]
23
22
  PyTree = Any
@@ -258,7 +257,7 @@ def transformed_lbfgsb(
258
257
  (
259
258
  latent_params,
260
259
  jax_lbfgsb_state,
261
- ) = jax.pure_callback( # type: ignore[attr-defined]
260
+ ) = jax.pure_callback(
262
261
  _init_pure,
263
262
  _example_state(params, maxcor),
264
263
  initialize_latent_fn(params),
@@ -304,7 +303,7 @@ def transformed_lbfgsb(
304
303
  (
305
304
  flat_latent_params,
306
305
  jax_lbfgsb_state,
307
- ) = jax.pure_callback( # type: ignore[attr-defined]
306
+ ) = jax.pure_callback(
308
307
  _update_pure,
309
308
  (flat_latent_grad, jax_lbfgsb_state),
310
309
  flat_latent_grad,
@@ -542,19 +541,19 @@ class ScipyLbfgsbState:
542
541
  """Converts a dictionary of jax arrays to a `ScipyLbfgsbState`."""
543
542
  state_dict = copy.deepcopy(state_dict)
544
543
  return ScipyLbfgsbState(
545
- x=onp.asarray(state_dict["x"], dtype=onp.float64),
544
+ x=onp.array(state_dict["x"], dtype=onp.float64),
546
545
  converged=onp.asarray(state_dict["converged"], dtype=bool),
547
546
  _maxcor=int(state_dict["_maxcor"]),
548
547
  _line_search_max_steps=int(state_dict["_line_search_max_steps"]),
549
548
  _ftol=onp.asarray(state_dict["_ftol"], dtype=onp.float64),
550
549
  _gtol=onp.asarray(state_dict["_gtol"], dtype=onp.float64),
551
- _wa=onp.asarray(state_dict["_wa"], onp.float64),
552
- _iwa=onp.asarray(state_dict["_iwa"], dtype=FORTRAN_INT),
550
+ _wa=onp.array(state_dict["_wa"], onp.float64),
551
+ _iwa=onp.array(state_dict["_iwa"], dtype=FORTRAN_INT),
553
552
  _task=_s60_str_from_array(state_dict["_task"]),
554
553
  _csave=_s60_str_from_array(state_dict["_csave"]),
555
- _lsave=onp.asarray(state_dict["_lsave"], dtype=FORTRAN_INT),
556
- _isave=onp.asarray(state_dict["_isave"], dtype=FORTRAN_INT),
557
- _dsave=onp.asarray(state_dict["_dsave"], dtype=onp.float64),
554
+ _lsave=onp.array(state_dict["_lsave"], dtype=FORTRAN_INT),
555
+ _isave=onp.array(state_dict["_isave"], dtype=FORTRAN_INT),
556
+ _dsave=onp.array(state_dict["_dsave"], dtype=onp.float64),
558
557
  _lower_bound=onp.asarray(state_dict["_lower_bound"], dtype=onp.float64),
559
558
  _upper_bound=onp.asarray(state_dict["_upper_bound"], dtype=onp.float64),
560
559
  _bound_type=onp.asarray(state_dict["_bound_type"], dtype=int),
@@ -0,0 +1,150 @@
1
+ import dataclasses
2
+ from typing import Any, Callable, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import optax # type: ignore[import-untyped]
7
+ from jax import tree_util
8
+ from totypes import types
9
+
10
+ from invrs_opt import base, transform
11
+
12
+ PyTree = Any
13
+ WrappedOptaxState = Tuple[PyTree, PyTree, PyTree]
14
+
15
+
16
+ def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
17
+ """Return a wrapped optax optimizer."""
18
+ return transformed_wrapped_optax(
19
+ opt=opt,
20
+ transform_fn=lambda x: x,
21
+ initialize_latent_fn=lambda x: x,
22
+ )
23
+
24
+
25
+ def density_wrapped_optax(
26
+ opt: optax.GradientTransformation,
27
+ beta: float,
28
+ ) -> base.Optimizer:
29
+ """Return a wrapped optax optimizer with transforms for density arrays."""
30
+
31
+ def transform_fn(tree: PyTree) -> PyTree:
32
+ return tree_util.tree_map(
33
+ lambda x: transform_density(x) if _is_density(x) else x,
34
+ tree,
35
+ is_leaf=_is_density,
36
+ )
37
+
38
+ def initialize_latent_fn(tree: PyTree) -> PyTree:
39
+ return tree_util.tree_map(
40
+ lambda x: initialize_latent_density(x) if _is_density(x) else x,
41
+ tree,
42
+ is_leaf=_is_density,
43
+ )
44
+
45
+ def transform_density(density: types.Density2DArray) -> types.Density2DArray:
46
+ transformed = types.symmetrize_density(density)
47
+ transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
48
+ # Scale to ensure that the full valid range of the density array is reachable.
49
+ mid_value = (density.lower_bound + density.upper_bound) / 2
50
+ transformed = tree_util.tree_map(
51
+ lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
52
+ )
53
+ return transform.apply_fixed_pixels(transformed)
54
+
55
+ def initialize_latent_density(
56
+ density: types.Density2DArray,
57
+ ) -> types.Density2DArray:
58
+ array = transform.normalized_array_from_density(density)
59
+ array = jnp.clip(array, -1, 1)
60
+ array *= jnp.tanh(beta)
61
+ latent_array = jnp.arctanh(array) / beta
62
+ latent_array = transform.rescale_array_for_density(latent_array, density)
63
+ return dataclasses.replace(density, array=latent_array)
64
+
65
+ return transformed_wrapped_optax(
66
+ opt=opt,
67
+ transform_fn=transform_fn,
68
+ initialize_latent_fn=initialize_latent_fn,
69
+ )
70
+
71
+
72
+ def transformed_wrapped_optax(
73
+ opt: optax.GradientTransformation,
74
+ transform_fn: Callable[[PyTree], PyTree],
75
+ initialize_latent_fn: Callable[[PyTree], PyTree],
76
+ ) -> base.Optimizer:
77
+ """Return a wrapped optax optimizer for transformed latent parameters.
78
+
79
+ Args:
80
+ opt: The optax `GradientTransformation` to be wrapped.
81
+ transform_fn: Function which transforms the internal latent parameters to
82
+ the parameters returned by the optimizer.
83
+ initialize_latent_fn: Function which computes the initial latent parameters
84
+ given the initial parameters.
85
+
86
+ Returns:
87
+ The `base.Optimizer`.
88
+ """
89
+
90
+ def init_fn(params: PyTree) -> WrappedOptaxState:
91
+ """Initializes the optimization state."""
92
+ latent_params = initialize_latent_fn(_clip(params))
93
+ params = transform_fn(latent_params)
94
+ return params, latent_params, opt.init(latent_params)
95
+
96
+ def params_fn(state: WrappedOptaxState) -> PyTree:
97
+ """Returns the parameters for the given `state`."""
98
+ params, _, _ = state
99
+ return params
100
+
101
+ def update_fn(
102
+ *,
103
+ grad: PyTree,
104
+ value: float,
105
+ params: PyTree,
106
+ state: WrappedOptaxState,
107
+ ) -> WrappedOptaxState:
108
+ """Updates the state."""
109
+ del value
110
+
111
+ _, latent_params, opt_state = state
112
+ _, vjp_fn = jax.vjp(transform_fn, latent_params)
113
+ (latent_grad,) = vjp_fn(grad)
114
+
115
+ updates, opt_state = opt.update(latent_grad, opt_state)
116
+ latent_params = optax.apply_updates(params=latent_params, updates=updates)
117
+ latent_params = _clip(latent_params)
118
+ params = transform_fn(latent_params)
119
+ return params, latent_params, opt_state
120
+
121
+ return base.Optimizer(
122
+ init=init_fn,
123
+ params=params_fn,
124
+ update=update_fn,
125
+ )
126
+
127
+
128
+ def _is_density(leaf: Any) -> Any:
129
+ """Return `True` if `leaf` is a density array."""
130
+ return isinstance(leaf, types.Density2DArray)
131
+
132
+
133
+ def _is_custom_type(leaf: Any) -> bool:
134
+ """Return `True` if `leaf` is a recognized custom type."""
135
+ return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
136
+
137
+
138
+ def _clip(pytree: PyTree) -> PyTree:
139
+ """Clips leaves on `pytree` to their bounds."""
140
+
141
+ def _clip_fn(leaf: Any) -> Any:
142
+ if not _is_custom_type(leaf):
143
+ return leaf
144
+ if leaf.lower_bound is None and leaf.upper_bound is None:
145
+ return leaf
146
+ return tree_util.tree_map(
147
+ lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
148
+ )
149
+
150
+ return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.5.1
3
+ Version: 0.6.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -34,6 +34,7 @@ Requires-Dist: jax
34
34
  Requires-Dist: jaxlib
35
35
  Requires-Dist: numpy
36
36
  Requires-Dist: requests
37
+ Requires-Dist: optax
37
38
  Requires-Dist: scipy
38
39
  Requires-Dist: totypes
39
40
  Requires-Dist: types-requests
@@ -49,7 +50,7 @@ Requires-Dist: mypy; extra == "dev"
49
50
  Requires-Dist: pre-commit; extra == "dev"
50
51
 
51
52
  # invrs-opt - Optimization algorithms for inverse design
52
- `v0.5.1`
53
+ `v0.6.0`
53
54
 
54
55
  ## Overview
55
56
 
@@ -4,6 +4,7 @@ pyproject.toml
4
4
  src/invrs_opt/__init__.py
5
5
  src/invrs_opt/base.py
6
6
  src/invrs_opt/py.typed
7
+ src/invrs_opt/transform.py
7
8
  src/invrs_opt.egg-info/PKG-INFO
8
9
  src/invrs_opt.egg-info/SOURCES.txt
9
10
  src/invrs_opt.egg-info/dependency_links.txt
@@ -14,5 +15,7 @@ src/invrs_opt/experimental/client.py
14
15
  src/invrs_opt/experimental/labels.py
15
16
  src/invrs_opt/lbfgsb/__init__.py
16
17
  src/invrs_opt/lbfgsb/lbfgsb.py
17
- src/invrs_opt/lbfgsb/transform.py
18
- tests/test_algos.py
18
+ src/invrs_opt/wrapped_optax/__init__.py
19
+ src/invrs_opt/wrapped_optax/wrapped_optax.py
20
+ tests/test_algos.py
21
+ tests/test_transform.py
@@ -2,6 +2,7 @@ jax
2
2
  jaxlib
3
3
  numpy
4
4
  requests
5
+ optax
5
6
  scipy
6
7
  totypes
7
8
  types-requests
@@ -9,6 +9,7 @@ import unittest
9
9
  import jax
10
10
  import jax.numpy as jnp
11
11
  import numpy as onp
12
+ import optax
12
13
  import parameterized
13
14
  from totypes import json_utils, symmetry, types
14
15
 
@@ -21,6 +22,8 @@ jax.config.update("jax_enable_x64", True)
21
22
  OPTIMIZERS = [
22
23
  invrs_opt.lbfgsb(maxcor=20, line_search_max_steps=100),
23
24
  invrs_opt.density_lbfgsb(maxcor=20, line_search_max_steps=100, beta=2.0),
25
+ invrs_opt.wrapped_optax(optax.adam(1e-2)),
26
+ invrs_opt.density_wrapped_optax(optax.adam(1e-2), beta=2.0),
24
27
  ]
25
28
 
26
29
  # Various parameter combinations tested in this module.
@@ -0,0 +1,362 @@
1
+ """Tests for `transforms`.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import dataclasses
7
+ import unittest
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import numpy as onp
12
+ from parameterized import parameterized
13
+ from totypes import types
14
+
15
+ from invrs_opt import transform
16
+
17
+
18
+ class GaussianFilterTest(unittest.TestCase):
19
+ @parameterized.expand([[1, 5], [3, 3], [5, 1]])
20
+ def test_transformed_matches_expected(self, minimum_width, minimum_spacing):
21
+ array = onp.zeros((9, 9))
22
+ array[4, 4] = 9
23
+ density = types.Density2DArray(
24
+ array=array,
25
+ lower_bound=0,
26
+ upper_bound=1,
27
+ minimum_width=minimum_width,
28
+ minimum_spacing=minimum_spacing,
29
+ )
30
+ beta = 1
31
+ transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta)
32
+ expected = onp.asarray(
33
+ [
34
+ [0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12],
35
+ [0.12, 0.12, 0.13, 0.14, 0.14, 0.14, 0.13, 0.12, 0.12],
36
+ [0.12, 0.13, 0.15, 0.22, 0.27, 0.22, 0.15, 0.13, 0.12],
37
+ [0.12, 0.14, 0.22, 0.48, 0.64, 0.48, 0.22, 0.14, 0.12],
38
+ [0.12, 0.14, 0.27, 0.64, 0.82, 0.64, 0.27, 0.14, 0.12],
39
+ [0.12, 0.14, 0.22, 0.48, 0.64, 0.48, 0.22, 0.14, 0.12],
40
+ [0.12, 0.13, 0.15, 0.22, 0.27, 0.22, 0.15, 0.13, 0.12],
41
+ [0.12, 0.12, 0.13, 0.14, 0.14, 0.14, 0.13, 0.12, 0.12],
42
+ [0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12],
43
+ ]
44
+ )
45
+ onp.testing.assert_allclose(transformed.array, expected, rtol=0.05)
46
+
47
+ @parameterized.expand([[1, 1], [3, 1], [5, 1], [10, 1], [10, 0.5], [10, 2]])
48
+ def test_ones_density_yields_tanh_beta(self, length_scale, upper_bound):
49
+ array = onp.ones((20, 20)) * upper_bound
50
+ density = types.Density2DArray(
51
+ array=array,
52
+ lower_bound=0,
53
+ upper_bound=upper_bound,
54
+ minimum_width=length_scale,
55
+ minimum_spacing=length_scale,
56
+ )
57
+ beta = 1
58
+ transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta)
59
+ onp.testing.assert_allclose(
60
+ transformed.array,
61
+ (1 + onp.tanh(beta)) * 0.5 * upper_bound,
62
+ rtol=0.01,
63
+ )
64
+
65
+ def test_batch_matches_single(self):
66
+ beta = 4
67
+ density = types.Density2DArray(
68
+ array=onp.arange(600).reshape((6, 10, 10)),
69
+ minimum_width=5,
70
+ minimum_spacing=5,
71
+ )
72
+ transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta)
73
+ for i in range(6):
74
+ transformed_single = transform.density_gaussian_filter_and_tanh(
75
+ density=dataclasses.replace(
76
+ density,
77
+ array=density.array[i, :, :],
78
+ ),
79
+ beta=beta,
80
+ )
81
+ onp.testing.assert_allclose(
82
+ transformed.array[i, :, :], transformed_single.array
83
+ )
84
+
85
+ def test_periodic(self):
86
+ beta = 100
87
+ array = onp.zeros((5, 5))
88
+ array[0, 0] = 9
89
+
90
+ # No periodicity.
91
+ density = types.Density2DArray(
92
+ array,
93
+ minimum_spacing=3,
94
+ minimum_width=3,
95
+ periodic=(False, False),
96
+ lower_bound=0,
97
+ upper_bound=1,
98
+ )
99
+ transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta)
100
+ expected = onp.asarray(
101
+ [
102
+ [1, 1, 1, 0, 0],
103
+ [1, 1, 0, 0, 0],
104
+ [1, 0, 0, 0, 0],
105
+ [0, 0, 0, 0, 0],
106
+ [0, 0, 0, 0, 0],
107
+ ]
108
+ )
109
+ onp.testing.assert_allclose(transformed.array, expected, atol=0.01)
110
+
111
+ # Periodic along the first axis.
112
+ density = dataclasses.replace(density, periodic=(True, False))
113
+ transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta)
114
+ expected = onp.asarray(
115
+ [
116
+ [1, 1, 0, 0, 0],
117
+ [1, 1, 0, 0, 0],
118
+ [1, 0, 0, 0, 0],
119
+ [1, 0, 0, 0, 0],
120
+ [1, 1, 0, 0, 0],
121
+ ]
122
+ )
123
+ onp.testing.assert_allclose(transformed.array, expected, atol=0.01)
124
+
125
+ # Periodic along the second axis.
126
+ density = dataclasses.replace(density, periodic=(False, True))
127
+ transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta)
128
+ expected = onp.asarray(
129
+ [
130
+ [1, 1, 1, 1, 1],
131
+ [1, 1, 0, 0, 1],
132
+ [0, 0, 0, 0, 0],
133
+ [0, 0, 0, 0, 0],
134
+ [0, 0, 0, 0, 0],
135
+ ]
136
+ )
137
+ onp.testing.assert_allclose(transformed.array, expected, atol=0.01)
138
+
139
+ # Periodic along both axes.
140
+ density = dataclasses.replace(density, periodic=(True, True))
141
+ transformed = transform.density_gaussian_filter_and_tanh(density, beta=beta)
142
+ expected = onp.asarray(
143
+ [
144
+ [1, 1, 0, 0, 1],
145
+ [1, 0, 0, 0, 0],
146
+ [0, 0, 0, 0, 0],
147
+ [0, 0, 0, 0, 0],
148
+ [1, 0, 0, 0, 0],
149
+ ]
150
+ )
151
+ onp.testing.assert_allclose(transformed.array, expected, atol=0.01)
152
+
153
+
154
+ class RescaleTest(unittest.TestCase):
155
+ @parameterized.expand([(-1.0, 1.0), (2.0, 3.0), (-0.5, -0.1)])
156
+ def test_normalized_array_from_density(self, lower_bound, upper_bound):
157
+ density = types.Density2DArray(
158
+ array=jnp.linspace(lower_bound, upper_bound).reshape((10, 5)),
159
+ lower_bound=lower_bound,
160
+ upper_bound=upper_bound,
161
+ fixed_solid=None,
162
+ fixed_void=None,
163
+ minimum_width=1,
164
+ minimum_spacing=1,
165
+ )
166
+ # Compute `array`, which should now have values between `-1` and `1`.
167
+ array = transform.normalized_array_from_density(density)
168
+ expected = jnp.linspace(-1.0, 1.0).reshape((10, 5))
169
+ onp.testing.assert_allclose(array, expected, rtol=1e-5)
170
+
171
+ @parameterized.expand([(-1.0, 1.0), (2.0, 3.0), (-0.5, -0.1)])
172
+ def test_rescale_array_for_density(self, lower_bound, upper_bound):
173
+ dummy_density = types.Density2DArray(
174
+ array=jnp.ones((2, 2)),
175
+ lower_bound=lower_bound,
176
+ upper_bound=upper_bound,
177
+ fixed_solid=None,
178
+ fixed_void=None,
179
+ minimum_width=1,
180
+ minimum_spacing=1,
181
+ )
182
+ array = jnp.linspace(-1.0, 1.0).reshape((10, 5))
183
+ rescaled = transform.rescale_array_for_density(array, dummy_density)
184
+ expected = jnp.linspace(lower_bound, upper_bound).reshape((10, 5))
185
+ onp.testing.assert_allclose(rescaled, expected, rtol=1e-5)
186
+
187
+
188
+ class FixedPixelTest(unittest.TestCase):
189
+ @parameterized.expand(
190
+ [
191
+ [
192
+ jnp.asarray([[1, 0, 0, 0, 0]], dtype=bool),
193
+ jnp.asarray([[0, 0, 1, 1, 0]], dtype=bool),
194
+ jnp.asarray([[3.0, 0.0, -0.5, -0.5, 0.0]]),
195
+ ],
196
+ [
197
+ None,
198
+ jnp.asarray([[0, 0, 1, 1, 0]], dtype=bool),
199
+ jnp.asarray([[0.0, 0.0, -0.5, -0.5, 0.0]]),
200
+ ],
201
+ [
202
+ None,
203
+ None,
204
+ jnp.asarray([[0.0, 0.0, 0.0, 0.0, 0.0]]),
205
+ ],
206
+ ]
207
+ )
208
+ def test_apply_fixed_pixels(self, fixed_solid, fixed_void, expected):
209
+ density = types.Density2DArray(
210
+ array=jnp.zeros((1, 5)),
211
+ fixed_solid=fixed_solid,
212
+ fixed_void=fixed_void,
213
+ lower_bound=-0.5,
214
+ upper_bound=3.0,
215
+ )
216
+ density = transform.apply_fixed_pixels(density)
217
+ onp.testing.assert_array_equal(density.array, expected)
218
+
219
+
220
+ class Pad2DTest(unittest.TestCase):
221
+ def test_pad2d_edge(self):
222
+ array = jnp.asarray(
223
+ [
224
+ [0, 1, 2, 3, 4],
225
+ [5, 6, 7, 8, 9],
226
+ [10, 11, 12, 13, 14],
227
+ ]
228
+ )
229
+ expected = jnp.asarray(
230
+ [
231
+ [0, 0, 1, 2, 3, 4, 4],
232
+ [0, 0, 1, 2, 3, 4, 4],
233
+ [5, 5, 6, 7, 8, 9, 9],
234
+ [10, 10, 11, 12, 13, 14, 14],
235
+ [10, 10, 11, 12, 13, 14, 14],
236
+ ]
237
+ )
238
+ padded = transform.pad2d(array, ((1, 1), (1, 1)), "edge")
239
+ padded_both_specified = transform.pad2d(
240
+ array, ((1, 1), (1, 1)), ("edge", "edge")
241
+ )
242
+ onp.testing.assert_array_equal(padded, expected)
243
+ onp.testing.assert_array_equal(padded_both_specified, expected)
244
+
245
+ def test_pad2d_wrap(self):
246
+ array = jnp.asarray(
247
+ [
248
+ [0, 1, 2, 3, 4],
249
+ [5, 6, 7, 8, 9],
250
+ [10, 11, 12, 13, 14],
251
+ ]
252
+ )
253
+ expected = jnp.asarray(
254
+ [
255
+ [14, 10, 11, 12, 13, 14, 10],
256
+ [4, 0, 1, 2, 3, 4, 0],
257
+ [9, 5, 6, 7, 8, 9, 5],
258
+ [14, 10, 11, 12, 13, 14, 10],
259
+ [4, 0, 1, 2, 3, 4, 0],
260
+ ]
261
+ )
262
+ padded = transform.pad2d(array, ((1, 1), (1, 1)), "wrap")
263
+ padded_both_specified = transform.pad2d(
264
+ array, ((1, 1), (1, 1)), ("wrap", "wrap")
265
+ )
266
+ onp.testing.assert_array_equal(padded, expected)
267
+ onp.testing.assert_array_equal(padded_both_specified, expected)
268
+
269
+ def test_pad2d_wrap_edge(self):
270
+ array = jnp.asarray(
271
+ [
272
+ [0, 1, 2, 3, 4],
273
+ [5, 6, 7, 8, 9],
274
+ [10, 11, 12, 13, 14],
275
+ ]
276
+ )
277
+ expected = jnp.asarray(
278
+ [
279
+ [10, 10, 11, 12, 13, 14, 14],
280
+ [0, 0, 1, 2, 3, 4, 4],
281
+ [5, 5, 6, 7, 8, 9, 9],
282
+ [10, 10, 11, 12, 13, 14, 14],
283
+ [0, 0, 1, 2, 3, 4, 4],
284
+ ]
285
+ )
286
+ padded = transform.pad2d(array, ((1, 1), (1, 1)), ("wrap", "edge"))
287
+ onp.testing.assert_array_equal(padded, expected)
288
+
289
+ def test_pad2d_edge_wrap(self):
290
+ array = jnp.asarray(
291
+ [
292
+ [0, 1, 2, 3, 4],
293
+ [5, 6, 7, 8, 9],
294
+ [10, 11, 12, 13, 14],
295
+ ]
296
+ )
297
+ expected = jnp.asarray(
298
+ [
299
+ [4, 0, 1, 2, 3, 4, 0],
300
+ [4, 0, 1, 2, 3, 4, 0],
301
+ [9, 5, 6, 7, 8, 9, 5],
302
+ [14, 10, 11, 12, 13, 14, 10],
303
+ [14, 10, 11, 12, 13, 14, 10],
304
+ ]
305
+ )
306
+ padded = transform.pad2d(array, ((1, 1), (1, 1)), ("edge", "wrap"))
307
+ onp.testing.assert_array_equal(padded, expected)
308
+
309
+ def test_pad2d_batch_dims(self):
310
+ array = jnp.arange(300).reshape((2, 1, 5, 10, 3))
311
+ pad_width = ((3, 4), (1, 2))
312
+ padded = transform.pad2d(array, pad_width, "wrap")
313
+ self.assertSequenceEqual(padded.shape, (2, 1, 5, 17, 6))
314
+ for i in range(array.shape[0]):
315
+ for j in range(array.shape[1]):
316
+ for k in range(array.shape[2]):
317
+ onp.testing.assert_array_equal(
318
+ padded[i, j, k, :, :],
319
+ transform.pad2d(array[i, j, k, :, :], pad_width, "wrap"),
320
+ )
321
+
322
+
323
+ class PadWidthForKernelShapeTest(unittest.TestCase):
324
+ @parameterized.expand([[(2, 3)], [(4, 4)], [(5, 6)]])
325
+ def test_pad_width(self, kernel_shape):
326
+ # Checks that when padded by the computed amount, a convolution
327
+ # valid padding returns an array with the original shape.
328
+ kernel = jnp.ones(kernel_shape)
329
+ array = jnp.arange(77).reshape((7, 11)).astype(float)
330
+ pad_width = transform.pad_width_for_kernel_shape(kernel_shape)
331
+ padded = jnp.pad(array, pad_width)
332
+ y = jax.lax.conv_general_dilated(
333
+ lhs=padded[jnp.newaxis, jnp.newaxis, :, :], # HCHW
334
+ rhs=kernel[jnp.newaxis, jnp.newaxis, :, :], # OIHW
335
+ padding="VALID",
336
+ dimension_numbers=("NCHW", "OIHW", "NCHW"),
337
+ window_strides=(1, 1),
338
+ )
339
+ self.assertEqual(y.shape, (1, 1) + array.shape)
340
+
341
+ @parameterized.expand([[(2, 3)], [(4, 4)], [(5, 6)]])
342
+ def test_pad_width_offsets(self, kernel_shape):
343
+ kernel = onp.zeros(kernel_shape)
344
+ kernel[kernel_shape[0] // 2, kernel_shape[1] // 2] = 1
345
+ array = jnp.arange(77).reshape((7, 11)).astype(float)
346
+ pad_width = transform.pad_width_for_kernel_shape(kernel_shape)
347
+ padded = jnp.pad(array, pad_width)
348
+ y = jax.lax.conv_general_dilated(
349
+ lhs=padded[jnp.newaxis, jnp.newaxis, :, :], # HCHW
350
+ rhs=kernel[jnp.newaxis, jnp.newaxis, :, :], # OIHW
351
+ padding="VALID",
352
+ dimension_numbers=("NCHW", "OIHW", "NCHW"),
353
+ window_strides=(1, 1),
354
+ )
355
+ onp.testing.assert_array_equal(y[0, 0, ...], array)
356
+
357
+
358
+ class GaussianKernelTest(unittest.TestCase):
359
+ @parameterized.expand([[2], [3], [4], [5]])
360
+ def test_gaussian_peak_on_gridpoint(self, fwhm_size_multiple):
361
+ kernel = transform._gaussian_kernel(1.0, fwhm_size_multiple)
362
+ self.assertEqual(kernel[kernel.shape[0] // 2, kernel.shape[1] // 2], 1.0)
File without changes
File without changes