shancx 1.8.92__py3-none-any.whl → 1.9.33.218__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shancx/3D/__init__.py +25 -0
- shancx/Algo/Class.py +11 -0
- shancx/Algo/CudaPrefetcher1.py +112 -0
- shancx/Algo/Fake_image.py +24 -0
- shancx/Algo/Hsml.py +391 -0
- shancx/Algo/L2Loss.py +10 -0
- shancx/Algo/MetricTracker.py +132 -0
- shancx/Algo/Normalize.py +66 -0
- shancx/Algo/OptimizerWScheduler.py +38 -0
- shancx/Algo/Rmageresize.py +79 -0
- shancx/Algo/Savemodel.py +33 -0
- shancx/Algo/SmoothL1_losses.py +27 -0
- shancx/Algo/Tqdm.py +62 -0
- shancx/Algo/__init__.py +121 -0
- shancx/Algo/checknan.py +28 -0
- shancx/Algo/iouJU.py +83 -0
- shancx/Algo/mask.py +25 -0
- shancx/Algo/psnr.py +9 -0
- shancx/Algo/ssim.py +70 -0
- shancx/Algo/structural_similarity.py +308 -0
- shancx/Algo/tool.py +704 -0
- shancx/Calmetrics/__init__.py +97 -0
- shancx/Calmetrics/calmetrics.py +14 -0
- shancx/Calmetrics/calmetricsmatrixLib.py +147 -0
- shancx/Calmetrics/rmseR2score.py +35 -0
- shancx/Clip/__init__.py +50 -0
- shancx/Cmd.py +126 -0
- shancx/Config_.py +26 -0
- shancx/Df/DataFrame.py +11 -2
- shancx/Df/__init__.py +17 -0
- shancx/Df/tool.py +0 -0
- shancx/Diffm/Psamples.py +18 -0
- shancx/Diffm/__init__.py +0 -0
- shancx/Diffm/test.py +207 -0
- shancx/Doc/__init__.py +214 -0
- shancx/E/__init__.py +178 -152
- shancx/Fillmiss/__init__.py +0 -0
- shancx/Fillmiss/imgidwJU.py +46 -0
- shancx/Fillmiss/imgidwLatLonJU.py +82 -0
- shancx/Gpu/__init__.py +55 -0
- shancx/H9/__init__.py +126 -0
- shancx/H9/ahi_read_hsd.py +877 -0
- shancx/H9/ahisearchtable.py +298 -0
- shancx/H9/geometry.py +2439 -0
- shancx/Hug/__init__.py +81 -0
- shancx/Inst.py +22 -0
- shancx/Lib.py +31 -0
- shancx/Mos/__init__.py +37 -0
- shancx/NN/__init__.py +235 -106
- shancx/Path1.py +161 -0
- shancx/Plot/GlobMap.py +276 -116
- shancx/Plot/__init__.py +491 -1
- shancx/Plot/draw_day_CR_PNG.py +4 -21
- shancx/Plot/exam.py +116 -0
- shancx/Plot/plotGlobal.py +325 -0
- shancx/{radar_nmc.py → Plot/radarNmc.py} +4 -34
- shancx/{subplots_single_china_map.py → Plot/single_china_map.py} +1 -1
- shancx/Point.py +46 -0
- shancx/QC.py +223 -0
- shancx/RdPzl/__init__.py +32 -0
- shancx/Read.py +72 -0
- shancx/Resize.py +79 -0
- shancx/SN/__init__.py +62 -123
- shancx/Time/GetTime.py +9 -3
- shancx/Time/__init__.py +66 -1
- shancx/Time/timeCycle.py +302 -0
- shancx/Time/tool.py +0 -0
- shancx/Train/__init__.py +74 -0
- shancx/Train/makelist.py +187 -0
- shancx/Train/multiGpu.py +27 -0
- shancx/Train/prepare.py +161 -0
- shancx/Train/renet50.py +157 -0
- shancx/ZR.py +12 -0
- shancx/__init__.py +333 -262
- shancx/args.py +27 -0
- shancx/bak.py +768 -0
- shancx/df2database.py +62 -2
- shancx/geosProj.py +80 -0
- shancx/info.py +38 -0
- shancx/netdfJU.py +231 -0
- shancx/sendM.py +59 -0
- shancx/tensBoard/__init__.py +28 -0
- shancx/wait.py +246 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/METADATA +15 -5
- shancx-1.9.33.218.dist-info/RECORD +91 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
- my_timer_decorator/__init__.py +0 -10
- shancx/Dsalgor/__init__.py +0 -19
- shancx/E/DFGRRIB.py +0 -30
- shancx/EN/DFGRRIB.py +0 -30
- shancx/EN/__init__.py +0 -148
- shancx/FileRead.py +0 -44
- shancx/Gray2RGB.py +0 -86
- shancx/M/__init__.py +0 -137
- shancx/MN/__init__.py +0 -133
- shancx/N/__init__.py +0 -131
- shancx/Plot/draw_day_CR_PNGUS.py +0 -206
- shancx/Plot/draw_day_CR_SVG.py +0 -275
- shancx/Plot/draw_day_pre_PNGUS.py +0 -205
- shancx/Plot/glob_nation_map.py +0 -116
- shancx/Plot/radar_nmc.py +0 -61
- shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
- shancx/Plot/radar_nmc_china_map_f.py +0 -121
- shancx/Plot/radar_nmc_us_map_f.py +0 -128
- shancx/Plot/subplots_compare_devlop.py +0 -36
- shancx/Plot/subplots_single_china_map.py +0 -45
- shancx/S/__init__.py +0 -138
- shancx/W/__init__.py +0 -132
- shancx/WN/__init__.py +0 -132
- shancx/code.py +0 -331
- shancx/draw_day_CR_PNG.py +0 -200
- shancx/draw_day_CR_PNGUS.py +0 -206
- shancx/draw_day_CR_SVG.py +0 -275
- shancx/draw_day_pre_PNGUS.py +0 -205
- shancx/makenetCDFN.py +0 -42
- shancx/mkIMGSCX.py +0 -92
- shancx/netCDF.py +0 -130
- shancx/radar_nmc_china_map_compare1.py +0 -50
- shancx/radar_nmc_china_map_f.py +0 -125
- shancx/radar_nmc_us_map_f.py +0 -67
- shancx/subplots_compare_devlop.py +0 -36
- shancx/tool.py +0 -18
- shancx/user/H8mess.py +0 -317
- shancx/user/__init__.py +0 -137
- shancx/user/cinradHJN.py +0 -496
- shancx/user/examMeso.py +0 -293
- shancx/user/hjnDAAS.py +0 -26
- shancx/user/hjnFTP.py +0 -81
- shancx/user/hjnGIS.py +0 -320
- shancx/user/hjnGPU.py +0 -21
- shancx/user/hjnIDW.py +0 -68
- shancx/user/hjnKDTree.py +0 -75
- shancx/user/hjnLAPSTransform.py +0 -47
- shancx/user/hjnMiscellaneous.py +0 -182
- shancx/user/hjnProj.py +0 -162
- shancx/user/inotify.py +0 -41
- shancx/user/matplotlibMess.py +0 -87
- shancx/user/mkNCHJN.py +0 -623
- shancx/user/newTypeRadar.py +0 -492
- shancx/user/test.py +0 -6
- shancx/user/tlogP.py +0 -129
- shancx/util_log.py +0 -33
- shancx/wtx/H8mess.py +0 -315
- shancx/wtx/__init__.py +0 -151
- shancx/wtx/cinradHJN.py +0 -496
- shancx/wtx/colormap.py +0 -64
- shancx/wtx/examMeso.py +0 -298
- shancx/wtx/hjnDAAS.py +0 -26
- shancx/wtx/hjnFTP.py +0 -81
- shancx/wtx/hjnGIS.py +0 -330
- shancx/wtx/hjnGPU.py +0 -21
- shancx/wtx/hjnIDW.py +0 -68
- shancx/wtx/hjnKDTree.py +0 -75
- shancx/wtx/hjnLAPSTransform.py +0 -47
- shancx/wtx/hjnLog.py +0 -78
- shancx/wtx/hjnMiscellaneous.py +0 -201
- shancx/wtx/hjnProj.py +0 -161
- shancx/wtx/inotify.py +0 -41
- shancx/wtx/matplotlibMess.py +0 -87
- shancx/wtx/mkNCHJN.py +0 -613
- shancx/wtx/newTypeRadar.py +0 -492
- shancx/wtx/test.py +0 -6
- shancx/wtx/tlogP.py +0 -129
- shancx-1.8.92.dist-info/RECORD +0 -99
- /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
shancx/3D/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
# from mpl_toolkits.mplot3d import Axes3D
|
|
5
|
+
from shancx import crDir
|
|
6
|
+
def plot3DJU(x,path="./3D_test1.png"):
|
|
7
|
+
data = x.cpu().numpy() if x.is_cuda else x[:,:,:].numpy()
|
|
8
|
+
data = data[:, ::5, ::5]
|
|
9
|
+
x1 = np.arange(data.shape[0]) # x轴的范围,取决于data的第一个维度
|
|
10
|
+
y1 = np.arange(data.shape[1]) # y轴的范围,取决于data的第二个维度
|
|
11
|
+
z1 = np.arange(data.shape[2]) # z轴的范围,取决于data的第三个维度
|
|
12
|
+
x1, y1, z1 = np.meshgrid(x1, y1, z1)
|
|
13
|
+
x1_flat = x1.flatten()
|
|
14
|
+
y1_flat = y1.flatten()
|
|
15
|
+
z1_flat = z1.flatten()
|
|
16
|
+
colors = data.flatten()
|
|
17
|
+
fig = plt.figure(figsize=(10, 7))
|
|
18
|
+
ax = fig.add_subplot(111, projection='3d')
|
|
19
|
+
scatter = ax.scatter(x1_flat, y1_flat, z1_flat, c=colors, cmap='viridis')
|
|
20
|
+
ax.set_title("3D Scatter Plot of Shape (400, 640, 400)")
|
|
21
|
+
plt.colorbar(scatter)
|
|
22
|
+
outpath = path
|
|
23
|
+
crDir(outpath)
|
|
24
|
+
plt.savefig(outpath)
|
|
25
|
+
plt.close()
|
shancx/Algo/Class.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import copy
|
|
3
|
+
def classify1h(pre0):
|
|
4
|
+
pre = copy.deepcopy(pre0)
|
|
5
|
+
pre[torch.logical_and(pre0 >= 0.1, pre0 <= 2.5)] = 1
|
|
6
|
+
pre[torch.logical_and(pre0 > 2.5, pre0 <= 8)] = 2
|
|
7
|
+
pre[torch.logical_and(pre0 > 8, pre0 <= 16)] = 3
|
|
8
|
+
pre[torch.logical_and(pre0 > 16, pre0 <= 300)] = 4
|
|
9
|
+
pre[pre0 > 300] = -1
|
|
10
|
+
pre[torch.isnan(pre0)] = -1
|
|
11
|
+
return pre
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
|
|
2
|
+
from torch.utils.data import Dataset, DataLoader
|
|
3
|
+
import torch
|
|
4
|
+
class CUDAPrefetcher1:
|
|
5
|
+
"""
|
|
6
|
+
Use the CUDA side to accelerate data reading.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
|
|
10
|
+
device (torch.device): Specify running device.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, dataloader: DataLoader, device: torch.device):
|
|
14
|
+
self.batch_data = None
|
|
15
|
+
self.original_dataloader = dataloader
|
|
16
|
+
self.device = device
|
|
17
|
+
|
|
18
|
+
self.data = iter(dataloader)
|
|
19
|
+
self.stream = torch.cuda.Stream()
|
|
20
|
+
self.preload()
|
|
21
|
+
|
|
22
|
+
def preload(self):
|
|
23
|
+
"""
|
|
24
|
+
Load the next batch of data and move it to the specified device asynchronously.
|
|
25
|
+
"""
|
|
26
|
+
try:
|
|
27
|
+
self.batch_data = next(self.data)
|
|
28
|
+
except StopIteration:
|
|
29
|
+
self.batch_data = None
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
# Asynchronously move the data to the GPU
|
|
33
|
+
with torch.cuda.stream(self.stream):
|
|
34
|
+
if isinstance(self.batch_data, dict):
|
|
35
|
+
self.batch_data = {
|
|
36
|
+
k: v.to(self.device, non_blocking=True)
|
|
37
|
+
if torch.is_tensor(v) else v
|
|
38
|
+
for k, v in self.batch_data.items()
|
|
39
|
+
}
|
|
40
|
+
elif isinstance(self.batch_data, (list, tuple)):
|
|
41
|
+
self.batch_data = [
|
|
42
|
+
x.to(self.device, non_blocking=True)
|
|
43
|
+
if torch.is_tensor(x) else x
|
|
44
|
+
for x in self.batch_data
|
|
45
|
+
]
|
|
46
|
+
elif torch.is_tensor(self.batch_data):
|
|
47
|
+
self.batch_data = self.batch_data.to(self.device, non_blocking=True)
|
|
48
|
+
else:
|
|
49
|
+
raise TypeError(
|
|
50
|
+
f"Unsupported data type {type(self.batch_data)}. Ensure the dataloader outputs tensors, dicts, or lists."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def __iter__(self):
|
|
54
|
+
"""
|
|
55
|
+
Make the object iterable by resetting the dataloader and returning self.
|
|
56
|
+
"""
|
|
57
|
+
self.reset()
|
|
58
|
+
return self
|
|
59
|
+
|
|
60
|
+
def __next__(self):
|
|
61
|
+
"""
|
|
62
|
+
Wait for the current stream, return the current batch, and preload the next batch.
|
|
63
|
+
"""
|
|
64
|
+
if self.batch_data is None:
|
|
65
|
+
raise StopIteration
|
|
66
|
+
|
|
67
|
+
torch.cuda.current_stream().wait_stream(self.stream)
|
|
68
|
+
batch_data = self.batch_data
|
|
69
|
+
self.preload()
|
|
70
|
+
return batch_data
|
|
71
|
+
|
|
72
|
+
def reset(self):
|
|
73
|
+
"""
|
|
74
|
+
Reset the dataloader iterator and preload the first batch.
|
|
75
|
+
"""
|
|
76
|
+
self.data = iter(self.original_dataloader)
|
|
77
|
+
self.preload()
|
|
78
|
+
|
|
79
|
+
def __len__(self):
|
|
80
|
+
"""
|
|
81
|
+
Return the length of the dataloader.
|
|
82
|
+
"""
|
|
83
|
+
return len(self.original_dataloader)
|
|
84
|
+
|
|
85
|
+
# train_data_prefetcher = CUDAPrefetcher1(degenerated_train_dataloader, device)
|
|
86
|
+
|
|
87
|
+
import torch
|
|
88
|
+
from torch.utils.data import DataLoader
|
|
89
|
+
import torch.nn.functional as F
|
|
90
|
+
def custom_collate_fn(batch):
|
|
91
|
+
"""
|
|
92
|
+
Custom collate function to handle batches with varying tensor sizes by padding them to the same size.
|
|
93
|
+
"""
|
|
94
|
+
if isinstance(batch[0], dict):
|
|
95
|
+
collated = {key: custom_collate_fn([d[key] for d in batch]) for key in batch[0]}
|
|
96
|
+
return collated
|
|
97
|
+
elif isinstance(batch[0], torch.Tensor):
|
|
98
|
+
# Find max dimensions
|
|
99
|
+
max_height = max(item.shape[1] for item in batch)
|
|
100
|
+
max_width = max(item.shape[2] for item in batch)
|
|
101
|
+
|
|
102
|
+
# Pad each tensor to the same size
|
|
103
|
+
padded_batch = []
|
|
104
|
+
for item in batch:
|
|
105
|
+
padding = (0, max_width - item.shape[2], 0, max_height - item.shape[1])
|
|
106
|
+
padded_item = F.pad(item, padding, mode="constant", value=0)
|
|
107
|
+
padded_batch.append(padded_item)
|
|
108
|
+
return torch.stack(padded_batch)
|
|
109
|
+
else:
|
|
110
|
+
return batch
|
|
111
|
+
# device = torch.device("cuda:0")
|
|
112
|
+
# paired_test_dataloader = DataLoader(datasetdata,device)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# 加载最佳模型权重
|
|
2
|
+
import torch
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from shancx import crDir
|
|
5
|
+
def Fakeimage(generator,device,torchvision.utils.make_grid,):
|
|
6
|
+
generator.load_state_dict(torch.load('/home/scx/train_Gan/best_generator_weights.pth'))
|
|
7
|
+
# 设置模型为评估模式
|
|
8
|
+
generator.eval()
|
|
9
|
+
with torch.no_grad():
|
|
10
|
+
fixed_noise = torch.randn(64, 100, device=device) # 将噪声张量直接创建在 GPU 上
|
|
11
|
+
fake_images = generator(fixed_noise)
|
|
12
|
+
# 检查生成图像的形状
|
|
13
|
+
print(f"生成的图像形状: {fake_images.shape}")
|
|
14
|
+
# 反归一化处理
|
|
15
|
+
fake_images = (fake_images + 1) / 2 # 将图像像素值从[-1, 1]映射到[0, 1]
|
|
16
|
+
print(f"图像最小值: {fake_images.min()}, 最大值: {fake_images.max()}")
|
|
17
|
+
grid = torchvision.utils.make_grid(fake_images, nrow=8, normalize=True) # 设置normalize=True以确保正确的缩放
|
|
18
|
+
grid_cpu = grid.cpu().detach().numpy()
|
|
19
|
+
plt.imshow(grid_cpu.transpose(1, 2, 0)) # 调整图像通道顺序为(H, W, C)
|
|
20
|
+
plt.axis('off') # 关闭坐标轴显示
|
|
21
|
+
outPath = './fake_images/best_weights_fake_images.png'
|
|
22
|
+
crDir(outPath)
|
|
23
|
+
plt.savefig(outPath)
|
|
24
|
+
plt.close()
|
shancx/Algo/Hsml.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
|
|
2
|
+
import torchvision.models as models
|
|
3
|
+
from torch import nn
|
|
4
|
+
import torch
|
|
5
|
+
from torchvision.models import vgg19
|
|
6
|
+
# Define VGG Loss
|
|
7
|
+
class VGGLoss(nn.Module):
|
|
8
|
+
def __init__(self, weights_path=None,device =None):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.vgg = models.vgg19(pretrained=False).features[:35].eval().to(device)
|
|
11
|
+
if weights_path:
|
|
12
|
+
pretrained_weights = torch.load(weights_path)
|
|
13
|
+
self.vgg.load_state_dict(pretrained_weights, strict=False)
|
|
14
|
+
self.vgg[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1).to(device)
|
|
15
|
+
for param in self.vgg.parameters():
|
|
16
|
+
param.requires_grad = False
|
|
17
|
+
self.loss = nn.MSELoss()
|
|
18
|
+
def forward(self, input, target):
|
|
19
|
+
input_features = self.vgg(input)
|
|
20
|
+
target_features = self.vgg(target)
|
|
21
|
+
return self.loss(input_features, target_features)
|
|
22
|
+
"""
|
|
23
|
+
vgg_loss = VGGLoss(weights_path="/mnt/wtx_weather_forecast/scx/stat/sat/sat2radar/vgg19-dcbb9e9d.pth").to(device)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
28
|
+
import torch.nn.functional as F
|
|
29
|
+
from segment_anything import sam_model_registry
|
|
30
|
+
|
|
31
|
+
class SAMLoss(nn.Module):
|
|
32
|
+
def __init__(self, model_type='vit_b', checkpoint_path=None, input_size=1024):
|
|
33
|
+
"""
|
|
34
|
+
SAM-based perceptual loss with resolution handling
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
model_type (str): SAM model type (vit_b, vit_l, vit_h)
|
|
38
|
+
checkpoint_path (str): Path to SAM checkpoint weights
|
|
39
|
+
input_size (int): Target input size for SAM (default 1024)
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.input_size = input_size
|
|
43
|
+
# Initialize SAM model
|
|
44
|
+
self.sam = sam_model_registry[model_type](checkpoint=None)
|
|
45
|
+
# Load pretrained weights if provided
|
|
46
|
+
if checkpoint_path:
|
|
47
|
+
state_dict = torch.load(checkpoint_path)
|
|
48
|
+
self.sam.load_state_dict(state_dict)
|
|
49
|
+
# Use image encoder only and freeze parameters
|
|
50
|
+
self.image_encoder = self.sam.image_encoder.eval()
|
|
51
|
+
for param in self.image_encoder.parameters():
|
|
52
|
+
param.requires_grad = False
|
|
53
|
+
# Define loss function
|
|
54
|
+
self.loss = nn.MSELoss()
|
|
55
|
+
# Normalization parameters
|
|
56
|
+
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
|
57
|
+
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
|
58
|
+
|
|
59
|
+
def preprocess(self, x):
|
|
60
|
+
"""
|
|
61
|
+
Preprocess input to match SAM requirements:
|
|
62
|
+
1. Convert to 3-channel if needed
|
|
63
|
+
2. Normalize using ImageNet stats
|
|
64
|
+
3. Resize to target size
|
|
65
|
+
"""
|
|
66
|
+
if x.shape[1] == 1:
|
|
67
|
+
x = x.expand(-1, 3, -1, -1) # More memory efficient than repeat
|
|
68
|
+
# Normalize
|
|
69
|
+
x = (x - self.mean) / self.std
|
|
70
|
+
# Resize
|
|
71
|
+
if x.shape[-2:] != (self.input_size, self.input_size):
|
|
72
|
+
x = F.interpolate(x, size=(self.input_size, self.input_size),
|
|
73
|
+
mode='bilinear', align_corners=False)
|
|
74
|
+
return x
|
|
75
|
+
def forward(self, input, target):
|
|
76
|
+
# Preprocess
|
|
77
|
+
input = self.preprocess(input)
|
|
78
|
+
target = self.preprocess(target)
|
|
79
|
+
# Process in batches if needed
|
|
80
|
+
batch_size = 4 # Adjust based on your GPU memory
|
|
81
|
+
input_features = []
|
|
82
|
+
target_features = []
|
|
83
|
+
with torch.no_grad():
|
|
84
|
+
for i in range(0, input.size(0), batch_size):
|
|
85
|
+
input_batch = input[i:i+batch_size]
|
|
86
|
+
target_batch = target[i:i+batch_size]
|
|
87
|
+
input_features.append(self.image_encoder(input_batch))
|
|
88
|
+
target_features.append(self.image_encoder(target_batch))
|
|
89
|
+
return self.loss(torch.cat(input_features), torch.cat(target_features))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
import torch
|
|
93
|
+
import torch.nn as nn
|
|
94
|
+
import torch.nn.functional as F
|
|
95
|
+
from segment_anything import sam_model_registry
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class SAMLoss(nn.Module):
|
|
99
|
+
def __init__(self, model_type='vit_b', checkpoint_path=None):
|
|
100
|
+
super().__init__()
|
|
101
|
+
# 初始化 SAM 并加载预训练权重
|
|
102
|
+
self.sam = sam_model_registry[model_type](checkpoint=None)
|
|
103
|
+
if checkpoint_path:
|
|
104
|
+
state_dict = torch.load(checkpoint_path)
|
|
105
|
+
self.sam.load_state_dict(state_dict)
|
|
106
|
+
|
|
107
|
+
# 提取 image_encoder 并冻结
|
|
108
|
+
self.image_encoder = self.sam.image_encoder.eval()
|
|
109
|
+
for param in self.image_encoder.parameters():
|
|
110
|
+
param.requires_grad = False
|
|
111
|
+
|
|
112
|
+
# 获取 patch_size(通常为16)
|
|
113
|
+
self.patch_size = self.image_encoder.patch_embed.proj.kernel_size[0]
|
|
114
|
+
|
|
115
|
+
# 保存原始 pos_embed 供动态调整
|
|
116
|
+
self.original_pos_embed = self.image_encoder.pos_embed
|
|
117
|
+
|
|
118
|
+
# 归一化参数
|
|
119
|
+
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
|
120
|
+
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
|
121
|
+
|
|
122
|
+
self.loss = nn.MSELoss()
|
|
123
|
+
|
|
124
|
+
def adjust_pos_embed(self, x):
|
|
125
|
+
"""
|
|
126
|
+
动态调整 pos_embed 以匹配输入尺寸
|
|
127
|
+
Args:
|
|
128
|
+
x: 输入张量 (B, C, H, W)
|
|
129
|
+
"""
|
|
130
|
+
B, _, H, W = x.shape
|
|
131
|
+
# 计算当前特征图的分辨率(H/patch_size, W/patch_size)
|
|
132
|
+
h, w = H // self.patch_size, W // self.patch_size
|
|
133
|
+
|
|
134
|
+
# 如果 pos_embed 尺寸不匹配,则进行插值
|
|
135
|
+
if (h, w) != self.original_pos_embed.shape[1:3]:
|
|
136
|
+
# 插值 pos_embed 到目标尺寸 (1, h, w, C) -> (1, C, h, w) -> 插值 -> 恢复形状
|
|
137
|
+
pos_embed = self.original_pos_embed.permute(0, 3, 1, 2) # (1, C, H_orig, W_orig)
|
|
138
|
+
pos_embed = F.interpolate(
|
|
139
|
+
pos_embed,
|
|
140
|
+
size=(h, w),
|
|
141
|
+
mode='bilinear',
|
|
142
|
+
align_corners=False
|
|
143
|
+
).permute(0, 2, 3, 1) # (1, h, w, C)
|
|
144
|
+
|
|
145
|
+
# 用 nn.Parameter 包装调整后的 pos_embed
|
|
146
|
+
self.image_encoder.pos_embed = nn.Parameter(pos_embed, requires_grad=False)
|
|
147
|
+
else:
|
|
148
|
+
self.image_encoder.pos_embed = self.original_pos_embed
|
|
149
|
+
def preprocess(self, x):
|
|
150
|
+
"""预处理:通道扩展 + 归一化 + 尺寸调整"""
|
|
151
|
+
if x.shape[1] == 1:
|
|
152
|
+
x = x.expand(-1, 3, -1, -1) # 单通道→三通
|
|
153
|
+
# 归一化
|
|
154
|
+
x = (x - self.mean) / self.std
|
|
155
|
+
# 确保尺寸是 patch_size 的整数倍
|
|
156
|
+
B, C, H, W = x.shape
|
|
157
|
+
H_new = (H // self.patch_size) * self.patch_size
|
|
158
|
+
W_new = (W // self.patch_size) * self.patch_size
|
|
159
|
+
if H != H_new or W != W_new:
|
|
160
|
+
x = F.interpolate(
|
|
161
|
+
x,
|
|
162
|
+
size=(H_new, W_new),
|
|
163
|
+
mode='bilinear',
|
|
164
|
+
align_corners=False
|
|
165
|
+
)
|
|
166
|
+
return x
|
|
167
|
+
def forward(self, input, target):
|
|
168
|
+
# 预处理
|
|
169
|
+
input = self.preprocess(input)
|
|
170
|
+
target = self.preprocess(target)
|
|
171
|
+
|
|
172
|
+
# 动态调整 pos_embed
|
|
173
|
+
self.adjust_pos_embed(input)
|
|
174
|
+
|
|
175
|
+
# 计算特征
|
|
176
|
+
with torch.no_grad():
|
|
177
|
+
input_feat = self.image_encoder(input)
|
|
178
|
+
target_feat = self.image_encoder(target)
|
|
179
|
+
|
|
180
|
+
return self.loss(input_feat, target_feat)
|
|
181
|
+
|
|
182
|
+
# saml = SAMLoss(checkpoint_path = "/path/to/sam_vit_b_01ec64.pth.1").to(device)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
import torch
|
|
186
|
+
import torch.nn as nn
|
|
187
|
+
from torchvision import models
|
|
188
|
+
class WeightedVGGLoss(nn.Module):
|
|
189
|
+
def __init__(self, weights_path=None, device=None, apply_weighting=True):
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.device = device
|
|
192
|
+
self.apply_weighting = apply_weighting
|
|
193
|
+
self.vgg = models.vgg19(pretrained=False).features[:35].eval().to(device)
|
|
194
|
+
if weights_path:
|
|
195
|
+
pretrained_weights = torch.load(weights_path)
|
|
196
|
+
self.vgg.load_state_dict(pretrained_weights, strict=False)
|
|
197
|
+
self.vgg[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1).to(device)
|
|
198
|
+
for param in self.vgg.parameters():
|
|
199
|
+
param.requires_grad = False
|
|
200
|
+
self.mse_loss = nn.MSELoss()
|
|
201
|
+
def custom_weight_4d(self, target):
|
|
202
|
+
EPS = 1e-6
|
|
203
|
+
segments = [
|
|
204
|
+
[10, 20, 0.5, 0.1], [20, 30, 0.4, 0.12],
|
|
205
|
+
[30, 40, 0.3, 0.14], [40, 50, 0.2, 0.16],
|
|
206
|
+
[50, 60, 0.1, 0.18], [60, 70, 0.05, 0.2]
|
|
207
|
+
]
|
|
208
|
+
weights = torch.full_like(target, 0.1)
|
|
209
|
+
for start, end, steepness, base_w in segments:
|
|
210
|
+
mask = (target >= start) & (target < end)
|
|
211
|
+
x = ((target - start) / (end - start + EPS)).clamp(0, 1)
|
|
212
|
+
delta_w = (1.0 - base_w) * (0.7 - steepness)
|
|
213
|
+
seg_weights = base_w + delta_w * (x / (x + steepness*(1 - x) + EPS))
|
|
214
|
+
weights = torch.where(mask, seg_weights, weights)
|
|
215
|
+
weights[target >= 70] = 1.0
|
|
216
|
+
return weights
|
|
217
|
+
def forward(self, input, target):
|
|
218
|
+
input_features = self.vgg(input)
|
|
219
|
+
target_features = self.vgg(target)
|
|
220
|
+
if self.apply_weighting:
|
|
221
|
+
weights = self.custom_weight_4d(target_features.detach())
|
|
222
|
+
loss = torch.mean(weights * (input_features - target_features)**2)
|
|
223
|
+
else:
|
|
224
|
+
loss = self.mse_loss(input_features, target_features)
|
|
225
|
+
return loss
|
|
226
|
+
# vgg_loss = WeightedVGGLoss(weights_path="/mnt/wtx_weather_forecast/scx/stat/sat/sat2radar/vgg19-dcbb9e9d.pth",device=device).to(device)
|
|
227
|
+
|
|
228
|
+
import torch
|
|
229
|
+
import torch.nn as nn
|
|
230
|
+
import torch.nn.functional as F
|
|
231
|
+
class WeightedL1Loss(nn.Module):
|
|
232
|
+
def __init__(self, segment_ranges=None, device=None):
|
|
233
|
+
super().__init__()
|
|
234
|
+
self.device = device
|
|
235
|
+
self.segment_ranges = segment_ranges or [
|
|
236
|
+
(10, 20, 0.5, 0.1), (20, 30, 0.4, 0.12),
|
|
237
|
+
(30, 40, 0.3, 0.14), (40, 50, 0.2, 0.16),
|
|
238
|
+
(50, 60, 0.1, 0.18), (60, 70, 0.05, 0.2)
|
|
239
|
+
]
|
|
240
|
+
def compute_weights(self, target):
|
|
241
|
+
EPS = 1e-6
|
|
242
|
+
weights = torch.full_like(target, 0.1) # 默认权重0.1
|
|
243
|
+
for start, end, steepness, base_w in self.segment_ranges:
|
|
244
|
+
mask = (target >= start) & (target < end)
|
|
245
|
+
x = ((target - start) / (end - start + EPS)).clamp(0, 1)
|
|
246
|
+
delta_w = (1.0 - base_w) * (0.7 - steepness)
|
|
247
|
+
seg_weights = base_w + delta_w * (x / (x + steepness*(1 - x) + EPS))
|
|
248
|
+
weights = torch.where(mask, seg_weights, weights)
|
|
249
|
+
weights[target >= 70] = 1.0
|
|
250
|
+
return weights
|
|
251
|
+
def forward(self, output, target):
|
|
252
|
+
"""
|
|
253
|
+
非降维(reduction='none')计算流程:
|
|
254
|
+
1. 计算逐元素L1损失
|
|
255
|
+
2. 生成动态权重矩阵
|
|
256
|
+
3. 返回加权损失张量(保持输入维度)
|
|
257
|
+
"""
|
|
258
|
+
# 保持维度的L1损失 (B,C,H,W)
|
|
259
|
+
l1_loss = F.l1_loss(output, target, reduction='none')
|
|
260
|
+
weights = self.compute_weights(target.detach()) # 切断梯度
|
|
261
|
+
# weighted_loss = torch.mean(weights * l1_loss)
|
|
262
|
+
weighted_loss = weights * l1_loss
|
|
263
|
+
return weighted_loss
|
|
264
|
+
# L1loss = WeightedL1Loss(device= device)
|
|
265
|
+
|
|
266
|
+
import torch
|
|
267
|
+
import torch.nn as nn
|
|
268
|
+
import torch.nn.functional as F
|
|
269
|
+
from segment_anything import sam_model_registry
|
|
270
|
+
class SAMLoss1(nn.Module):
|
|
271
|
+
def __init__(self, model_type='vit_b', checkpoint_path=None):
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.sam = sam_model_registry[model_type](checkpoint=None)
|
|
274
|
+
if checkpoint_path:
|
|
275
|
+
state_dict = torch.load(checkpoint_path)
|
|
276
|
+
self.sam.load_state_dict(state_dict)
|
|
277
|
+
self.image_encoder = self.sam.image_encoder.eval()
|
|
278
|
+
for param in self.image_encoder.parameters():
|
|
279
|
+
param.requires_grad = False
|
|
280
|
+
self.patch_size = self.image_encoder.patch_embed.proj.kernel_size[0]
|
|
281
|
+
self.original_pos_embed = self.image_encoder.pos_embed
|
|
282
|
+
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
|
283
|
+
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
|
284
|
+
self.loss = nn.MSELoss()
|
|
285
|
+
def adjust_pos_embed(self, x):
|
|
286
|
+
"""
|
|
287
|
+
动态调整 pos_embed 以匹配输入尺寸
|
|
288
|
+
Args:
|
|
289
|
+
x: 输入张量 (B, C, H, W)
|
|
290
|
+
"""
|
|
291
|
+
B, _, H, W = x.shape
|
|
292
|
+
h, w = H // self.patch_size, W // self.patch_size
|
|
293
|
+
if (h, w) != self.original_pos_embed.shape[1:3]:
|
|
294
|
+
pos_embed = self.original_pos_embed.permute(0, 3, 1, 2) # (1, C, H_orig, W_orig)
|
|
295
|
+
pos_embed = F.interpolate(
|
|
296
|
+
pos_embed,
|
|
297
|
+
size=(h, w),
|
|
298
|
+
mode='bilinear',
|
|
299
|
+
align_corners=False
|
|
300
|
+
).permute(0, 2, 3, 1) # (1, h, w, C)
|
|
301
|
+
self.image_encoder.pos_embed = nn.Parameter(pos_embed, requires_grad=False)
|
|
302
|
+
else:
|
|
303
|
+
self.image_encoder.pos_embed = self.original_pos_embed
|
|
304
|
+
def preprocess(self, x):
|
|
305
|
+
"""预处理:通道扩展 + 归一化 + 尺寸调整"""
|
|
306
|
+
if x.shape[1] == 1:
|
|
307
|
+
x = x.expand(-1, 3, -1, -1) # 单通道→三通道
|
|
308
|
+
x = (x - self.mean) / self.std
|
|
309
|
+
B, C, H, W = x.shape
|
|
310
|
+
H_new = (H // self.patch_size) * self.patch_size
|
|
311
|
+
W_new = (W // self.patch_size) * self.patch_size
|
|
312
|
+
if H != H_new or W != W_new:
|
|
313
|
+
x = F.interpolate(
|
|
314
|
+
x,
|
|
315
|
+
size=(H_new, W_new),
|
|
316
|
+
mode='bilinear',
|
|
317
|
+
align_corners=False
|
|
318
|
+
)
|
|
319
|
+
return x
|
|
320
|
+
def forward(self, input, target):
|
|
321
|
+
input = self.preprocess(input)
|
|
322
|
+
target = self.preprocess(target)
|
|
323
|
+
self.adjust_pos_embed(input)
|
|
324
|
+
with torch.no_grad():
|
|
325
|
+
input_feat = self.image_encoder(input)
|
|
326
|
+
target_feat = self.image_encoder(target)
|
|
327
|
+
return self.loss(input_feat, target_feat)
|
|
328
|
+
|
|
329
|
+
import torch
|
|
330
|
+
import torch.nn as nn
|
|
331
|
+
import torch.nn.functional as F
|
|
332
|
+
from segment_anything import sam_model_registry
|
|
333
|
+
class SAMLoss2(nn.Module):
|
|
334
|
+
def __init__(self, model_type='vit_b', checkpoint_path=None):
|
|
335
|
+
super().__init__()
|
|
336
|
+
self.sam = sam_model_registry[model_type](checkpoint=None)
|
|
337
|
+
if checkpoint_path:
|
|
338
|
+
state_dict = torch.load(checkpoint_path)
|
|
339
|
+
self.sam.load_state_dict(state_dict)
|
|
340
|
+
self.image_encoder = self.sam.image_encoder.eval()
|
|
341
|
+
for param in self.image_encoder.parameters():
|
|
342
|
+
param.requires_grad = False
|
|
343
|
+
self.patch_size = self.image_encoder.patch_embed.proj.kernel_size[0]
|
|
344
|
+
self.original_pos_embed = self.image_encoder.pos_embed
|
|
345
|
+
self.loss = nn.MSELoss()
|
|
346
|
+
def adjust_pos_embed(self, x):
|
|
347
|
+
"""
|
|
348
|
+
动态调整 pos_embed 以匹配输入尺寸
|
|
349
|
+
Args:
|
|
350
|
+
x: 输入张量 (B, C, H, W)
|
|
351
|
+
"""
|
|
352
|
+
B, _, H, W = x.shape
|
|
353
|
+
h, w = H // self.patch_size, W // self.patch_size
|
|
354
|
+
if (h, w) != self.original_pos_embed.shape[1:3]:
|
|
355
|
+
pos_embed = self.original_pos_embed.permute(0, 3, 1, 2) # (1, C, H_orig, W_orig)
|
|
356
|
+
pos_embed = F.interpolate(
|
|
357
|
+
pos_embed,
|
|
358
|
+
size=(h, w),
|
|
359
|
+
mode='bilinear',
|
|
360
|
+
align_corners=False
|
|
361
|
+
).permute(0, 2, 3, 1) # (1, h, w, C)
|
|
362
|
+
self.image_encoder.pos_embed = nn.Parameter(pos_embed, requires_grad=False)
|
|
363
|
+
else:
|
|
364
|
+
self.image_encoder.pos_embed = self.original_pos_embed
|
|
365
|
+
def preprocess(self, x):
|
|
366
|
+
"""预处理:仅通道扩展 + 尺寸调整(跳过归一化)"""
|
|
367
|
+
if x.shape[1] == 1:
|
|
368
|
+
x = x.expand(-1, 3, -1, -1) # 单通道→三通道
|
|
369
|
+
B, C, H, W = x.shape
|
|
370
|
+
H_new = (H // self.patch_size) * self.patch_size
|
|
371
|
+
W_new = (W // self.patch_size) * self.patch_size
|
|
372
|
+
if H != H_new or W != W_new:
|
|
373
|
+
x = F.interpolate(
|
|
374
|
+
x,
|
|
375
|
+
size=(H_new, W_new),
|
|
376
|
+
mode='bilinear',
|
|
377
|
+
align_corners=False
|
|
378
|
+
)
|
|
379
|
+
return x
|
|
380
|
+
def forward(self, input, target):
|
|
381
|
+
# 预处理(input和target均不归一化)
|
|
382
|
+
input = self.preprocess(input)
|
|
383
|
+
target = self.preprocess(target)
|
|
384
|
+
# 动态调整 pos_embed
|
|
385
|
+
self.adjust_pos_embed(input)
|
|
386
|
+
# 计算特征
|
|
387
|
+
with torch.no_grad():
|
|
388
|
+
input_feat = self.image_encoder(input)
|
|
389
|
+
target_feat = self.image_encoder(target)
|
|
390
|
+
return self.loss(input_feat, target_feat)
|
|
391
|
+
#saml = SAMLoss(checkpoint_path = "/mnt/wtx_weather_forecast/scx/stat/sat/sat2radar/sam_vit_b_01ec64.pth.1").to(device)
|