shancx 1.8.92__py3-none-any.whl → 1.9.33.218__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shancx/3D/__init__.py +25 -0
- shancx/Algo/Class.py +11 -0
- shancx/Algo/CudaPrefetcher1.py +112 -0
- shancx/Algo/Fake_image.py +24 -0
- shancx/Algo/Hsml.py +391 -0
- shancx/Algo/L2Loss.py +10 -0
- shancx/Algo/MetricTracker.py +132 -0
- shancx/Algo/Normalize.py +66 -0
- shancx/Algo/OptimizerWScheduler.py +38 -0
- shancx/Algo/Rmageresize.py +79 -0
- shancx/Algo/Savemodel.py +33 -0
- shancx/Algo/SmoothL1_losses.py +27 -0
- shancx/Algo/Tqdm.py +62 -0
- shancx/Algo/__init__.py +121 -0
- shancx/Algo/checknan.py +28 -0
- shancx/Algo/iouJU.py +83 -0
- shancx/Algo/mask.py +25 -0
- shancx/Algo/psnr.py +9 -0
- shancx/Algo/ssim.py +70 -0
- shancx/Algo/structural_similarity.py +308 -0
- shancx/Algo/tool.py +704 -0
- shancx/Calmetrics/__init__.py +97 -0
- shancx/Calmetrics/calmetrics.py +14 -0
- shancx/Calmetrics/calmetricsmatrixLib.py +147 -0
- shancx/Calmetrics/rmseR2score.py +35 -0
- shancx/Clip/__init__.py +50 -0
- shancx/Cmd.py +126 -0
- shancx/Config_.py +26 -0
- shancx/Df/DataFrame.py +11 -2
- shancx/Df/__init__.py +17 -0
- shancx/Df/tool.py +0 -0
- shancx/Diffm/Psamples.py +18 -0
- shancx/Diffm/__init__.py +0 -0
- shancx/Diffm/test.py +207 -0
- shancx/Doc/__init__.py +214 -0
- shancx/E/__init__.py +178 -152
- shancx/Fillmiss/__init__.py +0 -0
- shancx/Fillmiss/imgidwJU.py +46 -0
- shancx/Fillmiss/imgidwLatLonJU.py +82 -0
- shancx/Gpu/__init__.py +55 -0
- shancx/H9/__init__.py +126 -0
- shancx/H9/ahi_read_hsd.py +877 -0
- shancx/H9/ahisearchtable.py +298 -0
- shancx/H9/geometry.py +2439 -0
- shancx/Hug/__init__.py +81 -0
- shancx/Inst.py +22 -0
- shancx/Lib.py +31 -0
- shancx/Mos/__init__.py +37 -0
- shancx/NN/__init__.py +235 -106
- shancx/Path1.py +161 -0
- shancx/Plot/GlobMap.py +276 -116
- shancx/Plot/__init__.py +491 -1
- shancx/Plot/draw_day_CR_PNG.py +4 -21
- shancx/Plot/exam.py +116 -0
- shancx/Plot/plotGlobal.py +325 -0
- shancx/{radar_nmc.py → Plot/radarNmc.py} +4 -34
- shancx/{subplots_single_china_map.py → Plot/single_china_map.py} +1 -1
- shancx/Point.py +46 -0
- shancx/QC.py +223 -0
- shancx/RdPzl/__init__.py +32 -0
- shancx/Read.py +72 -0
- shancx/Resize.py +79 -0
- shancx/SN/__init__.py +62 -123
- shancx/Time/GetTime.py +9 -3
- shancx/Time/__init__.py +66 -1
- shancx/Time/timeCycle.py +302 -0
- shancx/Time/tool.py +0 -0
- shancx/Train/__init__.py +74 -0
- shancx/Train/makelist.py +187 -0
- shancx/Train/multiGpu.py +27 -0
- shancx/Train/prepare.py +161 -0
- shancx/Train/renet50.py +157 -0
- shancx/ZR.py +12 -0
- shancx/__init__.py +333 -262
- shancx/args.py +27 -0
- shancx/bak.py +768 -0
- shancx/df2database.py +62 -2
- shancx/geosProj.py +80 -0
- shancx/info.py +38 -0
- shancx/netdfJU.py +231 -0
- shancx/sendM.py +59 -0
- shancx/tensBoard/__init__.py +28 -0
- shancx/wait.py +246 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/METADATA +15 -5
- shancx-1.9.33.218.dist-info/RECORD +91 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
- my_timer_decorator/__init__.py +0 -10
- shancx/Dsalgor/__init__.py +0 -19
- shancx/E/DFGRRIB.py +0 -30
- shancx/EN/DFGRRIB.py +0 -30
- shancx/EN/__init__.py +0 -148
- shancx/FileRead.py +0 -44
- shancx/Gray2RGB.py +0 -86
- shancx/M/__init__.py +0 -137
- shancx/MN/__init__.py +0 -133
- shancx/N/__init__.py +0 -131
- shancx/Plot/draw_day_CR_PNGUS.py +0 -206
- shancx/Plot/draw_day_CR_SVG.py +0 -275
- shancx/Plot/draw_day_pre_PNGUS.py +0 -205
- shancx/Plot/glob_nation_map.py +0 -116
- shancx/Plot/radar_nmc.py +0 -61
- shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
- shancx/Plot/radar_nmc_china_map_f.py +0 -121
- shancx/Plot/radar_nmc_us_map_f.py +0 -128
- shancx/Plot/subplots_compare_devlop.py +0 -36
- shancx/Plot/subplots_single_china_map.py +0 -45
- shancx/S/__init__.py +0 -138
- shancx/W/__init__.py +0 -132
- shancx/WN/__init__.py +0 -132
- shancx/code.py +0 -331
- shancx/draw_day_CR_PNG.py +0 -200
- shancx/draw_day_CR_PNGUS.py +0 -206
- shancx/draw_day_CR_SVG.py +0 -275
- shancx/draw_day_pre_PNGUS.py +0 -205
- shancx/makenetCDFN.py +0 -42
- shancx/mkIMGSCX.py +0 -92
- shancx/netCDF.py +0 -130
- shancx/radar_nmc_china_map_compare1.py +0 -50
- shancx/radar_nmc_china_map_f.py +0 -125
- shancx/radar_nmc_us_map_f.py +0 -67
- shancx/subplots_compare_devlop.py +0 -36
- shancx/tool.py +0 -18
- shancx/user/H8mess.py +0 -317
- shancx/user/__init__.py +0 -137
- shancx/user/cinradHJN.py +0 -496
- shancx/user/examMeso.py +0 -293
- shancx/user/hjnDAAS.py +0 -26
- shancx/user/hjnFTP.py +0 -81
- shancx/user/hjnGIS.py +0 -320
- shancx/user/hjnGPU.py +0 -21
- shancx/user/hjnIDW.py +0 -68
- shancx/user/hjnKDTree.py +0 -75
- shancx/user/hjnLAPSTransform.py +0 -47
- shancx/user/hjnMiscellaneous.py +0 -182
- shancx/user/hjnProj.py +0 -162
- shancx/user/inotify.py +0 -41
- shancx/user/matplotlibMess.py +0 -87
- shancx/user/mkNCHJN.py +0 -623
- shancx/user/newTypeRadar.py +0 -492
- shancx/user/test.py +0 -6
- shancx/user/tlogP.py +0 -129
- shancx/util_log.py +0 -33
- shancx/wtx/H8mess.py +0 -315
- shancx/wtx/__init__.py +0 -151
- shancx/wtx/cinradHJN.py +0 -496
- shancx/wtx/colormap.py +0 -64
- shancx/wtx/examMeso.py +0 -298
- shancx/wtx/hjnDAAS.py +0 -26
- shancx/wtx/hjnFTP.py +0 -81
- shancx/wtx/hjnGIS.py +0 -330
- shancx/wtx/hjnGPU.py +0 -21
- shancx/wtx/hjnIDW.py +0 -68
- shancx/wtx/hjnKDTree.py +0 -75
- shancx/wtx/hjnLAPSTransform.py +0 -47
- shancx/wtx/hjnLog.py +0 -78
- shancx/wtx/hjnMiscellaneous.py +0 -201
- shancx/wtx/hjnProj.py +0 -161
- shancx/wtx/inotify.py +0 -41
- shancx/wtx/matplotlibMess.py +0 -87
- shancx/wtx/mkNCHJN.py +0 -613
- shancx/wtx/newTypeRadar.py +0 -492
- shancx/wtx/test.py +0 -6
- shancx/wtx/tlogP.py +0 -129
- shancx-1.8.92.dist-info/RECORD +0 -99
- /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
- {shancx-1.8.92.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
shancx/Train/prepare.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
from io import BytesIO
|
|
3
|
+
import multiprocessing
|
|
4
|
+
from multiprocessing import Lock, Process, RawValue
|
|
5
|
+
from functools import partial
|
|
6
|
+
from multiprocessing.sharedctypes import RawValue
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
from torchvision.transforms import functional as trans_fn
|
|
10
|
+
import os
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
import lmdb
|
|
13
|
+
import numpy as np
|
|
14
|
+
import time
|
|
15
|
+
def resize_and_convert(img, size, resample):
|
|
16
|
+
if(img.size[0] != size):
|
|
17
|
+
img = trans_fn.resize(img, size, resample)
|
|
18
|
+
img = trans_fn.center_crop(img, size)
|
|
19
|
+
return img
|
|
20
|
+
def image_convert_bytes(img):
|
|
21
|
+
buffer = BytesIO()
|
|
22
|
+
img.save(buffer, format='png')
|
|
23
|
+
return buffer.getvalue()
|
|
24
|
+
def resize_multiple(img, sizes=(16, 128), resample=Image.BICUBIC, lmdb_save=False):
|
|
25
|
+
lr_img = resize_and_convert(img, sizes[0], resample)
|
|
26
|
+
hr_img = resize_and_convert(img, sizes[1], resample)
|
|
27
|
+
sr_img = resize_and_convert(lr_img, sizes[1], resample)
|
|
28
|
+
if lmdb_save:
|
|
29
|
+
lr_img = image_convert_bytes(lr_img)
|
|
30
|
+
hr_img = image_convert_bytes(hr_img)
|
|
31
|
+
sr_img = image_convert_bytes(sr_img)
|
|
32
|
+
return [lr_img, hr_img, sr_img]
|
|
33
|
+
def resize_worker(img_file, sizes, resample, lmdb_save=False):
|
|
34
|
+
img = Image.open(img_file)
|
|
35
|
+
img = img.convert('RGB')
|
|
36
|
+
out = resize_multiple(
|
|
37
|
+
img, sizes=sizes, resample=resample, lmdb_save=lmdb_save)
|
|
38
|
+
return img_file.name.split('.')[0], out
|
|
39
|
+
class WorkingContext():
|
|
40
|
+
def __init__(self, resize_fn, lmdb_save, out_path, env, sizes):
|
|
41
|
+
self.resize_fn = resize_fn
|
|
42
|
+
self.lmdb_save = lmdb_save
|
|
43
|
+
self.out_path = out_path
|
|
44
|
+
self.env = env
|
|
45
|
+
self.sizes = sizes
|
|
46
|
+
self.counter = RawValue('i', 0)
|
|
47
|
+
self.counter_lock = Lock()
|
|
48
|
+
def inc_get(self):
|
|
49
|
+
with self.counter_lock:
|
|
50
|
+
self.counter.value += 1
|
|
51
|
+
return self.counter.value
|
|
52
|
+
def value(self):
|
|
53
|
+
with self.counter_lock:
|
|
54
|
+
return self.counter.value
|
|
55
|
+
def prepare_process_worker(wctx, file_subset):
|
|
56
|
+
for file in file_subset:
|
|
57
|
+
i, imgs = wctx.resize_fn(file)
|
|
58
|
+
lr_img, hr_img, sr_img = imgs
|
|
59
|
+
if not wctx.lmdb_save:
|
|
60
|
+
lr_img.save(
|
|
61
|
+
'{}/lr_{}/{}.png'.format(wctx.out_path, wctx.sizes[0], i.zfill(5)))
|
|
62
|
+
hr_img.save(
|
|
63
|
+
'{}/hr_{}/{}.png'.format(wctx.out_path, wctx.sizes[1], i.zfill(5)))
|
|
64
|
+
sr_img.save(
|
|
65
|
+
'{}/sr_{}_{}/{}.png'.format(wctx.out_path, wctx.sizes[0], wctx.sizes[1], i.zfill(5)))
|
|
66
|
+
else:
|
|
67
|
+
with wctx.env.begin(write=True) as txn:
|
|
68
|
+
txn.put('lr_{}_{}'.format(
|
|
69
|
+
wctx.sizes[0], i.zfill(5)).encode('utf-8'), lr_img)
|
|
70
|
+
txn.put('hr_{}_{}'.format(
|
|
71
|
+
wctx.sizes[1], i.zfill(5)).encode('utf-8'), hr_img)
|
|
72
|
+
txn.put('sr_{}_{}_{}'.format(
|
|
73
|
+
wctx.sizes[0], wctx.sizes[1], i.zfill(5)).encode('utf-8'), sr_img)
|
|
74
|
+
curr_total = wctx.inc_get()
|
|
75
|
+
if wctx.lmdb_save:
|
|
76
|
+
with wctx.env.begin(write=True) as txn:
|
|
77
|
+
txn.put('length'.encode('utf-8'), str(curr_total).encode('utf-8'))
|
|
78
|
+
def all_threads_inactive(worker_threads):
|
|
79
|
+
for thread in worker_threads:
|
|
80
|
+
if thread.is_alive():
|
|
81
|
+
return False
|
|
82
|
+
return True
|
|
83
|
+
def prepare(img_path, out_path, n_worker=3, sizes=(16, 128), resample=Image.BICUBIC, lmdb_save=False):
|
|
84
|
+
resize_fn = partial(resize_worker, sizes=sizes,
|
|
85
|
+
resample=resample, lmdb_save=lmdb_save)
|
|
86
|
+
files = [p for p in Path(
|
|
87
|
+
'{}'.format(img_path)).glob(f'**/*')]
|
|
88
|
+
|
|
89
|
+
if not lmdb_save:
|
|
90
|
+
os.makedirs(out_path, exist_ok=True)
|
|
91
|
+
os.makedirs('{}/lr_{}'.format(out_path, sizes[0]), exist_ok=True)
|
|
92
|
+
os.makedirs('{}/hr_{}'.format(out_path, sizes[1]), exist_ok=True)
|
|
93
|
+
os.makedirs('{}/sr_{}_{}'.format(out_path,
|
|
94
|
+
sizes[0], sizes[1]), exist_ok=True)
|
|
95
|
+
else:
|
|
96
|
+
env = lmdb.open(out_path, map_size=1024 ** 4, readahead=False)
|
|
97
|
+
if n_worker > 1:
|
|
98
|
+
# prepare data subsets
|
|
99
|
+
multi_env = None
|
|
100
|
+
if lmdb_save:
|
|
101
|
+
multi_env = env
|
|
102
|
+
file_subsets = np.array_split(files, n_worker)
|
|
103
|
+
worker_threads = []
|
|
104
|
+
wctx = WorkingContext(resize_fn, lmdb_save, out_path, multi_env, sizes)
|
|
105
|
+
# start worker processes, monitor results
|
|
106
|
+
for i in range(n_worker):
|
|
107
|
+
proc = Process(target=prepare_process_worker, args=(wctx, file_subsets[i]))
|
|
108
|
+
proc.start()
|
|
109
|
+
worker_threads.append(proc)
|
|
110
|
+
total_count = str(len(files))
|
|
111
|
+
while not all_threads_inactive(worker_threads):
|
|
112
|
+
print("\r{}/{} images processed".format(wctx.value(), total_count), end=" ")
|
|
113
|
+
time.sleep(0.1)
|
|
114
|
+
else:
|
|
115
|
+
total = 0
|
|
116
|
+
for file in tqdm(files):
|
|
117
|
+
i, imgs = resize_fn(file)
|
|
118
|
+
lr_img, hr_img, sr_img = imgs
|
|
119
|
+
if not lmdb_save:
|
|
120
|
+
lr_img.save(
|
|
121
|
+
'{}/lr_{}/{}.png'.format(out_path, sizes[0], i.zfill(5)))
|
|
122
|
+
hr_img.save(
|
|
123
|
+
'{}/hr_{}/{}.png'.format(out_path, sizes[1], i.zfill(5)))
|
|
124
|
+
sr_img.save(
|
|
125
|
+
'{}/sr_{}_{}/{}.png'.format(out_path, sizes[0], sizes[1], i.zfill(5)))
|
|
126
|
+
else:
|
|
127
|
+
with env.begin(write=True) as txn:
|
|
128
|
+
txn.put('lr_{}_{}'.format(
|
|
129
|
+
sizes[0], i.zfill(5)).encode('utf-8'), lr_img)
|
|
130
|
+
txn.put('hr_{}_{}'.format(
|
|
131
|
+
sizes[1], i.zfill(5)).encode('utf-8'), hr_img)
|
|
132
|
+
txn.put('sr_{}_{}_{}'.format(
|
|
133
|
+
sizes[0], sizes[1], i.zfill(5)).encode('utf-8'), sr_img)
|
|
134
|
+
total += 1
|
|
135
|
+
if lmdb_save:
|
|
136
|
+
with env.begin(write=True) as txn:
|
|
137
|
+
txn.put('length'.encode('utf-8'), str(total).encode('utf-8'))
|
|
138
|
+
|
|
139
|
+
if __name__ == '__main__':
|
|
140
|
+
parser = argparse.ArgumentParser()
|
|
141
|
+
parser.add_argument('--path', '-p', type=str,
|
|
142
|
+
default='{}/Dataset/celebahq_256'.format(Path.home()))
|
|
143
|
+
parser.add_argument('--out', '-o', type=str,
|
|
144
|
+
default='./dataset/celebahq')
|
|
145
|
+
parser.add_argument('--size', type=str, default='64,512')
|
|
146
|
+
parser.add_argument('--n_worker', type=int, default=3)
|
|
147
|
+
parser.add_argument('--resample', type=str, default='bicubic')
|
|
148
|
+
# default save in png format
|
|
149
|
+
parser.add_argument('--lmdb', '-l', action='store_true')
|
|
150
|
+
|
|
151
|
+
args = parser.parse_args()
|
|
152
|
+
|
|
153
|
+
resample_map = {'bilinear': Image.BILINEAR, 'bicubic': Image.BICUBIC}
|
|
154
|
+
resample = resample_map[args.resample]
|
|
155
|
+
sizes = [int(s.strip()) for s in args.size.split(',')]
|
|
156
|
+
|
|
157
|
+
args.out = '{}_{}_{}'.format(args.out, sizes[0], sizes[1])
|
|
158
|
+
|
|
159
|
+
#1.输入 输出文件夹
|
|
160
|
+
prepare(args.path, args.out, args.n_worker,
|
|
161
|
+
sizes=sizes, resample=resample, lmdb_save=args.lmdb)
|
shancx/Train/renet50.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
#!/usr/bin/python
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
import torch
|
|
4
|
+
import torchvision
|
|
5
|
+
import torchvision.transforms as transforms
|
|
6
|
+
import cv2
|
|
7
|
+
import numpy as np
|
|
8
|
+
from typing import List, Tuple
|
|
9
|
+
from PIL import Image
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
# class ModelLoader:
|
|
13
|
+
# """Handles loading and managing the pretrained model"""
|
|
14
|
+
|
|
15
|
+
# def __init__(self, model_name: str = 'resnet50'):
|
|
16
|
+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
17
|
+
# self.model = self._load_pretrained_model(model_name)
|
|
18
|
+
|
|
19
|
+
# def _load_pretrained_model(self, model_name: str) -> torch.nn.Module:
|
|
20
|
+
# """Load pretrained model from torchvision.models"""
|
|
21
|
+
# model = getattr(torchvision.models, model_name)(pretrained=True)
|
|
22
|
+
# return model.eval().to(self.device)
|
|
23
|
+
|
|
24
|
+
class ModelLoader:
|
|
25
|
+
"""Handles loading and managing the pretrained model"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, model_name: str = 'resnet50'):
|
|
28
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
29
|
+
self.model_dir = "models"
|
|
30
|
+
os.makedirs(self.model_dir, exist_ok=True)
|
|
31
|
+
self.model = self._load_pretrained_model(model_name)
|
|
32
|
+
|
|
33
|
+
def _load_pretrained_model(self, model_name: str) -> torch.nn.Module:
|
|
34
|
+
"""Load pretrained model from torchvision.models. If not found locally, download and save it."""
|
|
35
|
+
model_path = os.path.join(self.model_dir, f"{model_name}.pth")
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
model = getattr(torchvision.models, model_name)(pretrained=False)
|
|
39
|
+
model.load_state_dict(torch.load(model_path))
|
|
40
|
+
print(f"Loaded {model_name} from local directory: {model_path}")
|
|
41
|
+
except FileNotFoundError:
|
|
42
|
+
print(f"Model {model_name} not found locally. Downloading...")
|
|
43
|
+
model = getattr(torchvision.models, model_name)(pretrained=True)
|
|
44
|
+
torch.save(model.state_dict(), model_path)
|
|
45
|
+
print(f"Downloaded and saved {model_name} to: {model_path}")
|
|
46
|
+
except Exception as e:
|
|
47
|
+
raise RuntimeError(f"Failed to load or download the model: {e}")
|
|
48
|
+
|
|
49
|
+
return model.eval().to(self.device)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ImageProcessor:
|
|
53
|
+
"""Handles image preprocessing and transformations"""
|
|
54
|
+
|
|
55
|
+
def __init__(self):
|
|
56
|
+
self.transform = transforms.Compose([
|
|
57
|
+
transforms.Resize(256),
|
|
58
|
+
transforms.CenterCrop(224),
|
|
59
|
+
transforms.ToTensor(),
|
|
60
|
+
transforms.Normalize(
|
|
61
|
+
mean=[0.485, 0.456, 0.406],
|
|
62
|
+
std=[0.229, 0.224, 0.225]
|
|
63
|
+
)
|
|
64
|
+
])
|
|
65
|
+
|
|
66
|
+
def preprocess_image(self, image_path: str) -> torch.Tensor:
|
|
67
|
+
"""Load and preprocess image for model input"""
|
|
68
|
+
try:
|
|
69
|
+
if not os.path.exists(image_path):
|
|
70
|
+
raise FileNotFoundError(f"Image file not found at {image_path}")
|
|
71
|
+
|
|
72
|
+
image = cv2.imread(image_path)
|
|
73
|
+
if image is None:
|
|
74
|
+
raise ValueError(f"Unable to read image at {image_path}")
|
|
75
|
+
|
|
76
|
+
# Convert BGR to RGB
|
|
77
|
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
78
|
+
# Convert to PIL Image for torchvision transforms
|
|
79
|
+
image = Image.fromarray(image)
|
|
80
|
+
return self.transform(image).unsqueeze(0)
|
|
81
|
+
except Exception as e:
|
|
82
|
+
print(f"Error processing image: {e}")
|
|
83
|
+
raise
|
|
84
|
+
|
|
85
|
+
class Predictor:
|
|
86
|
+
"""Handles model predictions and visualization"""
|
|
87
|
+
|
|
88
|
+
def __init__(self, model_loader: ModelLoader, labels_path: str = 'imagenet_classes.txt'):
|
|
89
|
+
self.model = model_loader.model
|
|
90
|
+
self.device = model_loader.device
|
|
91
|
+
self.labels = self._load_labels(labels_path)
|
|
92
|
+
|
|
93
|
+
def _load_labels(self, labels_path: str) -> List[str]:
|
|
94
|
+
"""Load class labels from file"""
|
|
95
|
+
try:
|
|
96
|
+
with open(labels_path) as f:
|
|
97
|
+
labels = [line.strip() for line in f.readlines()]
|
|
98
|
+
if len(labels) != 1000:
|
|
99
|
+
raise ValueError(f"Expected 1000 ImageNet classes, got {len(labels)}")
|
|
100
|
+
return labels
|
|
101
|
+
except Exception as e:
|
|
102
|
+
print(f"Error loading labels: {e}")
|
|
103
|
+
raise
|
|
104
|
+
|
|
105
|
+
def predict(self, input_tensor: torch.Tensor) -> Tuple[int, str]:
|
|
106
|
+
"""Run model prediction on input tensor"""
|
|
107
|
+
with torch.no_grad():
|
|
108
|
+
input_tensor = input_tensor.to(self.device)
|
|
109
|
+
pred = self.model(input_tensor)
|
|
110
|
+
pred_index = torch.argmax(pred, 1).cpu().detach().numpy()[0]
|
|
111
|
+
|
|
112
|
+
if pred_index >= len(self.labels):
|
|
113
|
+
raise ValueError(f"Prediction index {pred_index} out of range for {len(self.labels)} classes")
|
|
114
|
+
|
|
115
|
+
return pred_index, self.labels[pred_index]
|
|
116
|
+
|
|
117
|
+
def visualize_prediction(self, image_path: str, class_name: str):
|
|
118
|
+
"""Display image with predicted class label"""
|
|
119
|
+
try:
|
|
120
|
+
image = cv2.imread(image_path)
|
|
121
|
+
if image is None:
|
|
122
|
+
raise FileNotFoundError(f"Image not found at {image_path}")
|
|
123
|
+
|
|
124
|
+
cv2.putText(image, class_name, (50, 50),
|
|
125
|
+
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
|
|
126
|
+
# cv2.imshow("Prediction", image)
|
|
127
|
+
# cv2.waitKey(0)
|
|
128
|
+
# cv2.destroyAllWindows()
|
|
129
|
+
output_path = f"output_image_{class_name}.jpg" # 替换为你想保存的路径
|
|
130
|
+
cv2.imwrite(output_path, image)
|
|
131
|
+
print(f"Image saved to {output_path}")
|
|
132
|
+
except Exception as e:
|
|
133
|
+
print(f"Error visualizing prediction: {e}")
|
|
134
|
+
raise
|
|
135
|
+
|
|
136
|
+
def main():
|
|
137
|
+
try:
|
|
138
|
+
# Initialize components
|
|
139
|
+
model_loader = ModelLoader()
|
|
140
|
+
image_processor = ImageProcessor()
|
|
141
|
+
predictor = Predictor(model_loader)
|
|
142
|
+
|
|
143
|
+
# Process image and make prediction
|
|
144
|
+
image_path = "./space_shuttle.jpg"
|
|
145
|
+
input_tensor = image_processor.preprocess_image(image_path)
|
|
146
|
+
pred_index, class_name = predictor.predict(input_tensor)
|
|
147
|
+
|
|
148
|
+
# Display results
|
|
149
|
+
print(f"Predicted class index: {pred_index}")
|
|
150
|
+
print(f"Predicted class name: {class_name}")
|
|
151
|
+
predictor.visualize_prediction(image_path, class_name)
|
|
152
|
+
|
|
153
|
+
except Exception as e:
|
|
154
|
+
print(f"Error in main execution: {e}")
|
|
155
|
+
|
|
156
|
+
# if __name__ == "__main__":
|
|
157
|
+
# main()
|
shancx/ZR.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
def dbz2rfl(d): # 全局定义
|
|
2
|
+
return 10. ** (d / 10.)
|
|
3
|
+
|
|
4
|
+
def rfl2mmh(z, a=200., b=1.6): # 全局定义
|
|
5
|
+
return (z / a) ** (1. / b)
|
|
6
|
+
|
|
7
|
+
def ZR1(conf):
|
|
8
|
+
dbz = conf[0]
|
|
9
|
+
z_values = dbz2rfl(dbz) # 直接调用全局函数
|
|
10
|
+
rainfall_mmh = rfl2mmh(z_values)
|
|
11
|
+
rainfall_mmh[rainfall_mmh < 0.1] = 0
|
|
12
|
+
return rainfall_mmh
|