shancx 1.8.92__py3-none-any.whl → 1.9.33.218__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. shancx/3D/__init__.py +25 -0
  2. shancx/Algo/Class.py +11 -0
  3. shancx/Algo/CudaPrefetcher1.py +112 -0
  4. shancx/Algo/Fake_image.py +24 -0
  5. shancx/Algo/Hsml.py +391 -0
  6. shancx/Algo/L2Loss.py +10 -0
  7. shancx/Algo/MetricTracker.py +132 -0
  8. shancx/Algo/Normalize.py +66 -0
  9. shancx/Algo/OptimizerWScheduler.py +38 -0
  10. shancx/Algo/Rmageresize.py +79 -0
  11. shancx/Algo/Savemodel.py +33 -0
  12. shancx/Algo/SmoothL1_losses.py +27 -0
  13. shancx/Algo/Tqdm.py +62 -0
  14. shancx/Algo/__init__.py +121 -0
  15. shancx/Algo/checknan.py +28 -0
  16. shancx/Algo/iouJU.py +83 -0
  17. shancx/Algo/mask.py +25 -0
  18. shancx/Algo/psnr.py +9 -0
  19. shancx/Algo/ssim.py +70 -0
  20. shancx/Algo/structural_similarity.py +308 -0
  21. shancx/Algo/tool.py +704 -0
  22. shancx/Calmetrics/__init__.py +97 -0
  23. shancx/Calmetrics/calmetrics.py +14 -0
  24. shancx/Calmetrics/calmetricsmatrixLib.py +147 -0
  25. shancx/Calmetrics/rmseR2score.py +35 -0
  26. shancx/Clip/__init__.py +50 -0
  27. shancx/Cmd.py +126 -0
  28. shancx/Config_.py +26 -0
  29. shancx/Df/DataFrame.py +11 -2
  30. shancx/Df/__init__.py +17 -0
  31. shancx/Df/tool.py +0 -0
  32. shancx/Diffm/Psamples.py +18 -0
  33. shancx/Diffm/__init__.py +0 -0
  34. shancx/Diffm/test.py +207 -0
  35. shancx/Doc/__init__.py +214 -0
  36. shancx/E/__init__.py +178 -152
  37. shancx/Fillmiss/__init__.py +0 -0
  38. shancx/Fillmiss/imgidwJU.py +46 -0
  39. shancx/Fillmiss/imgidwLatLonJU.py +82 -0
  40. shancx/Gpu/__init__.py +55 -0
  41. shancx/H9/__init__.py +126 -0
  42. shancx/H9/ahi_read_hsd.py +877 -0
  43. shancx/H9/ahisearchtable.py +298 -0
  44. shancx/H9/geometry.py +2439 -0
  45. shancx/Hug/__init__.py +81 -0
  46. shancx/Inst.py +22 -0
  47. shancx/Lib.py +31 -0
  48. shancx/Mos/__init__.py +37 -0
  49. shancx/NN/__init__.py +235 -106
  50. shancx/Path1.py +161 -0
  51. shancx/Plot/GlobMap.py +276 -116
  52. shancx/Plot/__init__.py +491 -1
  53. shancx/Plot/draw_day_CR_PNG.py +4 -21
  54. shancx/Plot/exam.py +116 -0
  55. shancx/Plot/plotGlobal.py +325 -0
  56. shancx/{radar_nmc.py → Plot/radarNmc.py} +4 -34
  57. shancx/{subplots_single_china_map.py → Plot/single_china_map.py} +1 -1
  58. shancx/Point.py +46 -0
  59. shancx/QC.py +223 -0
  60. shancx/RdPzl/__init__.py +32 -0
  61. shancx/Read.py +72 -0
  62. shancx/Resize.py +79 -0
  63. shancx/SN/__init__.py +62 -123
  64. shancx/Time/GetTime.py +9 -3
  65. shancx/Time/__init__.py +66 -1
  66. shancx/Time/timeCycle.py +302 -0
  67. shancx/Time/tool.py +0 -0
  68. shancx/Train/__init__.py +74 -0
  69. shancx/Train/makelist.py +187 -0
  70. shancx/Train/multiGpu.py +27 -0
  71. shancx/Train/prepare.py +161 -0
  72. shancx/Train/renet50.py +157 -0
  73. shancx/ZR.py +12 -0
  74. shancx/__init__.py +333 -262
  75. shancx/args.py +27 -0
  76. shancx/bak.py +768 -0
  77. shancx/df2database.py +62 -2
  78. shancx/geosProj.py +80 -0
  79. shancx/info.py +38 -0
  80. shancx/netdfJU.py +231 -0
  81. shancx/sendM.py +59 -0
  82. shancx/tensBoard/__init__.py +28 -0
  83. shancx/wait.py +246 -0
  84. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/METADATA +15 -5
  85. shancx-1.9.33.218.dist-info/RECORD +91 -0
  86. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
  87. my_timer_decorator/__init__.py +0 -10
  88. shancx/Dsalgor/__init__.py +0 -19
  89. shancx/E/DFGRRIB.py +0 -30
  90. shancx/EN/DFGRRIB.py +0 -30
  91. shancx/EN/__init__.py +0 -148
  92. shancx/FileRead.py +0 -44
  93. shancx/Gray2RGB.py +0 -86
  94. shancx/M/__init__.py +0 -137
  95. shancx/MN/__init__.py +0 -133
  96. shancx/N/__init__.py +0 -131
  97. shancx/Plot/draw_day_CR_PNGUS.py +0 -206
  98. shancx/Plot/draw_day_CR_SVG.py +0 -275
  99. shancx/Plot/draw_day_pre_PNGUS.py +0 -205
  100. shancx/Plot/glob_nation_map.py +0 -116
  101. shancx/Plot/radar_nmc.py +0 -61
  102. shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
  103. shancx/Plot/radar_nmc_china_map_f.py +0 -121
  104. shancx/Plot/radar_nmc_us_map_f.py +0 -128
  105. shancx/Plot/subplots_compare_devlop.py +0 -36
  106. shancx/Plot/subplots_single_china_map.py +0 -45
  107. shancx/S/__init__.py +0 -138
  108. shancx/W/__init__.py +0 -132
  109. shancx/WN/__init__.py +0 -132
  110. shancx/code.py +0 -331
  111. shancx/draw_day_CR_PNG.py +0 -200
  112. shancx/draw_day_CR_PNGUS.py +0 -206
  113. shancx/draw_day_CR_SVG.py +0 -275
  114. shancx/draw_day_pre_PNGUS.py +0 -205
  115. shancx/makenetCDFN.py +0 -42
  116. shancx/mkIMGSCX.py +0 -92
  117. shancx/netCDF.py +0 -130
  118. shancx/radar_nmc_china_map_compare1.py +0 -50
  119. shancx/radar_nmc_china_map_f.py +0 -125
  120. shancx/radar_nmc_us_map_f.py +0 -67
  121. shancx/subplots_compare_devlop.py +0 -36
  122. shancx/tool.py +0 -18
  123. shancx/user/H8mess.py +0 -317
  124. shancx/user/__init__.py +0 -137
  125. shancx/user/cinradHJN.py +0 -496
  126. shancx/user/examMeso.py +0 -293
  127. shancx/user/hjnDAAS.py +0 -26
  128. shancx/user/hjnFTP.py +0 -81
  129. shancx/user/hjnGIS.py +0 -320
  130. shancx/user/hjnGPU.py +0 -21
  131. shancx/user/hjnIDW.py +0 -68
  132. shancx/user/hjnKDTree.py +0 -75
  133. shancx/user/hjnLAPSTransform.py +0 -47
  134. shancx/user/hjnMiscellaneous.py +0 -182
  135. shancx/user/hjnProj.py +0 -162
  136. shancx/user/inotify.py +0 -41
  137. shancx/user/matplotlibMess.py +0 -87
  138. shancx/user/mkNCHJN.py +0 -623
  139. shancx/user/newTypeRadar.py +0 -492
  140. shancx/user/test.py +0 -6
  141. shancx/user/tlogP.py +0 -129
  142. shancx/util_log.py +0 -33
  143. shancx/wtx/H8mess.py +0 -315
  144. shancx/wtx/__init__.py +0 -151
  145. shancx/wtx/cinradHJN.py +0 -496
  146. shancx/wtx/colormap.py +0 -64
  147. shancx/wtx/examMeso.py +0 -298
  148. shancx/wtx/hjnDAAS.py +0 -26
  149. shancx/wtx/hjnFTP.py +0 -81
  150. shancx/wtx/hjnGIS.py +0 -330
  151. shancx/wtx/hjnGPU.py +0 -21
  152. shancx/wtx/hjnIDW.py +0 -68
  153. shancx/wtx/hjnKDTree.py +0 -75
  154. shancx/wtx/hjnLAPSTransform.py +0 -47
  155. shancx/wtx/hjnLog.py +0 -78
  156. shancx/wtx/hjnMiscellaneous.py +0 -201
  157. shancx/wtx/hjnProj.py +0 -161
  158. shancx/wtx/inotify.py +0 -41
  159. shancx/wtx/matplotlibMess.py +0 -87
  160. shancx/wtx/mkNCHJN.py +0 -613
  161. shancx/wtx/newTypeRadar.py +0 -492
  162. shancx/wtx/test.py +0 -6
  163. shancx/wtx/tlogP.py +0 -129
  164. shancx-1.8.92.dist-info/RECORD +0 -99
  165. /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
  166. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
shancx/Algo/mask.py ADDED
@@ -0,0 +1,25 @@
1
+ import torch
2
+ def apply_mask_and_select(labels, outputs):
3
+ """
4
+ 应用掩码过滤,并选择有效的标签和输出值。
5
+ :param labels: 标签张量
6
+ :param outputs: 模型输出张量
7
+ :return: 过滤后的标签和输出
8
+ """
9
+ mask = labels > 0
10
+ filtered_labels = torch.masked_select(labels, mask)
11
+ filtered_outputs = torch.masked_select(outputs, mask)
12
+ return filtered_labels, filtered_outputs
13
+ if __name__ == "__main__":
14
+ labels = torch.tensor([1, 0, 2, -1, 3], dtype=torch.float)
15
+ outputs = torch.tensor([1.1, 0.5, 2.1, -0.1, 2.9], dtype=torch.float)
16
+ filtered_labels, filtered_outputs = apply_mask_and_select(labels, outputs)
17
+ print("Filtered Labels:", filtered_labels)
18
+ print("Filtered Outputs:", filtered_outputs)
19
+ def apply_mask(img, label ): #有效值参与计算
20
+ non_zero_maskimg = img > 0
21
+ img = img * non_zero_maskimg
22
+ non_zero_masklabel = label > 0
23
+ label = label * non_zero_masklabel
24
+ return img,non_zero_maskimg, label,non_zero_masklabel
25
+
shancx/Algo/psnr.py ADDED
@@ -0,0 +1,9 @@
1
+ import torch
2
+ @staticmethod
3
+ def calculate_psnr(img1, img2):
4
+ return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
5
+ """ 使用方法
6
+ psnr += self.calculate_psnr(fake_img, label).item()
7
+ total += 1
8
+ mean_psnr = psnr / total
9
+ """
shancx/Algo/ssim.py ADDED
@@ -0,0 +1,70 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from shancx.Dsalgor.structural_similarity import MSSSIMLoss
5
+ import torch.nn as nn
6
+ class LossFunctions():
7
+
8
+ def __init__(self,
9
+ device='cpu'
10
+ ):
11
+
12
+ self.device = device
13
+
14
+ # self.alpha = 0.15
15
+ self.alpha = 0.3
16
+ self.w_fact = 0.01
17
+ self.w_exponent = 0.05
18
+ self.data_range = 1 #0.15
19
+ # self.scale_pos = hparam.scale_pos
20
+ # self.scale_neg = hparam.scale_neg
21
+ self.w_fact = torch.Tensor([self.w_fact]).to(device) #0.001
22
+ self.w_exponent = torch.Tensor([self.w_exponent]).to(device) #0.038
23
+ # self.w_exponent = nn.Parameter(torch.Tensor([self.w_exponent]).to(device))
24
+ self.data_range = self.data_range #1
25
+ self.zero = torch.Tensor([0]).to(self.device)
26
+ self.one = torch.Tensor([1]).to(self.device)
27
+ def mse(self, output, target):
28
+ """ Mean Squared Error Loss """
29
+
30
+ criterion = torch.nn.MSELoss()
31
+ loss = criterion(output, target)
32
+ return loss
33
+ def msssim(self, output, target):
34
+ """ Multi-Scale Structural Similarity Index Loss """
35
+ criterion = MSSSIMLoss(data_range=self.data_range)
36
+ loss = criterion(output, target)
37
+ return loss
38
+ def msssim_weighted_mse(self, output, target):
39
+ """ MS-SSIM with Weighted Mean Squared Error Loss """
40
+ weights = torch.minimum(self.one, self.w_fact*torch.exp(self.w_exponent*target))
41
+ criterion = MSSSIMLoss(data_range=self.data_range)
42
+ loss = self.alpha*(weights * (output - target) ** 2).mean() \
43
+ + (1.-self.alpha)*criterion(output, target)
44
+ return loss
45
+
46
+
47
+ def mse_mae(self, output, target):
48
+ """ Combined Mean Squared Error and Mean Absolute Error Loss """
49
+ loss = (1.-self.alpha)*((output - target) ** 2).mean() \
50
+ + self.alpha*(abs(output - target)).mean()
51
+ return loss
52
+
53
+ def weighted_mse(self, output, target):
54
+ """ Weighted Mean Squared Error Loss """
55
+ weights = torch.minimum(self.one, self.w_fact*torch.exp(self.w_exponent*target))
56
+ loss = (weights * (output - target) ** 2).mean()
57
+ return loss
58
+ def mae_weighted_mse(self, output, target):
59
+ """ Weighted Mean Squared Error and Mean Absolute Error Loss """
60
+ weights = torch.minimum(self.one, self.w_fact*torch.exp(self.w_exponent*target))
61
+ loss = self.alpha*(weights * (output - target) ** 2).mean() \
62
+ + (1.-self.alpha)*(torch.abs(output - target)).mean()
63
+ return loss
64
+ """
65
+ if __name__ == '__main__':
66
+ losses = LossFunctions(device=device)
67
+ cost = getattr(losses, "msssim_weighted_mse") #"msssim_weighted_mse"
68
+ loss = cost(yhat, y)
69
+
70
+ """
@@ -0,0 +1,308 @@
1
+ # MS_SSIM implementation from https://github.com/VainF/pytorch-msssim
2
+ # MSSSIM loss from https://github.com/spcl/deep-weather/blob/48748598294f02acbe029dac543e2abcb5285c09/Uncertainty_Quantification/Pytorch/models.py
3
+
4
+ import warnings
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def _fspecial_gauss_1d(size, sigma):
11
+ r"""Create 1-D gauss kernel
12
+ Args:
13
+ size (int): the size of gauss kernel
14
+ sigma (float): sigma of normal distribution
15
+ Returns:
16
+ torch.Tensor: 1D kernel (1 x 1 x size)
17
+ """
18
+ coords = torch.arange(size).to(dtype=torch.float)
19
+ coords -= size // 2
20
+
21
+ g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
22
+ g /= g.sum()
23
+
24
+ return g.unsqueeze(0).unsqueeze(0)
25
+
26
+
27
+ def gaussian_filter(input, win):
28
+ r"""Blur input with 1-D kernel
29
+ Args:
30
+ input (torch.Tensor): a batch of tensors to be blurred
31
+ window (torch.Tensor): 1-D gauss kernel
32
+ Returns:
33
+ torch.Tensor: blurred tensors
34
+ """
35
+ assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
36
+ if len(input.shape) == 4:
37
+ conv = F.conv2d
38
+ elif len(input.shape) == 5:
39
+ conv = F.conv3d
40
+ else:
41
+ raise NotImplementedError(input.shape)
42
+
43
+ C = input.shape[1]
44
+ out = input
45
+ for i, s in enumerate(input.shape[2:]):
46
+ if s >= win.shape[-1]:
47
+ out = conv(
48
+ out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C
49
+ )
50
+ else:
51
+ warnings.warn(
52
+ f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
53
+ )
54
+
55
+ return out
56
+
57
+
58
+ def _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)):
59
+
60
+ r"""Calculate ssim index for X and Y
61
+ Args:
62
+ X (torch.Tensor): images
63
+ Y (torch.Tensor): images
64
+ win (torch.Tensor): 1-D gauss kernel
65
+ data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
66
+ size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
67
+ Returns:
68
+ torch.Tensor: ssim results.
69
+ """
70
+ K1, K2 = K
71
+ # batch, channel, [depth,] height, width = X.shape
72
+ compensation = 1.0
73
+
74
+ C1 = (K1 * data_range) ** 2
75
+ C2 = (K2 * data_range) ** 2
76
+
77
+ win = win.to(X.device, dtype=X.dtype)
78
+
79
+ mu1 = gaussian_filter(X, win)
80
+ mu2 = gaussian_filter(Y, win)
81
+
82
+ mu1_sq = mu1.pow(2)
83
+ mu2_sq = mu2.pow(2)
84
+ mu1_mu2 = mu1 * mu2
85
+
86
+ sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
87
+ sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
88
+ sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
89
+
90
+ cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
91
+ ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
92
+
93
+ ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
94
+ cs = torch.flatten(cs_map, 2).mean(-1)
95
+ return ssim_per_channel, cs
96
+
97
+
98
+ def ssim(
99
+ X,
100
+ Y,
101
+ data_range=255,
102
+ size_average=True,
103
+ win_size=11,
104
+ win_sigma=1.5,
105
+ win=None,
106
+ K=(0.01, 0.03),
107
+ nonnegative_ssim=False,
108
+ ):
109
+ r"""interface of ssim
110
+ Args:
111
+ X (torch.Tensor): a batch of images, (N,C,H,W)
112
+ Y (torch.Tensor): a batch of images, (N,C,H,W)
113
+ data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
114
+ size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
115
+ win_size: (int, optional): the size of gauss kernel
116
+ win_sigma: (float, optional): sigma of normal distribution
117
+ win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
118
+ K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
119
+ nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
120
+ Returns:
121
+ torch.Tensor: ssim results
122
+ """
123
+ if not X.shape == Y.shape:
124
+ raise ValueError("Input images should have the same dimensions.")
125
+
126
+ for d in range(len(X.shape) - 1, 1, -1):
127
+ X = X.squeeze(dim=d)
128
+ Y = Y.squeeze(dim=d)
129
+
130
+ if len(X.shape) not in (4, 5):
131
+ raise ValueError(
132
+ f"Input images should be 4-d or 5-d tensors, but got {X.shape}"
133
+ )
134
+
135
+ if not X.type() == Y.type():
136
+ raise ValueError("Input images should have the same dtype.")
137
+
138
+ if win is not None: # set win_size
139
+ win_size = win.shape[-1]
140
+
141
+ if not (win_size % 2 == 1):
142
+ raise ValueError("Window size should be odd.")
143
+
144
+ if win is None:
145
+ win = _fspecial_gauss_1d(win_size, win_sigma)
146
+ win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
147
+
148
+ ssim_per_channel, cs = _ssim(
149
+ X, Y, data_range=data_range, win=win, size_average=False, K=K
150
+ )
151
+ if nonnegative_ssim:
152
+ ssim_per_channel = torch.relu(ssim_per_channel)
153
+
154
+ if size_average:
155
+ return ssim_per_channel.mean()
156
+ else:
157
+ return ssim_per_channel.mean(1)
158
+
159
+
160
+ def ms_ssim(
161
+ X,
162
+ Y,
163
+ data_range=255,
164
+ size_average=True,
165
+ win_size=11,
166
+ win_sigma=1.5,
167
+ win=None,
168
+ weights=None,
169
+ K=(0.01, 0.03),
170
+ ):
171
+
172
+ r"""interface of ms-ssim
173
+ Args:
174
+ X (torch.Tensor): a batch of images, (N,C,[T,]H,W)
175
+ Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)
176
+ data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
177
+ size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
178
+ win_size: (int, optional): the size of gauss kernel
179
+ win_sigma: (float, optional): sigma of normal distribution
180
+ win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
181
+ weights (list, optional): weights for different levels
182
+ K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
183
+ Returns:
184
+ torch.Tensor: ms-ssim results
185
+ """
186
+ if not X.shape == Y.shape:
187
+ raise ValueError("Input images should have the same dimensions.")
188
+
189
+ for d in range(len(X.shape) - 1, 1, -1):
190
+ X = X.squeeze(dim=d)
191
+ Y = Y.squeeze(dim=d)
192
+
193
+ if not X.type() == Y.type():
194
+ raise ValueError("Input images should have the same dtype.")
195
+
196
+ if len(X.shape) == 4:
197
+ avg_pool = F.avg_pool2d
198
+ elif len(X.shape) == 5:
199
+ avg_pool = F.avg_pool3d
200
+ else:
201
+ raise ValueError(
202
+ f"Input images should be 4-d or 5-d tensors, but got {X.shape}"
203
+ )
204
+
205
+ if win is not None: # set win_size
206
+ win_size = win.shape[-1]
207
+
208
+ if not (win_size % 2 == 1):
209
+ raise ValueError("Window size should be odd.")
210
+
211
+ smaller_side = min(X.shape[-2:])
212
+ assert smaller_side > (win_size - 1) * (
213
+ 2 ** 4
214
+ ), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % (
215
+ (win_size - 1) * (2 ** 4)
216
+ )
217
+
218
+ if weights is None:
219
+ weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
220
+ weights = torch.FloatTensor(weights).to(X.device, dtype=X.dtype)
221
+
222
+ if win is None:
223
+ win = _fspecial_gauss_1d(win_size, win_sigma)
224
+ win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
225
+
226
+ levels = weights.shape[0]
227
+ mcs = []
228
+ for i in range(levels):
229
+ ssim_per_channel, cs = _ssim(
230
+ X, Y, win=win, data_range=data_range, size_average=False, K=K
231
+ )
232
+
233
+ if i < levels - 1:
234
+ mcs.append(torch.relu(cs))
235
+ padding = [s % 2 for s in X.shape[2:]]
236
+ X = avg_pool(X, kernel_size=2, padding=padding)
237
+ Y = avg_pool(Y, kernel_size=2, padding=padding)
238
+
239
+ ssim_per_channel = torch.relu(ssim_per_channel) # (batch, channel)
240
+ mcs_and_ssim = torch.stack(
241
+ mcs + [ssim_per_channel], dim=0
242
+ ) # (level, batch, channel)
243
+ ms_ssim_val = torch.prod(mcs_and_ssim ** weights.view(-1, 1, 1), dim=0)
244
+
245
+ if size_average:
246
+ return ms_ssim_val.mean()
247
+ else:
248
+ return ms_ssim_val.mean(1)
249
+
250
+
251
+ class MS_SSIM(torch.nn.Module):
252
+ def __init__(
253
+ self,
254
+ data_range=255,
255
+ size_average=True,
256
+ win_size=11,
257
+ win_sigma=1.5,
258
+ channel=3,
259
+ spatial_dims=2,
260
+ weights=None,
261
+ K=(0.01, 0.03),
262
+ ):
263
+ r"""class for ms-ssim
264
+ Args:
265
+ data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
266
+ size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
267
+ win_size: (int, optional): the size of gauss kernel
268
+ win_sigma: (float, optional): sigma of normal distribution
269
+ channel (int, optional): input channels (default: 3)
270
+ weights (list, optional): weights for different levels
271
+ K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
272
+ """
273
+
274
+ super(MS_SSIM, self).__init__()
275
+ self.win_size = win_size
276
+ self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat(
277
+ [channel, 1] + [1] * spatial_dims
278
+ )
279
+ self.size_average = size_average
280
+ self.data_range = data_range
281
+ self.weights = weights
282
+ self.K = K
283
+
284
+ def forward(self, X, Y):
285
+ return ms_ssim(
286
+ X,
287
+ Y,
288
+ data_range=self.data_range,
289
+ size_average=self.size_average,
290
+ win=self.win,
291
+ weights=self.weights,
292
+ K=self.K,
293
+ )
294
+
295
+
296
+ class MSSSIMLoss(torch.nn.Module):
297
+ def __init__(self, dim=2, data_range=50):
298
+ super(MSSSIMLoss, self).__init__()
299
+ if dim == 2:
300
+ self.msssim = MS_SSIM(data_range=data_range, channel=1)
301
+ else:
302
+ self.msssim = MS_SSIM(
303
+ data_range=data_range, channel=2 # 60 for mean TODO check again
304
+ ) # after standardization ~0+-60 through analysis, channel = 2 for 2 pressure levels
305
+
306
+ def forward(self, x, target):
307
+ return 1.0 - self.msssim(x, target)
308
+