fdwt 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dwt/DWT1D.py +70 -0
- dwt/DWT2D.py +85 -0
- dwt/DWT3D.py +101 -0
- dwt/__init__.py +26 -0
- dwt/dbFBimpulseResponse.py +24 -0
- dwt/dwt_op.py +42 -0
- dwt/filters.py +38 -0
- dwt/layout.py +82 -0
- dwt/multilevel/__init__.py +3 -0
- dwt/multilevel/dwt1.py +70 -0
- dwt/multilevel/dwt2.py +70 -0
- dwt/multilevel/dwt3.py +70 -0
- fdwt-0.1.0.dist-info/METADATA +178 -0
- fdwt-0.1.0.dist-info/RECORD +16 -0
- fdwt-0.1.0.dist-info/WHEEL +4 -0
- fdwt-0.1.0.dist-info/licenses/LICENSE +201 -0
dwt/DWT1D.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""ॐ
|
|
2
|
+
FDWT: Fast Multidimensional Discrete Wavelet Transform Layers (PyTorch).
|
|
3
|
+
Copyright 2026 Kishore Kumar Tarafdar
|
|
4
|
+
|
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
you may not use this file except in compliance with the License.
|
|
7
|
+
You may obtain a copy of the License at
|
|
8
|
+
|
|
9
|
+
https://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
|
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
See the License for the specific language governing permissions and
|
|
15
|
+
limitations under the License."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from dwt.layout import DWTNDlayout, IDWTNDlayout
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DWT1D(DWTNDlayout):
|
|
22
|
+
"""1D DWT analysis module.
|
|
23
|
+
|
|
24
|
+
clean=True: (batch, N, C) -> (batch, N/2, C*2) [L||H packed along channel axis]
|
|
25
|
+
clean=False: (batch, N, C) -> (batch, N, C)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
N = x.shape[1]
|
|
30
|
+
A = self._get_A(N, x.device)
|
|
31
|
+
out = torch.einsum('ij,bjc->bic', A, x)
|
|
32
|
+
if self.clean:
|
|
33
|
+
return self._extract_2subbands(out)
|
|
34
|
+
return out
|
|
35
|
+
|
|
36
|
+
def _extract_2subbands(self, x: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
mid = x.shape[1] // 2
|
|
38
|
+
L = x[:, :mid, :]
|
|
39
|
+
H = x[:, mid:, :]
|
|
40
|
+
return torch.cat([L, H], dim=-1)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class IDWT1D(IDWTNDlayout):
|
|
44
|
+
"""1D IDWT synthesis module.
|
|
45
|
+
|
|
46
|
+
clean=True: (batch, N/2, C*2) -> (batch, N, C)
|
|
47
|
+
clean=False: (batch, N, C) -> (batch, N, C)
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
51
|
+
if self.clean:
|
|
52
|
+
x = self._join_2subbands(x)
|
|
53
|
+
N = x.shape[1]
|
|
54
|
+
S = self._get_S(N, x.device)
|
|
55
|
+
return torch.einsum('ij,bjc->bic', S, x)
|
|
56
|
+
|
|
57
|
+
def _join_2subbands(self, x: torch.Tensor) -> torch.Tensor:
|
|
58
|
+
L, H = torch.chunk(x, 2, dim=-1)
|
|
59
|
+
return torch.cat([L, H], dim=1)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
if __name__ == '__main__':
|
|
63
|
+
wave = 'haar'
|
|
64
|
+
dwt = DWT1D(wave)
|
|
65
|
+
idwt = IDWT1D(wave)
|
|
66
|
+
x = torch.randn(2, 256, 1)
|
|
67
|
+
lh = dwt(x)
|
|
68
|
+
xhat = idwt(lh)
|
|
69
|
+
print('DWT output:', lh.shape)
|
|
70
|
+
print('Reconstruction error (max):', (x - xhat).abs().max().item())
|
dwt/DWT2D.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""ॐ
|
|
2
|
+
FDWT: Fast Multidimensional Discrete Wavelet Transform Layers (PyTorch).
|
|
3
|
+
Copyright 2026 Kishore Kumar Tarafdar
|
|
4
|
+
|
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
you may not use this file except in compliance with the License.
|
|
7
|
+
You may obtain a copy of the License at
|
|
8
|
+
|
|
9
|
+
https://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
|
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
See the License for the specific language governing permissions and
|
|
15
|
+
limitations under the License."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from dwt.layout import DWTNDlayout, IDWTNDlayout
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DWT2D(DWTNDlayout):
|
|
22
|
+
"""2D DWT analysis module.
|
|
23
|
+
|
|
24
|
+
clean=True: (batch, H, W, C) -> (batch, H/2, W/2, C*4) [LL|LH|HL|HH along channel]
|
|
25
|
+
clean=False: (batch, H, W, C) -> (batch, H, W, C)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
N = x.shape[1]
|
|
30
|
+
A = self._get_A(N, x.device)
|
|
31
|
+
# columns (swap H and W, apply A along H, swap back)
|
|
32
|
+
x = x.permute(0, 2, 1, 3)
|
|
33
|
+
x = torch.einsum('ij,bjkc->bikc', A, x)
|
|
34
|
+
x = x.permute(0, 2, 1, 3)
|
|
35
|
+
# rows
|
|
36
|
+
x = torch.einsum('ij,bjkc->bikc', A, x)
|
|
37
|
+
if self.clean:
|
|
38
|
+
return self._extract_4subbands(x)
|
|
39
|
+
return x
|
|
40
|
+
|
|
41
|
+
def _extract_4subbands(self, x: torch.Tensor) -> torch.Tensor:
|
|
42
|
+
mid = x.shape[1] // 2
|
|
43
|
+
LL = x[:, :mid, :mid, :]
|
|
44
|
+
LH = x[:, mid:, :mid, :]
|
|
45
|
+
HL = x[:, :mid, mid:, :]
|
|
46
|
+
HH = x[:, mid:, mid:, :]
|
|
47
|
+
return torch.cat([LL, LH, HL, HH], dim=-1)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class IDWT2D(IDWTNDlayout):
|
|
51
|
+
"""2D IDWT synthesis module.
|
|
52
|
+
|
|
53
|
+
clean=True: (batch, H/2, W/2, C*4) -> (batch, H, W, C)
|
|
54
|
+
clean=False: (batch, H, W, C) -> (batch, H, W, C)
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
58
|
+
if self.clean:
|
|
59
|
+
x = self._join_quadrants(x)
|
|
60
|
+
N = x.shape[1]
|
|
61
|
+
S = self._get_S(N, x.device)
|
|
62
|
+
# columns
|
|
63
|
+
x = x.permute(0, 2, 1, 3)
|
|
64
|
+
x = torch.einsum('ij,bjkc->bikc', S, x)
|
|
65
|
+
x = x.permute(0, 2, 1, 3)
|
|
66
|
+
# rows
|
|
67
|
+
x = torch.einsum('ij,bjkc->bikc', S, x)
|
|
68
|
+
return x
|
|
69
|
+
|
|
70
|
+
def _join_quadrants(self, x: torch.Tensor) -> torch.Tensor:
|
|
71
|
+
LL, LH, HL, HH = torch.chunk(x, 4, dim=-1)
|
|
72
|
+
top = torch.cat([LL, HL], dim=2)
|
|
73
|
+
bottom = torch.cat([LH, HH], dim=2)
|
|
74
|
+
return torch.cat([top, bottom], dim=1)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
if __name__ == '__main__':
|
|
78
|
+
wave = 'haar'
|
|
79
|
+
dwt = DWT2D(wave)
|
|
80
|
+
idwt = IDWT2D(wave)
|
|
81
|
+
x = torch.randn(2, 256, 256, 1)
|
|
82
|
+
lh = dwt(x)
|
|
83
|
+
xhat = idwt(lh)
|
|
84
|
+
print('DWT output:', lh.shape)
|
|
85
|
+
print('Reconstruction error (max):', (x - xhat).abs().max().item())
|
dwt/DWT3D.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""ॐ
|
|
2
|
+
FDWT: Fast Multidimensional Discrete Wavelet Transform Layers (PyTorch).
|
|
3
|
+
Copyright 2026 Kishore Kumar Tarafdar
|
|
4
|
+
|
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
you may not use this file except in compliance with the License.
|
|
7
|
+
You may obtain a copy of the License at
|
|
8
|
+
|
|
9
|
+
https://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
|
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
See the License for the specific language governing permissions and
|
|
15
|
+
limitations under the License."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from dwt.layout import DWTNDlayout, IDWTNDlayout
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DWT3D(DWTNDlayout):
|
|
22
|
+
"""3D DWT analysis module.
|
|
23
|
+
|
|
24
|
+
clean=True: (batch, D, H, W, C) -> (batch, D/2, H/2, W/2, C*8)
|
|
25
|
+
clean=False: (batch, D, H, W, C) -> (batch, D, H, W, C)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
N = x.shape[1]
|
|
30
|
+
A = self._get_A(N, x.device)
|
|
31
|
+
# columns (axis 2): swap D and H
|
|
32
|
+
x = x.permute(0, 2, 1, 3, 4)
|
|
33
|
+
x = torch.einsum('ij,bjklc->biklc', A, x)
|
|
34
|
+
x = x.permute(0, 2, 1, 3, 4)
|
|
35
|
+
# rows (axis 1)
|
|
36
|
+
x = torch.einsum('ij,bjklc->biklc', A, x)
|
|
37
|
+
# depth (axis 3): bring W to front
|
|
38
|
+
x = x.permute(0, 3, 1, 2, 4)
|
|
39
|
+
x = torch.einsum('ij,bjklc->biklc', A, x)
|
|
40
|
+
x = x.permute(0, 2, 3, 1, 4)
|
|
41
|
+
if self.clean:
|
|
42
|
+
return self._extract_8subbands(x)
|
|
43
|
+
return x
|
|
44
|
+
|
|
45
|
+
def _extract_8subbands(self, x: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
mid = x.shape[1] // 2
|
|
47
|
+
LLL = x[:, :mid, :mid, :mid, :]
|
|
48
|
+
LLH = x[:, :mid, :mid, mid:, :]
|
|
49
|
+
LHL = x[:, :mid, mid:, :mid, :]
|
|
50
|
+
LHH = x[:, :mid, mid:, mid:, :]
|
|
51
|
+
HLL = x[:, mid:, :mid, :mid, :]
|
|
52
|
+
HLH = x[:, mid:, :mid, mid:, :]
|
|
53
|
+
HHL = x[:, mid:, mid:, :mid, :]
|
|
54
|
+
HHH = x[:, mid:, mid:, mid:, :]
|
|
55
|
+
return torch.cat([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH], dim=-1)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class IDWT3D(IDWTNDlayout):
|
|
59
|
+
"""3D IDWT synthesis module.
|
|
60
|
+
|
|
61
|
+
clean=True: (batch, D/2, H/2, W/2, C*8) -> (batch, D, H, W, C)
|
|
62
|
+
clean=False: (batch, D, H, W, C) -> (batch, D, H, W, C)
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
66
|
+
if self.clean:
|
|
67
|
+
x = self._join_octants(x)
|
|
68
|
+
N = x.shape[1]
|
|
69
|
+
S = self._get_S(N, x.device)
|
|
70
|
+
# columns
|
|
71
|
+
x = x.permute(0, 2, 1, 3, 4)
|
|
72
|
+
x = torch.einsum('ij,bjklc->biklc', S, x)
|
|
73
|
+
x = x.permute(0, 2, 1, 3, 4)
|
|
74
|
+
# rows
|
|
75
|
+
x = torch.einsum('ij,bjklc->biklc', S, x)
|
|
76
|
+
# depth
|
|
77
|
+
x = x.permute(0, 3, 1, 2, 4)
|
|
78
|
+
x = torch.einsum('ij,bjklc->biklc', S, x)
|
|
79
|
+
x = x.permute(0, 2, 3, 1, 4)
|
|
80
|
+
return x
|
|
81
|
+
|
|
82
|
+
def _join_octants(self, x: torch.Tensor) -> torch.Tensor:
|
|
83
|
+
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = torch.chunk(x, 8, dim=-1)
|
|
84
|
+
front_top = torch.cat([LLL, LLH], dim=3)
|
|
85
|
+
front_bot = torch.cat([LHL, LHH], dim=3)
|
|
86
|
+
back_top = torch.cat([HLL, HLH], dim=3)
|
|
87
|
+
back_bot = torch.cat([HHL, HHH], dim=3)
|
|
88
|
+
front = torch.cat([front_top, front_bot], dim=2)
|
|
89
|
+
back = torch.cat([back_top, back_bot], dim=2)
|
|
90
|
+
return torch.cat([front, back], dim=1)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
if __name__ == '__main__':
|
|
94
|
+
wave = 'bior1.5'
|
|
95
|
+
dwt = DWT3D(wave)
|
|
96
|
+
idwt = IDWT3D(wave)
|
|
97
|
+
x = torch.randn(1, 32, 32, 32, 2)
|
|
98
|
+
lh = dwt(x)
|
|
99
|
+
xhat = idwt(lh)
|
|
100
|
+
print('DWT output:', lh.shape)
|
|
101
|
+
print('Reconstruction error (max):', (x - xhat).abs().max().item())
|
dwt/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""ॐ
|
|
2
|
+
FDWT: Fast Multidimensional Discrete Wavelet Transform Layers (PyTorch).
|
|
3
|
+
Copyright 2026 Kishore Kumar Tarafdar
|
|
4
|
+
|
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
you may not use this file except in compliance with the License.
|
|
7
|
+
You may obtain a copy of the License at
|
|
8
|
+
|
|
9
|
+
https://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
|
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
See the License for the specific language governing permissions and
|
|
15
|
+
limitations under the License."""
|
|
16
|
+
|
|
17
|
+
from dwt.filters import FetchAnalysisSynthesisFilters
|
|
18
|
+
from dwt.dwt_op import make_dwt_operator_matrix_A
|
|
19
|
+
from dwt.DWT1D import DWT1D, IDWT1D
|
|
20
|
+
from dwt.DWT2D import DWT2D, IDWT2D
|
|
21
|
+
from dwt.DWT3D import DWT3D, IDWT3D
|
|
22
|
+
from dwt.multilevel.dwt1 import dwt, idwt
|
|
23
|
+
from dwt.multilevel.dwt2 import dwt2, idwt2
|
|
24
|
+
from dwt.multilevel.dwt3 import dwt3, idwt3
|
|
25
|
+
|
|
26
|
+
__version__ = "0.1.0"
|