jacksung-dev 0.0.4.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jacksung/__init__.py +1 -0
- jacksung/ai/GeoAttX.py +356 -0
- jacksung/ai/GeoNet/__init__.py +0 -0
- jacksung/ai/GeoNet/m_block.py +393 -0
- jacksung/ai/GeoNet/m_blockV2.py +442 -0
- jacksung/ai/GeoNet/m_network.py +107 -0
- jacksung/ai/GeoNet/m_networkV2.py +91 -0
- jacksung/ai/__init__.py +0 -0
- jacksung/ai/latex_tool.py +199 -0
- jacksung/ai/metrics.py +181 -0
- jacksung/ai/utils/__init__.py +0 -0
- jacksung/ai/utils/cmorph.py +42 -0
- jacksung/ai/utils/data_parallelV2.py +90 -0
- jacksung/ai/utils/fy.py +333 -0
- jacksung/ai/utils/goes.py +161 -0
- jacksung/ai/utils/gsmap.py +24 -0
- jacksung/ai/utils/imerg.py +159 -0
- jacksung/ai/utils/metsat.py +164 -0
- jacksung/ai/utils/norm_util.py +109 -0
- jacksung/ai/utils/util.py +300 -0
- jacksung/libs/times.ttf +0 -0
- jacksung/utils/__init__.py +1 -0
- jacksung/utils/base_db.py +72 -0
- jacksung/utils/cache.py +71 -0
- jacksung/utils/data_convert.py +273 -0
- jacksung/utils/exception.py +27 -0
- jacksung/utils/fastnumpy.py +115 -0
- jacksung/utils/figure.py +251 -0
- jacksung/utils/hash.py +26 -0
- jacksung/utils/image.py +221 -0
- jacksung/utils/log.py +86 -0
- jacksung/utils/login.py +149 -0
- jacksung/utils/mean_std.py +66 -0
- jacksung/utils/multi_task.py +129 -0
- jacksung/utils/number.py +6 -0
- jacksung/utils/nvidia.py +140 -0
- jacksung/utils/time.py +87 -0
- jacksung/utils/web.py +63 -0
- jacksung_dev-0.0.4.15.dist-info/LICENSE +201 -0
- jacksung_dev-0.0.4.15.dist-info/METADATA +228 -0
- jacksung_dev-0.0.4.15.dist-info/RECORD +44 -0
- jacksung_dev-0.0.4.15.dist-info/WHEEL +5 -0
- jacksung_dev-0.0.4.15.dist-info/entry_points.txt +3 -0
- jacksung_dev-0.0.4.15.dist-info/top_level.txt +1 -0
jacksung/utils/cache.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Cache:
|
|
6
|
+
def __init__(self, cache_len=120):
|
|
7
|
+
self.cache_L = threading.Lock()
|
|
8
|
+
self.cache = {}
|
|
9
|
+
self.cache_list = []
|
|
10
|
+
self.cache_len = cache_len
|
|
11
|
+
|
|
12
|
+
# 判断key是否在缓存中,在则返回key对应的值,否则返回False, 同时锁住该key的所有查询,直到被放入数据
|
|
13
|
+
def get_key_in_cache(self, key):
|
|
14
|
+
try:
|
|
15
|
+
while True:
|
|
16
|
+
self.cache_L.acquire()
|
|
17
|
+
if key in self.cache.keys():
|
|
18
|
+
if self.cache[key].is_ok:
|
|
19
|
+
result = self.cache[key].value
|
|
20
|
+
self.cache_list.remove(key)
|
|
21
|
+
self.cache_list.append(key)
|
|
22
|
+
self.cache_L.release()
|
|
23
|
+
return result
|
|
24
|
+
else:
|
|
25
|
+
break
|
|
26
|
+
self.cache_L.release()
|
|
27
|
+
time.sleep(0.5)
|
|
28
|
+
self.__add_key(key)
|
|
29
|
+
self.cache_L.release()
|
|
30
|
+
return None
|
|
31
|
+
except Exception as e:
|
|
32
|
+
self.cache_L.release()
|
|
33
|
+
raise e
|
|
34
|
+
|
|
35
|
+
def add_key(self, key, value):
|
|
36
|
+
try:
|
|
37
|
+
self.cache_L.acquire()
|
|
38
|
+
if key not in self.cache.keys():
|
|
39
|
+
self.__add_key(key)
|
|
40
|
+
self.cache[key].set_value(value)
|
|
41
|
+
self.cache_L.release()
|
|
42
|
+
return value
|
|
43
|
+
except Exception as e:
|
|
44
|
+
self.cache_L.release()
|
|
45
|
+
raise e
|
|
46
|
+
|
|
47
|
+
def __add_key(self, key):
|
|
48
|
+
self.cache_list.append(key)
|
|
49
|
+
self.cache[key] = self.CacheClass(key)
|
|
50
|
+
if len(self.cache_list) > self.cache_len:
|
|
51
|
+
del_key = self.cache_list.pop(0)
|
|
52
|
+
del self.cache[del_key]
|
|
53
|
+
|
|
54
|
+
class CacheClass:
|
|
55
|
+
def __init__(self, key):
|
|
56
|
+
self.is_ok = False
|
|
57
|
+
self.key = key
|
|
58
|
+
self.value = None
|
|
59
|
+
|
|
60
|
+
def set_value(self, value):
|
|
61
|
+
self.value = value
|
|
62
|
+
self.is_ok = True
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
if __name__ == '__main__':
|
|
66
|
+
cache = Cache(2)
|
|
67
|
+
cache.add_key('a', 1)
|
|
68
|
+
cache.add_key('b', 2)
|
|
69
|
+
print(cache.get_key_in_cache('a'))
|
|
70
|
+
cache.add_key('c', 3)
|
|
71
|
+
print(cache)
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
import numpy as np
|
|
3
|
+
from einops import rearrange
|
|
4
|
+
from rasterio.transform import from_origin
|
|
5
|
+
import netCDF4 as nc
|
|
6
|
+
import rasterio
|
|
7
|
+
from rasterio.transform import from_gcps
|
|
8
|
+
from rasterio.control import GroundControlPoint as GCP
|
|
9
|
+
from typing import Tuple
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Coordinate:
|
|
13
|
+
def __init__(self, left, top, x_res=None, y_res=None, right=None, bottom=None, h=None, w=None):
|
|
14
|
+
self.left = left
|
|
15
|
+
self.top = top
|
|
16
|
+
self.x_res = x_res
|
|
17
|
+
self.y_res = y_res
|
|
18
|
+
self.h = h
|
|
19
|
+
self.w = w
|
|
20
|
+
self.right = right
|
|
21
|
+
self.bottom = bottom
|
|
22
|
+
if self.x_res is None:
|
|
23
|
+
if self.right is None or self.w is None:
|
|
24
|
+
raise Exception('x_res is None, right or w is also None')
|
|
25
|
+
else:
|
|
26
|
+
self.x_res = (self.right - self.left) / self.w
|
|
27
|
+
if self.y_res is None:
|
|
28
|
+
if self.bottom is None or self.h is None:
|
|
29
|
+
raise Exception('y_res is None, bottom or h is also None')
|
|
30
|
+
else:
|
|
31
|
+
self.y_res = (self.top - self.bottom) / self.h
|
|
32
|
+
if self.right is None:
|
|
33
|
+
if self.x_res is None or self.w is None:
|
|
34
|
+
raise Exception('right is None, x_res or w is also None')
|
|
35
|
+
else:
|
|
36
|
+
self.right = self.left + self.x_res * self.w
|
|
37
|
+
if self.bottom is None:
|
|
38
|
+
if self.y_res is None or self.h is None:
|
|
39
|
+
raise Exception('bottom is None, y_res or h is also None')
|
|
40
|
+
else:
|
|
41
|
+
self.bottom = self.top - self.y_res * self.h
|
|
42
|
+
if self.w is None:
|
|
43
|
+
if self.right is None or self.x_res is None:
|
|
44
|
+
raise Exception('w is None, right or x_res is also None')
|
|
45
|
+
else:
|
|
46
|
+
self.w = (self.right - self.left) / self.x_res
|
|
47
|
+
if self.h is None:
|
|
48
|
+
if self.bottom is None or self.y_res is None:
|
|
49
|
+
raise Exception('h is None, bottom or y_res is also None')
|
|
50
|
+
else:
|
|
51
|
+
self.h = (self.top - self.bottom) / self.y_res
|
|
52
|
+
self.h, self.w = int(self.h), int(self.w)
|
|
53
|
+
self.ld = (self.left + self.right) / 2
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def dms_to_d10(data):
|
|
57
|
+
d, m, s = float(data[0]), float(data[1]), float(data[2])
|
|
58
|
+
return d + 1 / 60 * m + 1 / 60 / 60 * s
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def make_dms(s):
|
|
62
|
+
print(s)
|
|
63
|
+
temp = s.split('°')
|
|
64
|
+
s_d = temp[0]
|
|
65
|
+
temp = temp[1].split('′')
|
|
66
|
+
s_m = temp[0]
|
|
67
|
+
if len(temp) > 1 and temp[1]:
|
|
68
|
+
s_s = temp[1].split('″')[0]
|
|
69
|
+
else:
|
|
70
|
+
s_s = 0
|
|
71
|
+
result = [float(s_d), float(s_m), float(s_s)]
|
|
72
|
+
return result
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _save_np2tif(np_data, output_dir, out_name, coordinate=None, resolution=None, dtype=None, print_log=False,
|
|
76
|
+
transform=None):
|
|
77
|
+
h, w = np_data.shape
|
|
78
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
79
|
+
save_path = os.path.join(output_dir, out_name)
|
|
80
|
+
if coordinate:
|
|
81
|
+
# 创建带地理坐标信息的 GeoTIFF 文件
|
|
82
|
+
left, top = coordinate
|
|
83
|
+
res1, res2 = resolution
|
|
84
|
+
# 左上角坐标, 先经后纬
|
|
85
|
+
transform = from_origin(left, top, res1, res2)
|
|
86
|
+
with rasterio.open(save_path, "w", driver="GTiff", width=w, height=h, count=1,
|
|
87
|
+
dtype=dtype if dtype else np_data.dtype, crs="EPSG:4326", transform=transform) as dst:
|
|
88
|
+
dst.write(np_data, 1)
|
|
89
|
+
if print_log:
|
|
90
|
+
print(f"GeoTIFF '{save_path}' generated with geographic coordinates.")
|
|
91
|
+
elif transform:
|
|
92
|
+
with rasterio.open(save_path, "w", driver="GTiff", width=w, height=h, count=1,
|
|
93
|
+
dtype=dtype if dtype else np_data.dtype, crs="EPSG:4326", transform=transform) as dst:
|
|
94
|
+
dst.write(np_data, 1)
|
|
95
|
+
if print_log:
|
|
96
|
+
print(f"GeoTIFF '{save_path}' generated with geographic coordinates.")
|
|
97
|
+
else:
|
|
98
|
+
# 将数据写入TIFF文件
|
|
99
|
+
with rasterio.open(save_path, "w", width=w, height=h, count=1, dtype=dtype if dtype else np_data.dtype) as dst:
|
|
100
|
+
dst.write(np_data, 1)
|
|
101
|
+
if print_log:
|
|
102
|
+
print(f"TIFF image saved as '{save_path}'")
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def np2tif(input_data, save_path='np2tif_dir', out_name='', left=None, top=None, x_res=None, y_res=None, dtype=None,
|
|
106
|
+
dim_value=None, coord=None, print_log=True, transform=None):
|
|
107
|
+
if type(input_data) == str:
|
|
108
|
+
np_data = np.load(input_data)
|
|
109
|
+
if np_data is None:
|
|
110
|
+
print(f'load {input_data} is None')
|
|
111
|
+
else:
|
|
112
|
+
np_data = input_data
|
|
113
|
+
shape = np_data.shape
|
|
114
|
+
if len(shape) < 2:
|
|
115
|
+
raise Exception(str(shape) + 'is less than 2 Dimensions')
|
|
116
|
+
mode_list = ['d' + str(i) for i in range(len(shape))]
|
|
117
|
+
|
|
118
|
+
mode_str = ' '.join(mode_list) + '->(' + ' '.join(mode_list[:-2]) + ') ' + mode_list[-2] + ' ' + mode_list[-1]
|
|
119
|
+
np_data = rearrange(np_data, mode_str)
|
|
120
|
+
if left is not None and top is not None:
|
|
121
|
+
coordinate = (left, top)
|
|
122
|
+
elif coord is not None:
|
|
123
|
+
coordinate = (coord.left, coord.top)
|
|
124
|
+
x_res, y_res = coord.x_res, coord.y_res
|
|
125
|
+
else:
|
|
126
|
+
coordinate = None
|
|
127
|
+
for idx, single_np in enumerate(np_data):
|
|
128
|
+
name = ''
|
|
129
|
+
idx_tmp = idx
|
|
130
|
+
for s in range(len(shape[:-2])):
|
|
131
|
+
name += '-'
|
|
132
|
+
temp = int(idx_tmp // np.prod(shape[s + 1:-2], axis=None))
|
|
133
|
+
if dim_value is not None:
|
|
134
|
+
plus_name = str(dim_value[s]['value'][temp])
|
|
135
|
+
else:
|
|
136
|
+
plus_name = str(temp)
|
|
137
|
+
name += plus_name
|
|
138
|
+
idx_tmp -= temp * np.prod(shape[s + 1:-2], axis=None)
|
|
139
|
+
name = out_name + name + '.tif'
|
|
140
|
+
# if name == '':
|
|
141
|
+
# name = out_name + '.tif'
|
|
142
|
+
# else:
|
|
143
|
+
# name = out_name + '-' + name + '.tif'
|
|
144
|
+
_save_np2tif(single_np, save_path, name, coordinate=coordinate, resolution=(x_res, y_res), dtype=dtype,
|
|
145
|
+
print_log=print_log, transform=transform)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def nc2tif(input_data, save_path='np2tif_dir', lock=None):
|
|
149
|
+
np_data, dim_value = nc2np(input_data, lock)
|
|
150
|
+
np2tif(np_data, save_path, dim_value=dim_value)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def nc2np(input_data, lock=None, return_dim=True):
|
|
154
|
+
if type(input_data) == str:
|
|
155
|
+
if lock:
|
|
156
|
+
lock.acquire()
|
|
157
|
+
nc_data = nc.Dataset(input_data) # 读取.nc文件,传入f中。此时f包含了该.nc文件的全部信息
|
|
158
|
+
if lock:
|
|
159
|
+
lock.release()
|
|
160
|
+
else:
|
|
161
|
+
nc_data = input_data
|
|
162
|
+
vars = []
|
|
163
|
+
max_shape = 0
|
|
164
|
+
dimensions = {}
|
|
165
|
+
for name, var in nc_data.variables.items():
|
|
166
|
+
if len(var.shape) > max_shape:
|
|
167
|
+
max_shape = len(var.shape)
|
|
168
|
+
vars = [name]
|
|
169
|
+
elif len(var.shape) == max_shape:
|
|
170
|
+
vars.append(name)
|
|
171
|
+
if len(var.shape) == 1:
|
|
172
|
+
dimensions[name] = list(nc_data[name][:])
|
|
173
|
+
np_data = []
|
|
174
|
+
for var in vars:
|
|
175
|
+
np_data.append(np.array(nc_data[var][:]))
|
|
176
|
+
np_data = np.array(np_data)
|
|
177
|
+
value_idx = 0
|
|
178
|
+
while 'value' + str(value_idx) in dimensions:
|
|
179
|
+
value_idx += 1
|
|
180
|
+
value_key = 'value' + str(value_idx)
|
|
181
|
+
dimensions[value_key] = vars
|
|
182
|
+
if return_dim:
|
|
183
|
+
np_idx = [value_key] + list(nc_data[var].dimensions)
|
|
184
|
+
dim_value = [{'dim_name': key, 'value': dimensions[key]} for key in np_idx]
|
|
185
|
+
else:
|
|
186
|
+
dim_value = None
|
|
187
|
+
if type(input_data) == str:
|
|
188
|
+
nc_data.close()
|
|
189
|
+
return np_data, dim_value
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def add_None(a, b):
|
|
193
|
+
if a is None:
|
|
194
|
+
return b
|
|
195
|
+
else:
|
|
196
|
+
return a + b
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def get_transform_from_lonlat_matrices(
|
|
200
|
+
lon_array: np.ndarray,
|
|
201
|
+
lat_array: np.ndarray,
|
|
202
|
+
gcp_density: int = 10,
|
|
203
|
+
print_log=False,
|
|
204
|
+
crs: str = "EPSG:4326"
|
|
205
|
+
) -> Tuple[rasterio.Affine, float]:
|
|
206
|
+
"""
|
|
207
|
+
从每个像素的经纬度矩阵中,拟合并输出rasterio的transform(仿射变换矩阵)
|
|
208
|
+
参数:
|
|
209
|
+
lon_array: 2D numpy数组,shape为[height, width],存放每个像素的经度
|
|
210
|
+
lat_array: 2D numpy数组,shape为[height, width],存放每个像素的纬度
|
|
211
|
+
gcp_density: 控制点密度(每边提取的GCP数量),默认10(总GCP数≈10×10=100个)
|
|
212
|
+
范围越大,建议设越大(如20-50),拟合精度越高
|
|
213
|
+
crs: 坐标系字符串(默认WGS84经纬度,EPSG:4326)
|
|
214
|
+
|
|
215
|
+
返回:
|
|
216
|
+
transform: rasterio.Affine对象,像素坐标到经纬度的仿射变换矩阵
|
|
217
|
+
avg_error_km: 平均拟合误差(km),用于验证精度
|
|
218
|
+
"""
|
|
219
|
+
# 1. 验证输入矩阵的有效性
|
|
220
|
+
if lon_array.shape != lat_array.shape:
|
|
221
|
+
raise ValueError(f"经度矩阵和纬度矩阵形状不匹配!lon_shape={lon_array.shape}, lat_shape={lat_array.shape}")
|
|
222
|
+
height, width = lon_array.shape
|
|
223
|
+
if height < 2 or width < 2:
|
|
224
|
+
raise ValueError("矩阵尺寸过小(至少需要2×2像素),无法拟合transform")
|
|
225
|
+
# 2. 均匀提取地面控制点(GCPs)- 避免边缘和密集采样,保证全局覆盖
|
|
226
|
+
# 生成均匀分布的像素坐标(col, row)
|
|
227
|
+
col_indices = np.linspace(0, width - 1, gcp_density, dtype=int)
|
|
228
|
+
row_indices = np.linspace(0, height - 1, gcp_density, dtype=int)
|
|
229
|
+
col_grid, row_grid = np.meshgrid(col_indices, row_indices) # 网格状GCP分布
|
|
230
|
+
# 3. 构造GCP列表(像素坐标 → 经纬度)
|
|
231
|
+
gcps = []
|
|
232
|
+
for row, col in zip(row_grid.flatten(), col_grid.flatten()):
|
|
233
|
+
lon = lon_array[row, col]
|
|
234
|
+
lat = lat_array[row, col]
|
|
235
|
+
# 跳过无效经纬度(如NaN)
|
|
236
|
+
if np.isnan(lon) or np.isnan(lat):
|
|
237
|
+
continue
|
|
238
|
+
# GCP格式:GCP(像素列, 像素行, 经度, 纬度)
|
|
239
|
+
gcps.append(GCP(col, row, lon, lat))
|
|
240
|
+
if len(gcps) < 3:
|
|
241
|
+
raise ValueError(f"有效控制点不足3个(仅{len(gcps)}个),无法拟合仿射变换")
|
|
242
|
+
# 4. 基于GCPs拟合transform(最小二乘法)
|
|
243
|
+
transform = from_gcps(gcps)
|
|
244
|
+
# 5. 计算拟合误差(验证精度)
|
|
245
|
+
errors_km = []
|
|
246
|
+
for gcp in gcps:
|
|
247
|
+
# 用拟合的transform反推经纬度
|
|
248
|
+
pred_lon, pred_lat = transform * (gcp.col, gcp.row)
|
|
249
|
+
# 用半正矢公式计算实际经纬度与预测值的距离(km)
|
|
250
|
+
error_km = haversine_distance(gcp.x, gcp.y, pred_lon, pred_lat)
|
|
251
|
+
errors_km.append(error_km)
|
|
252
|
+
avg_error_km = np.mean(errors_km)
|
|
253
|
+
max_error_km = np.max(errors_km)
|
|
254
|
+
if print_log:
|
|
255
|
+
print(f"拟合完成:平均误差={avg_error_km:.3f}km,最大误差={max_error_km:.3f}km")
|
|
256
|
+
print(f"提示:若误差过大(>0.5km),请增大gcp_density(当前={gcp_density})")
|
|
257
|
+
return transform, avg_error_km
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def haversine_distance(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
|
|
261
|
+
"""辅助函数:用半正矢公式计算两点间地球表面距离(km)"""
|
|
262
|
+
lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
|
|
263
|
+
dlon = lon2 - lon1
|
|
264
|
+
dlat = lat2 - lat1
|
|
265
|
+
a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
|
|
266
|
+
c = 2 * np.arcsin(np.sqrt(a))
|
|
267
|
+
return 6371 * c # 地球平均半径≈6371km
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
if __name__ == '__main__':
|
|
271
|
+
np_data, dim = nc2np(r'C:\Users\jackSung\Desktop\download.nc')
|
|
272
|
+
np2tif(np_data, 'com', dim_value=dim)
|
|
273
|
+
print(dim)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from jacksung.utils.log import oprint as print
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class NoFileException(Exception):
|
|
6
|
+
def __init__(self, file_name):
|
|
7
|
+
self.file_name = file_name
|
|
8
|
+
super().__init__(f'No such file: {file_name}')
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NanNPException(Exception):
|
|
12
|
+
def __init__(self, file_name):
|
|
13
|
+
self.file_name = file_name
|
|
14
|
+
super().__init__(f'Nan value in np data: {file_name}')
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def wait_fun(fun, args, catch_exception=Exception, sleep_time=0.5, wait_time=5, open_log=True):
|
|
18
|
+
try:
|
|
19
|
+
return fun(*args)
|
|
20
|
+
except catch_exception as e:
|
|
21
|
+
if open_log:
|
|
22
|
+
print(f'Task {args} failed, retry in {sleep_time}s, remain waiting time: {wait_time}s')
|
|
23
|
+
if wait_time <= 0:
|
|
24
|
+
raise e
|
|
25
|
+
else:
|
|
26
|
+
time.sleep(sleep_time)
|
|
27
|
+
return wait_fun(fun, args, catch_exception, sleep_time=sleep_time, wait_time=wait_time - sleep_time)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.lib.format
|
|
4
|
+
import struct
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Add():
|
|
8
|
+
def __init__(self):
|
|
9
|
+
self.datas = None
|
|
10
|
+
self.count = 0
|
|
11
|
+
|
|
12
|
+
def add(self, np_data):
|
|
13
|
+
if self.datas is None:
|
|
14
|
+
self.datas = np_data.astype(np.float64)
|
|
15
|
+
else:
|
|
16
|
+
self.datas += np_data.astype(np.float64)
|
|
17
|
+
self.count += 1
|
|
18
|
+
|
|
19
|
+
def get_mean(self):
|
|
20
|
+
return self.datas / self.count
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def save(file, array):
|
|
24
|
+
magic_string = b"\x93NUMPY\x01\x00v\x00"
|
|
25
|
+
header = bytes(("{'descr': '" + array.dtype.descr[0][1] + "', 'fortran_order': False, 'shape': " + str(
|
|
26
|
+
array.shape) + ", }").ljust(127 - len(magic_string)) + "\n", 'utf-8')
|
|
27
|
+
if type(file) == str:
|
|
28
|
+
file = open(file, "wb")
|
|
29
|
+
file.write(magic_string)
|
|
30
|
+
file.write(header)
|
|
31
|
+
file.write(array.tobytes())
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def pack(array):
|
|
35
|
+
size = len(array.shape)
|
|
36
|
+
return bytes(array.dtype.byteorder.replace('=', '<' if sys.byteorder == 'little' else '>') + array.dtype.kind,
|
|
37
|
+
'utf-8') + array.dtype.itemsize.to_bytes(1, byteorder='little') + struct.pack(f'<B{size}I', size,
|
|
38
|
+
*array.shape) + array.tobytes()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def load(file):
|
|
42
|
+
if type(file) == str:
|
|
43
|
+
file = open(file, "rb")
|
|
44
|
+
header = file.read(128)
|
|
45
|
+
if not header:
|
|
46
|
+
return None
|
|
47
|
+
descr = str(header[19:25], 'utf-8').replace("'", "").replace(" ", "")
|
|
48
|
+
shape = tuple(int(num) for num in
|
|
49
|
+
str(header[60:120], 'utf-8').replace(', }', '')
|
|
50
|
+
.replace('(', '').replace(')', '').replace(' ', '').split(',') if num)
|
|
51
|
+
datasize = numpy.lib.format.descr_to_dtype(descr).itemsize
|
|
52
|
+
for dimension in shape:
|
|
53
|
+
datasize *= dimension
|
|
54
|
+
return np.ndarray(shape, dtype=descr, buffer=file.read(datasize))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def unpack(data):
|
|
58
|
+
dtype = str(data[:2], 'utf-8')
|
|
59
|
+
dtype += str(data[2])
|
|
60
|
+
size = data[3]
|
|
61
|
+
shape = struct.unpack_from(f'<{size}I', data, 4)
|
|
62
|
+
datasize = data[2]
|
|
63
|
+
for dimension in shape:
|
|
64
|
+
datasize *= dimension
|
|
65
|
+
return np.ndarray(shape, dtype=dtype, buffer=data[4 + size * 4:4 + size * 4 + datasize])
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
if __name__ == "__main__":
|
|
69
|
+
import io
|
|
70
|
+
from timeit import default_timer as timer
|
|
71
|
+
from datetime import timedelta
|
|
72
|
+
|
|
73
|
+
iterations = 100000
|
|
74
|
+
testarray = np.random.rand(3, 64, 64).astype('float32')
|
|
75
|
+
|
|
76
|
+
start = timer()
|
|
77
|
+
for i in range(iterations):
|
|
78
|
+
buffer = io.BytesIO()
|
|
79
|
+
np.save(buffer, testarray)
|
|
80
|
+
numpy_save_data = buffer.getvalue()
|
|
81
|
+
print("numpy.save:", timedelta(seconds=timer() - start))
|
|
82
|
+
|
|
83
|
+
start = timer()
|
|
84
|
+
for i in range(iterations):
|
|
85
|
+
buffer = io.BytesIO()
|
|
86
|
+
save(buffer, testarray)
|
|
87
|
+
fastnumpyio_save_data = buffer.getvalue()
|
|
88
|
+
print("fastnumpyio.save:", timedelta(seconds=timer() - start))
|
|
89
|
+
|
|
90
|
+
start = timer()
|
|
91
|
+
for i in range(iterations):
|
|
92
|
+
fastnumpyio_pack_data = pack(testarray)
|
|
93
|
+
print("fastnumpyio.pack:", timedelta(seconds=timer() - start))
|
|
94
|
+
|
|
95
|
+
start = timer()
|
|
96
|
+
for i in range(iterations):
|
|
97
|
+
buffer = io.BytesIO(numpy_save_data)
|
|
98
|
+
test_numpy_save = np.load(buffer)
|
|
99
|
+
print("numpy.load:", timedelta(seconds=timer() - start))
|
|
100
|
+
|
|
101
|
+
start = timer()
|
|
102
|
+
for i in range(iterations):
|
|
103
|
+
buffer = io.BytesIO(fastnumpyio_save_data)
|
|
104
|
+
test_fastnumpyio_save = load(buffer)
|
|
105
|
+
print("fastnumpyio.load:", timedelta(seconds=timer() - start))
|
|
106
|
+
|
|
107
|
+
start = timer()
|
|
108
|
+
for i in range(iterations):
|
|
109
|
+
test_fastnumpyio_pack = unpack(fastnumpyio_pack_data)
|
|
110
|
+
print("fastnumpyio.unpack:", timedelta(seconds=timer() - start))
|
|
111
|
+
|
|
112
|
+
print("numpy.save+numpy.load == fastnumpyio.save+fastnumpyio.load:",
|
|
113
|
+
np.array_equal(test_numpy_save, test_fastnumpyio_save))
|
|
114
|
+
print("numpy.save+numpy.load == fastnumpyio.pack+fastnumpyio.unpack:",
|
|
115
|
+
np.array_equal(test_numpy_save, test_fastnumpyio_pack))
|