dct-autoencoder 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dct_autoencoder/__init__.py +2 -0
- dct_autoencoder/basis.py +56 -0
- dct_autoencoder/core.py +260 -0
- dct_autoencoder/utils.py +46 -0
- dct_autoencoder/visualization.py +52 -0
- dct_autoencoder-0.1.0.dist-info/LICENSE +21 -0
- dct_autoencoder-0.1.0.dist-info/METADATA +34 -0
- dct_autoencoder-0.1.0.dist-info/RECORD +9 -0
- dct_autoencoder-0.1.0.dist-info/WHEEL +4 -0
dct_autoencoder/basis.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import NamedTuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DCTBasis(NamedTuple):
|
|
7
|
+
basis_functions: np.ndarray
|
|
8
|
+
spatial_frequencies_components: np.ndarray
|
|
9
|
+
spatial_frequencies_magnitude: np.ndarray
|
|
10
|
+
multiplication_factor_matrix: np.ndarray
|
|
11
|
+
multiplication_factor_scalar: float
|
|
12
|
+
block_size: int
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_dct_basis(block_size: int = 8) -> DCTBasis:
|
|
16
|
+
"""Generate the DCT basis variables for a given block size.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
block_size (int, optional): The block size. Defaults to 8.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
DCTBasis: The DCT basis variables.
|
|
23
|
+
"""
|
|
24
|
+
frequencies = np.arange(block_size)
|
|
25
|
+
x = np.arange(block_size)
|
|
26
|
+
y = np.arange(block_size)
|
|
27
|
+
x, y = np.meshgrid(x, y, indexing="xy")
|
|
28
|
+
basis_functions = np.zeros(
|
|
29
|
+
(block_size, block_size, block_size, block_size), dtype=np.float32
|
|
30
|
+
)
|
|
31
|
+
spatial_frequencies = np.zeros((block_size, block_size, 2), dtype=np.int64)
|
|
32
|
+
multiplication_factor_matrix = np.zeros((block_size, block_size), dtype=np.float32)
|
|
33
|
+
for v in frequencies:
|
|
34
|
+
for u in frequencies:
|
|
35
|
+
# spatial frequencies
|
|
36
|
+
spatial_frequencies[v, u] = (v, u)
|
|
37
|
+
# basis functions
|
|
38
|
+
x_ref_patch = np.cos(((2 * x + 1) * u * np.pi) / (2 * block_size))
|
|
39
|
+
y_ref_patch = np.cos(((2 * y + 1) * v * np.pi) / (2 * block_size))
|
|
40
|
+
basis_functions[v, u] = x_ref_patch * y_ref_patch
|
|
41
|
+
# constants
|
|
42
|
+
c_v = 1 / np.sqrt(2) if v == 0 else 1
|
|
43
|
+
c_u = 1 / np.sqrt(2) if u == 0 else 1
|
|
44
|
+
multiplication_factor_matrix[v, u] = c_u * c_v
|
|
45
|
+
|
|
46
|
+
spatial_frequencies_magnitude = np.linalg.norm(spatial_frequencies, axis=2)
|
|
47
|
+
multiplication_factor_scalar = 2 / block_size
|
|
48
|
+
|
|
49
|
+
return DCTBasis(
|
|
50
|
+
basis_functions=basis_functions,
|
|
51
|
+
spatial_frequencies_components=spatial_frequencies,
|
|
52
|
+
spatial_frequencies_magnitude=spatial_frequencies_magnitude,
|
|
53
|
+
multiplication_factor_matrix=multiplication_factor_matrix,
|
|
54
|
+
multiplication_factor_scalar=multiplication_factor_scalar,
|
|
55
|
+
block_size=block_size,
|
|
56
|
+
)
|
dct_autoencoder/core.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from torch import nn
|
|
4
|
+
from torch.nn import functional as F
|
|
5
|
+
|
|
6
|
+
from .basis import get_dct_basis
|
|
7
|
+
from .utils import rgb_to_ycbcr, ycbcr_to_rgb
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DCTAutoencoder(nn.Module):
|
|
11
|
+
"""DCT Autoencoder.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
block_size (int, optional): The block size. Defaults to 8.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, block_size: int = 8) -> None:
|
|
18
|
+
super().__init__()
|
|
19
|
+
dct_basis = get_dct_basis(block_size)
|
|
20
|
+
basis_functions = dct_basis.basis_functions
|
|
21
|
+
kernels = basis_functions.reshape(-1, block_size, block_size)
|
|
22
|
+
spatial_frequencies_magnitude = dct_basis.spatial_frequencies_magnitude.reshape(
|
|
23
|
+
-1
|
|
24
|
+
)
|
|
25
|
+
sort_indices = np.argsort(spatial_frequencies_magnitude)
|
|
26
|
+
kernels = kernels[sort_indices]
|
|
27
|
+
spatial_frequencies_magnitude = spatial_frequencies_magnitude[sort_indices]
|
|
28
|
+
kernels = kernels[:, np.newaxis, :, :]
|
|
29
|
+
multiplication_factor_scalar = dct_basis.multiplication_factor_scalar
|
|
30
|
+
multiplication_factor_matrix = dct_basis.multiplication_factor_matrix
|
|
31
|
+
multiplication_factor_matrix = multiplication_factor_matrix.reshape(-1)
|
|
32
|
+
multiplication_factor_matrix = multiplication_factor_matrix[sort_indices]
|
|
33
|
+
multiplication_factor_matrix = multiplication_factor_matrix[
|
|
34
|
+
np.newaxis, :, np.newaxis, np.newaxis
|
|
35
|
+
]
|
|
36
|
+
self.register_buffer("kernels", torch.from_numpy(kernels))
|
|
37
|
+
self.register_buffer(
|
|
38
|
+
"spatial_frequencies_magnitude",
|
|
39
|
+
torch.from_numpy(spatial_frequencies_magnitude),
|
|
40
|
+
)
|
|
41
|
+
self.register_buffer("block_size", torch.tensor(block_size))
|
|
42
|
+
self.register_buffer(
|
|
43
|
+
"multiplication_factor_scalar", torch.tensor(multiplication_factor_scalar)
|
|
44
|
+
)
|
|
45
|
+
self.register_buffer(
|
|
46
|
+
"multiplication_factor_matrix",
|
|
47
|
+
torch.from_numpy(multiplication_factor_matrix),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
self.embedding_dimension = (block_size**2) * 3
|
|
51
|
+
|
|
52
|
+
def encode(self, rgb_images_batch: torch.Tensor) -> torch.Tensor:
|
|
53
|
+
"""Encodes the input RGB images.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
rgb_images_batch (torch.Tensor): The input RGB images. The images should
|
|
57
|
+
have shape (*, 3, height, width). Image values should be in the range [0, 1].
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
torch.Tensor: The encoded images.
|
|
61
|
+
"""
|
|
62
|
+
# check input
|
|
63
|
+
b, c, h, w = rgb_images_batch.shape
|
|
64
|
+
if c != 3:
|
|
65
|
+
raise ValueError("Input images must be RGB")
|
|
66
|
+
if h % self.block_size != 0 or w % self.block_size != 0:
|
|
67
|
+
raise ValueError("Image dimensions must be divisible by the block size")
|
|
68
|
+
# convert to YCbCr
|
|
69
|
+
ycbcr_tsr = rgb_to_ycbcr(rgb_images_batch)
|
|
70
|
+
# normalize to -1, 1
|
|
71
|
+
ycbcr_tsr = 2 * ycbcr_tsr - 1
|
|
72
|
+
y = ycbcr_tsr[:, [0], :, :]
|
|
73
|
+
cb = ycbcr_tsr[:, [1], :, :]
|
|
74
|
+
cr = ycbcr_tsr[:, [2], :, :]
|
|
75
|
+
|
|
76
|
+
# DCT encode
|
|
77
|
+
c1 = self.multiplication_factor_scalar
|
|
78
|
+
c2 = self.multiplication_factor_matrix
|
|
79
|
+
y = c1 * c2 * F.conv2d(y, self.kernels, stride=self.block_size.item())
|
|
80
|
+
cb = c1 * c2 * F.conv2d(cb, self.kernels, stride=self.block_size.item())
|
|
81
|
+
cr = c1 * c2 * F.conv2d(cr, self.kernels, stride=self.block_size.item())
|
|
82
|
+
|
|
83
|
+
return torch.cat([y, cb, cr], dim=1)
|
|
84
|
+
|
|
85
|
+
def decode(self, encodings_batch: torch.Tensor) -> torch.Tensor:
|
|
86
|
+
"""Decodes the input encoded images.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
encodings_batch (torch.Tensor): The input encoded images.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
torch.Tensor: The decoded images.
|
|
93
|
+
"""
|
|
94
|
+
org_ch = self.block_size**2
|
|
95
|
+
y = encodings_batch[:, :org_ch, :, :]
|
|
96
|
+
cb = encodings_batch[:, org_ch : org_ch * 2, :, :]
|
|
97
|
+
cr = encodings_batch[:, org_ch * 2 :, :, :]
|
|
98
|
+
|
|
99
|
+
# DCT Decode
|
|
100
|
+
c1 = self.multiplication_factor_scalar
|
|
101
|
+
c2 = self.multiplication_factor_matrix
|
|
102
|
+
y = c1 * F.conv_transpose2d(y * c2, self.kernels, stride=self.block_size.item())
|
|
103
|
+
cb = c1 * F.conv_transpose2d(
|
|
104
|
+
cb * c2, self.kernels, stride=self.block_size.item()
|
|
105
|
+
)
|
|
106
|
+
cr = c1 * F.conv_transpose2d(
|
|
107
|
+
cr * c2, self.kernels, stride=self.block_size.item()
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# convert to RGB
|
|
111
|
+
ycbcr_tsr = torch.cat([y, cb, cr], dim=1)
|
|
112
|
+
ycbcr_tsr = ycbcr_tsr / 2 + 0.5
|
|
113
|
+
rgb_images_batch = ycbcr_to_rgb(ycbcr_tsr)
|
|
114
|
+
return rgb_images_batch
|
|
115
|
+
|
|
116
|
+
def get_num_compressed_channels(
|
|
117
|
+
self,
|
|
118
|
+
luminance_compression_ratio: float = 1 / 2,
|
|
119
|
+
chrominance_compression_ratio: float = 1 / 4,
|
|
120
|
+
) -> int:
|
|
121
|
+
"""Get the number of compressed channels.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
luminance_compression_ratio (float, optional): The luminance compression
|
|
125
|
+
ratio. Defaults to 1/2.
|
|
126
|
+
chrominance_compression_ratio (float, optional): The chrominance compression
|
|
127
|
+
ratio. Defaults to 1/4.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
int: The number of compressed channels.
|
|
131
|
+
"""
|
|
132
|
+
num_per_channel_encodings = self.block_size**2
|
|
133
|
+
num_luminance_encodings = torch.round(
|
|
134
|
+
num_per_channel_encodings * luminance_compression_ratio
|
|
135
|
+
).int()
|
|
136
|
+
num_chrominance_encodings = torch.round(
|
|
137
|
+
num_per_channel_encodings * chrominance_compression_ratio
|
|
138
|
+
).int()
|
|
139
|
+
return (num_luminance_encodings + 2 * num_chrominance_encodings).item()
|
|
140
|
+
|
|
141
|
+
def compress(
|
|
142
|
+
self,
|
|
143
|
+
encodings_batch: torch.Tensor,
|
|
144
|
+
luminance_compression_ratio: float = 1 / 2,
|
|
145
|
+
chrominance_compression_ratio: float = 1 / 4,
|
|
146
|
+
) -> torch.Tensor:
|
|
147
|
+
"""Compresses the input encodings.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
encodings_batch (torch.Tensor): The input encodings.
|
|
151
|
+
luminance_compression_ratio (float, optional): The luminance compression
|
|
152
|
+
ratio. Defaults to 1/2.
|
|
153
|
+
chrominance_compression_ratio (float, optional): The chrominance compression
|
|
154
|
+
ratio. Defaults to 1/4.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
torch.Tensor: The compressed encodings.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
num_per_channel_encodings = self.block_size**2
|
|
161
|
+
num_luminance_encodings = torch.round(
|
|
162
|
+
num_per_channel_encodings * luminance_compression_ratio
|
|
163
|
+
).int()
|
|
164
|
+
num_chrominance_encodings = torch.round(
|
|
165
|
+
num_per_channel_encodings * chrominance_compression_ratio
|
|
166
|
+
).int()
|
|
167
|
+
|
|
168
|
+
luminance_encodings = encodings_batch[:, :num_per_channel_encodings]
|
|
169
|
+
chrominance_blue_encodings = encodings_batch[
|
|
170
|
+
:, num_per_channel_encodings : 2 * num_per_channel_encodings
|
|
171
|
+
]
|
|
172
|
+
chrominance_red_encodings = encodings_batch[:, 2 * num_per_channel_encodings :]
|
|
173
|
+
|
|
174
|
+
luminance_encodings = luminance_encodings[:, :num_luminance_encodings]
|
|
175
|
+
chrominance_blue_encodings = chrominance_blue_encodings[
|
|
176
|
+
:, :num_chrominance_encodings
|
|
177
|
+
]
|
|
178
|
+
chrominance_red_encodings = chrominance_red_encodings[
|
|
179
|
+
:, :num_chrominance_encodings
|
|
180
|
+
]
|
|
181
|
+
compressed_dct_encodings = torch.cat(
|
|
182
|
+
[
|
|
183
|
+
luminance_encodings,
|
|
184
|
+
chrominance_blue_encodings,
|
|
185
|
+
chrominance_red_encodings,
|
|
186
|
+
],
|
|
187
|
+
dim=1,
|
|
188
|
+
)
|
|
189
|
+
return compressed_dct_encodings
|
|
190
|
+
|
|
191
|
+
def decompress(
|
|
192
|
+
self,
|
|
193
|
+
compressed_encodings_batch: torch.Tensor,
|
|
194
|
+
luminance_compression_ratio: float = 1 / 2,
|
|
195
|
+
chrominance_compression_ratio: float = 1 / 4,
|
|
196
|
+
) -> torch.Tensor:
|
|
197
|
+
"""Decompresses the input compressed encodings.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
compressed_encodings_batch (torch.Tensor): The input compressed encodings.
|
|
201
|
+
luminance_compression_ratio (float, optional): The luminance compression
|
|
202
|
+
ratio. Defaults to 1/2.
|
|
203
|
+
chrominance_compression_ratio (float, optional): The chrominance compression
|
|
204
|
+
ratio. Defaults to 1/4.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
torch.Tensor: The decompressed encodings.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
b, _, h, w = compressed_encodings_batch.shape
|
|
211
|
+
dtype = compressed_encodings_batch.dtype
|
|
212
|
+
device = compressed_encodings_batch.device
|
|
213
|
+
|
|
214
|
+
num_per_channel_encodings = self.block_size**2
|
|
215
|
+
num_luminance_encodings = torch.floor(
|
|
216
|
+
num_per_channel_encodings * luminance_compression_ratio
|
|
217
|
+
).int()
|
|
218
|
+
num_chrominance_encodings = torch.floor(
|
|
219
|
+
num_per_channel_encodings * chrominance_compression_ratio
|
|
220
|
+
).int()
|
|
221
|
+
compressed_luminance_encodings = compressed_encodings_batch[
|
|
222
|
+
:, :num_luminance_encodings
|
|
223
|
+
]
|
|
224
|
+
compressed_chrominance_blue_encodings = compressed_encodings_batch[
|
|
225
|
+
:,
|
|
226
|
+
num_luminance_encodings : num_luminance_encodings
|
|
227
|
+
+ num_chrominance_encodings,
|
|
228
|
+
]
|
|
229
|
+
compressed_chrominance_red_encodings = compressed_encodings_batch[
|
|
230
|
+
:, num_luminance_encodings + num_chrominance_encodings :
|
|
231
|
+
]
|
|
232
|
+
|
|
233
|
+
luminance_encodings = torch.zeros(
|
|
234
|
+
b, num_per_channel_encodings, h, w, dtype=dtype, device=device
|
|
235
|
+
)
|
|
236
|
+
luminance_encodings[:, :num_luminance_encodings, :, :] = (
|
|
237
|
+
compressed_luminance_encodings
|
|
238
|
+
)
|
|
239
|
+
chrominance_blue_encodings = torch.zeros(
|
|
240
|
+
b, num_per_channel_encodings, h, w, dtype=dtype, device=device
|
|
241
|
+
)
|
|
242
|
+
chrominance_blue_encodings[:, :num_chrominance_encodings, :, :] = (
|
|
243
|
+
compressed_chrominance_blue_encodings
|
|
244
|
+
)
|
|
245
|
+
chrominance_red_encodings = torch.zeros(
|
|
246
|
+
b, num_per_channel_encodings, h, w, dtype=dtype, device=device
|
|
247
|
+
)
|
|
248
|
+
chrominance_red_encodings[:, :num_chrominance_encodings, :, :] = (
|
|
249
|
+
compressed_chrominance_red_encodings
|
|
250
|
+
)
|
|
251
|
+
decompressed_dct_encodings = torch.cat(
|
|
252
|
+
[
|
|
253
|
+
luminance_encodings,
|
|
254
|
+
chrominance_blue_encodings,
|
|
255
|
+
chrominance_red_encodings,
|
|
256
|
+
],
|
|
257
|
+
dim=1,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
return decompressed_dct_encodings
|
dct_autoencoder/utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def ycbcr_to_rgb(image: torch.Tensor) -> torch.Tensor:
|
|
5
|
+
"""Converts an image from YCbCr to RGB color space.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
image (torch.Tensor): The input image. The image should have shape
|
|
9
|
+
(*, 3, height, width). Image values should be in the range [0, 1].
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
torch.Tensor: The output image in RGB color space.
|
|
13
|
+
"""
|
|
14
|
+
y = image[..., 0, :, :]
|
|
15
|
+
cb = image[..., 1, :, :]
|
|
16
|
+
cr = image[..., 2, :, :]
|
|
17
|
+
|
|
18
|
+
delta: float = 0.5
|
|
19
|
+
cb_shifted = cb - delta
|
|
20
|
+
cr_shifted = cr - delta
|
|
21
|
+
|
|
22
|
+
r = y + 1.403 * cr_shifted
|
|
23
|
+
g = y - 0.714 * cr_shifted - 0.344 * cb_shifted
|
|
24
|
+
b = y + 1.773 * cb_shifted
|
|
25
|
+
return torch.stack([r, g, b], -3).clamp(0, 1)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def rgb_to_ycbcr(image) -> torch.Tensor:
|
|
29
|
+
"""Converts an image from RGB to YCbCr color space.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
image (torch.Tensor): The input image. The image should have shape
|
|
33
|
+
(*, 3, height, width). Image values should be in the range [0, 1].
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
torch.Tensor: The output image in YCbCr color space.
|
|
37
|
+
"""
|
|
38
|
+
r = image[..., 0, :, :]
|
|
39
|
+
g = image[..., 1, :, :]
|
|
40
|
+
b = image[..., 2, :, :]
|
|
41
|
+
|
|
42
|
+
delta: float = 0.5
|
|
43
|
+
y = 0.299 * r + 0.587 * g + 0.114 * b
|
|
44
|
+
cb = (b - y) * 0.564 + delta
|
|
45
|
+
cr = (r - y) * 0.713 + delta
|
|
46
|
+
return torch.stack([y, cb, cr], -3)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from .basis import DCTBasis
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def visualize_dct_basis_functions(
|
|
8
|
+
dct_constants: DCTBasis,
|
|
9
|
+
figsize: int = 8,
|
|
10
|
+
fig_facecolor: str = "#fb6a2c",
|
|
11
|
+
title_color: str = "k",
|
|
12
|
+
title_fontsize: int = 20,
|
|
13
|
+
cmap: str = "gray",
|
|
14
|
+
) -> tuple:
|
|
15
|
+
"""Visualize the DCT basis functions.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
dct_constants (DCTBasis): The DCT basis constants.
|
|
19
|
+
figsize (int, optional): The figure size. Defaults to 8.
|
|
20
|
+
fig_facecolor (str, optional): The figure facecolor. Defaults to "#fb6a2c".
|
|
21
|
+
title_color (str, optional): The title color. Defaults to "k".
|
|
22
|
+
title_fontsize (int, optional): The title fontsize. Defaults to 20.
|
|
23
|
+
cmap (str, optional): The colormap. Defaults to "gray".
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
tuple: The figure and axis.
|
|
27
|
+
"""
|
|
28
|
+
block_size = dct_constants.block_size
|
|
29
|
+
basis_functions = dct_constants.basis_functions
|
|
30
|
+
basis_functions_image = np.zeros((block_size * block_size, block_size * block_size))
|
|
31
|
+
for v in range(block_size):
|
|
32
|
+
for u in range(block_size):
|
|
33
|
+
basis_functions_image[
|
|
34
|
+
v * block_size : (v + 1) * block_size,
|
|
35
|
+
u * block_size : (u + 1) * block_size,
|
|
36
|
+
] = basis_functions[v, u]
|
|
37
|
+
plt.figure(figsize=(figsize, figsize), facecolor=fig_facecolor)
|
|
38
|
+
plt.title(
|
|
39
|
+
f"DCT Basis functions (block size: {block_size}x{block_size})",
|
|
40
|
+
color=title_color,
|
|
41
|
+
fontsize=title_fontsize,
|
|
42
|
+
fontweight="bold",
|
|
43
|
+
)
|
|
44
|
+
plt.imshow(basis_functions_image, cmap=cmap)
|
|
45
|
+
plt.axis("off")
|
|
46
|
+
for i in range(block_size):
|
|
47
|
+
plt.axhline(i * block_size - 0.5, color=fig_facecolor)
|
|
48
|
+
plt.axvline(i * block_size - 0.5, color=fig_facecolor)
|
|
49
|
+
plt.tight_layout()
|
|
50
|
+
fig = plt.gcf()
|
|
51
|
+
ax = plt.gca()
|
|
52
|
+
return fig, ax
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 dariush-bahrami
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: dct-autoencoder
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary:
|
|
5
|
+
Author: Dariush Bahrami
|
|
6
|
+
Author-email: dariushbahrami1993@gmail.com
|
|
7
|
+
Requires-Python: >=3.10,<4.0
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Requires-Dist: matplotlib (>=3.9.2,<4.0.0)
|
|
13
|
+
Requires-Dist: numpy (>=2.1.1,<3.0.0)
|
|
14
|
+
Requires-Dist: torch (>=2.4.1,<3.0.0)
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
|
|
17
|
+
# dct-autoencoder
|
|
18
|
+
|
|
19
|
+
2D Discrete Cosine Transform in PyTorch
|
|
20
|
+
|
|
21
|
+

|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
## Usage
|
|
25
|
+
|
|
26
|
+
Refer to the [usage notebook](./usage.ipynb) for code examples.
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
## TODO
|
|
30
|
+
|
|
31
|
+
- [x] Add support for color images
|
|
32
|
+
- [x] Improve documentation
|
|
33
|
+
- [ ] Add tests
|
|
34
|
+
- [ ] Distribute on PyPI
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
dct_autoencoder/__init__.py,sha256=HPbvVdAiG_hWQ2ZNruAJnr8MOWJnMMQLjjwob-verAY,76
|
|
2
|
+
dct_autoencoder/basis.py,sha256=ynx3Plbts6snn7lcyfPUOCfFxUgcXMa3iYPWca7KCdI,2060
|
|
3
|
+
dct_autoencoder/core.py,sha256=ikVN_yeuSNN4NXbkU3kH_m32EIrkOn1z--tij-nCJIk,9965
|
|
4
|
+
dct_autoencoder/utils.py,sha256=oCFZgLAFFgV7dJoFXC3xfbtjM-fnEBpY1eSPywHIj4U,1306
|
|
5
|
+
dct_autoencoder/visualization.py,sha256=Mj9Ipz2oHTfzyq8hqM1UjoRxQ2ALAtoJeRAB7GZpXCs,1818
|
|
6
|
+
dct_autoencoder-0.1.0.dist-info/LICENSE,sha256=kVBYE8Z59CVgIBn5bMZF2ihgBM-2fyEDqU93DArFnQU,1072
|
|
7
|
+
dct_autoencoder-0.1.0.dist-info/METADATA,sha256=4XSc-J373HyOfe4FUk51qQy44ITa6uwN_Ylclv4P70g,856
|
|
8
|
+
dct_autoencoder-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
9
|
+
dct_autoencoder-0.1.0.dist-info/RECORD,,
|