shancx 1.8.92__py3-none-any.whl → 1.9.33.218__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shancx/3D/__init__.py +25 -0
- shancx/Algo/Class.py +11 -0
- shancx/Algo/CudaPrefetcher1.py +112 -0
- shancx/Algo/Fake_image.py +24 -0
- shancx/Algo/Hsml.py +391 -0
- shancx/Algo/L2Loss.py +10 -0
- shancx/Algo/MetricTracker.py +132 -0
- shancx/Algo/Normalize.py +66 -0
- shancx/Algo/OptimizerWScheduler.py +38 -0
- shancx/Algo/Rmageresize.py +79 -0
- shancx/Algo/Savemodel.py +33 -0
- shancx/Algo/SmoothL1_losses.py +27 -0
- shancx/Algo/Tqdm.py +62 -0
- shancx/Algo/__init__.py +121 -0
- shancx/Algo/checknan.py +28 -0
- shancx/Algo/iouJU.py +83 -0
- shancx/Algo/mask.py +25 -0
- shancx/Algo/psnr.py +9 -0
- shancx/Algo/ssim.py +70 -0
- shancx/Algo/structural_similarity.py +308 -0
- shancx/Algo/tool.py +704 -0
- shancx/Calmetrics/__init__.py +97 -0
- shancx/Calmetrics/calmetrics.py +14 -0
- shancx/Calmetrics/calmetricsmatrixLib.py +147 -0
- shancx/Calmetrics/rmseR2score.py +35 -0
- shancx/Clip/__init__.py +50 -0
- shancx/Cmd.py +126 -0
- shancx/Config_.py +26 -0
- shancx/Df/DataFrame.py +11 -2
- shancx/Df/__init__.py +17 -0
- shancx/Df/tool.py +0 -0
- shancx/Diffm/Psamples.py +18 -0
- shancx/Diffm/__init__.py +0 -0
- shancx/Diffm/test.py +207 -0
- shancx/Doc/__init__.py +214 -0
- shancx/E/__init__.py +178 -152
- shancx/Fillmiss/__init__.py +0 -0
- shancx/Fillmiss/imgidwJU.py +46 -0
- shancx/Fillmiss/imgidwLatLonJU.py +82 -0
- shancx/Gpu/__init__.py +55 -0
- shancx/H9/__init__.py +126 -0
- shancx/H9/ahi_read_hsd.py +877 -0
- shancx/H9/ahisearchtable.py +298 -0
- shancx/H9/geometry.py +2439 -0
- shancx/Hug/__init__.py +81 -0
- shancx/Inst.py +22 -0
- shancx/Lib.py +31 -0
- shancx/Mos/__init__.py +37 -0
- shancx/NN/__init__.py +235 -106
- shancx/Path1.py +161 -0
- shancx/Plot/GlobMap.py +276 -116
- shancx/Plot/__init__.py +491 -1
- shancx/Plot/draw_day_CR_PNG.py +4 -21
- shancx/Plot/exam.py +116 -0
- shancx/Plot/plotGlobal.py +325 -0
- shancx/{radar_nmc.py → Plot/radarNmc.py} +4 -34
- shancx/{subplots_single_china_map.py → Plot/single_china_map.py} +1 -1
- shancx/Point.py +46 -0
- shancx/QC.py +223 -0
- shancx/RdPzl/__init__.py +32 -0
- shancx/Read.py +72 -0
- shancx/Resize.py +79 -0
- shancx/SN/__init__.py +62 -123
- shancx/Time/GetTime.py +9 -3
- shancx/Time/__init__.py +66 -1
- shancx/Time/timeCycle.py +302 -0
- shancx/Time/tool.py +0 -0
- shancx/Train/__init__.py +74 -0
- shancx/Train/makelist.py +187 -0
- shancx/Train/multiGpu.py +27 -0
- shancx/Train/prepare.py +161 -0
- shancx/Train/renet50.py +157 -0
- shancx/ZR.py +12 -0
- shancx/__init__.py +333 -262
- shancx/args.py +27 -0
- shancx/bak.py +768 -0
- shancx/df2database.py +62 -2
- shancx/geosProj.py +80 -0
- shancx/info.py +38 -0
- shancx/netdfJU.py +231 -0
- shancx/sendM.py +59 -0
- shancx/tensBoard/__init__.py +28 -0
- shancx/wait.py +246 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/METADATA +15 -5
- shancx-1.9.33.218.dist-info/RECORD +91 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
- my_timer_decorator/__init__.py +0 -10
- shancx/Dsalgor/__init__.py +0 -19
- shancx/E/DFGRRIB.py +0 -30
- shancx/EN/DFGRRIB.py +0 -30
- shancx/EN/__init__.py +0 -148
- shancx/FileRead.py +0 -44
- shancx/Gray2RGB.py +0 -86
- shancx/M/__init__.py +0 -137
- shancx/MN/__init__.py +0 -133
- shancx/N/__init__.py +0 -131
- shancx/Plot/draw_day_CR_PNGUS.py +0 -206
- shancx/Plot/draw_day_CR_SVG.py +0 -275
- shancx/Plot/draw_day_pre_PNGUS.py +0 -205
- shancx/Plot/glob_nation_map.py +0 -116
- shancx/Plot/radar_nmc.py +0 -61
- shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
- shancx/Plot/radar_nmc_china_map_f.py +0 -121
- shancx/Plot/radar_nmc_us_map_f.py +0 -128
- shancx/Plot/subplots_compare_devlop.py +0 -36
- shancx/Plot/subplots_single_china_map.py +0 -45
- shancx/S/__init__.py +0 -138
- shancx/W/__init__.py +0 -132
- shancx/WN/__init__.py +0 -132
- shancx/code.py +0 -331
- shancx/draw_day_CR_PNG.py +0 -200
- shancx/draw_day_CR_PNGUS.py +0 -206
- shancx/draw_day_CR_SVG.py +0 -275
- shancx/draw_day_pre_PNGUS.py +0 -205
- shancx/makenetCDFN.py +0 -42
- shancx/mkIMGSCX.py +0 -92
- shancx/netCDF.py +0 -130
- shancx/radar_nmc_china_map_compare1.py +0 -50
- shancx/radar_nmc_china_map_f.py +0 -125
- shancx/radar_nmc_us_map_f.py +0 -67
- shancx/subplots_compare_devlop.py +0 -36
- shancx/tool.py +0 -18
- shancx/user/H8mess.py +0 -317
- shancx/user/__init__.py +0 -137
- shancx/user/cinradHJN.py +0 -496
- shancx/user/examMeso.py +0 -293
- shancx/user/hjnDAAS.py +0 -26
- shancx/user/hjnFTP.py +0 -81
- shancx/user/hjnGIS.py +0 -320
- shancx/user/hjnGPU.py +0 -21
- shancx/user/hjnIDW.py +0 -68
- shancx/user/hjnKDTree.py +0 -75
- shancx/user/hjnLAPSTransform.py +0 -47
- shancx/user/hjnMiscellaneous.py +0 -182
- shancx/user/hjnProj.py +0 -162
- shancx/user/inotify.py +0 -41
- shancx/user/matplotlibMess.py +0 -87
- shancx/user/mkNCHJN.py +0 -623
- shancx/user/newTypeRadar.py +0 -492
- shancx/user/test.py +0 -6
- shancx/user/tlogP.py +0 -129
- shancx/util_log.py +0 -33
- shancx/wtx/H8mess.py +0 -315
- shancx/wtx/__init__.py +0 -151
- shancx/wtx/cinradHJN.py +0 -496
- shancx/wtx/colormap.py +0 -64
- shancx/wtx/examMeso.py +0 -298
- shancx/wtx/hjnDAAS.py +0 -26
- shancx/wtx/hjnFTP.py +0 -81
- shancx/wtx/hjnGIS.py +0 -330
- shancx/wtx/hjnGPU.py +0 -21
- shancx/wtx/hjnIDW.py +0 -68
- shancx/wtx/hjnKDTree.py +0 -75
- shancx/wtx/hjnLAPSTransform.py +0 -47
- shancx/wtx/hjnLog.py +0 -78
- shancx/wtx/hjnMiscellaneous.py +0 -201
- shancx/wtx/hjnProj.py +0 -161
- shancx/wtx/inotify.py +0 -41
- shancx/wtx/matplotlibMess.py +0 -87
- shancx/wtx/mkNCHJN.py +0 -613
- shancx/wtx/newTypeRadar.py +0 -492
- shancx/wtx/test.py +0 -6
- shancx/wtx/tlogP.py +0 -129
- shancx-1.8.92.dist-info/RECORD +0 -99
- /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
|
2
|
+
class MetricTracker(object):
|
|
3
|
+
"""Computes and stores the average and current value"""
|
|
4
|
+
def __init__(self):
|
|
5
|
+
self.reset()
|
|
6
|
+
def reset(self):
|
|
7
|
+
self.val = 0
|
|
8
|
+
self.avg = 0
|
|
9
|
+
self.sum = 0
|
|
10
|
+
self.count = 0
|
|
11
|
+
def update(self, val, n=1):
|
|
12
|
+
self.val = val
|
|
13
|
+
self.sum += val * n
|
|
14
|
+
self.count += n
|
|
15
|
+
self.avg = self.sum / self.count
|
|
16
|
+
"""
|
|
17
|
+
train_acc = metrics.MetricTracker()
|
|
18
|
+
train_acc.update(metrics.acc(outputs, labels,threshold=0.1), outputs.size(0))
|
|
19
|
+
train_acc.avg
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
from tqdm import tqdm
|
|
24
|
+
|
|
25
|
+
class TrainingManager:
|
|
26
|
+
def __init__(self, loader, description="Training Progress", log_file=None):
|
|
27
|
+
"""
|
|
28
|
+
封装进度条和日志记录的训练管理类
|
|
29
|
+
:param loader: 数据加载器
|
|
30
|
+
:param description: 进度条描述
|
|
31
|
+
:param log_file: 日志文件路径,可选
|
|
32
|
+
"""
|
|
33
|
+
self.loader = tqdm(loader, desc=description)
|
|
34
|
+
self.metrics = {}
|
|
35
|
+
self.log_file = log_file
|
|
36
|
+
if log_file:
|
|
37
|
+
with open(log_file, 'w') as f:
|
|
38
|
+
headers = "Epoch,Step," + ",".join(self.metrics.keys()) + "\n"
|
|
39
|
+
f.write(headers)
|
|
40
|
+
|
|
41
|
+
def add_metric(self, name, initial_value=0.0):
|
|
42
|
+
"""
|
|
43
|
+
添加指标用于追踪
|
|
44
|
+
:param name: 指标名称
|
|
45
|
+
:param initial_value: 初始值
|
|
46
|
+
"""
|
|
47
|
+
self.metrics[name] = Metric(initial_value)
|
|
48
|
+
|
|
49
|
+
def update_metrics(self, **kwargs):
|
|
50
|
+
"""
|
|
51
|
+
更新指定的指标值
|
|
52
|
+
:param kwargs: 关键字参数,指标名和更新值
|
|
53
|
+
"""
|
|
54
|
+
for name, value in kwargs.items():
|
|
55
|
+
if name in self.metrics:
|
|
56
|
+
self.metrics[name].update(value)
|
|
57
|
+
|
|
58
|
+
def log_progress(self, epoch, step):
|
|
59
|
+
"""
|
|
60
|
+
更新进度条和日志文件
|
|
61
|
+
:param epoch: 当前训练轮次
|
|
62
|
+
:param step: 当前训练步骤
|
|
63
|
+
"""
|
|
64
|
+
description = f"Epoch {epoch}: " + " ".join(
|
|
65
|
+
[f"{name}: {metric.avg:.4f}" for name, metric in self.metrics.items()]
|
|
66
|
+
)
|
|
67
|
+
self.loader.set_description(description)
|
|
68
|
+
|
|
69
|
+
if self.log_file:
|
|
70
|
+
with open(self.log_file, 'a') as f:
|
|
71
|
+
line = f"{epoch},{step}," + ",".join([f"{metric.avg:.4f}" for metric in self.metrics.values()]) + "\n"
|
|
72
|
+
f.write(line)
|
|
73
|
+
|
|
74
|
+
def close(self):
|
|
75
|
+
"""关闭进度条"""
|
|
76
|
+
self.loader.close()
|
|
77
|
+
|
|
78
|
+
class Metric:
|
|
79
|
+
def __init__(self, initial_value=0.0):
|
|
80
|
+
"""
|
|
81
|
+
用于追踪的指标
|
|
82
|
+
:param initial_value: 初始值
|
|
83
|
+
"""
|
|
84
|
+
self.total = initial_value
|
|
85
|
+
self.count = 0
|
|
86
|
+
self.avg = initial_value
|
|
87
|
+
|
|
88
|
+
def update(self, value, count=1):
|
|
89
|
+
"""
|
|
90
|
+
更新指标值
|
|
91
|
+
:param value: 新增的值
|
|
92
|
+
:param count: 数据数量,默认1
|
|
93
|
+
"""
|
|
94
|
+
self.total += value * count
|
|
95
|
+
self.count += count
|
|
96
|
+
self.avg = self.total / self.count
|
|
97
|
+
|
|
98
|
+
# 示例代码
|
|
99
|
+
if __name__ == "__main__":
|
|
100
|
+
# 假设 train_dataloader 是数据加载器,示例模拟100条数据
|
|
101
|
+
train_dataloader = [{"sat_img": None, "map_img": None} for _ in range(100)]
|
|
102
|
+
|
|
103
|
+
manager = TrainingManager(train_dataloader, description="Training Progress", log_file="training_log.csv")
|
|
104
|
+
manager.add_metric("Loss")
|
|
105
|
+
manager.add_metric("Accuracy")
|
|
106
|
+
manager.add_metric("Ts")
|
|
107
|
+
manager.add_metric("Fsc")
|
|
108
|
+
manager.add_metric("Far")
|
|
109
|
+
|
|
110
|
+
for epoch in range(3): # 模拟3个epoch
|
|
111
|
+
for idx, data in enumerate(manager.loader):
|
|
112
|
+
# 模拟训练过程的计算
|
|
113
|
+
loss_value = 0.5 + idx * 0.01
|
|
114
|
+
accuracy_value = 0.8 - idx * 0.001
|
|
115
|
+
ts_value = 0.7 + idx * 0.001
|
|
116
|
+
fsc_value = 0.6 + idx * 0.002
|
|
117
|
+
far_value = 0.2 + idx * 0.001
|
|
118
|
+
|
|
119
|
+
# 更新指标
|
|
120
|
+
manager.update_metrics(
|
|
121
|
+
Loss=loss_value,
|
|
122
|
+
Accuracy=accuracy_value,
|
|
123
|
+
Ts=ts_value,
|
|
124
|
+
Fsc=fsc_value,
|
|
125
|
+
Far=far_value
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# 记录进度和日志
|
|
129
|
+
manager.log_progress(epoch, idx)
|
|
130
|
+
|
|
131
|
+
# 关闭进度条
|
|
132
|
+
manager.close()
|
shancx/Algo/Normalize.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
|
|
2
|
+
from torchvision.transforms import Resize# Function to calculate valid size based on scale factor
|
|
3
|
+
def get_valid_size(size, scale_factor=2):
|
|
4
|
+
return size - (size % scale_factor)
|
|
5
|
+
# Function to resize a high-resolution image to a low-resolution image
|
|
6
|
+
def resize_to_low_res(high_res_image, scale_factor):
|
|
7
|
+
_, width, height = high_res_image.size()
|
|
8
|
+
# Get valid width and height divisible by the scale factor
|
|
9
|
+
width = get_valid_size(width, scale_factor)
|
|
10
|
+
height = get_valid_size(height, scale_factor)
|
|
11
|
+
# Crop the high-res image to the valid dimensions
|
|
12
|
+
high_res_image = high_res_image[:, :width, :height]
|
|
13
|
+
# Create the low-res image by resizing the high-res image
|
|
14
|
+
low_res_image = Resize((width // scale_factor, height // scale_factor))(high_res_image)
|
|
15
|
+
return low_res_image
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
cuda-version 11.8 hcce14f8_3
|
|
19
|
+
cudatoolkit 11.8.0 h6a678d5_0
|
|
20
|
+
cudnn 8.9.2.26 cuda11_0
|
|
21
|
+
nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
|
|
22
|
+
nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
|
|
23
|
+
nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
|
|
24
|
+
nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi
|
|
25
|
+
mqpf conda install pytorch torchvision torchaudio cudatoolkit=11.8 -c pytorch
|
|
26
|
+
conda install cudnn=8.9.2.26 cudatoolkit=11.8
|
|
27
|
+
resunet pip install torch==2.4.0 torchvision torchaudio
|
|
28
|
+
conda install cudnn==8.9.2.26 cudatoolkit==11.8.0
|
|
29
|
+
conda install pytorch=2.2.2 torchvision torchaudio cudatoolkit=11.8 -c pytorch
|
|
30
|
+
resunet pip install torch==2.4.0 torchvision torchaudio
|
|
31
|
+
pip install protobuf==3.20
|
|
32
|
+
|
|
33
|
+
my-envmf1
|
|
34
|
+
torch 2.3.0 pypi_0 pypi
|
|
35
|
+
torchvision 0.18.0 pypi_0 pypi
|
|
36
|
+
|
|
37
|
+
RES:
|
|
38
|
+
torch 2.4.0 pypi_0 pypi
|
|
39
|
+
torchaudio 2.2.2 py311_cpu pytorch
|
|
40
|
+
torchsummary 1.5.1 pypi_0 pypi
|
|
41
|
+
torchvision 0.19.0 pypi_0 pypi
|
|
42
|
+
|
|
43
|
+
mqpf:
|
|
44
|
+
torch 2.3.1 pypi_0 pypi
|
|
45
|
+
torchaudio 2.3.1 pypi_0 pypi
|
|
46
|
+
torchvision 0.18.1 pypi_0 pypi
|
|
47
|
+
onnxruntime-gpu 1.16.0
|
|
48
|
+
onnx 1.15.0
|
|
49
|
+
numpy 1.26.4
|
|
50
|
+
|
|
51
|
+
vllm:
|
|
52
|
+
torch 2.1.2 pypi_0 pypi
|
|
53
|
+
torchvision 0.15.1+cu118 pypi_0 pypi
|
|
54
|
+
vllm 0.2.7 pypi_0 pypi
|
|
55
|
+
|
|
56
|
+
import torch
|
|
57
|
+
print("CUDA available:", torch.cuda.is_available())
|
|
58
|
+
print("CUDA version:", torch.version.cuda)
|
|
59
|
+
print("GPU device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
|
|
60
|
+
nvidia-smi
|
|
61
|
+
nvcc --version
|
|
62
|
+
系统已经检测到物理 GPU(NVIDIA GeForce RTX 4090)和 NVIDIA 驱动,同时安装了 CUDA 12.1。然而,PyTorch 没有正确检测到 GPU,可能是因为 PyTorch 版本与 CUDA 驱动不兼容,或者环境变量未正确配置。
|
|
63
|
+
|
|
64
|
+
pip install torch==2.3.1 torchvision==0.18.1
|
|
65
|
+
|
|
66
|
+
"""
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
class OptimizerWithScheduler:
|
|
4
|
+
def __init__(self, model, lr=0.001, step_size=40, gamma=0.1):
|
|
5
|
+
"""
|
|
6
|
+
初始化优化器和学习率调度器。
|
|
7
|
+
:param model: 要优化的模型
|
|
8
|
+
:param lr: 初始学习率
|
|
9
|
+
:param step_size: 调度器每隔多少个 epoch 调整学习率
|
|
10
|
+
:param gamma: 调度器调整学习率的乘法因子
|
|
11
|
+
"""
|
|
12
|
+
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
13
|
+
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
|
14
|
+
self.optimizer, step_size=step_size, gamma=gamma
|
|
15
|
+
)
|
|
16
|
+
def zero_grad(self):
|
|
17
|
+
self.optimizer.zero_grad()
|
|
18
|
+
def step(self):
|
|
19
|
+
self.optimizer.step()
|
|
20
|
+
def step_scheduler(self):
|
|
21
|
+
self.lr_scheduler.step()
|
|
22
|
+
def get_lr(self):
|
|
23
|
+
return self.optimizer.param_groups[0]['lr']
|
|
24
|
+
"""
|
|
25
|
+
if __name__ == "__main__":
|
|
26
|
+
model = torch.nn.Linear(10, 1)
|
|
27
|
+
optimizer_with_scheduler = OptimizerWithScheduler(model, lr=0.001, step_size=40, gamma=0.1)
|
|
28
|
+
for epoch in range(100):
|
|
29
|
+
optimizer_with_scheduler.step_scheduler()
|
|
30
|
+
for idx, data in enumerate(loader):
|
|
31
|
+
inputs = data["sat_img"].cuda()
|
|
32
|
+
labels = data["map_img"].cuda()
|
|
33
|
+
optimizer_with_scheduler.zero_grad()
|
|
34
|
+
loss = torch.randn(1) # 仅为示例
|
|
35
|
+
loss.backward()
|
|
36
|
+
optimizer_with_scheduler.step()
|
|
37
|
+
print(f"Epoch {epoch + 1}, Learning Rate: {optimizer_with_scheduler.get_lr()}")
|
|
38
|
+
"""
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torchvision.transforms import Resize
|
|
3
|
+
from torchvision.transforms.functional import InterpolationMode
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
def image_resize(image, scale_factor, antialiasing=True):
|
|
7
|
+
"""
|
|
8
|
+
Resize the input image with a specified scale factor using torchvision's Resize.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
image (np.ndarray or torch.Tensor): Input image of shape (H, W, C) or (C, H, W).
|
|
12
|
+
scale_factor (float): Scale factor to resize the image.
|
|
13
|
+
antialiasing (bool): Whether to use antialiasing (default: True).
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
Resized image in the same format as input (NumPy or PyTorch Tensor).
|
|
17
|
+
"""
|
|
18
|
+
# Check if input is NumPy array
|
|
19
|
+
numpy_type = isinstance(image, np.ndarray)
|
|
20
|
+
|
|
21
|
+
# Handle missing values for NumPy arrays
|
|
22
|
+
if numpy_type:
|
|
23
|
+
image = np.nan_to_num(image) # Replace NaNs with 0
|
|
24
|
+
if image.ndim == 3: # (H, W, C)
|
|
25
|
+
image = torch.from_numpy(image.transpose(2, 0, 1)).float() # Convert to (C, H, W)
|
|
26
|
+
elif image.ndim == 2: # (H, W)
|
|
27
|
+
image = torch.from_numpy(image[None, :, :]).float() # Convert to (C, H, W)
|
|
28
|
+
|
|
29
|
+
# Calculate new dimensions
|
|
30
|
+
_, in_h, in_w = image.shape # Assuming (C, H, W) format
|
|
31
|
+
out_h, out_w = int(in_h * scale_factor), int(in_w * scale_factor)
|
|
32
|
+
|
|
33
|
+
# Perform resizing
|
|
34
|
+
mode = InterpolationMode.BICUBIC if antialiasing else InterpolationMode.NEAREST
|
|
35
|
+
resize_transform = Resize((out_h, out_w), interpolation=mode)
|
|
36
|
+
resized_image = resize_transform(image)
|
|
37
|
+
|
|
38
|
+
# Convert back to NumPy array if input was NumPy
|
|
39
|
+
if numpy_type:
|
|
40
|
+
resized_image = resized_image.numpy().transpose(1, 2, 0) # Convert back to (H, W, C)
|
|
41
|
+
|
|
42
|
+
return resized_image
|
|
43
|
+
# resized_image = image_resize(image, scale_factor, antialiasing)
|
|
44
|
+
|
|
45
|
+
# Example for loading image data and processing NaNs
|
|
46
|
+
def process_image(file_path, scale_factor, antialiasing=True):
|
|
47
|
+
"""
|
|
48
|
+
Load an image from a file, handle NaN values, and resize it.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
file_path (str): Path to the image file (e.g., .npy for NumPy arrays).
|
|
52
|
+
scale_factor (float): Scale factor for resizing.
|
|
53
|
+
antialiasing (bool): Use antialiasing during resizing (default: True).
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Resized image as a NumPy array.
|
|
57
|
+
"""
|
|
58
|
+
# Load the image from a .npy file
|
|
59
|
+
image = np.load(file_path)
|
|
60
|
+
image[np.isnan(image)] = 0 # Replace NaN values with 0
|
|
61
|
+
|
|
62
|
+
# Resize the image
|
|
63
|
+
resized_image = image_resize(image, scale_factor, antialiasing)
|
|
64
|
+
return resized_image
|
|
65
|
+
if __name__=="__main__":
|
|
66
|
+
import matplotlib.pyplot as plt
|
|
67
|
+
|
|
68
|
+
# Path to a NumPy array image file
|
|
69
|
+
file_path = 'example.npy' # Replace with your .npy file path
|
|
70
|
+
|
|
71
|
+
# Resize the image
|
|
72
|
+
scale_factor = 0.5 # Downscale by 50%
|
|
73
|
+
resized_image = process_image(file_path, scale_factor)
|
|
74
|
+
|
|
75
|
+
# Display the resized image
|
|
76
|
+
plt.title("Resized Image")
|
|
77
|
+
plt.imshow(resized_image.astype(np.uint8))
|
|
78
|
+
plt.axis("off")
|
|
79
|
+
plt.show()
|
shancx/Algo/Savemodel.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
def save_model_checkpoint(step, epoch, model, optimizer, best_loss, best_acc, checkpoint_dir, name):
|
|
4
|
+
if not os.path.exists(checkpoint_dir):
|
|
5
|
+
os.makedirs(checkpoint_dir)
|
|
6
|
+
save_path = os.path.join(checkpoint_dir, f"{name}_checkpoint_{step:04d}.pt")
|
|
7
|
+
torch.save({
|
|
8
|
+
"step": step,
|
|
9
|
+
"epoch": epoch,
|
|
10
|
+
"arch": model.__class__.__name__,
|
|
11
|
+
"state_dict": model.state_dict(),
|
|
12
|
+
"best_acc": best_acc,
|
|
13
|
+
"best_loss": best_loss,
|
|
14
|
+
"optimizer": optimizer.state_dict(),
|
|
15
|
+
}, save_path)
|
|
16
|
+
return save_path
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
if valid_loss < best_loss:
|
|
20
|
+
best_loss = valid_loss
|
|
21
|
+
wait = 0 # Reset wait counter
|
|
22
|
+
save_path = save_model_checkpoint(
|
|
23
|
+
step=step,
|
|
24
|
+
epoch=epoch,
|
|
25
|
+
model=model,
|
|
26
|
+
optimizer=optimizer,
|
|
27
|
+
best_loss=best_loss,
|
|
28
|
+
best_acc=best_acc,
|
|
29
|
+
checkpoint_dir=checkpoint_dir,
|
|
30
|
+
name=name
|
|
31
|
+
)
|
|
32
|
+
print(f"Checkpoint saved at: {save_path}")
|
|
33
|
+
"""
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
def smoothL1_loss(x, y):
|
|
5
|
+
smoothL1loss = nn.SmoothL1Loss()
|
|
6
|
+
return smoothL1loss(x, y)
|
|
7
|
+
def compute_smoothL1_losses(x,y, y_label):
|
|
8
|
+
"""
|
|
9
|
+
计算每个标签类别的 Smooth L1 损失。
|
|
10
|
+
参数:
|
|
11
|
+
- x (torch.Tensor): 预测值。
|
|
12
|
+
- y (torch.Tensor): 真实值。
|
|
13
|
+
- y_label (torch.Tensor): 标签值,用于选择对应类别的预测和真实值。
|
|
14
|
+
返回:
|
|
15
|
+
- smoothL1_losses (list): 包含每个类别的 Smooth L1 损失值的列表。
|
|
16
|
+
"""
|
|
17
|
+
smoothL1_losses = []
|
|
18
|
+
for label_value in range(1, 5):
|
|
19
|
+
mask = y_label == label_value
|
|
20
|
+
x_masked = torch.masked_select(x, mask)
|
|
21
|
+
y_masked = torch.masked_select(y, mask)
|
|
22
|
+
if x_masked.numel() > 0:
|
|
23
|
+
smoothL1_val = smoothL1_loss(x_masked, y_masked) #, reduction='mean'
|
|
24
|
+
smoothL1_losses.append(smoothL1_val)
|
|
25
|
+
else:
|
|
26
|
+
smoothL1_losses.append(torch.tensor(0.0)) #增强逻辑
|
|
27
|
+
return smoothL1_losses
|
shancx/Algo/Tqdm.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from tqdm import tqdm
|
|
2
|
+
class TrainingManager:
|
|
3
|
+
def __init__(self, loader, description="Training Progress", log_file=None):
|
|
4
|
+
self.loader = tqdm(loader, desc=description)
|
|
5
|
+
self.metrics = {}
|
|
6
|
+
self.log_file = log_file
|
|
7
|
+
if log_file:
|
|
8
|
+
with open(log_file, 'w') as f:
|
|
9
|
+
headers = "Epoch,Step," + ",".join(self.metrics.keys()) + "\n"
|
|
10
|
+
f.write(headers)
|
|
11
|
+
def add_metric(self, name, initial_value=0.0):
|
|
12
|
+
self.metrics[name] = Metric(initial_value) ###用的另外一个更新类计算均值的
|
|
13
|
+
def update_metrics(self, **kwargs):
|
|
14
|
+
for name, value in kwargs.items(): #param kwargs: 关键字参数,指标名和更新值
|
|
15
|
+
if name in self.metrics:
|
|
16
|
+
self.metrics[name].update(value)
|
|
17
|
+
def log_progress(self, epoch, step):
|
|
18
|
+
description = f"Epoch {epoch}: " + " ".join(
|
|
19
|
+
[f"{name}: {metric.avg:.4f}" for name, metric in self.metrics.items()]
|
|
20
|
+
)
|
|
21
|
+
self.loader.set_description(description)
|
|
22
|
+
if self.log_file:
|
|
23
|
+
with open(self.log_file, 'a') as f:
|
|
24
|
+
line = f"{epoch},{step}," + ",".join([f"{metric.avg:.4f}" for metric in self.metrics.values()]) + "\n"
|
|
25
|
+
f.write(line)
|
|
26
|
+
def log_progress1(self, description):
|
|
27
|
+
self.loader.set_description(description)
|
|
28
|
+
def close(self):
|
|
29
|
+
self.loader.close()
|
|
30
|
+
class Metric:
|
|
31
|
+
def __init__(self, initial_value=0.0):
|
|
32
|
+
self.total = initial_value
|
|
33
|
+
self.count = 0
|
|
34
|
+
self.avg = initial_value
|
|
35
|
+
def update(self, value, count=1):
|
|
36
|
+
self.total += value * count
|
|
37
|
+
self.count += count
|
|
38
|
+
self.avg = self.total / self.count
|
|
39
|
+
if __name__ == "__main__":
|
|
40
|
+
train_dataloader = [{"sat_img": None, "map_img": None} for _ in range(100)]
|
|
41
|
+
manager = TrainingManager(train_dataloader, description="Training Progress", log_file="training_log.csv")
|
|
42
|
+
manager.add_metric("Loss")
|
|
43
|
+
manager.add_metric("Accuracy")
|
|
44
|
+
manager.add_metric("Ts")
|
|
45
|
+
manager.add_metric("Fsc")
|
|
46
|
+
manager.add_metric("Far")
|
|
47
|
+
for epoch in range(3): # 模拟3个epoch
|
|
48
|
+
for idx, data in enumerate(manager.loader):
|
|
49
|
+
loss_value = 0.5 + idx * 0.01
|
|
50
|
+
accuracy_value = 0.8 - idx * 0.001
|
|
51
|
+
ts_value = 0.7 + idx * 0.001
|
|
52
|
+
fsc_value = 0.6 + idx * 0.002
|
|
53
|
+
far_value = 0.2 + idx * 0.001
|
|
54
|
+
manager.update_metrics(
|
|
55
|
+
Loss=loss_value,
|
|
56
|
+
Accuracy=accuracy_value,
|
|
57
|
+
Ts=ts_value,
|
|
58
|
+
Fsc=fsc_value,
|
|
59
|
+
Far=far_value
|
|
60
|
+
)
|
|
61
|
+
manager.log_progress(epoch, idx)
|
|
62
|
+
manager.close()
|
shancx/Algo/__init__.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
#!/usr/bin/python
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# @Time : 2024/10/17 上午午10:40
|
|
4
|
+
# @Author : shancx
|
|
5
|
+
# @File : __init__.py
|
|
6
|
+
# @email : shanhe12@163.com
|
|
7
|
+
|
|
8
|
+
def quick_sort(arr):
|
|
9
|
+
if len(arr) <= 1:
|
|
10
|
+
return arr
|
|
11
|
+
pivot = arr[len(arr) // 2]
|
|
12
|
+
left = [x for x in arr if x < pivot]
|
|
13
|
+
middle = [x for x in arr if x == pivot]
|
|
14
|
+
right = [x for x in arr if x > pivot]
|
|
15
|
+
return quick_sort(left) + middle + quick_sort(right)
|
|
16
|
+
def sort_dict_by_key(d):
|
|
17
|
+
sorted_dict = {key: d[key] for key in sorted(d.keys())}
|
|
18
|
+
return sorted_dict
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
class TrainingManager:
|
|
21
|
+
def __init__(self, loader, desc="Training Progress", log_file=None):
|
|
22
|
+
self.loader = tqdm(loader, desc=desc)
|
|
23
|
+
self.metrics = {}
|
|
24
|
+
self.log_file = log_file
|
|
25
|
+
if log_file:
|
|
26
|
+
with open(log_file, 'w') as f:
|
|
27
|
+
headers = "Epoch,Step," + ",".join(self.metrics.keys()) + "\n"
|
|
28
|
+
f.write(headers)
|
|
29
|
+
def add_metric(self, name, initial_value=0.0):
|
|
30
|
+
self.metrics[name] = Metric(initial_value) ###用的另外一个更新类计算均值的
|
|
31
|
+
def update_metrics(self, **kwargs):
|
|
32
|
+
for name, value in kwargs.items(): #param kwargs: 关键字参数,指标名和更新值
|
|
33
|
+
if name in self.metrics:
|
|
34
|
+
self.metrics[name].update(value)
|
|
35
|
+
def log_progress(self, epoch, step):
|
|
36
|
+
description = f"Epoch {epoch}: " + " ".join(
|
|
37
|
+
[f"{name}: {metric.avg:.4f}" for name, metric in self.metrics.items()]
|
|
38
|
+
)
|
|
39
|
+
self.loader.set_description(description)
|
|
40
|
+
if self.log_file:
|
|
41
|
+
with open(self.log_file, 'a') as f:
|
|
42
|
+
line = f"{epoch},{step}," + ",".join([f"{metric.avg:.4f}" for metric in self.metrics.values()]) + "\n"
|
|
43
|
+
f.write(line)
|
|
44
|
+
"""
|
|
45
|
+
description = f"Epoch {epoch}: " + " ".join(
|
|
46
|
+
[f"{name}: {metric.avg:.4f}" for name, metric in self.metrics.items()]
|
|
47
|
+
)
|
|
48
|
+
def log_progress1(self,description):
|
|
49
|
+
self.loader.set_description(description)
|
|
50
|
+
"""
|
|
51
|
+
def close(self):
|
|
52
|
+
self.loader.close()
|
|
53
|
+
class Metric:
|
|
54
|
+
def __init__(self, initial_value=0.0):
|
|
55
|
+
self.total = initial_value
|
|
56
|
+
self.count = 0
|
|
57
|
+
self.avg = initial_value
|
|
58
|
+
def update(self, value, count=1):
|
|
59
|
+
self.total += value * count
|
|
60
|
+
self.count += count
|
|
61
|
+
self.avg = self.total / self.count
|
|
62
|
+
if __name__ == "__main__":
|
|
63
|
+
train_dataloader = [{"sat_img": None, "map_img": None} for _ in range(100)]
|
|
64
|
+
manager = TrainingManager(train_dataloader, description="Training Progress", log_file="training_log.csv")
|
|
65
|
+
manager.add_metric("Loss")
|
|
66
|
+
manager.add_metric("Accuracy")
|
|
67
|
+
manager.add_metric("Ts")
|
|
68
|
+
manager.add_metric("Fsc")
|
|
69
|
+
manager.add_metric("Far")
|
|
70
|
+
for epoch in range(3): # 模拟3个epoch
|
|
71
|
+
for idx, data in enumerate(manager.loader):
|
|
72
|
+
loss_value = 0.5 + idx * 0.01
|
|
73
|
+
accuracy_value = 0.8 - idx * 0.001
|
|
74
|
+
ts_value = 0.7 + idx * 0.001
|
|
75
|
+
fsc_value = 0.6 + idx * 0.002
|
|
76
|
+
far_value = 0.2 + idx * 0.001
|
|
77
|
+
manager.update_metrics(
|
|
78
|
+
Loss=loss_value,
|
|
79
|
+
Accuracy=accuracy_value,
|
|
80
|
+
Ts=ts_value,
|
|
81
|
+
Fsc=fsc_value,
|
|
82
|
+
Far=far_value
|
|
83
|
+
)
|
|
84
|
+
manager.log_progress(epoch, idx)
|
|
85
|
+
manager.close()
|
|
86
|
+
|
|
87
|
+
import psutil
|
|
88
|
+
import os
|
|
89
|
+
def get_memory():
|
|
90
|
+
process = psutil.Process(os.getpid())
|
|
91
|
+
return process.memory_info().rss / 1024 / 1024
|
|
92
|
+
|
|
93
|
+
'''
|
|
94
|
+
initial_memory = get_memory()
|
|
95
|
+
logger.info(f"内存使用前: {initial_memory:.2f} MB {sUTC}")
|
|
96
|
+
load_memory = get_memory()
|
|
97
|
+
logger.info(f"创建对象后: {load_memory:.2f} MB {sUTC}")
|
|
98
|
+
final_memory = get_memory()
|
|
99
|
+
logger.info(f"删除对象后: {final_memory:.2f} MB {sUTC}")
|
|
100
|
+
logger.info(f"内存变化: 初始{initial_memory:.2f} -> 峰值{load_memory:.2f} -> 最终{final_memory:.2f}")
|
|
101
|
+
'''
|
|
102
|
+
|
|
103
|
+
import psutil
|
|
104
|
+
import sys
|
|
105
|
+
def check_memory_threshold(threshold=90):
|
|
106
|
+
mem = psutil.virtual_memory()
|
|
107
|
+
return mem.percent >= threshold
|
|
108
|
+
try:
|
|
109
|
+
if check_memory_threshold(90):
|
|
110
|
+
raise RuntimeError("The system memory usage is too high. Terminating the task. error")
|
|
111
|
+
except RuntimeError as e:
|
|
112
|
+
print(f"Memory warning error: {e}")
|
|
113
|
+
except MemoryError:
|
|
114
|
+
print("Insufficient memory, task failed. error")
|
|
115
|
+
# except Exception as e:
|
|
116
|
+
# print(f"未知错误:{e}")
|
|
117
|
+
'''
|
|
118
|
+
if check_memory_threshold(90):
|
|
119
|
+
raise RuntimeError("The system memory usage is too high. Terminating the task. error")
|
|
120
|
+
|
|
121
|
+
'''
|
shancx/Algo/checknan.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
def checkData(data, name="data"):
|
|
5
|
+
if isinstance(data, torch.Tensor):
|
|
6
|
+
# 检查 Tensor 数据
|
|
7
|
+
has_nan = torch.isnan(data).any().item()
|
|
8
|
+
has_inf = torch.isinf(data).any().item()
|
|
9
|
+
print(f"torch.isnan(data).any() {torch.isnan(data).any()} torch.isinf(data).any() {torch.isinf(data).any()} ")
|
|
10
|
+
elif isinstance(data, np.ndarray):
|
|
11
|
+
# 检查 NumPy 数据
|
|
12
|
+
has_nan = np.isnan(data).any()
|
|
13
|
+
has_inf = np.isinf(data).any()
|
|
14
|
+
print(f"np.isnan(data).any() {np.isnan(data).any()} np.isinf(data).any() {np.isinf(data).any()} ")
|
|
15
|
+
else:
|
|
16
|
+
raise TypeError(f"Unsupported data type: {type(data)}. Expected torch.Tensor or numpy.ndarray.")
|
|
17
|
+
|
|
18
|
+
# 如果有 NaN 或 Inf,打印日志并返回 False
|
|
19
|
+
if has_nan or has_inf:
|
|
20
|
+
print(f"{name} contains NaN or Inf values!"
|
|
21
|
+
f"has_nan: {has_nan}, has_inf: {has_inf}")
|
|
22
|
+
return False
|
|
23
|
+
return True
|
|
24
|
+
""" 训练循环跳过
|
|
25
|
+
if not check_data(data, "data") or not check_data(label, "label"):
|
|
26
|
+
logger.info("# 跳过异常数据")
|
|
27
|
+
continue # 跳过异常数据
|
|
28
|
+
"""
|
shancx/Algo/iouJU.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
|
|
2
|
+
def iou(inputs,targets,thresholds=[0, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 80]):
|
|
3
|
+
ts_sum = 0
|
|
4
|
+
for i in range(len(thresholds)):
|
|
5
|
+
inputs_copy = inputs.flatten().copy() # nputs_copy = inputs.flatten().clone()
|
|
6
|
+
targets_copy = targets.flatten().copy()
|
|
7
|
+
inputs_copy[inputs.flatten() < thresholds[i]] = 0
|
|
8
|
+
inputs_copy[inputs.flatten() >= thresholds[i]] = 1
|
|
9
|
+
targets_copy[targets.flatten() < thresholds[i]] = 0
|
|
10
|
+
targets_copy[targets.flatten() >= thresholds[i]] = 1
|
|
11
|
+
intersection = (inputs_copy.flatten() * targets_copy).sum()
|
|
12
|
+
total = (inputs_copy.flatten() + targets_copy).sum()
|
|
13
|
+
union = total - intersection
|
|
14
|
+
iou = (intersection + 1e-16)/ (union + 1e-16)
|
|
15
|
+
ts_sum += iou
|
|
16
|
+
return ts_sum/len(thresholds)
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
def iouPlus(inputs, targets, thresholds=[0, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 80]):
|
|
20
|
+
inputs_flat = inputs.flatten() # 提前展平
|
|
21
|
+
targets_flat = targets.flatten()
|
|
22
|
+
ts_sum = 0
|
|
23
|
+
|
|
24
|
+
for threshold in thresholds:
|
|
25
|
+
# 根据阈值二值化
|
|
26
|
+
inputs_bin = (inputs_flat >= threshold).float()
|
|
27
|
+
targets_bin = (targets_flat >= threshold).float()
|
|
28
|
+
|
|
29
|
+
# 计算交集和并集
|
|
30
|
+
intersection = (inputs_bin * targets_bin).sum()
|
|
31
|
+
union = inputs_bin.sum() + targets_bin.sum() - intersection
|
|
32
|
+
|
|
33
|
+
# 计算 IoU
|
|
34
|
+
iou = (intersection + 1e-16) / (union + 1e-16)
|
|
35
|
+
ts_sum += iou
|
|
36
|
+
|
|
37
|
+
# 返回平均 IoU
|
|
38
|
+
return ts_sum / len(thresholds)
|
|
39
|
+
def IouOrder(inputs,targets,thresholds=[15]):
|
|
40
|
+
lower_threshold = thresholds[0]
|
|
41
|
+
upper_threshold = thresholds[1]
|
|
42
|
+
inputs_copy = (inputs.flatten() >= lower_threshold) & (inputs.flatten() <= upper_threshold)
|
|
43
|
+
targets_copy = (targets.flatten() >= lower_threshold) & (targets.flatten() <= upper_threshold)
|
|
44
|
+
inputs_copy = inputs_copy.astype(int)
|
|
45
|
+
targets_copy = targets_copy.astype(int)
|
|
46
|
+
intersection = (inputs_copy.flatten() * targets_copy).sum()
|
|
47
|
+
total = (inputs_copy.flatten() + targets_copy).sum()
|
|
48
|
+
union = total - intersection
|
|
49
|
+
iou = (intersection + 1e-16)/ (union + 1e-16)
|
|
50
|
+
return iou
|
|
51
|
+
|
|
52
|
+
def iouMask(inputs, targets, mask, thresholds=[0, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 80],type=None):
|
|
53
|
+
# 检查输入类型
|
|
54
|
+
is_tensor = type
|
|
55
|
+
# 确保 inputs, targets, mask_data 在同一设备上
|
|
56
|
+
targets = targets
|
|
57
|
+
mask_data = mask
|
|
58
|
+
# 生成 mask:mask_data == 0 的区域为 True
|
|
59
|
+
# 展平 inputs, targets, mask
|
|
60
|
+
inputs_flat = inputs.flatten()
|
|
61
|
+
targets_flat = targets.flatten()
|
|
62
|
+
mask_flat = mask.flatten()
|
|
63
|
+
# 只计算 mask 区域的 IoU
|
|
64
|
+
if is_tensor:
|
|
65
|
+
inputs_flat = inputs_flat[mask_flat] # 只保留 mask 为 True 的区域
|
|
66
|
+
targets_flat = targets_flat[mask_flat]
|
|
67
|
+
else:
|
|
68
|
+
inputs_flat = inputs_flat[mask_flat == 1] # 只保留 mask 为 1 的区域
|
|
69
|
+
targets_flat = targets_flat[mask_flat == 1]
|
|
70
|
+
ts_sum = 0.0
|
|
71
|
+
for threshold in thresholds:
|
|
72
|
+
if is_tensor:
|
|
73
|
+
inputs_bin = (inputs_flat >= threshold).float()
|
|
74
|
+
targets_bin = (targets_flat >= threshold).float()
|
|
75
|
+
else:
|
|
76
|
+
inputs_bin = (inputs_flat >= threshold).astype(float)
|
|
77
|
+
targets_bin = (targets_flat >= threshold).astype(float)
|
|
78
|
+
intersection = (inputs_bin * targets_bin).sum()
|
|
79
|
+
union = inputs_bin.sum() + targets_bin.sum() - intersection
|
|
80
|
+
iou = (intersection + 1e-16) / (union + 1e-16)
|
|
81
|
+
ts_sum += iou
|
|
82
|
+
return ts_sum / len(thresholds)
|
|
83
|
+
|