scxfxt 1.0.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. scxfxt-1.0.0.0/LICENSE +21 -0
  2. scxfxt-1.0.0.0/MANIFEST.in +2 -0
  3. scxfxt-1.0.0.0/PKG-INFO +26 -0
  4. scxfxt-1.0.0.0/README.md +9 -0
  5. scxfxt-1.0.0.0/scxfxt/3D/__init__.py +25 -0
  6. scxfxt-1.0.0.0/scxfxt/Algo/Class.py +11 -0
  7. scxfxt-1.0.0.0/scxfxt/Algo/CudaPrefetcher1.py +112 -0
  8. scxfxt-1.0.0.0/scxfxt/Algo/Fake_image.py +24 -0
  9. scxfxt-1.0.0.0/scxfxt/Algo/Hsml.py +391 -0
  10. scxfxt-1.0.0.0/scxfxt/Algo/L2Loss.py +10 -0
  11. scxfxt-1.0.0.0/scxfxt/Algo/MetricTracker.py +132 -0
  12. scxfxt-1.0.0.0/scxfxt/Algo/Normalize.py +66 -0
  13. scxfxt-1.0.0.0/scxfxt/Algo/OptimizerWScheduler.py +38 -0
  14. scxfxt-1.0.0.0/scxfxt/Algo/Rmageresize.py +79 -0
  15. scxfxt-1.0.0.0/scxfxt/Algo/Savemodel.py +33 -0
  16. scxfxt-1.0.0.0/scxfxt/Algo/SmoothL1_losses.py +27 -0
  17. scxfxt-1.0.0.0/scxfxt/Algo/Tqdm.py +62 -0
  18. scxfxt-1.0.0.0/scxfxt/Algo/__init__.py +121 -0
  19. scxfxt-1.0.0.0/scxfxt/Algo/checknan.py +28 -0
  20. scxfxt-1.0.0.0/scxfxt/Algo/dsalgor.py +12 -0
  21. scxfxt-1.0.0.0/scxfxt/Algo/iouJU.py +83 -0
  22. scxfxt-1.0.0.0/scxfxt/Algo/mask.py +25 -0
  23. scxfxt-1.0.0.0/scxfxt/Algo/psnr.py +9 -0
  24. scxfxt-1.0.0.0/scxfxt/Algo/ssim.py +70 -0
  25. scxfxt-1.0.0.0/scxfxt/Algo/structural_similarity.py +308 -0
  26. scxfxt-1.0.0.0/scxfxt/Algo/tool.py +704 -0
  27. scxfxt-1.0.0.0/scxfxt/Calmetrics/__init__.py +97 -0
  28. scxfxt-1.0.0.0/scxfxt/Calmetrics/calmetrics.py +14 -0
  29. scxfxt-1.0.0.0/scxfxt/Calmetrics/calmetricsmatrixLib.py +147 -0
  30. scxfxt-1.0.0.0/scxfxt/Calmetrics/rmseR2score.py +35 -0
  31. scxfxt-1.0.0.0/scxfxt/Clip/__init__.py +50 -0
  32. scxfxt-1.0.0.0/scxfxt/Cmd.py +126 -0
  33. scxfxt-1.0.0.0/scxfxt/Config_.py +26 -0
  34. scxfxt-1.0.0.0/scxfxt/Df/DataFrame.py +18 -0
  35. scxfxt-1.0.0.0/scxfxt/Df/__init__.py +34 -0
  36. scxfxt-1.0.0.0/scxfxt/Df/tool.py +0 -0
  37. scxfxt-1.0.0.0/scxfxt/Diffm/Psamples.py +18 -0
  38. scxfxt-1.0.0.0/scxfxt/Diffm/__init__.py +0 -0
  39. scxfxt-1.0.0.0/scxfxt/Diffm/test.py +207 -0
  40. scxfxt-1.0.0.0/scxfxt/Doc/__init__.py +214 -0
  41. scxfxt-1.0.0.0/scxfxt/E/__init__.py +178 -0
  42. scxfxt-1.0.0.0/scxfxt/Fillmiss/__init__.py +0 -0
  43. scxfxt-1.0.0.0/scxfxt/Fillmiss/imgidwJU.py +46 -0
  44. scxfxt-1.0.0.0/scxfxt/Fillmiss/imgidwLatLonJU.py +82 -0
  45. scxfxt-1.0.0.0/scxfxt/Gpu/__init__.py +55 -0
  46. scxfxt-1.0.0.0/scxfxt/H/__init__.py +0 -0
  47. scxfxt-1.0.0.0/scxfxt/H/optSchler.py +19 -0
  48. scxfxt-1.0.0.0/scxfxt/H/simple.py +8 -0
  49. scxfxt-1.0.0.0/scxfxt/H9/__init__.py +126 -0
  50. scxfxt-1.0.0.0/scxfxt/H9/ahi_read_hsd.py +877 -0
  51. scxfxt-1.0.0.0/scxfxt/H9/ahisearchtable.py +298 -0
  52. scxfxt-1.0.0.0/scxfxt/H9/geometry.py +2439 -0
  53. scxfxt-1.0.0.0/scxfxt/Hug/__init__.py +78 -0
  54. scxfxt-1.0.0.0/scxfxt/Inst.py +22 -0
  55. scxfxt-1.0.0.0/scxfxt/Lib.py +31 -0
  56. scxfxt-1.0.0.0/scxfxt/Mos/__init__.py +37 -0
  57. scxfxt-1.0.0.0/scxfxt/NN/__init__.py +256 -0
  58. scxfxt-1.0.0.0/scxfxt/NN/chainMul.py +175 -0
  59. scxfxt-1.0.0.0/scxfxt/Path1.py +161 -0
  60. scxfxt-1.0.0.0/scxfxt/Plot/GlobMap.py +276 -0
  61. scxfxt-1.0.0.0/scxfxt/Plot/Gray2RGB.py +86 -0
  62. scxfxt-1.0.0.0/scxfxt/Plot/__init__.py +498 -0
  63. scxfxt-1.0.0.0/scxfxt/Plot/draw_day_CR_PNG.py +183 -0
  64. scxfxt-1.0.0.0/scxfxt/Plot/exam.py +116 -0
  65. scxfxt-1.0.0.0/scxfxt/Plot/plotGlobal.py +449 -0
  66. scxfxt-1.0.0.0/scxfxt/Plot/radarNmc.py +31 -0
  67. scxfxt-1.0.0.0/scxfxt/Plot/single_china_map.py +45 -0
  68. scxfxt-1.0.0.0/scxfxt/Point.py +46 -0
  69. scxfxt-1.0.0.0/scxfxt/QC.py +223 -0
  70. scxfxt-1.0.0.0/scxfxt/RdPzl/__init__.py +32 -0
  71. scxfxt-1.0.0.0/scxfxt/Read.py +148 -0
  72. scxfxt-1.0.0.0/scxfxt/Resize.py +79 -0
  73. scxfxt-1.0.0.0/scxfxt/SN/__init__.py +77 -0
  74. scxfxt-1.0.0.0/scxfxt/Time/GetTime.py +38 -0
  75. scxfxt-1.0.0.0/scxfxt/Time/__init__.py +114 -0
  76. scxfxt-1.0.0.0/scxfxt/Time/timeCycle.py +302 -0
  77. scxfxt-1.0.0.0/scxfxt/Time/tool.py +0 -0
  78. scxfxt-1.0.0.0/scxfxt/Train/__init__.py +74 -0
  79. scxfxt-1.0.0.0/scxfxt/Train/makelist.py +187 -0
  80. scxfxt-1.0.0.0/scxfxt/Train/multiGpu.py +27 -0
  81. scxfxt-1.0.0.0/scxfxt/Train/prepare.py +161 -0
  82. scxfxt-1.0.0.0/scxfxt/Train/renet50.py +157 -0
  83. scxfxt-1.0.0.0/scxfxt/ZR.py +12 -0
  84. scxfxt-1.0.0.0/scxfxt/__init__.py +334 -0
  85. scxfxt-1.0.0.0/scxfxt/args.py +27 -0
  86. scxfxt-1.0.0.0/scxfxt/bak.py +768 -0
  87. scxfxt-1.0.0.0/scxfxt/cmp.py +46 -0
  88. scxfxt-1.0.0.0/scxfxt/df2database.py +89 -0
  89. scxfxt-1.0.0.0/scxfxt/geosProj.py +80 -0
  90. scxfxt-1.0.0.0/scxfxt/getResponse.py +25 -0
  91. scxfxt-1.0.0.0/scxfxt/info.py +38 -0
  92. scxfxt-1.0.0.0/scxfxt/netdfJU.py +231 -0
  93. scxfxt-1.0.0.0/scxfxt/npz.py +29 -0
  94. scxfxt-1.0.0.0/scxfxt/sendM.py +58 -0
  95. scxfxt-1.0.0.0/scxfxt/tensBoard/__init__.py +28 -0
  96. scxfxt-1.0.0.0/scxfxt/wait.py +246 -0
  97. scxfxt-1.0.0.0/scxfxt.egg-info/PKG-INFO +26 -0
  98. scxfxt-1.0.0.0/scxfxt.egg-info/SOURCES.txt +102 -0
  99. scxfxt-1.0.0.0/scxfxt.egg-info/dependency_links.txt +1 -0
  100. scxfxt-1.0.0.0/scxfxt.egg-info/requires.txt +3 -0
  101. scxfxt-1.0.0.0/scxfxt.egg-info/top_level.txt +1 -0
  102. scxfxt-1.0.0.0/setup.cfg +4 -0
  103. scxfxt-1.0.0.0/setup.py +44 -0
scxfxt-1.0.0.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 [scxfxt]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,2 @@
1
+ include LICENSE
2
+ include README.md
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.1
2
+ Name: scxfxt
3
+ Version: 1.0.0.0
4
+ Summary: scxfxt
5
+ Home-page: https://gitee.com/scxfxt
6
+ Author: scxfxt_scxfxt
7
+ Author-email: shanhe12@163.com
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.6
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: tqdm
15
+ Requires-Dist: pandas
16
+ Requires-Dist: matplotlib
17
+
18
+ # Welecome to scxfxt
19
+
20
+ A simple Python package that provides a timer decorator to measure the execution time of functions.
21
+
22
+ ## Installation
23
+
24
+ pip install scxfxt
25
+
26
+ pip install scxfxt -i https://mirrors.cloud.tencent.com/pypi/simple
@@ -0,0 +1,9 @@
1
+ # Welecome to scxfxt
2
+
3
+ A simple Python package that provides a timer decorator to measure the execution time of functions.
4
+
5
+ ## Installation
6
+
7
+ pip install scxfxt
8
+
9
+ pip install scxfxt -i https://mirrors.cloud.tencent.com/pypi/simple
@@ -0,0 +1,25 @@
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ # from mpl_toolkits.mplot3d import Axes3D
5
+ from shancxn import crDir
6
+ def plot3DJU(x,path="./3D_test1.png"):
7
+ data = x.cpu().numpy() if x.is_cuda else x[:,:,:].numpy()
8
+ data = data[:, ::5, ::5]
9
+ x1 = np.arange(data.shape[0]) # x轴的范围,取决于data的第一个维度
10
+ y1 = np.arange(data.shape[1]) # y轴的范围,取决于data的第二个维度
11
+ z1 = np.arange(data.shape[2]) # z轴的范围,取决于data的第三个维度
12
+ x1, y1, z1 = np.meshgrid(x1, y1, z1)
13
+ x1_flat = x1.flatten()
14
+ y1_flat = y1.flatten()
15
+ z1_flat = z1.flatten()
16
+ colors = data.flatten()
17
+ fig = plt.figure(figsize=(10, 7))
18
+ ax = fig.add_subplot(111, projection='3d')
19
+ scatter = ax.scatter(x1_flat, y1_flat, z1_flat, c=colors, cmap='viridis')
20
+ ax.set_title("3D Scatter Plot of Shape (400, 640, 400)")
21
+ plt.colorbar(scatter)
22
+ outpath = path
23
+ crDir(outpath)
24
+ plt.savefig(outpath)
25
+ plt.close()
@@ -0,0 +1,11 @@
1
+ import torch
2
+ import copy
3
+ def classify1h(pre0):
4
+ pre = copy.deepcopy(pre0)
5
+ pre[torch.logical_and(pre0 >= 0.1, pre0 <= 2.5)] = 1
6
+ pre[torch.logical_and(pre0 > 2.5, pre0 <= 8)] = 2
7
+ pre[torch.logical_and(pre0 > 8, pre0 <= 16)] = 3
8
+ pre[torch.logical_and(pre0 > 16, pre0 <= 300)] = 4
9
+ pre[pre0 > 300] = -1
10
+ pre[torch.isnan(pre0)] = -1
11
+ return pre
@@ -0,0 +1,112 @@
1
+
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch
4
+ class CUDAPrefetcher1:
5
+ """
6
+ Use the CUDA side to accelerate data reading.
7
+
8
+ Args:
9
+ dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
10
+ device (torch.device): Specify running device.
11
+ """
12
+
13
+ def __init__(self, dataloader: DataLoader, device: torch.device):
14
+ self.batch_data = None
15
+ self.original_dataloader = dataloader
16
+ self.device = device
17
+
18
+ self.data = iter(dataloader)
19
+ self.stream = torch.cuda.Stream()
20
+ self.preload()
21
+
22
+ def preload(self):
23
+ """
24
+ Load the next batch of data and move it to the specified device asynchronously.
25
+ """
26
+ try:
27
+ self.batch_data = next(self.data)
28
+ except StopIteration:
29
+ self.batch_data = None
30
+ return
31
+
32
+ # Asynchronously move the data to the GPU
33
+ with torch.cuda.stream(self.stream):
34
+ if isinstance(self.batch_data, dict):
35
+ self.batch_data = {
36
+ k: v.to(self.device, non_blocking=True)
37
+ if torch.is_tensor(v) else v
38
+ for k, v in self.batch_data.items()
39
+ }
40
+ elif isinstance(self.batch_data, (list, tuple)):
41
+ self.batch_data = [
42
+ x.to(self.device, non_blocking=True)
43
+ if torch.is_tensor(x) else x
44
+ for x in self.batch_data
45
+ ]
46
+ elif torch.is_tensor(self.batch_data):
47
+ self.batch_data = self.batch_data.to(self.device, non_blocking=True)
48
+ else:
49
+ raise TypeError(
50
+ f"Unsupported data type {type(self.batch_data)}. Ensure the dataloader outputs tensors, dicts, or lists."
51
+ )
52
+
53
+ def __iter__(self):
54
+ """
55
+ Make the object iterable by resetting the dataloader and returning self.
56
+ """
57
+ self.reset()
58
+ return self
59
+
60
+ def __next__(self):
61
+ """
62
+ Wait for the current stream, return the current batch, and preload the next batch.
63
+ """
64
+ if self.batch_data is None:
65
+ raise StopIteration
66
+
67
+ torch.cuda.current_stream().wait_stream(self.stream)
68
+ batch_data = self.batch_data
69
+ self.preload()
70
+ return batch_data
71
+
72
+ def reset(self):
73
+ """
74
+ Reset the dataloader iterator and preload the first batch.
75
+ """
76
+ self.data = iter(self.original_dataloader)
77
+ self.preload()
78
+
79
+ def __len__(self):
80
+ """
81
+ Return the length of the dataloader.
82
+ """
83
+ return len(self.original_dataloader)
84
+
85
+ # train_data_prefetcher = CUDAPrefetcher1(degenerated_train_dataloader, device)
86
+
87
+ import torch
88
+ from torch.utils.data import DataLoader
89
+ import torch.nn.functional as F
90
+ def custom_collate_fn(batch):
91
+ """
92
+ Custom collate function to handle batches with varying tensor sizes by padding them to the same size.
93
+ """
94
+ if isinstance(batch[0], dict):
95
+ collated = {key: custom_collate_fn([d[key] for d in batch]) for key in batch[0]}
96
+ return collated
97
+ elif isinstance(batch[0], torch.Tensor):
98
+ # Find max dimensions
99
+ max_height = max(item.shape[1] for item in batch)
100
+ max_width = max(item.shape[2] for item in batch)
101
+
102
+ # Pad each tensor to the same size
103
+ padded_batch = []
104
+ for item in batch:
105
+ padding = (0, max_width - item.shape[2], 0, max_height - item.shape[1])
106
+ padded_item = F.pad(item, padding, mode="constant", value=0)
107
+ padded_batch.append(padded_item)
108
+ return torch.stack(padded_batch)
109
+ else:
110
+ return batch
111
+ # device = torch.device("cuda:0")
112
+ # paired_test_dataloader = DataLoader(datasetdata,device)
@@ -0,0 +1,24 @@
1
+ # 加载最佳模型权重
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ from shancxn import crDir
5
+ def Fakeimage(generator,device,torchvision.utils.make_grid,):
6
+ generator.load_state_dict(torch.load('/home/scx/train_Gan/best_generator_weights.pth'))
7
+ # 设置模型为评估模式
8
+ generator.eval()
9
+ with torch.no_grad():
10
+ fixed_noise = torch.randn(64, 100, device=device) # 将噪声张量直接创建在 GPU 上
11
+ fake_images = generator(fixed_noise)
12
+ # 检查生成图像的形状
13
+ print(f"生成的图像形状: {fake_images.shape}")
14
+ # 反归一化处理
15
+ fake_images = (fake_images + 1) / 2 # 将图像像素值从[-1, 1]映射到[0, 1]
16
+ print(f"图像最小值: {fake_images.min()}, 最大值: {fake_images.max()}")
17
+ grid = torchvision.utils.make_grid(fake_images, nrow=8, normalize=True) # 设置normalize=True以确保正确的缩放
18
+ grid_cpu = grid.cpu().detach().numpy()
19
+ plt.imshow(grid_cpu.transpose(1, 2, 0)) # 调整图像通道顺序为(H, W, C)
20
+ plt.axis('off') # 关闭坐标轴显示
21
+ outPath = './fake_images/best_weights_fake_images.png'
22
+ crDir(outPath)
23
+ plt.savefig(outPath)
24
+ plt.close()
@@ -0,0 +1,391 @@
1
+
2
+ import torchvision.models as models
3
+ from torch import nn
4
+ import torch
5
+ from torchvision.models import vgg19
6
+ # Define VGG Loss
7
+ class VGGLoss(nn.Module):
8
+ def __init__(self, weights_path=None,device =None):
9
+ super().__init__()
10
+ self.vgg = models.vgg19(pretrained=False).features[:35].eval().to(device)
11
+ if weights_path:
12
+ pretrained_weights = torch.load(weights_path)
13
+ self.vgg.load_state_dict(pretrained_weights, strict=False)
14
+ self.vgg[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1).to(device)
15
+ for param in self.vgg.parameters():
16
+ param.requires_grad = False
17
+ self.loss = nn.MSELoss()
18
+ def forward(self, input, target):
19
+ input_features = self.vgg(input)
20
+ target_features = self.vgg(target)
21
+ return self.loss(input_features, target_features)
22
+ """
23
+ vgg_loss = VGGLoss(weights_path="/mnt/wtx_weather_forecast/scx/stat/sat/sat2radar/vgg19-dcbb9e9d.pth").to(device)
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from segment_anything import sam_model_registry
30
+
31
+ class SAMLoss(nn.Module):
32
+ def __init__(self, model_type='vit_b', checkpoint_path=None, input_size=1024):
33
+ """
34
+ SAM-based perceptual loss with resolution handling
35
+
36
+ Args:
37
+ model_type (str): SAM model type (vit_b, vit_l, vit_h)
38
+ checkpoint_path (str): Path to SAM checkpoint weights
39
+ input_size (int): Target input size for SAM (default 1024)
40
+ """
41
+ super().__init__()
42
+ self.input_size = input_size
43
+ # Initialize SAM model
44
+ self.sam = sam_model_registry[model_type](checkpoint=None)
45
+ # Load pretrained weights if provided
46
+ if checkpoint_path:
47
+ state_dict = torch.load(checkpoint_path)
48
+ self.sam.load_state_dict(state_dict)
49
+ # Use image encoder only and freeze parameters
50
+ self.image_encoder = self.sam.image_encoder.eval()
51
+ for param in self.image_encoder.parameters():
52
+ param.requires_grad = False
53
+ # Define loss function
54
+ self.loss = nn.MSELoss()
55
+ # Normalization parameters
56
+ self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
57
+ self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
58
+
59
+ def preprocess(self, x):
60
+ """
61
+ Preprocess input to match SAM requirements:
62
+ 1. Convert to 3-channel if needed
63
+ 2. Normalize using ImageNet stats
64
+ 3. Resize to target size
65
+ """
66
+ if x.shape[1] == 1:
67
+ x = x.expand(-1, 3, -1, -1) # More memory efficient than repeat
68
+ # Normalize
69
+ x = (x - self.mean) / self.std
70
+ # Resize
71
+ if x.shape[-2:] != (self.input_size, self.input_size):
72
+ x = F.interpolate(x, size=(self.input_size, self.input_size),
73
+ mode='bilinear', align_corners=False)
74
+ return x
75
+ def forward(self, input, target):
76
+ # Preprocess
77
+ input = self.preprocess(input)
78
+ target = self.preprocess(target)
79
+ # Process in batches if needed
80
+ batch_size = 4 # Adjust based on your GPU memory
81
+ input_features = []
82
+ target_features = []
83
+ with torch.no_grad():
84
+ for i in range(0, input.size(0), batch_size):
85
+ input_batch = input[i:i+batch_size]
86
+ target_batch = target[i:i+batch_size]
87
+ input_features.append(self.image_encoder(input_batch))
88
+ target_features.append(self.image_encoder(target_batch))
89
+ return self.loss(torch.cat(input_features), torch.cat(target_features))
90
+
91
+
92
+ import torch
93
+ import torch.nn as nn
94
+ import torch.nn.functional as F
95
+ from segment_anything import sam_model_registry
96
+
97
+
98
+ class SAMLoss(nn.Module):
99
+ def __init__(self, model_type='vit_b', checkpoint_path=None):
100
+ super().__init__()
101
+ # 初始化 SAM 并加载预训练权重
102
+ self.sam = sam_model_registry[model_type](checkpoint=None)
103
+ if checkpoint_path:
104
+ state_dict = torch.load(checkpoint_path)
105
+ self.sam.load_state_dict(state_dict)
106
+
107
+ # 提取 image_encoder 并冻结
108
+ self.image_encoder = self.sam.image_encoder.eval()
109
+ for param in self.image_encoder.parameters():
110
+ param.requires_grad = False
111
+
112
+ # 获取 patch_size(通常为16)
113
+ self.patch_size = self.image_encoder.patch_embed.proj.kernel_size[0]
114
+
115
+ # 保存原始 pos_embed 供动态调整
116
+ self.original_pos_embed = self.image_encoder.pos_embed
117
+
118
+ # 归一化参数
119
+ self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
120
+ self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
121
+
122
+ self.loss = nn.MSELoss()
123
+
124
+ def adjust_pos_embed(self, x):
125
+ """
126
+ 动态调整 pos_embed 以匹配输入尺寸
127
+ Args:
128
+ x: 输入张量 (B, C, H, W)
129
+ """
130
+ B, _, H, W = x.shape
131
+ # 计算当前特征图的分辨率(H/patch_size, W/patch_size)
132
+ h, w = H // self.patch_size, W // self.patch_size
133
+
134
+ # 如果 pos_embed 尺寸不匹配,则进行插值
135
+ if (h, w) != self.original_pos_embed.shape[1:3]:
136
+ # 插值 pos_embed 到目标尺寸 (1, h, w, C) -> (1, C, h, w) -> 插值 -> 恢复形状
137
+ pos_embed = self.original_pos_embed.permute(0, 3, 1, 2) # (1, C, H_orig, W_orig)
138
+ pos_embed = F.interpolate(
139
+ pos_embed,
140
+ size=(h, w),
141
+ mode='bilinear',
142
+ align_corners=False
143
+ ).permute(0, 2, 3, 1) # (1, h, w, C)
144
+
145
+ # 用 nn.Parameter 包装调整后的 pos_embed
146
+ self.image_encoder.pos_embed = nn.Parameter(pos_embed, requires_grad=False)
147
+ else:
148
+ self.image_encoder.pos_embed = self.original_pos_embed
149
+ def preprocess(self, x):
150
+ """预处理:通道扩展 + 归一化 + 尺寸调整"""
151
+ if x.shape[1] == 1:
152
+ x = x.expand(-1, 3, -1, -1) # 单通道→三通
153
+ # 归一化
154
+ x = (x - self.mean) / self.std
155
+ # 确保尺寸是 patch_size 的整数倍
156
+ B, C, H, W = x.shape
157
+ H_new = (H // self.patch_size) * self.patch_size
158
+ W_new = (W // self.patch_size) * self.patch_size
159
+ if H != H_new or W != W_new:
160
+ x = F.interpolate(
161
+ x,
162
+ size=(H_new, W_new),
163
+ mode='bilinear',
164
+ align_corners=False
165
+ )
166
+ return x
167
+ def forward(self, input, target):
168
+ # 预处理
169
+ input = self.preprocess(input)
170
+ target = self.preprocess(target)
171
+
172
+ # 动态调整 pos_embed
173
+ self.adjust_pos_embed(input)
174
+
175
+ # 计算特征
176
+ with torch.no_grad():
177
+ input_feat = self.image_encoder(input)
178
+ target_feat = self.image_encoder(target)
179
+
180
+ return self.loss(input_feat, target_feat)
181
+
182
+ # saml = SAMLoss(checkpoint_path = "/path/to/sam_vit_b_01ec64.pth.1").to(device)
183
+
184
+
185
+ import torch
186
+ import torch.nn as nn
187
+ from torchvision import models
188
+ class WeightedVGGLoss(nn.Module):
189
+ def __init__(self, weights_path=None, device=None, apply_weighting=True):
190
+ super().__init__()
191
+ self.device = device
192
+ self.apply_weighting = apply_weighting
193
+ self.vgg = models.vgg19(pretrained=False).features[:35].eval().to(device)
194
+ if weights_path:
195
+ pretrained_weights = torch.load(weights_path)
196
+ self.vgg.load_state_dict(pretrained_weights, strict=False)
197
+ self.vgg[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1).to(device)
198
+ for param in self.vgg.parameters():
199
+ param.requires_grad = False
200
+ self.mse_loss = nn.MSELoss()
201
+ def custom_weight_4d(self, target):
202
+ EPS = 1e-6
203
+ segments = [
204
+ [10, 20, 0.5, 0.1], [20, 30, 0.4, 0.12],
205
+ [30, 40, 0.3, 0.14], [40, 50, 0.2, 0.16],
206
+ [50, 60, 0.1, 0.18], [60, 70, 0.05, 0.2]
207
+ ]
208
+ weights = torch.full_like(target, 0.1)
209
+ for start, end, steepness, base_w in segments:
210
+ mask = (target >= start) & (target < end)
211
+ x = ((target - start) / (end - start + EPS)).clamp(0, 1)
212
+ delta_w = (1.0 - base_w) * (0.7 - steepness)
213
+ seg_weights = base_w + delta_w * (x / (x + steepness*(1 - x) + EPS))
214
+ weights = torch.where(mask, seg_weights, weights)
215
+ weights[target >= 70] = 1.0
216
+ return weights
217
+ def forward(self, input, target):
218
+ input_features = self.vgg(input)
219
+ target_features = self.vgg(target)
220
+ if self.apply_weighting:
221
+ weights = self.custom_weight_4d(target_features.detach())
222
+ loss = torch.mean(weights * (input_features - target_features)**2)
223
+ else:
224
+ loss = self.mse_loss(input_features, target_features)
225
+ return loss
226
+ # vgg_loss = WeightedVGGLoss(weights_path="/mnt/wtx_weather_forecast/scx/stat/sat/sat2radar/vgg19-dcbb9e9d.pth",device=device).to(device)
227
+
228
+ import torch
229
+ import torch.nn as nn
230
+ import torch.nn.functional as F
231
+ class WeightedL1Loss(nn.Module):
232
+ def __init__(self, segment_ranges=None, device=None):
233
+ super().__init__()
234
+ self.device = device
235
+ self.segment_ranges = segment_ranges or [
236
+ (10, 20, 0.5, 0.1), (20, 30, 0.4, 0.12),
237
+ (30, 40, 0.3, 0.14), (40, 50, 0.2, 0.16),
238
+ (50, 60, 0.1, 0.18), (60, 70, 0.05, 0.2)
239
+ ]
240
+ def compute_weights(self, target):
241
+ EPS = 1e-6
242
+ weights = torch.full_like(target, 0.1) # 默认权重0.1
243
+ for start, end, steepness, base_w in self.segment_ranges:
244
+ mask = (target >= start) & (target < end)
245
+ x = ((target - start) / (end - start + EPS)).clamp(0, 1)
246
+ delta_w = (1.0 - base_w) * (0.7 - steepness)
247
+ seg_weights = base_w + delta_w * (x / (x + steepness*(1 - x) + EPS))
248
+ weights = torch.where(mask, seg_weights, weights)
249
+ weights[target >= 70] = 1.0
250
+ return weights
251
+ def forward(self, output, target):
252
+ """
253
+ 非降维(reduction='none')计算流程:
254
+ 1. 计算逐元素L1损失
255
+ 2. 生成动态权重矩阵
256
+ 3. 返回加权损失张量(保持输入维度)
257
+ """
258
+ # 保持维度的L1损失 (B,C,H,W)
259
+ l1_loss = F.l1_loss(output, target, reduction='none')
260
+ weights = self.compute_weights(target.detach()) # 切断梯度
261
+ # weighted_loss = torch.mean(weights * l1_loss)
262
+ weighted_loss = weights * l1_loss
263
+ return weighted_loss
264
+ # L1loss = WeightedL1Loss(device= device)
265
+
266
+ import torch
267
+ import torch.nn as nn
268
+ import torch.nn.functional as F
269
+ from segment_anything import sam_model_registry
270
+ class SAMLoss1(nn.Module):
271
+ def __init__(self, model_type='vit_b', checkpoint_path=None):
272
+ super().__init__()
273
+ self.sam = sam_model_registry[model_type](checkpoint=None)
274
+ if checkpoint_path:
275
+ state_dict = torch.load(checkpoint_path)
276
+ self.sam.load_state_dict(state_dict)
277
+ self.image_encoder = self.sam.image_encoder.eval()
278
+ for param in self.image_encoder.parameters():
279
+ param.requires_grad = False
280
+ self.patch_size = self.image_encoder.patch_embed.proj.kernel_size[0]
281
+ self.original_pos_embed = self.image_encoder.pos_embed
282
+ self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
283
+ self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
284
+ self.loss = nn.MSELoss()
285
+ def adjust_pos_embed(self, x):
286
+ """
287
+ 动态调整 pos_embed 以匹配输入尺寸
288
+ Args:
289
+ x: 输入张量 (B, C, H, W)
290
+ """
291
+ B, _, H, W = x.shape
292
+ h, w = H // self.patch_size, W // self.patch_size
293
+ if (h, w) != self.original_pos_embed.shape[1:3]:
294
+ pos_embed = self.original_pos_embed.permute(0, 3, 1, 2) # (1, C, H_orig, W_orig)
295
+ pos_embed = F.interpolate(
296
+ pos_embed,
297
+ size=(h, w),
298
+ mode='bilinear',
299
+ align_corners=False
300
+ ).permute(0, 2, 3, 1) # (1, h, w, C)
301
+ self.image_encoder.pos_embed = nn.Parameter(pos_embed, requires_grad=False)
302
+ else:
303
+ self.image_encoder.pos_embed = self.original_pos_embed
304
+ def preprocess(self, x):
305
+ """预处理:通道扩展 + 归一化 + 尺寸调整"""
306
+ if x.shape[1] == 1:
307
+ x = x.expand(-1, 3, -1, -1) # 单通道→三通道
308
+ x = (x - self.mean) / self.std
309
+ B, C, H, W = x.shape
310
+ H_new = (H // self.patch_size) * self.patch_size
311
+ W_new = (W // self.patch_size) * self.patch_size
312
+ if H != H_new or W != W_new:
313
+ x = F.interpolate(
314
+ x,
315
+ size=(H_new, W_new),
316
+ mode='bilinear',
317
+ align_corners=False
318
+ )
319
+ return x
320
+ def forward(self, input, target):
321
+ input = self.preprocess(input)
322
+ target = self.preprocess(target)
323
+ self.adjust_pos_embed(input)
324
+ with torch.no_grad():
325
+ input_feat = self.image_encoder(input)
326
+ target_feat = self.image_encoder(target)
327
+ return self.loss(input_feat, target_feat)
328
+
329
+ import torch
330
+ import torch.nn as nn
331
+ import torch.nn.functional as F
332
+ from segment_anything import sam_model_registry
333
+ class SAMLoss2(nn.Module):
334
+ def __init__(self, model_type='vit_b', checkpoint_path=None):
335
+ super().__init__()
336
+ self.sam = sam_model_registry[model_type](checkpoint=None)
337
+ if checkpoint_path:
338
+ state_dict = torch.load(checkpoint_path)
339
+ self.sam.load_state_dict(state_dict)
340
+ self.image_encoder = self.sam.image_encoder.eval()
341
+ for param in self.image_encoder.parameters():
342
+ param.requires_grad = False
343
+ self.patch_size = self.image_encoder.patch_embed.proj.kernel_size[0]
344
+ self.original_pos_embed = self.image_encoder.pos_embed
345
+ self.loss = nn.MSELoss()
346
+ def adjust_pos_embed(self, x):
347
+ """
348
+ 动态调整 pos_embed 以匹配输入尺寸
349
+ Args:
350
+ x: 输入张量 (B, C, H, W)
351
+ """
352
+ B, _, H, W = x.shape
353
+ h, w = H // self.patch_size, W // self.patch_size
354
+ if (h, w) != self.original_pos_embed.shape[1:3]:
355
+ pos_embed = self.original_pos_embed.permute(0, 3, 1, 2) # (1, C, H_orig, W_orig)
356
+ pos_embed = F.interpolate(
357
+ pos_embed,
358
+ size=(h, w),
359
+ mode='bilinear',
360
+ align_corners=False
361
+ ).permute(0, 2, 3, 1) # (1, h, w, C)
362
+ self.image_encoder.pos_embed = nn.Parameter(pos_embed, requires_grad=False)
363
+ else:
364
+ self.image_encoder.pos_embed = self.original_pos_embed
365
+ def preprocess(self, x):
366
+ """预处理:仅通道扩展 + 尺寸调整(跳过归一化)"""
367
+ if x.shape[1] == 1:
368
+ x = x.expand(-1, 3, -1, -1) # 单通道→三通道
369
+ B, C, H, W = x.shape
370
+ H_new = (H // self.patch_size) * self.patch_size
371
+ W_new = (W // self.patch_size) * self.patch_size
372
+ if H != H_new or W != W_new:
373
+ x = F.interpolate(
374
+ x,
375
+ size=(H_new, W_new),
376
+ mode='bilinear',
377
+ align_corners=False
378
+ )
379
+ return x
380
+ def forward(self, input, target):
381
+ # 预处理(input和target均不归一化)
382
+ input = self.preprocess(input)
383
+ target = self.preprocess(target)
384
+ # 动态调整 pos_embed
385
+ self.adjust_pos_embed(input)
386
+ # 计算特征
387
+ with torch.no_grad():
388
+ input_feat = self.image_encoder(input)
389
+ target_feat = self.image_encoder(target)
390
+ return self.loss(input_feat, target_feat)
391
+ #saml = SAMLoss(checkpoint_path = "/mnt/wtx_weather_forecast/scx/stat/sat/sat2radar/sam_vit_b_01ec64.pth.1").to(device)
@@ -0,0 +1,10 @@
1
+ import torch
2
+ def L2loss(model,loss,lambda_reg =0.01):
3
+ l2_reg = 0.0
4
+ for param in model.parameters():
5
+ l2_reg += torch.norm(param)**2 # 模型参数的平方和
6
+ print(l2_reg)
7
+ loss += lambda_reg * l2_reg
8
+ return loss
9
+
10
+