dct-autoencoder 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2 @@
1
+ from .basis import DCTBasis, get_dct_basis
2
+ from .core import DCTAutoencoder
@@ -0,0 +1,56 @@
1
+ from typing import NamedTuple
2
+
3
+ import numpy as np
4
+
5
+
6
+ class DCTBasis(NamedTuple):
7
+ basis_functions: np.ndarray
8
+ spatial_frequencies_components: np.ndarray
9
+ spatial_frequencies_magnitude: np.ndarray
10
+ multiplication_factor_matrix: np.ndarray
11
+ multiplication_factor_scalar: float
12
+ block_size: int
13
+
14
+
15
+ def get_dct_basis(block_size: int = 8) -> DCTBasis:
16
+ """Generate the DCT basis variables for a given block size.
17
+
18
+ Args:
19
+ block_size (int, optional): The block size. Defaults to 8.
20
+
21
+ Returns:
22
+ DCTBasis: The DCT basis variables.
23
+ """
24
+ frequencies = np.arange(block_size)
25
+ x = np.arange(block_size)
26
+ y = np.arange(block_size)
27
+ x, y = np.meshgrid(x, y, indexing="xy")
28
+ basis_functions = np.zeros(
29
+ (block_size, block_size, block_size, block_size), dtype=np.float32
30
+ )
31
+ spatial_frequencies = np.zeros((block_size, block_size, 2), dtype=np.int64)
32
+ multiplication_factor_matrix = np.zeros((block_size, block_size), dtype=np.float32)
33
+ for v in frequencies:
34
+ for u in frequencies:
35
+ # spatial frequencies
36
+ spatial_frequencies[v, u] = (v, u)
37
+ # basis functions
38
+ x_ref_patch = np.cos(((2 * x + 1) * u * np.pi) / (2 * block_size))
39
+ y_ref_patch = np.cos(((2 * y + 1) * v * np.pi) / (2 * block_size))
40
+ basis_functions[v, u] = x_ref_patch * y_ref_patch
41
+ # constants
42
+ c_v = 1 / np.sqrt(2) if v == 0 else 1
43
+ c_u = 1 / np.sqrt(2) if u == 0 else 1
44
+ multiplication_factor_matrix[v, u] = c_u * c_v
45
+
46
+ spatial_frequencies_magnitude = np.linalg.norm(spatial_frequencies, axis=2)
47
+ multiplication_factor_scalar = 2 / block_size
48
+
49
+ return DCTBasis(
50
+ basis_functions=basis_functions,
51
+ spatial_frequencies_components=spatial_frequencies,
52
+ spatial_frequencies_magnitude=spatial_frequencies_magnitude,
53
+ multiplication_factor_matrix=multiplication_factor_matrix,
54
+ multiplication_factor_scalar=multiplication_factor_scalar,
55
+ block_size=block_size,
56
+ )
@@ -0,0 +1,260 @@
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from .basis import get_dct_basis
7
+ from .utils import rgb_to_ycbcr, ycbcr_to_rgb
8
+
9
+
10
+ class DCTAutoencoder(nn.Module):
11
+ """DCT Autoencoder.
12
+
13
+ Args:
14
+ block_size (int, optional): The block size. Defaults to 8.
15
+ """
16
+
17
+ def __init__(self, block_size: int = 8) -> None:
18
+ super().__init__()
19
+ dct_basis = get_dct_basis(block_size)
20
+ basis_functions = dct_basis.basis_functions
21
+ kernels = basis_functions.reshape(-1, block_size, block_size)
22
+ spatial_frequencies_magnitude = dct_basis.spatial_frequencies_magnitude.reshape(
23
+ -1
24
+ )
25
+ sort_indices = np.argsort(spatial_frequencies_magnitude)
26
+ kernels = kernels[sort_indices]
27
+ spatial_frequencies_magnitude = spatial_frequencies_magnitude[sort_indices]
28
+ kernels = kernels[:, np.newaxis, :, :]
29
+ multiplication_factor_scalar = dct_basis.multiplication_factor_scalar
30
+ multiplication_factor_matrix = dct_basis.multiplication_factor_matrix
31
+ multiplication_factor_matrix = multiplication_factor_matrix.reshape(-1)
32
+ multiplication_factor_matrix = multiplication_factor_matrix[sort_indices]
33
+ multiplication_factor_matrix = multiplication_factor_matrix[
34
+ np.newaxis, :, np.newaxis, np.newaxis
35
+ ]
36
+ self.register_buffer("kernels", torch.from_numpy(kernels))
37
+ self.register_buffer(
38
+ "spatial_frequencies_magnitude",
39
+ torch.from_numpy(spatial_frequencies_magnitude),
40
+ )
41
+ self.register_buffer("block_size", torch.tensor(block_size))
42
+ self.register_buffer(
43
+ "multiplication_factor_scalar", torch.tensor(multiplication_factor_scalar)
44
+ )
45
+ self.register_buffer(
46
+ "multiplication_factor_matrix",
47
+ torch.from_numpy(multiplication_factor_matrix),
48
+ )
49
+
50
+ self.embedding_dimension = (block_size**2) * 3
51
+
52
+ def encode(self, rgb_images_batch: torch.Tensor) -> torch.Tensor:
53
+ """Encodes the input RGB images.
54
+
55
+ Args:
56
+ rgb_images_batch (torch.Tensor): The input RGB images. The images should
57
+ have shape (*, 3, height, width). Image values should be in the range [0, 1].
58
+
59
+ Returns:
60
+ torch.Tensor: The encoded images.
61
+ """
62
+ # check input
63
+ b, c, h, w = rgb_images_batch.shape
64
+ if c != 3:
65
+ raise ValueError("Input images must be RGB")
66
+ if h % self.block_size != 0 or w % self.block_size != 0:
67
+ raise ValueError("Image dimensions must be divisible by the block size")
68
+ # convert to YCbCr
69
+ ycbcr_tsr = rgb_to_ycbcr(rgb_images_batch)
70
+ # normalize to -1, 1
71
+ ycbcr_tsr = 2 * ycbcr_tsr - 1
72
+ y = ycbcr_tsr[:, [0], :, :]
73
+ cb = ycbcr_tsr[:, [1], :, :]
74
+ cr = ycbcr_tsr[:, [2], :, :]
75
+
76
+ # DCT encode
77
+ c1 = self.multiplication_factor_scalar
78
+ c2 = self.multiplication_factor_matrix
79
+ y = c1 * c2 * F.conv2d(y, self.kernels, stride=self.block_size.item())
80
+ cb = c1 * c2 * F.conv2d(cb, self.kernels, stride=self.block_size.item())
81
+ cr = c1 * c2 * F.conv2d(cr, self.kernels, stride=self.block_size.item())
82
+
83
+ return torch.cat([y, cb, cr], dim=1)
84
+
85
+ def decode(self, encodings_batch: torch.Tensor) -> torch.Tensor:
86
+ """Decodes the input encoded images.
87
+
88
+ Args:
89
+ encodings_batch (torch.Tensor): The input encoded images.
90
+
91
+ Returns:
92
+ torch.Tensor: The decoded images.
93
+ """
94
+ org_ch = self.block_size**2
95
+ y = encodings_batch[:, :org_ch, :, :]
96
+ cb = encodings_batch[:, org_ch : org_ch * 2, :, :]
97
+ cr = encodings_batch[:, org_ch * 2 :, :, :]
98
+
99
+ # DCT Decode
100
+ c1 = self.multiplication_factor_scalar
101
+ c2 = self.multiplication_factor_matrix
102
+ y = c1 * F.conv_transpose2d(y * c2, self.kernels, stride=self.block_size.item())
103
+ cb = c1 * F.conv_transpose2d(
104
+ cb * c2, self.kernels, stride=self.block_size.item()
105
+ )
106
+ cr = c1 * F.conv_transpose2d(
107
+ cr * c2, self.kernels, stride=self.block_size.item()
108
+ )
109
+
110
+ # convert to RGB
111
+ ycbcr_tsr = torch.cat([y, cb, cr], dim=1)
112
+ ycbcr_tsr = ycbcr_tsr / 2 + 0.5
113
+ rgb_images_batch = ycbcr_to_rgb(ycbcr_tsr)
114
+ return rgb_images_batch
115
+
116
+ def get_num_compressed_channels(
117
+ self,
118
+ luminance_compression_ratio: float = 1 / 2,
119
+ chrominance_compression_ratio: float = 1 / 4,
120
+ ) -> int:
121
+ """Get the number of compressed channels.
122
+
123
+ Args:
124
+ luminance_compression_ratio (float, optional): The luminance compression
125
+ ratio. Defaults to 1/2.
126
+ chrominance_compression_ratio (float, optional): The chrominance compression
127
+ ratio. Defaults to 1/4.
128
+
129
+ Returns:
130
+ int: The number of compressed channels.
131
+ """
132
+ num_per_channel_encodings = self.block_size**2
133
+ num_luminance_encodings = torch.round(
134
+ num_per_channel_encodings * luminance_compression_ratio
135
+ ).int()
136
+ num_chrominance_encodings = torch.round(
137
+ num_per_channel_encodings * chrominance_compression_ratio
138
+ ).int()
139
+ return (num_luminance_encodings + 2 * num_chrominance_encodings).item()
140
+
141
+ def compress(
142
+ self,
143
+ encodings_batch: torch.Tensor,
144
+ luminance_compression_ratio: float = 1 / 2,
145
+ chrominance_compression_ratio: float = 1 / 4,
146
+ ) -> torch.Tensor:
147
+ """Compresses the input encodings.
148
+
149
+ Args:
150
+ encodings_batch (torch.Tensor): The input encodings.
151
+ luminance_compression_ratio (float, optional): The luminance compression
152
+ ratio. Defaults to 1/2.
153
+ chrominance_compression_ratio (float, optional): The chrominance compression
154
+ ratio. Defaults to 1/4.
155
+
156
+ Returns:
157
+ torch.Tensor: The compressed encodings.
158
+ """
159
+
160
+ num_per_channel_encodings = self.block_size**2
161
+ num_luminance_encodings = torch.round(
162
+ num_per_channel_encodings * luminance_compression_ratio
163
+ ).int()
164
+ num_chrominance_encodings = torch.round(
165
+ num_per_channel_encodings * chrominance_compression_ratio
166
+ ).int()
167
+
168
+ luminance_encodings = encodings_batch[:, :num_per_channel_encodings]
169
+ chrominance_blue_encodings = encodings_batch[
170
+ :, num_per_channel_encodings : 2 * num_per_channel_encodings
171
+ ]
172
+ chrominance_red_encodings = encodings_batch[:, 2 * num_per_channel_encodings :]
173
+
174
+ luminance_encodings = luminance_encodings[:, :num_luminance_encodings]
175
+ chrominance_blue_encodings = chrominance_blue_encodings[
176
+ :, :num_chrominance_encodings
177
+ ]
178
+ chrominance_red_encodings = chrominance_red_encodings[
179
+ :, :num_chrominance_encodings
180
+ ]
181
+ compressed_dct_encodings = torch.cat(
182
+ [
183
+ luminance_encodings,
184
+ chrominance_blue_encodings,
185
+ chrominance_red_encodings,
186
+ ],
187
+ dim=1,
188
+ )
189
+ return compressed_dct_encodings
190
+
191
+ def decompress(
192
+ self,
193
+ compressed_encodings_batch: torch.Tensor,
194
+ luminance_compression_ratio: float = 1 / 2,
195
+ chrominance_compression_ratio: float = 1 / 4,
196
+ ) -> torch.Tensor:
197
+ """Decompresses the input compressed encodings.
198
+
199
+ Args:
200
+ compressed_encodings_batch (torch.Tensor): The input compressed encodings.
201
+ luminance_compression_ratio (float, optional): The luminance compression
202
+ ratio. Defaults to 1/2.
203
+ chrominance_compression_ratio (float, optional): The chrominance compression
204
+ ratio. Defaults to 1/4.
205
+
206
+ Returns:
207
+ torch.Tensor: The decompressed encodings.
208
+ """
209
+
210
+ b, _, h, w = compressed_encodings_batch.shape
211
+ dtype = compressed_encodings_batch.dtype
212
+ device = compressed_encodings_batch.device
213
+
214
+ num_per_channel_encodings = self.block_size**2
215
+ num_luminance_encodings = torch.floor(
216
+ num_per_channel_encodings * luminance_compression_ratio
217
+ ).int()
218
+ num_chrominance_encodings = torch.floor(
219
+ num_per_channel_encodings * chrominance_compression_ratio
220
+ ).int()
221
+ compressed_luminance_encodings = compressed_encodings_batch[
222
+ :, :num_luminance_encodings
223
+ ]
224
+ compressed_chrominance_blue_encodings = compressed_encodings_batch[
225
+ :,
226
+ num_luminance_encodings : num_luminance_encodings
227
+ + num_chrominance_encodings,
228
+ ]
229
+ compressed_chrominance_red_encodings = compressed_encodings_batch[
230
+ :, num_luminance_encodings + num_chrominance_encodings :
231
+ ]
232
+
233
+ luminance_encodings = torch.zeros(
234
+ b, num_per_channel_encodings, h, w, dtype=dtype, device=device
235
+ )
236
+ luminance_encodings[:, :num_luminance_encodings, :, :] = (
237
+ compressed_luminance_encodings
238
+ )
239
+ chrominance_blue_encodings = torch.zeros(
240
+ b, num_per_channel_encodings, h, w, dtype=dtype, device=device
241
+ )
242
+ chrominance_blue_encodings[:, :num_chrominance_encodings, :, :] = (
243
+ compressed_chrominance_blue_encodings
244
+ )
245
+ chrominance_red_encodings = torch.zeros(
246
+ b, num_per_channel_encodings, h, w, dtype=dtype, device=device
247
+ )
248
+ chrominance_red_encodings[:, :num_chrominance_encodings, :, :] = (
249
+ compressed_chrominance_red_encodings
250
+ )
251
+ decompressed_dct_encodings = torch.cat(
252
+ [
253
+ luminance_encodings,
254
+ chrominance_blue_encodings,
255
+ chrominance_red_encodings,
256
+ ],
257
+ dim=1,
258
+ )
259
+
260
+ return decompressed_dct_encodings
@@ -0,0 +1,46 @@
1
+ import torch
2
+
3
+
4
+ def ycbcr_to_rgb(image: torch.Tensor) -> torch.Tensor:
5
+ """Converts an image from YCbCr to RGB color space.
6
+
7
+ Args:
8
+ image (torch.Tensor): The input image. The image should have shape
9
+ (*, 3, height, width). Image values should be in the range [0, 1].
10
+
11
+ Returns:
12
+ torch.Tensor: The output image in RGB color space.
13
+ """
14
+ y = image[..., 0, :, :]
15
+ cb = image[..., 1, :, :]
16
+ cr = image[..., 2, :, :]
17
+
18
+ delta: float = 0.5
19
+ cb_shifted = cb - delta
20
+ cr_shifted = cr - delta
21
+
22
+ r = y + 1.403 * cr_shifted
23
+ g = y - 0.714 * cr_shifted - 0.344 * cb_shifted
24
+ b = y + 1.773 * cb_shifted
25
+ return torch.stack([r, g, b], -3).clamp(0, 1)
26
+
27
+
28
+ def rgb_to_ycbcr(image) -> torch.Tensor:
29
+ """Converts an image from RGB to YCbCr color space.
30
+
31
+ Args:
32
+ image (torch.Tensor): The input image. The image should have shape
33
+ (*, 3, height, width). Image values should be in the range [0, 1].
34
+
35
+ Returns:
36
+ torch.Tensor: The output image in YCbCr color space.
37
+ """
38
+ r = image[..., 0, :, :]
39
+ g = image[..., 1, :, :]
40
+ b = image[..., 2, :, :]
41
+
42
+ delta: float = 0.5
43
+ y = 0.299 * r + 0.587 * g + 0.114 * b
44
+ cb = (b - y) * 0.564 + delta
45
+ cr = (r - y) * 0.713 + delta
46
+ return torch.stack([y, cb, cr], -3)
@@ -0,0 +1,52 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+ from .basis import DCTBasis
5
+
6
+
7
+ def visualize_dct_basis_functions(
8
+ dct_constants: DCTBasis,
9
+ figsize: int = 8,
10
+ fig_facecolor: str = "#fb6a2c",
11
+ title_color: str = "k",
12
+ title_fontsize: int = 20,
13
+ cmap: str = "gray",
14
+ ) -> tuple:
15
+ """Visualize the DCT basis functions.
16
+
17
+ Args:
18
+ dct_constants (DCTBasis): The DCT basis constants.
19
+ figsize (int, optional): The figure size. Defaults to 8.
20
+ fig_facecolor (str, optional): The figure facecolor. Defaults to "#fb6a2c".
21
+ title_color (str, optional): The title color. Defaults to "k".
22
+ title_fontsize (int, optional): The title fontsize. Defaults to 20.
23
+ cmap (str, optional): The colormap. Defaults to "gray".
24
+
25
+ Returns:
26
+ tuple: The figure and axis.
27
+ """
28
+ block_size = dct_constants.block_size
29
+ basis_functions = dct_constants.basis_functions
30
+ basis_functions_image = np.zeros((block_size * block_size, block_size * block_size))
31
+ for v in range(block_size):
32
+ for u in range(block_size):
33
+ basis_functions_image[
34
+ v * block_size : (v + 1) * block_size,
35
+ u * block_size : (u + 1) * block_size,
36
+ ] = basis_functions[v, u]
37
+ plt.figure(figsize=(figsize, figsize), facecolor=fig_facecolor)
38
+ plt.title(
39
+ f"DCT Basis functions (block size: {block_size}x{block_size})",
40
+ color=title_color,
41
+ fontsize=title_fontsize,
42
+ fontweight="bold",
43
+ )
44
+ plt.imshow(basis_functions_image, cmap=cmap)
45
+ plt.axis("off")
46
+ for i in range(block_size):
47
+ plt.axhline(i * block_size - 0.5, color=fig_facecolor)
48
+ plt.axvline(i * block_size - 0.5, color=fig_facecolor)
49
+ plt.tight_layout()
50
+ fig = plt.gcf()
51
+ ax = plt.gca()
52
+ return fig, ax
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 dariush-bahrami
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,34 @@
1
+ Metadata-Version: 2.1
2
+ Name: dct-autoencoder
3
+ Version: 0.1.0
4
+ Summary:
5
+ Author: Dariush Bahrami
6
+ Author-email: dariushbahrami1993@gmail.com
7
+ Requires-Python: >=3.10,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.10
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: matplotlib (>=3.9.2,<4.0.0)
13
+ Requires-Dist: numpy (>=2.1.1,<3.0.0)
14
+ Requires-Dist: torch (>=2.4.1,<3.0.0)
15
+ Description-Content-Type: text/markdown
16
+
17
+ # dct-autoencoder
18
+
19
+ 2D Discrete Cosine Transform in PyTorch
20
+
21
+ ![DCT Basis Functions](./assets/figures/dct_basis_functions_block_size_16.png)
22
+
23
+
24
+ ## Usage
25
+
26
+ Refer to the [usage notebook](./usage.ipynb) for code examples.
27
+
28
+
29
+ ## TODO
30
+
31
+ - [x] Add support for color images
32
+ - [x] Improve documentation
33
+ - [ ] Add tests
34
+ - [ ] Distribute on PyPI
@@ -0,0 +1,9 @@
1
+ dct_autoencoder/__init__.py,sha256=HPbvVdAiG_hWQ2ZNruAJnr8MOWJnMMQLjjwob-verAY,76
2
+ dct_autoencoder/basis.py,sha256=ynx3Plbts6snn7lcyfPUOCfFxUgcXMa3iYPWca7KCdI,2060
3
+ dct_autoencoder/core.py,sha256=ikVN_yeuSNN4NXbkU3kH_m32EIrkOn1z--tij-nCJIk,9965
4
+ dct_autoencoder/utils.py,sha256=oCFZgLAFFgV7dJoFXC3xfbtjM-fnEBpY1eSPywHIj4U,1306
5
+ dct_autoencoder/visualization.py,sha256=Mj9Ipz2oHTfzyq8hqM1UjoRxQ2ALAtoJeRAB7GZpXCs,1818
6
+ dct_autoencoder-0.1.0.dist-info/LICENSE,sha256=kVBYE8Z59CVgIBn5bMZF2ihgBM-2fyEDqU93DArFnQU,1072
7
+ dct_autoencoder-0.1.0.dist-info/METADATA,sha256=4XSc-J373HyOfe4FUk51qQy44ITa6uwN_Ylclv4P70g,856
8
+ dct_autoencoder-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
9
+ dct_autoencoder-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: poetry-core 1.9.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any