utax 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
utax-0.0.2/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Aymeric Galan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
utax-0.0.2/PKG-INFO ADDED
@@ -0,0 +1,32 @@
1
+ Metadata-Version: 2.4
2
+ Name: utax
3
+ Version: 0.0.2
4
+ Summary: Utility functions for signal processing, compatible with the differentiable programming library JAX.
5
+ Home-page: https://github.com/aymgal/utax
6
+ Author: Austin Peel, Martin Millon, Frederic Dux, Kevin Michalewicz
7
+ Author-email: Aymeric Galan <aymeric.galan@gmail.com>
8
+ License: MIT
9
+ Project-URL: Homepage, https://github.com/aymgal/utax
10
+ Project-URL: Repository, https://github.com/aymgal/utax
11
+ Keywords: jax,signal processing,wavelet,convolution,interpolation
12
+ Requires-Python: >=3.10
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: jax>=0.5.0
16
+ Requires-Dist: jaxlib>=0.5.0
17
+ Provides-Extra: dev
18
+ Requires-Dist: pytest; extra == "dev"
19
+ Requires-Dist: pytest-cov; extra == "dev"
20
+ Requires-Dist: pytest-pep8; extra == "dev"
21
+ Dynamic: home-page
22
+ Dynamic: license-file
23
+ Dynamic: requires-python
24
+
25
+ ![License](https://img.shields.io/github/license/aymgal/utax)
26
+ ![PyPi python support](https://img.shields.io/badge/Python-3.10-blue)
27
+ [![Tests](https://github.com/aymgal/utax/actions/workflows/ci_tests.yml/badge.svg?branch=main)](https://github.com/aymgal/utax/actions/workflows/ci_tests.yml)
28
+ [![Coverage Status](https://coveralls.io/repos/github/aymgal/utax/badge.svg?branch=main)](https://coveralls.io/github/aymgal/utax?branch=main)
29
+
30
+ # `utax`
31
+
32
+ Utility functions for applications in signal processing problems, compatible with the differentable programming library `JAX`.
utax-0.0.2/README.md ADDED
@@ -0,0 +1,8 @@
1
+ ![License](https://img.shields.io/github/license/aymgal/utax)
2
+ ![PyPi python support](https://img.shields.io/badge/Python-3.10-blue)
3
+ [![Tests](https://github.com/aymgal/utax/actions/workflows/ci_tests.yml/badge.svg?branch=main)](https://github.com/aymgal/utax/actions/workflows/ci_tests.yml)
4
+ [![Coverage Status](https://coveralls.io/repos/github/aymgal/utax/badge.svg?branch=main)](https://coveralls.io/github/aymgal/utax?branch=main)
5
+
6
+ # `utax`
7
+
8
+ Utility functions for applications in signal processing problems, compatible with the differentable programming library `JAX`.
@@ -0,0 +1,37 @@
1
+ [build-system]
2
+ requires = ["setuptools>=64"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "utax"
7
+ version = "0.0.2"
8
+ description = "Utility functions for signal processing, compatible with the differentiable programming library JAX."
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ authors = [
12
+ {name = "Aymeric Galan", email = "aymeric.galan@gmail.com"},
13
+ {name = "Austin Peel"},
14
+ {name = "Martin Millon"},
15
+ {name = "Frederic Dux"},
16
+ {name = "Kevin Michalewicz"},
17
+ ]
18
+ keywords = ["jax", "signal processing", "wavelet", "convolution", "interpolation"]
19
+ requires-python = ">=3.10"
20
+ dependencies = [
21
+ "jax>=0.5.0",
22
+ "jaxlib>=0.5.0",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ dev = [
27
+ "pytest",
28
+ "pytest-cov",
29
+ "pytest-pep8",
30
+ ]
31
+
32
+ [project.urls]
33
+ Homepage = "https://github.com/aymgal/utax"
34
+ Repository = "https://github.com/aymgal/utax"
35
+
36
+ [tool.setuptools.packages.find]
37
+ include = ["utax*"]
utax-0.0.2/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
utax-0.0.2/setup.py ADDED
@@ -0,0 +1,31 @@
1
+ import setuptools
2
+ import os
3
+
4
+ __name__ = 'utax'
5
+
6
+ release_info = {}
7
+ infopath = os.path.abspath(os.path.join(os.path.dirname(__file__),
8
+ __name__, 'info.py'))
9
+ with open(infopath) as open_file:
10
+ exec(open_file.read(), release_info)
11
+
12
+ this_directory = os.path.abspath(os.path.dirname(__file__))
13
+ with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f:
14
+ long_description = f.read()
15
+
16
+ setuptools.setup(
17
+ name=__name__,
18
+ author=release_info['__author__'],
19
+ author_email=release_info['__email__'],
20
+ version=release_info['__version__'],
21
+ url=release_info['__url__'],
22
+ packages=setuptools.find_packages(),
23
+ python_requires=release_info['__python__'],
24
+ install_requires=release_info['__requires__'],
25
+ license=release_info['__license__'],
26
+ description=release_info['__about__'],
27
+ long_description=long_description,
28
+ long_description_content_type='text/markdown',
29
+ setup_requires=release_info['__setup_requires__'],
30
+ tests_require=release_info['__tests_require__']
31
+ )
@@ -0,0 +1,9 @@
1
+ from importlib.metadata import version, PackageNotFoundError
2
+
3
+ __author__ = 'Aymeric Galan, Austin Peel, Martin Millon, Frederic Dux, Kevin Michalewicz'
4
+ __email__ = 'aymeric.galan@gmail.com'
5
+
6
+ try:
7
+ __version__ = version("utax")
8
+ except PackageNotFoundError:
9
+ __version__ = "unknown"
@@ -0,0 +1,5 @@
1
+ from .functions import (convolve_separable_dilated,
2
+ build_convolution_matrix,
3
+ convolve_dilated1D)
4
+ from .classes import (GaussianFilter,
5
+ BlurringOperator)
@@ -0,0 +1,195 @@
1
+ from functools import partial
2
+ import numpy as np
3
+ from scipy.linalg import toeplitz
4
+ import jax.numpy as jnp
5
+ from jax.scipy import stats as jstats
6
+ from jax import jit
7
+
8
+ from utax.convolution import convolve_separable_dilated
9
+
10
+
11
+ class GaussianFilter(object):
12
+ """JAX-friendly Gaussian filter."""
13
+ def __init__(self, sigma, truncate=4.0, mode='edge'):
14
+ """Convolve an image by a gaussian filter.
15
+
16
+ Parameters
17
+ ----------
18
+ sigma : float
19
+ Standard deviation of the Gaussian kernel.
20
+ truncate : float, optional
21
+ Truncate the filter at this many standard deviations.
22
+ Default is 4.0.
23
+
24
+ Note
25
+ ----
26
+ Reproduces `scipy.ndimage.gaussian_filter` with high accuracy.
27
+
28
+ """
29
+ if sigma <= 0:
30
+ self.kernel = None
31
+ else:
32
+ self.kernel = self.gaussian_kernel(sigma, truncate)
33
+ self.mode = mode
34
+
35
+ def gaussian_kernel(self, sigma, truncate):
36
+ # Determine the kernel pixel size (rounded up to an odd int)
37
+ self.radius = int(jnp.ceil(2 * truncate * sigma)) // 2
38
+ npix = self.radius * 2 + 1 # always at least 1
39
+
40
+ # Return the identity if sigma is not a positive number
41
+ if sigma <= 0:
42
+ return jnp.ones(1)
43
+
44
+ # Compute the kernel
45
+ x = jnp.ravel(jnp.indices((npix,))) # pixel coordinates
46
+ kernel = jstats.norm.pdf((x-self.radius) / sigma)
47
+ kernel /= kernel.sum()
48
+
49
+ return kernel
50
+
51
+ @partial(jit, static_argnums=(0,))
52
+ def __call__(self, image):
53
+ """Jit-compiled convolution an image by a gaussian filter.
54
+
55
+ Parameters
56
+ ----------
57
+ image : array_like
58
+ Image to filter.
59
+ """
60
+ # Convolve
61
+ # pad_mode = ['constant', 'edge'][mode == 'nearest']
62
+ # image_padded = jnp.pad(image, pad_width=radius, mode=pad_mode)
63
+ if self.kernel is None:
64
+ return image
65
+ return convolve_separable_dilated(image, self.kernel, boundary=self.mode)
66
+
67
+
68
+ class BlurringOperator(object):
69
+
70
+ def __init__(self, nx, ny, kernel):
71
+ self.target_shape = nx, ny
72
+ self.kernel = kernel
73
+ conv_matrix, self.temp_shape = self.toeplitz_matrix(self.target_shape, self.kernel)
74
+ self.conv_matrix = jnp.array(conv_matrix)
75
+ # get the indices for cropping the output of the full convolution
76
+ nxk, nyk = self.kernel.shape
77
+ if nxk % 2 == 0:
78
+ self.i1, self.i2 = nxk//2-1, -nxk//2
79
+ else:
80
+ self.i1, self.i2 = nxk//2, -nxk//2+1
81
+ if nyk % 2 == 0:
82
+ self.j1, self.j2 = nyk//2-1, -nyk//2
83
+ else:
84
+ self.j1, self.j2 = nyk//2, -nyk//2+1
85
+
86
+ @partial(jit, static_argnums=(0, 2))
87
+ def convolve(self, image, out_padding='same'):
88
+ image_conv = self.v2m(self.conv_matrix.dot(self.m2v(image)), self.temp_shape)
89
+ if out_padding == 'same':
90
+ return image_conv[self.i1:self.i2, self.j1:self.j2]
91
+ elif out_padding == 'full':
92
+ return image_conv
93
+ else:
94
+ raise ValueError(f"padding model '{out_padding}' is not supported.")
95
+
96
+ @partial(jit, static_argnums=(0, 2))
97
+ def convolve_transpose(self, image, in_padding='same'):
98
+ if in_padding == 'full':
99
+ image_padded = image
100
+ elif in_padding == 'same':
101
+ image_padded = jnp.pad(image, ((self.i1, -self.i2), (self.j1, -self.j2)),
102
+ 'constant', constant_values=0)
103
+ else:
104
+ raise ValueError(f"padding model '{in_padding}' is not supported.")
105
+ image_conv_t = self.v2m(self.conv_matrix.T.dot(self.m2v(image_padded)), self.target_shape)
106
+ return image_conv_t
107
+
108
+ @staticmethod
109
+ def toeplitz_matrix(input_shape, kernel, verbose=False):
110
+ """
111
+ Performs 2D convolution between input I and filter F by converting the F to a toeplitz matrix and multiply it
112
+ with vectorizes version of I
113
+ By : AliSaaalehi@gmail.com
114
+
115
+ Arg:
116
+ input shape of the image to be convolved (I) -- 2D numpy matrix
117
+ convolution kernel (F) -- numpy 2D matrix
118
+ verbose -- if True, all intermediate resutls will be printed after each step of the algorithms
119
+
120
+ Returns:
121
+ output -- 2D numpy matrix, result of convolving I with F
122
+ """
123
+ # number of columns and rows of the input
124
+ I_row_num, I_col_num = input_shape
125
+
126
+ # number of columns and rows of the filter
127
+ F_row_num, F_col_num = kernel.shape
128
+
129
+ # calculate the output dimensions
130
+ output_row_num = I_row_num + F_row_num - 1
131
+ output_col_num = I_col_num + F_col_num - 1
132
+ if verbose: print('output dimension:', output_row_num, output_col_num)
133
+
134
+ # zero pad the filter
135
+ F_zero_padded = np.pad(kernel, ((output_row_num - F_row_num, 0),
136
+ (0, output_col_num - F_col_num)),
137
+ 'constant', constant_values=0)
138
+ if verbose: print('F_zero_padded: ', F_zero_padded)
139
+
140
+ # use each row of the zero-padded F to creat a toeplitz matrix.
141
+ # Number of columns in this matrices are same as numbe of columns of input signal
142
+ toeplitz_list = []
143
+ for i in range(F_zero_padded.shape[0]-1, -1, -1): # iterate from last row to the first row
144
+ c = F_zero_padded[i, :] # i th row of the F
145
+ r = np.r_[c[0], np.zeros(I_col_num-1)] # first row for the toeplitz fuction should be defined otherwise
146
+ # the result is wrong
147
+ toeplitz_m = toeplitz(c,r) # this function is in scipy.linalg library
148
+ toeplitz_list.append(toeplitz_m)
149
+ if verbose: print('F '+ str(i)+'\n', toeplitz_m)
150
+
151
+ # doubly blocked toeplitz indices:
152
+ # this matrix defines which toeplitz matrix from toeplitz_list goes to which part of the doubly blocked
153
+ c = range(1, F_zero_padded.shape[0]+1)
154
+ r = np.r_[c[0], np.zeros(I_row_num-1, dtype=int)]
155
+ doubly_indices = toeplitz(c, r)
156
+ if verbose: print('doubly indices \n', doubly_indices)
157
+
158
+ ## creat doubly blocked matrix with zero values
159
+ toeplitz_shape = toeplitz_list[0].shape # shape of one toeplitz matrix
160
+ h = toeplitz_shape[0]*doubly_indices.shape[0]
161
+ w = toeplitz_shape[1]*doubly_indices.shape[1]
162
+ doubly_blocked_shape = [h, w]
163
+ doubly_blocked = np.zeros(doubly_blocked_shape)
164
+
165
+ # tile toeplitz matrices for each row in the doubly blocked matrix
166
+ b_h, b_w = toeplitz_shape # hight and withs of each block
167
+ for i in range(doubly_indices.shape[0]):
168
+ for j in range(doubly_indices.shape[1]):
169
+ start_i = i * b_h
170
+ start_j = j * b_w
171
+ end_i = start_i + b_h
172
+ end_j = start_j + b_w
173
+ doubly_blocked[start_i: end_i, start_j:end_j] = toeplitz_list[doubly_indices[i,j]-1]
174
+
175
+ if verbose: print('doubly_blocked: ', doubly_blocked)
176
+
177
+ out_shape = (output_row_num, output_col_num)
178
+
179
+ return doubly_blocked, out_shape
180
+
181
+ @staticmethod
182
+ def m2v(mat):
183
+ return jnp.flipud(mat).flatten(order='C')
184
+
185
+ @staticmethod
186
+ def v2m(vec, output_shape):
187
+ return jnp.flipud(vec.reshape(output_shape, order='C'))
188
+
189
+ # @staticmethod
190
+ # def m2v_t(mat):
191
+ # return jnp.flipud(mat.flatten(order='C'))
192
+
193
+ # @staticmethod
194
+ # def v2m_t(vec, output_shape):
195
+ # return jnp.flipud(vec).reshape(output_shape, order='C')
@@ -0,0 +1,211 @@
1
+ from functools import partial
2
+ import numpy as np
3
+ from scipy import sparse
4
+ import jax.numpy as jnp
5
+ from jax.scipy import stats as jstats
6
+ from jax import jit
7
+ from jax.lax import conv_general_dilated, conv_dimension_numbers
8
+
9
+
10
+ @partial(jit, static_argnums=(2, 3))
11
+ def convolve_separable_dilated(image2D, kernel1D, dilation=1, boundary='edge'):
12
+ """
13
+
14
+ Convolves an image contained in image2D with the 1D kernel kernel1D.
15
+ The operation is basically the following:
16
+ blured2D = image2D * (kernel1D ∧ kernel1D )
17
+ where ∧ is a wedge product, here a tensor product.
18
+
19
+
20
+
21
+ Parameters
22
+ ----------
23
+ image2D : 2D array
24
+ imaged to be convolved with the kernel.
25
+ kernel1D : 1D array
26
+ kernel to convolve the image with..
27
+ dilation : TYPE, optional
28
+ makes the spacial extent of the kernel bigger. The default is 1.
29
+
30
+ Returns
31
+ -------
32
+ 2D array
33
+ image convoluted by the kernel.
34
+
35
+ """
36
+
37
+ # padding
38
+ b = int(kernel1D.size // 2) * dilation
39
+ padded = jnp.pad(image2D, ((b, b), (b, b)), mode=boundary)
40
+ # Fred D.: THIS PADDING IS DANGEROUS AS IT WILL OVERFLOW MEMORY VERY QUICKLY
41
+ # I LEAVE IT AS ORIGINALLY IMPLEMENTED AS I DO NOT WANT TO CHANGE THE
42
+ # OUTPUT OF THE WAVELET TRANSFORM (this could have impact on the science)
43
+
44
+
45
+ # specify the row and column operations for the jax convolution routine:
46
+ image = jnp.expand_dims(padded, (2,))
47
+ # shape (Nx, Ny, 1) -- (N, W, C)
48
+ # we treat the Nx as the batch number!! (because it is a 1D convolution
49
+ # over the rows)
50
+ kernel = jnp.expand_dims(kernel1D, (0,2,))
51
+ # here we have kernel shape ~(I,W,O)
52
+ # so:
53
+ # (Nbatch, Width, Channel) * (Inputdim, Widthkernel, Outputdim)
54
+ # -> (Nbatch, Width, Channel)
55
+ # where Nbatch is our number of rows.
56
+ dimension_numbers = ('NWC', 'IWO', 'NWC')
57
+ dn = conv_dimension_numbers(image.shape,
58
+ kernel.shape,
59
+ dimension_numbers)
60
+ # with these conv_general_dilated knows how to handle the different
61
+ # axes:
62
+ rowblur = conv_general_dilated(image, kernel,
63
+ window_strides=(1,),
64
+ padding='VALID',
65
+ rhs_dilation=(dilation,),
66
+ dimension_numbers=dn)
67
+
68
+ # now we do the same for the columns, hence this time we have
69
+ # (Height, Nbatch, Channel) * (Inputdim, Widthkernel, Outputdim)
70
+ # -> (Height, Nbatch, Channel)
71
+ # where Nbatch is our number of columns.
72
+ dimension_numbers = ('HNC', 'IHO', 'HNC')
73
+ dn = conv_dimension_numbers(image.shape,
74
+ kernel.shape,
75
+ dimension_numbers)
76
+
77
+ rowcolblur = conv_general_dilated(rowblur, kernel,
78
+ window_strides=(1,),
79
+ padding='VALID',
80
+ rhs_dilation=(dilation,),
81
+ dimension_numbers=dn)
82
+
83
+ return rowcolblur[:,:,0]
84
+
85
+ @partial(jit, static_argnums=(2, 3))
86
+ def convolve_dilated1D(signal1D, kernel1D, dilation=1, boundary='edge'):
87
+ """
88
+
89
+ Convolves a vector contained in signal1D with the 1D kernel.
90
+ The operation is basically the following:
91
+ blured1D = signal1D * kernel1D
92
+
93
+ Parameters
94
+ ----------
95
+ signal1D : 1D array
96
+ vector to be convolved with the kernel.
97
+ kernel1D : 1D array
98
+ kernel to convolve signal1D with..
99
+ dilation : TYPE, optional
100
+ makes the spacial extent of the kernel bigger. The default is 1.
101
+
102
+ Returns
103
+ -------
104
+ 1D array convoluted by the kernel.
105
+ """
106
+
107
+ # padding
108
+ b = int(kernel1D.size // 2) * dilation
109
+ padded = jnp.pad(signal1D, ((b, b)), mode=boundary)
110
+
111
+ shape = kernel1D.shape
112
+ strides = tuple(1 for s in shape)
113
+ rowblur = conv_general_dilated(padded[None, None], kernel1D[None, None],
114
+ window_strides=strides,
115
+ padding='VALID',
116
+ rhs_dilation=(dilation,),
117
+ )
118
+
119
+ return rowblur[0, 0]
120
+
121
+
122
+
123
+ def build_convolution_matrix(psf_kernel_2d, image_shape):
124
+ """
125
+ Build a sparse matrix to convolve an image via matrix-vector product.
126
+ Ported from C++ code in VKL from Vernardos & Koopmans 2022.
127
+
128
+ Note: only works with square kernel with odd number of pixels on the side,
129
+ lower than the number of pixels on the side of the image to be convolved.
130
+
131
+ Authors: @gvernard, @aymgal
132
+ """
133
+ Ni, Nj = image_shape
134
+ Ncropx, Ncropy = psf_kernel_2d.shape
135
+
136
+ def setCroppedLimitsEven(k, Ncrop, Nimg, Nquad):
137
+ if k < (Nquad - 1):
138
+ Npre = k
139
+ Npost = Nquad
140
+ offset = Nquad - k
141
+ elif k > (Nimg - Nquad - 1):
142
+ Npre = Nquad
143
+ Npost = Nimg - k
144
+ offset = 0
145
+ else:
146
+ Npre = Nquad
147
+ Npost = Nquad
148
+ offset = 0
149
+ return Npre, Npost, offset
150
+
151
+ def setCroppedLimitsOdd(k, Ncrop, Nimg, Nquad):
152
+ if k < (Nquad - 1):
153
+ Npre = k
154
+ Npost = Nquad
155
+ offset = Nquad - 1 - k
156
+ elif k > (Nimg - Nquad - 1):
157
+ Npre = Nquad - 1
158
+ Npost = Nimg - k
159
+ offset = 0
160
+ else:
161
+ Npre = Nquad-1
162
+ Npost = Nquad
163
+ offset = 0
164
+ return Npre, Npost, offset
165
+
166
+ # get the correct method to offset the PSF kernel from the above
167
+ if Ncropx % 2 == 0:
168
+ # Warning: this might be broken in certain cases
169
+ func_limits_x = setCroppedLimitsEven
170
+ Nquadx = Ncropx//2
171
+ else:
172
+ func_limits_x = setCroppedLimitsOdd
173
+ Nquadx = int(np.ceil(Ncropx/2.))
174
+ if Ncropx % 2 == 0:
175
+ # Warning: this might be broken in certain cases
176
+ func_limits_y = setCroppedLimitsEven
177
+ Nquady = Ncropy//2
178
+ else:
179
+ func_limits_y = setCroppedLimitsOdd
180
+ Nquady = int(np.ceil(Ncropy/2.))
181
+
182
+ # create the blurring matrix in a sparse form
183
+ blur = psf_kernel_2d.flatten()
184
+ sparse_B_rows, sparse_B_cols = [], []
185
+ sparse_B_values = []
186
+ for i in range(Ni): # loop over image rows
187
+ for j in range(Nj): # loop over image columns
188
+
189
+ Nleft, Nright, crop_offsetx = func_limits_x(j, Ncropx, Nj, Nquadx)
190
+ Ntop, Nbottom, crop_offsety = func_limits_y(i, Ncropy, Ni, Nquady)
191
+
192
+ crop_offset = crop_offsety*Ncropx + crop_offsetx
193
+
194
+ for ii in range(i-Ntop, i+Nbottom): # loop over PSF rows
195
+ ic = ii - i + Ntop
196
+
197
+ for jj in range(j-Nleft, j+Nright): # loop over PSF columns
198
+ jc = jj - j + Nleft;
199
+
200
+ val = blur[crop_offset + ic*Ncropx + jc]
201
+
202
+ # save entries
203
+ # (note: rows and cols were inverted from the VKL code)
204
+ sparse_B_rows.append(ii*Nj + jj)
205
+ sparse_B_cols.append(i*Nj + j)
206
+ sparse_B_values.append(val)
207
+
208
+ # populate the sparse matrix
209
+ blurring_matrix = sparse.csr_matrix((sparse_B_values, (sparse_B_rows, sparse_B_cols)),
210
+ shape=(Ni**2, Nj**2))
211
+ return blurring_matrix
@@ -0,0 +1,29 @@
1
+ """PACKAGE INFO
2
+
3
+ This module provides some basic information about the package.
4
+
5
+ """
6
+
7
+ # Set the package release version
8
+ version_info = (0, 0, 2)
9
+ __version__ = '.'.join(str(c) for c in version_info)
10
+
11
+ # Set the package details
12
+ __author__ = 'Aymeric Galan, Austin Peel, Martin Millon, Frederic Dux, Kevin Michalewicz'
13
+ __email__ = 'aymeric.galan@gmail.com'
14
+ __year__ = '2022'
15
+ __url__ = 'https://github.com/aymgal/utax'
16
+ __description__ = 'Utility functions for signal processing, compatible with the differentable programming library JAX.'
17
+ __python__ = '>=3.10'
18
+ __requires__ = [
19
+ 'jax>=0.5.0',
20
+ 'jaxlib>=0.5.0',
21
+ ] # Package dependencies
22
+
23
+ # Default package properties
24
+ __license__ = 'MIT'
25
+ __about__ = ('{} Author: {}, Email: {}, Year: {}, {}'
26
+ ''.format(__name__, __author__, __email__, __year__,
27
+ __description__))
28
+ __setup_requires__ = ['pytest-runner', ]
29
+ __tests_require__ = ['pytest', 'pytest-cov', 'pytest-pep8']
@@ -0,0 +1,235 @@
1
+ import numpy as np
2
+ import jax.numpy as jnp
3
+ from jax import lax, vmap #, jit
4
+
5
+
6
+
7
+ class BilinearInterpolator(object):
8
+ """Bilinear interpolation of a 2D field.
9
+
10
+ Functionality is modelled after scipy.interpolate.RectBivariateSpline
11
+ when `kx` and `ky` are both equal to 1. Results match the scipy version when
12
+ interpolated values lie within the x and y domain (boundaries included).
13
+ Returned values can be significantly different outside the natural domain,
14
+ as the scipy version does not extrapolate. Evaluation of this jax version
15
+ is MUCH SLOWER as well.
16
+
17
+ """
18
+ def __init__(self, x, y, z, allow_extrapolation=True):
19
+ self.z = jnp.array(z)
20
+
21
+ # Sort x if not increasing
22
+ x = jnp.array(x)
23
+ x_sorted = jnp.sort(x)
24
+ flip_x = ~jnp.all(jnp.diff(x) >= 0)
25
+
26
+ def x_keep_fn(_):
27
+ return x, self.z
28
+
29
+ def x_sort_fn(_):
30
+ return x_sorted, jnp.flip(self.z, axis=0)
31
+
32
+ self.x, self.z = lax.cond(flip_x, x_sort_fn, x_keep_fn, operand=None)
33
+
34
+ # Sort y if not increasing
35
+ y = jnp.array(y)
36
+ y_sorted = jnp.sort(y)
37
+ flip_y = ~jnp.all(jnp.diff(y) >= 0)
38
+
39
+ def y_keep_fn(_):
40
+ return y, self.z
41
+
42
+ def y_sort_fn(_):
43
+ return y_sorted, jnp.flip(self.z, axis=1)
44
+
45
+ self.y, self.z = lax.cond(flip_y, y_sort_fn, y_keep_fn, operand=None)
46
+ self._extrapol_bool = allow_extrapolation
47
+
48
+ def __call__(self, x, y, dx=0, dy=0):
49
+ """Vectorized evaluation of the interpolation or its derivatives.
50
+
51
+ Parameters
52
+ ----------
53
+ x, y : array_like
54
+ Position(s) at which to evaluate the interpolation.
55
+ dx, dy : int, either 0 or 1
56
+ If 1, return the first partial derivative of the interpolation
57
+ with respect to that coordinate. Only one of (dx, dy) should be
58
+ nonzero at a time.
59
+
60
+ """
61
+ x = jnp.atleast_1d(x)
62
+ y = jnp.atleast_1d(y)
63
+
64
+ error_msg_type = "dx and dy must be integers"
65
+ error_msg_value = "dx and dy must only be either 0 or 1"
66
+ assert isinstance(dx, int) and isinstance(dy, int), error_msg_type
67
+ assert dx in (0, 1) and dy in (0, 1), error_msg_value
68
+ if dx == 1: dy = 0
69
+
70
+ return vmap(self._evaluate, in_axes=(0, 0, None, None))(x, y, dx, dy)
71
+
72
+ # @partial(jit, static_argnums=(0,))
73
+ def _compute_coeffs(self, x, y):
74
+ # Find the pixel that the point (x, y) falls in
75
+ # x_ind = jnp.digitize(x, self.x_padded) - 1
76
+ # y_ind = jnp.digitize(y, self.y_padded) - 1
77
+ x_ind = jnp.searchsorted(self.x, x, side='right') - 1
78
+ x_ind = jnp.clip(x_ind, a_min=0, a_max=(len(self.x) - 2))
79
+ y_ind = jnp.searchsorted(self.y, y, side='right') - 1
80
+ y_ind = jnp.clip(y_ind, a_min=0, a_max=(len(self.y) - 2))
81
+
82
+ # Determine the coordinates and dimensions of this pixel
83
+ x1 = self.x[x_ind]
84
+ x2 = self.x[x_ind + 1]
85
+ y1 = self.y[y_ind]
86
+ y2 = self.y[y_ind + 1]
87
+ area = (x2 - x1) * (y2 - y1)
88
+
89
+ # Compute function values at the four corners
90
+ # Edge padding is implicitly constant
91
+ v11 = self.z[x_ind, y_ind]
92
+ v12 = self.z[x_ind, y_ind + 1]
93
+ v21 = self.z[x_ind + 1, y_ind]
94
+ v22 = self.z[x_ind + 1, y_ind + 1]
95
+
96
+ # Compute the coefficients
97
+ a0_ = v11 * x2 * y2 - v12 * x2 * y1 - v21 * x1 * y2 + v22 * x1 * y1
98
+ a1_ = -v11 * y2 + v12 * y1 + v21 * y2 - v22 * y1
99
+ a2_ = -v11 * x2 + v12 * x2 + v21 * x1 - v22 * x1
100
+ a3_ = v11 - v12 - v21 + v22
101
+
102
+ return a0_ / area, a1_ / area, a2_ / area, a3_ / area
103
+
104
+ def _evaluate(self, x, y, dx=0, dy=0):
105
+ """Single-point evaluation of the interpolation."""
106
+ a0, a1, a2, a3 = self._compute_coeffs(x, y)
107
+ if (dx, dy) == (0, 0):
108
+ result = a0 + a1 * x + a2 * y + a3 * x * y
109
+ elif (dx, dy) == (1, 0):
110
+ result = a1 + a3 * y
111
+ else:
112
+ result = a2 + a3 * x
113
+ # if extrapolation is not allowed, then we mask out values outside the original bounding box
114
+ result = lax.cond(self._extrapol_bool,
115
+ lambda _: result,
116
+ lambda _: result * (x >= self.x[0]) * (x <= self.x[-1]) * (y >= self.y[0]) * (y <= self.y[-1]),
117
+ operand=None)
118
+ return result
119
+
120
+
121
+ class BicubicInterpolator(object):
122
+ """Bicubic interpolation of a 2D field.
123
+
124
+ Functionality is modelled after scipy.interpolate.RectBivariateSpline
125
+ when `kx` and `ky` are both equal to 3.
126
+
127
+ """
128
+ def __init__(self, x, y, z, zx=None, zy=None, zxy=None, allow_extrapolation=True):
129
+ self.z = jnp.array(z)
130
+ if np.all(np.diff(x) >= 0): # check if sorted in increasing order
131
+ self.x = jnp.array(x)
132
+ else:
133
+ self.x = jnp.array(np.sort(x))
134
+ self.z = jnp.flip(self.z, axis=1)
135
+ if np.all(np.diff(y) >= 0): # check if sorted in increasing order
136
+ self.y = jnp.array(y)
137
+ else:
138
+ self.y = jnp.array(np.sort(y))
139
+ self.z = jnp.flip(self.z, axis=0)
140
+
141
+ # Assume uniform coordinate spacing
142
+ self.dx = self.x[1] - self.x[0]
143
+ self.dy = self.y[1] - self.y[0]
144
+
145
+ # Compute approximate partial derivatives if not provided
146
+ if zx is None:
147
+ self.zx = jnp.gradient(z, axis=0) / self.dx
148
+ else:
149
+ self.zx = zy
150
+ if zy is None:
151
+ self.zy = jnp.gradient(z, axis=1) / self.dy
152
+ else:
153
+ self.zy = zx
154
+ if zxy is None:
155
+ self.zxy = jnp.gradient(self.zx, axis=1) / self.dy
156
+ else:
157
+ self.zxy = zxy
158
+
159
+ # Prepare coefficients for function evaluations
160
+ self._A = jnp.array([[1., 0., 0., 0.],
161
+ [0., 0., 1., 0.],
162
+ [-3., 3., -2., -1.],
163
+ [2., -2., 1., 1.]])
164
+ self._B = jnp.array([[1., 0., -3., 2.],
165
+ [0., 0., 3., -2.],
166
+ [0., 1., -2., 1.],
167
+ [0., 0., -1., 1.]])
168
+ row0 = [self.z[:-1,:-1], self.z[:-1,1:], self.dy * self.zy[:-1,:-1], self.dy * self.zy[:-1,1:]]
169
+ row1 = [self.z[1:,:-1], self.z[1:,1:], self.dy * self.zy[1:,:-1], self.dy * self.zy[1:,1:]]
170
+ row2 = self.dx * jnp.array([self.zx[:-1,:-1], self.zx[:-1,1:],
171
+ self.dy * self.zxy[:-1,:-1], self.dy * self.zxy[:-1,1:]])
172
+ row3 = self.dx * jnp.array([self.zx[1:,:-1], self.zx[1:,1:],
173
+ self.dy * self.zxy[1:,:-1], self.dy * self.zxy[1:,1:]])
174
+ self._m = jnp.array([row0, row1, row2, row3])
175
+
176
+ self._m = jnp.transpose(self._m, axes=(2, 3, 0, 1))
177
+
178
+ self._extrapol_bool = allow_extrapolation
179
+
180
+ def __call__(self, x, y, dx=0, dy=0):
181
+ """Vectorized evaluation of the interpolation or its derivatives.
182
+
183
+ Parameters
184
+ ----------
185
+ x, y : array_like
186
+ Position(s) at which to evaluate the interpolation.
187
+ dx, dy : int, either 0, 1, or 2
188
+ Return the nth partial derivative of the interpolation
189
+ with respect to the specified coordinate. Only one of (dx, dy)
190
+ should be nonzero at a time.
191
+
192
+ """
193
+ x = jnp.atleast_1d(x)
194
+ y = jnp.atleast_1d(y)
195
+ if x.ndim == 1:
196
+ vmap_call = vmap(self._evaluate, in_axes=(0, 0, None, None))
197
+ elif x.ndim == 2:
198
+ vmap_call = vmap(vmap(self._evaluate, in_axes=(0, 0, None, None)),
199
+ in_axes=(0, 0, None, None))
200
+ return vmap_call(x, y, dx, dy)
201
+
202
+ def _evaluate(self, x, y, dx=0, dy=0):
203
+ """Evaluate the interpolation at a single point."""
204
+ # Determine which pixel (i, j) the point (x, y) falls in
205
+ i = jnp.maximum(0, jnp.searchsorted(self.x, x) - 1)
206
+ j = jnp.maximum(0, jnp.searchsorted(self.y, y) - 1)
207
+
208
+ # Rescale coordinates into (0, 1)
209
+ u = (x - self.x[i]) / self.dx
210
+ v = (y - self.y[j]) / self.dy
211
+
212
+ # Compute interpolation coefficients
213
+ a = jnp.dot(self._A, jnp.dot(self._m[i, j], self._B))
214
+
215
+ if dx == 0:
216
+ uu = jnp.asarray([1., u, u**2, u**3])
217
+ if dx == 1:
218
+ uu = jnp.asarray([0., 1., 2. * u, 3. * u**2]) / self.dx
219
+ if dx == 2:
220
+ uu = jnp.asarray([0., 0., 2., 6. * u]) / self.dx**2
221
+ if dy == 0:
222
+ vv = jnp.asarray([1., v, v**2, v**3])
223
+ if dy == 1:
224
+ vv = jnp.asarray([0., 1., 2. * v, 3. * v**2]) / self.dy
225
+ if dy == 2:
226
+ vv = jnp.asarray([0., 0., 2., 6. * v]) / self.dy**2
227
+ result = jnp.dot(uu, jnp.dot(a, vv))
228
+
229
+ # if extrapolation is not allowed, then we mask out values outside the original bounding box
230
+ result = lax.cond(self._extrapol_bool,
231
+ lambda _: result,
232
+ lambda _: result * (x >= self.x[0]) * (x <= self.x[-1]) * (y >= self.y[0]) * (y <= self.y[-1]),
233
+ operand=None)
234
+ return result
235
+
@@ -0,0 +1,243 @@
1
+ from functools import partial
2
+ import jax.numpy as jnp
3
+ from jax import jit
4
+
5
+ from utax.convolution import convolve_separable_dilated, convolve_dilated1D
6
+
7
+
8
+ class WaveletTransform(object):
9
+ """
10
+ Class that handles wavelet transform using JAX, using the 'a trous' algorithm
11
+
12
+ Parameters
13
+ ----------
14
+ nscales : number of scales in the decomposition
15
+ dim : dimensionality of the convolution (2 for image, 1 for 1D vector)
16
+ self._type : supported types are 'starlet', 'battle-lemarie-1', 'battle-lemarie-3'
17
+
18
+ """
19
+ def __init__(self, nscales, wavelet_type='starlet', second_gen=False, dim=2):
20
+ self._n_scales = nscales
21
+ self._second_gen = second_gen
22
+ if wavelet_type == 'starlet':
23
+ self._h = jnp.array([1., 4., 6., 4., 1.]) / 16.
24
+ elif wavelet_type == 'battle-lemarie-1': # (order 1 = 2 vanishing moments)
25
+ self._h = jnp.array([-0.000122686, -0.000224296, 0.000511636,
26
+ 0.000923371, -0.002201945, -0.003883261, 0.009990599,
27
+ 0.016974805, -0.051945337, -0.06910102, 0.39729643,
28
+ 0.817645956, 0.39729643, -0.06910102, -0.051945337,
29
+ 0.016974805, 0.009990599, -0.003883261, -0.002201945,
30
+ 0.000923371, 0.000511636, -0.000224296, -0.000122686])
31
+ elif wavelet_type == 'battle-lemarie-3': # (order 2 = 4 vanishing moments)
32
+ self._h = jnp.array([0.000146098, -0.000232304, -0.000285414,
33
+ 0.000462093, 0.000559952, -0.000927187, -0.001103748,
34
+ 0.00188212, 0.002186714, -0.003882426, -0.00435384,
35
+ 0.008201477, 0.008685294, -0.017982291, -0.017176331,
36
+ 0.042068328, 0.032080869, -0.110036987, -0.050201753,
37
+ 0.433923147, 0.766130398, 0.433923147, -0.050201753,
38
+ -0.110036987, 0.032080869, 0.042068328, -0.017176331,
39
+ -0.017982291, 0.008685294, 0.008201477, -0.00435384,
40
+ -0.003882426, 0.002186714, 0.00188212, -0.001103748,
41
+ -0.000927187, 0.000559952, 0.000462093, -0.000285414,
42
+ -0.000232304, 0.000146098])
43
+
44
+ elif wavelet_type == 'battle-lemarie-5': # (order 5 = 6 vanishing moments)
45
+ self._h = jnp.array([1.4299532e-004,
46
+ 1.5656611e-004,
47
+ -2.2509746e-004,
48
+ -2.4421337e-004,
49
+ 3.5556002e-004,
50
+ 3.8161407e-004,
51
+ -5.6393016e-004,
52
+ -5.9748378e-004,
53
+ 8.9882146e-004,
54
+ 9.3739129e-004,
55
+ -1.4412528e-003,
56
+ -1.4737089e-003,
57
+ 2.3286290e-003,
58
+ 2.3211761e-003,
59
+ -3.7992267e-003,
60
+ -3.6602095e-003,
61
+ 6.2791340e-003,
62
+ 5.7683203e-003,
63
+ -1.0562022e-002,
64
+ -9.0493510e-003,
65
+ 1.8208558e-002,
66
+ 1.4009689e-002,
67
+ -3.2519969e-002,
68
+ -2.1006296e-002,
69
+ 6.1312356e-002,
70
+ 2.9474179e-002,
71
+ -1.2926869e-001,
72
+ -3.7019995e-002,
73
+ 4.4246341e-001,
74
+ 7.4723338e-001,
75
+ 4.4246341e-001,
76
+ -3.7019995e-002,
77
+ -1.2926869e-001,
78
+ 2.9474179e-002,
79
+ 6.1312356e-002,
80
+ -2.1006296e-002,
81
+ -3.2519969e-002,
82
+ 1.4009689e-002,
83
+ 1.8208558e-002,
84
+ -9.0493510e-003,
85
+ -1.0562022e-002,
86
+ 5.7683203e-003,
87
+ 6.2791340e-003,
88
+ -3.6602095e-003,
89
+ -3.7992267e-003,
90
+ 2.3211761e-003,
91
+ 2.3286290e-003,
92
+ -1.4737089e-003,
93
+ -1.4412528e-003,
94
+ 9.3739129e-004,
95
+ 8.9882146e-004,
96
+ -5.9748378e-004,
97
+ -5.6393016e-004,
98
+ 3.8161407e-004,
99
+ 3.5556002e-004,
100
+ -2.4421337e-004,
101
+ -2.2509746e-004,
102
+ 1.5656611e-004,
103
+ 1.4299532e-004])
104
+ else:
105
+ raise ValueError(f"'{wavelet_type}' starlet transform is not supported")
106
+
107
+ self._h /= jnp.sum(self._h)
108
+ self._fac = len(self._h) // 2
109
+
110
+ if self._second_gen:
111
+ self.decompose = self._decompose_2nd_gen
112
+ self.reconstruct = self._reconstruct_2nd_gen
113
+ else:
114
+ self.decompose = self._decompose_1st_gen
115
+ self.reconstruct = self._reconstruct_1st_gen
116
+
117
+ self._dim = dim
118
+ if self._dim == 1:
119
+ self.convolve = convolve_dilated1D
120
+ if self._second_gen:
121
+ raise NotImplementedError('2nd generation of wavelet not yet implemented for 1D decomposition.')
122
+ elif self._dim == 2:
123
+ self.convolve = convolve_separable_dilated
124
+ else:
125
+ raise ValueError(f"Dimensionality {dim} not supported.")
126
+
127
+ @property
128
+ def scale_norms(self):
129
+ if not hasattr(self, '_norms'):
130
+ npix_dirac = 2**(self._n_scales + 2)
131
+ if self._dim == 1 :
132
+ dirac = (jnp.arange(npix_dirac) == int(npix_dirac / 2)).astype(float)
133
+ wt_dirac = self.decompose(dirac)
134
+ self._norms = jnp.sqrt(jnp.sum(wt_dirac**2, axis=(1,)))
135
+ else:
136
+ dirac = jnp.diag((jnp.arange(npix_dirac) == int(npix_dirac / 2)).astype(float))
137
+ wt_dirac = self.decompose(dirac)
138
+ self._norms = jnp.sqrt(jnp.sum(wt_dirac**2, axis=(1, 2,)))
139
+ return self._norms
140
+
141
+
142
+ @partial(jit, static_argnums=(0,))
143
+ def _decompose_1st_gen(self, image):
144
+ """Decompose an image into the chosen wavelet basis"""
145
+ # Validate input
146
+ assert self._n_scales >= 0, "nscales must be a non-negative integer"
147
+ if self._n_scales == 0:
148
+ return image
149
+
150
+ # Preparations
151
+ # image = jnp.copy(image)
152
+ kernel = self._h.copy()
153
+
154
+ # Compute the first scale:
155
+ c1 = self.convolve(image, kernel)
156
+ # Wavelet coefficients:
157
+ w0 = (image - c1)
158
+ result = jnp.expand_dims(w0, 0)
159
+ cj = c1
160
+
161
+ # Compute the remaining scales
162
+ # at each scale, the kernel becomes larger ( a trou ) using the
163
+ # dilation argument in the jax wrapper for convolution.
164
+ for step in range(1, self._n_scales):
165
+ cj1 = self.convolve(cj, kernel, dilation=self._fac**step)
166
+ # wavelet coefficients
167
+ wj = (cj - cj1)
168
+ result = jnp.concatenate((result, jnp.expand_dims(wj, 0)), axis=0)
169
+ cj = cj1
170
+
171
+ # Append final coarse scale
172
+ result = jnp.concatenate((result, jnp.expand_dims(cj, axis=0)), axis=0)
173
+ return result
174
+
175
+ @partial(jit, static_argnums=(0,))
176
+ def _decompose_2nd_gen(self, image):
177
+ """Decompose an image into the chosen wavelet basis"""
178
+ # Validate input
179
+ assert self._n_scales >= 0, "nscales must be a non-negative integer"
180
+ if self._n_scales == 0:
181
+ return image
182
+
183
+ # Preparations
184
+ # image = jnp.copy(image)
185
+ kernel = self._h.copy()
186
+
187
+ # Compute the first scale:
188
+ c1 = self.convolve(image, kernel)
189
+ c1p = self.convolve(c1, kernel)
190
+ # Wavelet coefficients:
191
+ w0 = (image - c1p)
192
+ result = jnp.expand_dims(w0, 0)
193
+ cj = c1
194
+
195
+ # Compute the remaining scales
196
+ # at each scale, the kernel becomes larger ( a trou ) using the
197
+ # dilation argument in the jax wrapper for convolution.
198
+ for step in range(1, self._n_scales):
199
+ cj1 = self.convolve(cj, kernel, dilation=self._fac**step)
200
+ cj1p = self.convolve(cj1, kernel, dilation=self._fac**step)
201
+ # wavelet coefficients
202
+ wj = (cj - cj1p)
203
+ result = jnp.concatenate((result, jnp.expand_dims(wj, 0)), axis=0)
204
+ cj = cj1
205
+
206
+ # Append final coarse scale
207
+ result = jnp.concatenate((result, jnp.expand_dims(cj, axis=0)), axis=0)
208
+ return result
209
+
210
+
211
+ @partial(jit, static_argnums=(0,))
212
+ def _reconstruct_1st_gen(self, coeffs):
213
+ return jnp.sum(coeffs, axis=0)
214
+
215
+
216
+ @partial(jit, static_argnums=(0,))
217
+ def _reconstruct_2nd_gen(self, coeffs):
218
+ # Validate input
219
+ assert coeffs.shape[-3] == self._n_scales+1, \
220
+ "Wavelet coefficients are not consistent with number of scales"
221
+ if self._n_scales == 0:
222
+ return coeffs[0, :, :]
223
+
224
+ kernel = self._h
225
+
226
+ # Start with the last scale 'J-1'
227
+ cJ = coeffs[self._n_scales, :, :]
228
+ cJp = self.convolve(cJ, kernel,
229
+ dilation=self._fac**(self._n_scales-1))
230
+
231
+
232
+ wJ = coeffs[self._n_scales-1, :, :]
233
+ cj = cJp + wJ
234
+
235
+ # Compute the remaining scales
236
+ for ii in range(self._n_scales-2, -1, -1):
237
+ cj1 = cj
238
+ cj1p = self.convolve(cj1, kernel, dilation=self._fac**ii)
239
+ wj1 = coeffs[ii, :, :]
240
+ cj = cj1p + wj1
241
+
242
+ result = cj
243
+ return result
@@ -0,0 +1,32 @@
1
+ Metadata-Version: 2.4
2
+ Name: utax
3
+ Version: 0.0.2
4
+ Summary: Utility functions for signal processing, compatible with the differentiable programming library JAX.
5
+ Home-page: https://github.com/aymgal/utax
6
+ Author: Austin Peel, Martin Millon, Frederic Dux, Kevin Michalewicz
7
+ Author-email: Aymeric Galan <aymeric.galan@gmail.com>
8
+ License: MIT
9
+ Project-URL: Homepage, https://github.com/aymgal/utax
10
+ Project-URL: Repository, https://github.com/aymgal/utax
11
+ Keywords: jax,signal processing,wavelet,convolution,interpolation
12
+ Requires-Python: >=3.10
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: jax>=0.5.0
16
+ Requires-Dist: jaxlib>=0.5.0
17
+ Provides-Extra: dev
18
+ Requires-Dist: pytest; extra == "dev"
19
+ Requires-Dist: pytest-cov; extra == "dev"
20
+ Requires-Dist: pytest-pep8; extra == "dev"
21
+ Dynamic: home-page
22
+ Dynamic: license-file
23
+ Dynamic: requires-python
24
+
25
+ ![License](https://img.shields.io/github/license/aymgal/utax)
26
+ ![PyPi python support](https://img.shields.io/badge/Python-3.10-blue)
27
+ [![Tests](https://github.com/aymgal/utax/actions/workflows/ci_tests.yml/badge.svg?branch=main)](https://github.com/aymgal/utax/actions/workflows/ci_tests.yml)
28
+ [![Coverage Status](https://coveralls.io/repos/github/aymgal/utax/badge.svg?branch=main)](https://coveralls.io/github/aymgal/utax?branch=main)
29
+
30
+ # `utax`
31
+
32
+ Utility functions for applications in signal processing problems, compatible with the differentable programming library `JAX`.
@@ -0,0 +1,16 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ utax/__init__.py
6
+ utax/info.py
7
+ utax/interpolation.py
8
+ utax/wavelet.py
9
+ utax.egg-info/PKG-INFO
10
+ utax.egg-info/SOURCES.txt
11
+ utax.egg-info/dependency_links.txt
12
+ utax.egg-info/requires.txt
13
+ utax.egg-info/top_level.txt
14
+ utax/convolution/__init__.py
15
+ utax/convolution/classes.py
16
+ utax/convolution/functions.py
@@ -0,0 +1,7 @@
1
+ jax>=0.5.0
2
+ jaxlib>=0.5.0
3
+
4
+ [dev]
5
+ pytest
6
+ pytest-cov
7
+ pytest-pep8
@@ -0,0 +1 @@
1
+ utax