fdwt 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dwt/DWT1D.py ADDED
@@ -0,0 +1,70 @@
1
+ """ॐ
2
+ FDWT: Fast Multidimensional Discrete Wavelet Transform Layers (PyTorch).
3
+ Copyright 2026 Kishore Kumar Tarafdar
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ https://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License."""
16
+
17
+ import torch
18
+ from dwt.layout import DWTNDlayout, IDWTNDlayout
19
+
20
+
21
+ class DWT1D(DWTNDlayout):
22
+ """1D DWT analysis module.
23
+
24
+ clean=True: (batch, N, C) -> (batch, N/2, C*2) [L||H packed along channel axis]
25
+ clean=False: (batch, N, C) -> (batch, N, C)
26
+ """
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ N = x.shape[1]
30
+ A = self._get_A(N, x.device)
31
+ out = torch.einsum('ij,bjc->bic', A, x)
32
+ if self.clean:
33
+ return self._extract_2subbands(out)
34
+ return out
35
+
36
+ def _extract_2subbands(self, x: torch.Tensor) -> torch.Tensor:
37
+ mid = x.shape[1] // 2
38
+ L = x[:, :mid, :]
39
+ H = x[:, mid:, :]
40
+ return torch.cat([L, H], dim=-1)
41
+
42
+
43
+ class IDWT1D(IDWTNDlayout):
44
+ """1D IDWT synthesis module.
45
+
46
+ clean=True: (batch, N/2, C*2) -> (batch, N, C)
47
+ clean=False: (batch, N, C) -> (batch, N, C)
48
+ """
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ if self.clean:
52
+ x = self._join_2subbands(x)
53
+ N = x.shape[1]
54
+ S = self._get_S(N, x.device)
55
+ return torch.einsum('ij,bjc->bic', S, x)
56
+
57
+ def _join_2subbands(self, x: torch.Tensor) -> torch.Tensor:
58
+ L, H = torch.chunk(x, 2, dim=-1)
59
+ return torch.cat([L, H], dim=1)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ wave = 'haar'
64
+ dwt = DWT1D(wave)
65
+ idwt = IDWT1D(wave)
66
+ x = torch.randn(2, 256, 1)
67
+ lh = dwt(x)
68
+ xhat = idwt(lh)
69
+ print('DWT output:', lh.shape)
70
+ print('Reconstruction error (max):', (x - xhat).abs().max().item())
dwt/DWT2D.py ADDED
@@ -0,0 +1,85 @@
1
+ """ॐ
2
+ FDWT: Fast Multidimensional Discrete Wavelet Transform Layers (PyTorch).
3
+ Copyright 2026 Kishore Kumar Tarafdar
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ https://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License."""
16
+
17
+ import torch
18
+ from dwt.layout import DWTNDlayout, IDWTNDlayout
19
+
20
+
21
+ class DWT2D(DWTNDlayout):
22
+ """2D DWT analysis module.
23
+
24
+ clean=True: (batch, H, W, C) -> (batch, H/2, W/2, C*4) [LL|LH|HL|HH along channel]
25
+ clean=False: (batch, H, W, C) -> (batch, H, W, C)
26
+ """
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ N = x.shape[1]
30
+ A = self._get_A(N, x.device)
31
+ # columns (swap H and W, apply A along H, swap back)
32
+ x = x.permute(0, 2, 1, 3)
33
+ x = torch.einsum('ij,bjkc->bikc', A, x)
34
+ x = x.permute(0, 2, 1, 3)
35
+ # rows
36
+ x = torch.einsum('ij,bjkc->bikc', A, x)
37
+ if self.clean:
38
+ return self._extract_4subbands(x)
39
+ return x
40
+
41
+ def _extract_4subbands(self, x: torch.Tensor) -> torch.Tensor:
42
+ mid = x.shape[1] // 2
43
+ LL = x[:, :mid, :mid, :]
44
+ LH = x[:, mid:, :mid, :]
45
+ HL = x[:, :mid, mid:, :]
46
+ HH = x[:, mid:, mid:, :]
47
+ return torch.cat([LL, LH, HL, HH], dim=-1)
48
+
49
+
50
+ class IDWT2D(IDWTNDlayout):
51
+ """2D IDWT synthesis module.
52
+
53
+ clean=True: (batch, H/2, W/2, C*4) -> (batch, H, W, C)
54
+ clean=False: (batch, H, W, C) -> (batch, H, W, C)
55
+ """
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ if self.clean:
59
+ x = self._join_quadrants(x)
60
+ N = x.shape[1]
61
+ S = self._get_S(N, x.device)
62
+ # columns
63
+ x = x.permute(0, 2, 1, 3)
64
+ x = torch.einsum('ij,bjkc->bikc', S, x)
65
+ x = x.permute(0, 2, 1, 3)
66
+ # rows
67
+ x = torch.einsum('ij,bjkc->bikc', S, x)
68
+ return x
69
+
70
+ def _join_quadrants(self, x: torch.Tensor) -> torch.Tensor:
71
+ LL, LH, HL, HH = torch.chunk(x, 4, dim=-1)
72
+ top = torch.cat([LL, HL], dim=2)
73
+ bottom = torch.cat([LH, HH], dim=2)
74
+ return torch.cat([top, bottom], dim=1)
75
+
76
+
77
+ if __name__ == '__main__':
78
+ wave = 'haar'
79
+ dwt = DWT2D(wave)
80
+ idwt = IDWT2D(wave)
81
+ x = torch.randn(2, 256, 256, 1)
82
+ lh = dwt(x)
83
+ xhat = idwt(lh)
84
+ print('DWT output:', lh.shape)
85
+ print('Reconstruction error (max):', (x - xhat).abs().max().item())
dwt/DWT3D.py ADDED
@@ -0,0 +1,101 @@
1
+ """ॐ
2
+ FDWT: Fast Multidimensional Discrete Wavelet Transform Layers (PyTorch).
3
+ Copyright 2026 Kishore Kumar Tarafdar
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ https://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License."""
16
+
17
+ import torch
18
+ from dwt.layout import DWTNDlayout, IDWTNDlayout
19
+
20
+
21
+ class DWT3D(DWTNDlayout):
22
+ """3D DWT analysis module.
23
+
24
+ clean=True: (batch, D, H, W, C) -> (batch, D/2, H/2, W/2, C*8)
25
+ clean=False: (batch, D, H, W, C) -> (batch, D, H, W, C)
26
+ """
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ N = x.shape[1]
30
+ A = self._get_A(N, x.device)
31
+ # columns (axis 2): swap D and H
32
+ x = x.permute(0, 2, 1, 3, 4)
33
+ x = torch.einsum('ij,bjklc->biklc', A, x)
34
+ x = x.permute(0, 2, 1, 3, 4)
35
+ # rows (axis 1)
36
+ x = torch.einsum('ij,bjklc->biklc', A, x)
37
+ # depth (axis 3): bring W to front
38
+ x = x.permute(0, 3, 1, 2, 4)
39
+ x = torch.einsum('ij,bjklc->biklc', A, x)
40
+ x = x.permute(0, 2, 3, 1, 4)
41
+ if self.clean:
42
+ return self._extract_8subbands(x)
43
+ return x
44
+
45
+ def _extract_8subbands(self, x: torch.Tensor) -> torch.Tensor:
46
+ mid = x.shape[1] // 2
47
+ LLL = x[:, :mid, :mid, :mid, :]
48
+ LLH = x[:, :mid, :mid, mid:, :]
49
+ LHL = x[:, :mid, mid:, :mid, :]
50
+ LHH = x[:, :mid, mid:, mid:, :]
51
+ HLL = x[:, mid:, :mid, :mid, :]
52
+ HLH = x[:, mid:, :mid, mid:, :]
53
+ HHL = x[:, mid:, mid:, :mid, :]
54
+ HHH = x[:, mid:, mid:, mid:, :]
55
+ return torch.cat([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH], dim=-1)
56
+
57
+
58
+ class IDWT3D(IDWTNDlayout):
59
+ """3D IDWT synthesis module.
60
+
61
+ clean=True: (batch, D/2, H/2, W/2, C*8) -> (batch, D, H, W, C)
62
+ clean=False: (batch, D, H, W, C) -> (batch, D, H, W, C)
63
+ """
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ if self.clean:
67
+ x = self._join_octants(x)
68
+ N = x.shape[1]
69
+ S = self._get_S(N, x.device)
70
+ # columns
71
+ x = x.permute(0, 2, 1, 3, 4)
72
+ x = torch.einsum('ij,bjklc->biklc', S, x)
73
+ x = x.permute(0, 2, 1, 3, 4)
74
+ # rows
75
+ x = torch.einsum('ij,bjklc->biklc', S, x)
76
+ # depth
77
+ x = x.permute(0, 3, 1, 2, 4)
78
+ x = torch.einsum('ij,bjklc->biklc', S, x)
79
+ x = x.permute(0, 2, 3, 1, 4)
80
+ return x
81
+
82
+ def _join_octants(self, x: torch.Tensor) -> torch.Tensor:
83
+ LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = torch.chunk(x, 8, dim=-1)
84
+ front_top = torch.cat([LLL, LLH], dim=3)
85
+ front_bot = torch.cat([LHL, LHH], dim=3)
86
+ back_top = torch.cat([HLL, HLH], dim=3)
87
+ back_bot = torch.cat([HHL, HHH], dim=3)
88
+ front = torch.cat([front_top, front_bot], dim=2)
89
+ back = torch.cat([back_top, back_bot], dim=2)
90
+ return torch.cat([front, back], dim=1)
91
+
92
+
93
+ if __name__ == '__main__':
94
+ wave = 'bior1.5'
95
+ dwt = DWT3D(wave)
96
+ idwt = IDWT3D(wave)
97
+ x = torch.randn(1, 32, 32, 32, 2)
98
+ lh = dwt(x)
99
+ xhat = idwt(lh)
100
+ print('DWT output:', lh.shape)
101
+ print('Reconstruction error (max):', (x - xhat).abs().max().item())
dwt/__init__.py ADDED
@@ -0,0 +1,26 @@
1
+ """ॐ
2
+ FDWT: Fast Multidimensional Discrete Wavelet Transform Layers (PyTorch).
3
+ Copyright 2026 Kishore Kumar Tarafdar
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ https://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License."""
16
+
17
+ from dwt.filters import FetchAnalysisSynthesisFilters
18
+ from dwt.dwt_op import make_dwt_operator_matrix_A
19
+ from dwt.DWT1D import DWT1D, IDWT1D
20
+ from dwt.DWT2D import DWT2D, IDWT2D
21
+ from dwt.DWT3D import DWT3D, IDWT3D
22
+ from dwt.multilevel.dwt1 import dwt, idwt
23
+ from dwt.multilevel.dwt2 import dwt2, idwt2
24
+ from dwt.multilevel.dwt3 import dwt3, idwt3
25
+
26
+ __version__ = "0.1.0"