shancx 1.8.92__py3-none-any.whl → 1.9.33.218__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. shancx/3D/__init__.py +25 -0
  2. shancx/Algo/Class.py +11 -0
  3. shancx/Algo/CudaPrefetcher1.py +112 -0
  4. shancx/Algo/Fake_image.py +24 -0
  5. shancx/Algo/Hsml.py +391 -0
  6. shancx/Algo/L2Loss.py +10 -0
  7. shancx/Algo/MetricTracker.py +132 -0
  8. shancx/Algo/Normalize.py +66 -0
  9. shancx/Algo/OptimizerWScheduler.py +38 -0
  10. shancx/Algo/Rmageresize.py +79 -0
  11. shancx/Algo/Savemodel.py +33 -0
  12. shancx/Algo/SmoothL1_losses.py +27 -0
  13. shancx/Algo/Tqdm.py +62 -0
  14. shancx/Algo/__init__.py +121 -0
  15. shancx/Algo/checknan.py +28 -0
  16. shancx/Algo/iouJU.py +83 -0
  17. shancx/Algo/mask.py +25 -0
  18. shancx/Algo/psnr.py +9 -0
  19. shancx/Algo/ssim.py +70 -0
  20. shancx/Algo/structural_similarity.py +308 -0
  21. shancx/Algo/tool.py +704 -0
  22. shancx/Calmetrics/__init__.py +97 -0
  23. shancx/Calmetrics/calmetrics.py +14 -0
  24. shancx/Calmetrics/calmetricsmatrixLib.py +147 -0
  25. shancx/Calmetrics/rmseR2score.py +35 -0
  26. shancx/Clip/__init__.py +50 -0
  27. shancx/Cmd.py +126 -0
  28. shancx/Config_.py +26 -0
  29. shancx/Df/DataFrame.py +11 -2
  30. shancx/Df/__init__.py +17 -0
  31. shancx/Df/tool.py +0 -0
  32. shancx/Diffm/Psamples.py +18 -0
  33. shancx/Diffm/__init__.py +0 -0
  34. shancx/Diffm/test.py +207 -0
  35. shancx/Doc/__init__.py +214 -0
  36. shancx/E/__init__.py +178 -152
  37. shancx/Fillmiss/__init__.py +0 -0
  38. shancx/Fillmiss/imgidwJU.py +46 -0
  39. shancx/Fillmiss/imgidwLatLonJU.py +82 -0
  40. shancx/Gpu/__init__.py +55 -0
  41. shancx/H9/__init__.py +126 -0
  42. shancx/H9/ahi_read_hsd.py +877 -0
  43. shancx/H9/ahisearchtable.py +298 -0
  44. shancx/H9/geometry.py +2439 -0
  45. shancx/Hug/__init__.py +81 -0
  46. shancx/Inst.py +22 -0
  47. shancx/Lib.py +31 -0
  48. shancx/Mos/__init__.py +37 -0
  49. shancx/NN/__init__.py +235 -106
  50. shancx/Path1.py +161 -0
  51. shancx/Plot/GlobMap.py +276 -116
  52. shancx/Plot/__init__.py +491 -1
  53. shancx/Plot/draw_day_CR_PNG.py +4 -21
  54. shancx/Plot/exam.py +116 -0
  55. shancx/Plot/plotGlobal.py +325 -0
  56. shancx/{radar_nmc.py → Plot/radarNmc.py} +4 -34
  57. shancx/{subplots_single_china_map.py → Plot/single_china_map.py} +1 -1
  58. shancx/Point.py +46 -0
  59. shancx/QC.py +223 -0
  60. shancx/RdPzl/__init__.py +32 -0
  61. shancx/Read.py +72 -0
  62. shancx/Resize.py +79 -0
  63. shancx/SN/__init__.py +62 -123
  64. shancx/Time/GetTime.py +9 -3
  65. shancx/Time/__init__.py +66 -1
  66. shancx/Time/timeCycle.py +302 -0
  67. shancx/Time/tool.py +0 -0
  68. shancx/Train/__init__.py +74 -0
  69. shancx/Train/makelist.py +187 -0
  70. shancx/Train/multiGpu.py +27 -0
  71. shancx/Train/prepare.py +161 -0
  72. shancx/Train/renet50.py +157 -0
  73. shancx/ZR.py +12 -0
  74. shancx/__init__.py +333 -262
  75. shancx/args.py +27 -0
  76. shancx/bak.py +768 -0
  77. shancx/df2database.py +62 -2
  78. shancx/geosProj.py +80 -0
  79. shancx/info.py +38 -0
  80. shancx/netdfJU.py +231 -0
  81. shancx/sendM.py +59 -0
  82. shancx/tensBoard/__init__.py +28 -0
  83. shancx/wait.py +246 -0
  84. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/METADATA +15 -5
  85. shancx-1.9.33.218.dist-info/RECORD +91 -0
  86. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
  87. my_timer_decorator/__init__.py +0 -10
  88. shancx/Dsalgor/__init__.py +0 -19
  89. shancx/E/DFGRRIB.py +0 -30
  90. shancx/EN/DFGRRIB.py +0 -30
  91. shancx/EN/__init__.py +0 -148
  92. shancx/FileRead.py +0 -44
  93. shancx/Gray2RGB.py +0 -86
  94. shancx/M/__init__.py +0 -137
  95. shancx/MN/__init__.py +0 -133
  96. shancx/N/__init__.py +0 -131
  97. shancx/Plot/draw_day_CR_PNGUS.py +0 -206
  98. shancx/Plot/draw_day_CR_SVG.py +0 -275
  99. shancx/Plot/draw_day_pre_PNGUS.py +0 -205
  100. shancx/Plot/glob_nation_map.py +0 -116
  101. shancx/Plot/radar_nmc.py +0 -61
  102. shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
  103. shancx/Plot/radar_nmc_china_map_f.py +0 -121
  104. shancx/Plot/radar_nmc_us_map_f.py +0 -128
  105. shancx/Plot/subplots_compare_devlop.py +0 -36
  106. shancx/Plot/subplots_single_china_map.py +0 -45
  107. shancx/S/__init__.py +0 -138
  108. shancx/W/__init__.py +0 -132
  109. shancx/WN/__init__.py +0 -132
  110. shancx/code.py +0 -331
  111. shancx/draw_day_CR_PNG.py +0 -200
  112. shancx/draw_day_CR_PNGUS.py +0 -206
  113. shancx/draw_day_CR_SVG.py +0 -275
  114. shancx/draw_day_pre_PNGUS.py +0 -205
  115. shancx/makenetCDFN.py +0 -42
  116. shancx/mkIMGSCX.py +0 -92
  117. shancx/netCDF.py +0 -130
  118. shancx/radar_nmc_china_map_compare1.py +0 -50
  119. shancx/radar_nmc_china_map_f.py +0 -125
  120. shancx/radar_nmc_us_map_f.py +0 -67
  121. shancx/subplots_compare_devlop.py +0 -36
  122. shancx/tool.py +0 -18
  123. shancx/user/H8mess.py +0 -317
  124. shancx/user/__init__.py +0 -137
  125. shancx/user/cinradHJN.py +0 -496
  126. shancx/user/examMeso.py +0 -293
  127. shancx/user/hjnDAAS.py +0 -26
  128. shancx/user/hjnFTP.py +0 -81
  129. shancx/user/hjnGIS.py +0 -320
  130. shancx/user/hjnGPU.py +0 -21
  131. shancx/user/hjnIDW.py +0 -68
  132. shancx/user/hjnKDTree.py +0 -75
  133. shancx/user/hjnLAPSTransform.py +0 -47
  134. shancx/user/hjnMiscellaneous.py +0 -182
  135. shancx/user/hjnProj.py +0 -162
  136. shancx/user/inotify.py +0 -41
  137. shancx/user/matplotlibMess.py +0 -87
  138. shancx/user/mkNCHJN.py +0 -623
  139. shancx/user/newTypeRadar.py +0 -492
  140. shancx/user/test.py +0 -6
  141. shancx/user/tlogP.py +0 -129
  142. shancx/util_log.py +0 -33
  143. shancx/wtx/H8mess.py +0 -315
  144. shancx/wtx/__init__.py +0 -151
  145. shancx/wtx/cinradHJN.py +0 -496
  146. shancx/wtx/colormap.py +0 -64
  147. shancx/wtx/examMeso.py +0 -298
  148. shancx/wtx/hjnDAAS.py +0 -26
  149. shancx/wtx/hjnFTP.py +0 -81
  150. shancx/wtx/hjnGIS.py +0 -330
  151. shancx/wtx/hjnGPU.py +0 -21
  152. shancx/wtx/hjnIDW.py +0 -68
  153. shancx/wtx/hjnKDTree.py +0 -75
  154. shancx/wtx/hjnLAPSTransform.py +0 -47
  155. shancx/wtx/hjnLog.py +0 -78
  156. shancx/wtx/hjnMiscellaneous.py +0 -201
  157. shancx/wtx/hjnProj.py +0 -161
  158. shancx/wtx/inotify.py +0 -41
  159. shancx/wtx/matplotlibMess.py +0 -87
  160. shancx/wtx/mkNCHJN.py +0 -613
  161. shancx/wtx/newTypeRadar.py +0 -492
  162. shancx/wtx/test.py +0 -6
  163. shancx/wtx/tlogP.py +0 -129
  164. shancx-1.8.92.dist-info/RECORD +0 -99
  165. /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
  166. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
shancx/Algo/tool.py ADDED
@@ -0,0 +1,704 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader, Dataset
5
+ import torchvision.transforms as transforms
6
+ from torchvision import transforms
7
+ import os
8
+ from PIL import Image
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import glob
12
+
13
+ # 简化的 ResUNet 架构
14
+ class ResUNet(nn.Module):
15
+ def __init__(self, in_channels, out_channels):
16
+ super(ResUNet, self).__init__()
17
+ self.encoder = ResNetEncoder(in_channels)
18
+ self.decoder = UNetDecoder(out_channels)
19
+
20
+ def forward(self, x):
21
+ enc_out = self.encoder(x)
22
+ dec_out = self.decoder(enc_out)
23
+ return dec_out
24
+
25
+ class ResNetEncoder(nn.Module):
26
+ def __init__(self, in_channels):
27
+ super(ResNetEncoder, self).__init__()
28
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
29
+ self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
30
+ self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
31
+ # 可以继续添加更多层
32
+
33
+ def forward(self, x):
34
+ x = torch.relu(self.conv1(x))
35
+ x = torch.relu(self.conv2(x))
36
+ x = torch.relu(self.conv3(x))
37
+ return x
38
+
39
+ class UNetDecoder(nn.Module):
40
+ def __init__(self, out_channels):
41
+ super(UNetDecoder, self).__init__()
42
+ self.upconv = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
43
+ self.final_conv = nn.Conv2d(128, out_channels, kernel_size=1)
44
+
45
+ def forward(self, x):
46
+ x = self.upconv(x)
47
+ x = self.final_conv(x)
48
+ return x
49
+
50
+ # 自定义数据集类
51
+ class SatelliteDataset(Dataset):
52
+ def __init__(self, image_dir, mask_dir, transform=None):
53
+ self.image_dir = image_dir
54
+ self.mask_dir = mask_dir
55
+ self.image_filenames = os.listdir(image_dir)
56
+ self.mask_filenames = os.listdir(mask_dir)
57
+ self.transform = transform
58
+
59
+ def __len__(self):
60
+ return len(self.image_filenames)
61
+
62
+ def __getitem__(self, idx):
63
+ img_name = self.image_filenames[idx]
64
+ mask_name = self.mask_filenames[idx]
65
+
66
+ img = Image.open(os.path.join(self.image_dir, img_name)).convert("RGB")
67
+ mask = Image.open(os.path.join(self.mask_dir, mask_name)).convert("L") # 单通道灰度图
68
+
69
+ if self.transform:
70
+ img = self.transform(img)
71
+ mask = self.transform(mask)
72
+
73
+ return img, mask
74
+
75
+ # 数据增强与预处理
76
+ transform = transforms.Compose([
77
+ transforms.Resize((256, 256)), # 调整图像大小
78
+ transforms.ToTensor(), # 转换为Tensor并归一化到[0, 1]
79
+ ])
80
+ # 加载数据集
81
+ image_dir = 'path/to/your/images' # 输入图像路径
82
+ mask_dir = 'path/to/your/masks' # 分割掩膜路径
83
+ dataset = SatelliteDataset(image_dir, mask_dir, transform)
84
+ dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
85
+ # 定义模型、损失函数和优化器
86
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87
+ model = ResUNet(in_channels=3, out_channels=1).to(device) # 3个输入通道(RGB),1个输出通道(二值化分割)
88
+ criterion = nn.BCEWithLogitsLoss() # 用于二分类任务的损失函数
89
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
90
+ # 训练循环
91
+ num_epochs = 20
92
+ for epoch in range(num_epochs):
93
+ model.train()
94
+ running_loss = 0.0
95
+
96
+ for imgs, masks in dataloader:
97
+ imgs, masks = imgs.to(device), masks.to(device)
98
+
99
+ # 正向传播
100
+ outputs = model(imgs)
101
+
102
+ # 计算损失
103
+ loss = criterion(outputs.squeeze(1), masks.float()) # 去除多余的维度并将掩膜转换为float
104
+ running_loss += loss.item()
105
+
106
+ # 反向传播和优化
107
+ optimizer.zero_grad()
108
+ loss.backward()
109
+ optimizer.step()
110
+
111
+ avg_loss = running_loss / len(dataloader)
112
+ print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
113
+
114
+ # 保存训练好的模型
115
+ torch.save(model.state_dict(), 'resunet_model.pth')
116
+
117
+
118
+ """ --------------Resunet-------------------Dataset
119
+ from torch.utils.data import DataLoader
120
+ from torch.utils.data import Dataset
121
+ class ToTensorTarget(object):
122
+ def __call__(self, sample):
123
+ sat_img, map_img = sample["sat_img"], sample["map_img"]
124
+
125
+ # swap color axis because
126
+ # numpy image: H x W x C
127
+ # torch image: C X H X W
128
+
129
+ return {
130
+ "sat_img": transforms.functional.to_tensor(sat_img).permute(1,2,0), #(H, W, C) -->transforms.functional.to_tensor(sat_img)-->(C, H, W) --.permute(1,2,0)-->(H, W, C) H 是图像的高度,W 是宽度,C 是通道数
131
+ "map_img": torch.from_numpy(map_img).unsqueeze(0).float(),
132
+ } # unsqueeze for the channel dimension
133
+
134
+ # 自定义数据集类
135
+ class npyDataset_regression(Dataset):
136
+ def __init__(self, args, train=True, transform=None):
137
+ self.train = train
138
+ self.path = args.train if train else args.valid
139
+ self.mask_list = glob.glob(
140
+ os.path.join(self.path, "mask", "*.npy"), recursive=True
141
+ )
142
+ self.transform = transform
143
+ def __len__(self):
144
+ return len(self.mask_list)
145
+ def __getitem__(self, idx):
146
+ try:
147
+ maskpath = self.mask_list[idx]
148
+ image = np.load(maskpath.replace("mask", "input")).astype(np.float32)
149
+ image = image[-2:,:,:]
150
+ image[image<15] = np.nan
151
+ ### 5-15dbz
152
+ #image[image>20] = np.nan
153
+ #image[image<5] = np.nan
154
+ #mean = np.float32(9.81645766)
155
+ #std = np.float32(10.172995)
156
+ image_mask = image[-1,:,:].copy().reshape(256,256)
157
+ image_mask[~np.isnan(image_mask)]=1
158
+ #tmp = image[-2,:,:].reshape((256,256)) * image_mask
159
+ #image[-2,:,:] = tmp.reshape((1,256,256))
160
+ mask = np.load(maskpath).astype(np.float32)
161
+ mask = mask * image_mask
162
+ image[np.isnan(image)]=0
163
+ sample = {"x_img": image, "map_img": mask}
164
+ if self.transform:
165
+ sample = self.transform(sample)
166
+ sample['maskpath'] = maskpath
167
+ return sample
168
+ except Exception as e:
169
+ print(f"Error loading data at index {index}: {str(e)}")
170
+ # 可以选择跳过当前样本或者返回一个默认值
171
+ print(traceback.format_exc())
172
+ loggers.info(traceback.format_exc())
173
+ return None
174
+
175
+ dataset = npyDataset_regression(args, transform=transforms.Compose([ToTensorTarget()])) #=transforms.Compose([dataloader_radar10.ToTensorTarget()])
176
+ dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
177
+ loader = tqdm(dataloader, desc="training")
178
+ for idx, data in enumerate(loader):
179
+ inputs = data["x_img"].cuda()
180
+ labels = data["map_img"].cuda()
181
+ optimizer.zero_grad()
182
+ outputs = model(inputs)
183
+ """
184
+
185
+ """" ------------Gan-------------
186
+ import torch
187
+ import torch.nn as nn
188
+ import torch.optim as optim
189
+ from torchvision import datasets, transforms
190
+ from torchvision.utils import save_image
191
+ import os
192
+
193
+ def save_model_weights(generator, discriminator, epoch):
194
+ torch.save(generator.state_dict(), f'generator_epoch_{epoch}.pt')
195
+ torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch}.pt')
196
+
197
+ def load_generator_weights(generator, weights_path, device='cuda:0'):
198
+ generator.load_state_dict(torch.load(weights_path, map_location=device))
199
+ generator.eval() # Ensure generator is in evaluation mode after loading weights
200
+
201
+ def generate_images(generator, num_images, device='cuda:0'):
202
+ generator.eval() # Set the generator to evaluation mode
203
+ noise = torch.randn(num_images, 100).to(device) # Generate random noise
204
+ with torch.no_grad():
205
+ generated_images = generator(noise).cpu() # Generate images from noise
206
+ return generated_images
207
+
208
+ # Generator model
209
+ class Generator(nn.Module):
210
+ def __init__(self):
211
+ super(Generator, self).__init__()
212
+ self.model = nn.Sequential(
213
+ nn.Linear(100, 256),
214
+ nn.ReLU(),
215
+ nn.Linear(256, 64 * 64),
216
+ nn.Tanh() # Output range [-1, 1]
217
+ )
218
+
219
+ def forward(self, x):
220
+ return self.model(x).view(-1, 1, 64, 64)
221
+
222
+ # Discriminator model
223
+ class Discriminator(nn.Module):
224
+ def __init__(self):
225
+ super(Discriminator, self).__init__()
226
+ self.model = nn.Sequential(
227
+ nn.Flatten(),
228
+ nn.Linear(64 * 64, 256),
229
+ nn.ReLU(),
230
+ nn.Linear(256, 1),
231
+ nn.Sigmoid() # Output [0, 1]
232
+ )
233
+
234
+ def forward(self, x):
235
+ return self.model(x)
236
+
237
+ # Training function
238
+ def train_gan(generator, discriminator, g_optimizer, d_optimizer, criterion, dataloader, epochs=50, device='cuda:0'):
239
+ for epoch in range(epochs):
240
+ for i, (real_images, _) in enumerate(dataloader):
241
+ batch_size = real_images.size(0)
242
+ real_images = real_images.to(device)
243
+ real_images = (real_images - 0.5) * 2 # Normalize to [-1, 1]
244
+
245
+ real_labels = torch.ones(batch_size, 1).to(device)
246
+ fake_labels = torch.zeros(batch_size, 1).to(device)
247
+
248
+ # Train discriminator
249
+ d_optimizer.zero_grad()
250
+
251
+ # Real image loss
252
+ outputs = discriminator(real_images)
253
+ d_loss_real = criterion(outputs, real_labels)
254
+ d_loss_real.backward()
255
+
256
+ # Fake image loss
257
+ noise = torch.randn(batch_size, 100).to(device)
258
+ fake_images = generator(noise)
259
+ outputs = discriminator(fake_images.detach())
260
+ d_loss_fake = criterion(outputs, fake_labels)
261
+ d_loss_fake.backward()
262
+ d_optimizer.step()
263
+
264
+ # Train generator
265
+ g_optimizer.zero_grad()
266
+ outputs = discriminator(fake_images)
267
+ g_loss = criterion(outputs, real_labels)
268
+ g_loss.backward()
269
+ g_optimizer.step()
270
+
271
+ # Print loss
272
+ if (i + 1) % 100 == 0:
273
+ print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(dataloader)}], '
274
+ f'D Loss: {d_loss_real.item() + d_loss_fake.item()}, G Loss: {g_loss.item()}')
275
+
276
+ # Save generated images and model weights every 10 epochs
277
+ if (epoch + 1) % 10 == 0:
278
+ save_image(fake_images.detach(), f'gan_images/epoch_{epoch + 1}.png')
279
+ save_model_weights(generator, discriminator, epoch + 1)
280
+
281
+ # Data preprocessing and loader
282
+ transform = transforms.Compose([
283
+ transforms.Resize(64),
284
+ transforms.ToTensor(),
285
+ transforms.Normalize((0.5,), (0.5,)),
286
+ ])
287
+
288
+ if __name__ == '__main__':
289
+ # Set device to the first GPU (cuda:0)
290
+ device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
291
+ generator = Generator().to(device)
292
+ discriminator = Discriminator().to(device)
293
+
294
+ # Load dataset
295
+ dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
296
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=64,num_workers=20,shuffle=True) #num_workers=10,
297
+
298
+ # Define loss function and optimizers
299
+ criterion = nn.BCELoss()
300
+ g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
301
+ d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
302
+
303
+ # Create directory for saving images
304
+ os.makedirs('gan_images', exist_ok=True)
305
+
306
+ # Train GAN
307
+ train_gan(generator, discriminator, g_optimizer, d_optimizer, criterion, dataloader, epochs=10, device=device)
308
+
309
+ # Load generator weights (replace with actual path)
310
+ load_generator_weights(generator, 'generator_epoch_10.pt', device=device)
311
+
312
+ # Generate images
313
+ generated_images = generate_images(generator, 10, device=device)
314
+
315
+ # Save generated images
316
+ save_image(generated_images, 'generated_images.png', nrow=10, normalize=True)
317
+
318
+
319
+ """
320
+
321
+ """
322
+ ######################################## diffusion modle #########################################
323
+ import os
324
+ from PIL import Image
325
+ import torch
326
+ import torch.nn as nn
327
+ import torch.optim as optim
328
+ from torchvision.transforms import Compose, ToTensor, Resize
329
+ from torchvision.utils import save_image
330
+ from torch.utils.data import DataLoader, Dataset
331
+ import numpy as np
332
+
333
+ # 自定义数据集类
334
+ class ImageDataset(Dataset):
335
+ def __init__(self, folder_path, low_res_size=(64, 64), high_res_size=(256, 256)):
336
+ super().__init__()
337
+ self.image_paths = [os.path.join(folder_path, fname) for fname in os.listdir(folder_path) if fname.endswith(('png', 'jpg', 'jpeg'))]
338
+ self.low_res_transform = Compose([
339
+ Resize(low_res_size),
340
+ ToTensor()
341
+ ])
342
+ self.high_res_transform = Compose([
343
+ Resize(high_res_size),
344
+ ToTensor()
345
+ ])
346
+
347
+ def __len__(self):
348
+ return len(self.image_paths)
349
+
350
+ def __getitem__(self, idx):
351
+ image = Image.open(self.image_paths[idx]).convert("RGB")
352
+ low_res_image = self.low_res_transform(image)
353
+ high_res_image = self.high_res_transform(image)
354
+ return low_res_image, high_res_image
355
+
356
+ # 定义扩散模型中的去噪网络
357
+ class DenoisingUNet(nn.Module):
358
+ def __init__(self):
359
+ super(DenoisingUNet, self).__init__()
360
+ self.encoder = nn.Sequential(
361
+ nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
362
+ nn.ReLU(),
363
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
364
+ nn.ReLU()
365
+ )
366
+ self.middle = nn.Sequential(
367
+ nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
368
+ nn.ReLU()
369
+ )
370
+ self.decoder = nn.Sequential(
371
+ nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
372
+ nn.ReLU(),
373
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
374
+ )
375
+
376
+ def forward(self, x):
377
+ encoded = self.encoder(x)
378
+ middle = self.middle(encoded)
379
+ decoded = self.decoder(middle)
380
+ return decoded
381
+
382
+ # 定义扩散过程的正向和反向过程
383
+ class DiffusionModel:
384
+ def __init__(self, denoising_model, timesteps=1000):
385
+ self.denoising_model = denoising_model
386
+ self.timesteps = timesteps
387
+ self.beta_schedule = self._linear_beta_schedule()
388
+ self.alphas = 1.0 - self.beta_schedule
389
+ self.alpha_cumprod = torch.tensor(np.cumprod(self.alphas).astype(np.float32)) # 转为Tensor类型
390
+
391
+ def _linear_beta_schedule(self):
392
+ return np.linspace(1e-4, 0.02, self.timesteps).astype(np.float32)
393
+
394
+ def forward_diffusion(self, x, t):
395
+ alpha_t = self.alpha_cumprod[t]
396
+ noise = torch.randn_like(x)
397
+ return torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise, noise
398
+
399
+ def reverse_diffusion(self, x, t):
400
+ alpha_t = self.alpha_cumprod[t]
401
+ pred_noise = self.denoising_model(x)
402
+ return (x - torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(alpha_t)
403
+
404
+ # 训练函数
405
+ def train_diffusion_model(model, dataloader, epochs=50, lr=1e-4):
406
+ optimizer = optim.Adam(model.denoising_model.parameters(), lr=lr)
407
+ loss_fn = nn.MSELoss()
408
+
409
+ for epoch in range(epochs):
410
+ epoch_loss = 0
411
+ for low_res, high_res in dataloader:
412
+ high_res = high_res.to('cuda') # 使用GPU
413
+ low_res = low_res.to('cuda') # 使用GPU
414
+ t = torch.randint(0, model.timesteps, (1,)).item() # 随机选择时间步
415
+ noisy_image, noise = model.forward_diffusion(high_res, t)
416
+ noisy_image = noisy_image.to('cuda') # 确保噪声图像在GPU上
417
+ predicted_noise = model.denoising_model(noisy_image)
418
+
419
+ loss = loss_fn(predicted_noise, noise)
420
+ optimizer.zero_grad()
421
+ loss.backward()
422
+ optimizer.step()
423
+ epoch_loss += loss.item()
424
+ print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(dataloader)}")
425
+
426
+ # 示例:生成高分辨率图像
427
+ def generate_high_res_image(model, low_res_image):
428
+ t = model.timesteps - 1
429
+ x = low_res_image.to('cuda') # 使用GPU
430
+ for step in reversed(range(t)):
431
+ x = model.reverse_diffusion(x, step)
432
+ return x.cpu() # 返回到CPU
433
+
434
+ # 主函数
435
+ if __name__ == "__main__":
436
+ # 模型初始化
437
+ unet = DenoisingUNet().to('cuda') # 将模型加载到GPU
438
+ diffusion_model = DiffusionModel(denoising_model=unet)
439
+
440
+ # 数据集路径
441
+ folder_path = "/mnt/wtx_weather_forecast/scx/diffdataset/output_dataset/HR" # 替换为包含图像的文件夹路径
442
+ dataset = ImageDataset(folder_path)
443
+ dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
444
+
445
+ # 训练模型
446
+ print("开始训练模型...")
447
+ train_diffusion_model(diffusion_model, dataloader)
448
+
449
+ # 生成高分辨率图像
450
+ print("生成高分辨率图像...")
451
+ low_res_input, _ = dataset[0] # 示例图像
452
+ low_res_input = low_res_input.unsqueeze(0) # 增加批次维度
453
+ high_res_output = generate_high_res_image(diffusion_model, low_res_input)
454
+ save_image(high_res_output, "high_res_output.png")
455
+ print("高分辨率图像已保存为 'high_res_output.png'")
456
+
457
+
458
+
459
+ """
460
+
461
+ """
462
+ from torch.utils.data import random_split, DataLoader, Dataset, Subset
463
+ """
464
+
465
+ """
466
+ ##############潜在空间,gan 模型
467
+ import os
468
+ from PIL import Image
469
+ import torch
470
+ import torch.nn as nn
471
+ import torch.optim as optim
472
+ from torchvision.transforms import Compose, ToTensor, Resize
473
+ from torch.utils.data import DataLoader, Dataset
474
+ import torch.nn.functional as F
475
+
476
+ # 确保数据和模型均加载到指定 GPU 上
477
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
478
+ # 自定义数据集类
479
+ class ImageDataset(Dataset):
480
+ def __init__(self, folder_path, low_res_size=(64, 64), high_res_size=(256, 256)):
481
+ super().__init__()
482
+ self.image_paths = [os.path.join(folder_path, fname) for fname in os.listdir(folder_path) if fname.endswith(('png', 'jpg', 'jpeg'))]
483
+ self.low_res_transform = Compose([
484
+ Resize(low_res_size),
485
+ ToTensor()
486
+ ])
487
+ self.high_res_transform = Compose([
488
+ Resize(high_res_size),
489
+ ToTensor()
490
+ ])
491
+
492
+ def __len__(self):
493
+ return len(self.image_paths)
494
+
495
+ def __getitem__(self, idx):
496
+ image = Image.open(self.image_paths[idx]).convert("RGB")
497
+ low_res_image = self.low_res_transform(image)
498
+ high_res_image = self.high_res_transform(image)
499
+ return low_res_image, high_res_image
500
+
501
+ import torch
502
+ import torch.nn as nn
503
+ import torch.optim as optim
504
+ from torchvision import transforms
505
+
506
+ # 定义潜在空间编码器
507
+ class Encoder(nn.Module):
508
+ def __init__(self, input_dim, latent_dim):
509
+ super(Encoder, self).__init__()
510
+ self.encoder = nn.Sequential(
511
+ nn.Conv2d(input_dim, 64, kernel_size=4, stride=2, padding=1),
512
+ nn.ReLU(),
513
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
514
+ nn.ReLU(),
515
+ nn.Conv2d(128, latent_dim, kernel_size=4, stride=2, padding=1)
516
+ )
517
+
518
+ def forward(self, x):
519
+ return self.encoder(x)
520
+
521
+ # 定义潜在空间解码器
522
+ class Decoder(nn.Module):
523
+ def __init__(self, latent_dim, output_dim):
524
+ super(Decoder, self).__init__()
525
+ self.decoder = nn.Sequential(
526
+ nn.ConvTranspose2d(latent_dim, 128, kernel_size=4, stride=2, padding=1),
527
+ nn.ReLU(),
528
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
529
+ nn.ReLU(),
530
+ nn.ConvTranspose2d(64, output_dim, kernel_size=4, stride=2, padding=1)
531
+ )
532
+
533
+ def forward(self, x):
534
+ return self.decoder(x)
535
+
536
+ # 定义扩散模型
537
+ class DiffusionModel(nn.Module):
538
+ def __init__(self, latent_dim):
539
+ super(DiffusionModel, self).__init__()
540
+ self.diffusion = nn.Sequential(
541
+ nn.Conv2d(latent_dim, latent_dim, kernel_size=3, stride=1, padding=1),
542
+ nn.ReLU(),
543
+ nn.Conv2d(latent_dim, latent_dim, kernel_size=3, stride=1, padding=1)
544
+ )
545
+
546
+ def forward(self, x):
547
+ return self.diffusion(x)
548
+
549
+ # 定义超分辨率模块
550
+ class SuperResolution(nn.Module):
551
+ def __init__(self, input_dim, output_dim):
552
+ super(SuperResolution, self).__init__()
553
+ self.sr = nn.Sequential(
554
+ nn.Conv2d(input_dim, 64, kernel_size=3, stride=1, padding=1),
555
+ nn.ReLU(),
556
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
557
+ nn.ReLU(),
558
+ nn.Conv2d(64, output_dim, kernel_size=3, stride=1, padding=1)
559
+ )
560
+
561
+ def forward(self, x):
562
+ return self.sr(x)
563
+
564
+ # 定义GAN判别器
565
+ class Discriminator(nn.Module):
566
+ def __init__(self, input_dim):
567
+ super(Discriminator, self).__init__()
568
+ self.discriminator = nn.Sequential(
569
+ nn.Conv2d(input_dim, 64, kernel_size=4, stride=2, padding=1),
570
+ nn.ReLU(),
571
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
572
+ nn.ReLU(),
573
+ nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=1)
574
+ )
575
+
576
+ def forward(self, x):
577
+ return self.discriminator(x)
578
+
579
+ # 构建模型
580
+ latent_dim = 256
581
+ input_dim = 3
582
+ output_dim = 3
583
+
584
+ encoder = Encoder(input_dim, latent_dim)
585
+ decoder = Decoder(latent_dim, output_dim)
586
+ diffusion_model = DiffusionModel(latent_dim)
587
+ super_resolution = SuperResolution(output_dim, output_dim)
588
+ discriminator = Discriminator(output_dim)
589
+
590
+ # 定义优化器和损失函数
591
+ gen_optimizer = optim.Adam(list(encoder.parameters()) +
592
+ list(decoder.parameters()) +
593
+ list(diffusion_model.parameters()) +
594
+ list(super_resolution.parameters()), lr=1e-4)
595
+ disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)
596
+ mse_loss = nn.MSELoss()
597
+ bce_loss = nn.BCEWithLogitsLoss()
598
+
599
+
600
+ # 训练过程函数
601
+ def train_step(generator_parts, discriminator, optimizers, images, device):
602
+ encoder, decoder, diffusion_model, super_resolution = generator_parts
603
+ gen_optimizer, disc_optimizer = optimizers
604
+
605
+ low_res_images, high_res_images = images
606
+ low_res_images = low_res_images.to(device)
607
+ high_res_images = high_res_images.to(device)
608
+
609
+ # 编码
610
+ latent = encoder(low_res_images)
611
+
612
+ # 扩散生成
613
+ latent_diffused = diffusion_model(latent)
614
+
615
+ # 解码
616
+ reconstructed = decoder(latent_diffused)
617
+
618
+ # 超分辨率生成
619
+ high_res_generated = super_resolution(reconstructed)
620
+
621
+ # 调整大小以匹配
622
+ downsampled_high_res = F.interpolate(high_res_images, size=(64, 64), mode='bilinear', align_corners=False)
623
+ upsampled_reconstructed = F.interpolate(reconstructed, size=(256, 256), mode='bilinear', align_corners=False)
624
+
625
+ # 生成器损失 (重建 + 超分辨率 + 对抗)
626
+ reconstruction_loss = nn.MSELoss()(reconstructed, downsampled_high_res)
627
+ super_res_loss = nn.MSELoss()(upsampled_reconstructed, high_res_images)
628
+ disc_fake = discriminator(high_res_generated)
629
+ adversarial_loss = nn.BCEWithLogitsLoss()(disc_fake, torch.ones_like(disc_fake))
630
+ gen_loss = reconstruction_loss + super_res_loss + adversarial_loss
631
+
632
+ gen_optimizer.zero_grad()
633
+ gen_loss.backward()
634
+ gen_optimizer.step()
635
+
636
+ # 判别器损失
637
+ disc_real = discriminator(high_res_images)
638
+ disc_loss_real = nn.BCEWithLogitsLoss()(disc_real, torch.ones_like(disc_real))
639
+ disc_loss_fake = nn.BCEWithLogitsLoss()(disc_fake.detach(), torch.zeros_like(disc_fake))
640
+ disc_loss = (disc_loss_real + disc_loss_fake) / 2
641
+
642
+ disc_optimizer.zero_grad()
643
+ disc_loss.backward()
644
+ disc_optimizer.step()
645
+
646
+ return gen_loss.item(), disc_loss.item()
647
+ from shancx.Dsalgor.CudaPrefetcher1 import CUDAPrefetcher1
648
+ def save_best_model(encoder, decoder, diffusion_model, super_resolution, discriminator, epoch):
649
+ checkpoint = {
650
+ 'encoder': encoder.state_dict(),
651
+ 'decoder': decoder.state_dict(),
652
+ 'diffusion_model': diffusion_model.state_dict(),
653
+ 'super_resolution': super_resolution.state_dict(),
654
+ 'discriminator': discriminator.state_dict(),
655
+ }
656
+ torch.save(checkpoint, f"best_model_epoch_{epoch+1}.pth")
657
+ print(f"Best model saved at epoch {epoch+1}.")
658
+ # 主训练循环
659
+ if __name__ == "__main__":
660
+ # 模型初始化
661
+ latent_dim = 256
662
+ input_dim = 3
663
+ output_dim = 3
664
+
665
+ # 加载模型至指定设备
666
+ encoder = Encoder(input_dim, latent_dim).to(device)
667
+ decoder = Decoder(latent_dim, output_dim).to(device)
668
+ diffusion_model = DiffusionModel(latent_dim).to(device)
669
+ super_resolution = SuperResolution(output_dim, output_dim).to(device)
670
+ discriminator = Discriminator(output_dim).to(device)
671
+
672
+ generator_parts = (encoder, decoder, diffusion_model, super_resolution)
673
+
674
+ # 定义优化器
675
+ gen_optimizer = optim.Adam(
676
+ list(encoder.parameters()) +
677
+ list(decoder.parameters()) +
678
+ list(diffusion_model.parameters()) +
679
+ list(super_resolution.parameters()),
680
+ lr=1e-4
681
+ )
682
+ disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)
683
+ optimizers = (gen_optimizer, disc_optimizer)
684
+
685
+ # 数据加载
686
+ folder_path = "/mnt/wtx_weather_forecast/scx/diffdataset/output_dataset/HR" # 数据集路径
687
+ dataset = ImageDataset(folder_path)
688
+ dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
689
+ dataloader = CUDAPrefetcher1(dataloader,device)
690
+ num_epochs = 3
691
+ best_gen_loss = float('inf') # 初始化生成器最小损失值
692
+ # 训练参数
693
+ for epoch in range(num_epochs):
694
+ for images in dataloader:
695
+ gen_loss, disc_loss = train_step(generator_parts, discriminator, optimizers, images, device)
696
+ print(f"Epoch [{epoch+1}/{num_epochs}], Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}")
697
+ # 每个epoch保存最佳模型
698
+ if gen_loss < best_gen_loss:
699
+ best_gen_loss = gen_loss
700
+ # 保存最佳模型权重
701
+ save_best_model(encoder, decoder, diffusion_model, super_resolution, discriminator, epoch)
702
+ print("save modle")
703
+
704
+ """