shancx 1.8.92__py3-none-any.whl → 1.9.33.218__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shancx/3D/__init__.py +25 -0
- shancx/Algo/Class.py +11 -0
- shancx/Algo/CudaPrefetcher1.py +112 -0
- shancx/Algo/Fake_image.py +24 -0
- shancx/Algo/Hsml.py +391 -0
- shancx/Algo/L2Loss.py +10 -0
- shancx/Algo/MetricTracker.py +132 -0
- shancx/Algo/Normalize.py +66 -0
- shancx/Algo/OptimizerWScheduler.py +38 -0
- shancx/Algo/Rmageresize.py +79 -0
- shancx/Algo/Savemodel.py +33 -0
- shancx/Algo/SmoothL1_losses.py +27 -0
- shancx/Algo/Tqdm.py +62 -0
- shancx/Algo/__init__.py +121 -0
- shancx/Algo/checknan.py +28 -0
- shancx/Algo/iouJU.py +83 -0
- shancx/Algo/mask.py +25 -0
- shancx/Algo/psnr.py +9 -0
- shancx/Algo/ssim.py +70 -0
- shancx/Algo/structural_similarity.py +308 -0
- shancx/Algo/tool.py +704 -0
- shancx/Calmetrics/__init__.py +97 -0
- shancx/Calmetrics/calmetrics.py +14 -0
- shancx/Calmetrics/calmetricsmatrixLib.py +147 -0
- shancx/Calmetrics/rmseR2score.py +35 -0
- shancx/Clip/__init__.py +50 -0
- shancx/Cmd.py +126 -0
- shancx/Config_.py +26 -0
- shancx/Df/DataFrame.py +11 -2
- shancx/Df/__init__.py +17 -0
- shancx/Df/tool.py +0 -0
- shancx/Diffm/Psamples.py +18 -0
- shancx/Diffm/__init__.py +0 -0
- shancx/Diffm/test.py +207 -0
- shancx/Doc/__init__.py +214 -0
- shancx/E/__init__.py +178 -152
- shancx/Fillmiss/__init__.py +0 -0
- shancx/Fillmiss/imgidwJU.py +46 -0
- shancx/Fillmiss/imgidwLatLonJU.py +82 -0
- shancx/Gpu/__init__.py +55 -0
- shancx/H9/__init__.py +126 -0
- shancx/H9/ahi_read_hsd.py +877 -0
- shancx/H9/ahisearchtable.py +298 -0
- shancx/H9/geometry.py +2439 -0
- shancx/Hug/__init__.py +81 -0
- shancx/Inst.py +22 -0
- shancx/Lib.py +31 -0
- shancx/Mos/__init__.py +37 -0
- shancx/NN/__init__.py +235 -106
- shancx/Path1.py +161 -0
- shancx/Plot/GlobMap.py +276 -116
- shancx/Plot/__init__.py +491 -1
- shancx/Plot/draw_day_CR_PNG.py +4 -21
- shancx/Plot/exam.py +116 -0
- shancx/Plot/plotGlobal.py +325 -0
- shancx/{radar_nmc.py → Plot/radarNmc.py} +4 -34
- shancx/{subplots_single_china_map.py → Plot/single_china_map.py} +1 -1
- shancx/Point.py +46 -0
- shancx/QC.py +223 -0
- shancx/RdPzl/__init__.py +32 -0
- shancx/Read.py +72 -0
- shancx/Resize.py +79 -0
- shancx/SN/__init__.py +62 -123
- shancx/Time/GetTime.py +9 -3
- shancx/Time/__init__.py +66 -1
- shancx/Time/timeCycle.py +302 -0
- shancx/Time/tool.py +0 -0
- shancx/Train/__init__.py +74 -0
- shancx/Train/makelist.py +187 -0
- shancx/Train/multiGpu.py +27 -0
- shancx/Train/prepare.py +161 -0
- shancx/Train/renet50.py +157 -0
- shancx/ZR.py +12 -0
- shancx/__init__.py +333 -262
- shancx/args.py +27 -0
- shancx/bak.py +768 -0
- shancx/df2database.py +62 -2
- shancx/geosProj.py +80 -0
- shancx/info.py +38 -0
- shancx/netdfJU.py +231 -0
- shancx/sendM.py +59 -0
- shancx/tensBoard/__init__.py +28 -0
- shancx/wait.py +246 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/METADATA +15 -5
- shancx-1.9.33.218.dist-info/RECORD +91 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
- my_timer_decorator/__init__.py +0 -10
- shancx/Dsalgor/__init__.py +0 -19
- shancx/E/DFGRRIB.py +0 -30
- shancx/EN/DFGRRIB.py +0 -30
- shancx/EN/__init__.py +0 -148
- shancx/FileRead.py +0 -44
- shancx/Gray2RGB.py +0 -86
- shancx/M/__init__.py +0 -137
- shancx/MN/__init__.py +0 -133
- shancx/N/__init__.py +0 -131
- shancx/Plot/draw_day_CR_PNGUS.py +0 -206
- shancx/Plot/draw_day_CR_SVG.py +0 -275
- shancx/Plot/draw_day_pre_PNGUS.py +0 -205
- shancx/Plot/glob_nation_map.py +0 -116
- shancx/Plot/radar_nmc.py +0 -61
- shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
- shancx/Plot/radar_nmc_china_map_f.py +0 -121
- shancx/Plot/radar_nmc_us_map_f.py +0 -128
- shancx/Plot/subplots_compare_devlop.py +0 -36
- shancx/Plot/subplots_single_china_map.py +0 -45
- shancx/S/__init__.py +0 -138
- shancx/W/__init__.py +0 -132
- shancx/WN/__init__.py +0 -132
- shancx/code.py +0 -331
- shancx/draw_day_CR_PNG.py +0 -200
- shancx/draw_day_CR_PNGUS.py +0 -206
- shancx/draw_day_CR_SVG.py +0 -275
- shancx/draw_day_pre_PNGUS.py +0 -205
- shancx/makenetCDFN.py +0 -42
- shancx/mkIMGSCX.py +0 -92
- shancx/netCDF.py +0 -130
- shancx/radar_nmc_china_map_compare1.py +0 -50
- shancx/radar_nmc_china_map_f.py +0 -125
- shancx/radar_nmc_us_map_f.py +0 -67
- shancx/subplots_compare_devlop.py +0 -36
- shancx/tool.py +0 -18
- shancx/user/H8mess.py +0 -317
- shancx/user/__init__.py +0 -137
- shancx/user/cinradHJN.py +0 -496
- shancx/user/examMeso.py +0 -293
- shancx/user/hjnDAAS.py +0 -26
- shancx/user/hjnFTP.py +0 -81
- shancx/user/hjnGIS.py +0 -320
- shancx/user/hjnGPU.py +0 -21
- shancx/user/hjnIDW.py +0 -68
- shancx/user/hjnKDTree.py +0 -75
- shancx/user/hjnLAPSTransform.py +0 -47
- shancx/user/hjnMiscellaneous.py +0 -182
- shancx/user/hjnProj.py +0 -162
- shancx/user/inotify.py +0 -41
- shancx/user/matplotlibMess.py +0 -87
- shancx/user/mkNCHJN.py +0 -623
- shancx/user/newTypeRadar.py +0 -492
- shancx/user/test.py +0 -6
- shancx/user/tlogP.py +0 -129
- shancx/util_log.py +0 -33
- shancx/wtx/H8mess.py +0 -315
- shancx/wtx/__init__.py +0 -151
- shancx/wtx/cinradHJN.py +0 -496
- shancx/wtx/colormap.py +0 -64
- shancx/wtx/examMeso.py +0 -298
- shancx/wtx/hjnDAAS.py +0 -26
- shancx/wtx/hjnFTP.py +0 -81
- shancx/wtx/hjnGIS.py +0 -330
- shancx/wtx/hjnGPU.py +0 -21
- shancx/wtx/hjnIDW.py +0 -68
- shancx/wtx/hjnKDTree.py +0 -75
- shancx/wtx/hjnLAPSTransform.py +0 -47
- shancx/wtx/hjnLog.py +0 -78
- shancx/wtx/hjnMiscellaneous.py +0 -201
- shancx/wtx/hjnProj.py +0 -161
- shancx/wtx/inotify.py +0 -41
- shancx/wtx/matplotlibMess.py +0 -87
- shancx/wtx/mkNCHJN.py +0 -613
- shancx/wtx/newTypeRadar.py +0 -492
- shancx/wtx/test.py +0 -6
- shancx/wtx/tlogP.py +0 -129
- shancx-1.8.92.dist-info/RECORD +0 -99
- /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
shancx/Algo/mask.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
def apply_mask_and_select(labels, outputs):
|
|
3
|
+
"""
|
|
4
|
+
应用掩码过滤,并选择有效的标签和输出值。
|
|
5
|
+
:param labels: 标签张量
|
|
6
|
+
:param outputs: 模型输出张量
|
|
7
|
+
:return: 过滤后的标签和输出
|
|
8
|
+
"""
|
|
9
|
+
mask = labels > 0
|
|
10
|
+
filtered_labels = torch.masked_select(labels, mask)
|
|
11
|
+
filtered_outputs = torch.masked_select(outputs, mask)
|
|
12
|
+
return filtered_labels, filtered_outputs
|
|
13
|
+
if __name__ == "__main__":
|
|
14
|
+
labels = torch.tensor([1, 0, 2, -1, 3], dtype=torch.float)
|
|
15
|
+
outputs = torch.tensor([1.1, 0.5, 2.1, -0.1, 2.9], dtype=torch.float)
|
|
16
|
+
filtered_labels, filtered_outputs = apply_mask_and_select(labels, outputs)
|
|
17
|
+
print("Filtered Labels:", filtered_labels)
|
|
18
|
+
print("Filtered Outputs:", filtered_outputs)
|
|
19
|
+
def apply_mask(img, label ): #有效值参与计算
|
|
20
|
+
non_zero_maskimg = img > 0
|
|
21
|
+
img = img * non_zero_maskimg
|
|
22
|
+
non_zero_masklabel = label > 0
|
|
23
|
+
label = label * non_zero_masklabel
|
|
24
|
+
return img,non_zero_maskimg, label,non_zero_masklabel
|
|
25
|
+
|
shancx/Algo/psnr.py
ADDED
shancx/Algo/ssim.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from shancx.Dsalgor.structural_similarity import MSSSIMLoss
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
class LossFunctions():
|
|
7
|
+
|
|
8
|
+
def __init__(self,
|
|
9
|
+
device='cpu'
|
|
10
|
+
):
|
|
11
|
+
|
|
12
|
+
self.device = device
|
|
13
|
+
|
|
14
|
+
# self.alpha = 0.15
|
|
15
|
+
self.alpha = 0.3
|
|
16
|
+
self.w_fact = 0.01
|
|
17
|
+
self.w_exponent = 0.05
|
|
18
|
+
self.data_range = 1 #0.15
|
|
19
|
+
# self.scale_pos = hparam.scale_pos
|
|
20
|
+
# self.scale_neg = hparam.scale_neg
|
|
21
|
+
self.w_fact = torch.Tensor([self.w_fact]).to(device) #0.001
|
|
22
|
+
self.w_exponent = torch.Tensor([self.w_exponent]).to(device) #0.038
|
|
23
|
+
# self.w_exponent = nn.Parameter(torch.Tensor([self.w_exponent]).to(device))
|
|
24
|
+
self.data_range = self.data_range #1
|
|
25
|
+
self.zero = torch.Tensor([0]).to(self.device)
|
|
26
|
+
self.one = torch.Tensor([1]).to(self.device)
|
|
27
|
+
def mse(self, output, target):
|
|
28
|
+
""" Mean Squared Error Loss """
|
|
29
|
+
|
|
30
|
+
criterion = torch.nn.MSELoss()
|
|
31
|
+
loss = criterion(output, target)
|
|
32
|
+
return loss
|
|
33
|
+
def msssim(self, output, target):
|
|
34
|
+
""" Multi-Scale Structural Similarity Index Loss """
|
|
35
|
+
criterion = MSSSIMLoss(data_range=self.data_range)
|
|
36
|
+
loss = criterion(output, target)
|
|
37
|
+
return loss
|
|
38
|
+
def msssim_weighted_mse(self, output, target):
|
|
39
|
+
""" MS-SSIM with Weighted Mean Squared Error Loss """
|
|
40
|
+
weights = torch.minimum(self.one, self.w_fact*torch.exp(self.w_exponent*target))
|
|
41
|
+
criterion = MSSSIMLoss(data_range=self.data_range)
|
|
42
|
+
loss = self.alpha*(weights * (output - target) ** 2).mean() \
|
|
43
|
+
+ (1.-self.alpha)*criterion(output, target)
|
|
44
|
+
return loss
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def mse_mae(self, output, target):
|
|
48
|
+
""" Combined Mean Squared Error and Mean Absolute Error Loss """
|
|
49
|
+
loss = (1.-self.alpha)*((output - target) ** 2).mean() \
|
|
50
|
+
+ self.alpha*(abs(output - target)).mean()
|
|
51
|
+
return loss
|
|
52
|
+
|
|
53
|
+
def weighted_mse(self, output, target):
|
|
54
|
+
""" Weighted Mean Squared Error Loss """
|
|
55
|
+
weights = torch.minimum(self.one, self.w_fact*torch.exp(self.w_exponent*target))
|
|
56
|
+
loss = (weights * (output - target) ** 2).mean()
|
|
57
|
+
return loss
|
|
58
|
+
def mae_weighted_mse(self, output, target):
|
|
59
|
+
""" Weighted Mean Squared Error and Mean Absolute Error Loss """
|
|
60
|
+
weights = torch.minimum(self.one, self.w_fact*torch.exp(self.w_exponent*target))
|
|
61
|
+
loss = self.alpha*(weights * (output - target) ** 2).mean() \
|
|
62
|
+
+ (1.-self.alpha)*(torch.abs(output - target)).mean()
|
|
63
|
+
return loss
|
|
64
|
+
"""
|
|
65
|
+
if __name__ == '__main__':
|
|
66
|
+
losses = LossFunctions(device=device)
|
|
67
|
+
cost = getattr(losses, "msssim_weighted_mse") #"msssim_weighted_mse"
|
|
68
|
+
loss = cost(yhat, y)
|
|
69
|
+
|
|
70
|
+
"""
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
# MS_SSIM implementation from https://github.com/VainF/pytorch-msssim
|
|
2
|
+
# MSSSIM loss from https://github.com/spcl/deep-weather/blob/48748598294f02acbe029dac543e2abcb5285c09/Uncertainty_Quantification/Pytorch/models.py
|
|
3
|
+
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _fspecial_gauss_1d(size, sigma):
|
|
11
|
+
r"""Create 1-D gauss kernel
|
|
12
|
+
Args:
|
|
13
|
+
size (int): the size of gauss kernel
|
|
14
|
+
sigma (float): sigma of normal distribution
|
|
15
|
+
Returns:
|
|
16
|
+
torch.Tensor: 1D kernel (1 x 1 x size)
|
|
17
|
+
"""
|
|
18
|
+
coords = torch.arange(size).to(dtype=torch.float)
|
|
19
|
+
coords -= size // 2
|
|
20
|
+
|
|
21
|
+
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
|
|
22
|
+
g /= g.sum()
|
|
23
|
+
|
|
24
|
+
return g.unsqueeze(0).unsqueeze(0)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def gaussian_filter(input, win):
|
|
28
|
+
r"""Blur input with 1-D kernel
|
|
29
|
+
Args:
|
|
30
|
+
input (torch.Tensor): a batch of tensors to be blurred
|
|
31
|
+
window (torch.Tensor): 1-D gauss kernel
|
|
32
|
+
Returns:
|
|
33
|
+
torch.Tensor: blurred tensors
|
|
34
|
+
"""
|
|
35
|
+
assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
|
|
36
|
+
if len(input.shape) == 4:
|
|
37
|
+
conv = F.conv2d
|
|
38
|
+
elif len(input.shape) == 5:
|
|
39
|
+
conv = F.conv3d
|
|
40
|
+
else:
|
|
41
|
+
raise NotImplementedError(input.shape)
|
|
42
|
+
|
|
43
|
+
C = input.shape[1]
|
|
44
|
+
out = input
|
|
45
|
+
for i, s in enumerate(input.shape[2:]):
|
|
46
|
+
if s >= win.shape[-1]:
|
|
47
|
+
out = conv(
|
|
48
|
+
out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
warnings.warn(
|
|
52
|
+
f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return out
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)):
|
|
59
|
+
|
|
60
|
+
r"""Calculate ssim index for X and Y
|
|
61
|
+
Args:
|
|
62
|
+
X (torch.Tensor): images
|
|
63
|
+
Y (torch.Tensor): images
|
|
64
|
+
win (torch.Tensor): 1-D gauss kernel
|
|
65
|
+
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
|
|
66
|
+
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
|
|
67
|
+
Returns:
|
|
68
|
+
torch.Tensor: ssim results.
|
|
69
|
+
"""
|
|
70
|
+
K1, K2 = K
|
|
71
|
+
# batch, channel, [depth,] height, width = X.shape
|
|
72
|
+
compensation = 1.0
|
|
73
|
+
|
|
74
|
+
C1 = (K1 * data_range) ** 2
|
|
75
|
+
C2 = (K2 * data_range) ** 2
|
|
76
|
+
|
|
77
|
+
win = win.to(X.device, dtype=X.dtype)
|
|
78
|
+
|
|
79
|
+
mu1 = gaussian_filter(X, win)
|
|
80
|
+
mu2 = gaussian_filter(Y, win)
|
|
81
|
+
|
|
82
|
+
mu1_sq = mu1.pow(2)
|
|
83
|
+
mu2_sq = mu2.pow(2)
|
|
84
|
+
mu1_mu2 = mu1 * mu2
|
|
85
|
+
|
|
86
|
+
sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
|
|
87
|
+
sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
|
|
88
|
+
sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
|
|
89
|
+
|
|
90
|
+
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
|
|
91
|
+
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
|
|
92
|
+
|
|
93
|
+
ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
|
|
94
|
+
cs = torch.flatten(cs_map, 2).mean(-1)
|
|
95
|
+
return ssim_per_channel, cs
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def ssim(
|
|
99
|
+
X,
|
|
100
|
+
Y,
|
|
101
|
+
data_range=255,
|
|
102
|
+
size_average=True,
|
|
103
|
+
win_size=11,
|
|
104
|
+
win_sigma=1.5,
|
|
105
|
+
win=None,
|
|
106
|
+
K=(0.01, 0.03),
|
|
107
|
+
nonnegative_ssim=False,
|
|
108
|
+
):
|
|
109
|
+
r"""interface of ssim
|
|
110
|
+
Args:
|
|
111
|
+
X (torch.Tensor): a batch of images, (N,C,H,W)
|
|
112
|
+
Y (torch.Tensor): a batch of images, (N,C,H,W)
|
|
113
|
+
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
|
|
114
|
+
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
|
|
115
|
+
win_size: (int, optional): the size of gauss kernel
|
|
116
|
+
win_sigma: (float, optional): sigma of normal distribution
|
|
117
|
+
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
|
|
118
|
+
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
|
119
|
+
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
|
|
120
|
+
Returns:
|
|
121
|
+
torch.Tensor: ssim results
|
|
122
|
+
"""
|
|
123
|
+
if not X.shape == Y.shape:
|
|
124
|
+
raise ValueError("Input images should have the same dimensions.")
|
|
125
|
+
|
|
126
|
+
for d in range(len(X.shape) - 1, 1, -1):
|
|
127
|
+
X = X.squeeze(dim=d)
|
|
128
|
+
Y = Y.squeeze(dim=d)
|
|
129
|
+
|
|
130
|
+
if len(X.shape) not in (4, 5):
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"Input images should be 4-d or 5-d tensors, but got {X.shape}"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if not X.type() == Y.type():
|
|
136
|
+
raise ValueError("Input images should have the same dtype.")
|
|
137
|
+
|
|
138
|
+
if win is not None: # set win_size
|
|
139
|
+
win_size = win.shape[-1]
|
|
140
|
+
|
|
141
|
+
if not (win_size % 2 == 1):
|
|
142
|
+
raise ValueError("Window size should be odd.")
|
|
143
|
+
|
|
144
|
+
if win is None:
|
|
145
|
+
win = _fspecial_gauss_1d(win_size, win_sigma)
|
|
146
|
+
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
|
|
147
|
+
|
|
148
|
+
ssim_per_channel, cs = _ssim(
|
|
149
|
+
X, Y, data_range=data_range, win=win, size_average=False, K=K
|
|
150
|
+
)
|
|
151
|
+
if nonnegative_ssim:
|
|
152
|
+
ssim_per_channel = torch.relu(ssim_per_channel)
|
|
153
|
+
|
|
154
|
+
if size_average:
|
|
155
|
+
return ssim_per_channel.mean()
|
|
156
|
+
else:
|
|
157
|
+
return ssim_per_channel.mean(1)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def ms_ssim(
|
|
161
|
+
X,
|
|
162
|
+
Y,
|
|
163
|
+
data_range=255,
|
|
164
|
+
size_average=True,
|
|
165
|
+
win_size=11,
|
|
166
|
+
win_sigma=1.5,
|
|
167
|
+
win=None,
|
|
168
|
+
weights=None,
|
|
169
|
+
K=(0.01, 0.03),
|
|
170
|
+
):
|
|
171
|
+
|
|
172
|
+
r"""interface of ms-ssim
|
|
173
|
+
Args:
|
|
174
|
+
X (torch.Tensor): a batch of images, (N,C,[T,]H,W)
|
|
175
|
+
Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)
|
|
176
|
+
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
|
|
177
|
+
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
|
|
178
|
+
win_size: (int, optional): the size of gauss kernel
|
|
179
|
+
win_sigma: (float, optional): sigma of normal distribution
|
|
180
|
+
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
|
|
181
|
+
weights (list, optional): weights for different levels
|
|
182
|
+
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
|
183
|
+
Returns:
|
|
184
|
+
torch.Tensor: ms-ssim results
|
|
185
|
+
"""
|
|
186
|
+
if not X.shape == Y.shape:
|
|
187
|
+
raise ValueError("Input images should have the same dimensions.")
|
|
188
|
+
|
|
189
|
+
for d in range(len(X.shape) - 1, 1, -1):
|
|
190
|
+
X = X.squeeze(dim=d)
|
|
191
|
+
Y = Y.squeeze(dim=d)
|
|
192
|
+
|
|
193
|
+
if not X.type() == Y.type():
|
|
194
|
+
raise ValueError("Input images should have the same dtype.")
|
|
195
|
+
|
|
196
|
+
if len(X.shape) == 4:
|
|
197
|
+
avg_pool = F.avg_pool2d
|
|
198
|
+
elif len(X.shape) == 5:
|
|
199
|
+
avg_pool = F.avg_pool3d
|
|
200
|
+
else:
|
|
201
|
+
raise ValueError(
|
|
202
|
+
f"Input images should be 4-d or 5-d tensors, but got {X.shape}"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if win is not None: # set win_size
|
|
206
|
+
win_size = win.shape[-1]
|
|
207
|
+
|
|
208
|
+
if not (win_size % 2 == 1):
|
|
209
|
+
raise ValueError("Window size should be odd.")
|
|
210
|
+
|
|
211
|
+
smaller_side = min(X.shape[-2:])
|
|
212
|
+
assert smaller_side > (win_size - 1) * (
|
|
213
|
+
2 ** 4
|
|
214
|
+
), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % (
|
|
215
|
+
(win_size - 1) * (2 ** 4)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if weights is None:
|
|
219
|
+
weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
|
|
220
|
+
weights = torch.FloatTensor(weights).to(X.device, dtype=X.dtype)
|
|
221
|
+
|
|
222
|
+
if win is None:
|
|
223
|
+
win = _fspecial_gauss_1d(win_size, win_sigma)
|
|
224
|
+
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
|
|
225
|
+
|
|
226
|
+
levels = weights.shape[0]
|
|
227
|
+
mcs = []
|
|
228
|
+
for i in range(levels):
|
|
229
|
+
ssim_per_channel, cs = _ssim(
|
|
230
|
+
X, Y, win=win, data_range=data_range, size_average=False, K=K
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if i < levels - 1:
|
|
234
|
+
mcs.append(torch.relu(cs))
|
|
235
|
+
padding = [s % 2 for s in X.shape[2:]]
|
|
236
|
+
X = avg_pool(X, kernel_size=2, padding=padding)
|
|
237
|
+
Y = avg_pool(Y, kernel_size=2, padding=padding)
|
|
238
|
+
|
|
239
|
+
ssim_per_channel = torch.relu(ssim_per_channel) # (batch, channel)
|
|
240
|
+
mcs_and_ssim = torch.stack(
|
|
241
|
+
mcs + [ssim_per_channel], dim=0
|
|
242
|
+
) # (level, batch, channel)
|
|
243
|
+
ms_ssim_val = torch.prod(mcs_and_ssim ** weights.view(-1, 1, 1), dim=0)
|
|
244
|
+
|
|
245
|
+
if size_average:
|
|
246
|
+
return ms_ssim_val.mean()
|
|
247
|
+
else:
|
|
248
|
+
return ms_ssim_val.mean(1)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class MS_SSIM(torch.nn.Module):
|
|
252
|
+
def __init__(
|
|
253
|
+
self,
|
|
254
|
+
data_range=255,
|
|
255
|
+
size_average=True,
|
|
256
|
+
win_size=11,
|
|
257
|
+
win_sigma=1.5,
|
|
258
|
+
channel=3,
|
|
259
|
+
spatial_dims=2,
|
|
260
|
+
weights=None,
|
|
261
|
+
K=(0.01, 0.03),
|
|
262
|
+
):
|
|
263
|
+
r"""class for ms-ssim
|
|
264
|
+
Args:
|
|
265
|
+
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
|
|
266
|
+
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
|
|
267
|
+
win_size: (int, optional): the size of gauss kernel
|
|
268
|
+
win_sigma: (float, optional): sigma of normal distribution
|
|
269
|
+
channel (int, optional): input channels (default: 3)
|
|
270
|
+
weights (list, optional): weights for different levels
|
|
271
|
+
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
super(MS_SSIM, self).__init__()
|
|
275
|
+
self.win_size = win_size
|
|
276
|
+
self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat(
|
|
277
|
+
[channel, 1] + [1] * spatial_dims
|
|
278
|
+
)
|
|
279
|
+
self.size_average = size_average
|
|
280
|
+
self.data_range = data_range
|
|
281
|
+
self.weights = weights
|
|
282
|
+
self.K = K
|
|
283
|
+
|
|
284
|
+
def forward(self, X, Y):
|
|
285
|
+
return ms_ssim(
|
|
286
|
+
X,
|
|
287
|
+
Y,
|
|
288
|
+
data_range=self.data_range,
|
|
289
|
+
size_average=self.size_average,
|
|
290
|
+
win=self.win,
|
|
291
|
+
weights=self.weights,
|
|
292
|
+
K=self.K,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class MSSSIMLoss(torch.nn.Module):
|
|
297
|
+
def __init__(self, dim=2, data_range=50):
|
|
298
|
+
super(MSSSIMLoss, self).__init__()
|
|
299
|
+
if dim == 2:
|
|
300
|
+
self.msssim = MS_SSIM(data_range=data_range, channel=1)
|
|
301
|
+
else:
|
|
302
|
+
self.msssim = MS_SSIM(
|
|
303
|
+
data_range=data_range, channel=2 # 60 for mean TODO check again
|
|
304
|
+
) # after standardization ~0+-60 through analysis, channel = 2 for 2 pressure levels
|
|
305
|
+
|
|
306
|
+
def forward(self, x, target):
|
|
307
|
+
return 1.0 - self.msssim(x, target)
|
|
308
|
+
|