dct-autoencoder 0.2.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,17 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: dct-autoencoder
3
- Version: 0.2.0
4
- Summary:
3
+ Version: 0.3.0
4
+ Summary: Discrete Cosine Transform in PyTorch
5
5
  Author: Dariush Bahrami
6
- Author-email: dariushbahrami1993@gmail.com
7
- Requires-Python: >=3.10,<4.0
8
- Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.10
10
- Classifier: Programming Language :: Python :: 3.11
11
- Classifier: Programming Language :: Python :: 3.12
12
- Requires-Dist: matplotlib (>=3.9.2,<4.0.0)
13
- Requires-Dist: numpy (>=2.1.1,<3.0.0)
14
- Requires-Dist: torch (>=2.4.1,<3.0.0)
6
+ Author-email: Dariush Bahrami <dariushbahrami1993@gmail.com>
7
+ License: MIT
8
+ Requires-Dist: numpy>=2.4.6
9
+ Requires-Dist: torch>=2.0.0 ; extra == 'torch'
10
+ Requires-Dist: torchvision>=0.15.0 ; extra == 'torch'
11
+ Requires-Python: >=3.12
12
+ Project-URL: Homepage, https://github.com/dariush-bahrami/dct-autoencoder
13
+ Project-URL: Repository, https://github.com/dariush-bahrami/dct-autoencoder
14
+ Provides-Extra: torch
15
15
  Description-Content-Type: text/markdown
16
16
 
17
17
  # DCT-Autoencoder
@@ -57,4 +57,3 @@ DCT basis functions for a block size of 16:
57
57
  - [x] Improve documentation
58
58
  - [ ] Add unit tests
59
59
  - [x] Distribute package on PyPI
60
-
@@ -0,0 +1,36 @@
1
+ [project]
2
+ name = "dct-autoencoder"
3
+ version = "0.3.0"
4
+ description = "Discrete Cosine Transform in PyTorch"
5
+ readme = "README.md"
6
+ license = { text = "MIT" }
7
+ authors = [
8
+ { name = "Dariush Bahrami", email = "dariushbahrami1993@gmail.com" }
9
+ ]
10
+ requires-python = ">=3.12"
11
+ dependencies = [
12
+ "numpy>=2.4.6",
13
+ ]
14
+
15
+ [project.optional-dependencies]
16
+ torch = [
17
+ "torch>=2.0.0",
18
+ "torchvision>=0.15.0",
19
+ ]
20
+
21
+ [project.urls]
22
+ Homepage = "https://github.com/dariush-bahrami/dct-autoencoder"
23
+ Repository = "https://github.com/dariush-bahrami/dct-autoencoder"
24
+
25
+ [build-system]
26
+ requires = ["uv_build>=0.11.16,<0.12.0"]
27
+ build-backend = "uv_build"
28
+
29
+ [dependency-groups]
30
+ dev = [
31
+ "dct-autoencoder[torch]",
32
+ "ipykernel>=7.2.0",
33
+ "ipywidgets>=8.1.8",
34
+ "matplotlib>=3.10.9",
35
+ "pillow>=12.2.0",
36
+ ]
@@ -1,5 +1,3 @@
1
- import math
2
-
3
1
  import numpy as np
4
2
  import torch
5
3
  from torch import nn
@@ -14,15 +12,38 @@ class DCTAutoencoder(nn.Module):
14
12
 
15
13
  Args:
16
14
  block_size (int, optional): The block size. Defaults to 8.
15
+ num_luminance_compressed_channels (int | None, optional): Number of lowest-frequency
16
+ luminance channels to retain after compression. ``None`` keeps all channels
17
+ (no compression). Defaults to ``None``.
18
+ num_chrominance_compressed_channels (int | None, optional): Number of lowest-frequency
19
+ chrominance channels to retain after compression. ``None`` keeps all channels
20
+ (no compression). Defaults to ``None``.
17
21
  """
18
22
 
19
23
  def __init__(
20
24
  self,
21
25
  block_size: int = 8,
22
- luminance_compression_ratio: float = 1 / 2,
23
- chrominance_compression_ratio: float = 1 / 4,
26
+ num_luminance_compressed_channels: int | None = None,
27
+ num_chrominance_compressed_channels: int | None = None,
24
28
  ) -> None:
25
29
  super().__init__()
30
+ total_channels = block_size**2
31
+
32
+ if num_luminance_compressed_channels is not None and not (
33
+ 1 <= num_luminance_compressed_channels <= total_channels
34
+ ):
35
+ raise ValueError(
36
+ f"num_luminance_compressed_channels must be between 1 and {total_channels}, "
37
+ f"got {num_luminance_compressed_channels}"
38
+ )
39
+ if num_chrominance_compressed_channels is not None and not (
40
+ 1 <= num_chrominance_compressed_channels <= total_channels
41
+ ):
42
+ raise ValueError(
43
+ f"num_chrominance_compressed_channels must be between 1 and {total_channels}, "
44
+ f"got {num_chrominance_compressed_channels}"
45
+ )
46
+
26
47
  dct_basis = get_dct_basis(block_size)
27
48
  basis_functions = dct_basis.basis_functions
28
49
  kernels = basis_functions.reshape(-1, block_size, block_size)
@@ -32,7 +53,7 @@ class DCTAutoencoder(nn.Module):
32
53
  spatial_frequencies_components = (
33
54
  dct_basis.spatial_frequencies_components.reshape(-1, 2)
34
55
  )
35
- sort_indices = np.argsort(spatial_frequencies_magnitude)
56
+ sort_indices = np.argsort(spatial_frequencies_magnitude, kind="stable")
36
57
  kernels = kernels[sort_indices]
37
58
  spatial_frequencies_magnitude = spatial_frequencies_magnitude[sort_indices]
38
59
  spatial_frequencies_components = spatial_frequencies_components[sort_indices]
@@ -62,64 +83,32 @@ class DCTAutoencoder(nn.Module):
62
83
  torch.from_numpy(multiplication_factor_matrix),
63
84
  )
64
85
 
65
- self.embedding_dimension = (block_size**2) * 3
86
+ self.embedding_dimension = total_channels * 3
66
87
 
67
- # compressor initialization
68
- if luminance_compression_ratio == 1 and chrominance_compression_ratio == 1:
69
- self.do_compression = False
70
- self.compression_luminance_mask = torch.ones(
71
- block_size**2,
72
- dtype=bool,
73
- device=self.spatial_frequencies_components.device,
74
- )
75
- self.compression_chrominance_mask = torch.ones(
76
- block_size**2,
77
- dtype=bool,
78
- device=self.spatial_frequencies_components.device,
79
- )
80
- self.compression_luminance_passband = block_size**2
81
- self.compression_chrominance_passband = block_size**2
82
- else:
83
- original_frequencies = self.spatial_frequencies_components.to(
84
- dtype=torch.float32
85
- )
86
- luminance_block_size = math.ceil(block_size * luminance_compression_ratio)
87
- chrominance_block_size = math.ceil(
88
- block_size * chrominance_compression_ratio
89
- )
90
- luminance_frequencies = get_dct_basis(
91
- luminance_block_size
92
- ).spatial_frequencies_components.reshape(-1, 2)
93
- luminance_frequencies = torch.from_numpy(luminance_frequencies).to(
94
- device=original_frequencies.device, dtype=torch.float32
95
- )
96
- chrominance_frequencies = get_dct_basis(
97
- chrominance_block_size
98
- ).spatial_frequencies_components.reshape(-1, 2)
99
- chrominance_frequencies = torch.from_numpy(chrominance_frequencies).to(
100
- device=original_frequencies.device, dtype=torch.float32
101
- )
102
- indices = torch.arange(block_size**2, device=original_frequencies.device)
103
- luminance_mask = torch.isin(
104
- indices,
105
- torch.cdist(original_frequencies, luminance_frequencies, p=2).argmin(
106
- dim=0
107
- ),
108
- )
109
- chrominance_mask = torch.isin(
110
- indices,
111
- torch.cdist(original_frequencies, chrominance_frequencies, p=2).argmin(
112
- dim=0
113
- ),
114
- )
115
- luminance_passband = luminance_mask.sum()
116
- chrominance_passband = chrominance_mask.sum()
88
+ # compressor initialization — kernels are already sorted by ascending frequency,
89
+ # so keeping the first N channels retains the N lowest frequencies exactly.
90
+ lum_n = num_luminance_compressed_channels
91
+ chr_n = num_chrominance_compressed_channels
92
+ self.do_compression = (lum_n is not None and lum_n < total_channels) or (
93
+ chr_n is not None and chr_n < total_channels
94
+ )
95
+
96
+ lum_passband = lum_n if lum_n is not None else total_channels
97
+ chr_passband = chr_n if chr_n is not None else total_channels
98
+
99
+ lum_mask = torch.zeros(total_channels, dtype=torch.bool)
100
+ lum_mask[:lum_passband] = True
101
+ chr_mask = torch.zeros(total_channels, dtype=torch.bool)
102
+ chr_mask[:chr_passband] = True
117
103
 
118
- self.do_compression = True
119
- self.compression_luminance_mask = luminance_mask
120
- self.compression_chrominance_mask = chrominance_mask
121
- self.compression_luminance_passband = luminance_passband
122
- self.compression_chrominance_passband = chrominance_passband
104
+ self.register_buffer("compression_luminance_mask", lum_mask)
105
+ self.register_buffer("compression_chrominance_mask", chr_mask)
106
+ self.register_buffer(
107
+ "compression_luminance_passband", torch.tensor(lum_passband)
108
+ )
109
+ self.register_buffer(
110
+ "compression_chrominance_passband", torch.tensor(chr_passband)
111
+ )
123
112
 
124
113
  def encode(self, rgb_images_batch: torch.Tensor) -> torch.Tensor:
125
114
  """Encodes the input RGB images.
@@ -156,12 +145,14 @@ class DCTAutoencoder(nn.Module):
156
145
  encodings_batch = encodings_batch / self.block_size
157
146
  return encodings_batch
158
147
 
159
- def decode(self, encodings_batch: torch.Tensor) -> torch.Tensor:
148
+ def decode(
149
+ self, encodings_batch: torch.Tensor, clamp_output: bool = True
150
+ ) -> torch.Tensor:
160
151
  """Decodes the input encoded images.
161
152
 
162
153
  Args:
163
154
  encodings_batch (torch.Tensor): The input encoded images.
164
-
155
+ clamp_output (bool, optional): Whether to clamp the output to the range [0, 1]. Defaults to True.
165
156
  Returns:
166
157
  torch.Tensor: The decoded images.
167
158
  """
@@ -187,6 +178,9 @@ class DCTAutoencoder(nn.Module):
187
178
  ycbcr_tsr = torch.cat([y, cb, cr], dim=1)
188
179
  ycbcr_tsr = ycbcr_tsr / 2 + 0.5
189
180
  rgb_images_batch = ycbcr_to_rgb(ycbcr_tsr)
181
+ if clamp_output:
182
+ # clamp is expected for display; note it makes decode non-linear for out-of-range values
183
+ rgb_images_batch = rgb_images_batch.clamp(0, 1)
190
184
  return rgb_images_batch
191
185
 
192
186
  def get_num_compressed_channels(self) -> int:
@@ -198,7 +192,7 @@ class DCTAutoencoder(nn.Module):
198
192
  + 2 * self.compression_chrominance_passband.item()
199
193
  )
200
194
 
201
- def compress(self, encodings):
195
+ def compress(self, encodings: torch.Tensor) -> torch.Tensor:
202
196
  if not self.do_compression:
203
197
  return encodings
204
198
  else:
@@ -211,7 +205,7 @@ class DCTAutoencoder(nn.Module):
211
205
  compressed_encoding = torch.cat([l, c1, c2], dim=1)
212
206
  return compressed_encoding
213
207
 
214
- def decompress(self, compressed_encoding):
208
+ def decompress(self, compressed_encoding: torch.Tensor) -> torch.Tensor:
215
209
  if not self.do_compression:
216
210
  return compressed_encoding
217
211
  else:
@@ -19,10 +19,11 @@ def ycbcr_to_rgb(image: torch.Tensor) -> torch.Tensor:
19
19
  cb_shifted = cb - delta
20
20
  cr_shifted = cr - delta
21
21
 
22
- r = y + 1.403 * cr_shifted
23
- g = y - 0.714 * cr_shifted - 0.344 * cb_shifted
24
- b = y + 1.773 * cb_shifted
25
- return torch.stack([r, g, b], -3).clamp(0, 1)
22
+ # Exact inverse of the forward matrix: 1/0.713, 0.299/(0.713*0.587), 0.114/(0.564*0.587), 1/0.564
23
+ r = y + 1.40252 * cr_shifted
24
+ g = y - 0.71440 * cr_shifted - 0.34434 * cb_shifted
25
+ b = y + 1.77305 * cb_shifted
26
+ return torch.stack([r, g, b], -3)
26
27
 
27
28
 
28
29
  def rgb_to_ycbcr(image) -> torch.Tensor:
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2024 dariush-bahrami
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
@@ -1,52 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- import numpy as np
3
-
4
- from .basis import DCTBasis
5
-
6
-
7
- def visualize_dct_basis_functions(
8
- dct_constants: DCTBasis,
9
- figsize: int = 8,
10
- fig_facecolor: str = "#fb6a2c",
11
- title_color: str = "k",
12
- title_fontsize: int = 20,
13
- cmap: str = "gray",
14
- ) -> tuple:
15
- """Visualize the DCT basis functions.
16
-
17
- Args:
18
- dct_constants (DCTBasis): The DCT basis constants.
19
- figsize (int, optional): The figure size. Defaults to 8.
20
- fig_facecolor (str, optional): The figure facecolor. Defaults to "#fb6a2c".
21
- title_color (str, optional): The title color. Defaults to "k".
22
- title_fontsize (int, optional): The title fontsize. Defaults to 20.
23
- cmap (str, optional): The colormap. Defaults to "gray".
24
-
25
- Returns:
26
- tuple: The figure and axis.
27
- """
28
- block_size = dct_constants.block_size
29
- basis_functions = dct_constants.basis_functions
30
- basis_functions_image = np.zeros((block_size * block_size, block_size * block_size))
31
- for v in range(block_size):
32
- for u in range(block_size):
33
- basis_functions_image[
34
- v * block_size : (v + 1) * block_size,
35
- u * block_size : (u + 1) * block_size,
36
- ] = basis_functions[v, u]
37
- plt.figure(figsize=(figsize, figsize), facecolor=fig_facecolor)
38
- plt.title(
39
- f"DCT Basis functions (block size: {block_size}x{block_size})",
40
- color=title_color,
41
- fontsize=title_fontsize,
42
- fontweight="bold",
43
- )
44
- plt.imshow(basis_functions_image, cmap=cmap)
45
- plt.axis("off")
46
- for i in range(block_size):
47
- plt.axhline(i * block_size - 0.5, color=fig_facecolor)
48
- plt.axvline(i * block_size - 0.5, color=fig_facecolor)
49
- plt.tight_layout()
50
- fig = plt.gcf()
51
- ax = plt.gca()
52
- return fig, ax
@@ -1,16 +0,0 @@
1
- [tool.poetry]
2
- name = "dct-autoencoder"
3
- version = "0.2.0"
4
- description = ""
5
- authors = ["Dariush Bahrami <dariushbahrami1993@gmail.com>"]
6
- readme = "README.md"
7
-
8
- [tool.poetry.dependencies]
9
- python = "^3.10"
10
- numpy = "^2.1.1"
11
- matplotlib = {version = "^3.9.2", optional = true}
12
- torch = "^2.4.1"
13
-
14
- [build-system]
15
- requires = ["poetry-core"]
16
- build-backend = "poetry.core.masonry.api"