shancx 1.8.92__py3-none-any.whl → 1.9.33.218__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. shancx/3D/__init__.py +25 -0
  2. shancx/Algo/Class.py +11 -0
  3. shancx/Algo/CudaPrefetcher1.py +112 -0
  4. shancx/Algo/Fake_image.py +24 -0
  5. shancx/Algo/Hsml.py +391 -0
  6. shancx/Algo/L2Loss.py +10 -0
  7. shancx/Algo/MetricTracker.py +132 -0
  8. shancx/Algo/Normalize.py +66 -0
  9. shancx/Algo/OptimizerWScheduler.py +38 -0
  10. shancx/Algo/Rmageresize.py +79 -0
  11. shancx/Algo/Savemodel.py +33 -0
  12. shancx/Algo/SmoothL1_losses.py +27 -0
  13. shancx/Algo/Tqdm.py +62 -0
  14. shancx/Algo/__init__.py +121 -0
  15. shancx/Algo/checknan.py +28 -0
  16. shancx/Algo/iouJU.py +83 -0
  17. shancx/Algo/mask.py +25 -0
  18. shancx/Algo/psnr.py +9 -0
  19. shancx/Algo/ssim.py +70 -0
  20. shancx/Algo/structural_similarity.py +308 -0
  21. shancx/Algo/tool.py +704 -0
  22. shancx/Calmetrics/__init__.py +97 -0
  23. shancx/Calmetrics/calmetrics.py +14 -0
  24. shancx/Calmetrics/calmetricsmatrixLib.py +147 -0
  25. shancx/Calmetrics/rmseR2score.py +35 -0
  26. shancx/Clip/__init__.py +50 -0
  27. shancx/Cmd.py +126 -0
  28. shancx/Config_.py +26 -0
  29. shancx/Df/DataFrame.py +11 -2
  30. shancx/Df/__init__.py +17 -0
  31. shancx/Df/tool.py +0 -0
  32. shancx/Diffm/Psamples.py +18 -0
  33. shancx/Diffm/__init__.py +0 -0
  34. shancx/Diffm/test.py +207 -0
  35. shancx/Doc/__init__.py +214 -0
  36. shancx/E/__init__.py +178 -152
  37. shancx/Fillmiss/__init__.py +0 -0
  38. shancx/Fillmiss/imgidwJU.py +46 -0
  39. shancx/Fillmiss/imgidwLatLonJU.py +82 -0
  40. shancx/Gpu/__init__.py +55 -0
  41. shancx/H9/__init__.py +126 -0
  42. shancx/H9/ahi_read_hsd.py +877 -0
  43. shancx/H9/ahisearchtable.py +298 -0
  44. shancx/H9/geometry.py +2439 -0
  45. shancx/Hug/__init__.py +81 -0
  46. shancx/Inst.py +22 -0
  47. shancx/Lib.py +31 -0
  48. shancx/Mos/__init__.py +37 -0
  49. shancx/NN/__init__.py +235 -106
  50. shancx/Path1.py +161 -0
  51. shancx/Plot/GlobMap.py +276 -116
  52. shancx/Plot/__init__.py +491 -1
  53. shancx/Plot/draw_day_CR_PNG.py +4 -21
  54. shancx/Plot/exam.py +116 -0
  55. shancx/Plot/plotGlobal.py +325 -0
  56. shancx/{radar_nmc.py → Plot/radarNmc.py} +4 -34
  57. shancx/{subplots_single_china_map.py → Plot/single_china_map.py} +1 -1
  58. shancx/Point.py +46 -0
  59. shancx/QC.py +223 -0
  60. shancx/RdPzl/__init__.py +32 -0
  61. shancx/Read.py +72 -0
  62. shancx/Resize.py +79 -0
  63. shancx/SN/__init__.py +62 -123
  64. shancx/Time/GetTime.py +9 -3
  65. shancx/Time/__init__.py +66 -1
  66. shancx/Time/timeCycle.py +302 -0
  67. shancx/Time/tool.py +0 -0
  68. shancx/Train/__init__.py +74 -0
  69. shancx/Train/makelist.py +187 -0
  70. shancx/Train/multiGpu.py +27 -0
  71. shancx/Train/prepare.py +161 -0
  72. shancx/Train/renet50.py +157 -0
  73. shancx/ZR.py +12 -0
  74. shancx/__init__.py +333 -262
  75. shancx/args.py +27 -0
  76. shancx/bak.py +768 -0
  77. shancx/df2database.py +62 -2
  78. shancx/geosProj.py +80 -0
  79. shancx/info.py +38 -0
  80. shancx/netdfJU.py +231 -0
  81. shancx/sendM.py +59 -0
  82. shancx/tensBoard/__init__.py +28 -0
  83. shancx/wait.py +246 -0
  84. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/METADATA +15 -5
  85. shancx-1.9.33.218.dist-info/RECORD +91 -0
  86. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
  87. my_timer_decorator/__init__.py +0 -10
  88. shancx/Dsalgor/__init__.py +0 -19
  89. shancx/E/DFGRRIB.py +0 -30
  90. shancx/EN/DFGRRIB.py +0 -30
  91. shancx/EN/__init__.py +0 -148
  92. shancx/FileRead.py +0 -44
  93. shancx/Gray2RGB.py +0 -86
  94. shancx/M/__init__.py +0 -137
  95. shancx/MN/__init__.py +0 -133
  96. shancx/N/__init__.py +0 -131
  97. shancx/Plot/draw_day_CR_PNGUS.py +0 -206
  98. shancx/Plot/draw_day_CR_SVG.py +0 -275
  99. shancx/Plot/draw_day_pre_PNGUS.py +0 -205
  100. shancx/Plot/glob_nation_map.py +0 -116
  101. shancx/Plot/radar_nmc.py +0 -61
  102. shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
  103. shancx/Plot/radar_nmc_china_map_f.py +0 -121
  104. shancx/Plot/radar_nmc_us_map_f.py +0 -128
  105. shancx/Plot/subplots_compare_devlop.py +0 -36
  106. shancx/Plot/subplots_single_china_map.py +0 -45
  107. shancx/S/__init__.py +0 -138
  108. shancx/W/__init__.py +0 -132
  109. shancx/WN/__init__.py +0 -132
  110. shancx/code.py +0 -331
  111. shancx/draw_day_CR_PNG.py +0 -200
  112. shancx/draw_day_CR_PNGUS.py +0 -206
  113. shancx/draw_day_CR_SVG.py +0 -275
  114. shancx/draw_day_pre_PNGUS.py +0 -205
  115. shancx/makenetCDFN.py +0 -42
  116. shancx/mkIMGSCX.py +0 -92
  117. shancx/netCDF.py +0 -130
  118. shancx/radar_nmc_china_map_compare1.py +0 -50
  119. shancx/radar_nmc_china_map_f.py +0 -125
  120. shancx/radar_nmc_us_map_f.py +0 -67
  121. shancx/subplots_compare_devlop.py +0 -36
  122. shancx/tool.py +0 -18
  123. shancx/user/H8mess.py +0 -317
  124. shancx/user/__init__.py +0 -137
  125. shancx/user/cinradHJN.py +0 -496
  126. shancx/user/examMeso.py +0 -293
  127. shancx/user/hjnDAAS.py +0 -26
  128. shancx/user/hjnFTP.py +0 -81
  129. shancx/user/hjnGIS.py +0 -320
  130. shancx/user/hjnGPU.py +0 -21
  131. shancx/user/hjnIDW.py +0 -68
  132. shancx/user/hjnKDTree.py +0 -75
  133. shancx/user/hjnLAPSTransform.py +0 -47
  134. shancx/user/hjnMiscellaneous.py +0 -182
  135. shancx/user/hjnProj.py +0 -162
  136. shancx/user/inotify.py +0 -41
  137. shancx/user/matplotlibMess.py +0 -87
  138. shancx/user/mkNCHJN.py +0 -623
  139. shancx/user/newTypeRadar.py +0 -492
  140. shancx/user/test.py +0 -6
  141. shancx/user/tlogP.py +0 -129
  142. shancx/util_log.py +0 -33
  143. shancx/wtx/H8mess.py +0 -315
  144. shancx/wtx/__init__.py +0 -151
  145. shancx/wtx/cinradHJN.py +0 -496
  146. shancx/wtx/colormap.py +0 -64
  147. shancx/wtx/examMeso.py +0 -298
  148. shancx/wtx/hjnDAAS.py +0 -26
  149. shancx/wtx/hjnFTP.py +0 -81
  150. shancx/wtx/hjnGIS.py +0 -330
  151. shancx/wtx/hjnGPU.py +0 -21
  152. shancx/wtx/hjnIDW.py +0 -68
  153. shancx/wtx/hjnKDTree.py +0 -75
  154. shancx/wtx/hjnLAPSTransform.py +0 -47
  155. shancx/wtx/hjnLog.py +0 -78
  156. shancx/wtx/hjnMiscellaneous.py +0 -201
  157. shancx/wtx/hjnProj.py +0 -161
  158. shancx/wtx/inotify.py +0 -41
  159. shancx/wtx/matplotlibMess.py +0 -87
  160. shancx/wtx/mkNCHJN.py +0 -613
  161. shancx/wtx/newTypeRadar.py +0 -492
  162. shancx/wtx/test.py +0 -6
  163. shancx/wtx/tlogP.py +0 -129
  164. shancx-1.8.92.dist-info/RECORD +0 -99
  165. /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
  166. {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,132 @@
1
+ # https://github.com/pytorch/examples/blob/master/imagenet/main.py
2
+ class MetricTracker(object):
3
+ """Computes and stores the average and current value"""
4
+ def __init__(self):
5
+ self.reset()
6
+ def reset(self):
7
+ self.val = 0
8
+ self.avg = 0
9
+ self.sum = 0
10
+ self.count = 0
11
+ def update(self, val, n=1):
12
+ self.val = val
13
+ self.sum += val * n
14
+ self.count += n
15
+ self.avg = self.sum / self.count
16
+ """
17
+ train_acc = metrics.MetricTracker()
18
+ train_acc.update(metrics.acc(outputs, labels,threshold=0.1), outputs.size(0))
19
+ train_acc.avg
20
+ """
21
+
22
+
23
+ from tqdm import tqdm
24
+
25
+ class TrainingManager:
26
+ def __init__(self, loader, description="Training Progress", log_file=None):
27
+ """
28
+ 封装进度条和日志记录的训练管理类
29
+ :param loader: 数据加载器
30
+ :param description: 进度条描述
31
+ :param log_file: 日志文件路径,可选
32
+ """
33
+ self.loader = tqdm(loader, desc=description)
34
+ self.metrics = {}
35
+ self.log_file = log_file
36
+ if log_file:
37
+ with open(log_file, 'w') as f:
38
+ headers = "Epoch,Step," + ",".join(self.metrics.keys()) + "\n"
39
+ f.write(headers)
40
+
41
+ def add_metric(self, name, initial_value=0.0):
42
+ """
43
+ 添加指标用于追踪
44
+ :param name: 指标名称
45
+ :param initial_value: 初始值
46
+ """
47
+ self.metrics[name] = Metric(initial_value)
48
+
49
+ def update_metrics(self, **kwargs):
50
+ """
51
+ 更新指定的指标值
52
+ :param kwargs: 关键字参数,指标名和更新值
53
+ """
54
+ for name, value in kwargs.items():
55
+ if name in self.metrics:
56
+ self.metrics[name].update(value)
57
+
58
+ def log_progress(self, epoch, step):
59
+ """
60
+ 更新进度条和日志文件
61
+ :param epoch: 当前训练轮次
62
+ :param step: 当前训练步骤
63
+ """
64
+ description = f"Epoch {epoch}: " + " ".join(
65
+ [f"{name}: {metric.avg:.4f}" for name, metric in self.metrics.items()]
66
+ )
67
+ self.loader.set_description(description)
68
+
69
+ if self.log_file:
70
+ with open(self.log_file, 'a') as f:
71
+ line = f"{epoch},{step}," + ",".join([f"{metric.avg:.4f}" for metric in self.metrics.values()]) + "\n"
72
+ f.write(line)
73
+
74
+ def close(self):
75
+ """关闭进度条"""
76
+ self.loader.close()
77
+
78
+ class Metric:
79
+ def __init__(self, initial_value=0.0):
80
+ """
81
+ 用于追踪的指标
82
+ :param initial_value: 初始值
83
+ """
84
+ self.total = initial_value
85
+ self.count = 0
86
+ self.avg = initial_value
87
+
88
+ def update(self, value, count=1):
89
+ """
90
+ 更新指标值
91
+ :param value: 新增的值
92
+ :param count: 数据数量,默认1
93
+ """
94
+ self.total += value * count
95
+ self.count += count
96
+ self.avg = self.total / self.count
97
+
98
+ # 示例代码
99
+ if __name__ == "__main__":
100
+ # 假设 train_dataloader 是数据加载器,示例模拟100条数据
101
+ train_dataloader = [{"sat_img": None, "map_img": None} for _ in range(100)]
102
+
103
+ manager = TrainingManager(train_dataloader, description="Training Progress", log_file="training_log.csv")
104
+ manager.add_metric("Loss")
105
+ manager.add_metric("Accuracy")
106
+ manager.add_metric("Ts")
107
+ manager.add_metric("Fsc")
108
+ manager.add_metric("Far")
109
+
110
+ for epoch in range(3): # 模拟3个epoch
111
+ for idx, data in enumerate(manager.loader):
112
+ # 模拟训练过程的计算
113
+ loss_value = 0.5 + idx * 0.01
114
+ accuracy_value = 0.8 - idx * 0.001
115
+ ts_value = 0.7 + idx * 0.001
116
+ fsc_value = 0.6 + idx * 0.002
117
+ far_value = 0.2 + idx * 0.001
118
+
119
+ # 更新指标
120
+ manager.update_metrics(
121
+ Loss=loss_value,
122
+ Accuracy=accuracy_value,
123
+ Ts=ts_value,
124
+ Fsc=fsc_value,
125
+ Far=far_value
126
+ )
127
+
128
+ # 记录进度和日志
129
+ manager.log_progress(epoch, idx)
130
+
131
+ # 关闭进度条
132
+ manager.close()
@@ -0,0 +1,66 @@
1
+
2
+ from torchvision.transforms import Resize# Function to calculate valid size based on scale factor
3
+ def get_valid_size(size, scale_factor=2):
4
+ return size - (size % scale_factor)
5
+ # Function to resize a high-resolution image to a low-resolution image
6
+ def resize_to_low_res(high_res_image, scale_factor):
7
+ _, width, height = high_res_image.size()
8
+ # Get valid width and height divisible by the scale factor
9
+ width = get_valid_size(width, scale_factor)
10
+ height = get_valid_size(height, scale_factor)
11
+ # Crop the high-res image to the valid dimensions
12
+ high_res_image = high_res_image[:, :width, :height]
13
+ # Create the low-res image by resizing the high-res image
14
+ low_res_image = Resize((width // scale_factor, height // scale_factor))(high_res_image)
15
+ return low_res_image
16
+ """
17
+
18
+ cuda-version 11.8 hcce14f8_3
19
+ cudatoolkit 11.8.0 h6a678d5_0
20
+ cudnn 8.9.2.26 cuda11_0
21
+ nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
22
+ nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
23
+ nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
24
+ nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi
25
+ mqpf conda install pytorch torchvision torchaudio cudatoolkit=11.8 -c pytorch
26
+ conda install cudnn=8.9.2.26 cudatoolkit=11.8
27
+ resunet pip install torch==2.4.0 torchvision torchaudio
28
+ conda install cudnn==8.9.2.26 cudatoolkit==11.8.0
29
+ conda install pytorch=2.2.2 torchvision torchaudio cudatoolkit=11.8 -c pytorch
30
+ resunet pip install torch==2.4.0 torchvision torchaudio
31
+ pip install protobuf==3.20
32
+
33
+ my-envmf1
34
+ torch 2.3.0 pypi_0 pypi
35
+ torchvision 0.18.0 pypi_0 pypi
36
+
37
+ RES:
38
+ torch 2.4.0 pypi_0 pypi
39
+ torchaudio 2.2.2 py311_cpu pytorch
40
+ torchsummary 1.5.1 pypi_0 pypi
41
+ torchvision 0.19.0 pypi_0 pypi
42
+
43
+ mqpf:
44
+ torch 2.3.1 pypi_0 pypi
45
+ torchaudio 2.3.1 pypi_0 pypi
46
+ torchvision 0.18.1 pypi_0 pypi
47
+ onnxruntime-gpu 1.16.0
48
+ onnx 1.15.0
49
+ numpy 1.26.4
50
+
51
+ vllm:
52
+ torch 2.1.2 pypi_0 pypi
53
+ torchvision 0.15.1+cu118 pypi_0 pypi
54
+ vllm 0.2.7 pypi_0 pypi
55
+
56
+ import torch
57
+ print("CUDA available:", torch.cuda.is_available())
58
+ print("CUDA version:", torch.version.cuda)
59
+ print("GPU device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
60
+ nvidia-smi
61
+ nvcc --version
62
+ 系统已经检测到物理 GPU(NVIDIA GeForce RTX 4090)和 NVIDIA 驱动,同时安装了 CUDA 12.1。然而,PyTorch 没有正确检测到 GPU,可能是因为 PyTorch 版本与 CUDA 驱动不兼容,或者环境变量未正确配置。
63
+
64
+ pip install torch==2.3.1 torchvision==0.18.1
65
+
66
+ """
@@ -0,0 +1,38 @@
1
+ import torch
2
+
3
+ class OptimizerWithScheduler:
4
+ def __init__(self, model, lr=0.001, step_size=40, gamma=0.1):
5
+ """
6
+ 初始化优化器和学习率调度器。
7
+ :param model: 要优化的模型
8
+ :param lr: 初始学习率
9
+ :param step_size: 调度器每隔多少个 epoch 调整学习率
10
+ :param gamma: 调度器调整学习率的乘法因子
11
+ """
12
+ self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
13
+ self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
14
+ self.optimizer, step_size=step_size, gamma=gamma
15
+ )
16
+ def zero_grad(self):
17
+ self.optimizer.zero_grad()
18
+ def step(self):
19
+ self.optimizer.step()
20
+ def step_scheduler(self):
21
+ self.lr_scheduler.step()
22
+ def get_lr(self):
23
+ return self.optimizer.param_groups[0]['lr']
24
+ """
25
+ if __name__ == "__main__":
26
+ model = torch.nn.Linear(10, 1)
27
+ optimizer_with_scheduler = OptimizerWithScheduler(model, lr=0.001, step_size=40, gamma=0.1)
28
+ for epoch in range(100):
29
+ optimizer_with_scheduler.step_scheduler()
30
+ for idx, data in enumerate(loader):
31
+ inputs = data["sat_img"].cuda()
32
+ labels = data["map_img"].cuda()
33
+ optimizer_with_scheduler.zero_grad()
34
+ loss = torch.randn(1) # 仅为示例
35
+ loss.backward()
36
+ optimizer_with_scheduler.step()
37
+ print(f"Epoch {epoch + 1}, Learning Rate: {optimizer_with_scheduler.get_lr()}")
38
+ """
@@ -0,0 +1,79 @@
1
+ import torch
2
+ from torchvision.transforms import Resize
3
+ from torchvision.transforms.functional import InterpolationMode
4
+ import numpy as np
5
+
6
+ def image_resize(image, scale_factor, antialiasing=True):
7
+ """
8
+ Resize the input image with a specified scale factor using torchvision's Resize.
9
+
10
+ Args:
11
+ image (np.ndarray or torch.Tensor): Input image of shape (H, W, C) or (C, H, W).
12
+ scale_factor (float): Scale factor to resize the image.
13
+ antialiasing (bool): Whether to use antialiasing (default: True).
14
+
15
+ Returns:
16
+ Resized image in the same format as input (NumPy or PyTorch Tensor).
17
+ """
18
+ # Check if input is NumPy array
19
+ numpy_type = isinstance(image, np.ndarray)
20
+
21
+ # Handle missing values for NumPy arrays
22
+ if numpy_type:
23
+ image = np.nan_to_num(image) # Replace NaNs with 0
24
+ if image.ndim == 3: # (H, W, C)
25
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float() # Convert to (C, H, W)
26
+ elif image.ndim == 2: # (H, W)
27
+ image = torch.from_numpy(image[None, :, :]).float() # Convert to (C, H, W)
28
+
29
+ # Calculate new dimensions
30
+ _, in_h, in_w = image.shape # Assuming (C, H, W) format
31
+ out_h, out_w = int(in_h * scale_factor), int(in_w * scale_factor)
32
+
33
+ # Perform resizing
34
+ mode = InterpolationMode.BICUBIC if antialiasing else InterpolationMode.NEAREST
35
+ resize_transform = Resize((out_h, out_w), interpolation=mode)
36
+ resized_image = resize_transform(image)
37
+
38
+ # Convert back to NumPy array if input was NumPy
39
+ if numpy_type:
40
+ resized_image = resized_image.numpy().transpose(1, 2, 0) # Convert back to (H, W, C)
41
+
42
+ return resized_image
43
+ # resized_image = image_resize(image, scale_factor, antialiasing)
44
+
45
+ # Example for loading image data and processing NaNs
46
+ def process_image(file_path, scale_factor, antialiasing=True):
47
+ """
48
+ Load an image from a file, handle NaN values, and resize it.
49
+
50
+ Args:
51
+ file_path (str): Path to the image file (e.g., .npy for NumPy arrays).
52
+ scale_factor (float): Scale factor for resizing.
53
+ antialiasing (bool): Use antialiasing during resizing (default: True).
54
+
55
+ Returns:
56
+ Resized image as a NumPy array.
57
+ """
58
+ # Load the image from a .npy file
59
+ image = np.load(file_path)
60
+ image[np.isnan(image)] = 0 # Replace NaN values with 0
61
+
62
+ # Resize the image
63
+ resized_image = image_resize(image, scale_factor, antialiasing)
64
+ return resized_image
65
+ if __name__=="__main__":
66
+ import matplotlib.pyplot as plt
67
+
68
+ # Path to a NumPy array image file
69
+ file_path = 'example.npy' # Replace with your .npy file path
70
+
71
+ # Resize the image
72
+ scale_factor = 0.5 # Downscale by 50%
73
+ resized_image = process_image(file_path, scale_factor)
74
+
75
+ # Display the resized image
76
+ plt.title("Resized Image")
77
+ plt.imshow(resized_image.astype(np.uint8))
78
+ plt.axis("off")
79
+ plt.show()
@@ -0,0 +1,33 @@
1
+ import os
2
+ import torch
3
+ def save_model_checkpoint(step, epoch, model, optimizer, best_loss, best_acc, checkpoint_dir, name):
4
+ if not os.path.exists(checkpoint_dir):
5
+ os.makedirs(checkpoint_dir)
6
+ save_path = os.path.join(checkpoint_dir, f"{name}_checkpoint_{step:04d}.pt")
7
+ torch.save({
8
+ "step": step,
9
+ "epoch": epoch,
10
+ "arch": model.__class__.__name__,
11
+ "state_dict": model.state_dict(),
12
+ "best_acc": best_acc,
13
+ "best_loss": best_loss,
14
+ "optimizer": optimizer.state_dict(),
15
+ }, save_path)
16
+ return save_path
17
+
18
+ """
19
+ if valid_loss < best_loss:
20
+ best_loss = valid_loss
21
+ wait = 0 # Reset wait counter
22
+ save_path = save_model_checkpoint(
23
+ step=step,
24
+ epoch=epoch,
25
+ model=model,
26
+ optimizer=optimizer,
27
+ best_loss=best_loss,
28
+ best_acc=best_acc,
29
+ checkpoint_dir=checkpoint_dir,
30
+ name=name
31
+ )
32
+ print(f"Checkpoint saved at: {save_path}")
33
+ """
@@ -0,0 +1,27 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ def smoothL1_loss(x, y):
5
+ smoothL1loss = nn.SmoothL1Loss()
6
+ return smoothL1loss(x, y)
7
+ def compute_smoothL1_losses(x,y, y_label):
8
+ """
9
+ 计算每个标签类别的 Smooth L1 损失。
10
+ 参数:
11
+ - x (torch.Tensor): 预测值。
12
+ - y (torch.Tensor): 真实值。
13
+ - y_label (torch.Tensor): 标签值,用于选择对应类别的预测和真实值。
14
+ 返回:
15
+ - smoothL1_losses (list): 包含每个类别的 Smooth L1 损失值的列表。
16
+ """
17
+ smoothL1_losses = []
18
+ for label_value in range(1, 5):
19
+ mask = y_label == label_value
20
+ x_masked = torch.masked_select(x, mask)
21
+ y_masked = torch.masked_select(y, mask)
22
+ if x_masked.numel() > 0:
23
+ smoothL1_val = smoothL1_loss(x_masked, y_masked) #, reduction='mean'
24
+ smoothL1_losses.append(smoothL1_val)
25
+ else:
26
+ smoothL1_losses.append(torch.tensor(0.0)) #增强逻辑
27
+ return smoothL1_losses
shancx/Algo/Tqdm.py ADDED
@@ -0,0 +1,62 @@
1
+ from tqdm import tqdm
2
+ class TrainingManager:
3
+ def __init__(self, loader, description="Training Progress", log_file=None):
4
+ self.loader = tqdm(loader, desc=description)
5
+ self.metrics = {}
6
+ self.log_file = log_file
7
+ if log_file:
8
+ with open(log_file, 'w') as f:
9
+ headers = "Epoch,Step," + ",".join(self.metrics.keys()) + "\n"
10
+ f.write(headers)
11
+ def add_metric(self, name, initial_value=0.0):
12
+ self.metrics[name] = Metric(initial_value) ###用的另外一个更新类计算均值的
13
+ def update_metrics(self, **kwargs):
14
+ for name, value in kwargs.items(): #param kwargs: 关键字参数,指标名和更新值
15
+ if name in self.metrics:
16
+ self.metrics[name].update(value)
17
+ def log_progress(self, epoch, step):
18
+ description = f"Epoch {epoch}: " + " ".join(
19
+ [f"{name}: {metric.avg:.4f}" for name, metric in self.metrics.items()]
20
+ )
21
+ self.loader.set_description(description)
22
+ if self.log_file:
23
+ with open(self.log_file, 'a') as f:
24
+ line = f"{epoch},{step}," + ",".join([f"{metric.avg:.4f}" for metric in self.metrics.values()]) + "\n"
25
+ f.write(line)
26
+ def log_progress1(self, description):
27
+ self.loader.set_description(description)
28
+ def close(self):
29
+ self.loader.close()
30
+ class Metric:
31
+ def __init__(self, initial_value=0.0):
32
+ self.total = initial_value
33
+ self.count = 0
34
+ self.avg = initial_value
35
+ def update(self, value, count=1):
36
+ self.total += value * count
37
+ self.count += count
38
+ self.avg = self.total / self.count
39
+ if __name__ == "__main__":
40
+ train_dataloader = [{"sat_img": None, "map_img": None} for _ in range(100)]
41
+ manager = TrainingManager(train_dataloader, description="Training Progress", log_file="training_log.csv")
42
+ manager.add_metric("Loss")
43
+ manager.add_metric("Accuracy")
44
+ manager.add_metric("Ts")
45
+ manager.add_metric("Fsc")
46
+ manager.add_metric("Far")
47
+ for epoch in range(3): # 模拟3个epoch
48
+ for idx, data in enumerate(manager.loader):
49
+ loss_value = 0.5 + idx * 0.01
50
+ accuracy_value = 0.8 - idx * 0.001
51
+ ts_value = 0.7 + idx * 0.001
52
+ fsc_value = 0.6 + idx * 0.002
53
+ far_value = 0.2 + idx * 0.001
54
+ manager.update_metrics(
55
+ Loss=loss_value,
56
+ Accuracy=accuracy_value,
57
+ Ts=ts_value,
58
+ Fsc=fsc_value,
59
+ Far=far_value
60
+ )
61
+ manager.log_progress(epoch, idx)
62
+ manager.close()
@@ -0,0 +1,121 @@
1
+ #!/usr/bin/python
2
+ # -*- coding: utf-8 -*-
3
+ # @Time : 2024/10/17 上午午10:40
4
+ # @Author : shancx
5
+ # @File : __init__.py
6
+ # @email : shanhe12@163.com
7
+
8
+ def quick_sort(arr):
9
+ if len(arr) <= 1:
10
+ return arr
11
+ pivot = arr[len(arr) // 2]
12
+ left = [x for x in arr if x < pivot]
13
+ middle = [x for x in arr if x == pivot]
14
+ right = [x for x in arr if x > pivot]
15
+ return quick_sort(left) + middle + quick_sort(right)
16
+ def sort_dict_by_key(d):
17
+ sorted_dict = {key: d[key] for key in sorted(d.keys())}
18
+ return sorted_dict
19
+ from tqdm import tqdm
20
+ class TrainingManager:
21
+ def __init__(self, loader, desc="Training Progress", log_file=None):
22
+ self.loader = tqdm(loader, desc=desc)
23
+ self.metrics = {}
24
+ self.log_file = log_file
25
+ if log_file:
26
+ with open(log_file, 'w') as f:
27
+ headers = "Epoch,Step," + ",".join(self.metrics.keys()) + "\n"
28
+ f.write(headers)
29
+ def add_metric(self, name, initial_value=0.0):
30
+ self.metrics[name] = Metric(initial_value) ###用的另外一个更新类计算均值的
31
+ def update_metrics(self, **kwargs):
32
+ for name, value in kwargs.items(): #param kwargs: 关键字参数,指标名和更新值
33
+ if name in self.metrics:
34
+ self.metrics[name].update(value)
35
+ def log_progress(self, epoch, step):
36
+ description = f"Epoch {epoch}: " + " ".join(
37
+ [f"{name}: {metric.avg:.4f}" for name, metric in self.metrics.items()]
38
+ )
39
+ self.loader.set_description(description)
40
+ if self.log_file:
41
+ with open(self.log_file, 'a') as f:
42
+ line = f"{epoch},{step}," + ",".join([f"{metric.avg:.4f}" for metric in self.metrics.values()]) + "\n"
43
+ f.write(line)
44
+ """
45
+ description = f"Epoch {epoch}: " + " ".join(
46
+ [f"{name}: {metric.avg:.4f}" for name, metric in self.metrics.items()]
47
+ )
48
+ def log_progress1(self,description):
49
+ self.loader.set_description(description)
50
+ """
51
+ def close(self):
52
+ self.loader.close()
53
+ class Metric:
54
+ def __init__(self, initial_value=0.0):
55
+ self.total = initial_value
56
+ self.count = 0
57
+ self.avg = initial_value
58
+ def update(self, value, count=1):
59
+ self.total += value * count
60
+ self.count += count
61
+ self.avg = self.total / self.count
62
+ if __name__ == "__main__":
63
+ train_dataloader = [{"sat_img": None, "map_img": None} for _ in range(100)]
64
+ manager = TrainingManager(train_dataloader, description="Training Progress", log_file="training_log.csv")
65
+ manager.add_metric("Loss")
66
+ manager.add_metric("Accuracy")
67
+ manager.add_metric("Ts")
68
+ manager.add_metric("Fsc")
69
+ manager.add_metric("Far")
70
+ for epoch in range(3): # 模拟3个epoch
71
+ for idx, data in enumerate(manager.loader):
72
+ loss_value = 0.5 + idx * 0.01
73
+ accuracy_value = 0.8 - idx * 0.001
74
+ ts_value = 0.7 + idx * 0.001
75
+ fsc_value = 0.6 + idx * 0.002
76
+ far_value = 0.2 + idx * 0.001
77
+ manager.update_metrics(
78
+ Loss=loss_value,
79
+ Accuracy=accuracy_value,
80
+ Ts=ts_value,
81
+ Fsc=fsc_value,
82
+ Far=far_value
83
+ )
84
+ manager.log_progress(epoch, idx)
85
+ manager.close()
86
+
87
+ import psutil
88
+ import os
89
+ def get_memory():
90
+ process = psutil.Process(os.getpid())
91
+ return process.memory_info().rss / 1024 / 1024
92
+
93
+ '''
94
+ initial_memory = get_memory()
95
+ logger.info(f"内存使用前: {initial_memory:.2f} MB {sUTC}")
96
+ load_memory = get_memory()
97
+ logger.info(f"创建对象后: {load_memory:.2f} MB {sUTC}")
98
+ final_memory = get_memory()
99
+ logger.info(f"删除对象后: {final_memory:.2f} MB {sUTC}")
100
+ logger.info(f"内存变化: 初始{initial_memory:.2f} -> 峰值{load_memory:.2f} -> 最终{final_memory:.2f}")
101
+ '''
102
+
103
+ import psutil
104
+ import sys
105
+ def check_memory_threshold(threshold=90):
106
+ mem = psutil.virtual_memory()
107
+ return mem.percent >= threshold
108
+ try:
109
+ if check_memory_threshold(90):
110
+ raise RuntimeError("The system memory usage is too high. Terminating the task. error")
111
+ except RuntimeError as e:
112
+ print(f"Memory warning error: {e}")
113
+ except MemoryError:
114
+ print("Insufficient memory, task failed. error")
115
+ # except Exception as e:
116
+ # print(f"未知错误:{e}")
117
+ '''
118
+ if check_memory_threshold(90):
119
+ raise RuntimeError("The system memory usage is too high. Terminating the task. error")
120
+
121
+ '''
@@ -0,0 +1,28 @@
1
+ import torch
2
+ import numpy as np
3
+
4
+ def checkData(data, name="data"):
5
+ if isinstance(data, torch.Tensor):
6
+ # 检查 Tensor 数据
7
+ has_nan = torch.isnan(data).any().item()
8
+ has_inf = torch.isinf(data).any().item()
9
+ print(f"torch.isnan(data).any() {torch.isnan(data).any()} torch.isinf(data).any() {torch.isinf(data).any()} ")
10
+ elif isinstance(data, np.ndarray):
11
+ # 检查 NumPy 数据
12
+ has_nan = np.isnan(data).any()
13
+ has_inf = np.isinf(data).any()
14
+ print(f"np.isnan(data).any() {np.isnan(data).any()} np.isinf(data).any() {np.isinf(data).any()} ")
15
+ else:
16
+ raise TypeError(f"Unsupported data type: {type(data)}. Expected torch.Tensor or numpy.ndarray.")
17
+
18
+ # 如果有 NaN 或 Inf,打印日志并返回 False
19
+ if has_nan or has_inf:
20
+ print(f"{name} contains NaN or Inf values!"
21
+ f"has_nan: {has_nan}, has_inf: {has_inf}")
22
+ return False
23
+ return True
24
+ """ 训练循环跳过
25
+ if not check_data(data, "data") or not check_data(label, "label"):
26
+ logger.info("# 跳过异常数据")
27
+ continue # 跳过异常数据
28
+ """
shancx/Algo/iouJU.py ADDED
@@ -0,0 +1,83 @@
1
+
2
+ def iou(inputs,targets,thresholds=[0, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 80]):
3
+ ts_sum = 0
4
+ for i in range(len(thresholds)):
5
+ inputs_copy = inputs.flatten().copy() # nputs_copy = inputs.flatten().clone()
6
+ targets_copy = targets.flatten().copy()
7
+ inputs_copy[inputs.flatten() < thresholds[i]] = 0
8
+ inputs_copy[inputs.flatten() >= thresholds[i]] = 1
9
+ targets_copy[targets.flatten() < thresholds[i]] = 0
10
+ targets_copy[targets.flatten() >= thresholds[i]] = 1
11
+ intersection = (inputs_copy.flatten() * targets_copy).sum()
12
+ total = (inputs_copy.flatten() + targets_copy).sum()
13
+ union = total - intersection
14
+ iou = (intersection + 1e-16)/ (union + 1e-16)
15
+ ts_sum += iou
16
+ return ts_sum/len(thresholds)
17
+ import torch
18
+
19
+ def iouPlus(inputs, targets, thresholds=[0, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 80]):
20
+ inputs_flat = inputs.flatten() # 提前展平
21
+ targets_flat = targets.flatten()
22
+ ts_sum = 0
23
+
24
+ for threshold in thresholds:
25
+ # 根据阈值二值化
26
+ inputs_bin = (inputs_flat >= threshold).float()
27
+ targets_bin = (targets_flat >= threshold).float()
28
+
29
+ # 计算交集和并集
30
+ intersection = (inputs_bin * targets_bin).sum()
31
+ union = inputs_bin.sum() + targets_bin.sum() - intersection
32
+
33
+ # 计算 IoU
34
+ iou = (intersection + 1e-16) / (union + 1e-16)
35
+ ts_sum += iou
36
+
37
+ # 返回平均 IoU
38
+ return ts_sum / len(thresholds)
39
+ def IouOrder(inputs,targets,thresholds=[15]):
40
+ lower_threshold = thresholds[0]
41
+ upper_threshold = thresholds[1]
42
+ inputs_copy = (inputs.flatten() >= lower_threshold) & (inputs.flatten() <= upper_threshold)
43
+ targets_copy = (targets.flatten() >= lower_threshold) & (targets.flatten() <= upper_threshold)
44
+ inputs_copy = inputs_copy.astype(int)
45
+ targets_copy = targets_copy.astype(int)
46
+ intersection = (inputs_copy.flatten() * targets_copy).sum()
47
+ total = (inputs_copy.flatten() + targets_copy).sum()
48
+ union = total - intersection
49
+ iou = (intersection + 1e-16)/ (union + 1e-16)
50
+ return iou
51
+
52
+ def iouMask(inputs, targets, mask, thresholds=[0, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 80],type=None):
53
+ # 检查输入类型
54
+ is_tensor = type
55
+ # 确保 inputs, targets, mask_data 在同一设备上
56
+ targets = targets
57
+ mask_data = mask
58
+ # 生成 mask:mask_data == 0 的区域为 True
59
+ # 展平 inputs, targets, mask
60
+ inputs_flat = inputs.flatten()
61
+ targets_flat = targets.flatten()
62
+ mask_flat = mask.flatten()
63
+ # 只计算 mask 区域的 IoU
64
+ if is_tensor:
65
+ inputs_flat = inputs_flat[mask_flat] # 只保留 mask 为 True 的区域
66
+ targets_flat = targets_flat[mask_flat]
67
+ else:
68
+ inputs_flat = inputs_flat[mask_flat == 1] # 只保留 mask 为 1 的区域
69
+ targets_flat = targets_flat[mask_flat == 1]
70
+ ts_sum = 0.0
71
+ for threshold in thresholds:
72
+ if is_tensor:
73
+ inputs_bin = (inputs_flat >= threshold).float()
74
+ targets_bin = (targets_flat >= threshold).float()
75
+ else:
76
+ inputs_bin = (inputs_flat >= threshold).astype(float)
77
+ targets_bin = (targets_flat >= threshold).astype(float)
78
+ intersection = (inputs_bin * targets_bin).sum()
79
+ union = inputs_bin.sum() + targets_bin.sum() - intersection
80
+ iou = (intersection + 1e-16) / (union + 1e-16)
81
+ ts_sum += iou
82
+ return ts_sum / len(thresholds)
83
+