dct-autoencoder 0.1.2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dct-autoencoder
3
- Version: 0.1.2
3
+ Version: 0.2.0
4
4
  Summary:
5
5
  Author: Dariush Bahrami
6
6
  Author-email: dariushbahrami1993@gmail.com
@@ -0,0 +1,257 @@
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from .basis import get_dct_basis
9
+ from .utils import rgb_to_ycbcr, ycbcr_to_rgb
10
+
11
+
12
+ class DCTAutoencoder(nn.Module):
13
+ """DCT Autoencoder.
14
+
15
+ Args:
16
+ block_size (int, optional): The block size. Defaults to 8.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ block_size: int = 8,
22
+ luminance_compression_ratio: float = 1 / 2,
23
+ chrominance_compression_ratio: float = 1 / 4,
24
+ ) -> None:
25
+ super().__init__()
26
+ dct_basis = get_dct_basis(block_size)
27
+ basis_functions = dct_basis.basis_functions
28
+ kernels = basis_functions.reshape(-1, block_size, block_size)
29
+ spatial_frequencies_magnitude = dct_basis.spatial_frequencies_magnitude.reshape(
30
+ -1
31
+ )
32
+ spatial_frequencies_components = (
33
+ dct_basis.spatial_frequencies_components.reshape(-1, 2)
34
+ )
35
+ sort_indices = np.argsort(spatial_frequencies_magnitude)
36
+ kernels = kernels[sort_indices]
37
+ spatial_frequencies_magnitude = spatial_frequencies_magnitude[sort_indices]
38
+ spatial_frequencies_components = spatial_frequencies_components[sort_indices]
39
+ kernels = kernels[:, np.newaxis, :, :]
40
+ multiplication_factor_scalar = dct_basis.multiplication_factor_scalar
41
+ multiplication_factor_matrix = dct_basis.multiplication_factor_matrix
42
+ multiplication_factor_matrix = multiplication_factor_matrix.reshape(-1)
43
+ multiplication_factor_matrix = multiplication_factor_matrix[sort_indices]
44
+ multiplication_factor_matrix = multiplication_factor_matrix[
45
+ np.newaxis, :, np.newaxis, np.newaxis
46
+ ]
47
+ self.register_buffer("kernels", torch.from_numpy(kernels))
48
+ self.register_buffer(
49
+ "spatial_frequencies_magnitude",
50
+ torch.from_numpy(spatial_frequencies_magnitude),
51
+ )
52
+ self.register_buffer(
53
+ "spatial_frequencies_components",
54
+ torch.from_numpy(spatial_frequencies_components),
55
+ )
56
+ self.register_buffer("block_size", torch.tensor(block_size))
57
+ self.register_buffer(
58
+ "multiplication_factor_scalar", torch.tensor(multiplication_factor_scalar)
59
+ )
60
+ self.register_buffer(
61
+ "multiplication_factor_matrix",
62
+ torch.from_numpy(multiplication_factor_matrix),
63
+ )
64
+
65
+ self.embedding_dimension = (block_size**2) * 3
66
+
67
+ # compressor initialization
68
+ if luminance_compression_ratio == 1 and chrominance_compression_ratio == 1:
69
+ self.do_compression = False
70
+ self.compression_luminance_mask = torch.ones(
71
+ block_size**2,
72
+ dtype=bool,
73
+ device=self.spatial_frequencies_components.device,
74
+ )
75
+ self.compression_chrominance_mask = torch.ones(
76
+ block_size**2,
77
+ dtype=bool,
78
+ device=self.spatial_frequencies_components.device,
79
+ )
80
+ self.compression_luminance_passband = block_size**2
81
+ self.compression_chrominance_passband = block_size**2
82
+ else:
83
+ original_frequencies = self.spatial_frequencies_components.to(
84
+ dtype=torch.float32
85
+ )
86
+ luminance_block_size = math.ceil(block_size * luminance_compression_ratio)
87
+ chrominance_block_size = math.ceil(
88
+ block_size * chrominance_compression_ratio
89
+ )
90
+ luminance_frequencies = get_dct_basis(
91
+ luminance_block_size
92
+ ).spatial_frequencies_components.reshape(-1, 2)
93
+ luminance_frequencies = torch.from_numpy(luminance_frequencies).to(
94
+ device=original_frequencies.device, dtype=torch.float32
95
+ )
96
+ chrominance_frequencies = get_dct_basis(
97
+ chrominance_block_size
98
+ ).spatial_frequencies_components.reshape(-1, 2)
99
+ chrominance_frequencies = torch.from_numpy(chrominance_frequencies).to(
100
+ device=original_frequencies.device, dtype=torch.float32
101
+ )
102
+ indices = torch.arange(block_size**2, device=original_frequencies.device)
103
+ luminance_mask = torch.isin(
104
+ indices,
105
+ torch.cdist(original_frequencies, luminance_frequencies, p=2).argmin(
106
+ dim=0
107
+ ),
108
+ )
109
+ chrominance_mask = torch.isin(
110
+ indices,
111
+ torch.cdist(original_frequencies, chrominance_frequencies, p=2).argmin(
112
+ dim=0
113
+ ),
114
+ )
115
+ luminance_passband = luminance_mask.sum()
116
+ chrominance_passband = chrominance_mask.sum()
117
+
118
+ self.do_compression = True
119
+ self.compression_luminance_mask = luminance_mask
120
+ self.compression_chrominance_mask = chrominance_mask
121
+ self.compression_luminance_passband = luminance_passband
122
+ self.compression_chrominance_passband = chrominance_passband
123
+
124
+ def encode(self, rgb_images_batch: torch.Tensor) -> torch.Tensor:
125
+ """Encodes the input RGB images.
126
+
127
+ Args:
128
+ rgb_images_batch (torch.Tensor): The input RGB images. The images should
129
+ have shape (*, 3, height, width). Image values should be in the range [0, 1].
130
+
131
+ Returns:
132
+ torch.Tensor: The encoded images.
133
+ """
134
+ # check input
135
+ b, c, h, w = rgb_images_batch.shape
136
+ if c != 3:
137
+ raise ValueError("Input images must be RGB")
138
+ if h % self.block_size != 0 or w % self.block_size != 0:
139
+ raise ValueError("Image dimensions must be divisible by the block size")
140
+ # convert to YCbCr
141
+ ycbcr_tsr = rgb_to_ycbcr(rgb_images_batch)
142
+ # normalize to -1, 1
143
+ ycbcr_tsr = 2 * ycbcr_tsr - 1
144
+ y = ycbcr_tsr[:, [0], :, :]
145
+ cb = ycbcr_tsr[:, [1], :, :]
146
+ cr = ycbcr_tsr[:, [2], :, :]
147
+
148
+ # DCT encode
149
+ c1 = self.multiplication_factor_scalar
150
+ c2 = self.multiplication_factor_matrix
151
+ y = c1 * c2 * F.conv2d(y, self.kernels, stride=self.block_size.item())
152
+ cb = c1 * c2 * F.conv2d(cb, self.kernels, stride=self.block_size.item())
153
+ cr = c1 * c2 * F.conv2d(cr, self.kernels, stride=self.block_size.item())
154
+ encodings_batch = torch.cat([y, cb, cr], dim=1)
155
+ # scale down
156
+ encodings_batch = encodings_batch / self.block_size
157
+ return encodings_batch
158
+
159
+ def decode(self, encodings_batch: torch.Tensor) -> torch.Tensor:
160
+ """Decodes the input encoded images.
161
+
162
+ Args:
163
+ encodings_batch (torch.Tensor): The input encoded images.
164
+
165
+ Returns:
166
+ torch.Tensor: The decoded images.
167
+ """
168
+ # scale up
169
+ encodings_batch = encodings_batch * self.block_size
170
+ org_ch = self.block_size**2
171
+ y = encodings_batch[:, :org_ch, :, :]
172
+ cb = encodings_batch[:, org_ch : org_ch * 2, :, :]
173
+ cr = encodings_batch[:, org_ch * 2 :, :, :]
174
+
175
+ # DCT Decode
176
+ c1 = self.multiplication_factor_scalar
177
+ c2 = self.multiplication_factor_matrix
178
+ y = c1 * F.conv_transpose2d(y * c2, self.kernels, stride=self.block_size.item())
179
+ cb = c1 * F.conv_transpose2d(
180
+ cb * c2, self.kernels, stride=self.block_size.item()
181
+ )
182
+ cr = c1 * F.conv_transpose2d(
183
+ cr * c2, self.kernels, stride=self.block_size.item()
184
+ )
185
+
186
+ # convert to RGB
187
+ ycbcr_tsr = torch.cat([y, cb, cr], dim=1)
188
+ ycbcr_tsr = ycbcr_tsr / 2 + 0.5
189
+ rgb_images_batch = ycbcr_to_rgb(ycbcr_tsr)
190
+ return rgb_images_batch
191
+
192
+ def get_num_compressed_channels(self) -> int:
193
+ if not self.do_compression:
194
+ return self.block_size**2 * 3
195
+ else:
196
+ return (
197
+ self.compression_luminance_passband.item()
198
+ + 2 * self.compression_chrominance_passband.item()
199
+ )
200
+
201
+ def compress(self, encodings):
202
+ if not self.do_compression:
203
+ return encodings
204
+ else:
205
+ l, c1, c2 = encodings.chunk(3, dim=1)
206
+ luminance_mask = self.compression_luminance_mask
207
+ chrominance_mask = self.compression_chrominance_mask
208
+ l = l[:, luminance_mask, :, :]
209
+ c1 = c1[:, chrominance_mask, :, :]
210
+ c2 = c2[:, chrominance_mask, :, :]
211
+ compressed_encoding = torch.cat([l, c1, c2], dim=1)
212
+ return compressed_encoding
213
+
214
+ def decompress(self, compressed_encoding):
215
+ if not self.do_compression:
216
+ return compressed_encoding
217
+ else:
218
+ batch_size, _, height, width = compressed_encoding.shape
219
+ device = compressed_encoding.device
220
+ dtype = compressed_encoding.dtype
221
+ luminance_mask = self.compression_luminance_mask
222
+ chrominance_mask = self.compression_chrominance_mask
223
+ luminance_passband = self.compression_luminance_passband.item()
224
+ chrominance_passband = self.compression_chrominance_passband.item()
225
+ l_comp, c1_comp, c2_comp = compressed_encoding.split(
226
+ [luminance_passband, chrominance_passband, chrominance_passband],
227
+ dim=1,
228
+ )
229
+ l = torch.zeros(
230
+ batch_size,
231
+ self.block_size**2,
232
+ height,
233
+ width,
234
+ device=device,
235
+ dtype=dtype,
236
+ )
237
+ l[:, luminance_mask, :, :] = l_comp
238
+ c1 = torch.zeros(
239
+ batch_size,
240
+ self.block_size**2,
241
+ height,
242
+ width,
243
+ device=device,
244
+ dtype=dtype,
245
+ )
246
+ c1[:, chrominance_mask, :, :] = c1_comp
247
+ c2 = torch.zeros(
248
+ batch_size,
249
+ self.block_size**2,
250
+ height,
251
+ width,
252
+ device=device,
253
+ dtype=dtype,
254
+ )
255
+ c2[:, chrominance_mask, :, :] = c2_comp
256
+ decompressed_encoding = torch.cat([l, c1, c2], dim=1)
257
+ return decompressed_encoding
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dct-autoencoder"
3
- version = "0.1.2"
3
+ version = "0.2.0"
4
4
  description = ""
5
5
  authors = ["Dariush Bahrami <dariushbahrami1993@gmail.com>"]
6
6
  readme = "README.md"
@@ -1,260 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- from .basis import get_dct_basis
7
- from .utils import rgb_to_ycbcr, ycbcr_to_rgb
8
-
9
-
10
- class DCTAutoencoder(nn.Module):
11
- """DCT Autoencoder.
12
-
13
- Args:
14
- block_size (int, optional): The block size. Defaults to 8.
15
- """
16
-
17
- def __init__(self, block_size: int = 8) -> None:
18
- super().__init__()
19
- dct_basis = get_dct_basis(block_size)
20
- basis_functions = dct_basis.basis_functions
21
- kernels = basis_functions.reshape(-1, block_size, block_size)
22
- spatial_frequencies_magnitude = dct_basis.spatial_frequencies_magnitude.reshape(
23
- -1
24
- )
25
- sort_indices = np.argsort(spatial_frequencies_magnitude)
26
- kernels = kernels[sort_indices]
27
- spatial_frequencies_magnitude = spatial_frequencies_magnitude[sort_indices]
28
- kernels = kernels[:, np.newaxis, :, :]
29
- multiplication_factor_scalar = dct_basis.multiplication_factor_scalar
30
- multiplication_factor_matrix = dct_basis.multiplication_factor_matrix
31
- multiplication_factor_matrix = multiplication_factor_matrix.reshape(-1)
32
- multiplication_factor_matrix = multiplication_factor_matrix[sort_indices]
33
- multiplication_factor_matrix = multiplication_factor_matrix[
34
- np.newaxis, :, np.newaxis, np.newaxis
35
- ]
36
- self.register_buffer("kernels", torch.from_numpy(kernels))
37
- self.register_buffer(
38
- "spatial_frequencies_magnitude",
39
- torch.from_numpy(spatial_frequencies_magnitude),
40
- )
41
- self.register_buffer("block_size", torch.tensor(block_size))
42
- self.register_buffer(
43
- "multiplication_factor_scalar", torch.tensor(multiplication_factor_scalar)
44
- )
45
- self.register_buffer(
46
- "multiplication_factor_matrix",
47
- torch.from_numpy(multiplication_factor_matrix),
48
- )
49
-
50
- self.embedding_dimension = (block_size**2) * 3
51
-
52
- def encode(self, rgb_images_batch: torch.Tensor) -> torch.Tensor:
53
- """Encodes the input RGB images.
54
-
55
- Args:
56
- rgb_images_batch (torch.Tensor): The input RGB images. The images should
57
- have shape (*, 3, height, width). Image values should be in the range [0, 1].
58
-
59
- Returns:
60
- torch.Tensor: The encoded images.
61
- """
62
- # check input
63
- b, c, h, w = rgb_images_batch.shape
64
- if c != 3:
65
- raise ValueError("Input images must be RGB")
66
- if h % self.block_size != 0 or w % self.block_size != 0:
67
- raise ValueError("Image dimensions must be divisible by the block size")
68
- # convert to YCbCr
69
- ycbcr_tsr = rgb_to_ycbcr(rgb_images_batch)
70
- # normalize to -1, 1
71
- ycbcr_tsr = 2 * ycbcr_tsr - 1
72
- y = ycbcr_tsr[:, [0], :, :]
73
- cb = ycbcr_tsr[:, [1], :, :]
74
- cr = ycbcr_tsr[:, [2], :, :]
75
-
76
- # DCT encode
77
- c1 = self.multiplication_factor_scalar
78
- c2 = self.multiplication_factor_matrix
79
- y = c1 * c2 * F.conv2d(y, self.kernels, stride=self.block_size.item())
80
- cb = c1 * c2 * F.conv2d(cb, self.kernels, stride=self.block_size.item())
81
- cr = c1 * c2 * F.conv2d(cr, self.kernels, stride=self.block_size.item())
82
-
83
- return torch.cat([y, cb, cr], dim=1)
84
-
85
- def decode(self, encodings_batch: torch.Tensor) -> torch.Tensor:
86
- """Decodes the input encoded images.
87
-
88
- Args:
89
- encodings_batch (torch.Tensor): The input encoded images.
90
-
91
- Returns:
92
- torch.Tensor: The decoded images.
93
- """
94
- org_ch = self.block_size**2
95
- y = encodings_batch[:, :org_ch, :, :]
96
- cb = encodings_batch[:, org_ch : org_ch * 2, :, :]
97
- cr = encodings_batch[:, org_ch * 2 :, :, :]
98
-
99
- # DCT Decode
100
- c1 = self.multiplication_factor_scalar
101
- c2 = self.multiplication_factor_matrix
102
- y = c1 * F.conv_transpose2d(y * c2, self.kernels, stride=self.block_size.item())
103
- cb = c1 * F.conv_transpose2d(
104
- cb * c2, self.kernels, stride=self.block_size.item()
105
- )
106
- cr = c1 * F.conv_transpose2d(
107
- cr * c2, self.kernels, stride=self.block_size.item()
108
- )
109
-
110
- # convert to RGB
111
- ycbcr_tsr = torch.cat([y, cb, cr], dim=1)
112
- ycbcr_tsr = ycbcr_tsr / 2 + 0.5
113
- rgb_images_batch = ycbcr_to_rgb(ycbcr_tsr)
114
- return rgb_images_batch
115
-
116
- def get_num_compressed_channels(
117
- self,
118
- luminance_compression_ratio: float = 1 / 2,
119
- chrominance_compression_ratio: float = 1 / 4,
120
- ) -> int:
121
- """Get the number of compressed channels.
122
-
123
- Args:
124
- luminance_compression_ratio (float, optional): The luminance compression
125
- ratio. Defaults to 1/2.
126
- chrominance_compression_ratio (float, optional): The chrominance compression
127
- ratio. Defaults to 1/4.
128
-
129
- Returns:
130
- int: The number of compressed channels.
131
- """
132
- num_per_channel_encodings = self.block_size**2
133
- num_luminance_encodings = torch.round(
134
- num_per_channel_encodings * luminance_compression_ratio
135
- ).int()
136
- num_chrominance_encodings = torch.round(
137
- num_per_channel_encodings * chrominance_compression_ratio
138
- ).int()
139
- return (num_luminance_encodings + 2 * num_chrominance_encodings).item()
140
-
141
- def compress(
142
- self,
143
- encodings_batch: torch.Tensor,
144
- luminance_compression_ratio: float = 1 / 2,
145
- chrominance_compression_ratio: float = 1 / 4,
146
- ) -> torch.Tensor:
147
- """Compresses the input encodings.
148
-
149
- Args:
150
- encodings_batch (torch.Tensor): The input encodings.
151
- luminance_compression_ratio (float, optional): The luminance compression
152
- ratio. Defaults to 1/2.
153
- chrominance_compression_ratio (float, optional): The chrominance compression
154
- ratio. Defaults to 1/4.
155
-
156
- Returns:
157
- torch.Tensor: The compressed encodings.
158
- """
159
-
160
- num_per_channel_encodings = self.block_size**2
161
- num_luminance_encodings = torch.round(
162
- num_per_channel_encodings * luminance_compression_ratio
163
- ).int()
164
- num_chrominance_encodings = torch.round(
165
- num_per_channel_encodings * chrominance_compression_ratio
166
- ).int()
167
-
168
- luminance_encodings = encodings_batch[:, :num_per_channel_encodings]
169
- chrominance_blue_encodings = encodings_batch[
170
- :, num_per_channel_encodings : 2 * num_per_channel_encodings
171
- ]
172
- chrominance_red_encodings = encodings_batch[:, 2 * num_per_channel_encodings :]
173
-
174
- luminance_encodings = luminance_encodings[:, :num_luminance_encodings]
175
- chrominance_blue_encodings = chrominance_blue_encodings[
176
- :, :num_chrominance_encodings
177
- ]
178
- chrominance_red_encodings = chrominance_red_encodings[
179
- :, :num_chrominance_encodings
180
- ]
181
- compressed_dct_encodings = torch.cat(
182
- [
183
- luminance_encodings,
184
- chrominance_blue_encodings,
185
- chrominance_red_encodings,
186
- ],
187
- dim=1,
188
- )
189
- return compressed_dct_encodings
190
-
191
- def decompress(
192
- self,
193
- compressed_encodings_batch: torch.Tensor,
194
- luminance_compression_ratio: float = 1 / 2,
195
- chrominance_compression_ratio: float = 1 / 4,
196
- ) -> torch.Tensor:
197
- """Decompresses the input compressed encodings.
198
-
199
- Args:
200
- compressed_encodings_batch (torch.Tensor): The input compressed encodings.
201
- luminance_compression_ratio (float, optional): The luminance compression
202
- ratio. Defaults to 1/2.
203
- chrominance_compression_ratio (float, optional): The chrominance compression
204
- ratio. Defaults to 1/4.
205
-
206
- Returns:
207
- torch.Tensor: The decompressed encodings.
208
- """
209
-
210
- b, _, h, w = compressed_encodings_batch.shape
211
- dtype = compressed_encodings_batch.dtype
212
- device = compressed_encodings_batch.device
213
-
214
- num_per_channel_encodings = self.block_size**2
215
- num_luminance_encodings = torch.floor(
216
- num_per_channel_encodings * luminance_compression_ratio
217
- ).int()
218
- num_chrominance_encodings = torch.floor(
219
- num_per_channel_encodings * chrominance_compression_ratio
220
- ).int()
221
- compressed_luminance_encodings = compressed_encodings_batch[
222
- :, :num_luminance_encodings
223
- ]
224
- compressed_chrominance_blue_encodings = compressed_encodings_batch[
225
- :,
226
- num_luminance_encodings : num_luminance_encodings
227
- + num_chrominance_encodings,
228
- ]
229
- compressed_chrominance_red_encodings = compressed_encodings_batch[
230
- :, num_luminance_encodings + num_chrominance_encodings :
231
- ]
232
-
233
- luminance_encodings = torch.zeros(
234
- b, num_per_channel_encodings, h, w, dtype=dtype, device=device
235
- )
236
- luminance_encodings[:, :num_luminance_encodings, :, :] = (
237
- compressed_luminance_encodings
238
- )
239
- chrominance_blue_encodings = torch.zeros(
240
- b, num_per_channel_encodings, h, w, dtype=dtype, device=device
241
- )
242
- chrominance_blue_encodings[:, :num_chrominance_encodings, :, :] = (
243
- compressed_chrominance_blue_encodings
244
- )
245
- chrominance_red_encodings = torch.zeros(
246
- b, num_per_channel_encodings, h, w, dtype=dtype, device=device
247
- )
248
- chrominance_red_encodings[:, :num_chrominance_encodings, :, :] = (
249
- compressed_chrominance_red_encodings
250
- )
251
- decompressed_dct_encodings = torch.cat(
252
- [
253
- luminance_encodings,
254
- chrominance_blue_encodings,
255
- chrominance_red_encodings,
256
- ],
257
- dim=1,
258
- )
259
-
260
- return decompressed_dct_encodings
File without changes