cache-dit 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -0,0 +1,353 @@
1
+ import torch
2
+ import torchvision
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ # Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py
11
+ try:
12
+ from torchvision.models.utils import load_state_dict_from_url
13
+ except ImportError:
14
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
15
+
16
+ # Inception weights ported to Pytorch from
17
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
18
+ FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501
19
+
20
+
21
+ class InceptionV3(nn.Module):
22
+ """Pretrained InceptionV3 network returning feature maps"""
23
+
24
+ # Index of default block of inception to return,
25
+ # corresponds to output of final average pooling
26
+ DEFAULT_BLOCK_INDEX = 3
27
+
28
+ # Maps feature dimensionality to their output blocks indices
29
+ BLOCK_INDEX_BY_DIM = {
30
+ 64: 0, # First max pooling features
31
+ 192: 1, # Second max pooling featurs
32
+ 768: 2, # Pre-aux classifier features
33
+ 2048: 3, # Final average pooling features
34
+ }
35
+
36
+ def __init__(
37
+ self,
38
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
39
+ resize_input=True,
40
+ normalize_input=True,
41
+ requires_grad=False,
42
+ use_fid_inception=True,
43
+ ):
44
+ """Build pretrained InceptionV3
45
+
46
+ Parameters
47
+ ----------
48
+ output_blocks : list of int
49
+ Indices of blocks to return features of. Possible values are:
50
+ - 0: corresponds to output of first max pooling
51
+ - 1: corresponds to output of second max pooling
52
+ - 2: corresponds to output which is fed to aux classifier
53
+ - 3: corresponds to output of final average pooling
54
+ resize_input : bool
55
+ If true, bilinearly resizes input to width and height 299 before
56
+ feeding input to model. As the network without fully connected
57
+ layers is fully convolutional, it should be able to handle inputs
58
+ of arbitrary size, so resizing might not be strictly needed
59
+ normalize_input : bool
60
+ If true, scales the input from range (0, 1) to the range the
61
+ pretrained Inception network expects, namely (-1, 1)
62
+ requires_grad : bool
63
+ If true, parameters of the model require gradients. Possibly useful
64
+ for finetuning the network
65
+ use_fid_inception : bool
66
+ If true, uses the pretrained Inception model used in Tensorflow's
67
+ FID implementation. If false, uses the pretrained Inception model
68
+ available in torchvision. The FID Inception model has different
69
+ weights and a slightly different structure from torchvision's
70
+ Inception model. If you want to compute FID scores, you are
71
+ strongly advised to set this parameter to true to get comparable
72
+ results.
73
+ """
74
+ super(InceptionV3, self).__init__()
75
+
76
+ self.resize_input = resize_input
77
+ self.normalize_input = normalize_input
78
+ self.output_blocks = sorted(output_blocks)
79
+ self.last_needed_block = max(output_blocks)
80
+
81
+ assert (
82
+ self.last_needed_block <= 3
83
+ ), "Last possible output block index is 3"
84
+
85
+ self.blocks = nn.ModuleList()
86
+
87
+ if use_fid_inception:
88
+ inception = fid_inception_v3()
89
+ else:
90
+ inception = _inception_v3(weights="DEFAULT")
91
+
92
+ # Block 0: input to maxpool1
93
+ block0 = [
94
+ inception.Conv2d_1a_3x3,
95
+ inception.Conv2d_2a_3x3,
96
+ inception.Conv2d_2b_3x3,
97
+ nn.MaxPool2d(kernel_size=3, stride=2),
98
+ ]
99
+ self.blocks.append(nn.Sequential(*block0))
100
+
101
+ # Block 1: maxpool1 to maxpool2
102
+ if self.last_needed_block >= 1:
103
+ block1 = [
104
+ inception.Conv2d_3b_1x1,
105
+ inception.Conv2d_4a_3x3,
106
+ nn.MaxPool2d(kernel_size=3, stride=2),
107
+ ]
108
+ self.blocks.append(nn.Sequential(*block1))
109
+
110
+ # Block 2: maxpool2 to aux classifier
111
+ if self.last_needed_block >= 2:
112
+ block2 = [
113
+ inception.Mixed_5b,
114
+ inception.Mixed_5c,
115
+ inception.Mixed_5d,
116
+ inception.Mixed_6a,
117
+ inception.Mixed_6b,
118
+ inception.Mixed_6c,
119
+ inception.Mixed_6d,
120
+ inception.Mixed_6e,
121
+ ]
122
+ self.blocks.append(nn.Sequential(*block2))
123
+
124
+ # Block 3: aux classifier to final avgpool
125
+ if self.last_needed_block >= 3:
126
+ block3 = [
127
+ inception.Mixed_7a,
128
+ inception.Mixed_7b,
129
+ inception.Mixed_7c,
130
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
131
+ ]
132
+ self.blocks.append(nn.Sequential(*block3))
133
+
134
+ for param in self.parameters():
135
+ param.requires_grad = requires_grad
136
+
137
+ def forward(self, inp):
138
+ """Get Inception feature maps
139
+
140
+ Parameters
141
+ ----------
142
+ inp : torch.autograd.Variable
143
+ Input tensor of shape Bx3xHxW. Values are expected to be in
144
+ range (0, 1)
145
+
146
+ Returns
147
+ -------
148
+ List of torch.autograd.Variable, corresponding to the selected output
149
+ block, sorted ascending by index
150
+ """
151
+ outp = []
152
+ x = inp
153
+
154
+ if self.resize_input:
155
+ x = F.interpolate(
156
+ x, size=(299, 299), mode="bilinear", align_corners=False
157
+ )
158
+
159
+ if self.normalize_input:
160
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
161
+
162
+ for idx, block in enumerate(self.blocks):
163
+ x = block(x)
164
+ if idx in self.output_blocks:
165
+ outp.append(x)
166
+
167
+ if idx == self.last_needed_block:
168
+ break
169
+
170
+ return outp
171
+
172
+
173
+ def _inception_v3(*args, **kwargs):
174
+ """Wraps `torchvision.models.inception_v3`"""
175
+ try:
176
+ version = tuple(map(int, torchvision.__version__.split(".")[:2]))
177
+ except ValueError:
178
+ # Just a caution against weird version strings
179
+ version = (0,)
180
+
181
+ # Skips default weight inititialization if supported by torchvision
182
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
183
+ if version >= (0, 6):
184
+ kwargs["init_weights"] = False
185
+
186
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
187
+ # argument prior to version 0.13.
188
+ if version < (0, 13) and "weights" in kwargs:
189
+ if kwargs["weights"] == "DEFAULT":
190
+ kwargs["pretrained"] = True
191
+ elif kwargs["weights"] is None:
192
+ kwargs["pretrained"] = False
193
+ else:
194
+ raise ValueError(
195
+ "weights=={} not supported in torchvision {}".format(
196
+ kwargs["weights"], torchvision.__version__
197
+ )
198
+ )
199
+ del kwargs["weights"]
200
+
201
+ return torchvision.models.inception_v3(*args, **kwargs)
202
+
203
+
204
+ def fid_inception_v3():
205
+ """Build pretrained Inception model for FID computation
206
+
207
+ The Inception model for FID computation uses a different set of weights
208
+ and has a slightly different structure than torchvision's Inception.
209
+
210
+ This method first constructs torchvision's Inception and then patches the
211
+ necessary parts that are different in the FID Inception model.
212
+ """
213
+ inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None)
214
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
215
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
216
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
217
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
218
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
219
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
220
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
221
+ inception.Mixed_7b = FIDInceptionE_1(1280)
222
+ inception.Mixed_7c = FIDInceptionE_2(2048)
223
+
224
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False)
225
+ inception.load_state_dict(state_dict)
226
+ return inception
227
+
228
+
229
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
230
+ """InceptionA block patched for FID computation"""
231
+
232
+ def __init__(self, in_channels, pool_features):
233
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
234
+
235
+ def forward(self, x):
236
+ branch1x1 = self.branch1x1(x)
237
+
238
+ branch5x5 = self.branch5x5_1(x)
239
+ branch5x5 = self.branch5x5_2(branch5x5)
240
+
241
+ branch3x3dbl = self.branch3x3dbl_1(x)
242
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
243
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
244
+
245
+ # Patch: Tensorflow's average pool does not use the padded zero's in
246
+ # its average calculation
247
+ branch_pool = F.avg_pool2d(
248
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
249
+ )
250
+ branch_pool = self.branch_pool(branch_pool)
251
+
252
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
253
+ return torch.cat(outputs, 1)
254
+
255
+
256
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
257
+ """InceptionC block patched for FID computation"""
258
+
259
+ def __init__(self, in_channels, channels_7x7):
260
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
261
+
262
+ def forward(self, x):
263
+ branch1x1 = self.branch1x1(x)
264
+
265
+ branch7x7 = self.branch7x7_1(x)
266
+ branch7x7 = self.branch7x7_2(branch7x7)
267
+ branch7x7 = self.branch7x7_3(branch7x7)
268
+
269
+ branch7x7dbl = self.branch7x7dbl_1(x)
270
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
271
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
272
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
273
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
274
+
275
+ # Patch: Tensorflow's average pool does not use the padded zero's in
276
+ # its average calculation
277
+ branch_pool = F.avg_pool2d(
278
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
279
+ )
280
+ branch_pool = self.branch_pool(branch_pool)
281
+
282
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
283
+ return torch.cat(outputs, 1)
284
+
285
+
286
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
287
+ """First InceptionE block patched for FID computation"""
288
+
289
+ def __init__(self, in_channels):
290
+ super(FIDInceptionE_1, self).__init__(in_channels)
291
+
292
+ def forward(self, x):
293
+ branch1x1 = self.branch1x1(x)
294
+
295
+ branch3x3 = self.branch3x3_1(x)
296
+ branch3x3 = [
297
+ self.branch3x3_2a(branch3x3),
298
+ self.branch3x3_2b(branch3x3),
299
+ ]
300
+ branch3x3 = torch.cat(branch3x3, 1)
301
+
302
+ branch3x3dbl = self.branch3x3dbl_1(x)
303
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
304
+ branch3x3dbl = [
305
+ self.branch3x3dbl_3a(branch3x3dbl),
306
+ self.branch3x3dbl_3b(branch3x3dbl),
307
+ ]
308
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
309
+
310
+ # Patch: Tensorflow's average pool does not use the padded zero's in
311
+ # its average calculation
312
+ branch_pool = F.avg_pool2d(
313
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
314
+ )
315
+ branch_pool = self.branch_pool(branch_pool)
316
+
317
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
318
+ return torch.cat(outputs, 1)
319
+
320
+
321
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
322
+ """Second InceptionE block patched for FID computation"""
323
+
324
+ def __init__(self, in_channels):
325
+ super(FIDInceptionE_2, self).__init__(in_channels)
326
+
327
+ def forward(self, x):
328
+ branch1x1 = self.branch1x1(x)
329
+
330
+ branch3x3 = self.branch3x3_1(x)
331
+ branch3x3 = [
332
+ self.branch3x3_2a(branch3x3),
333
+ self.branch3x3_2b(branch3x3),
334
+ ]
335
+ branch3x3 = torch.cat(branch3x3, 1)
336
+
337
+ branch3x3dbl = self.branch3x3dbl_1(x)
338
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
339
+ branch3x3dbl = [
340
+ self.branch3x3dbl_3a(branch3x3dbl),
341
+ self.branch3x3dbl_3b(branch3x3dbl),
342
+ ]
343
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
344
+
345
+ # Patch: The FID Inception model uses max pooling instead of average
346
+ # pooling. This is likely an error in this specific Inception
347
+ # implementation, as other Inception models use average pooling here
348
+ # (which matches the description in the paper).
349
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
350
+ branch_pool = self.branch_pool(branch_pool)
351
+
352
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
353
+ return torch.cat(outputs, 1)