dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,630 @@
1
+ # -*-Encoding: utf-8 -*-
2
+
3
+ import torch.distributed as dist
4
+ def average_tensor(t, is_distributed):
5
+ if is_distributed:
6
+ size = float(dist.get_world_size())
7
+ dist.all_reduce(t.data, op=dist.ReduceOp.SUM)
8
+ t.data /= size
9
+
10
+ def get_stride_for_cell_type(cell_type):
11
+ if cell_type.startswith('normal') or cell_type.startswith('combiner'):
12
+ stride = 1
13
+ elif cell_type.startswith('down'):
14
+ stride = 2
15
+ elif cell_type.startswith('up'):
16
+ stride = -1
17
+ else:
18
+ raise NotImplementedError(cell_type)
19
+
20
+ return stride
21
+
22
+
23
+ def get_input_size(dataset):
24
+ if dataset in {'mnist', 'omniglot'}:
25
+ return 32
26
+ elif dataset == 'cifar10':
27
+ return 32
28
+ elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'):
29
+ size = int(dataset.split('_')[-1])
30
+ return size
31
+ elif dataset == 'ffhq':
32
+ return 256
33
+ else:
34
+ raise NotImplementedError
35
+
36
+ def groups_per_scale(num_scales, num_groups_per_scale, is_adaptive, divider=2, minimum_groups=1):
37
+ g = []
38
+ n = num_groups_per_scale
39
+ for s in range(num_scales):
40
+ assert n >= 1
41
+ g.append(n)
42
+ if is_adaptive:
43
+ n = n // divider
44
+ n = max(minimum_groups, n)
45
+ return g
46
+
47
+
48
+
49
+ def get_arch_cells(arch_type):
50
+ if arch_type == 'res_elu':
51
+ arch_cells = dict()
52
+ arch_cells['normal_enc'] = ['res_elu', 'res_elu']
53
+ arch_cells['down_enc'] = ['res_elu', 'res_elu']
54
+ arch_cells['normal_dec'] = ['res_elu', 'res_elu']
55
+ arch_cells['up_dec'] = ['res_elu', 'res_elu']
56
+ arch_cells['normal_pre'] = ['res_elu', 'res_elu']
57
+ arch_cells['down_pre'] = ['res_elu', 'res_elu']
58
+ arch_cells['normal_post'] = ['res_elu', 'res_elu']
59
+ arch_cells['up_post'] = ['res_elu', 'res_elu']
60
+ arch_cells['ar_nn'] = ['']
61
+ elif arch_type == 'res_bnelu':
62
+ arch_cells = dict()
63
+ arch_cells['normal_enc'] = ['res_bnelu', 'res_bnelu']
64
+ arch_cells['down_enc'] = ['res_bnelu', 'res_bnelu']
65
+ arch_cells['normal_dec'] = ['res_bnelu', 'res_bnelu']
66
+ arch_cells['up_dec'] = ['res_bnelu', 'res_bnelu']
67
+ arch_cells['normal_pre'] = ['res_bnelu', 'res_bnelu']
68
+ arch_cells['down_pre'] = ['res_bnelu', 'res_bnelu']
69
+ arch_cells['normal_post'] = ['res_bnelu', 'res_bnelu']
70
+ arch_cells['up_post'] = ['res_bnelu', 'res_bnelu']
71
+ arch_cells['ar_nn'] = ['']
72
+ elif arch_type == 'res_bnswish':
73
+ arch_cells = dict()
74
+ arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish']
75
+ arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish']
76
+ arch_cells['normal_dec'] = ['res_bnswish', 'res_bnswish']
77
+ arch_cells['up_dec'] = ['res_bnswish', 'res_bnswish']
78
+ arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish']
79
+ arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish']
80
+ arch_cells['normal_post'] = ['res_bnswish', 'res_bnswish']
81
+ arch_cells['up_post'] = ['res_bnswish', 'res_bnswish']
82
+ arch_cells['ar_nn'] = ['']
83
+ elif arch_type == 'mbconv_sep':
84
+ arch_cells = dict()
85
+ arch_cells['normal_enc'] = ['mconv_e6k5g0']
86
+ arch_cells['down_enc'] = ['mconv_e6k5g0']
87
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
88
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
89
+ arch_cells['normal_pre'] = ['mconv_e3k5g0']
90
+ arch_cells['down_pre'] = ['mconv_e3k5g0']
91
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
92
+ arch_cells['up_post'] = ['mconv_e3k5g0']
93
+ arch_cells['ar_nn'] = ['']
94
+ elif arch_type == 'mbconv_sep11':
95
+ arch_cells = dict()
96
+ arch_cells['normal_enc'] = ['mconv_e6k11g0']
97
+ arch_cells['down_enc'] = ['mconv_e6k11g0']
98
+ arch_cells['normal_dec'] = ['mconv_e6k11g0']
99
+ arch_cells['up_dec'] = ['mconv_e6k11g0']
100
+ arch_cells['normal_pre'] = ['mconv_e3k5g0']
101
+ arch_cells['down_pre'] = ['mconv_e3k5g0']
102
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
103
+ arch_cells['up_post'] = ['mconv_e3k5g0']
104
+ arch_cells['ar_nn'] = ['']
105
+ elif arch_type == 'res_mbconv':
106
+ arch_cells = dict()
107
+ arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish']
108
+ arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish']
109
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
110
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
111
+ arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish']
112
+ arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish']
113
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
114
+ arch_cells['up_post'] = ['mconv_e3k5g0']
115
+ arch_cells['ar_nn'] = ['']
116
+ elif arch_type == 'res53_mbconv':
117
+ arch_cells = dict()
118
+ arch_cells['normal_enc'] = ['res_bnswish5', 'res_bnswish']
119
+ arch_cells['down_enc'] = ['res_bnswish5', 'res_bnswish']
120
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
121
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
122
+ arch_cells['normal_pre'] = ['res_bnswish5', 'res_bnswish']
123
+ arch_cells['down_pre'] = ['res_bnswish5', 'res_bnswish']
124
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
125
+ arch_cells['up_post'] = ['mconv_e3k5g0']
126
+ arch_cells['ar_nn'] = ['']
127
+ elif arch_type == 'res35_mbconv':
128
+ arch_cells = dict()
129
+ arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish5']
130
+ arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish5']
131
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
132
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
133
+ arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish5']
134
+ arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish5']
135
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
136
+ arch_cells['up_post'] = ['mconv_e3k5g0']
137
+ arch_cells['ar_nn'] = ['']
138
+ elif arch_type == 'res55_mbconv':
139
+ arch_cells = dict()
140
+ arch_cells['normal_enc'] = ['res_bnswish5', 'res_bnswish5']
141
+ arch_cells['down_enc'] = ['res_bnswish5', 'res_bnswish5']
142
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
143
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
144
+ arch_cells['normal_pre'] = ['res_bnswish5', 'res_bnswish5']
145
+ arch_cells['down_pre'] = ['res_bnswish5', 'res_bnswish5']
146
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
147
+ arch_cells['up_post'] = ['mconv_e3k5g0']
148
+ arch_cells['ar_nn'] = ['']
149
+ elif arch_type == 'res_mbconv9':
150
+ arch_cells = dict()
151
+ arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish']
152
+ arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish']
153
+ arch_cells['normal_dec'] = ['mconv_e6k9g0']
154
+ arch_cells['up_dec'] = ['mconv_e6k9g0']
155
+ arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish']
156
+ arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish']
157
+ arch_cells['normal_post'] = ['mconv_e3k9g0']
158
+ arch_cells['up_post'] = ['mconv_e3k9g0']
159
+ arch_cells['ar_nn'] = ['']
160
+ elif arch_type == 'mbconv_res':
161
+ arch_cells = dict()
162
+ arch_cells['normal_enc'] = ['mconv_e6k5g0']
163
+ arch_cells['down_enc'] = ['mconv_e6k5g0']
164
+ arch_cells['normal_dec'] = ['res_bnswish', 'res_bnswish']
165
+ arch_cells['up_dec'] = ['res_bnswish', 'res_bnswish']
166
+ arch_cells['normal_pre'] = ['mconv_e3k5g0']
167
+ arch_cells['down_pre'] = ['mconv_e3k5g0']
168
+ arch_cells['normal_post'] = ['res_bnswish', 'res_bnswish']
169
+ arch_cells['up_post'] = ['res_bnswish', 'res_bnswish']
170
+ arch_cells['ar_nn'] = ['']
171
+ elif arch_type == 'mbconv_den':
172
+ arch_cells = dict()
173
+ arch_cells['normal_enc'] = ['mconv_e6k5g0']
174
+ arch_cells['down_enc'] = ['mconv_e6k5g0']
175
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
176
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
177
+ arch_cells['normal_pre'] = ['mconv_e3k5g8']
178
+ arch_cells['down_pre'] = ['mconv_e3k5g8']
179
+ arch_cells['normal_post'] = ['mconv_e3k5g8']
180
+ arch_cells['up_post'] = ['mconv_e3k5g8']
181
+ arch_cells['ar_nn'] = ['']
182
+ else:
183
+ raise NotImplementedError
184
+ return arch_cells
185
+
186
+
187
+ '''
188
+ """
189
+ Authors:
190
+ Li,Yan (liyan22021121@gmail.com)
191
+ """
192
+ import logging
193
+ import os
194
+ import shutil
195
+ import time
196
+ from datetime import timedelta
197
+ import sys
198
+
199
+ import torch
200
+ import torch.nn as nn
201
+ import numpy as np
202
+ import torch.distributed as dist
203
+
204
+ import torch.nn.functional as F
205
+ from distributions import Normal, DiscMixLogistic
206
+
207
+ class AvgrageMeter(object):
208
+ def __init__(self):
209
+ self.reset()
210
+
211
+ def reset(self):
212
+ self.avg = 0
213
+ self.sum = 0
214
+ self.cnt = 0
215
+
216
+ def update(self, val, n=1):
217
+ self.sum += val * n
218
+ self.cnt += n
219
+ self.avg = self.sum / self.cnt
220
+
221
+
222
+ class ExpMovingAvgrageMeter(object):
223
+
224
+ def __init__(self, momentum=0.9):
225
+ self.momentum = momentum
226
+ self.reset()
227
+
228
+ def reset(self):
229
+ self.avg = 0
230
+
231
+ def update(self, val):
232
+ self.avg = (1. - self.momentum) * self.avg + self.momentum * val
233
+
234
+
235
+ class DummyDDP(nn.Module):
236
+ def __init__(self, model):
237
+ super(DummyDDP, self).__init__()
238
+ self.module = model
239
+
240
+ def forward(self, *input, **kwargs):
241
+ return self.module(*input, **kwargs)
242
+
243
+
244
+ def count_parameters_in_M(model):
245
+ return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
246
+
247
+
248
+ def save_checkpoint(state, is_best, save):
249
+ filename = os.path.join(save, 'checkpoint.pth.tar')
250
+ torch.save(state, filename)
251
+ if is_best:
252
+ best_filename = os.path.join(save, 'model_best.pth.tar')
253
+ shutil.copyfile(filename, best_filename)
254
+
255
+
256
+ def save(model, model_path):
257
+ torch.save(model.state_dict(), model_path)
258
+
259
+
260
+ def load(model, model_path):
261
+ model.load_state_dict(torch.load(model_path))
262
+
263
+
264
+ def create_exp_dir(path, scripts_to_save=None):
265
+ if not os.path.exists(path):
266
+ os.makedirs(path, exist_ok=True)
267
+ print('Experiment dir : {}'.format(path))
268
+
269
+ if scripts_to_save is not None:
270
+ if not os.path.exists(os.path.join(path, 'scripts')):
271
+ os.mkdir(os.path.join(path, 'scripts'))
272
+ for script in scripts_to_save:
273
+ dst_file = os.path.join(path, 'scripts', os.path.basename(script))
274
+ shutil.copyfile(script, dst_file)
275
+
276
+
277
+ class Logger(object):
278
+ def __init__(self, rank, save):
279
+ # other libraries may set logging before arriving at this line.
280
+ # by reloading logging, we can get rid of previous configs set by other libraries.
281
+ from importlib import reload
282
+ reload(logging)
283
+ self.rank = rank
284
+ if self.rank == 0:
285
+ log_format = '%(asctime)s %(message)s'
286
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO,
287
+ format=log_format, datefmt='%m/%d %I:%M:%S %p')
288
+ fh = logging.FileHandler(os.path.join(save, 'log.txt'))
289
+ fh.setFormatter(logging.Formatter(log_format))
290
+ logging.getLogger().addHandler(fh)
291
+ self.start_time = time.time()
292
+
293
+ def info(self, string, *args):
294
+ if self.rank == 0:
295
+ elapsed_time = time.time() - self.start_time
296
+ elapsed_time = time.strftime(
297
+ '(Elapsed: %H:%M:%S) ', time.gmtime(elapsed_time))
298
+ if isinstance(string, str):
299
+ string = elapsed_time + string
300
+ else:
301
+ logging.info(elapsed_time)
302
+ logging.info(string, *args)
303
+
304
+
305
+ def reduce_tensor(tensor, world_size):
306
+ rt = tensor.clone()
307
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
308
+ rt /= world_size
309
+ return rt
310
+
311
+
312
+ def get_stride_for_cell_type(cell_type):
313
+ if cell_type.startswith('normal') or cell_type.startswith('combiner'):
314
+ stride = 1
315
+ elif cell_type.startswith('down'):
316
+ stride = 2
317
+ elif cell_type.startswith('up'):
318
+ stride = -1
319
+ else:
320
+ raise NotImplementedError(cell_type)
321
+
322
+ return stride
323
+
324
+
325
+ def get_cout(cin, stride):
326
+ if stride == 1:
327
+ cout = cin
328
+ elif stride == -1:
329
+ cout = cin // 2
330
+ elif stride == 2:
331
+ cout = 2 * cin
332
+ return cout
333
+
334
+
335
+ def kl_balancer_coeff(num_scales, groups_per_scale, fun):
336
+ if fun == 'equal':
337
+ coeff = torch.cat([torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda()
338
+ elif fun == 'linear':
339
+ coeff = torch.cat([(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda()
340
+ elif fun == 'sqrt':
341
+ coeff = torch.cat([np.sqrt(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda()
342
+ elif fun == 'square':
343
+ coeff = torch.cat([np.square(2 ** i) / groups_per_scale[num_scales - i - 1] * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda()
344
+ else:
345
+ raise NotImplementedError
346
+ # convert min to 1.
347
+ coeff /= torch.min(coeff)
348
+ return coeff
349
+
350
+
351
+ def kl_per_group(kl_all):
352
+ kl_vals = torch.mean(kl_all, dim=0)
353
+ kl_coeff_i = torch.abs(kl_all)
354
+ kl_coeff_i = torch.mean(kl_coeff_i, dim=0, keepdim=True) + 0.01
355
+
356
+ return kl_coeff_i, kl_vals
357
+
358
+
359
+ def kl_balancer(kl_all, kl_coeff=1.0, kl_balance=False, alpha_i=None):
360
+ if kl_balance and kl_coeff < 1.0:
361
+ alpha_i = alpha_i.unsqueeze(0)
362
+
363
+ kl_all = torch.stack(kl_all, dim=1)
364
+ kl_coeff_i, kl_vals = kl_per_group(kl_all)
365
+ total_kl = torch.sum(kl_coeff_i)
366
+
367
+ kl_coeff_i = kl_coeff_i / alpha_i * total_kl
368
+ kl_coeff_i = kl_coeff_i / torch.mean(kl_coeff_i, dim=1, keepdim=True)
369
+ kl = torch.sum(kl_all * kl_coeff_i.detach(), dim=1)
370
+
371
+ # for reporting
372
+ kl_coeffs = kl_coeff_i.squeeze(0)
373
+ else:
374
+ kl_all = torch.stack(kl_all, dim=1)
375
+ kl_vals = torch.mean(kl_all, dim=0)
376
+ kl = torch.sum(kl_all, dim=1)
377
+ kl_coeffs = torch.ones(size=(len(kl_vals),))
378
+
379
+ return kl_coeff * kl, kl_coeffs, kl_vals
380
+
381
+
382
+ def kl_coeff(step, total_step, constant_step, min_kl_coeff):
383
+ return max(min((step - constant_step) / total_step, 1.0), min_kl_coeff)
384
+
385
+
386
+ def log_iw(decoder, x, log_q, log_p, crop=False):
387
+ recon = reconstruction_loss(decoder, x, crop)
388
+ return - recon - log_q + log_p
389
+
390
+
391
+ def reconstruction_loss(decoder, x, crop=False):
392
+
393
+
394
+ recon = decoder.log_prob(x)
395
+
396
+ if crop:
397
+ recon = recon[:, :, 2:30, 2:30]
398
+
399
+ if isinstance(decoder, DiscMixLogistic):
400
+ return - torch.sum(recon, dim=[1, 2]) # summation over RGB is done.
401
+ else:
402
+ return - torch.sum(recon, dim=[1, 2, 3])
403
+
404
+
405
+ def tile_image(batch_image, n):
406
+ assert n * n == batch_image.size(0)
407
+ channels, height, width = batch_image.size(1), batch_image.size(2), batch_image.size(3)
408
+ batch_image = batch_image.view(n, n, channels, height, width)
409
+ batch_image = batch_image.permute(2, 0, 3, 1, 4) # n, height, n, width, c
410
+ batch_image = batch_image.contiguous().view(channels, n * height, n * width)
411
+ return batch_image
412
+
413
+
414
+ def average_gradients(params, is_distributed):
415
+ """ Gradient averaging. """
416
+ if is_distributed:
417
+ size = float(dist.get_world_size())
418
+ for param in params:
419
+ if param.requires_grad:
420
+ # print(param)
421
+ dist.all_reduce(param, op=dist.ReduceOp.SUM)
422
+ param = param//size
423
+
424
+
425
+ def average_params(params, is_distributed):
426
+ """ parameter averaging. """
427
+ if is_distributed:
428
+ size = float(dist.get_world_size())
429
+ for param in params:
430
+ dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
431
+ param.data /= size
432
+
433
+
434
+
435
+
436
+ def one_hot(indices, depth, dim):
437
+ indices = indices.unsqueeze(dim)
438
+ size = list(indices.size())
439
+ size[dim] = depth
440
+ y_onehot = torch.zeros(size).cuda()
441
+ y_onehot.zero_()
442
+ y_onehot.scatter_(dim, indices, 1)
443
+
444
+ return y_onehot
445
+
446
+
447
+ def num_output(dataset):
448
+ if dataset in {'mnist', 'omniglot'}:
449
+ return 28 * 28
450
+ elif dataset == 'cifar10':
451
+ return 3 * 32 * 32
452
+ elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'):
453
+ size = int(dataset.split('_')[-1])
454
+ return 3 * size * size
455
+ elif dataset == 'ffhq':
456
+ return 3 * 256 * 256
457
+ else:
458
+ raise NotImplementedError
459
+
460
+
461
+ def get_input_size(dataset):
462
+ if dataset in {'mnist', 'omniglot'}:
463
+ return 32
464
+ elif dataset == 'cifar10':
465
+ return 32
466
+ elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'):
467
+ size = int(dataset.split('_')[-1])
468
+ return size
469
+ elif dataset == 'ffhq':
470
+ return 256
471
+ else:
472
+ raise NotImplementedError
473
+
474
+
475
+ def pre_process(x, num_bits):
476
+ if num_bits != 8:
477
+ x = torch.floor(x * 255 / 2 ** (8 - num_bits))
478
+ x /= (2 ** num_bits - 1)
479
+ return x
480
+
481
+
482
+ def get_arch_cells(arch_type):
483
+ if arch_type == 'res_elu':
484
+ arch_cells = dict()
485
+ arch_cells['normal_enc'] = ['res_elu', 'res_elu']
486
+ arch_cells['down_enc'] = ['res_elu', 'res_elu']
487
+ arch_cells['normal_dec'] = ['res_elu', 'res_elu']
488
+ arch_cells['up_dec'] = ['res_elu', 'res_elu']
489
+ arch_cells['normal_pre'] = ['res_elu', 'res_elu']
490
+ arch_cells['down_pre'] = ['res_elu', 'res_elu']
491
+ arch_cells['normal_post'] = ['res_elu', 'res_elu']
492
+ arch_cells['up_post'] = ['res_elu', 'res_elu']
493
+ arch_cells['ar_nn'] = ['']
494
+ elif arch_type == 'res_bnelu':
495
+ arch_cells = dict()
496
+ arch_cells['normal_enc'] = ['res_bnelu', 'res_bnelu']
497
+ arch_cells['down_enc'] = ['res_bnelu', 'res_bnelu']
498
+ arch_cells['normal_dec'] = ['res_bnelu', 'res_bnelu']
499
+ arch_cells['up_dec'] = ['res_bnelu', 'res_bnelu']
500
+ arch_cells['normal_pre'] = ['res_bnelu', 'res_bnelu']
501
+ arch_cells['down_pre'] = ['res_bnelu', 'res_bnelu']
502
+ arch_cells['normal_post'] = ['res_bnelu', 'res_bnelu']
503
+ arch_cells['up_post'] = ['res_bnelu', 'res_bnelu']
504
+ arch_cells['ar_nn'] = ['']
505
+ elif arch_type == 'res_bnswish':
506
+ arch_cells = dict()
507
+ arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish']
508
+ arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish']
509
+ arch_cells['normal_dec'] = ['res_bnswish', 'res_bnswish']
510
+ arch_cells['up_dec'] = ['res_bnswish', 'res_bnswish']
511
+ arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish']
512
+ arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish']
513
+ arch_cells['normal_post'] = ['res_bnswish', 'res_bnswish']
514
+ arch_cells['up_post'] = ['res_bnswish', 'res_bnswish']
515
+ arch_cells['ar_nn'] = ['']
516
+ elif arch_type == 'mbconv_sep':
517
+ arch_cells = dict()
518
+ arch_cells['normal_enc'] = ['mconv_e6k5g0']
519
+ arch_cells['down_enc'] = ['mconv_e6k5g0']
520
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
521
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
522
+ arch_cells['normal_pre'] = ['mconv_e3k5g0']
523
+ arch_cells['down_pre'] = ['mconv_e3k5g0']
524
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
525
+ arch_cells['up_post'] = ['mconv_e3k5g0']
526
+ arch_cells['ar_nn'] = ['']
527
+ elif arch_type == 'mbconv_sep11':
528
+ arch_cells = dict()
529
+ arch_cells['normal_enc'] = ['mconv_e6k11g0']
530
+ arch_cells['down_enc'] = ['mconv_e6k11g0']
531
+ arch_cells['normal_dec'] = ['mconv_e6k11g0']
532
+ arch_cells['up_dec'] = ['mconv_e6k11g0']
533
+ arch_cells['normal_pre'] = ['mconv_e3k5g0']
534
+ arch_cells['down_pre'] = ['mconv_e3k5g0']
535
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
536
+ arch_cells['up_post'] = ['mconv_e3k5g0']
537
+ arch_cells['ar_nn'] = ['']
538
+ elif arch_type == 'res_mbconv':
539
+ arch_cells = dict()
540
+ arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish']
541
+ arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish']
542
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
543
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
544
+ arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish']
545
+ arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish']
546
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
547
+ arch_cells['up_post'] = ['mconv_e3k5g0']
548
+ arch_cells['ar_nn'] = ['']
549
+ elif arch_type == 'res53_mbconv':
550
+ arch_cells = dict()
551
+ arch_cells['normal_enc'] = ['res_bnswish5', 'res_bnswish']
552
+ arch_cells['down_enc'] = ['res_bnswish5', 'res_bnswish']
553
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
554
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
555
+ arch_cells['normal_pre'] = ['res_bnswish5', 'res_bnswish']
556
+ arch_cells['down_pre'] = ['res_bnswish5', 'res_bnswish']
557
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
558
+ arch_cells['up_post'] = ['mconv_e3k5g0']
559
+ arch_cells['ar_nn'] = ['']
560
+ elif arch_type == 'res35_mbconv':
561
+ arch_cells = dict()
562
+ arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish5']
563
+ arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish5']
564
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
565
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
566
+ arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish5']
567
+ arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish5']
568
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
569
+ arch_cells['up_post'] = ['mconv_e3k5g0']
570
+ arch_cells['ar_nn'] = ['']
571
+ elif arch_type == 'res55_mbconv':
572
+ arch_cells = dict()
573
+ arch_cells['normal_enc'] = ['res_bnswish5', 'res_bnswish5']
574
+ arch_cells['down_enc'] = ['res_bnswish5', 'res_bnswish5']
575
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
576
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
577
+ arch_cells['normal_pre'] = ['res_bnswish5', 'res_bnswish5']
578
+ arch_cells['down_pre'] = ['res_bnswish5', 'res_bnswish5']
579
+ arch_cells['normal_post'] = ['mconv_e3k5g0']
580
+ arch_cells['up_post'] = ['mconv_e3k5g0']
581
+ arch_cells['ar_nn'] = ['']
582
+ elif arch_type == 'res_mbconv9':
583
+ arch_cells = dict()
584
+ arch_cells['normal_enc'] = ['res_bnswish', 'res_bnswish']
585
+ arch_cells['down_enc'] = ['res_bnswish', 'res_bnswish']
586
+ arch_cells['normal_dec'] = ['mconv_e6k9g0']
587
+ arch_cells['up_dec'] = ['mconv_e6k9g0']
588
+ arch_cells['normal_pre'] = ['res_bnswish', 'res_bnswish']
589
+ arch_cells['down_pre'] = ['res_bnswish', 'res_bnswish']
590
+ arch_cells['normal_post'] = ['mconv_e3k9g0']
591
+ arch_cells['up_post'] = ['mconv_e3k9g0']
592
+ arch_cells['ar_nn'] = ['']
593
+ elif arch_type == 'mbconv_res':
594
+ arch_cells = dict()
595
+ arch_cells['normal_enc'] = ['mconv_e6k5g0']
596
+ arch_cells['down_enc'] = ['mconv_e6k5g0']
597
+ arch_cells['normal_dec'] = ['res_bnswish', 'res_bnswish']
598
+ arch_cells['up_dec'] = ['res_bnswish', 'res_bnswish']
599
+ arch_cells['normal_pre'] = ['mconv_e3k5g0']
600
+ arch_cells['down_pre'] = ['mconv_e3k5g0']
601
+ arch_cells['normal_post'] = ['res_bnswish', 'res_bnswish']
602
+ arch_cells['up_post'] = ['res_bnswish', 'res_bnswish']
603
+ arch_cells['ar_nn'] = ['']
604
+ elif arch_type == 'mbconv_den':
605
+ arch_cells = dict()
606
+ arch_cells['normal_enc'] = ['mconv_e6k5g0']
607
+ arch_cells['down_enc'] = ['mconv_e6k5g0']
608
+ arch_cells['normal_dec'] = ['mconv_e6k5g0']
609
+ arch_cells['up_dec'] = ['mconv_e6k5g0']
610
+ arch_cells['normal_pre'] = ['mconv_e3k5g8']
611
+ arch_cells['down_pre'] = ['mconv_e3k5g8']
612
+ arch_cells['normal_post'] = ['mconv_e3k5g8']
613
+ arch_cells['up_post'] = ['mconv_e3k5g8']
614
+ arch_cells['ar_nn'] = ['']
615
+ else:
616
+ raise NotImplementedError
617
+ return arch_cells
618
+
619
+
620
+ def groups_per_scale(num_scales, num_groups_per_scale, is_adaptive, divider=2, minimum_groups=1):
621
+ g = []
622
+ n = num_groups_per_scale
623
+ for s in range(num_scales):
624
+ assert n >= 1
625
+ g.append(n)
626
+ if is_adaptive:
627
+ n = n // divider
628
+ n = max(minimum_groups, n)
629
+ return g
630
+ '''
File without changes