utax 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- utax-0.0.2/LICENSE +21 -0
- utax-0.0.2/PKG-INFO +32 -0
- utax-0.0.2/README.md +8 -0
- utax-0.0.2/pyproject.toml +37 -0
- utax-0.0.2/setup.cfg +4 -0
- utax-0.0.2/setup.py +31 -0
- utax-0.0.2/utax/__init__.py +9 -0
- utax-0.0.2/utax/convolution/__init__.py +5 -0
- utax-0.0.2/utax/convolution/classes.py +195 -0
- utax-0.0.2/utax/convolution/functions.py +211 -0
- utax-0.0.2/utax/info.py +29 -0
- utax-0.0.2/utax/interpolation.py +235 -0
- utax-0.0.2/utax/wavelet.py +243 -0
- utax-0.0.2/utax.egg-info/PKG-INFO +32 -0
- utax-0.0.2/utax.egg-info/SOURCES.txt +16 -0
- utax-0.0.2/utax.egg-info/dependency_links.txt +1 -0
- utax-0.0.2/utax.egg-info/requires.txt +7 -0
- utax-0.0.2/utax.egg-info/top_level.txt +1 -0
utax-0.0.2/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2022 Aymeric Galan
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
utax-0.0.2/PKG-INFO
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: utax
|
|
3
|
+
Version: 0.0.2
|
|
4
|
+
Summary: Utility functions for signal processing, compatible with the differentiable programming library JAX.
|
|
5
|
+
Home-page: https://github.com/aymgal/utax
|
|
6
|
+
Author: Austin Peel, Martin Millon, Frederic Dux, Kevin Michalewicz
|
|
7
|
+
Author-email: Aymeric Galan <aymeric.galan@gmail.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Project-URL: Homepage, https://github.com/aymgal/utax
|
|
10
|
+
Project-URL: Repository, https://github.com/aymgal/utax
|
|
11
|
+
Keywords: jax,signal processing,wavelet,convolution,interpolation
|
|
12
|
+
Requires-Python: >=3.10
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: jax>=0.5.0
|
|
16
|
+
Requires-Dist: jaxlib>=0.5.0
|
|
17
|
+
Provides-Extra: dev
|
|
18
|
+
Requires-Dist: pytest; extra == "dev"
|
|
19
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
20
|
+
Requires-Dist: pytest-pep8; extra == "dev"
|
|
21
|
+
Dynamic: home-page
|
|
22
|
+
Dynamic: license-file
|
|
23
|
+
Dynamic: requires-python
|
|
24
|
+
|
|
25
|
+

|
|
26
|
+

|
|
27
|
+
[](https://github.com/aymgal/utax/actions/workflows/ci_tests.yml)
|
|
28
|
+
[](https://coveralls.io/github/aymgal/utax?branch=main)
|
|
29
|
+
|
|
30
|
+
# `utax`
|
|
31
|
+
|
|
32
|
+
Utility functions for applications in signal processing problems, compatible with the differentable programming library `JAX`.
|
utax-0.0.2/README.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+

|
|
2
|
+

|
|
3
|
+
[](https://github.com/aymgal/utax/actions/workflows/ci_tests.yml)
|
|
4
|
+
[](https://coveralls.io/github/aymgal/utax?branch=main)
|
|
5
|
+
|
|
6
|
+
# `utax`
|
|
7
|
+
|
|
8
|
+
Utility functions for applications in signal processing problems, compatible with the differentable programming library `JAX`.
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=64"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "utax"
|
|
7
|
+
version = "0.0.2"
|
|
8
|
+
description = "Utility functions for signal processing, compatible with the differentiable programming library JAX."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = {text = "MIT"}
|
|
11
|
+
authors = [
|
|
12
|
+
{name = "Aymeric Galan", email = "aymeric.galan@gmail.com"},
|
|
13
|
+
{name = "Austin Peel"},
|
|
14
|
+
{name = "Martin Millon"},
|
|
15
|
+
{name = "Frederic Dux"},
|
|
16
|
+
{name = "Kevin Michalewicz"},
|
|
17
|
+
]
|
|
18
|
+
keywords = ["jax", "signal processing", "wavelet", "convolution", "interpolation"]
|
|
19
|
+
requires-python = ">=3.10"
|
|
20
|
+
dependencies = [
|
|
21
|
+
"jax>=0.5.0",
|
|
22
|
+
"jaxlib>=0.5.0",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
[project.optional-dependencies]
|
|
26
|
+
dev = [
|
|
27
|
+
"pytest",
|
|
28
|
+
"pytest-cov",
|
|
29
|
+
"pytest-pep8",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
[project.urls]
|
|
33
|
+
Homepage = "https://github.com/aymgal/utax"
|
|
34
|
+
Repository = "https://github.com/aymgal/utax"
|
|
35
|
+
|
|
36
|
+
[tool.setuptools.packages.find]
|
|
37
|
+
include = ["utax*"]
|
utax-0.0.2/setup.cfg
ADDED
utax-0.0.2/setup.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import setuptools
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
__name__ = 'utax'
|
|
5
|
+
|
|
6
|
+
release_info = {}
|
|
7
|
+
infopath = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
|
8
|
+
__name__, 'info.py'))
|
|
9
|
+
with open(infopath) as open_file:
|
|
10
|
+
exec(open_file.read(), release_info)
|
|
11
|
+
|
|
12
|
+
this_directory = os.path.abspath(os.path.dirname(__file__))
|
|
13
|
+
with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f:
|
|
14
|
+
long_description = f.read()
|
|
15
|
+
|
|
16
|
+
setuptools.setup(
|
|
17
|
+
name=__name__,
|
|
18
|
+
author=release_info['__author__'],
|
|
19
|
+
author_email=release_info['__email__'],
|
|
20
|
+
version=release_info['__version__'],
|
|
21
|
+
url=release_info['__url__'],
|
|
22
|
+
packages=setuptools.find_packages(),
|
|
23
|
+
python_requires=release_info['__python__'],
|
|
24
|
+
install_requires=release_info['__requires__'],
|
|
25
|
+
license=release_info['__license__'],
|
|
26
|
+
description=release_info['__about__'],
|
|
27
|
+
long_description=long_description,
|
|
28
|
+
long_description_content_type='text/markdown',
|
|
29
|
+
setup_requires=release_info['__setup_requires__'],
|
|
30
|
+
tests_require=release_info['__tests_require__']
|
|
31
|
+
)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from importlib.metadata import version, PackageNotFoundError
|
|
2
|
+
|
|
3
|
+
__author__ = 'Aymeric Galan, Austin Peel, Martin Millon, Frederic Dux, Kevin Michalewicz'
|
|
4
|
+
__email__ = 'aymeric.galan@gmail.com'
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
__version__ = version("utax")
|
|
8
|
+
except PackageNotFoundError:
|
|
9
|
+
__version__ = "unknown"
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
import numpy as np
|
|
3
|
+
from scipy.linalg import toeplitz
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from jax.scipy import stats as jstats
|
|
6
|
+
from jax import jit
|
|
7
|
+
|
|
8
|
+
from utax.convolution import convolve_separable_dilated
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GaussianFilter(object):
|
|
12
|
+
"""JAX-friendly Gaussian filter."""
|
|
13
|
+
def __init__(self, sigma, truncate=4.0, mode='edge'):
|
|
14
|
+
"""Convolve an image by a gaussian filter.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
sigma : float
|
|
19
|
+
Standard deviation of the Gaussian kernel.
|
|
20
|
+
truncate : float, optional
|
|
21
|
+
Truncate the filter at this many standard deviations.
|
|
22
|
+
Default is 4.0.
|
|
23
|
+
|
|
24
|
+
Note
|
|
25
|
+
----
|
|
26
|
+
Reproduces `scipy.ndimage.gaussian_filter` with high accuracy.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
if sigma <= 0:
|
|
30
|
+
self.kernel = None
|
|
31
|
+
else:
|
|
32
|
+
self.kernel = self.gaussian_kernel(sigma, truncate)
|
|
33
|
+
self.mode = mode
|
|
34
|
+
|
|
35
|
+
def gaussian_kernel(self, sigma, truncate):
|
|
36
|
+
# Determine the kernel pixel size (rounded up to an odd int)
|
|
37
|
+
self.radius = int(jnp.ceil(2 * truncate * sigma)) // 2
|
|
38
|
+
npix = self.radius * 2 + 1 # always at least 1
|
|
39
|
+
|
|
40
|
+
# Return the identity if sigma is not a positive number
|
|
41
|
+
if sigma <= 0:
|
|
42
|
+
return jnp.ones(1)
|
|
43
|
+
|
|
44
|
+
# Compute the kernel
|
|
45
|
+
x = jnp.ravel(jnp.indices((npix,))) # pixel coordinates
|
|
46
|
+
kernel = jstats.norm.pdf((x-self.radius) / sigma)
|
|
47
|
+
kernel /= kernel.sum()
|
|
48
|
+
|
|
49
|
+
return kernel
|
|
50
|
+
|
|
51
|
+
@partial(jit, static_argnums=(0,))
|
|
52
|
+
def __call__(self, image):
|
|
53
|
+
"""Jit-compiled convolution an image by a gaussian filter.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
image : array_like
|
|
58
|
+
Image to filter.
|
|
59
|
+
"""
|
|
60
|
+
# Convolve
|
|
61
|
+
# pad_mode = ['constant', 'edge'][mode == 'nearest']
|
|
62
|
+
# image_padded = jnp.pad(image, pad_width=radius, mode=pad_mode)
|
|
63
|
+
if self.kernel is None:
|
|
64
|
+
return image
|
|
65
|
+
return convolve_separable_dilated(image, self.kernel, boundary=self.mode)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class BlurringOperator(object):
|
|
69
|
+
|
|
70
|
+
def __init__(self, nx, ny, kernel):
|
|
71
|
+
self.target_shape = nx, ny
|
|
72
|
+
self.kernel = kernel
|
|
73
|
+
conv_matrix, self.temp_shape = self.toeplitz_matrix(self.target_shape, self.kernel)
|
|
74
|
+
self.conv_matrix = jnp.array(conv_matrix)
|
|
75
|
+
# get the indices for cropping the output of the full convolution
|
|
76
|
+
nxk, nyk = self.kernel.shape
|
|
77
|
+
if nxk % 2 == 0:
|
|
78
|
+
self.i1, self.i2 = nxk//2-1, -nxk//2
|
|
79
|
+
else:
|
|
80
|
+
self.i1, self.i2 = nxk//2, -nxk//2+1
|
|
81
|
+
if nyk % 2 == 0:
|
|
82
|
+
self.j1, self.j2 = nyk//2-1, -nyk//2
|
|
83
|
+
else:
|
|
84
|
+
self.j1, self.j2 = nyk//2, -nyk//2+1
|
|
85
|
+
|
|
86
|
+
@partial(jit, static_argnums=(0, 2))
|
|
87
|
+
def convolve(self, image, out_padding='same'):
|
|
88
|
+
image_conv = self.v2m(self.conv_matrix.dot(self.m2v(image)), self.temp_shape)
|
|
89
|
+
if out_padding == 'same':
|
|
90
|
+
return image_conv[self.i1:self.i2, self.j1:self.j2]
|
|
91
|
+
elif out_padding == 'full':
|
|
92
|
+
return image_conv
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(f"padding model '{out_padding}' is not supported.")
|
|
95
|
+
|
|
96
|
+
@partial(jit, static_argnums=(0, 2))
|
|
97
|
+
def convolve_transpose(self, image, in_padding='same'):
|
|
98
|
+
if in_padding == 'full':
|
|
99
|
+
image_padded = image
|
|
100
|
+
elif in_padding == 'same':
|
|
101
|
+
image_padded = jnp.pad(image, ((self.i1, -self.i2), (self.j1, -self.j2)),
|
|
102
|
+
'constant', constant_values=0)
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"padding model '{in_padding}' is not supported.")
|
|
105
|
+
image_conv_t = self.v2m(self.conv_matrix.T.dot(self.m2v(image_padded)), self.target_shape)
|
|
106
|
+
return image_conv_t
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def toeplitz_matrix(input_shape, kernel, verbose=False):
|
|
110
|
+
"""
|
|
111
|
+
Performs 2D convolution between input I and filter F by converting the F to a toeplitz matrix and multiply it
|
|
112
|
+
with vectorizes version of I
|
|
113
|
+
By : AliSaaalehi@gmail.com
|
|
114
|
+
|
|
115
|
+
Arg:
|
|
116
|
+
input shape of the image to be convolved (I) -- 2D numpy matrix
|
|
117
|
+
convolution kernel (F) -- numpy 2D matrix
|
|
118
|
+
verbose -- if True, all intermediate resutls will be printed after each step of the algorithms
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
output -- 2D numpy matrix, result of convolving I with F
|
|
122
|
+
"""
|
|
123
|
+
# number of columns and rows of the input
|
|
124
|
+
I_row_num, I_col_num = input_shape
|
|
125
|
+
|
|
126
|
+
# number of columns and rows of the filter
|
|
127
|
+
F_row_num, F_col_num = kernel.shape
|
|
128
|
+
|
|
129
|
+
# calculate the output dimensions
|
|
130
|
+
output_row_num = I_row_num + F_row_num - 1
|
|
131
|
+
output_col_num = I_col_num + F_col_num - 1
|
|
132
|
+
if verbose: print('output dimension:', output_row_num, output_col_num)
|
|
133
|
+
|
|
134
|
+
# zero pad the filter
|
|
135
|
+
F_zero_padded = np.pad(kernel, ((output_row_num - F_row_num, 0),
|
|
136
|
+
(0, output_col_num - F_col_num)),
|
|
137
|
+
'constant', constant_values=0)
|
|
138
|
+
if verbose: print('F_zero_padded: ', F_zero_padded)
|
|
139
|
+
|
|
140
|
+
# use each row of the zero-padded F to creat a toeplitz matrix.
|
|
141
|
+
# Number of columns in this matrices are same as numbe of columns of input signal
|
|
142
|
+
toeplitz_list = []
|
|
143
|
+
for i in range(F_zero_padded.shape[0]-1, -1, -1): # iterate from last row to the first row
|
|
144
|
+
c = F_zero_padded[i, :] # i th row of the F
|
|
145
|
+
r = np.r_[c[0], np.zeros(I_col_num-1)] # first row for the toeplitz fuction should be defined otherwise
|
|
146
|
+
# the result is wrong
|
|
147
|
+
toeplitz_m = toeplitz(c,r) # this function is in scipy.linalg library
|
|
148
|
+
toeplitz_list.append(toeplitz_m)
|
|
149
|
+
if verbose: print('F '+ str(i)+'\n', toeplitz_m)
|
|
150
|
+
|
|
151
|
+
# doubly blocked toeplitz indices:
|
|
152
|
+
# this matrix defines which toeplitz matrix from toeplitz_list goes to which part of the doubly blocked
|
|
153
|
+
c = range(1, F_zero_padded.shape[0]+1)
|
|
154
|
+
r = np.r_[c[0], np.zeros(I_row_num-1, dtype=int)]
|
|
155
|
+
doubly_indices = toeplitz(c, r)
|
|
156
|
+
if verbose: print('doubly indices \n', doubly_indices)
|
|
157
|
+
|
|
158
|
+
## creat doubly blocked matrix with zero values
|
|
159
|
+
toeplitz_shape = toeplitz_list[0].shape # shape of one toeplitz matrix
|
|
160
|
+
h = toeplitz_shape[0]*doubly_indices.shape[0]
|
|
161
|
+
w = toeplitz_shape[1]*doubly_indices.shape[1]
|
|
162
|
+
doubly_blocked_shape = [h, w]
|
|
163
|
+
doubly_blocked = np.zeros(doubly_blocked_shape)
|
|
164
|
+
|
|
165
|
+
# tile toeplitz matrices for each row in the doubly blocked matrix
|
|
166
|
+
b_h, b_w = toeplitz_shape # hight and withs of each block
|
|
167
|
+
for i in range(doubly_indices.shape[0]):
|
|
168
|
+
for j in range(doubly_indices.shape[1]):
|
|
169
|
+
start_i = i * b_h
|
|
170
|
+
start_j = j * b_w
|
|
171
|
+
end_i = start_i + b_h
|
|
172
|
+
end_j = start_j + b_w
|
|
173
|
+
doubly_blocked[start_i: end_i, start_j:end_j] = toeplitz_list[doubly_indices[i,j]-1]
|
|
174
|
+
|
|
175
|
+
if verbose: print('doubly_blocked: ', doubly_blocked)
|
|
176
|
+
|
|
177
|
+
out_shape = (output_row_num, output_col_num)
|
|
178
|
+
|
|
179
|
+
return doubly_blocked, out_shape
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def m2v(mat):
|
|
183
|
+
return jnp.flipud(mat).flatten(order='C')
|
|
184
|
+
|
|
185
|
+
@staticmethod
|
|
186
|
+
def v2m(vec, output_shape):
|
|
187
|
+
return jnp.flipud(vec.reshape(output_shape, order='C'))
|
|
188
|
+
|
|
189
|
+
# @staticmethod
|
|
190
|
+
# def m2v_t(mat):
|
|
191
|
+
# return jnp.flipud(mat.flatten(order='C'))
|
|
192
|
+
|
|
193
|
+
# @staticmethod
|
|
194
|
+
# def v2m_t(vec, output_shape):
|
|
195
|
+
# return jnp.flipud(vec).reshape(output_shape, order='C')
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
import numpy as np
|
|
3
|
+
from scipy import sparse
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from jax.scipy import stats as jstats
|
|
6
|
+
from jax import jit
|
|
7
|
+
from jax.lax import conv_general_dilated, conv_dimension_numbers
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@partial(jit, static_argnums=(2, 3))
|
|
11
|
+
def convolve_separable_dilated(image2D, kernel1D, dilation=1, boundary='edge'):
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
Convolves an image contained in image2D with the 1D kernel kernel1D.
|
|
15
|
+
The operation is basically the following:
|
|
16
|
+
blured2D = image2D * (kernel1D ∧ kernel1D )
|
|
17
|
+
where ∧ is a wedge product, here a tensor product.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
image2D : 2D array
|
|
24
|
+
imaged to be convolved with the kernel.
|
|
25
|
+
kernel1D : 1D array
|
|
26
|
+
kernel to convolve the image with..
|
|
27
|
+
dilation : TYPE, optional
|
|
28
|
+
makes the spacial extent of the kernel bigger. The default is 1.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
2D array
|
|
33
|
+
image convoluted by the kernel.
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# padding
|
|
38
|
+
b = int(kernel1D.size // 2) * dilation
|
|
39
|
+
padded = jnp.pad(image2D, ((b, b), (b, b)), mode=boundary)
|
|
40
|
+
# Fred D.: THIS PADDING IS DANGEROUS AS IT WILL OVERFLOW MEMORY VERY QUICKLY
|
|
41
|
+
# I LEAVE IT AS ORIGINALLY IMPLEMENTED AS I DO NOT WANT TO CHANGE THE
|
|
42
|
+
# OUTPUT OF THE WAVELET TRANSFORM (this could have impact on the science)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# specify the row and column operations for the jax convolution routine:
|
|
46
|
+
image = jnp.expand_dims(padded, (2,))
|
|
47
|
+
# shape (Nx, Ny, 1) -- (N, W, C)
|
|
48
|
+
# we treat the Nx as the batch number!! (because it is a 1D convolution
|
|
49
|
+
# over the rows)
|
|
50
|
+
kernel = jnp.expand_dims(kernel1D, (0,2,))
|
|
51
|
+
# here we have kernel shape ~(I,W,O)
|
|
52
|
+
# so:
|
|
53
|
+
# (Nbatch, Width, Channel) * (Inputdim, Widthkernel, Outputdim)
|
|
54
|
+
# -> (Nbatch, Width, Channel)
|
|
55
|
+
# where Nbatch is our number of rows.
|
|
56
|
+
dimension_numbers = ('NWC', 'IWO', 'NWC')
|
|
57
|
+
dn = conv_dimension_numbers(image.shape,
|
|
58
|
+
kernel.shape,
|
|
59
|
+
dimension_numbers)
|
|
60
|
+
# with these conv_general_dilated knows how to handle the different
|
|
61
|
+
# axes:
|
|
62
|
+
rowblur = conv_general_dilated(image, kernel,
|
|
63
|
+
window_strides=(1,),
|
|
64
|
+
padding='VALID',
|
|
65
|
+
rhs_dilation=(dilation,),
|
|
66
|
+
dimension_numbers=dn)
|
|
67
|
+
|
|
68
|
+
# now we do the same for the columns, hence this time we have
|
|
69
|
+
# (Height, Nbatch, Channel) * (Inputdim, Widthkernel, Outputdim)
|
|
70
|
+
# -> (Height, Nbatch, Channel)
|
|
71
|
+
# where Nbatch is our number of columns.
|
|
72
|
+
dimension_numbers = ('HNC', 'IHO', 'HNC')
|
|
73
|
+
dn = conv_dimension_numbers(image.shape,
|
|
74
|
+
kernel.shape,
|
|
75
|
+
dimension_numbers)
|
|
76
|
+
|
|
77
|
+
rowcolblur = conv_general_dilated(rowblur, kernel,
|
|
78
|
+
window_strides=(1,),
|
|
79
|
+
padding='VALID',
|
|
80
|
+
rhs_dilation=(dilation,),
|
|
81
|
+
dimension_numbers=dn)
|
|
82
|
+
|
|
83
|
+
return rowcolblur[:,:,0]
|
|
84
|
+
|
|
85
|
+
@partial(jit, static_argnums=(2, 3))
|
|
86
|
+
def convolve_dilated1D(signal1D, kernel1D, dilation=1, boundary='edge'):
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
Convolves a vector contained in signal1D with the 1D kernel.
|
|
90
|
+
The operation is basically the following:
|
|
91
|
+
blured1D = signal1D * kernel1D
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
signal1D : 1D array
|
|
96
|
+
vector to be convolved with the kernel.
|
|
97
|
+
kernel1D : 1D array
|
|
98
|
+
kernel to convolve signal1D with..
|
|
99
|
+
dilation : TYPE, optional
|
|
100
|
+
makes the spacial extent of the kernel bigger. The default is 1.
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
1D array convoluted by the kernel.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
# padding
|
|
108
|
+
b = int(kernel1D.size // 2) * dilation
|
|
109
|
+
padded = jnp.pad(signal1D, ((b, b)), mode=boundary)
|
|
110
|
+
|
|
111
|
+
shape = kernel1D.shape
|
|
112
|
+
strides = tuple(1 for s in shape)
|
|
113
|
+
rowblur = conv_general_dilated(padded[None, None], kernel1D[None, None],
|
|
114
|
+
window_strides=strides,
|
|
115
|
+
padding='VALID',
|
|
116
|
+
rhs_dilation=(dilation,),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return rowblur[0, 0]
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def build_convolution_matrix(psf_kernel_2d, image_shape):
|
|
124
|
+
"""
|
|
125
|
+
Build a sparse matrix to convolve an image via matrix-vector product.
|
|
126
|
+
Ported from C++ code in VKL from Vernardos & Koopmans 2022.
|
|
127
|
+
|
|
128
|
+
Note: only works with square kernel with odd number of pixels on the side,
|
|
129
|
+
lower than the number of pixels on the side of the image to be convolved.
|
|
130
|
+
|
|
131
|
+
Authors: @gvernard, @aymgal
|
|
132
|
+
"""
|
|
133
|
+
Ni, Nj = image_shape
|
|
134
|
+
Ncropx, Ncropy = psf_kernel_2d.shape
|
|
135
|
+
|
|
136
|
+
def setCroppedLimitsEven(k, Ncrop, Nimg, Nquad):
|
|
137
|
+
if k < (Nquad - 1):
|
|
138
|
+
Npre = k
|
|
139
|
+
Npost = Nquad
|
|
140
|
+
offset = Nquad - k
|
|
141
|
+
elif k > (Nimg - Nquad - 1):
|
|
142
|
+
Npre = Nquad
|
|
143
|
+
Npost = Nimg - k
|
|
144
|
+
offset = 0
|
|
145
|
+
else:
|
|
146
|
+
Npre = Nquad
|
|
147
|
+
Npost = Nquad
|
|
148
|
+
offset = 0
|
|
149
|
+
return Npre, Npost, offset
|
|
150
|
+
|
|
151
|
+
def setCroppedLimitsOdd(k, Ncrop, Nimg, Nquad):
|
|
152
|
+
if k < (Nquad - 1):
|
|
153
|
+
Npre = k
|
|
154
|
+
Npost = Nquad
|
|
155
|
+
offset = Nquad - 1 - k
|
|
156
|
+
elif k > (Nimg - Nquad - 1):
|
|
157
|
+
Npre = Nquad - 1
|
|
158
|
+
Npost = Nimg - k
|
|
159
|
+
offset = 0
|
|
160
|
+
else:
|
|
161
|
+
Npre = Nquad-1
|
|
162
|
+
Npost = Nquad
|
|
163
|
+
offset = 0
|
|
164
|
+
return Npre, Npost, offset
|
|
165
|
+
|
|
166
|
+
# get the correct method to offset the PSF kernel from the above
|
|
167
|
+
if Ncropx % 2 == 0:
|
|
168
|
+
# Warning: this might be broken in certain cases
|
|
169
|
+
func_limits_x = setCroppedLimitsEven
|
|
170
|
+
Nquadx = Ncropx//2
|
|
171
|
+
else:
|
|
172
|
+
func_limits_x = setCroppedLimitsOdd
|
|
173
|
+
Nquadx = int(np.ceil(Ncropx/2.))
|
|
174
|
+
if Ncropx % 2 == 0:
|
|
175
|
+
# Warning: this might be broken in certain cases
|
|
176
|
+
func_limits_y = setCroppedLimitsEven
|
|
177
|
+
Nquady = Ncropy//2
|
|
178
|
+
else:
|
|
179
|
+
func_limits_y = setCroppedLimitsOdd
|
|
180
|
+
Nquady = int(np.ceil(Ncropy/2.))
|
|
181
|
+
|
|
182
|
+
# create the blurring matrix in a sparse form
|
|
183
|
+
blur = psf_kernel_2d.flatten()
|
|
184
|
+
sparse_B_rows, sparse_B_cols = [], []
|
|
185
|
+
sparse_B_values = []
|
|
186
|
+
for i in range(Ni): # loop over image rows
|
|
187
|
+
for j in range(Nj): # loop over image columns
|
|
188
|
+
|
|
189
|
+
Nleft, Nright, crop_offsetx = func_limits_x(j, Ncropx, Nj, Nquadx)
|
|
190
|
+
Ntop, Nbottom, crop_offsety = func_limits_y(i, Ncropy, Ni, Nquady)
|
|
191
|
+
|
|
192
|
+
crop_offset = crop_offsety*Ncropx + crop_offsetx
|
|
193
|
+
|
|
194
|
+
for ii in range(i-Ntop, i+Nbottom): # loop over PSF rows
|
|
195
|
+
ic = ii - i + Ntop
|
|
196
|
+
|
|
197
|
+
for jj in range(j-Nleft, j+Nright): # loop over PSF columns
|
|
198
|
+
jc = jj - j + Nleft;
|
|
199
|
+
|
|
200
|
+
val = blur[crop_offset + ic*Ncropx + jc]
|
|
201
|
+
|
|
202
|
+
# save entries
|
|
203
|
+
# (note: rows and cols were inverted from the VKL code)
|
|
204
|
+
sparse_B_rows.append(ii*Nj + jj)
|
|
205
|
+
sparse_B_cols.append(i*Nj + j)
|
|
206
|
+
sparse_B_values.append(val)
|
|
207
|
+
|
|
208
|
+
# populate the sparse matrix
|
|
209
|
+
blurring_matrix = sparse.csr_matrix((sparse_B_values, (sparse_B_rows, sparse_B_cols)),
|
|
210
|
+
shape=(Ni**2, Nj**2))
|
|
211
|
+
return blurring_matrix
|
utax-0.0.2/utax/info.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""PACKAGE INFO
|
|
2
|
+
|
|
3
|
+
This module provides some basic information about the package.
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# Set the package release version
|
|
8
|
+
version_info = (0, 0, 2)
|
|
9
|
+
__version__ = '.'.join(str(c) for c in version_info)
|
|
10
|
+
|
|
11
|
+
# Set the package details
|
|
12
|
+
__author__ = 'Aymeric Galan, Austin Peel, Martin Millon, Frederic Dux, Kevin Michalewicz'
|
|
13
|
+
__email__ = 'aymeric.galan@gmail.com'
|
|
14
|
+
__year__ = '2022'
|
|
15
|
+
__url__ = 'https://github.com/aymgal/utax'
|
|
16
|
+
__description__ = 'Utility functions for signal processing, compatible with the differentable programming library JAX.'
|
|
17
|
+
__python__ = '>=3.10'
|
|
18
|
+
__requires__ = [
|
|
19
|
+
'jax>=0.5.0',
|
|
20
|
+
'jaxlib>=0.5.0',
|
|
21
|
+
] # Package dependencies
|
|
22
|
+
|
|
23
|
+
# Default package properties
|
|
24
|
+
__license__ = 'MIT'
|
|
25
|
+
__about__ = ('{} Author: {}, Email: {}, Year: {}, {}'
|
|
26
|
+
''.format(__name__, __author__, __email__, __year__,
|
|
27
|
+
__description__))
|
|
28
|
+
__setup_requires__ = ['pytest-runner', ]
|
|
29
|
+
__tests_require__ = ['pytest', 'pytest-cov', 'pytest-pep8']
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
from jax import lax, vmap #, jit
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BilinearInterpolator(object):
|
|
8
|
+
"""Bilinear interpolation of a 2D field.
|
|
9
|
+
|
|
10
|
+
Functionality is modelled after scipy.interpolate.RectBivariateSpline
|
|
11
|
+
when `kx` and `ky` are both equal to 1. Results match the scipy version when
|
|
12
|
+
interpolated values lie within the x and y domain (boundaries included).
|
|
13
|
+
Returned values can be significantly different outside the natural domain,
|
|
14
|
+
as the scipy version does not extrapolate. Evaluation of this jax version
|
|
15
|
+
is MUCH SLOWER as well.
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
def __init__(self, x, y, z, allow_extrapolation=True):
|
|
19
|
+
self.z = jnp.array(z)
|
|
20
|
+
|
|
21
|
+
# Sort x if not increasing
|
|
22
|
+
x = jnp.array(x)
|
|
23
|
+
x_sorted = jnp.sort(x)
|
|
24
|
+
flip_x = ~jnp.all(jnp.diff(x) >= 0)
|
|
25
|
+
|
|
26
|
+
def x_keep_fn(_):
|
|
27
|
+
return x, self.z
|
|
28
|
+
|
|
29
|
+
def x_sort_fn(_):
|
|
30
|
+
return x_sorted, jnp.flip(self.z, axis=0)
|
|
31
|
+
|
|
32
|
+
self.x, self.z = lax.cond(flip_x, x_sort_fn, x_keep_fn, operand=None)
|
|
33
|
+
|
|
34
|
+
# Sort y if not increasing
|
|
35
|
+
y = jnp.array(y)
|
|
36
|
+
y_sorted = jnp.sort(y)
|
|
37
|
+
flip_y = ~jnp.all(jnp.diff(y) >= 0)
|
|
38
|
+
|
|
39
|
+
def y_keep_fn(_):
|
|
40
|
+
return y, self.z
|
|
41
|
+
|
|
42
|
+
def y_sort_fn(_):
|
|
43
|
+
return y_sorted, jnp.flip(self.z, axis=1)
|
|
44
|
+
|
|
45
|
+
self.y, self.z = lax.cond(flip_y, y_sort_fn, y_keep_fn, operand=None)
|
|
46
|
+
self._extrapol_bool = allow_extrapolation
|
|
47
|
+
|
|
48
|
+
def __call__(self, x, y, dx=0, dy=0):
|
|
49
|
+
"""Vectorized evaluation of the interpolation or its derivatives.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
x, y : array_like
|
|
54
|
+
Position(s) at which to evaluate the interpolation.
|
|
55
|
+
dx, dy : int, either 0 or 1
|
|
56
|
+
If 1, return the first partial derivative of the interpolation
|
|
57
|
+
with respect to that coordinate. Only one of (dx, dy) should be
|
|
58
|
+
nonzero at a time.
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
x = jnp.atleast_1d(x)
|
|
62
|
+
y = jnp.atleast_1d(y)
|
|
63
|
+
|
|
64
|
+
error_msg_type = "dx and dy must be integers"
|
|
65
|
+
error_msg_value = "dx and dy must only be either 0 or 1"
|
|
66
|
+
assert isinstance(dx, int) and isinstance(dy, int), error_msg_type
|
|
67
|
+
assert dx in (0, 1) and dy in (0, 1), error_msg_value
|
|
68
|
+
if dx == 1: dy = 0
|
|
69
|
+
|
|
70
|
+
return vmap(self._evaluate, in_axes=(0, 0, None, None))(x, y, dx, dy)
|
|
71
|
+
|
|
72
|
+
# @partial(jit, static_argnums=(0,))
|
|
73
|
+
def _compute_coeffs(self, x, y):
|
|
74
|
+
# Find the pixel that the point (x, y) falls in
|
|
75
|
+
# x_ind = jnp.digitize(x, self.x_padded) - 1
|
|
76
|
+
# y_ind = jnp.digitize(y, self.y_padded) - 1
|
|
77
|
+
x_ind = jnp.searchsorted(self.x, x, side='right') - 1
|
|
78
|
+
x_ind = jnp.clip(x_ind, a_min=0, a_max=(len(self.x) - 2))
|
|
79
|
+
y_ind = jnp.searchsorted(self.y, y, side='right') - 1
|
|
80
|
+
y_ind = jnp.clip(y_ind, a_min=0, a_max=(len(self.y) - 2))
|
|
81
|
+
|
|
82
|
+
# Determine the coordinates and dimensions of this pixel
|
|
83
|
+
x1 = self.x[x_ind]
|
|
84
|
+
x2 = self.x[x_ind + 1]
|
|
85
|
+
y1 = self.y[y_ind]
|
|
86
|
+
y2 = self.y[y_ind + 1]
|
|
87
|
+
area = (x2 - x1) * (y2 - y1)
|
|
88
|
+
|
|
89
|
+
# Compute function values at the four corners
|
|
90
|
+
# Edge padding is implicitly constant
|
|
91
|
+
v11 = self.z[x_ind, y_ind]
|
|
92
|
+
v12 = self.z[x_ind, y_ind + 1]
|
|
93
|
+
v21 = self.z[x_ind + 1, y_ind]
|
|
94
|
+
v22 = self.z[x_ind + 1, y_ind + 1]
|
|
95
|
+
|
|
96
|
+
# Compute the coefficients
|
|
97
|
+
a0_ = v11 * x2 * y2 - v12 * x2 * y1 - v21 * x1 * y2 + v22 * x1 * y1
|
|
98
|
+
a1_ = -v11 * y2 + v12 * y1 + v21 * y2 - v22 * y1
|
|
99
|
+
a2_ = -v11 * x2 + v12 * x2 + v21 * x1 - v22 * x1
|
|
100
|
+
a3_ = v11 - v12 - v21 + v22
|
|
101
|
+
|
|
102
|
+
return a0_ / area, a1_ / area, a2_ / area, a3_ / area
|
|
103
|
+
|
|
104
|
+
def _evaluate(self, x, y, dx=0, dy=0):
|
|
105
|
+
"""Single-point evaluation of the interpolation."""
|
|
106
|
+
a0, a1, a2, a3 = self._compute_coeffs(x, y)
|
|
107
|
+
if (dx, dy) == (0, 0):
|
|
108
|
+
result = a0 + a1 * x + a2 * y + a3 * x * y
|
|
109
|
+
elif (dx, dy) == (1, 0):
|
|
110
|
+
result = a1 + a3 * y
|
|
111
|
+
else:
|
|
112
|
+
result = a2 + a3 * x
|
|
113
|
+
# if extrapolation is not allowed, then we mask out values outside the original bounding box
|
|
114
|
+
result = lax.cond(self._extrapol_bool,
|
|
115
|
+
lambda _: result,
|
|
116
|
+
lambda _: result * (x >= self.x[0]) * (x <= self.x[-1]) * (y >= self.y[0]) * (y <= self.y[-1]),
|
|
117
|
+
operand=None)
|
|
118
|
+
return result
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class BicubicInterpolator(object):
|
|
122
|
+
"""Bicubic interpolation of a 2D field.
|
|
123
|
+
|
|
124
|
+
Functionality is modelled after scipy.interpolate.RectBivariateSpline
|
|
125
|
+
when `kx` and `ky` are both equal to 3.
|
|
126
|
+
|
|
127
|
+
"""
|
|
128
|
+
def __init__(self, x, y, z, zx=None, zy=None, zxy=None, allow_extrapolation=True):
|
|
129
|
+
self.z = jnp.array(z)
|
|
130
|
+
if np.all(np.diff(x) >= 0): # check if sorted in increasing order
|
|
131
|
+
self.x = jnp.array(x)
|
|
132
|
+
else:
|
|
133
|
+
self.x = jnp.array(np.sort(x))
|
|
134
|
+
self.z = jnp.flip(self.z, axis=1)
|
|
135
|
+
if np.all(np.diff(y) >= 0): # check if sorted in increasing order
|
|
136
|
+
self.y = jnp.array(y)
|
|
137
|
+
else:
|
|
138
|
+
self.y = jnp.array(np.sort(y))
|
|
139
|
+
self.z = jnp.flip(self.z, axis=0)
|
|
140
|
+
|
|
141
|
+
# Assume uniform coordinate spacing
|
|
142
|
+
self.dx = self.x[1] - self.x[0]
|
|
143
|
+
self.dy = self.y[1] - self.y[0]
|
|
144
|
+
|
|
145
|
+
# Compute approximate partial derivatives if not provided
|
|
146
|
+
if zx is None:
|
|
147
|
+
self.zx = jnp.gradient(z, axis=0) / self.dx
|
|
148
|
+
else:
|
|
149
|
+
self.zx = zy
|
|
150
|
+
if zy is None:
|
|
151
|
+
self.zy = jnp.gradient(z, axis=1) / self.dy
|
|
152
|
+
else:
|
|
153
|
+
self.zy = zx
|
|
154
|
+
if zxy is None:
|
|
155
|
+
self.zxy = jnp.gradient(self.zx, axis=1) / self.dy
|
|
156
|
+
else:
|
|
157
|
+
self.zxy = zxy
|
|
158
|
+
|
|
159
|
+
# Prepare coefficients for function evaluations
|
|
160
|
+
self._A = jnp.array([[1., 0., 0., 0.],
|
|
161
|
+
[0., 0., 1., 0.],
|
|
162
|
+
[-3., 3., -2., -1.],
|
|
163
|
+
[2., -2., 1., 1.]])
|
|
164
|
+
self._B = jnp.array([[1., 0., -3., 2.],
|
|
165
|
+
[0., 0., 3., -2.],
|
|
166
|
+
[0., 1., -2., 1.],
|
|
167
|
+
[0., 0., -1., 1.]])
|
|
168
|
+
row0 = [self.z[:-1,:-1], self.z[:-1,1:], self.dy * self.zy[:-1,:-1], self.dy * self.zy[:-1,1:]]
|
|
169
|
+
row1 = [self.z[1:,:-1], self.z[1:,1:], self.dy * self.zy[1:,:-1], self.dy * self.zy[1:,1:]]
|
|
170
|
+
row2 = self.dx * jnp.array([self.zx[:-1,:-1], self.zx[:-1,1:],
|
|
171
|
+
self.dy * self.zxy[:-1,:-1], self.dy * self.zxy[:-1,1:]])
|
|
172
|
+
row3 = self.dx * jnp.array([self.zx[1:,:-1], self.zx[1:,1:],
|
|
173
|
+
self.dy * self.zxy[1:,:-1], self.dy * self.zxy[1:,1:]])
|
|
174
|
+
self._m = jnp.array([row0, row1, row2, row3])
|
|
175
|
+
|
|
176
|
+
self._m = jnp.transpose(self._m, axes=(2, 3, 0, 1))
|
|
177
|
+
|
|
178
|
+
self._extrapol_bool = allow_extrapolation
|
|
179
|
+
|
|
180
|
+
def __call__(self, x, y, dx=0, dy=0):
|
|
181
|
+
"""Vectorized evaluation of the interpolation or its derivatives.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
x, y : array_like
|
|
186
|
+
Position(s) at which to evaluate the interpolation.
|
|
187
|
+
dx, dy : int, either 0, 1, or 2
|
|
188
|
+
Return the nth partial derivative of the interpolation
|
|
189
|
+
with respect to the specified coordinate. Only one of (dx, dy)
|
|
190
|
+
should be nonzero at a time.
|
|
191
|
+
|
|
192
|
+
"""
|
|
193
|
+
x = jnp.atleast_1d(x)
|
|
194
|
+
y = jnp.atleast_1d(y)
|
|
195
|
+
if x.ndim == 1:
|
|
196
|
+
vmap_call = vmap(self._evaluate, in_axes=(0, 0, None, None))
|
|
197
|
+
elif x.ndim == 2:
|
|
198
|
+
vmap_call = vmap(vmap(self._evaluate, in_axes=(0, 0, None, None)),
|
|
199
|
+
in_axes=(0, 0, None, None))
|
|
200
|
+
return vmap_call(x, y, dx, dy)
|
|
201
|
+
|
|
202
|
+
def _evaluate(self, x, y, dx=0, dy=0):
|
|
203
|
+
"""Evaluate the interpolation at a single point."""
|
|
204
|
+
# Determine which pixel (i, j) the point (x, y) falls in
|
|
205
|
+
i = jnp.maximum(0, jnp.searchsorted(self.x, x) - 1)
|
|
206
|
+
j = jnp.maximum(0, jnp.searchsorted(self.y, y) - 1)
|
|
207
|
+
|
|
208
|
+
# Rescale coordinates into (0, 1)
|
|
209
|
+
u = (x - self.x[i]) / self.dx
|
|
210
|
+
v = (y - self.y[j]) / self.dy
|
|
211
|
+
|
|
212
|
+
# Compute interpolation coefficients
|
|
213
|
+
a = jnp.dot(self._A, jnp.dot(self._m[i, j], self._B))
|
|
214
|
+
|
|
215
|
+
if dx == 0:
|
|
216
|
+
uu = jnp.asarray([1., u, u**2, u**3])
|
|
217
|
+
if dx == 1:
|
|
218
|
+
uu = jnp.asarray([0., 1., 2. * u, 3. * u**2]) / self.dx
|
|
219
|
+
if dx == 2:
|
|
220
|
+
uu = jnp.asarray([0., 0., 2., 6. * u]) / self.dx**2
|
|
221
|
+
if dy == 0:
|
|
222
|
+
vv = jnp.asarray([1., v, v**2, v**3])
|
|
223
|
+
if dy == 1:
|
|
224
|
+
vv = jnp.asarray([0., 1., 2. * v, 3. * v**2]) / self.dy
|
|
225
|
+
if dy == 2:
|
|
226
|
+
vv = jnp.asarray([0., 0., 2., 6. * v]) / self.dy**2
|
|
227
|
+
result = jnp.dot(uu, jnp.dot(a, vv))
|
|
228
|
+
|
|
229
|
+
# if extrapolation is not allowed, then we mask out values outside the original bounding box
|
|
230
|
+
result = lax.cond(self._extrapol_bool,
|
|
231
|
+
lambda _: result,
|
|
232
|
+
lambda _: result * (x >= self.x[0]) * (x <= self.x[-1]) * (y >= self.y[0]) * (y <= self.y[-1]),
|
|
233
|
+
operand=None)
|
|
234
|
+
return result
|
|
235
|
+
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
from jax import jit
|
|
4
|
+
|
|
5
|
+
from utax.convolution import convolve_separable_dilated, convolve_dilated1D
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class WaveletTransform(object):
|
|
9
|
+
"""
|
|
10
|
+
Class that handles wavelet transform using JAX, using the 'a trous' algorithm
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
nscales : number of scales in the decomposition
|
|
15
|
+
dim : dimensionality of the convolution (2 for image, 1 for 1D vector)
|
|
16
|
+
self._type : supported types are 'starlet', 'battle-lemarie-1', 'battle-lemarie-3'
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
def __init__(self, nscales, wavelet_type='starlet', second_gen=False, dim=2):
|
|
20
|
+
self._n_scales = nscales
|
|
21
|
+
self._second_gen = second_gen
|
|
22
|
+
if wavelet_type == 'starlet':
|
|
23
|
+
self._h = jnp.array([1., 4., 6., 4., 1.]) / 16.
|
|
24
|
+
elif wavelet_type == 'battle-lemarie-1': # (order 1 = 2 vanishing moments)
|
|
25
|
+
self._h = jnp.array([-0.000122686, -0.000224296, 0.000511636,
|
|
26
|
+
0.000923371, -0.002201945, -0.003883261, 0.009990599,
|
|
27
|
+
0.016974805, -0.051945337, -0.06910102, 0.39729643,
|
|
28
|
+
0.817645956, 0.39729643, -0.06910102, -0.051945337,
|
|
29
|
+
0.016974805, 0.009990599, -0.003883261, -0.002201945,
|
|
30
|
+
0.000923371, 0.000511636, -0.000224296, -0.000122686])
|
|
31
|
+
elif wavelet_type == 'battle-lemarie-3': # (order 2 = 4 vanishing moments)
|
|
32
|
+
self._h = jnp.array([0.000146098, -0.000232304, -0.000285414,
|
|
33
|
+
0.000462093, 0.000559952, -0.000927187, -0.001103748,
|
|
34
|
+
0.00188212, 0.002186714, -0.003882426, -0.00435384,
|
|
35
|
+
0.008201477, 0.008685294, -0.017982291, -0.017176331,
|
|
36
|
+
0.042068328, 0.032080869, -0.110036987, -0.050201753,
|
|
37
|
+
0.433923147, 0.766130398, 0.433923147, -0.050201753,
|
|
38
|
+
-0.110036987, 0.032080869, 0.042068328, -0.017176331,
|
|
39
|
+
-0.017982291, 0.008685294, 0.008201477, -0.00435384,
|
|
40
|
+
-0.003882426, 0.002186714, 0.00188212, -0.001103748,
|
|
41
|
+
-0.000927187, 0.000559952, 0.000462093, -0.000285414,
|
|
42
|
+
-0.000232304, 0.000146098])
|
|
43
|
+
|
|
44
|
+
elif wavelet_type == 'battle-lemarie-5': # (order 5 = 6 vanishing moments)
|
|
45
|
+
self._h = jnp.array([1.4299532e-004,
|
|
46
|
+
1.5656611e-004,
|
|
47
|
+
-2.2509746e-004,
|
|
48
|
+
-2.4421337e-004,
|
|
49
|
+
3.5556002e-004,
|
|
50
|
+
3.8161407e-004,
|
|
51
|
+
-5.6393016e-004,
|
|
52
|
+
-5.9748378e-004,
|
|
53
|
+
8.9882146e-004,
|
|
54
|
+
9.3739129e-004,
|
|
55
|
+
-1.4412528e-003,
|
|
56
|
+
-1.4737089e-003,
|
|
57
|
+
2.3286290e-003,
|
|
58
|
+
2.3211761e-003,
|
|
59
|
+
-3.7992267e-003,
|
|
60
|
+
-3.6602095e-003,
|
|
61
|
+
6.2791340e-003,
|
|
62
|
+
5.7683203e-003,
|
|
63
|
+
-1.0562022e-002,
|
|
64
|
+
-9.0493510e-003,
|
|
65
|
+
1.8208558e-002,
|
|
66
|
+
1.4009689e-002,
|
|
67
|
+
-3.2519969e-002,
|
|
68
|
+
-2.1006296e-002,
|
|
69
|
+
6.1312356e-002,
|
|
70
|
+
2.9474179e-002,
|
|
71
|
+
-1.2926869e-001,
|
|
72
|
+
-3.7019995e-002,
|
|
73
|
+
4.4246341e-001,
|
|
74
|
+
7.4723338e-001,
|
|
75
|
+
4.4246341e-001,
|
|
76
|
+
-3.7019995e-002,
|
|
77
|
+
-1.2926869e-001,
|
|
78
|
+
2.9474179e-002,
|
|
79
|
+
6.1312356e-002,
|
|
80
|
+
-2.1006296e-002,
|
|
81
|
+
-3.2519969e-002,
|
|
82
|
+
1.4009689e-002,
|
|
83
|
+
1.8208558e-002,
|
|
84
|
+
-9.0493510e-003,
|
|
85
|
+
-1.0562022e-002,
|
|
86
|
+
5.7683203e-003,
|
|
87
|
+
6.2791340e-003,
|
|
88
|
+
-3.6602095e-003,
|
|
89
|
+
-3.7992267e-003,
|
|
90
|
+
2.3211761e-003,
|
|
91
|
+
2.3286290e-003,
|
|
92
|
+
-1.4737089e-003,
|
|
93
|
+
-1.4412528e-003,
|
|
94
|
+
9.3739129e-004,
|
|
95
|
+
8.9882146e-004,
|
|
96
|
+
-5.9748378e-004,
|
|
97
|
+
-5.6393016e-004,
|
|
98
|
+
3.8161407e-004,
|
|
99
|
+
3.5556002e-004,
|
|
100
|
+
-2.4421337e-004,
|
|
101
|
+
-2.2509746e-004,
|
|
102
|
+
1.5656611e-004,
|
|
103
|
+
1.4299532e-004])
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError(f"'{wavelet_type}' starlet transform is not supported")
|
|
106
|
+
|
|
107
|
+
self._h /= jnp.sum(self._h)
|
|
108
|
+
self._fac = len(self._h) // 2
|
|
109
|
+
|
|
110
|
+
if self._second_gen:
|
|
111
|
+
self.decompose = self._decompose_2nd_gen
|
|
112
|
+
self.reconstruct = self._reconstruct_2nd_gen
|
|
113
|
+
else:
|
|
114
|
+
self.decompose = self._decompose_1st_gen
|
|
115
|
+
self.reconstruct = self._reconstruct_1st_gen
|
|
116
|
+
|
|
117
|
+
self._dim = dim
|
|
118
|
+
if self._dim == 1:
|
|
119
|
+
self.convolve = convolve_dilated1D
|
|
120
|
+
if self._second_gen:
|
|
121
|
+
raise NotImplementedError('2nd generation of wavelet not yet implemented for 1D decomposition.')
|
|
122
|
+
elif self._dim == 2:
|
|
123
|
+
self.convolve = convolve_separable_dilated
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Dimensionality {dim} not supported.")
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def scale_norms(self):
|
|
129
|
+
if not hasattr(self, '_norms'):
|
|
130
|
+
npix_dirac = 2**(self._n_scales + 2)
|
|
131
|
+
if self._dim == 1 :
|
|
132
|
+
dirac = (jnp.arange(npix_dirac) == int(npix_dirac / 2)).astype(float)
|
|
133
|
+
wt_dirac = self.decompose(dirac)
|
|
134
|
+
self._norms = jnp.sqrt(jnp.sum(wt_dirac**2, axis=(1,)))
|
|
135
|
+
else:
|
|
136
|
+
dirac = jnp.diag((jnp.arange(npix_dirac) == int(npix_dirac / 2)).astype(float))
|
|
137
|
+
wt_dirac = self.decompose(dirac)
|
|
138
|
+
self._norms = jnp.sqrt(jnp.sum(wt_dirac**2, axis=(1, 2,)))
|
|
139
|
+
return self._norms
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@partial(jit, static_argnums=(0,))
|
|
143
|
+
def _decompose_1st_gen(self, image):
|
|
144
|
+
"""Decompose an image into the chosen wavelet basis"""
|
|
145
|
+
# Validate input
|
|
146
|
+
assert self._n_scales >= 0, "nscales must be a non-negative integer"
|
|
147
|
+
if self._n_scales == 0:
|
|
148
|
+
return image
|
|
149
|
+
|
|
150
|
+
# Preparations
|
|
151
|
+
# image = jnp.copy(image)
|
|
152
|
+
kernel = self._h.copy()
|
|
153
|
+
|
|
154
|
+
# Compute the first scale:
|
|
155
|
+
c1 = self.convolve(image, kernel)
|
|
156
|
+
# Wavelet coefficients:
|
|
157
|
+
w0 = (image - c1)
|
|
158
|
+
result = jnp.expand_dims(w0, 0)
|
|
159
|
+
cj = c1
|
|
160
|
+
|
|
161
|
+
# Compute the remaining scales
|
|
162
|
+
# at each scale, the kernel becomes larger ( a trou ) using the
|
|
163
|
+
# dilation argument in the jax wrapper for convolution.
|
|
164
|
+
for step in range(1, self._n_scales):
|
|
165
|
+
cj1 = self.convolve(cj, kernel, dilation=self._fac**step)
|
|
166
|
+
# wavelet coefficients
|
|
167
|
+
wj = (cj - cj1)
|
|
168
|
+
result = jnp.concatenate((result, jnp.expand_dims(wj, 0)), axis=0)
|
|
169
|
+
cj = cj1
|
|
170
|
+
|
|
171
|
+
# Append final coarse scale
|
|
172
|
+
result = jnp.concatenate((result, jnp.expand_dims(cj, axis=0)), axis=0)
|
|
173
|
+
return result
|
|
174
|
+
|
|
175
|
+
@partial(jit, static_argnums=(0,))
|
|
176
|
+
def _decompose_2nd_gen(self, image):
|
|
177
|
+
"""Decompose an image into the chosen wavelet basis"""
|
|
178
|
+
# Validate input
|
|
179
|
+
assert self._n_scales >= 0, "nscales must be a non-negative integer"
|
|
180
|
+
if self._n_scales == 0:
|
|
181
|
+
return image
|
|
182
|
+
|
|
183
|
+
# Preparations
|
|
184
|
+
# image = jnp.copy(image)
|
|
185
|
+
kernel = self._h.copy()
|
|
186
|
+
|
|
187
|
+
# Compute the first scale:
|
|
188
|
+
c1 = self.convolve(image, kernel)
|
|
189
|
+
c1p = self.convolve(c1, kernel)
|
|
190
|
+
# Wavelet coefficients:
|
|
191
|
+
w0 = (image - c1p)
|
|
192
|
+
result = jnp.expand_dims(w0, 0)
|
|
193
|
+
cj = c1
|
|
194
|
+
|
|
195
|
+
# Compute the remaining scales
|
|
196
|
+
# at each scale, the kernel becomes larger ( a trou ) using the
|
|
197
|
+
# dilation argument in the jax wrapper for convolution.
|
|
198
|
+
for step in range(1, self._n_scales):
|
|
199
|
+
cj1 = self.convolve(cj, kernel, dilation=self._fac**step)
|
|
200
|
+
cj1p = self.convolve(cj1, kernel, dilation=self._fac**step)
|
|
201
|
+
# wavelet coefficients
|
|
202
|
+
wj = (cj - cj1p)
|
|
203
|
+
result = jnp.concatenate((result, jnp.expand_dims(wj, 0)), axis=0)
|
|
204
|
+
cj = cj1
|
|
205
|
+
|
|
206
|
+
# Append final coarse scale
|
|
207
|
+
result = jnp.concatenate((result, jnp.expand_dims(cj, axis=0)), axis=0)
|
|
208
|
+
return result
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@partial(jit, static_argnums=(0,))
|
|
212
|
+
def _reconstruct_1st_gen(self, coeffs):
|
|
213
|
+
return jnp.sum(coeffs, axis=0)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@partial(jit, static_argnums=(0,))
|
|
217
|
+
def _reconstruct_2nd_gen(self, coeffs):
|
|
218
|
+
# Validate input
|
|
219
|
+
assert coeffs.shape[-3] == self._n_scales+1, \
|
|
220
|
+
"Wavelet coefficients are not consistent with number of scales"
|
|
221
|
+
if self._n_scales == 0:
|
|
222
|
+
return coeffs[0, :, :]
|
|
223
|
+
|
|
224
|
+
kernel = self._h
|
|
225
|
+
|
|
226
|
+
# Start with the last scale 'J-1'
|
|
227
|
+
cJ = coeffs[self._n_scales, :, :]
|
|
228
|
+
cJp = self.convolve(cJ, kernel,
|
|
229
|
+
dilation=self._fac**(self._n_scales-1))
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
wJ = coeffs[self._n_scales-1, :, :]
|
|
233
|
+
cj = cJp + wJ
|
|
234
|
+
|
|
235
|
+
# Compute the remaining scales
|
|
236
|
+
for ii in range(self._n_scales-2, -1, -1):
|
|
237
|
+
cj1 = cj
|
|
238
|
+
cj1p = self.convolve(cj1, kernel, dilation=self._fac**ii)
|
|
239
|
+
wj1 = coeffs[ii, :, :]
|
|
240
|
+
cj = cj1p + wj1
|
|
241
|
+
|
|
242
|
+
result = cj
|
|
243
|
+
return result
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: utax
|
|
3
|
+
Version: 0.0.2
|
|
4
|
+
Summary: Utility functions for signal processing, compatible with the differentiable programming library JAX.
|
|
5
|
+
Home-page: https://github.com/aymgal/utax
|
|
6
|
+
Author: Austin Peel, Martin Millon, Frederic Dux, Kevin Michalewicz
|
|
7
|
+
Author-email: Aymeric Galan <aymeric.galan@gmail.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Project-URL: Homepage, https://github.com/aymgal/utax
|
|
10
|
+
Project-URL: Repository, https://github.com/aymgal/utax
|
|
11
|
+
Keywords: jax,signal processing,wavelet,convolution,interpolation
|
|
12
|
+
Requires-Python: >=3.10
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: jax>=0.5.0
|
|
16
|
+
Requires-Dist: jaxlib>=0.5.0
|
|
17
|
+
Provides-Extra: dev
|
|
18
|
+
Requires-Dist: pytest; extra == "dev"
|
|
19
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
20
|
+
Requires-Dist: pytest-pep8; extra == "dev"
|
|
21
|
+
Dynamic: home-page
|
|
22
|
+
Dynamic: license-file
|
|
23
|
+
Dynamic: requires-python
|
|
24
|
+
|
|
25
|
+

|
|
26
|
+

|
|
27
|
+
[](https://github.com/aymgal/utax/actions/workflows/ci_tests.yml)
|
|
28
|
+
[](https://coveralls.io/github/aymgal/utax?branch=main)
|
|
29
|
+
|
|
30
|
+
# `utax`
|
|
31
|
+
|
|
32
|
+
Utility functions for applications in signal processing problems, compatible with the differentable programming library `JAX`.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
setup.py
|
|
5
|
+
utax/__init__.py
|
|
6
|
+
utax/info.py
|
|
7
|
+
utax/interpolation.py
|
|
8
|
+
utax/wavelet.py
|
|
9
|
+
utax.egg-info/PKG-INFO
|
|
10
|
+
utax.egg-info/SOURCES.txt
|
|
11
|
+
utax.egg-info/dependency_links.txt
|
|
12
|
+
utax.egg-info/requires.txt
|
|
13
|
+
utax.egg-info/top_level.txt
|
|
14
|
+
utax/convolution/__init__.py
|
|
15
|
+
utax/convolution/classes.py
|
|
16
|
+
utax/convolution/functions.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
utax
|