dct-autoencoder 0.2.0__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dct_autoencoder-0.2.0 → dct_autoencoder-0.3.0}/PKG-INFO +12 -13
- dct_autoencoder-0.3.0/pyproject.toml +36 -0
- {dct_autoencoder-0.2.0 → dct_autoencoder-0.3.0/src}/dct_autoencoder/core.py +59 -65
- {dct_autoencoder-0.2.0 → dct_autoencoder-0.3.0/src}/dct_autoencoder/utils.py +5 -4
- dct_autoencoder-0.2.0/LICENSE +0 -21
- dct_autoencoder-0.2.0/dct_autoencoder/visualization.py +0 -52
- dct_autoencoder-0.2.0/pyproject.toml +0 -16
- {dct_autoencoder-0.2.0 → dct_autoencoder-0.3.0}/README.md +0 -0
- {dct_autoencoder-0.2.0 → dct_autoencoder-0.3.0/src}/dct_autoencoder/__init__.py +0 -0
- {dct_autoencoder-0.2.0 → dct_autoencoder-0.3.0/src}/dct_autoencoder/basis.py +0 -0
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
2
|
Name: dct-autoencoder
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary:
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Discrete Cosine Transform in PyTorch
|
|
5
5
|
Author: Dariush Bahrami
|
|
6
|
-
Author-email: dariushbahrami1993@gmail.com
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
6
|
+
Author-email: Dariush Bahrami <dariushbahrami1993@gmail.com>
|
|
7
|
+
License: MIT
|
|
8
|
+
Requires-Dist: numpy>=2.4.6
|
|
9
|
+
Requires-Dist: torch>=2.0.0 ; extra == 'torch'
|
|
10
|
+
Requires-Dist: torchvision>=0.15.0 ; extra == 'torch'
|
|
11
|
+
Requires-Python: >=3.12
|
|
12
|
+
Project-URL: Homepage, https://github.com/dariush-bahrami/dct-autoencoder
|
|
13
|
+
Project-URL: Repository, https://github.com/dariush-bahrami/dct-autoencoder
|
|
14
|
+
Provides-Extra: torch
|
|
15
15
|
Description-Content-Type: text/markdown
|
|
16
16
|
|
|
17
17
|
# DCT-Autoencoder
|
|
@@ -57,4 +57,3 @@ DCT basis functions for a block size of 16:
|
|
|
57
57
|
- [x] Improve documentation
|
|
58
58
|
- [ ] Add unit tests
|
|
59
59
|
- [x] Distribute package on PyPI
|
|
60
|
-
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "dct-autoencoder"
|
|
3
|
+
version = "0.3.0"
|
|
4
|
+
description = "Discrete Cosine Transform in PyTorch"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
license = { text = "MIT" }
|
|
7
|
+
authors = [
|
|
8
|
+
{ name = "Dariush Bahrami", email = "dariushbahrami1993@gmail.com" }
|
|
9
|
+
]
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"numpy>=2.4.6",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
[project.optional-dependencies]
|
|
16
|
+
torch = [
|
|
17
|
+
"torch>=2.0.0",
|
|
18
|
+
"torchvision>=0.15.0",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
[project.urls]
|
|
22
|
+
Homepage = "https://github.com/dariush-bahrami/dct-autoencoder"
|
|
23
|
+
Repository = "https://github.com/dariush-bahrami/dct-autoencoder"
|
|
24
|
+
|
|
25
|
+
[build-system]
|
|
26
|
+
requires = ["uv_build>=0.11.16,<0.12.0"]
|
|
27
|
+
build-backend = "uv_build"
|
|
28
|
+
|
|
29
|
+
[dependency-groups]
|
|
30
|
+
dev = [
|
|
31
|
+
"dct-autoencoder[torch]",
|
|
32
|
+
"ipykernel>=7.2.0",
|
|
33
|
+
"ipywidgets>=8.1.8",
|
|
34
|
+
"matplotlib>=3.10.9",
|
|
35
|
+
"pillow>=12.2.0",
|
|
36
|
+
]
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
import math
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
import torch
|
|
5
3
|
from torch import nn
|
|
@@ -14,15 +12,38 @@ class DCTAutoencoder(nn.Module):
|
|
|
14
12
|
|
|
15
13
|
Args:
|
|
16
14
|
block_size (int, optional): The block size. Defaults to 8.
|
|
15
|
+
num_luminance_compressed_channels (int | None, optional): Number of lowest-frequency
|
|
16
|
+
luminance channels to retain after compression. ``None`` keeps all channels
|
|
17
|
+
(no compression). Defaults to ``None``.
|
|
18
|
+
num_chrominance_compressed_channels (int | None, optional): Number of lowest-frequency
|
|
19
|
+
chrominance channels to retain after compression. ``None`` keeps all channels
|
|
20
|
+
(no compression). Defaults to ``None``.
|
|
17
21
|
"""
|
|
18
22
|
|
|
19
23
|
def __init__(
|
|
20
24
|
self,
|
|
21
25
|
block_size: int = 8,
|
|
22
|
-
|
|
23
|
-
|
|
26
|
+
num_luminance_compressed_channels: int | None = None,
|
|
27
|
+
num_chrominance_compressed_channels: int | None = None,
|
|
24
28
|
) -> None:
|
|
25
29
|
super().__init__()
|
|
30
|
+
total_channels = block_size**2
|
|
31
|
+
|
|
32
|
+
if num_luminance_compressed_channels is not None and not (
|
|
33
|
+
1 <= num_luminance_compressed_channels <= total_channels
|
|
34
|
+
):
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"num_luminance_compressed_channels must be between 1 and {total_channels}, "
|
|
37
|
+
f"got {num_luminance_compressed_channels}"
|
|
38
|
+
)
|
|
39
|
+
if num_chrominance_compressed_channels is not None and not (
|
|
40
|
+
1 <= num_chrominance_compressed_channels <= total_channels
|
|
41
|
+
):
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"num_chrominance_compressed_channels must be between 1 and {total_channels}, "
|
|
44
|
+
f"got {num_chrominance_compressed_channels}"
|
|
45
|
+
)
|
|
46
|
+
|
|
26
47
|
dct_basis = get_dct_basis(block_size)
|
|
27
48
|
basis_functions = dct_basis.basis_functions
|
|
28
49
|
kernels = basis_functions.reshape(-1, block_size, block_size)
|
|
@@ -32,7 +53,7 @@ class DCTAutoencoder(nn.Module):
|
|
|
32
53
|
spatial_frequencies_components = (
|
|
33
54
|
dct_basis.spatial_frequencies_components.reshape(-1, 2)
|
|
34
55
|
)
|
|
35
|
-
sort_indices = np.argsort(spatial_frequencies_magnitude)
|
|
56
|
+
sort_indices = np.argsort(spatial_frequencies_magnitude, kind="stable")
|
|
36
57
|
kernels = kernels[sort_indices]
|
|
37
58
|
spatial_frequencies_magnitude = spatial_frequencies_magnitude[sort_indices]
|
|
38
59
|
spatial_frequencies_components = spatial_frequencies_components[sort_indices]
|
|
@@ -62,64 +83,32 @@ class DCTAutoencoder(nn.Module):
|
|
|
62
83
|
torch.from_numpy(multiplication_factor_matrix),
|
|
63
84
|
)
|
|
64
85
|
|
|
65
|
-
self.embedding_dimension =
|
|
86
|
+
self.embedding_dimension = total_channels * 3
|
|
66
87
|
|
|
67
|
-
# compressor initialization
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
else:
|
|
83
|
-
original_frequencies = self.spatial_frequencies_components.to(
|
|
84
|
-
dtype=torch.float32
|
|
85
|
-
)
|
|
86
|
-
luminance_block_size = math.ceil(block_size * luminance_compression_ratio)
|
|
87
|
-
chrominance_block_size = math.ceil(
|
|
88
|
-
block_size * chrominance_compression_ratio
|
|
89
|
-
)
|
|
90
|
-
luminance_frequencies = get_dct_basis(
|
|
91
|
-
luminance_block_size
|
|
92
|
-
).spatial_frequencies_components.reshape(-1, 2)
|
|
93
|
-
luminance_frequencies = torch.from_numpy(luminance_frequencies).to(
|
|
94
|
-
device=original_frequencies.device, dtype=torch.float32
|
|
95
|
-
)
|
|
96
|
-
chrominance_frequencies = get_dct_basis(
|
|
97
|
-
chrominance_block_size
|
|
98
|
-
).spatial_frequencies_components.reshape(-1, 2)
|
|
99
|
-
chrominance_frequencies = torch.from_numpy(chrominance_frequencies).to(
|
|
100
|
-
device=original_frequencies.device, dtype=torch.float32
|
|
101
|
-
)
|
|
102
|
-
indices = torch.arange(block_size**2, device=original_frequencies.device)
|
|
103
|
-
luminance_mask = torch.isin(
|
|
104
|
-
indices,
|
|
105
|
-
torch.cdist(original_frequencies, luminance_frequencies, p=2).argmin(
|
|
106
|
-
dim=0
|
|
107
|
-
),
|
|
108
|
-
)
|
|
109
|
-
chrominance_mask = torch.isin(
|
|
110
|
-
indices,
|
|
111
|
-
torch.cdist(original_frequencies, chrominance_frequencies, p=2).argmin(
|
|
112
|
-
dim=0
|
|
113
|
-
),
|
|
114
|
-
)
|
|
115
|
-
luminance_passband = luminance_mask.sum()
|
|
116
|
-
chrominance_passband = chrominance_mask.sum()
|
|
88
|
+
# compressor initialization — kernels are already sorted by ascending frequency,
|
|
89
|
+
# so keeping the first N channels retains the N lowest frequencies exactly.
|
|
90
|
+
lum_n = num_luminance_compressed_channels
|
|
91
|
+
chr_n = num_chrominance_compressed_channels
|
|
92
|
+
self.do_compression = (lum_n is not None and lum_n < total_channels) or (
|
|
93
|
+
chr_n is not None and chr_n < total_channels
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
lum_passband = lum_n if lum_n is not None else total_channels
|
|
97
|
+
chr_passband = chr_n if chr_n is not None else total_channels
|
|
98
|
+
|
|
99
|
+
lum_mask = torch.zeros(total_channels, dtype=torch.bool)
|
|
100
|
+
lum_mask[:lum_passband] = True
|
|
101
|
+
chr_mask = torch.zeros(total_channels, dtype=torch.bool)
|
|
102
|
+
chr_mask[:chr_passband] = True
|
|
117
103
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
104
|
+
self.register_buffer("compression_luminance_mask", lum_mask)
|
|
105
|
+
self.register_buffer("compression_chrominance_mask", chr_mask)
|
|
106
|
+
self.register_buffer(
|
|
107
|
+
"compression_luminance_passband", torch.tensor(lum_passband)
|
|
108
|
+
)
|
|
109
|
+
self.register_buffer(
|
|
110
|
+
"compression_chrominance_passband", torch.tensor(chr_passband)
|
|
111
|
+
)
|
|
123
112
|
|
|
124
113
|
def encode(self, rgb_images_batch: torch.Tensor) -> torch.Tensor:
|
|
125
114
|
"""Encodes the input RGB images.
|
|
@@ -156,12 +145,14 @@ class DCTAutoencoder(nn.Module):
|
|
|
156
145
|
encodings_batch = encodings_batch / self.block_size
|
|
157
146
|
return encodings_batch
|
|
158
147
|
|
|
159
|
-
def decode(
|
|
148
|
+
def decode(
|
|
149
|
+
self, encodings_batch: torch.Tensor, clamp_output: bool = True
|
|
150
|
+
) -> torch.Tensor:
|
|
160
151
|
"""Decodes the input encoded images.
|
|
161
152
|
|
|
162
153
|
Args:
|
|
163
154
|
encodings_batch (torch.Tensor): The input encoded images.
|
|
164
|
-
|
|
155
|
+
clamp_output (bool, optional): Whether to clamp the output to the range [0, 1]. Defaults to True.
|
|
165
156
|
Returns:
|
|
166
157
|
torch.Tensor: The decoded images.
|
|
167
158
|
"""
|
|
@@ -187,6 +178,9 @@ class DCTAutoencoder(nn.Module):
|
|
|
187
178
|
ycbcr_tsr = torch.cat([y, cb, cr], dim=1)
|
|
188
179
|
ycbcr_tsr = ycbcr_tsr / 2 + 0.5
|
|
189
180
|
rgb_images_batch = ycbcr_to_rgb(ycbcr_tsr)
|
|
181
|
+
if clamp_output:
|
|
182
|
+
# clamp is expected for display; note it makes decode non-linear for out-of-range values
|
|
183
|
+
rgb_images_batch = rgb_images_batch.clamp(0, 1)
|
|
190
184
|
return rgb_images_batch
|
|
191
185
|
|
|
192
186
|
def get_num_compressed_channels(self) -> int:
|
|
@@ -198,7 +192,7 @@ class DCTAutoencoder(nn.Module):
|
|
|
198
192
|
+ 2 * self.compression_chrominance_passband.item()
|
|
199
193
|
)
|
|
200
194
|
|
|
201
|
-
def compress(self, encodings):
|
|
195
|
+
def compress(self, encodings: torch.Tensor) -> torch.Tensor:
|
|
202
196
|
if not self.do_compression:
|
|
203
197
|
return encodings
|
|
204
198
|
else:
|
|
@@ -211,7 +205,7 @@ class DCTAutoencoder(nn.Module):
|
|
|
211
205
|
compressed_encoding = torch.cat([l, c1, c2], dim=1)
|
|
212
206
|
return compressed_encoding
|
|
213
207
|
|
|
214
|
-
def decompress(self, compressed_encoding):
|
|
208
|
+
def decompress(self, compressed_encoding: torch.Tensor) -> torch.Tensor:
|
|
215
209
|
if not self.do_compression:
|
|
216
210
|
return compressed_encoding
|
|
217
211
|
else:
|
|
@@ -19,10 +19,11 @@ def ycbcr_to_rgb(image: torch.Tensor) -> torch.Tensor:
|
|
|
19
19
|
cb_shifted = cb - delta
|
|
20
20
|
cr_shifted = cr - delta
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
# Exact inverse of the forward matrix: 1/0.713, 0.299/(0.713*0.587), 0.114/(0.564*0.587), 1/0.564
|
|
23
|
+
r = y + 1.40252 * cr_shifted
|
|
24
|
+
g = y - 0.71440 * cr_shifted - 0.34434 * cb_shifted
|
|
25
|
+
b = y + 1.77305 * cb_shifted
|
|
26
|
+
return torch.stack([r, g, b], -3)
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
def rgb_to_ycbcr(image) -> torch.Tensor:
|
dct_autoencoder-0.2.0/LICENSE
DELETED
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
MIT License
|
|
2
|
-
|
|
3
|
-
Copyright (c) 2024 dariush-bahrami
|
|
4
|
-
|
|
5
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
-
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
-
in the Software without restriction, including without limitation the rights
|
|
8
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
-
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
-
furnished to do so, subject to the following conditions:
|
|
11
|
-
|
|
12
|
-
The above copyright notice and this permission notice shall be included in all
|
|
13
|
-
copies or substantial portions of the Software.
|
|
14
|
-
|
|
15
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
-
SOFTWARE.
|
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
import matplotlib.pyplot as plt
|
|
2
|
-
import numpy as np
|
|
3
|
-
|
|
4
|
-
from .basis import DCTBasis
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def visualize_dct_basis_functions(
|
|
8
|
-
dct_constants: DCTBasis,
|
|
9
|
-
figsize: int = 8,
|
|
10
|
-
fig_facecolor: str = "#fb6a2c",
|
|
11
|
-
title_color: str = "k",
|
|
12
|
-
title_fontsize: int = 20,
|
|
13
|
-
cmap: str = "gray",
|
|
14
|
-
) -> tuple:
|
|
15
|
-
"""Visualize the DCT basis functions.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
dct_constants (DCTBasis): The DCT basis constants.
|
|
19
|
-
figsize (int, optional): The figure size. Defaults to 8.
|
|
20
|
-
fig_facecolor (str, optional): The figure facecolor. Defaults to "#fb6a2c".
|
|
21
|
-
title_color (str, optional): The title color. Defaults to "k".
|
|
22
|
-
title_fontsize (int, optional): The title fontsize. Defaults to 20.
|
|
23
|
-
cmap (str, optional): The colormap. Defaults to "gray".
|
|
24
|
-
|
|
25
|
-
Returns:
|
|
26
|
-
tuple: The figure and axis.
|
|
27
|
-
"""
|
|
28
|
-
block_size = dct_constants.block_size
|
|
29
|
-
basis_functions = dct_constants.basis_functions
|
|
30
|
-
basis_functions_image = np.zeros((block_size * block_size, block_size * block_size))
|
|
31
|
-
for v in range(block_size):
|
|
32
|
-
for u in range(block_size):
|
|
33
|
-
basis_functions_image[
|
|
34
|
-
v * block_size : (v + 1) * block_size,
|
|
35
|
-
u * block_size : (u + 1) * block_size,
|
|
36
|
-
] = basis_functions[v, u]
|
|
37
|
-
plt.figure(figsize=(figsize, figsize), facecolor=fig_facecolor)
|
|
38
|
-
plt.title(
|
|
39
|
-
f"DCT Basis functions (block size: {block_size}x{block_size})",
|
|
40
|
-
color=title_color,
|
|
41
|
-
fontsize=title_fontsize,
|
|
42
|
-
fontweight="bold",
|
|
43
|
-
)
|
|
44
|
-
plt.imshow(basis_functions_image, cmap=cmap)
|
|
45
|
-
plt.axis("off")
|
|
46
|
-
for i in range(block_size):
|
|
47
|
-
plt.axhline(i * block_size - 0.5, color=fig_facecolor)
|
|
48
|
-
plt.axvline(i * block_size - 0.5, color=fig_facecolor)
|
|
49
|
-
plt.tight_layout()
|
|
50
|
-
fig = plt.gcf()
|
|
51
|
-
ax = plt.gca()
|
|
52
|
-
return fig, ax
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
[tool.poetry]
|
|
2
|
-
name = "dct-autoencoder"
|
|
3
|
-
version = "0.2.0"
|
|
4
|
-
description = ""
|
|
5
|
-
authors = ["Dariush Bahrami <dariushbahrami1993@gmail.com>"]
|
|
6
|
-
readme = "README.md"
|
|
7
|
-
|
|
8
|
-
[tool.poetry.dependencies]
|
|
9
|
-
python = "^3.10"
|
|
10
|
-
numpy = "^2.1.1"
|
|
11
|
-
matplotlib = {version = "^3.9.2", optional = true}
|
|
12
|
-
torch = "^2.4.1"
|
|
13
|
-
|
|
14
|
-
[build-system]
|
|
15
|
-
requires = ["poetry-core"]
|
|
16
|
-
build-backend = "poetry.core.masonry.api"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|