shancx 1.8.92__py3-none-any.whl → 1.9.33.218__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shancx/3D/__init__.py +25 -0
- shancx/Algo/Class.py +11 -0
- shancx/Algo/CudaPrefetcher1.py +112 -0
- shancx/Algo/Fake_image.py +24 -0
- shancx/Algo/Hsml.py +391 -0
- shancx/Algo/L2Loss.py +10 -0
- shancx/Algo/MetricTracker.py +132 -0
- shancx/Algo/Normalize.py +66 -0
- shancx/Algo/OptimizerWScheduler.py +38 -0
- shancx/Algo/Rmageresize.py +79 -0
- shancx/Algo/Savemodel.py +33 -0
- shancx/Algo/SmoothL1_losses.py +27 -0
- shancx/Algo/Tqdm.py +62 -0
- shancx/Algo/__init__.py +121 -0
- shancx/Algo/checknan.py +28 -0
- shancx/Algo/iouJU.py +83 -0
- shancx/Algo/mask.py +25 -0
- shancx/Algo/psnr.py +9 -0
- shancx/Algo/ssim.py +70 -0
- shancx/Algo/structural_similarity.py +308 -0
- shancx/Algo/tool.py +704 -0
- shancx/Calmetrics/__init__.py +97 -0
- shancx/Calmetrics/calmetrics.py +14 -0
- shancx/Calmetrics/calmetricsmatrixLib.py +147 -0
- shancx/Calmetrics/rmseR2score.py +35 -0
- shancx/Clip/__init__.py +50 -0
- shancx/Cmd.py +126 -0
- shancx/Config_.py +26 -0
- shancx/Df/DataFrame.py +11 -2
- shancx/Df/__init__.py +17 -0
- shancx/Df/tool.py +0 -0
- shancx/Diffm/Psamples.py +18 -0
- shancx/Diffm/__init__.py +0 -0
- shancx/Diffm/test.py +207 -0
- shancx/Doc/__init__.py +214 -0
- shancx/E/__init__.py +178 -152
- shancx/Fillmiss/__init__.py +0 -0
- shancx/Fillmiss/imgidwJU.py +46 -0
- shancx/Fillmiss/imgidwLatLonJU.py +82 -0
- shancx/Gpu/__init__.py +55 -0
- shancx/H9/__init__.py +126 -0
- shancx/H9/ahi_read_hsd.py +877 -0
- shancx/H9/ahisearchtable.py +298 -0
- shancx/H9/geometry.py +2439 -0
- shancx/Hug/__init__.py +81 -0
- shancx/Inst.py +22 -0
- shancx/Lib.py +31 -0
- shancx/Mos/__init__.py +37 -0
- shancx/NN/__init__.py +235 -106
- shancx/Path1.py +161 -0
- shancx/Plot/GlobMap.py +276 -116
- shancx/Plot/__init__.py +491 -1
- shancx/Plot/draw_day_CR_PNG.py +4 -21
- shancx/Plot/exam.py +116 -0
- shancx/Plot/plotGlobal.py +325 -0
- shancx/{radar_nmc.py → Plot/radarNmc.py} +4 -34
- shancx/{subplots_single_china_map.py → Plot/single_china_map.py} +1 -1
- shancx/Point.py +46 -0
- shancx/QC.py +223 -0
- shancx/RdPzl/__init__.py +32 -0
- shancx/Read.py +72 -0
- shancx/Resize.py +79 -0
- shancx/SN/__init__.py +62 -123
- shancx/Time/GetTime.py +9 -3
- shancx/Time/__init__.py +66 -1
- shancx/Time/timeCycle.py +302 -0
- shancx/Time/tool.py +0 -0
- shancx/Train/__init__.py +74 -0
- shancx/Train/makelist.py +187 -0
- shancx/Train/multiGpu.py +27 -0
- shancx/Train/prepare.py +161 -0
- shancx/Train/renet50.py +157 -0
- shancx/ZR.py +12 -0
- shancx/__init__.py +333 -262
- shancx/args.py +27 -0
- shancx/bak.py +768 -0
- shancx/df2database.py +62 -2
- shancx/geosProj.py +80 -0
- shancx/info.py +38 -0
- shancx/netdfJU.py +231 -0
- shancx/sendM.py +59 -0
- shancx/tensBoard/__init__.py +28 -0
- shancx/wait.py +246 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/METADATA +15 -5
- shancx-1.9.33.218.dist-info/RECORD +91 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
- my_timer_decorator/__init__.py +0 -10
- shancx/Dsalgor/__init__.py +0 -19
- shancx/E/DFGRRIB.py +0 -30
- shancx/EN/DFGRRIB.py +0 -30
- shancx/EN/__init__.py +0 -148
- shancx/FileRead.py +0 -44
- shancx/Gray2RGB.py +0 -86
- shancx/M/__init__.py +0 -137
- shancx/MN/__init__.py +0 -133
- shancx/N/__init__.py +0 -131
- shancx/Plot/draw_day_CR_PNGUS.py +0 -206
- shancx/Plot/draw_day_CR_SVG.py +0 -275
- shancx/Plot/draw_day_pre_PNGUS.py +0 -205
- shancx/Plot/glob_nation_map.py +0 -116
- shancx/Plot/radar_nmc.py +0 -61
- shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
- shancx/Plot/radar_nmc_china_map_f.py +0 -121
- shancx/Plot/radar_nmc_us_map_f.py +0 -128
- shancx/Plot/subplots_compare_devlop.py +0 -36
- shancx/Plot/subplots_single_china_map.py +0 -45
- shancx/S/__init__.py +0 -138
- shancx/W/__init__.py +0 -132
- shancx/WN/__init__.py +0 -132
- shancx/code.py +0 -331
- shancx/draw_day_CR_PNG.py +0 -200
- shancx/draw_day_CR_PNGUS.py +0 -206
- shancx/draw_day_CR_SVG.py +0 -275
- shancx/draw_day_pre_PNGUS.py +0 -205
- shancx/makenetCDFN.py +0 -42
- shancx/mkIMGSCX.py +0 -92
- shancx/netCDF.py +0 -130
- shancx/radar_nmc_china_map_compare1.py +0 -50
- shancx/radar_nmc_china_map_f.py +0 -125
- shancx/radar_nmc_us_map_f.py +0 -67
- shancx/subplots_compare_devlop.py +0 -36
- shancx/tool.py +0 -18
- shancx/user/H8mess.py +0 -317
- shancx/user/__init__.py +0 -137
- shancx/user/cinradHJN.py +0 -496
- shancx/user/examMeso.py +0 -293
- shancx/user/hjnDAAS.py +0 -26
- shancx/user/hjnFTP.py +0 -81
- shancx/user/hjnGIS.py +0 -320
- shancx/user/hjnGPU.py +0 -21
- shancx/user/hjnIDW.py +0 -68
- shancx/user/hjnKDTree.py +0 -75
- shancx/user/hjnLAPSTransform.py +0 -47
- shancx/user/hjnMiscellaneous.py +0 -182
- shancx/user/hjnProj.py +0 -162
- shancx/user/inotify.py +0 -41
- shancx/user/matplotlibMess.py +0 -87
- shancx/user/mkNCHJN.py +0 -623
- shancx/user/newTypeRadar.py +0 -492
- shancx/user/test.py +0 -6
- shancx/user/tlogP.py +0 -129
- shancx/util_log.py +0 -33
- shancx/wtx/H8mess.py +0 -315
- shancx/wtx/__init__.py +0 -151
- shancx/wtx/cinradHJN.py +0 -496
- shancx/wtx/colormap.py +0 -64
- shancx/wtx/examMeso.py +0 -298
- shancx/wtx/hjnDAAS.py +0 -26
- shancx/wtx/hjnFTP.py +0 -81
- shancx/wtx/hjnGIS.py +0 -330
- shancx/wtx/hjnGPU.py +0 -21
- shancx/wtx/hjnIDW.py +0 -68
- shancx/wtx/hjnKDTree.py +0 -75
- shancx/wtx/hjnLAPSTransform.py +0 -47
- shancx/wtx/hjnLog.py +0 -78
- shancx/wtx/hjnMiscellaneous.py +0 -201
- shancx/wtx/hjnProj.py +0 -161
- shancx/wtx/inotify.py +0 -41
- shancx/wtx/matplotlibMess.py +0 -87
- shancx/wtx/mkNCHJN.py +0 -613
- shancx/wtx/newTypeRadar.py +0 -492
- shancx/wtx/test.py +0 -6
- shancx/wtx/tlogP.py +0 -129
- shancx-1.8.92.dist-info/RECORD +0 -99
- /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
shancx/Algo/tool.py
ADDED
|
@@ -0,0 +1,704 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.optim as optim
|
|
4
|
+
from torch.utils.data import DataLoader, Dataset
|
|
5
|
+
import torchvision.transforms as transforms
|
|
6
|
+
from torchvision import transforms
|
|
7
|
+
import os
|
|
8
|
+
from PIL import Image
|
|
9
|
+
import numpy as np
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
import glob
|
|
12
|
+
|
|
13
|
+
# 简化的 ResUNet 架构
|
|
14
|
+
class ResUNet(nn.Module):
|
|
15
|
+
def __init__(self, in_channels, out_channels):
|
|
16
|
+
super(ResUNet, self).__init__()
|
|
17
|
+
self.encoder = ResNetEncoder(in_channels)
|
|
18
|
+
self.decoder = UNetDecoder(out_channels)
|
|
19
|
+
|
|
20
|
+
def forward(self, x):
|
|
21
|
+
enc_out = self.encoder(x)
|
|
22
|
+
dec_out = self.decoder(enc_out)
|
|
23
|
+
return dec_out
|
|
24
|
+
|
|
25
|
+
class ResNetEncoder(nn.Module):
|
|
26
|
+
def __init__(self, in_channels):
|
|
27
|
+
super(ResNetEncoder, self).__init__()
|
|
28
|
+
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
|
|
29
|
+
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
|
30
|
+
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
|
|
31
|
+
# 可以继续添加更多层
|
|
32
|
+
|
|
33
|
+
def forward(self, x):
|
|
34
|
+
x = torch.relu(self.conv1(x))
|
|
35
|
+
x = torch.relu(self.conv2(x))
|
|
36
|
+
x = torch.relu(self.conv3(x))
|
|
37
|
+
return x
|
|
38
|
+
|
|
39
|
+
class UNetDecoder(nn.Module):
|
|
40
|
+
def __init__(self, out_channels):
|
|
41
|
+
super(UNetDecoder, self).__init__()
|
|
42
|
+
self.upconv = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
|
|
43
|
+
self.final_conv = nn.Conv2d(128, out_channels, kernel_size=1)
|
|
44
|
+
|
|
45
|
+
def forward(self, x):
|
|
46
|
+
x = self.upconv(x)
|
|
47
|
+
x = self.final_conv(x)
|
|
48
|
+
return x
|
|
49
|
+
|
|
50
|
+
# 自定义数据集类
|
|
51
|
+
class SatelliteDataset(Dataset):
|
|
52
|
+
def __init__(self, image_dir, mask_dir, transform=None):
|
|
53
|
+
self.image_dir = image_dir
|
|
54
|
+
self.mask_dir = mask_dir
|
|
55
|
+
self.image_filenames = os.listdir(image_dir)
|
|
56
|
+
self.mask_filenames = os.listdir(mask_dir)
|
|
57
|
+
self.transform = transform
|
|
58
|
+
|
|
59
|
+
def __len__(self):
|
|
60
|
+
return len(self.image_filenames)
|
|
61
|
+
|
|
62
|
+
def __getitem__(self, idx):
|
|
63
|
+
img_name = self.image_filenames[idx]
|
|
64
|
+
mask_name = self.mask_filenames[idx]
|
|
65
|
+
|
|
66
|
+
img = Image.open(os.path.join(self.image_dir, img_name)).convert("RGB")
|
|
67
|
+
mask = Image.open(os.path.join(self.mask_dir, mask_name)).convert("L") # 单通道灰度图
|
|
68
|
+
|
|
69
|
+
if self.transform:
|
|
70
|
+
img = self.transform(img)
|
|
71
|
+
mask = self.transform(mask)
|
|
72
|
+
|
|
73
|
+
return img, mask
|
|
74
|
+
|
|
75
|
+
# 数据增强与预处理
|
|
76
|
+
transform = transforms.Compose([
|
|
77
|
+
transforms.Resize((256, 256)), # 调整图像大小
|
|
78
|
+
transforms.ToTensor(), # 转换为Tensor并归一化到[0, 1]
|
|
79
|
+
])
|
|
80
|
+
# 加载数据集
|
|
81
|
+
image_dir = 'path/to/your/images' # 输入图像路径
|
|
82
|
+
mask_dir = 'path/to/your/masks' # 分割掩膜路径
|
|
83
|
+
dataset = SatelliteDataset(image_dir, mask_dir, transform)
|
|
84
|
+
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
|
|
85
|
+
# 定义模型、损失函数和优化器
|
|
86
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
87
|
+
model = ResUNet(in_channels=3, out_channels=1).to(device) # 3个输入通道(RGB),1个输出通道(二值化分割)
|
|
88
|
+
criterion = nn.BCEWithLogitsLoss() # 用于二分类任务的损失函数
|
|
89
|
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
|
90
|
+
# 训练循环
|
|
91
|
+
num_epochs = 20
|
|
92
|
+
for epoch in range(num_epochs):
|
|
93
|
+
model.train()
|
|
94
|
+
running_loss = 0.0
|
|
95
|
+
|
|
96
|
+
for imgs, masks in dataloader:
|
|
97
|
+
imgs, masks = imgs.to(device), masks.to(device)
|
|
98
|
+
|
|
99
|
+
# 正向传播
|
|
100
|
+
outputs = model(imgs)
|
|
101
|
+
|
|
102
|
+
# 计算损失
|
|
103
|
+
loss = criterion(outputs.squeeze(1), masks.float()) # 去除多余的维度并将掩膜转换为float
|
|
104
|
+
running_loss += loss.item()
|
|
105
|
+
|
|
106
|
+
# 反向传播和优化
|
|
107
|
+
optimizer.zero_grad()
|
|
108
|
+
loss.backward()
|
|
109
|
+
optimizer.step()
|
|
110
|
+
|
|
111
|
+
avg_loss = running_loss / len(dataloader)
|
|
112
|
+
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
|
|
113
|
+
|
|
114
|
+
# 保存训练好的模型
|
|
115
|
+
torch.save(model.state_dict(), 'resunet_model.pth')
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
""" --------------Resunet-------------------Dataset
|
|
119
|
+
from torch.utils.data import DataLoader
|
|
120
|
+
from torch.utils.data import Dataset
|
|
121
|
+
class ToTensorTarget(object):
|
|
122
|
+
def __call__(self, sample):
|
|
123
|
+
sat_img, map_img = sample["sat_img"], sample["map_img"]
|
|
124
|
+
|
|
125
|
+
# swap color axis because
|
|
126
|
+
# numpy image: H x W x C
|
|
127
|
+
# torch image: C X H X W
|
|
128
|
+
|
|
129
|
+
return {
|
|
130
|
+
"sat_img": transforms.functional.to_tensor(sat_img).permute(1,2,0), #(H, W, C) -->transforms.functional.to_tensor(sat_img)-->(C, H, W) --.permute(1,2,0)-->(H, W, C) H 是图像的高度,W 是宽度,C 是通道数
|
|
131
|
+
"map_img": torch.from_numpy(map_img).unsqueeze(0).float(),
|
|
132
|
+
} # unsqueeze for the channel dimension
|
|
133
|
+
|
|
134
|
+
# 自定义数据集类
|
|
135
|
+
class npyDataset_regression(Dataset):
|
|
136
|
+
def __init__(self, args, train=True, transform=None):
|
|
137
|
+
self.train = train
|
|
138
|
+
self.path = args.train if train else args.valid
|
|
139
|
+
self.mask_list = glob.glob(
|
|
140
|
+
os.path.join(self.path, "mask", "*.npy"), recursive=True
|
|
141
|
+
)
|
|
142
|
+
self.transform = transform
|
|
143
|
+
def __len__(self):
|
|
144
|
+
return len(self.mask_list)
|
|
145
|
+
def __getitem__(self, idx):
|
|
146
|
+
try:
|
|
147
|
+
maskpath = self.mask_list[idx]
|
|
148
|
+
image = np.load(maskpath.replace("mask", "input")).astype(np.float32)
|
|
149
|
+
image = image[-2:,:,:]
|
|
150
|
+
image[image<15] = np.nan
|
|
151
|
+
### 5-15dbz
|
|
152
|
+
#image[image>20] = np.nan
|
|
153
|
+
#image[image<5] = np.nan
|
|
154
|
+
#mean = np.float32(9.81645766)
|
|
155
|
+
#std = np.float32(10.172995)
|
|
156
|
+
image_mask = image[-1,:,:].copy().reshape(256,256)
|
|
157
|
+
image_mask[~np.isnan(image_mask)]=1
|
|
158
|
+
#tmp = image[-2,:,:].reshape((256,256)) * image_mask
|
|
159
|
+
#image[-2,:,:] = tmp.reshape((1,256,256))
|
|
160
|
+
mask = np.load(maskpath).astype(np.float32)
|
|
161
|
+
mask = mask * image_mask
|
|
162
|
+
image[np.isnan(image)]=0
|
|
163
|
+
sample = {"x_img": image, "map_img": mask}
|
|
164
|
+
if self.transform:
|
|
165
|
+
sample = self.transform(sample)
|
|
166
|
+
sample['maskpath'] = maskpath
|
|
167
|
+
return sample
|
|
168
|
+
except Exception as e:
|
|
169
|
+
print(f"Error loading data at index {index}: {str(e)}")
|
|
170
|
+
# 可以选择跳过当前样本或者返回一个默认值
|
|
171
|
+
print(traceback.format_exc())
|
|
172
|
+
loggers.info(traceback.format_exc())
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
dataset = npyDataset_regression(args, transform=transforms.Compose([ToTensorTarget()])) #=transforms.Compose([dataloader_radar10.ToTensorTarget()])
|
|
176
|
+
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
|
|
177
|
+
loader = tqdm(dataloader, desc="training")
|
|
178
|
+
for idx, data in enumerate(loader):
|
|
179
|
+
inputs = data["x_img"].cuda()
|
|
180
|
+
labels = data["map_img"].cuda()
|
|
181
|
+
optimizer.zero_grad()
|
|
182
|
+
outputs = model(inputs)
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
"""" ------------Gan-------------
|
|
186
|
+
import torch
|
|
187
|
+
import torch.nn as nn
|
|
188
|
+
import torch.optim as optim
|
|
189
|
+
from torchvision import datasets, transforms
|
|
190
|
+
from torchvision.utils import save_image
|
|
191
|
+
import os
|
|
192
|
+
|
|
193
|
+
def save_model_weights(generator, discriminator, epoch):
|
|
194
|
+
torch.save(generator.state_dict(), f'generator_epoch_{epoch}.pt')
|
|
195
|
+
torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch}.pt')
|
|
196
|
+
|
|
197
|
+
def load_generator_weights(generator, weights_path, device='cuda:0'):
|
|
198
|
+
generator.load_state_dict(torch.load(weights_path, map_location=device))
|
|
199
|
+
generator.eval() # Ensure generator is in evaluation mode after loading weights
|
|
200
|
+
|
|
201
|
+
def generate_images(generator, num_images, device='cuda:0'):
|
|
202
|
+
generator.eval() # Set the generator to evaluation mode
|
|
203
|
+
noise = torch.randn(num_images, 100).to(device) # Generate random noise
|
|
204
|
+
with torch.no_grad():
|
|
205
|
+
generated_images = generator(noise).cpu() # Generate images from noise
|
|
206
|
+
return generated_images
|
|
207
|
+
|
|
208
|
+
# Generator model
|
|
209
|
+
class Generator(nn.Module):
|
|
210
|
+
def __init__(self):
|
|
211
|
+
super(Generator, self).__init__()
|
|
212
|
+
self.model = nn.Sequential(
|
|
213
|
+
nn.Linear(100, 256),
|
|
214
|
+
nn.ReLU(),
|
|
215
|
+
nn.Linear(256, 64 * 64),
|
|
216
|
+
nn.Tanh() # Output range [-1, 1]
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def forward(self, x):
|
|
220
|
+
return self.model(x).view(-1, 1, 64, 64)
|
|
221
|
+
|
|
222
|
+
# Discriminator model
|
|
223
|
+
class Discriminator(nn.Module):
|
|
224
|
+
def __init__(self):
|
|
225
|
+
super(Discriminator, self).__init__()
|
|
226
|
+
self.model = nn.Sequential(
|
|
227
|
+
nn.Flatten(),
|
|
228
|
+
nn.Linear(64 * 64, 256),
|
|
229
|
+
nn.ReLU(),
|
|
230
|
+
nn.Linear(256, 1),
|
|
231
|
+
nn.Sigmoid() # Output [0, 1]
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
def forward(self, x):
|
|
235
|
+
return self.model(x)
|
|
236
|
+
|
|
237
|
+
# Training function
|
|
238
|
+
def train_gan(generator, discriminator, g_optimizer, d_optimizer, criterion, dataloader, epochs=50, device='cuda:0'):
|
|
239
|
+
for epoch in range(epochs):
|
|
240
|
+
for i, (real_images, _) in enumerate(dataloader):
|
|
241
|
+
batch_size = real_images.size(0)
|
|
242
|
+
real_images = real_images.to(device)
|
|
243
|
+
real_images = (real_images - 0.5) * 2 # Normalize to [-1, 1]
|
|
244
|
+
|
|
245
|
+
real_labels = torch.ones(batch_size, 1).to(device)
|
|
246
|
+
fake_labels = torch.zeros(batch_size, 1).to(device)
|
|
247
|
+
|
|
248
|
+
# Train discriminator
|
|
249
|
+
d_optimizer.zero_grad()
|
|
250
|
+
|
|
251
|
+
# Real image loss
|
|
252
|
+
outputs = discriminator(real_images)
|
|
253
|
+
d_loss_real = criterion(outputs, real_labels)
|
|
254
|
+
d_loss_real.backward()
|
|
255
|
+
|
|
256
|
+
# Fake image loss
|
|
257
|
+
noise = torch.randn(batch_size, 100).to(device)
|
|
258
|
+
fake_images = generator(noise)
|
|
259
|
+
outputs = discriminator(fake_images.detach())
|
|
260
|
+
d_loss_fake = criterion(outputs, fake_labels)
|
|
261
|
+
d_loss_fake.backward()
|
|
262
|
+
d_optimizer.step()
|
|
263
|
+
|
|
264
|
+
# Train generator
|
|
265
|
+
g_optimizer.zero_grad()
|
|
266
|
+
outputs = discriminator(fake_images)
|
|
267
|
+
g_loss = criterion(outputs, real_labels)
|
|
268
|
+
g_loss.backward()
|
|
269
|
+
g_optimizer.step()
|
|
270
|
+
|
|
271
|
+
# Print loss
|
|
272
|
+
if (i + 1) % 100 == 0:
|
|
273
|
+
print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(dataloader)}], '
|
|
274
|
+
f'D Loss: {d_loss_real.item() + d_loss_fake.item()}, G Loss: {g_loss.item()}')
|
|
275
|
+
|
|
276
|
+
# Save generated images and model weights every 10 epochs
|
|
277
|
+
if (epoch + 1) % 10 == 0:
|
|
278
|
+
save_image(fake_images.detach(), f'gan_images/epoch_{epoch + 1}.png')
|
|
279
|
+
save_model_weights(generator, discriminator, epoch + 1)
|
|
280
|
+
|
|
281
|
+
# Data preprocessing and loader
|
|
282
|
+
transform = transforms.Compose([
|
|
283
|
+
transforms.Resize(64),
|
|
284
|
+
transforms.ToTensor(),
|
|
285
|
+
transforms.Normalize((0.5,), (0.5,)),
|
|
286
|
+
])
|
|
287
|
+
|
|
288
|
+
if __name__ == '__main__':
|
|
289
|
+
# Set device to the first GPU (cuda:0)
|
|
290
|
+
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
|
|
291
|
+
generator = Generator().to(device)
|
|
292
|
+
discriminator = Discriminator().to(device)
|
|
293
|
+
|
|
294
|
+
# Load dataset
|
|
295
|
+
dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
|
|
296
|
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64,num_workers=20,shuffle=True) #num_workers=10,
|
|
297
|
+
|
|
298
|
+
# Define loss function and optimizers
|
|
299
|
+
criterion = nn.BCELoss()
|
|
300
|
+
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
|
|
301
|
+
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
|
|
302
|
+
|
|
303
|
+
# Create directory for saving images
|
|
304
|
+
os.makedirs('gan_images', exist_ok=True)
|
|
305
|
+
|
|
306
|
+
# Train GAN
|
|
307
|
+
train_gan(generator, discriminator, g_optimizer, d_optimizer, criterion, dataloader, epochs=10, device=device)
|
|
308
|
+
|
|
309
|
+
# Load generator weights (replace with actual path)
|
|
310
|
+
load_generator_weights(generator, 'generator_epoch_10.pt', device=device)
|
|
311
|
+
|
|
312
|
+
# Generate images
|
|
313
|
+
generated_images = generate_images(generator, 10, device=device)
|
|
314
|
+
|
|
315
|
+
# Save generated images
|
|
316
|
+
save_image(generated_images, 'generated_images.png', nrow=10, normalize=True)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
"""
|
|
322
|
+
######################################## diffusion modle #########################################
|
|
323
|
+
import os
|
|
324
|
+
from PIL import Image
|
|
325
|
+
import torch
|
|
326
|
+
import torch.nn as nn
|
|
327
|
+
import torch.optim as optim
|
|
328
|
+
from torchvision.transforms import Compose, ToTensor, Resize
|
|
329
|
+
from torchvision.utils import save_image
|
|
330
|
+
from torch.utils.data import DataLoader, Dataset
|
|
331
|
+
import numpy as np
|
|
332
|
+
|
|
333
|
+
# 自定义数据集类
|
|
334
|
+
class ImageDataset(Dataset):
|
|
335
|
+
def __init__(self, folder_path, low_res_size=(64, 64), high_res_size=(256, 256)):
|
|
336
|
+
super().__init__()
|
|
337
|
+
self.image_paths = [os.path.join(folder_path, fname) for fname in os.listdir(folder_path) if fname.endswith(('png', 'jpg', 'jpeg'))]
|
|
338
|
+
self.low_res_transform = Compose([
|
|
339
|
+
Resize(low_res_size),
|
|
340
|
+
ToTensor()
|
|
341
|
+
])
|
|
342
|
+
self.high_res_transform = Compose([
|
|
343
|
+
Resize(high_res_size),
|
|
344
|
+
ToTensor()
|
|
345
|
+
])
|
|
346
|
+
|
|
347
|
+
def __len__(self):
|
|
348
|
+
return len(self.image_paths)
|
|
349
|
+
|
|
350
|
+
def __getitem__(self, idx):
|
|
351
|
+
image = Image.open(self.image_paths[idx]).convert("RGB")
|
|
352
|
+
low_res_image = self.low_res_transform(image)
|
|
353
|
+
high_res_image = self.high_res_transform(image)
|
|
354
|
+
return low_res_image, high_res_image
|
|
355
|
+
|
|
356
|
+
# 定义扩散模型中的去噪网络
|
|
357
|
+
class DenoisingUNet(nn.Module):
|
|
358
|
+
def __init__(self):
|
|
359
|
+
super(DenoisingUNet, self).__init__()
|
|
360
|
+
self.encoder = nn.Sequential(
|
|
361
|
+
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
|
|
362
|
+
nn.ReLU(),
|
|
363
|
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
|
364
|
+
nn.ReLU()
|
|
365
|
+
)
|
|
366
|
+
self.middle = nn.Sequential(
|
|
367
|
+
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
|
|
368
|
+
nn.ReLU()
|
|
369
|
+
)
|
|
370
|
+
self.decoder = nn.Sequential(
|
|
371
|
+
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
|
|
372
|
+
nn.ReLU(),
|
|
373
|
+
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
def forward(self, x):
|
|
377
|
+
encoded = self.encoder(x)
|
|
378
|
+
middle = self.middle(encoded)
|
|
379
|
+
decoded = self.decoder(middle)
|
|
380
|
+
return decoded
|
|
381
|
+
|
|
382
|
+
# 定义扩散过程的正向和反向过程
|
|
383
|
+
class DiffusionModel:
|
|
384
|
+
def __init__(self, denoising_model, timesteps=1000):
|
|
385
|
+
self.denoising_model = denoising_model
|
|
386
|
+
self.timesteps = timesteps
|
|
387
|
+
self.beta_schedule = self._linear_beta_schedule()
|
|
388
|
+
self.alphas = 1.0 - self.beta_schedule
|
|
389
|
+
self.alpha_cumprod = torch.tensor(np.cumprod(self.alphas).astype(np.float32)) # 转为Tensor类型
|
|
390
|
+
|
|
391
|
+
def _linear_beta_schedule(self):
|
|
392
|
+
return np.linspace(1e-4, 0.02, self.timesteps).astype(np.float32)
|
|
393
|
+
|
|
394
|
+
def forward_diffusion(self, x, t):
|
|
395
|
+
alpha_t = self.alpha_cumprod[t]
|
|
396
|
+
noise = torch.randn_like(x)
|
|
397
|
+
return torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise, noise
|
|
398
|
+
|
|
399
|
+
def reverse_diffusion(self, x, t):
|
|
400
|
+
alpha_t = self.alpha_cumprod[t]
|
|
401
|
+
pred_noise = self.denoising_model(x)
|
|
402
|
+
return (x - torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(alpha_t)
|
|
403
|
+
|
|
404
|
+
# 训练函数
|
|
405
|
+
def train_diffusion_model(model, dataloader, epochs=50, lr=1e-4):
|
|
406
|
+
optimizer = optim.Adam(model.denoising_model.parameters(), lr=lr)
|
|
407
|
+
loss_fn = nn.MSELoss()
|
|
408
|
+
|
|
409
|
+
for epoch in range(epochs):
|
|
410
|
+
epoch_loss = 0
|
|
411
|
+
for low_res, high_res in dataloader:
|
|
412
|
+
high_res = high_res.to('cuda') # 使用GPU
|
|
413
|
+
low_res = low_res.to('cuda') # 使用GPU
|
|
414
|
+
t = torch.randint(0, model.timesteps, (1,)).item() # 随机选择时间步
|
|
415
|
+
noisy_image, noise = model.forward_diffusion(high_res, t)
|
|
416
|
+
noisy_image = noisy_image.to('cuda') # 确保噪声图像在GPU上
|
|
417
|
+
predicted_noise = model.denoising_model(noisy_image)
|
|
418
|
+
|
|
419
|
+
loss = loss_fn(predicted_noise, noise)
|
|
420
|
+
optimizer.zero_grad()
|
|
421
|
+
loss.backward()
|
|
422
|
+
optimizer.step()
|
|
423
|
+
epoch_loss += loss.item()
|
|
424
|
+
print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(dataloader)}")
|
|
425
|
+
|
|
426
|
+
# 示例:生成高分辨率图像
|
|
427
|
+
def generate_high_res_image(model, low_res_image):
|
|
428
|
+
t = model.timesteps - 1
|
|
429
|
+
x = low_res_image.to('cuda') # 使用GPU
|
|
430
|
+
for step in reversed(range(t)):
|
|
431
|
+
x = model.reverse_diffusion(x, step)
|
|
432
|
+
return x.cpu() # 返回到CPU
|
|
433
|
+
|
|
434
|
+
# 主函数
|
|
435
|
+
if __name__ == "__main__":
|
|
436
|
+
# 模型初始化
|
|
437
|
+
unet = DenoisingUNet().to('cuda') # 将模型加载到GPU
|
|
438
|
+
diffusion_model = DiffusionModel(denoising_model=unet)
|
|
439
|
+
|
|
440
|
+
# 数据集路径
|
|
441
|
+
folder_path = "/mnt/wtx_weather_forecast/scx/diffdataset/output_dataset/HR" # 替换为包含图像的文件夹路径
|
|
442
|
+
dataset = ImageDataset(folder_path)
|
|
443
|
+
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
|
|
444
|
+
|
|
445
|
+
# 训练模型
|
|
446
|
+
print("开始训练模型...")
|
|
447
|
+
train_diffusion_model(diffusion_model, dataloader)
|
|
448
|
+
|
|
449
|
+
# 生成高分辨率图像
|
|
450
|
+
print("生成高分辨率图像...")
|
|
451
|
+
low_res_input, _ = dataset[0] # 示例图像
|
|
452
|
+
low_res_input = low_res_input.unsqueeze(0) # 增加批次维度
|
|
453
|
+
high_res_output = generate_high_res_image(diffusion_model, low_res_input)
|
|
454
|
+
save_image(high_res_output, "high_res_output.png")
|
|
455
|
+
print("高分辨率图像已保存为 'high_res_output.png'")
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
"""
|
|
460
|
+
|
|
461
|
+
"""
|
|
462
|
+
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
"""
|
|
466
|
+
##############潜在空间,gan 模型
|
|
467
|
+
import os
|
|
468
|
+
from PIL import Image
|
|
469
|
+
import torch
|
|
470
|
+
import torch.nn as nn
|
|
471
|
+
import torch.optim as optim
|
|
472
|
+
from torchvision.transforms import Compose, ToTensor, Resize
|
|
473
|
+
from torch.utils.data import DataLoader, Dataset
|
|
474
|
+
import torch.nn.functional as F
|
|
475
|
+
|
|
476
|
+
# 确保数据和模型均加载到指定 GPU 上
|
|
477
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
478
|
+
# 自定义数据集类
|
|
479
|
+
class ImageDataset(Dataset):
|
|
480
|
+
def __init__(self, folder_path, low_res_size=(64, 64), high_res_size=(256, 256)):
|
|
481
|
+
super().__init__()
|
|
482
|
+
self.image_paths = [os.path.join(folder_path, fname) for fname in os.listdir(folder_path) if fname.endswith(('png', 'jpg', 'jpeg'))]
|
|
483
|
+
self.low_res_transform = Compose([
|
|
484
|
+
Resize(low_res_size),
|
|
485
|
+
ToTensor()
|
|
486
|
+
])
|
|
487
|
+
self.high_res_transform = Compose([
|
|
488
|
+
Resize(high_res_size),
|
|
489
|
+
ToTensor()
|
|
490
|
+
])
|
|
491
|
+
|
|
492
|
+
def __len__(self):
|
|
493
|
+
return len(self.image_paths)
|
|
494
|
+
|
|
495
|
+
def __getitem__(self, idx):
|
|
496
|
+
image = Image.open(self.image_paths[idx]).convert("RGB")
|
|
497
|
+
low_res_image = self.low_res_transform(image)
|
|
498
|
+
high_res_image = self.high_res_transform(image)
|
|
499
|
+
return low_res_image, high_res_image
|
|
500
|
+
|
|
501
|
+
import torch
|
|
502
|
+
import torch.nn as nn
|
|
503
|
+
import torch.optim as optim
|
|
504
|
+
from torchvision import transforms
|
|
505
|
+
|
|
506
|
+
# 定义潜在空间编码器
|
|
507
|
+
class Encoder(nn.Module):
|
|
508
|
+
def __init__(self, input_dim, latent_dim):
|
|
509
|
+
super(Encoder, self).__init__()
|
|
510
|
+
self.encoder = nn.Sequential(
|
|
511
|
+
nn.Conv2d(input_dim, 64, kernel_size=4, stride=2, padding=1),
|
|
512
|
+
nn.ReLU(),
|
|
513
|
+
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
|
514
|
+
nn.ReLU(),
|
|
515
|
+
nn.Conv2d(128, latent_dim, kernel_size=4, stride=2, padding=1)
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
def forward(self, x):
|
|
519
|
+
return self.encoder(x)
|
|
520
|
+
|
|
521
|
+
# 定义潜在空间解码器
|
|
522
|
+
class Decoder(nn.Module):
|
|
523
|
+
def __init__(self, latent_dim, output_dim):
|
|
524
|
+
super(Decoder, self).__init__()
|
|
525
|
+
self.decoder = nn.Sequential(
|
|
526
|
+
nn.ConvTranspose2d(latent_dim, 128, kernel_size=4, stride=2, padding=1),
|
|
527
|
+
nn.ReLU(),
|
|
528
|
+
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
|
|
529
|
+
nn.ReLU(),
|
|
530
|
+
nn.ConvTranspose2d(64, output_dim, kernel_size=4, stride=2, padding=1)
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
def forward(self, x):
|
|
534
|
+
return self.decoder(x)
|
|
535
|
+
|
|
536
|
+
# 定义扩散模型
|
|
537
|
+
class DiffusionModel(nn.Module):
|
|
538
|
+
def __init__(self, latent_dim):
|
|
539
|
+
super(DiffusionModel, self).__init__()
|
|
540
|
+
self.diffusion = nn.Sequential(
|
|
541
|
+
nn.Conv2d(latent_dim, latent_dim, kernel_size=3, stride=1, padding=1),
|
|
542
|
+
nn.ReLU(),
|
|
543
|
+
nn.Conv2d(latent_dim, latent_dim, kernel_size=3, stride=1, padding=1)
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
def forward(self, x):
|
|
547
|
+
return self.diffusion(x)
|
|
548
|
+
|
|
549
|
+
# 定义超分辨率模块
|
|
550
|
+
class SuperResolution(nn.Module):
|
|
551
|
+
def __init__(self, input_dim, output_dim):
|
|
552
|
+
super(SuperResolution, self).__init__()
|
|
553
|
+
self.sr = nn.Sequential(
|
|
554
|
+
nn.Conv2d(input_dim, 64, kernel_size=3, stride=1, padding=1),
|
|
555
|
+
nn.ReLU(),
|
|
556
|
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
|
557
|
+
nn.ReLU(),
|
|
558
|
+
nn.Conv2d(64, output_dim, kernel_size=3, stride=1, padding=1)
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
def forward(self, x):
|
|
562
|
+
return self.sr(x)
|
|
563
|
+
|
|
564
|
+
# 定义GAN判别器
|
|
565
|
+
class Discriminator(nn.Module):
|
|
566
|
+
def __init__(self, input_dim):
|
|
567
|
+
super(Discriminator, self).__init__()
|
|
568
|
+
self.discriminator = nn.Sequential(
|
|
569
|
+
nn.Conv2d(input_dim, 64, kernel_size=4, stride=2, padding=1),
|
|
570
|
+
nn.ReLU(),
|
|
571
|
+
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
|
572
|
+
nn.ReLU(),
|
|
573
|
+
nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1)
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
def forward(self, x):
|
|
577
|
+
return self.discriminator(x)
|
|
578
|
+
|
|
579
|
+
# 构建模型
|
|
580
|
+
latent_dim = 256
|
|
581
|
+
input_dim = 3
|
|
582
|
+
output_dim = 3
|
|
583
|
+
|
|
584
|
+
encoder = Encoder(input_dim, latent_dim)
|
|
585
|
+
decoder = Decoder(latent_dim, output_dim)
|
|
586
|
+
diffusion_model = DiffusionModel(latent_dim)
|
|
587
|
+
super_resolution = SuperResolution(output_dim, output_dim)
|
|
588
|
+
discriminator = Discriminator(output_dim)
|
|
589
|
+
|
|
590
|
+
# 定义优化器和损失函数
|
|
591
|
+
gen_optimizer = optim.Adam(list(encoder.parameters()) +
|
|
592
|
+
list(decoder.parameters()) +
|
|
593
|
+
list(diffusion_model.parameters()) +
|
|
594
|
+
list(super_resolution.parameters()), lr=1e-4)
|
|
595
|
+
disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)
|
|
596
|
+
mse_loss = nn.MSELoss()
|
|
597
|
+
bce_loss = nn.BCEWithLogitsLoss()
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
# 训练过程函数
|
|
601
|
+
def train_step(generator_parts, discriminator, optimizers, images, device):
|
|
602
|
+
encoder, decoder, diffusion_model, super_resolution = generator_parts
|
|
603
|
+
gen_optimizer, disc_optimizer = optimizers
|
|
604
|
+
|
|
605
|
+
low_res_images, high_res_images = images
|
|
606
|
+
low_res_images = low_res_images.to(device)
|
|
607
|
+
high_res_images = high_res_images.to(device)
|
|
608
|
+
|
|
609
|
+
# 编码
|
|
610
|
+
latent = encoder(low_res_images)
|
|
611
|
+
|
|
612
|
+
# 扩散生成
|
|
613
|
+
latent_diffused = diffusion_model(latent)
|
|
614
|
+
|
|
615
|
+
# 解码
|
|
616
|
+
reconstructed = decoder(latent_diffused)
|
|
617
|
+
|
|
618
|
+
# 超分辨率生成
|
|
619
|
+
high_res_generated = super_resolution(reconstructed)
|
|
620
|
+
|
|
621
|
+
# 调整大小以匹配
|
|
622
|
+
downsampled_high_res = F.interpolate(high_res_images, size=(64, 64), mode='bilinear', align_corners=False)
|
|
623
|
+
upsampled_reconstructed = F.interpolate(reconstructed, size=(256, 256), mode='bilinear', align_corners=False)
|
|
624
|
+
|
|
625
|
+
# 生成器损失 (重建 + 超分辨率 + 对抗)
|
|
626
|
+
reconstruction_loss = nn.MSELoss()(reconstructed, downsampled_high_res)
|
|
627
|
+
super_res_loss = nn.MSELoss()(upsampled_reconstructed, high_res_images)
|
|
628
|
+
disc_fake = discriminator(high_res_generated)
|
|
629
|
+
adversarial_loss = nn.BCEWithLogitsLoss()(disc_fake, torch.ones_like(disc_fake))
|
|
630
|
+
gen_loss = reconstruction_loss + super_res_loss + adversarial_loss
|
|
631
|
+
|
|
632
|
+
gen_optimizer.zero_grad()
|
|
633
|
+
gen_loss.backward()
|
|
634
|
+
gen_optimizer.step()
|
|
635
|
+
|
|
636
|
+
# 判别器损失
|
|
637
|
+
disc_real = discriminator(high_res_images)
|
|
638
|
+
disc_loss_real = nn.BCEWithLogitsLoss()(disc_real, torch.ones_like(disc_real))
|
|
639
|
+
disc_loss_fake = nn.BCEWithLogitsLoss()(disc_fake.detach(), torch.zeros_like(disc_fake))
|
|
640
|
+
disc_loss = (disc_loss_real + disc_loss_fake) / 2
|
|
641
|
+
|
|
642
|
+
disc_optimizer.zero_grad()
|
|
643
|
+
disc_loss.backward()
|
|
644
|
+
disc_optimizer.step()
|
|
645
|
+
|
|
646
|
+
return gen_loss.item(), disc_loss.item()
|
|
647
|
+
from shancx.Dsalgor.CudaPrefetcher1 import CUDAPrefetcher1
|
|
648
|
+
def save_best_model(encoder, decoder, diffusion_model, super_resolution, discriminator, epoch):
|
|
649
|
+
checkpoint = {
|
|
650
|
+
'encoder': encoder.state_dict(),
|
|
651
|
+
'decoder': decoder.state_dict(),
|
|
652
|
+
'diffusion_model': diffusion_model.state_dict(),
|
|
653
|
+
'super_resolution': super_resolution.state_dict(),
|
|
654
|
+
'discriminator': discriminator.state_dict(),
|
|
655
|
+
}
|
|
656
|
+
torch.save(checkpoint, f"best_model_epoch_{epoch+1}.pth")
|
|
657
|
+
print(f"Best model saved at epoch {epoch+1}.")
|
|
658
|
+
# 主训练循环
|
|
659
|
+
if __name__ == "__main__":
|
|
660
|
+
# 模型初始化
|
|
661
|
+
latent_dim = 256
|
|
662
|
+
input_dim = 3
|
|
663
|
+
output_dim = 3
|
|
664
|
+
|
|
665
|
+
# 加载模型至指定设备
|
|
666
|
+
encoder = Encoder(input_dim, latent_dim).to(device)
|
|
667
|
+
decoder = Decoder(latent_dim, output_dim).to(device)
|
|
668
|
+
diffusion_model = DiffusionModel(latent_dim).to(device)
|
|
669
|
+
super_resolution = SuperResolution(output_dim, output_dim).to(device)
|
|
670
|
+
discriminator = Discriminator(output_dim).to(device)
|
|
671
|
+
|
|
672
|
+
generator_parts = (encoder, decoder, diffusion_model, super_resolution)
|
|
673
|
+
|
|
674
|
+
# 定义优化器
|
|
675
|
+
gen_optimizer = optim.Adam(
|
|
676
|
+
list(encoder.parameters()) +
|
|
677
|
+
list(decoder.parameters()) +
|
|
678
|
+
list(diffusion_model.parameters()) +
|
|
679
|
+
list(super_resolution.parameters()),
|
|
680
|
+
lr=1e-4
|
|
681
|
+
)
|
|
682
|
+
disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)
|
|
683
|
+
optimizers = (gen_optimizer, disc_optimizer)
|
|
684
|
+
|
|
685
|
+
# 数据加载
|
|
686
|
+
folder_path = "/mnt/wtx_weather_forecast/scx/diffdataset/output_dataset/HR" # 数据集路径
|
|
687
|
+
dataset = ImageDataset(folder_path)
|
|
688
|
+
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
|
|
689
|
+
dataloader = CUDAPrefetcher1(dataloader,device)
|
|
690
|
+
num_epochs = 3
|
|
691
|
+
best_gen_loss = float('inf') # 初始化生成器最小损失值
|
|
692
|
+
# 训练参数
|
|
693
|
+
for epoch in range(num_epochs):
|
|
694
|
+
for images in dataloader:
|
|
695
|
+
gen_loss, disc_loss = train_step(generator_parts, discriminator, optimizers, images, device)
|
|
696
|
+
print(f"Epoch [{epoch+1}/{num_epochs}], Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}")
|
|
697
|
+
# 每个epoch保存最佳模型
|
|
698
|
+
if gen_loss < best_gen_loss:
|
|
699
|
+
best_gen_loss = gen_loss
|
|
700
|
+
# 保存最佳模型权重
|
|
701
|
+
save_best_model(encoder, decoder, diffusion_model, super_resolution, discriminator, epoch)
|
|
702
|
+
print("save modle")
|
|
703
|
+
|
|
704
|
+
"""
|