dct-autoencoder 0.1.2__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dct_autoencoder-0.1.2 → dct_autoencoder-0.2.0}/PKG-INFO +1 -1
- dct_autoencoder-0.2.0/dct_autoencoder/core.py +257 -0
- {dct_autoencoder-0.1.2 → dct_autoencoder-0.2.0}/pyproject.toml +1 -1
- dct_autoencoder-0.1.2/dct_autoencoder/core.py +0 -260
- {dct_autoencoder-0.1.2 → dct_autoencoder-0.2.0}/LICENSE +0 -0
- {dct_autoencoder-0.1.2 → dct_autoencoder-0.2.0}/README.md +0 -0
- {dct_autoencoder-0.1.2 → dct_autoencoder-0.2.0}/dct_autoencoder/__init__.py +0 -0
- {dct_autoencoder-0.1.2 → dct_autoencoder-0.2.0}/dct_autoencoder/basis.py +0 -0
- {dct_autoencoder-0.1.2 → dct_autoencoder-0.2.0}/dct_autoencoder/utils.py +0 -0
- {dct_autoencoder-0.1.2 → dct_autoencoder-0.2.0}/dct_autoencoder/visualization.py +0 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
|
|
8
|
+
from .basis import get_dct_basis
|
|
9
|
+
from .utils import rgb_to_ycbcr, ycbcr_to_rgb
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DCTAutoencoder(nn.Module):
|
|
13
|
+
"""DCT Autoencoder.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
block_size (int, optional): The block size. Defaults to 8.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
block_size: int = 8,
|
|
22
|
+
luminance_compression_ratio: float = 1 / 2,
|
|
23
|
+
chrominance_compression_ratio: float = 1 / 4,
|
|
24
|
+
) -> None:
|
|
25
|
+
super().__init__()
|
|
26
|
+
dct_basis = get_dct_basis(block_size)
|
|
27
|
+
basis_functions = dct_basis.basis_functions
|
|
28
|
+
kernels = basis_functions.reshape(-1, block_size, block_size)
|
|
29
|
+
spatial_frequencies_magnitude = dct_basis.spatial_frequencies_magnitude.reshape(
|
|
30
|
+
-1
|
|
31
|
+
)
|
|
32
|
+
spatial_frequencies_components = (
|
|
33
|
+
dct_basis.spatial_frequencies_components.reshape(-1, 2)
|
|
34
|
+
)
|
|
35
|
+
sort_indices = np.argsort(spatial_frequencies_magnitude)
|
|
36
|
+
kernels = kernels[sort_indices]
|
|
37
|
+
spatial_frequencies_magnitude = spatial_frequencies_magnitude[sort_indices]
|
|
38
|
+
spatial_frequencies_components = spatial_frequencies_components[sort_indices]
|
|
39
|
+
kernels = kernels[:, np.newaxis, :, :]
|
|
40
|
+
multiplication_factor_scalar = dct_basis.multiplication_factor_scalar
|
|
41
|
+
multiplication_factor_matrix = dct_basis.multiplication_factor_matrix
|
|
42
|
+
multiplication_factor_matrix = multiplication_factor_matrix.reshape(-1)
|
|
43
|
+
multiplication_factor_matrix = multiplication_factor_matrix[sort_indices]
|
|
44
|
+
multiplication_factor_matrix = multiplication_factor_matrix[
|
|
45
|
+
np.newaxis, :, np.newaxis, np.newaxis
|
|
46
|
+
]
|
|
47
|
+
self.register_buffer("kernels", torch.from_numpy(kernels))
|
|
48
|
+
self.register_buffer(
|
|
49
|
+
"spatial_frequencies_magnitude",
|
|
50
|
+
torch.from_numpy(spatial_frequencies_magnitude),
|
|
51
|
+
)
|
|
52
|
+
self.register_buffer(
|
|
53
|
+
"spatial_frequencies_components",
|
|
54
|
+
torch.from_numpy(spatial_frequencies_components),
|
|
55
|
+
)
|
|
56
|
+
self.register_buffer("block_size", torch.tensor(block_size))
|
|
57
|
+
self.register_buffer(
|
|
58
|
+
"multiplication_factor_scalar", torch.tensor(multiplication_factor_scalar)
|
|
59
|
+
)
|
|
60
|
+
self.register_buffer(
|
|
61
|
+
"multiplication_factor_matrix",
|
|
62
|
+
torch.from_numpy(multiplication_factor_matrix),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self.embedding_dimension = (block_size**2) * 3
|
|
66
|
+
|
|
67
|
+
# compressor initialization
|
|
68
|
+
if luminance_compression_ratio == 1 and chrominance_compression_ratio == 1:
|
|
69
|
+
self.do_compression = False
|
|
70
|
+
self.compression_luminance_mask = torch.ones(
|
|
71
|
+
block_size**2,
|
|
72
|
+
dtype=bool,
|
|
73
|
+
device=self.spatial_frequencies_components.device,
|
|
74
|
+
)
|
|
75
|
+
self.compression_chrominance_mask = torch.ones(
|
|
76
|
+
block_size**2,
|
|
77
|
+
dtype=bool,
|
|
78
|
+
device=self.spatial_frequencies_components.device,
|
|
79
|
+
)
|
|
80
|
+
self.compression_luminance_passband = block_size**2
|
|
81
|
+
self.compression_chrominance_passband = block_size**2
|
|
82
|
+
else:
|
|
83
|
+
original_frequencies = self.spatial_frequencies_components.to(
|
|
84
|
+
dtype=torch.float32
|
|
85
|
+
)
|
|
86
|
+
luminance_block_size = math.ceil(block_size * luminance_compression_ratio)
|
|
87
|
+
chrominance_block_size = math.ceil(
|
|
88
|
+
block_size * chrominance_compression_ratio
|
|
89
|
+
)
|
|
90
|
+
luminance_frequencies = get_dct_basis(
|
|
91
|
+
luminance_block_size
|
|
92
|
+
).spatial_frequencies_components.reshape(-1, 2)
|
|
93
|
+
luminance_frequencies = torch.from_numpy(luminance_frequencies).to(
|
|
94
|
+
device=original_frequencies.device, dtype=torch.float32
|
|
95
|
+
)
|
|
96
|
+
chrominance_frequencies = get_dct_basis(
|
|
97
|
+
chrominance_block_size
|
|
98
|
+
).spatial_frequencies_components.reshape(-1, 2)
|
|
99
|
+
chrominance_frequencies = torch.from_numpy(chrominance_frequencies).to(
|
|
100
|
+
device=original_frequencies.device, dtype=torch.float32
|
|
101
|
+
)
|
|
102
|
+
indices = torch.arange(block_size**2, device=original_frequencies.device)
|
|
103
|
+
luminance_mask = torch.isin(
|
|
104
|
+
indices,
|
|
105
|
+
torch.cdist(original_frequencies, luminance_frequencies, p=2).argmin(
|
|
106
|
+
dim=0
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
chrominance_mask = torch.isin(
|
|
110
|
+
indices,
|
|
111
|
+
torch.cdist(original_frequencies, chrominance_frequencies, p=2).argmin(
|
|
112
|
+
dim=0
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
luminance_passband = luminance_mask.sum()
|
|
116
|
+
chrominance_passband = chrominance_mask.sum()
|
|
117
|
+
|
|
118
|
+
self.do_compression = True
|
|
119
|
+
self.compression_luminance_mask = luminance_mask
|
|
120
|
+
self.compression_chrominance_mask = chrominance_mask
|
|
121
|
+
self.compression_luminance_passband = luminance_passband
|
|
122
|
+
self.compression_chrominance_passband = chrominance_passband
|
|
123
|
+
|
|
124
|
+
def encode(self, rgb_images_batch: torch.Tensor) -> torch.Tensor:
|
|
125
|
+
"""Encodes the input RGB images.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
rgb_images_batch (torch.Tensor): The input RGB images. The images should
|
|
129
|
+
have shape (*, 3, height, width). Image values should be in the range [0, 1].
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
torch.Tensor: The encoded images.
|
|
133
|
+
"""
|
|
134
|
+
# check input
|
|
135
|
+
b, c, h, w = rgb_images_batch.shape
|
|
136
|
+
if c != 3:
|
|
137
|
+
raise ValueError("Input images must be RGB")
|
|
138
|
+
if h % self.block_size != 0 or w % self.block_size != 0:
|
|
139
|
+
raise ValueError("Image dimensions must be divisible by the block size")
|
|
140
|
+
# convert to YCbCr
|
|
141
|
+
ycbcr_tsr = rgb_to_ycbcr(rgb_images_batch)
|
|
142
|
+
# normalize to -1, 1
|
|
143
|
+
ycbcr_tsr = 2 * ycbcr_tsr - 1
|
|
144
|
+
y = ycbcr_tsr[:, [0], :, :]
|
|
145
|
+
cb = ycbcr_tsr[:, [1], :, :]
|
|
146
|
+
cr = ycbcr_tsr[:, [2], :, :]
|
|
147
|
+
|
|
148
|
+
# DCT encode
|
|
149
|
+
c1 = self.multiplication_factor_scalar
|
|
150
|
+
c2 = self.multiplication_factor_matrix
|
|
151
|
+
y = c1 * c2 * F.conv2d(y, self.kernels, stride=self.block_size.item())
|
|
152
|
+
cb = c1 * c2 * F.conv2d(cb, self.kernels, stride=self.block_size.item())
|
|
153
|
+
cr = c1 * c2 * F.conv2d(cr, self.kernels, stride=self.block_size.item())
|
|
154
|
+
encodings_batch = torch.cat([y, cb, cr], dim=1)
|
|
155
|
+
# scale down
|
|
156
|
+
encodings_batch = encodings_batch / self.block_size
|
|
157
|
+
return encodings_batch
|
|
158
|
+
|
|
159
|
+
def decode(self, encodings_batch: torch.Tensor) -> torch.Tensor:
|
|
160
|
+
"""Decodes the input encoded images.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
encodings_batch (torch.Tensor): The input encoded images.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
torch.Tensor: The decoded images.
|
|
167
|
+
"""
|
|
168
|
+
# scale up
|
|
169
|
+
encodings_batch = encodings_batch * self.block_size
|
|
170
|
+
org_ch = self.block_size**2
|
|
171
|
+
y = encodings_batch[:, :org_ch, :, :]
|
|
172
|
+
cb = encodings_batch[:, org_ch : org_ch * 2, :, :]
|
|
173
|
+
cr = encodings_batch[:, org_ch * 2 :, :, :]
|
|
174
|
+
|
|
175
|
+
# DCT Decode
|
|
176
|
+
c1 = self.multiplication_factor_scalar
|
|
177
|
+
c2 = self.multiplication_factor_matrix
|
|
178
|
+
y = c1 * F.conv_transpose2d(y * c2, self.kernels, stride=self.block_size.item())
|
|
179
|
+
cb = c1 * F.conv_transpose2d(
|
|
180
|
+
cb * c2, self.kernels, stride=self.block_size.item()
|
|
181
|
+
)
|
|
182
|
+
cr = c1 * F.conv_transpose2d(
|
|
183
|
+
cr * c2, self.kernels, stride=self.block_size.item()
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# convert to RGB
|
|
187
|
+
ycbcr_tsr = torch.cat([y, cb, cr], dim=1)
|
|
188
|
+
ycbcr_tsr = ycbcr_tsr / 2 + 0.5
|
|
189
|
+
rgb_images_batch = ycbcr_to_rgb(ycbcr_tsr)
|
|
190
|
+
return rgb_images_batch
|
|
191
|
+
|
|
192
|
+
def get_num_compressed_channels(self) -> int:
|
|
193
|
+
if not self.do_compression:
|
|
194
|
+
return self.block_size**2 * 3
|
|
195
|
+
else:
|
|
196
|
+
return (
|
|
197
|
+
self.compression_luminance_passband.item()
|
|
198
|
+
+ 2 * self.compression_chrominance_passband.item()
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def compress(self, encodings):
|
|
202
|
+
if not self.do_compression:
|
|
203
|
+
return encodings
|
|
204
|
+
else:
|
|
205
|
+
l, c1, c2 = encodings.chunk(3, dim=1)
|
|
206
|
+
luminance_mask = self.compression_luminance_mask
|
|
207
|
+
chrominance_mask = self.compression_chrominance_mask
|
|
208
|
+
l = l[:, luminance_mask, :, :]
|
|
209
|
+
c1 = c1[:, chrominance_mask, :, :]
|
|
210
|
+
c2 = c2[:, chrominance_mask, :, :]
|
|
211
|
+
compressed_encoding = torch.cat([l, c1, c2], dim=1)
|
|
212
|
+
return compressed_encoding
|
|
213
|
+
|
|
214
|
+
def decompress(self, compressed_encoding):
|
|
215
|
+
if not self.do_compression:
|
|
216
|
+
return compressed_encoding
|
|
217
|
+
else:
|
|
218
|
+
batch_size, _, height, width = compressed_encoding.shape
|
|
219
|
+
device = compressed_encoding.device
|
|
220
|
+
dtype = compressed_encoding.dtype
|
|
221
|
+
luminance_mask = self.compression_luminance_mask
|
|
222
|
+
chrominance_mask = self.compression_chrominance_mask
|
|
223
|
+
luminance_passband = self.compression_luminance_passband.item()
|
|
224
|
+
chrominance_passband = self.compression_chrominance_passband.item()
|
|
225
|
+
l_comp, c1_comp, c2_comp = compressed_encoding.split(
|
|
226
|
+
[luminance_passband, chrominance_passband, chrominance_passband],
|
|
227
|
+
dim=1,
|
|
228
|
+
)
|
|
229
|
+
l = torch.zeros(
|
|
230
|
+
batch_size,
|
|
231
|
+
self.block_size**2,
|
|
232
|
+
height,
|
|
233
|
+
width,
|
|
234
|
+
device=device,
|
|
235
|
+
dtype=dtype,
|
|
236
|
+
)
|
|
237
|
+
l[:, luminance_mask, :, :] = l_comp
|
|
238
|
+
c1 = torch.zeros(
|
|
239
|
+
batch_size,
|
|
240
|
+
self.block_size**2,
|
|
241
|
+
height,
|
|
242
|
+
width,
|
|
243
|
+
device=device,
|
|
244
|
+
dtype=dtype,
|
|
245
|
+
)
|
|
246
|
+
c1[:, chrominance_mask, :, :] = c1_comp
|
|
247
|
+
c2 = torch.zeros(
|
|
248
|
+
batch_size,
|
|
249
|
+
self.block_size**2,
|
|
250
|
+
height,
|
|
251
|
+
width,
|
|
252
|
+
device=device,
|
|
253
|
+
dtype=dtype,
|
|
254
|
+
)
|
|
255
|
+
c2[:, chrominance_mask, :, :] = c2_comp
|
|
256
|
+
decompressed_encoding = torch.cat([l, c1, c2], dim=1)
|
|
257
|
+
return decompressed_encoding
|
|
@@ -1,260 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import torch
|
|
3
|
-
from torch import nn
|
|
4
|
-
from torch.nn import functional as F
|
|
5
|
-
|
|
6
|
-
from .basis import get_dct_basis
|
|
7
|
-
from .utils import rgb_to_ycbcr, ycbcr_to_rgb
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class DCTAutoencoder(nn.Module):
|
|
11
|
-
"""DCT Autoencoder.
|
|
12
|
-
|
|
13
|
-
Args:
|
|
14
|
-
block_size (int, optional): The block size. Defaults to 8.
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
def __init__(self, block_size: int = 8) -> None:
|
|
18
|
-
super().__init__()
|
|
19
|
-
dct_basis = get_dct_basis(block_size)
|
|
20
|
-
basis_functions = dct_basis.basis_functions
|
|
21
|
-
kernels = basis_functions.reshape(-1, block_size, block_size)
|
|
22
|
-
spatial_frequencies_magnitude = dct_basis.spatial_frequencies_magnitude.reshape(
|
|
23
|
-
-1
|
|
24
|
-
)
|
|
25
|
-
sort_indices = np.argsort(spatial_frequencies_magnitude)
|
|
26
|
-
kernels = kernels[sort_indices]
|
|
27
|
-
spatial_frequencies_magnitude = spatial_frequencies_magnitude[sort_indices]
|
|
28
|
-
kernels = kernels[:, np.newaxis, :, :]
|
|
29
|
-
multiplication_factor_scalar = dct_basis.multiplication_factor_scalar
|
|
30
|
-
multiplication_factor_matrix = dct_basis.multiplication_factor_matrix
|
|
31
|
-
multiplication_factor_matrix = multiplication_factor_matrix.reshape(-1)
|
|
32
|
-
multiplication_factor_matrix = multiplication_factor_matrix[sort_indices]
|
|
33
|
-
multiplication_factor_matrix = multiplication_factor_matrix[
|
|
34
|
-
np.newaxis, :, np.newaxis, np.newaxis
|
|
35
|
-
]
|
|
36
|
-
self.register_buffer("kernels", torch.from_numpy(kernels))
|
|
37
|
-
self.register_buffer(
|
|
38
|
-
"spatial_frequencies_magnitude",
|
|
39
|
-
torch.from_numpy(spatial_frequencies_magnitude),
|
|
40
|
-
)
|
|
41
|
-
self.register_buffer("block_size", torch.tensor(block_size))
|
|
42
|
-
self.register_buffer(
|
|
43
|
-
"multiplication_factor_scalar", torch.tensor(multiplication_factor_scalar)
|
|
44
|
-
)
|
|
45
|
-
self.register_buffer(
|
|
46
|
-
"multiplication_factor_matrix",
|
|
47
|
-
torch.from_numpy(multiplication_factor_matrix),
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
self.embedding_dimension = (block_size**2) * 3
|
|
51
|
-
|
|
52
|
-
def encode(self, rgb_images_batch: torch.Tensor) -> torch.Tensor:
|
|
53
|
-
"""Encodes the input RGB images.
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
rgb_images_batch (torch.Tensor): The input RGB images. The images should
|
|
57
|
-
have shape (*, 3, height, width). Image values should be in the range [0, 1].
|
|
58
|
-
|
|
59
|
-
Returns:
|
|
60
|
-
torch.Tensor: The encoded images.
|
|
61
|
-
"""
|
|
62
|
-
# check input
|
|
63
|
-
b, c, h, w = rgb_images_batch.shape
|
|
64
|
-
if c != 3:
|
|
65
|
-
raise ValueError("Input images must be RGB")
|
|
66
|
-
if h % self.block_size != 0 or w % self.block_size != 0:
|
|
67
|
-
raise ValueError("Image dimensions must be divisible by the block size")
|
|
68
|
-
# convert to YCbCr
|
|
69
|
-
ycbcr_tsr = rgb_to_ycbcr(rgb_images_batch)
|
|
70
|
-
# normalize to -1, 1
|
|
71
|
-
ycbcr_tsr = 2 * ycbcr_tsr - 1
|
|
72
|
-
y = ycbcr_tsr[:, [0], :, :]
|
|
73
|
-
cb = ycbcr_tsr[:, [1], :, :]
|
|
74
|
-
cr = ycbcr_tsr[:, [2], :, :]
|
|
75
|
-
|
|
76
|
-
# DCT encode
|
|
77
|
-
c1 = self.multiplication_factor_scalar
|
|
78
|
-
c2 = self.multiplication_factor_matrix
|
|
79
|
-
y = c1 * c2 * F.conv2d(y, self.kernels, stride=self.block_size.item())
|
|
80
|
-
cb = c1 * c2 * F.conv2d(cb, self.kernels, stride=self.block_size.item())
|
|
81
|
-
cr = c1 * c2 * F.conv2d(cr, self.kernels, stride=self.block_size.item())
|
|
82
|
-
|
|
83
|
-
return torch.cat([y, cb, cr], dim=1)
|
|
84
|
-
|
|
85
|
-
def decode(self, encodings_batch: torch.Tensor) -> torch.Tensor:
|
|
86
|
-
"""Decodes the input encoded images.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
encodings_batch (torch.Tensor): The input encoded images.
|
|
90
|
-
|
|
91
|
-
Returns:
|
|
92
|
-
torch.Tensor: The decoded images.
|
|
93
|
-
"""
|
|
94
|
-
org_ch = self.block_size**2
|
|
95
|
-
y = encodings_batch[:, :org_ch, :, :]
|
|
96
|
-
cb = encodings_batch[:, org_ch : org_ch * 2, :, :]
|
|
97
|
-
cr = encodings_batch[:, org_ch * 2 :, :, :]
|
|
98
|
-
|
|
99
|
-
# DCT Decode
|
|
100
|
-
c1 = self.multiplication_factor_scalar
|
|
101
|
-
c2 = self.multiplication_factor_matrix
|
|
102
|
-
y = c1 * F.conv_transpose2d(y * c2, self.kernels, stride=self.block_size.item())
|
|
103
|
-
cb = c1 * F.conv_transpose2d(
|
|
104
|
-
cb * c2, self.kernels, stride=self.block_size.item()
|
|
105
|
-
)
|
|
106
|
-
cr = c1 * F.conv_transpose2d(
|
|
107
|
-
cr * c2, self.kernels, stride=self.block_size.item()
|
|
108
|
-
)
|
|
109
|
-
|
|
110
|
-
# convert to RGB
|
|
111
|
-
ycbcr_tsr = torch.cat([y, cb, cr], dim=1)
|
|
112
|
-
ycbcr_tsr = ycbcr_tsr / 2 + 0.5
|
|
113
|
-
rgb_images_batch = ycbcr_to_rgb(ycbcr_tsr)
|
|
114
|
-
return rgb_images_batch
|
|
115
|
-
|
|
116
|
-
def get_num_compressed_channels(
|
|
117
|
-
self,
|
|
118
|
-
luminance_compression_ratio: float = 1 / 2,
|
|
119
|
-
chrominance_compression_ratio: float = 1 / 4,
|
|
120
|
-
) -> int:
|
|
121
|
-
"""Get the number of compressed channels.
|
|
122
|
-
|
|
123
|
-
Args:
|
|
124
|
-
luminance_compression_ratio (float, optional): The luminance compression
|
|
125
|
-
ratio. Defaults to 1/2.
|
|
126
|
-
chrominance_compression_ratio (float, optional): The chrominance compression
|
|
127
|
-
ratio. Defaults to 1/4.
|
|
128
|
-
|
|
129
|
-
Returns:
|
|
130
|
-
int: The number of compressed channels.
|
|
131
|
-
"""
|
|
132
|
-
num_per_channel_encodings = self.block_size**2
|
|
133
|
-
num_luminance_encodings = torch.round(
|
|
134
|
-
num_per_channel_encodings * luminance_compression_ratio
|
|
135
|
-
).int()
|
|
136
|
-
num_chrominance_encodings = torch.round(
|
|
137
|
-
num_per_channel_encodings * chrominance_compression_ratio
|
|
138
|
-
).int()
|
|
139
|
-
return (num_luminance_encodings + 2 * num_chrominance_encodings).item()
|
|
140
|
-
|
|
141
|
-
def compress(
|
|
142
|
-
self,
|
|
143
|
-
encodings_batch: torch.Tensor,
|
|
144
|
-
luminance_compression_ratio: float = 1 / 2,
|
|
145
|
-
chrominance_compression_ratio: float = 1 / 4,
|
|
146
|
-
) -> torch.Tensor:
|
|
147
|
-
"""Compresses the input encodings.
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
encodings_batch (torch.Tensor): The input encodings.
|
|
151
|
-
luminance_compression_ratio (float, optional): The luminance compression
|
|
152
|
-
ratio. Defaults to 1/2.
|
|
153
|
-
chrominance_compression_ratio (float, optional): The chrominance compression
|
|
154
|
-
ratio. Defaults to 1/4.
|
|
155
|
-
|
|
156
|
-
Returns:
|
|
157
|
-
torch.Tensor: The compressed encodings.
|
|
158
|
-
"""
|
|
159
|
-
|
|
160
|
-
num_per_channel_encodings = self.block_size**2
|
|
161
|
-
num_luminance_encodings = torch.round(
|
|
162
|
-
num_per_channel_encodings * luminance_compression_ratio
|
|
163
|
-
).int()
|
|
164
|
-
num_chrominance_encodings = torch.round(
|
|
165
|
-
num_per_channel_encodings * chrominance_compression_ratio
|
|
166
|
-
).int()
|
|
167
|
-
|
|
168
|
-
luminance_encodings = encodings_batch[:, :num_per_channel_encodings]
|
|
169
|
-
chrominance_blue_encodings = encodings_batch[
|
|
170
|
-
:, num_per_channel_encodings : 2 * num_per_channel_encodings
|
|
171
|
-
]
|
|
172
|
-
chrominance_red_encodings = encodings_batch[:, 2 * num_per_channel_encodings :]
|
|
173
|
-
|
|
174
|
-
luminance_encodings = luminance_encodings[:, :num_luminance_encodings]
|
|
175
|
-
chrominance_blue_encodings = chrominance_blue_encodings[
|
|
176
|
-
:, :num_chrominance_encodings
|
|
177
|
-
]
|
|
178
|
-
chrominance_red_encodings = chrominance_red_encodings[
|
|
179
|
-
:, :num_chrominance_encodings
|
|
180
|
-
]
|
|
181
|
-
compressed_dct_encodings = torch.cat(
|
|
182
|
-
[
|
|
183
|
-
luminance_encodings,
|
|
184
|
-
chrominance_blue_encodings,
|
|
185
|
-
chrominance_red_encodings,
|
|
186
|
-
],
|
|
187
|
-
dim=1,
|
|
188
|
-
)
|
|
189
|
-
return compressed_dct_encodings
|
|
190
|
-
|
|
191
|
-
def decompress(
|
|
192
|
-
self,
|
|
193
|
-
compressed_encodings_batch: torch.Tensor,
|
|
194
|
-
luminance_compression_ratio: float = 1 / 2,
|
|
195
|
-
chrominance_compression_ratio: float = 1 / 4,
|
|
196
|
-
) -> torch.Tensor:
|
|
197
|
-
"""Decompresses the input compressed encodings.
|
|
198
|
-
|
|
199
|
-
Args:
|
|
200
|
-
compressed_encodings_batch (torch.Tensor): The input compressed encodings.
|
|
201
|
-
luminance_compression_ratio (float, optional): The luminance compression
|
|
202
|
-
ratio. Defaults to 1/2.
|
|
203
|
-
chrominance_compression_ratio (float, optional): The chrominance compression
|
|
204
|
-
ratio. Defaults to 1/4.
|
|
205
|
-
|
|
206
|
-
Returns:
|
|
207
|
-
torch.Tensor: The decompressed encodings.
|
|
208
|
-
"""
|
|
209
|
-
|
|
210
|
-
b, _, h, w = compressed_encodings_batch.shape
|
|
211
|
-
dtype = compressed_encodings_batch.dtype
|
|
212
|
-
device = compressed_encodings_batch.device
|
|
213
|
-
|
|
214
|
-
num_per_channel_encodings = self.block_size**2
|
|
215
|
-
num_luminance_encodings = torch.floor(
|
|
216
|
-
num_per_channel_encodings * luminance_compression_ratio
|
|
217
|
-
).int()
|
|
218
|
-
num_chrominance_encodings = torch.floor(
|
|
219
|
-
num_per_channel_encodings * chrominance_compression_ratio
|
|
220
|
-
).int()
|
|
221
|
-
compressed_luminance_encodings = compressed_encodings_batch[
|
|
222
|
-
:, :num_luminance_encodings
|
|
223
|
-
]
|
|
224
|
-
compressed_chrominance_blue_encodings = compressed_encodings_batch[
|
|
225
|
-
:,
|
|
226
|
-
num_luminance_encodings : num_luminance_encodings
|
|
227
|
-
+ num_chrominance_encodings,
|
|
228
|
-
]
|
|
229
|
-
compressed_chrominance_red_encodings = compressed_encodings_batch[
|
|
230
|
-
:, num_luminance_encodings + num_chrominance_encodings :
|
|
231
|
-
]
|
|
232
|
-
|
|
233
|
-
luminance_encodings = torch.zeros(
|
|
234
|
-
b, num_per_channel_encodings, h, w, dtype=dtype, device=device
|
|
235
|
-
)
|
|
236
|
-
luminance_encodings[:, :num_luminance_encodings, :, :] = (
|
|
237
|
-
compressed_luminance_encodings
|
|
238
|
-
)
|
|
239
|
-
chrominance_blue_encodings = torch.zeros(
|
|
240
|
-
b, num_per_channel_encodings, h, w, dtype=dtype, device=device
|
|
241
|
-
)
|
|
242
|
-
chrominance_blue_encodings[:, :num_chrominance_encodings, :, :] = (
|
|
243
|
-
compressed_chrominance_blue_encodings
|
|
244
|
-
)
|
|
245
|
-
chrominance_red_encodings = torch.zeros(
|
|
246
|
-
b, num_per_channel_encodings, h, w, dtype=dtype, device=device
|
|
247
|
-
)
|
|
248
|
-
chrominance_red_encodings[:, :num_chrominance_encodings, :, :] = (
|
|
249
|
-
compressed_chrominance_red_encodings
|
|
250
|
-
)
|
|
251
|
-
decompressed_dct_encodings = torch.cat(
|
|
252
|
-
[
|
|
253
|
-
luminance_encodings,
|
|
254
|
-
chrominance_blue_encodings,
|
|
255
|
-
chrominance_red_encodings,
|
|
256
|
-
],
|
|
257
|
-
dim=1,
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
return decompressed_dct_encodings
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|