jacksung-dev 0.0.4.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jacksung/__init__.py +1 -0
- jacksung/ai/GeoAttX.py +356 -0
- jacksung/ai/GeoNet/__init__.py +0 -0
- jacksung/ai/GeoNet/m_block.py +393 -0
- jacksung/ai/GeoNet/m_blockV2.py +442 -0
- jacksung/ai/GeoNet/m_network.py +107 -0
- jacksung/ai/GeoNet/m_networkV2.py +91 -0
- jacksung/ai/__init__.py +0 -0
- jacksung/ai/latex_tool.py +199 -0
- jacksung/ai/metrics.py +181 -0
- jacksung/ai/utils/__init__.py +0 -0
- jacksung/ai/utils/cmorph.py +42 -0
- jacksung/ai/utils/data_parallelV2.py +90 -0
- jacksung/ai/utils/fy.py +333 -0
- jacksung/ai/utils/goes.py +161 -0
- jacksung/ai/utils/gsmap.py +24 -0
- jacksung/ai/utils/imerg.py +159 -0
- jacksung/ai/utils/metsat.py +164 -0
- jacksung/ai/utils/norm_util.py +109 -0
- jacksung/ai/utils/util.py +300 -0
- jacksung/libs/times.ttf +0 -0
- jacksung/utils/__init__.py +1 -0
- jacksung/utils/base_db.py +72 -0
- jacksung/utils/cache.py +71 -0
- jacksung/utils/data_convert.py +273 -0
- jacksung/utils/exception.py +27 -0
- jacksung/utils/fastnumpy.py +115 -0
- jacksung/utils/figure.py +251 -0
- jacksung/utils/hash.py +26 -0
- jacksung/utils/image.py +221 -0
- jacksung/utils/log.py +86 -0
- jacksung/utils/login.py +149 -0
- jacksung/utils/mean_std.py +66 -0
- jacksung/utils/multi_task.py +129 -0
- jacksung/utils/number.py +6 -0
- jacksung/utils/nvidia.py +140 -0
- jacksung/utils/time.py +87 -0
- jacksung/utils/web.py +63 -0
- jacksung_dev-0.0.4.15.dist-info/LICENSE +201 -0
- jacksung_dev-0.0.4.15.dist-info/METADATA +228 -0
- jacksung_dev-0.0.4.15.dist-info/RECORD +44 -0
- jacksung_dev-0.0.4.15.dist-info/WHEEL +5 -0
- jacksung_dev-0.0.4.15.dist-info/entry_points.txt +3 -0
- jacksung_dev-0.0.4.15.dist-info/top_level.txt +1 -0
jacksung/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from jacksung import *
|
jacksung/ai/GeoAttX.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
import random
|
|
3
|
+
from datetime import timedelta
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from jacksung.utils.time import RemainTime, Stopwatch, cur_timestamp_str
|
|
7
|
+
from jacksung.ai.utils.norm_util import PredNormalization, PrecNormalization, Normalization
|
|
8
|
+
import numpy as np
|
|
9
|
+
from jacksung.utils.data_convert import np2tif
|
|
10
|
+
from jacksung.utils.cache import Cache
|
|
11
|
+
from jacksung.ai.utils import fy, goes, metsat
|
|
12
|
+
from einops import rearrange
|
|
13
|
+
from jacksung.ai.utils.util import parse_config, data_to_device, clipSatelliteNP
|
|
14
|
+
from jacksung.ai.GeoNet.m_network import GeoNet
|
|
15
|
+
from jacksung.ai.GeoNet.m_networkV2 import GeoNet as GeoNetV2
|
|
16
|
+
from jacksung.utils.exception import NoFileException, NanNPException
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
|
|
19
|
+
AGRI = 'agri'
|
|
20
|
+
ABI = 'abi'
|
|
21
|
+
SEVIRI = 'seviri'
|
|
22
|
+
AHI = 'ahi'
|
|
23
|
+
FCI = 'fci'
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GeoAttX:
|
|
27
|
+
def __init__(self, config=None, root_path=None, task_type=None, area=((100, 140, 10), (20, 60, 10)), lock=None):
|
|
28
|
+
self.root_path = None
|
|
29
|
+
self.timestamp = None
|
|
30
|
+
self.dir_name = None
|
|
31
|
+
self.device, self.args = parse_config(config)
|
|
32
|
+
self.task_type = task_type
|
|
33
|
+
self.set_root_path(root_path)
|
|
34
|
+
self.area = area
|
|
35
|
+
self.lock = lock
|
|
36
|
+
|
|
37
|
+
def get_root_path(self):
|
|
38
|
+
return self.root_path
|
|
39
|
+
|
|
40
|
+
def get_dir_name(self):
|
|
41
|
+
return self.dir_name
|
|
42
|
+
|
|
43
|
+
def set_root_path(self, root_path=None, dir_name=None):
|
|
44
|
+
self.timestamp = cur_timestamp_str()
|
|
45
|
+
root_path = root_path if root_path else self.args.save_path
|
|
46
|
+
dir_name = dir_name if dir_name else self.task_type + '-' + self.args.model + '-' + self.timestamp + '_' + str(
|
|
47
|
+
random.randint(1000, 9999))
|
|
48
|
+
self.dir_name = dir_name
|
|
49
|
+
self.root_path = os.path.join(root_path, dir_name)
|
|
50
|
+
|
|
51
|
+
def load_model(self, path, version=1, c_in=None):
|
|
52
|
+
if version == 1:
|
|
53
|
+
model = GeoNet(window_sizes=self.args.window_sizes, n_lgab=self.args.n_lgab, c_in=self.args.c_in,
|
|
54
|
+
c_lgan=self.args.c_lgan, r_expand=self.args.r_expand, down_sample=self.args.down_sample,
|
|
55
|
+
num_heads=self.args.num_heads, task=self.task_type if self.task_type else self.args.task,
|
|
56
|
+
downstage=self.args.downstage)
|
|
57
|
+
else:
|
|
58
|
+
model = GeoNetV2(window_sizes=self.args.window_sizes, n_lgab=self.args.n_lgab, c_in=c_in,
|
|
59
|
+
c_lgan=self.args.c_lgan, r_expand=self.args.r_expand, down_sample=self.args.down_sample,
|
|
60
|
+
num_heads=self.args.num_heads, downstage=self.args.downstage)
|
|
61
|
+
|
|
62
|
+
ckpt = torch.load(path, map_location=torch.device(self.device))
|
|
63
|
+
model.load(ckpt['model_state_dict'])
|
|
64
|
+
model = model.to(self.device)
|
|
65
|
+
model = model.eval()
|
|
66
|
+
model.requires_grad_(False)
|
|
67
|
+
return model
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class GeoAttX_I(GeoAttX):
|
|
71
|
+
def __init__(self, data_path, x1_path, x4_path, x12_path, root_path=None, config='config_predict.yml',
|
|
72
|
+
area=((100, 140, 10), (20, 60, 10)), cache_size=1, lock=None, device=None):
|
|
73
|
+
super().__init__(config=config, root_path=root_path, task_type='pred', area=area, lock=lock)
|
|
74
|
+
if device is not None:
|
|
75
|
+
self.device = device
|
|
76
|
+
self.f, self.n, self.ys = None, None, None
|
|
77
|
+
self.data_path = data_path
|
|
78
|
+
self.x1 = self.load_model(x1_path)
|
|
79
|
+
self.x4 = self.load_model(x4_path)
|
|
80
|
+
self.x12 = self.load_model(x12_path)
|
|
81
|
+
# self.x48 = self.load_model(x48_path)
|
|
82
|
+
self.norm = PredNormalization(self.args.pred_data_path)
|
|
83
|
+
self.norm.mean, self.norm.std = data_to_device([self.norm.mean, self.norm.std], self.device, self.args.fp)
|
|
84
|
+
self.ld = None
|
|
85
|
+
self.cache = Cache(cache_size)
|
|
86
|
+
|
|
87
|
+
def save(self, file_name, ys):
|
|
88
|
+
file_info = fy.prase_filename(file_name)
|
|
89
|
+
ld = int(file_info["position"])
|
|
90
|
+
for idx, (k, y) in enumerate(ys.items()):
|
|
91
|
+
# coord = getFY_coord_min(ld)
|
|
92
|
+
coord = fy.getFY_coord_clip(self.area)
|
|
93
|
+
np2tif(y, save_path=self.root_path, out_name=f'{k.strftime("%Y%m%d_%H%M%S")}', coord=coord,
|
|
94
|
+
dtype=np.float32, print_log=False, dim_value=[{'value': [str(x) for x in list(range(9, 16))]}])
|
|
95
|
+
td = k - file_info['start']
|
|
96
|
+
mins = td.days * 24 * 60 + td.seconds // 60
|
|
97
|
+
target_filename = self.get_filename(file_name, mins)
|
|
98
|
+
p_path = self.get_path_by_filename(target_filename)
|
|
99
|
+
if idx >= 1 and os.path.exists(p_path):
|
|
100
|
+
coord = fy.getFY_coord_clip(self.area)
|
|
101
|
+
p_data = fy.getNPfromHDF(p_path)
|
|
102
|
+
p_data = clipSatelliteNP(p_data, ld=self.ld, area=self.area)
|
|
103
|
+
if p_data is not None:
|
|
104
|
+
p_data = p_data[2:, :, :]
|
|
105
|
+
np2tif(p_data, save_path=self.root_path, out_name=f'target_{k.strftime("%Y%m%d_%H%M%S")}',
|
|
106
|
+
print_log=False, coord=coord, dtype=np.float32,
|
|
107
|
+
dim_value=[{'value': [str(x) for x in list(range(9, 16))]}])
|
|
108
|
+
print(f'data saved in {self.root_path}')
|
|
109
|
+
with open(os.path.join(self.root_path, 'info.log'), 'w') as f:
|
|
110
|
+
f.write(f'输入数据:{file_info["start"]} {file_info["position"]} {file_info["end"]}\n')
|
|
111
|
+
for k, y in ys.items():
|
|
112
|
+
f.write(f'预测:{k}\n')
|
|
113
|
+
return self.root_path
|
|
114
|
+
|
|
115
|
+
def get_filename(self, file_name, mins):
|
|
116
|
+
file_info = fy.prase_filename(file_name)
|
|
117
|
+
new_file_name = file_name.replace(
|
|
118
|
+
f'{file_info["start"].strftime("%Y%m%d%H%M%S")}_{file_info["end"].strftime("%Y%m%d%H%M%S")}',
|
|
119
|
+
f'{(file_info["start"] + timedelta(minutes=mins)).strftime("%Y%m%d%H%M%S")}_'
|
|
120
|
+
f'{(file_info["end"] + timedelta(minutes=mins)).strftime("%Y%m%d%H%M%S")}')
|
|
121
|
+
return new_file_name
|
|
122
|
+
|
|
123
|
+
def get_path_by_filename(self, file_name):
|
|
124
|
+
file_info = fy.prase_filename(file_name)
|
|
125
|
+
return f'{self.data_path}{os.sep}downloaded_file{os.sep}{file_info["start"].year}{os.sep}{file_info["start"].month}{os.sep}{file_info["start"].day}{os.sep}{file_name}'
|
|
126
|
+
# return f'{self.data_path}/{file_name}'
|
|
127
|
+
|
|
128
|
+
def numpy2tensor(self, f_data):
|
|
129
|
+
f_data = torch.from_numpy(f_data)
|
|
130
|
+
f = data_to_device([f_data], self.device, self.args.fp)[0]
|
|
131
|
+
f = rearrange(f, '(b c) h w -> b c h w', b=1)
|
|
132
|
+
return f
|
|
133
|
+
|
|
134
|
+
def get_exist_by_filename_and_mins(self, file_name, mins):
|
|
135
|
+
f_path = self.get_filename(file_name, mins)
|
|
136
|
+
f_path = self.get_path_by_filename(f_path)
|
|
137
|
+
if not os.path.exists(f_path):
|
|
138
|
+
raise NoFileException(f_path)
|
|
139
|
+
f_data = fy.getNPfromHDF(f_path)
|
|
140
|
+
f_data = clipSatelliteNP(f_data, ld=self.ld, area=self.area)
|
|
141
|
+
if type(f_data) is not str:
|
|
142
|
+
f_data = f_data[2:, :, :]
|
|
143
|
+
else:
|
|
144
|
+
raise NanNPException(f_path)
|
|
145
|
+
# f_data = zoom(f_data.astype(np.float32), (1, 1 / 5, 1 / 5))
|
|
146
|
+
return self.numpy2tensor(f_data)
|
|
147
|
+
|
|
148
|
+
def mean_std2Tensor(self, in_data, h, w):
|
|
149
|
+
in_data = rearrange(in_data, '(b h w) c->b c h w', h=1, w=1)
|
|
150
|
+
in_data = in_data.expand(1, 7, h, w)
|
|
151
|
+
return in_data
|
|
152
|
+
|
|
153
|
+
def secdOrderStd(self, t_data, mean0, std0, mean, std):
|
|
154
|
+
h, w = t_data.shape[2], t_data.shape[3]
|
|
155
|
+
t_data = (t_data - self.mean_std2Tensor(mean0, h, w)) / self.mean_std2Tensor(std0, h, w)
|
|
156
|
+
t_data = t_data * self.mean_std2Tensor(std, h, w) + self.mean_std2Tensor(mean, h, w)
|
|
157
|
+
return t_data
|
|
158
|
+
|
|
159
|
+
def predict(self, file_name, step=360, p_steps=(48, 12, 4, 1), print_log=True):
|
|
160
|
+
try:
|
|
161
|
+
file_info = fy.prase_filename(file_name)
|
|
162
|
+
self.ld = int(file_info["position"])
|
|
163
|
+
step = step // 15
|
|
164
|
+
# if step > 24:
|
|
165
|
+
# step = 24
|
|
166
|
+
if print_log:
|
|
167
|
+
print(f'当前时刻:{file_info["start"]}\n预测长度:{step * 15}分钟')
|
|
168
|
+
task_progress = []
|
|
169
|
+
p_steps = sorted(p_steps, reverse=True)
|
|
170
|
+
while step > 0:
|
|
171
|
+
for p_step in p_steps:
|
|
172
|
+
if step >= p_step:
|
|
173
|
+
task_progress.append(p_step)
|
|
174
|
+
step -= p_step
|
|
175
|
+
break
|
|
176
|
+
task_progress.reverse()
|
|
177
|
+
if print_log:
|
|
178
|
+
print(f'正在预测:{file_info["start"] + timedelta(minutes=sum(task_progress) * 15)}...')
|
|
179
|
+
n = self.get_exist_by_filename_and_mins(file_name, 0)
|
|
180
|
+
now_date = file_info["start"]
|
|
181
|
+
porcess_list = {now_date: n.detach().cpu().numpy()[0]}
|
|
182
|
+
n = self.norm.norm(n)
|
|
183
|
+
o_mean = torch.mean(n, dim=(2, 3))
|
|
184
|
+
o_std = torch.std(n, dim=(2, 3))
|
|
185
|
+
for step in task_progress:
|
|
186
|
+
pre_date = now_date - timedelta(minutes=15 * step)
|
|
187
|
+
now_date += timedelta(minutes=15 * step)
|
|
188
|
+
if print_log:
|
|
189
|
+
print(f'正在预测 {now_date}...')
|
|
190
|
+
if pre_date in porcess_list:
|
|
191
|
+
f = porcess_list[pre_date]
|
|
192
|
+
f = self.numpy2tensor(f)
|
|
193
|
+
else:
|
|
194
|
+
f = self.get_exist_by_filename_and_mins(file_name,
|
|
195
|
+
-int((file_info["start"] - pre_date).seconds / 60))
|
|
196
|
+
f = self.norm.norm(f)
|
|
197
|
+
st = Stopwatch()
|
|
198
|
+
y_ = eval(f'self.x{step}(f, n)')
|
|
199
|
+
if print_log:
|
|
200
|
+
print(f'预测耗时: {st.reset()} 秒')
|
|
201
|
+
# mean and std再标准化
|
|
202
|
+
y_mean = torch.mean(y_, dim=(2, 3))
|
|
203
|
+
y_std = torch.std(y_, dim=(2, 3))
|
|
204
|
+
# 二次标准化
|
|
205
|
+
y_ = self.secdOrderStd(y_, y_mean, y_std, o_mean, o_std)
|
|
206
|
+
# y_ = torch.Tensor(np.zeros((1, 7, 2400, 2400), dtype=np.float32)).to(self.device)
|
|
207
|
+
# for i in range(5):
|
|
208
|
+
# for j in range(5):
|
|
209
|
+
# n_ = n[:, :, 480 * i:480 * (i + 1), 480 * j:480 * (j + 1)]
|
|
210
|
+
# f_ = f[:, :, 480 * i:480 * (i + 1), 480 * j:480 * (j + 1)]
|
|
211
|
+
# y_[:, :, 480 * i:480 * (i + 1), 480 * j:480 * (j + 1)] = eval(f'self.x{step}(f_, n_)')
|
|
212
|
+
del f
|
|
213
|
+
n = y_
|
|
214
|
+
porcess_list[now_date] = self.norm.denorm(y_).detach().cpu().numpy()[0]
|
|
215
|
+
except (NoFileException, NanNPException) as e:
|
|
216
|
+
os.makedirs(self.root_path, exist_ok=True)
|
|
217
|
+
with open(os.path.join(self.root_path, 'err.log'), 'a') as f:
|
|
218
|
+
filename = e.file_name.split(os.sep)[-1]
|
|
219
|
+
file_info = fy.prase_filename(filename)
|
|
220
|
+
f.write(f'{e.__class__}: {file_info["start"]} {e.file_name}\n')
|
|
221
|
+
return {}
|
|
222
|
+
return porcess_list
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class GeoAttX_P(GeoAttX):
|
|
226
|
+
def __init__(self, model_path, root_path=None, config='predict_qpe.yml', area=((100, 140, 10), (20, 60, 10)),
|
|
227
|
+
device=None):
|
|
228
|
+
super().__init__(config=config, root_path=root_path, task_type='prec', area=area)
|
|
229
|
+
if device is not None:
|
|
230
|
+
self.device = device
|
|
231
|
+
self.model = self.load_model(model_path)
|
|
232
|
+
|
|
233
|
+
def save(self, y, save_name, info_log=True, print_log=True):
|
|
234
|
+
np2tif(y, save_path=self.root_path, out_name=save_name, coord=fy.getFY_coord_clip(self.area), dtype=np.float32,
|
|
235
|
+
print_log=False, dim_value=[{'value': ['qpe']}])
|
|
236
|
+
if print_log:
|
|
237
|
+
print(f'data saved in {self.root_path}')
|
|
238
|
+
if info_log:
|
|
239
|
+
with open(os.path.join(self.root_path, 'info.log'), 'w') as f:
|
|
240
|
+
f.write(f'QPE 反演:{save_name}\n')
|
|
241
|
+
return self.root_path
|
|
242
|
+
|
|
243
|
+
def predict(self, fy_npy):
|
|
244
|
+
try:
|
|
245
|
+
if type(fy_npy) is str:
|
|
246
|
+
print(f'正在反演:{fy_npy}...')
|
|
247
|
+
if not os.path.exists(fy_npy):
|
|
248
|
+
raise NoFileException(fy_npy)
|
|
249
|
+
n_data = np.load(fy_npy)
|
|
250
|
+
elif type(fy_npy) is np.ndarray:
|
|
251
|
+
n_data = fy_npy
|
|
252
|
+
else:
|
|
253
|
+
raise Exception('输入数据类型错误,仅支持文件路径或numpy数组')
|
|
254
|
+
n_data = torch.from_numpy(n_data)
|
|
255
|
+
norm = PrecNormalization(self.args.prec_data_path)
|
|
256
|
+
norm.mean_fy, norm.mean_qpe, norm.std_fy, norm.std_qpe = \
|
|
257
|
+
data_to_device([norm.mean_fy, norm.mean_qpe, norm.std_fy, norm.std_qpe], self.device, self.args.fp)
|
|
258
|
+
n_data = data_to_device([n_data], self.device, self.args.fp)[0]
|
|
259
|
+
n_data = rearrange(n_data, '(b t c) h w -> b t c h w', b=1, t=1)
|
|
260
|
+
n = norm.norm(n_data, norm_type='fy')[:, 0, :, :, :]
|
|
261
|
+
y_ = self.model(n, n)
|
|
262
|
+
y = norm.denorm(y_, norm_type='qpe').detach().cpu().numpy()[0]
|
|
263
|
+
return y
|
|
264
|
+
except NoFileException as e:
|
|
265
|
+
os.makedirs(self.root_path, exist_ok=True)
|
|
266
|
+
with open(os.path.join(self.root_path, 'err.log'), 'a') as f:
|
|
267
|
+
filename = e.file_name.split(os.sep)[-1]
|
|
268
|
+
file_info = fy.prase_filename(filename)
|
|
269
|
+
f.write(f'Not exist {file_info["start"]} {e.file_name}\n')
|
|
270
|
+
return None
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class Huayu(GeoAttX):
|
|
274
|
+
def __init__(self, model_path, root_path=None, config='predict_imerg.yml', area=((100, 140, 10), (20, 60, 10)),
|
|
275
|
+
device=None):
|
|
276
|
+
super().__init__(config=config, root_path=root_path, task_type='prem', area=area)
|
|
277
|
+
if device is not None:
|
|
278
|
+
self.device = device
|
|
279
|
+
if self.args.sensor_type == AGRI:
|
|
280
|
+
self.sensor = AGRI
|
|
281
|
+
self.satellite_num = 2
|
|
282
|
+
self.satellite_channel = 7
|
|
283
|
+
elif self.args.sensor_type == ABI:
|
|
284
|
+
self.sensor = ABI
|
|
285
|
+
self.satellite_num = 3
|
|
286
|
+
self.satellite_channel = 9
|
|
287
|
+
elif self.args.sensor_type == SEVIRI:
|
|
288
|
+
self.sensor = SEVIRI
|
|
289
|
+
self.satellite_num = 2
|
|
290
|
+
self.satellite_channel = 7
|
|
291
|
+
else:
|
|
292
|
+
raise ValueError(f'Unknown sensor type{self.args.sensor_type}')
|
|
293
|
+
self.model = self.load_model(model_path, version=2, c_in=self.satellite_channel)
|
|
294
|
+
|
|
295
|
+
def save(self, y, save_name, info_log=True, print_log=True):
|
|
296
|
+
np2tif(y, save_path=self.root_path, out_name=save_name, coord=fy.getFY_coord_clip(self.area), dtype=np.float32,
|
|
297
|
+
print_log=False, dim_value=[{'value': ['imerg']}])
|
|
298
|
+
if print_log:
|
|
299
|
+
print(f'data saved in {self.root_path}')
|
|
300
|
+
if info_log:
|
|
301
|
+
with open(os.path.join(self.root_path, 'info.log'), 'w') as f:
|
|
302
|
+
f.write(f'Imerg 反演:{save_name}\n')
|
|
303
|
+
return self.root_path
|
|
304
|
+
|
|
305
|
+
def predict(self, satellite_file, npy_path=None, smooth=True, up=True, area=None, satellite_date=None):
|
|
306
|
+
try:
|
|
307
|
+
if npy_path is None:
|
|
308
|
+
if self.sensor == AGRI:
|
|
309
|
+
n_data, coord = fy.getNPfromHDF(satellite_file, return_coord=True)
|
|
310
|
+
elif self.sensor == ABI:
|
|
311
|
+
n_data, coord = goes.getNPfromDir(satellite_file, satellite_date, return_coord=True)
|
|
312
|
+
elif self.sensor == SEVIRI:
|
|
313
|
+
n_data, coord = metsat.getNPfromNAT(satellite_file, return_coord=True)
|
|
314
|
+
else:
|
|
315
|
+
raise ValueError(f'Unknown sensor type{self.args.sensor_type}')
|
|
316
|
+
n_data = clipSatelliteNP(n_data, coord.ld, self.area if area is None else area)
|
|
317
|
+
else:
|
|
318
|
+
n_data = np.load(npy_path)
|
|
319
|
+
mean_std_npy = np.load(os.path.join(rf'{self.args.data_path}', 'mean_std.npy'))
|
|
320
|
+
satellite_norm = Normalization(mean_std_npy, (0, self.satellite_channel))
|
|
321
|
+
imerg_norm = Normalization(mean_std_npy,
|
|
322
|
+
(self.satellite_channel * self.satellite_num,
|
|
323
|
+
self.satellite_channel * self.satellite_num + 3))
|
|
324
|
+
satellite_norm.mean, satellite_norm.std, imerg_norm.mean, imerg_norm.std = \
|
|
325
|
+
data_to_device([satellite_norm.mean, satellite_norm.std, imerg_norm.mean, imerg_norm.std],
|
|
326
|
+
self.device, self.args.fp)
|
|
327
|
+
n_data = data_to_device([n_data], self.device, self.args.fp)[0]
|
|
328
|
+
n_data = rearrange(n_data, '(b c) h w -> b c h w', b=1)
|
|
329
|
+
n = satellite_norm.norm(n_data)[:, :, :, :]
|
|
330
|
+
ps = nn.PixelShuffle(2)
|
|
331
|
+
ups = nn.PixelUnshuffle(2)
|
|
332
|
+
if smooth:
|
|
333
|
+
smooth = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
|
|
334
|
+
n = ups(n)
|
|
335
|
+
n = rearrange(n, 'b (c dsize) h w -> (b dsize) c h w', dsize=4)
|
|
336
|
+
if not up:
|
|
337
|
+
n = n.mean(dim=0, keepdim=True)
|
|
338
|
+
y_ = self.model(n)
|
|
339
|
+
y_ = rearrange(y_, '(b dsize) c h w -> b (c dsize) h w', b=1)
|
|
340
|
+
if up:
|
|
341
|
+
y_ = ps(y_)
|
|
342
|
+
y = imerg_norm.denorm(y_)[0]
|
|
343
|
+
y[0][y[1] > y[2]] = 0
|
|
344
|
+
y[0][y[0] < 0] = 0
|
|
345
|
+
y = rearrange(y[0], '(b h) w -> b h w', b=1)
|
|
346
|
+
_, H, W = y.shape
|
|
347
|
+
if smooth:
|
|
348
|
+
y[0, 1:H - 1, 1:W - 1] = smooth(y)[0, 1:H - 1, 1:W - 1]
|
|
349
|
+
return y.detach().cpu().numpy()
|
|
350
|
+
except NoFileException as e:
|
|
351
|
+
os.makedirs(self.root_path, exist_ok=True)
|
|
352
|
+
with open(os.path.join(self.root_path, 'err.log'), 'a') as f:
|
|
353
|
+
filename = e.file_name.split(os.sep)[-1]
|
|
354
|
+
file_info = fy.prase_filename(filename)
|
|
355
|
+
f.write(f'Not exist {file_info["start"]} {e.file_name}\n')
|
|
356
|
+
return None
|
|
File without changes
|