shancx 1.8.92__py3-none-any.whl → 1.9.33.218__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. shancx/3D/__init__.py +25 -0
  2. shancx/Algo/Class.py +11 -0
  3. shancx/Algo/CudaPrefetcher1.py +112 -0
  4. shancx/Algo/Fake_image.py +24 -0
  5. shancx/Algo/Hsml.py +391 -0
  6. shancx/Algo/L2Loss.py +10 -0
  7. shancx/Algo/MetricTracker.py +132 -0
  8. shancx/Algo/Normalize.py +66 -0
  9. shancx/Algo/OptimizerWScheduler.py +38 -0
  10. shancx/Algo/Rmageresize.py +79 -0
  11. shancx/Algo/Savemodel.py +33 -0
  12. shancx/Algo/SmoothL1_losses.py +27 -0
  13. shancx/Algo/Tqdm.py +62 -0
  14. shancx/Algo/__init__.py +121 -0
  15. shancx/Algo/checknan.py +28 -0
  16. shancx/Algo/iouJU.py +83 -0
  17. shancx/Algo/mask.py +25 -0
  18. shancx/Algo/psnr.py +9 -0
  19. shancx/Algo/ssim.py +70 -0
  20. shancx/Algo/structural_similarity.py +308 -0
  21. shancx/Algo/tool.py +704 -0
  22. shancx/Calmetrics/__init__.py +97 -0
  23. shancx/Calmetrics/calmetrics.py +14 -0
  24. shancx/Calmetrics/calmetricsmatrixLib.py +147 -0
  25. shancx/Calmetrics/rmseR2score.py +35 -0
  26. shancx/Clip/__init__.py +50 -0
  27. shancx/Cmd.py +126 -0
  28. shancx/Config_.py +26 -0
  29. shancx/Df/DataFrame.py +11 -2
  30. shancx/Df/__init__.py +17 -0
  31. shancx/Df/tool.py +0 -0
  32. shancx/Diffm/Psamples.py +18 -0
  33. shancx/Diffm/__init__.py +0 -0
  34. shancx/Diffm/test.py +207 -0
  35. shancx/Doc/__init__.py +214 -0
  36. shancx/E/__init__.py +178 -152
  37. shancx/Fillmiss/__init__.py +0 -0
  38. shancx/Fillmiss/imgidwJU.py +46 -0
  39. shancx/Fillmiss/imgidwLatLonJU.py +82 -0
  40. shancx/Gpu/__init__.py +55 -0
  41. shancx/H9/__init__.py +126 -0
  42. shancx/H9/ahi_read_hsd.py +877 -0
  43. shancx/H9/ahisearchtable.py +298 -0
  44. shancx/H9/geometry.py +2439 -0
  45. shancx/Hug/__init__.py +81 -0
  46. shancx/Inst.py +22 -0
  47. shancx/Lib.py +31 -0
  48. shancx/Mos/__init__.py +37 -0
  49. shancx/NN/__init__.py +235 -106
  50. shancx/Path1.py +161 -0
  51. shancx/Plot/GlobMap.py +276 -116
  52. shancx/Plot/__init__.py +491 -1
  53. shancx/Plot/draw_day_CR_PNG.py +4 -21
  54. shancx/Plot/exam.py +116 -0
  55. shancx/Plot/plotGlobal.py +325 -0
  56. shancx/{radar_nmc.py → Plot/radarNmc.py} +4 -34
  57. shancx/{subplots_single_china_map.py → Plot/single_china_map.py} +1 -1
  58. shancx/Point.py +46 -0
  59. shancx/QC.py +223 -0
  60. shancx/RdPzl/__init__.py +32 -0
  61. shancx/Read.py +72 -0
  62. shancx/Resize.py +79 -0
  63. shancx/SN/__init__.py +62 -123
  64. shancx/Time/GetTime.py +9 -3
  65. shancx/Time/__init__.py +66 -1
  66. shancx/Time/timeCycle.py +302 -0
  67. shancx/Time/tool.py +0 -0
  68. shancx/Train/__init__.py +74 -0
  69. shancx/Train/makelist.py +187 -0
  70. shancx/Train/multiGpu.py +27 -0
  71. shancx/Train/prepare.py +161 -0
  72. shancx/Train/renet50.py +157 -0
  73. shancx/ZR.py +12 -0
  74. shancx/__init__.py +333 -262
  75. shancx/args.py +27 -0
  76. shancx/bak.py +768 -0
  77. shancx/df2database.py +62 -2
  78. shancx/geosProj.py +80 -0
  79. shancx/info.py +38 -0
  80. shancx/netdfJU.py +231 -0
  81. shancx/sendM.py +59 -0
  82. shancx/tensBoard/__init__.py +28 -0
  83. shancx/wait.py +246 -0
  84. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/METADATA +15 -5
  85. shancx-1.9.33.218.dist-info/RECORD +91 -0
  86. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
  87. my_timer_decorator/__init__.py +0 -10
  88. shancx/Dsalgor/__init__.py +0 -19
  89. shancx/E/DFGRRIB.py +0 -30
  90. shancx/EN/DFGRRIB.py +0 -30
  91. shancx/EN/__init__.py +0 -148
  92. shancx/FileRead.py +0 -44
  93. shancx/Gray2RGB.py +0 -86
  94. shancx/M/__init__.py +0 -137
  95. shancx/MN/__init__.py +0 -133
  96. shancx/N/__init__.py +0 -131
  97. shancx/Plot/draw_day_CR_PNGUS.py +0 -206
  98. shancx/Plot/draw_day_CR_SVG.py +0 -275
  99. shancx/Plot/draw_day_pre_PNGUS.py +0 -205
  100. shancx/Plot/glob_nation_map.py +0 -116
  101. shancx/Plot/radar_nmc.py +0 -61
  102. shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
  103. shancx/Plot/radar_nmc_china_map_f.py +0 -121
  104. shancx/Plot/radar_nmc_us_map_f.py +0 -128
  105. shancx/Plot/subplots_compare_devlop.py +0 -36
  106. shancx/Plot/subplots_single_china_map.py +0 -45
  107. shancx/S/__init__.py +0 -138
  108. shancx/W/__init__.py +0 -132
  109. shancx/WN/__init__.py +0 -132
  110. shancx/code.py +0 -331
  111. shancx/draw_day_CR_PNG.py +0 -200
  112. shancx/draw_day_CR_PNGUS.py +0 -206
  113. shancx/draw_day_CR_SVG.py +0 -275
  114. shancx/draw_day_pre_PNGUS.py +0 -205
  115. shancx/makenetCDFN.py +0 -42
  116. shancx/mkIMGSCX.py +0 -92
  117. shancx/netCDF.py +0 -130
  118. shancx/radar_nmc_china_map_compare1.py +0 -50
  119. shancx/radar_nmc_china_map_f.py +0 -125
  120. shancx/radar_nmc_us_map_f.py +0 -67
  121. shancx/subplots_compare_devlop.py +0 -36
  122. shancx/tool.py +0 -18
  123. shancx/user/H8mess.py +0 -317
  124. shancx/user/__init__.py +0 -137
  125. shancx/user/cinradHJN.py +0 -496
  126. shancx/user/examMeso.py +0 -293
  127. shancx/user/hjnDAAS.py +0 -26
  128. shancx/user/hjnFTP.py +0 -81
  129. shancx/user/hjnGIS.py +0 -320
  130. shancx/user/hjnGPU.py +0 -21
  131. shancx/user/hjnIDW.py +0 -68
  132. shancx/user/hjnKDTree.py +0 -75
  133. shancx/user/hjnLAPSTransform.py +0 -47
  134. shancx/user/hjnMiscellaneous.py +0 -182
  135. shancx/user/hjnProj.py +0 -162
  136. shancx/user/inotify.py +0 -41
  137. shancx/user/matplotlibMess.py +0 -87
  138. shancx/user/mkNCHJN.py +0 -623
  139. shancx/user/newTypeRadar.py +0 -492
  140. shancx/user/test.py +0 -6
  141. shancx/user/tlogP.py +0 -129
  142. shancx/util_log.py +0 -33
  143. shancx/wtx/H8mess.py +0 -315
  144. shancx/wtx/__init__.py +0 -151
  145. shancx/wtx/cinradHJN.py +0 -496
  146. shancx/wtx/colormap.py +0 -64
  147. shancx/wtx/examMeso.py +0 -298
  148. shancx/wtx/hjnDAAS.py +0 -26
  149. shancx/wtx/hjnFTP.py +0 -81
  150. shancx/wtx/hjnGIS.py +0 -330
  151. shancx/wtx/hjnGPU.py +0 -21
  152. shancx/wtx/hjnIDW.py +0 -68
  153. shancx/wtx/hjnKDTree.py +0 -75
  154. shancx/wtx/hjnLAPSTransform.py +0 -47
  155. shancx/wtx/hjnLog.py +0 -78
  156. shancx/wtx/hjnMiscellaneous.py +0 -201
  157. shancx/wtx/hjnProj.py +0 -161
  158. shancx/wtx/inotify.py +0 -41
  159. shancx/wtx/matplotlibMess.py +0 -87
  160. shancx/wtx/mkNCHJN.py +0 -613
  161. shancx/wtx/newTypeRadar.py +0 -492
  162. shancx/wtx/test.py +0 -6
  163. shancx/wtx/tlogP.py +0 -129
  164. shancx-1.8.92.dist-info/RECORD +0 -99
  165. /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
  166. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
shancx/3D/__init__.py ADDED
@@ -0,0 +1,25 @@
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ # from mpl_toolkits.mplot3d import Axes3D
5
+ from shancx import crDir
6
+ def plot3DJU(x,path="./3D_test1.png"):
7
+ data = x.cpu().numpy() if x.is_cuda else x[:,:,:].numpy()
8
+ data = data[:, ::5, ::5]
9
+ x1 = np.arange(data.shape[0]) # x轴的范围,取决于data的第一个维度
10
+ y1 = np.arange(data.shape[1]) # y轴的范围,取决于data的第二个维度
11
+ z1 = np.arange(data.shape[2]) # z轴的范围,取决于data的第三个维度
12
+ x1, y1, z1 = np.meshgrid(x1, y1, z1)
13
+ x1_flat = x1.flatten()
14
+ y1_flat = y1.flatten()
15
+ z1_flat = z1.flatten()
16
+ colors = data.flatten()
17
+ fig = plt.figure(figsize=(10, 7))
18
+ ax = fig.add_subplot(111, projection='3d')
19
+ scatter = ax.scatter(x1_flat, y1_flat, z1_flat, c=colors, cmap='viridis')
20
+ ax.set_title("3D Scatter Plot of Shape (400, 640, 400)")
21
+ plt.colorbar(scatter)
22
+ outpath = path
23
+ crDir(outpath)
24
+ plt.savefig(outpath)
25
+ plt.close()
shancx/Algo/Class.py ADDED
@@ -0,0 +1,11 @@
1
+ import torch
2
+ import copy
3
+ def classify1h(pre0):
4
+ pre = copy.deepcopy(pre0)
5
+ pre[torch.logical_and(pre0 >= 0.1, pre0 <= 2.5)] = 1
6
+ pre[torch.logical_and(pre0 > 2.5, pre0 <= 8)] = 2
7
+ pre[torch.logical_and(pre0 > 8, pre0 <= 16)] = 3
8
+ pre[torch.logical_and(pre0 > 16, pre0 <= 300)] = 4
9
+ pre[pre0 > 300] = -1
10
+ pre[torch.isnan(pre0)] = -1
11
+ return pre
@@ -0,0 +1,112 @@
1
+
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch
4
+ class CUDAPrefetcher1:
5
+ """
6
+ Use the CUDA side to accelerate data reading.
7
+
8
+ Args:
9
+ dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
10
+ device (torch.device): Specify running device.
11
+ """
12
+
13
+ def __init__(self, dataloader: DataLoader, device: torch.device):
14
+ self.batch_data = None
15
+ self.original_dataloader = dataloader
16
+ self.device = device
17
+
18
+ self.data = iter(dataloader)
19
+ self.stream = torch.cuda.Stream()
20
+ self.preload()
21
+
22
+ def preload(self):
23
+ """
24
+ Load the next batch of data and move it to the specified device asynchronously.
25
+ """
26
+ try:
27
+ self.batch_data = next(self.data)
28
+ except StopIteration:
29
+ self.batch_data = None
30
+ return
31
+
32
+ # Asynchronously move the data to the GPU
33
+ with torch.cuda.stream(self.stream):
34
+ if isinstance(self.batch_data, dict):
35
+ self.batch_data = {
36
+ k: v.to(self.device, non_blocking=True)
37
+ if torch.is_tensor(v) else v
38
+ for k, v in self.batch_data.items()
39
+ }
40
+ elif isinstance(self.batch_data, (list, tuple)):
41
+ self.batch_data = [
42
+ x.to(self.device, non_blocking=True)
43
+ if torch.is_tensor(x) else x
44
+ for x in self.batch_data
45
+ ]
46
+ elif torch.is_tensor(self.batch_data):
47
+ self.batch_data = self.batch_data.to(self.device, non_blocking=True)
48
+ else:
49
+ raise TypeError(
50
+ f"Unsupported data type {type(self.batch_data)}. Ensure the dataloader outputs tensors, dicts, or lists."
51
+ )
52
+
53
+ def __iter__(self):
54
+ """
55
+ Make the object iterable by resetting the dataloader and returning self.
56
+ """
57
+ self.reset()
58
+ return self
59
+
60
+ def __next__(self):
61
+ """
62
+ Wait for the current stream, return the current batch, and preload the next batch.
63
+ """
64
+ if self.batch_data is None:
65
+ raise StopIteration
66
+
67
+ torch.cuda.current_stream().wait_stream(self.stream)
68
+ batch_data = self.batch_data
69
+ self.preload()
70
+ return batch_data
71
+
72
+ def reset(self):
73
+ """
74
+ Reset the dataloader iterator and preload the first batch.
75
+ """
76
+ self.data = iter(self.original_dataloader)
77
+ self.preload()
78
+
79
+ def __len__(self):
80
+ """
81
+ Return the length of the dataloader.
82
+ """
83
+ return len(self.original_dataloader)
84
+
85
+ # train_data_prefetcher = CUDAPrefetcher1(degenerated_train_dataloader, device)
86
+
87
+ import torch
88
+ from torch.utils.data import DataLoader
89
+ import torch.nn.functional as F
90
+ def custom_collate_fn(batch):
91
+ """
92
+ Custom collate function to handle batches with varying tensor sizes by padding them to the same size.
93
+ """
94
+ if isinstance(batch[0], dict):
95
+ collated = {key: custom_collate_fn([d[key] for d in batch]) for key in batch[0]}
96
+ return collated
97
+ elif isinstance(batch[0], torch.Tensor):
98
+ # Find max dimensions
99
+ max_height = max(item.shape[1] for item in batch)
100
+ max_width = max(item.shape[2] for item in batch)
101
+
102
+ # Pad each tensor to the same size
103
+ padded_batch = []
104
+ for item in batch:
105
+ padding = (0, max_width - item.shape[2], 0, max_height - item.shape[1])
106
+ padded_item = F.pad(item, padding, mode="constant", value=0)
107
+ padded_batch.append(padded_item)
108
+ return torch.stack(padded_batch)
109
+ else:
110
+ return batch
111
+ # device = torch.device("cuda:0")
112
+ # paired_test_dataloader = DataLoader(datasetdata,device)
@@ -0,0 +1,24 @@
1
+ # 加载最佳模型权重
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ from shancx import crDir
5
+ def Fakeimage(generator,device,torchvision.utils.make_grid,):
6
+ generator.load_state_dict(torch.load('/home/scx/train_Gan/best_generator_weights.pth'))
7
+ # 设置模型为评估模式
8
+ generator.eval()
9
+ with torch.no_grad():
10
+ fixed_noise = torch.randn(64, 100, device=device) # 将噪声张量直接创建在 GPU 上
11
+ fake_images = generator(fixed_noise)
12
+ # 检查生成图像的形状
13
+ print(f"生成的图像形状: {fake_images.shape}")
14
+ # 反归一化处理
15
+ fake_images = (fake_images + 1) / 2 # 将图像像素值从[-1, 1]映射到[0, 1]
16
+ print(f"图像最小值: {fake_images.min()}, 最大值: {fake_images.max()}")
17
+ grid = torchvision.utils.make_grid(fake_images, nrow=8, normalize=True) # 设置normalize=True以确保正确的缩放
18
+ grid_cpu = grid.cpu().detach().numpy()
19
+ plt.imshow(grid_cpu.transpose(1, 2, 0)) # 调整图像通道顺序为(H, W, C)
20
+ plt.axis('off') # 关闭坐标轴显示
21
+ outPath = './fake_images/best_weights_fake_images.png'
22
+ crDir(outPath)
23
+ plt.savefig(outPath)
24
+ plt.close()
shancx/Algo/Hsml.py ADDED
@@ -0,0 +1,391 @@
1
+
2
+ import torchvision.models as models
3
+ from torch import nn
4
+ import torch
5
+ from torchvision.models import vgg19
6
+ # Define VGG Loss
7
+ class VGGLoss(nn.Module):
8
+ def __init__(self, weights_path=None,device =None):
9
+ super().__init__()
10
+ self.vgg = models.vgg19(pretrained=False).features[:35].eval().to(device)
11
+ if weights_path:
12
+ pretrained_weights = torch.load(weights_path)
13
+ self.vgg.load_state_dict(pretrained_weights, strict=False)
14
+ self.vgg[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1).to(device)
15
+ for param in self.vgg.parameters():
16
+ param.requires_grad = False
17
+ self.loss = nn.MSELoss()
18
+ def forward(self, input, target):
19
+ input_features = self.vgg(input)
20
+ target_features = self.vgg(target)
21
+ return self.loss(input_features, target_features)
22
+ """
23
+ vgg_loss = VGGLoss(weights_path="/mnt/wtx_weather_forecast/scx/stat/sat/sat2radar/vgg19-dcbb9e9d.pth").to(device)
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from segment_anything import sam_model_registry
30
+
31
+ class SAMLoss(nn.Module):
32
+ def __init__(self, model_type='vit_b', checkpoint_path=None, input_size=1024):
33
+ """
34
+ SAM-based perceptual loss with resolution handling
35
+
36
+ Args:
37
+ model_type (str): SAM model type (vit_b, vit_l, vit_h)
38
+ checkpoint_path (str): Path to SAM checkpoint weights
39
+ input_size (int): Target input size for SAM (default 1024)
40
+ """
41
+ super().__init__()
42
+ self.input_size = input_size
43
+ # Initialize SAM model
44
+ self.sam = sam_model_registry[model_type](checkpoint=None)
45
+ # Load pretrained weights if provided
46
+ if checkpoint_path:
47
+ state_dict = torch.load(checkpoint_path)
48
+ self.sam.load_state_dict(state_dict)
49
+ # Use image encoder only and freeze parameters
50
+ self.image_encoder = self.sam.image_encoder.eval()
51
+ for param in self.image_encoder.parameters():
52
+ param.requires_grad = False
53
+ # Define loss function
54
+ self.loss = nn.MSELoss()
55
+ # Normalization parameters
56
+ self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
57
+ self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
58
+
59
+ def preprocess(self, x):
60
+ """
61
+ Preprocess input to match SAM requirements:
62
+ 1. Convert to 3-channel if needed
63
+ 2. Normalize using ImageNet stats
64
+ 3. Resize to target size
65
+ """
66
+ if x.shape[1] == 1:
67
+ x = x.expand(-1, 3, -1, -1) # More memory efficient than repeat
68
+ # Normalize
69
+ x = (x - self.mean) / self.std
70
+ # Resize
71
+ if x.shape[-2:] != (self.input_size, self.input_size):
72
+ x = F.interpolate(x, size=(self.input_size, self.input_size),
73
+ mode='bilinear', align_corners=False)
74
+ return x
75
+ def forward(self, input, target):
76
+ # Preprocess
77
+ input = self.preprocess(input)
78
+ target = self.preprocess(target)
79
+ # Process in batches if needed
80
+ batch_size = 4 # Adjust based on your GPU memory
81
+ input_features = []
82
+ target_features = []
83
+ with torch.no_grad():
84
+ for i in range(0, input.size(0), batch_size):
85
+ input_batch = input[i:i+batch_size]
86
+ target_batch = target[i:i+batch_size]
87
+ input_features.append(self.image_encoder(input_batch))
88
+ target_features.append(self.image_encoder(target_batch))
89
+ return self.loss(torch.cat(input_features), torch.cat(target_features))
90
+
91
+
92
+ import torch
93
+ import torch.nn as nn
94
+ import torch.nn.functional as F
95
+ from segment_anything import sam_model_registry
96
+
97
+
98
+ class SAMLoss(nn.Module):
99
+ def __init__(self, model_type='vit_b', checkpoint_path=None):
100
+ super().__init__()
101
+ # 初始化 SAM 并加载预训练权重
102
+ self.sam = sam_model_registry[model_type](checkpoint=None)
103
+ if checkpoint_path:
104
+ state_dict = torch.load(checkpoint_path)
105
+ self.sam.load_state_dict(state_dict)
106
+
107
+ # 提取 image_encoder 并冻结
108
+ self.image_encoder = self.sam.image_encoder.eval()
109
+ for param in self.image_encoder.parameters():
110
+ param.requires_grad = False
111
+
112
+ # 获取 patch_size(通常为16)
113
+ self.patch_size = self.image_encoder.patch_embed.proj.kernel_size[0]
114
+
115
+ # 保存原始 pos_embed 供动态调整
116
+ self.original_pos_embed = self.image_encoder.pos_embed
117
+
118
+ # 归一化参数
119
+ self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
120
+ self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
121
+
122
+ self.loss = nn.MSELoss()
123
+
124
+ def adjust_pos_embed(self, x):
125
+ """
126
+ 动态调整 pos_embed 以匹配输入尺寸
127
+ Args:
128
+ x: 输入张量 (B, C, H, W)
129
+ """
130
+ B, _, H, W = x.shape
131
+ # 计算当前特征图的分辨率(H/patch_size, W/patch_size)
132
+ h, w = H // self.patch_size, W // self.patch_size
133
+
134
+ # 如果 pos_embed 尺寸不匹配,则进行插值
135
+ if (h, w) != self.original_pos_embed.shape[1:3]:
136
+ # 插值 pos_embed 到目标尺寸 (1, h, w, C) -> (1, C, h, w) -> 插值 -> 恢复形状
137
+ pos_embed = self.original_pos_embed.permute(0, 3, 1, 2) # (1, C, H_orig, W_orig)
138
+ pos_embed = F.interpolate(
139
+ pos_embed,
140
+ size=(h, w),
141
+ mode='bilinear',
142
+ align_corners=False
143
+ ).permute(0, 2, 3, 1) # (1, h, w, C)
144
+
145
+ # 用 nn.Parameter 包装调整后的 pos_embed
146
+ self.image_encoder.pos_embed = nn.Parameter(pos_embed, requires_grad=False)
147
+ else:
148
+ self.image_encoder.pos_embed = self.original_pos_embed
149
+ def preprocess(self, x):
150
+ """预处理:通道扩展 + 归一化 + 尺寸调整"""
151
+ if x.shape[1] == 1:
152
+ x = x.expand(-1, 3, -1, -1) # 单通道→三通
153
+ # 归一化
154
+ x = (x - self.mean) / self.std
155
+ # 确保尺寸是 patch_size 的整数倍
156
+ B, C, H, W = x.shape
157
+ H_new = (H // self.patch_size) * self.patch_size
158
+ W_new = (W // self.patch_size) * self.patch_size
159
+ if H != H_new or W != W_new:
160
+ x = F.interpolate(
161
+ x,
162
+ size=(H_new, W_new),
163
+ mode='bilinear',
164
+ align_corners=False
165
+ )
166
+ return x
167
+ def forward(self, input, target):
168
+ # 预处理
169
+ input = self.preprocess(input)
170
+ target = self.preprocess(target)
171
+
172
+ # 动态调整 pos_embed
173
+ self.adjust_pos_embed(input)
174
+
175
+ # 计算特征
176
+ with torch.no_grad():
177
+ input_feat = self.image_encoder(input)
178
+ target_feat = self.image_encoder(target)
179
+
180
+ return self.loss(input_feat, target_feat)
181
+
182
+ # saml = SAMLoss(checkpoint_path = "/path/to/sam_vit_b_01ec64.pth.1").to(device)
183
+
184
+
185
+ import torch
186
+ import torch.nn as nn
187
+ from torchvision import models
188
+ class WeightedVGGLoss(nn.Module):
189
+ def __init__(self, weights_path=None, device=None, apply_weighting=True):
190
+ super().__init__()
191
+ self.device = device
192
+ self.apply_weighting = apply_weighting
193
+ self.vgg = models.vgg19(pretrained=False).features[:35].eval().to(device)
194
+ if weights_path:
195
+ pretrained_weights = torch.load(weights_path)
196
+ self.vgg.load_state_dict(pretrained_weights, strict=False)
197
+ self.vgg[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1).to(device)
198
+ for param in self.vgg.parameters():
199
+ param.requires_grad = False
200
+ self.mse_loss = nn.MSELoss()
201
+ def custom_weight_4d(self, target):
202
+ EPS = 1e-6
203
+ segments = [
204
+ [10, 20, 0.5, 0.1], [20, 30, 0.4, 0.12],
205
+ [30, 40, 0.3, 0.14], [40, 50, 0.2, 0.16],
206
+ [50, 60, 0.1, 0.18], [60, 70, 0.05, 0.2]
207
+ ]
208
+ weights = torch.full_like(target, 0.1)
209
+ for start, end, steepness, base_w in segments:
210
+ mask = (target >= start) & (target < end)
211
+ x = ((target - start) / (end - start + EPS)).clamp(0, 1)
212
+ delta_w = (1.0 - base_w) * (0.7 - steepness)
213
+ seg_weights = base_w + delta_w * (x / (x + steepness*(1 - x) + EPS))
214
+ weights = torch.where(mask, seg_weights, weights)
215
+ weights[target >= 70] = 1.0
216
+ return weights
217
+ def forward(self, input, target):
218
+ input_features = self.vgg(input)
219
+ target_features = self.vgg(target)
220
+ if self.apply_weighting:
221
+ weights = self.custom_weight_4d(target_features.detach())
222
+ loss = torch.mean(weights * (input_features - target_features)**2)
223
+ else:
224
+ loss = self.mse_loss(input_features, target_features)
225
+ return loss
226
+ # vgg_loss = WeightedVGGLoss(weights_path="/mnt/wtx_weather_forecast/scx/stat/sat/sat2radar/vgg19-dcbb9e9d.pth",device=device).to(device)
227
+
228
+ import torch
229
+ import torch.nn as nn
230
+ import torch.nn.functional as F
231
+ class WeightedL1Loss(nn.Module):
232
+ def __init__(self, segment_ranges=None, device=None):
233
+ super().__init__()
234
+ self.device = device
235
+ self.segment_ranges = segment_ranges or [
236
+ (10, 20, 0.5, 0.1), (20, 30, 0.4, 0.12),
237
+ (30, 40, 0.3, 0.14), (40, 50, 0.2, 0.16),
238
+ (50, 60, 0.1, 0.18), (60, 70, 0.05, 0.2)
239
+ ]
240
+ def compute_weights(self, target):
241
+ EPS = 1e-6
242
+ weights = torch.full_like(target, 0.1) # 默认权重0.1
243
+ for start, end, steepness, base_w in self.segment_ranges:
244
+ mask = (target >= start) & (target < end)
245
+ x = ((target - start) / (end - start + EPS)).clamp(0, 1)
246
+ delta_w = (1.0 - base_w) * (0.7 - steepness)
247
+ seg_weights = base_w + delta_w * (x / (x + steepness*(1 - x) + EPS))
248
+ weights = torch.where(mask, seg_weights, weights)
249
+ weights[target >= 70] = 1.0
250
+ return weights
251
+ def forward(self, output, target):
252
+ """
253
+ 非降维(reduction='none')计算流程:
254
+ 1. 计算逐元素L1损失
255
+ 2. 生成动态权重矩阵
256
+ 3. 返回加权损失张量(保持输入维度)
257
+ """
258
+ # 保持维度的L1损失 (B,C,H,W)
259
+ l1_loss = F.l1_loss(output, target, reduction='none')
260
+ weights = self.compute_weights(target.detach()) # 切断梯度
261
+ # weighted_loss = torch.mean(weights * l1_loss)
262
+ weighted_loss = weights * l1_loss
263
+ return weighted_loss
264
+ # L1loss = WeightedL1Loss(device= device)
265
+
266
+ import torch
267
+ import torch.nn as nn
268
+ import torch.nn.functional as F
269
+ from segment_anything import sam_model_registry
270
+ class SAMLoss1(nn.Module):
271
+ def __init__(self, model_type='vit_b', checkpoint_path=None):
272
+ super().__init__()
273
+ self.sam = sam_model_registry[model_type](checkpoint=None)
274
+ if checkpoint_path:
275
+ state_dict = torch.load(checkpoint_path)
276
+ self.sam.load_state_dict(state_dict)
277
+ self.image_encoder = self.sam.image_encoder.eval()
278
+ for param in self.image_encoder.parameters():
279
+ param.requires_grad = False
280
+ self.patch_size = self.image_encoder.patch_embed.proj.kernel_size[0]
281
+ self.original_pos_embed = self.image_encoder.pos_embed
282
+ self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
283
+ self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
284
+ self.loss = nn.MSELoss()
285
+ def adjust_pos_embed(self, x):
286
+ """
287
+ 动态调整 pos_embed 以匹配输入尺寸
288
+ Args:
289
+ x: 输入张量 (B, C, H, W)
290
+ """
291
+ B, _, H, W = x.shape
292
+ h, w = H // self.patch_size, W // self.patch_size
293
+ if (h, w) != self.original_pos_embed.shape[1:3]:
294
+ pos_embed = self.original_pos_embed.permute(0, 3, 1, 2) # (1, C, H_orig, W_orig)
295
+ pos_embed = F.interpolate(
296
+ pos_embed,
297
+ size=(h, w),
298
+ mode='bilinear',
299
+ align_corners=False
300
+ ).permute(0, 2, 3, 1) # (1, h, w, C)
301
+ self.image_encoder.pos_embed = nn.Parameter(pos_embed, requires_grad=False)
302
+ else:
303
+ self.image_encoder.pos_embed = self.original_pos_embed
304
+ def preprocess(self, x):
305
+ """预处理:通道扩展 + 归一化 + 尺寸调整"""
306
+ if x.shape[1] == 1:
307
+ x = x.expand(-1, 3, -1, -1) # 单通道→三通道
308
+ x = (x - self.mean) / self.std
309
+ B, C, H, W = x.shape
310
+ H_new = (H // self.patch_size) * self.patch_size
311
+ W_new = (W // self.patch_size) * self.patch_size
312
+ if H != H_new or W != W_new:
313
+ x = F.interpolate(
314
+ x,
315
+ size=(H_new, W_new),
316
+ mode='bilinear',
317
+ align_corners=False
318
+ )
319
+ return x
320
+ def forward(self, input, target):
321
+ input = self.preprocess(input)
322
+ target = self.preprocess(target)
323
+ self.adjust_pos_embed(input)
324
+ with torch.no_grad():
325
+ input_feat = self.image_encoder(input)
326
+ target_feat = self.image_encoder(target)
327
+ return self.loss(input_feat, target_feat)
328
+
329
+ import torch
330
+ import torch.nn as nn
331
+ import torch.nn.functional as F
332
+ from segment_anything import sam_model_registry
333
+ class SAMLoss2(nn.Module):
334
+ def __init__(self, model_type='vit_b', checkpoint_path=None):
335
+ super().__init__()
336
+ self.sam = sam_model_registry[model_type](checkpoint=None)
337
+ if checkpoint_path:
338
+ state_dict = torch.load(checkpoint_path)
339
+ self.sam.load_state_dict(state_dict)
340
+ self.image_encoder = self.sam.image_encoder.eval()
341
+ for param in self.image_encoder.parameters():
342
+ param.requires_grad = False
343
+ self.patch_size = self.image_encoder.patch_embed.proj.kernel_size[0]
344
+ self.original_pos_embed = self.image_encoder.pos_embed
345
+ self.loss = nn.MSELoss()
346
+ def adjust_pos_embed(self, x):
347
+ """
348
+ 动态调整 pos_embed 以匹配输入尺寸
349
+ Args:
350
+ x: 输入张量 (B, C, H, W)
351
+ """
352
+ B, _, H, W = x.shape
353
+ h, w = H // self.patch_size, W // self.patch_size
354
+ if (h, w) != self.original_pos_embed.shape[1:3]:
355
+ pos_embed = self.original_pos_embed.permute(0, 3, 1, 2) # (1, C, H_orig, W_orig)
356
+ pos_embed = F.interpolate(
357
+ pos_embed,
358
+ size=(h, w),
359
+ mode='bilinear',
360
+ align_corners=False
361
+ ).permute(0, 2, 3, 1) # (1, h, w, C)
362
+ self.image_encoder.pos_embed = nn.Parameter(pos_embed, requires_grad=False)
363
+ else:
364
+ self.image_encoder.pos_embed = self.original_pos_embed
365
+ def preprocess(self, x):
366
+ """预处理:仅通道扩展 + 尺寸调整(跳过归一化)"""
367
+ if x.shape[1] == 1:
368
+ x = x.expand(-1, 3, -1, -1) # 单通道→三通道
369
+ B, C, H, W = x.shape
370
+ H_new = (H // self.patch_size) * self.patch_size
371
+ W_new = (W // self.patch_size) * self.patch_size
372
+ if H != H_new or W != W_new:
373
+ x = F.interpolate(
374
+ x,
375
+ size=(H_new, W_new),
376
+ mode='bilinear',
377
+ align_corners=False
378
+ )
379
+ return x
380
+ def forward(self, input, target):
381
+ # 预处理(input和target均不归一化)
382
+ input = self.preprocess(input)
383
+ target = self.preprocess(target)
384
+ # 动态调整 pos_embed
385
+ self.adjust_pos_embed(input)
386
+ # 计算特征
387
+ with torch.no_grad():
388
+ input_feat = self.image_encoder(input)
389
+ target_feat = self.image_encoder(target)
390
+ return self.loss(input_feat, target_feat)
391
+ #saml = SAMLoss(checkpoint_path = "/mnt/wtx_weather_forecast/scx/stat/sat/sat2radar/sam_vit_b_01ec64.pth.1").to(device)
shancx/Algo/L2Loss.py ADDED
@@ -0,0 +1,10 @@
1
+ import torch
2
+ def L2loss(model,loss,lambda_reg =0.01):
3
+ l2_reg = 0.0
4
+ for param in model.parameters():
5
+ l2_reg += torch.norm(param)**2 # 模型参数的平方和
6
+ print(l2_reg)
7
+ loss += lambda_reg * l2_reg
8
+ return loss
9
+
10
+