shancx 1.9.33.109__py3-none-any.whl → 1.9.33.218__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shancx/{Dsalgor → Algo}/__init__.py +37 -1
- shancx/Calmetrics/__init__.py +78 -9
- shancx/Calmetrics/calmetrics.py +14 -0
- shancx/Calmetrics/rmseR2score.py +14 -3
- shancx/{Command.py → Cmd.py} +20 -15
- shancx/Config_.py +26 -0
- shancx/Df/__init__.py +11 -0
- shancx/Df/tool.py +0 -1
- shancx/NN/__init__.py +200 -11
- shancx/{path.py → Path1.py} +2 -3
- shancx/Plot/__init__.py +129 -403
- shancx/Plot/draw_day_CR_PNG.py +4 -21
- shancx/Plot/exam.py +116 -0
- shancx/Plot/plotGlobal.py +325 -0
- shancx/Plot/radarNmc.py +1 -48
- shancx/Plot/single_china_map.py +1 -1
- shancx/Point.py +46 -0
- shancx/QC.py +223 -0
- shancx/Read.py +17 -10
- shancx/Resize.py +79 -0
- shancx/SN/__init__.py +8 -1
- shancx/Time/timeCycle.py +97 -23
- shancx/Train/makelist.py +161 -155
- shancx/__init__.py +79 -232
- shancx/bak.py +78 -53
- shancx/geosProj.py +2 -2
- shancx/wait.py +35 -1
- {shancx-1.9.33.109.dist-info → shancx-1.9.33.218.dist-info}/METADATA +12 -4
- shancx-1.9.33.218.dist-info/RECORD +91 -0
- {shancx-1.9.33.109.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
- shancx/Plot/Mip.py +0 -42
- shancx/Plot/border.py +0 -44
- shancx/Plot/draw_day_CR_PNGUS.py +0 -206
- shancx/Plot/draw_day_CR_SVG.py +0 -275
- shancx/Plot/draw_day_pre_PNGUS.py +0 -205
- shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
- shancx/makenetCDFN.py +0 -42
- shancx-1.9.33.109.dist-info/RECORD +0 -91
- /shancx/{3DJU → 3D}/__init__.py +0 -0
- /shancx/{Dsalgor → Algo}/Class.py +0 -0
- /shancx/{Dsalgor → Algo}/CudaPrefetcher1.py +0 -0
- /shancx/{Dsalgor → Algo}/Fake_image.py +0 -0
- /shancx/{Dsalgor → Algo}/Hsml.py +0 -0
- /shancx/{Dsalgor → Algo}/L2Loss.py +0 -0
- /shancx/{Dsalgor → Algo}/MetricTracker.py +0 -0
- /shancx/{Dsalgor → Algo}/Normalize.py +0 -0
- /shancx/{Dsalgor → Algo}/OptimizerWScheduler.py +0 -0
- /shancx/{Dsalgor → Algo}/Rmageresize.py +0 -0
- /shancx/{Dsalgor → Algo}/Savemodel.py +0 -0
- /shancx/{Dsalgor → Algo}/SmoothL1_losses.py +0 -0
- /shancx/{Dsalgor → Algo}/Tqdm.py +0 -0
- /shancx/{Dsalgor → Algo}/checknan.py +0 -0
- /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
- /shancx/{Dsalgor → Algo}/iouJU.py +0 -0
- /shancx/{Dsalgor → Algo}/mask.py +0 -0
- /shancx/{Dsalgor → Algo}/psnr.py +0 -0
- /shancx/{Dsalgor → Algo}/ssim.py +0 -0
- /shancx/{Dsalgor → Algo}/structural_similarity.py +0 -0
- /shancx/{Dsalgor → Algo}/tool.py +0 -0
- /shancx/Calmetrics/{matrixLib.py → calmetricsmatrixLib.py} +0 -0
- /shancx/{Diffmodel → Diffm}/Psamples.py +0 -0
- /shancx/{Diffmodel → Diffm}/__init__.py +0 -0
- /shancx/{Diffmodel → Diffm}/test.py +0 -0
- /shancx/{Board → tensBoard}/__init__.py +0 -0
- {shancx-1.9.33.109.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
shancx/Plot/__init__.py
CHANGED
|
@@ -21,9 +21,9 @@ def plotGrey(img,name="plotGrey", saveDir="plotGrey",cmap='gray', title='Image')
|
|
|
21
21
|
import matplotlib.pyplot as plt
|
|
22
22
|
from shancx import crDir
|
|
23
23
|
import datetime
|
|
24
|
-
def plotMat(matrix,name='plotMat',saveDir="plotMat",title='Matrix Plot', xlabel='X-axis', ylabel='Y-axis', color_label='Value', cmap='viridis'):
|
|
24
|
+
def plotMat(matrix,name='plotMat',saveDir="plotMat",title='Matrix Plot', xlabel='X-axis', ylabel='Y-axis', color_label='Value', cmap='viridis',aspect="equal"): #aspect='auto'
|
|
25
25
|
now_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
26
|
-
plt.imshow(matrix, cmap=cmap, origin='upper', aspect='
|
|
26
|
+
plt.imshow(matrix, cmap=cmap, origin='upper', aspect=f'{aspect}')
|
|
27
27
|
plt.colorbar(label=color_label)
|
|
28
28
|
plt.title(title)
|
|
29
29
|
plt.xlabel(xlabel)
|
|
@@ -34,6 +34,35 @@ def plotMat(matrix,name='plotMat',saveDir="plotMat",title='Matrix Plot', xlabel=
|
|
|
34
34
|
plt.savefig(outpath)
|
|
35
35
|
plt.close()
|
|
36
36
|
|
|
37
|
+
import matplotlib.pyplot as plt
|
|
38
|
+
from shancx import crDir
|
|
39
|
+
import datetime
|
|
40
|
+
def plotMatplus(matrix, name='plotMat', saveDir="plotMat", title='Matrix Plot',
|
|
41
|
+
xlabel='Longitude', ylabel='Latitude', color_label='Value',
|
|
42
|
+
cmap='viridis', extent=None):
|
|
43
|
+
"""
|
|
44
|
+
extent: [lon_min, lon_max, lat_min, lat_max]
|
|
45
|
+
"""
|
|
46
|
+
now_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
47
|
+
plt.imshow(matrix, cmap=cmap, origin='upper', aspect='auto', extent=extent)
|
|
48
|
+
plt.colorbar(label=color_label)
|
|
49
|
+
plt.title(title)
|
|
50
|
+
|
|
51
|
+
# 添加度符号和方向标识
|
|
52
|
+
plt.xlabel(f'{xlabel} (°E)') # 东经
|
|
53
|
+
plt.ylabel(f'{ylabel} (°N)') # 北纬
|
|
54
|
+
|
|
55
|
+
plt.tight_layout()
|
|
56
|
+
outpath = f'./{saveDir}/{name}_{now_str}.png' if name=="plotMat" else f"./{saveDir}/{name}.png"
|
|
57
|
+
crDir(outpath)
|
|
58
|
+
plt.savefig(outpath)
|
|
59
|
+
plt.close()
|
|
60
|
+
"""
|
|
61
|
+
latlon = [10.0, 37.0, 105.0, 125.0]
|
|
62
|
+
latmin, latmax, lonmin, lonmax = latlon
|
|
63
|
+
plotMatplus(data,extent=[lon_min, lon_max, lat_min, lat_max])
|
|
64
|
+
"""
|
|
65
|
+
|
|
37
66
|
import datetime
|
|
38
67
|
from hjnwtx.colormap import cmp_hjnwtx
|
|
39
68
|
from shancx import crDir
|
|
@@ -45,7 +74,6 @@ def plotRadar(array_dt,name="plotRadar", saveDir="plotRadar",ty="CR"):
|
|
|
45
74
|
if len(array_dt.shape) == 2 and ty == "pre":
|
|
46
75
|
fig, ax = plt.subplots()
|
|
47
76
|
im = ax.imshow(array_dt, vmin=0, vmax=10, cmap=cmp_hjnwtx["pre_tqw"])
|
|
48
|
-
# 创建与图像高度一致的colorbar
|
|
49
77
|
divider = make_axes_locatable(ax)
|
|
50
78
|
cax = divider.append_axes("right", size="5%", pad=0.05)
|
|
51
79
|
plt.colorbar(im, cax=cax)
|
|
@@ -122,6 +150,27 @@ def plotScatter(df1,saveDir="plotScatter"):
|
|
|
122
150
|
plt.savefig(f"./{saveDir}/plotScatter_{now_str}.png", dpi=300, bbox_inches="tight")
|
|
123
151
|
plt.close()
|
|
124
152
|
|
|
153
|
+
import matplotlib.pyplot as plt
|
|
154
|
+
import os
|
|
155
|
+
def plotScatter1(true,pre,saveDir="plotScatter"):
|
|
156
|
+
now_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
157
|
+
plt.figure(figsize=(10, 8))
|
|
158
|
+
plt.scatter(
|
|
159
|
+
true,
|
|
160
|
+
pre,
|
|
161
|
+
s=25,
|
|
162
|
+
alpha=0.6,
|
|
163
|
+
edgecolor="black",
|
|
164
|
+
linewidth=0.5
|
|
165
|
+
)
|
|
166
|
+
plt.title("Scatter Plot of Ture Pre", fontsize=14)
|
|
167
|
+
plt.xlabel("Longitude", fontsize=12)
|
|
168
|
+
plt.ylabel("Latitude", fontsize=12)
|
|
169
|
+
plt.tight_layout()
|
|
170
|
+
os.makedirs(saveDir, exist_ok=True)
|
|
171
|
+
plt.savefig(f"./{saveDir}/plotScatter1_{now_str}.png", dpi=300, bbox_inches="tight")
|
|
172
|
+
plt.close()
|
|
173
|
+
|
|
125
174
|
import numpy as np
|
|
126
175
|
import matplotlib.pyplot as plt
|
|
127
176
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
@@ -148,7 +197,7 @@ def plotVal( epoch=0,*datasets, title=["input","prediction","truth"], saveDir="p
|
|
|
148
197
|
plt.savefig(filename)
|
|
149
198
|
plt.close(fig)
|
|
150
199
|
|
|
151
|
-
"""
|
|
200
|
+
"""
|
|
152
201
|
if total >= 3:
|
|
153
202
|
break
|
|
154
203
|
if epoch % 2 == 0:
|
|
@@ -200,11 +249,11 @@ def plotValplus(epoch=0, *datasets, title=["input", "prediction", "truth"], save
|
|
|
200
249
|
filename = f"{saveDir}/epoch_{epoch}.png"
|
|
201
250
|
plt.savefig(filename)
|
|
202
251
|
plt.close(fig)
|
|
203
|
-
"""
|
|
252
|
+
"""
|
|
204
253
|
if total >= 3:
|
|
205
254
|
break
|
|
206
255
|
if epoch % 2 == 0:
|
|
207
|
-
|
|
256
|
+
plotValplus(epoch,
|
|
208
257
|
data[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
|
|
209
258
|
output[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
|
|
210
259
|
label[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
|
|
@@ -213,30 +262,6 @@ def plotValplus(epoch=0, *datasets, title=["input", "prediction", "truth"], save
|
|
|
213
262
|
)
|
|
214
263
|
"""
|
|
215
264
|
|
|
216
|
-
import numpy as np
|
|
217
|
-
import matplotlib
|
|
218
|
-
matplotlib.use("Agg")
|
|
219
|
-
import matplotlib.pyplot as plt
|
|
220
|
-
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
221
|
-
from hjnwtx.colormap import cmp_hjnwtx
|
|
222
|
-
from shancx import crDir
|
|
223
|
-
import os
|
|
224
|
-
def plot_dataset(ax, data, title, cmap, vmin, vmax):
|
|
225
|
-
"""
|
|
226
|
-
Helper function to plot a single dataset on a given axis.
|
|
227
|
-
"""
|
|
228
|
-
im = ax.matshow(data, cmap=cmap, vmin=vmin, vmax=vmax)
|
|
229
|
-
# Remove axis ticks
|
|
230
|
-
ax.set_xticks([])
|
|
231
|
-
ax.set_yticks([])
|
|
232
|
-
# Add colorbar
|
|
233
|
-
divider = make_axes_locatable(ax)
|
|
234
|
-
cax = divider.append_axes("right", size="5%", pad=0.05)
|
|
235
|
-
cbar = plt.colorbar(im, cax=cax, ticks=np.linspace(vmin, vmax, 15))
|
|
236
|
-
cbar.set_ticks(np.linspace(vmin, vmax, 15))
|
|
237
|
-
# Set title
|
|
238
|
-
ax.set_title(title)
|
|
239
|
-
return im
|
|
240
265
|
|
|
241
266
|
def plotValplus1(epoch=0, *datasets, title=["input", "prediction", "truth"], saveDir="plotValplus", cmap='summer'):
|
|
242
267
|
"""
|
|
@@ -270,6 +295,43 @@ def plotValplus1(epoch=0, *datasets, title=["input", "prediction", "truth"], sav
|
|
|
270
295
|
plt.savefig(filename)
|
|
271
296
|
plt.close(fig)
|
|
272
297
|
|
|
298
|
+
"""
|
|
299
|
+
if total >= 3:
|
|
300
|
+
break
|
|
301
|
+
if epoch % 2 == 0:
|
|
302
|
+
plotValplus1(epoch,
|
|
303
|
+
data[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
|
|
304
|
+
output[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
|
|
305
|
+
label[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
|
|
306
|
+
title=["input", "prediction", "groundtruth"],
|
|
307
|
+
saveDir="plot_train_dir"
|
|
308
|
+
)
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
import numpy as np
|
|
312
|
+
import matplotlib
|
|
313
|
+
matplotlib.use("Agg")
|
|
314
|
+
import matplotlib.pyplot as plt
|
|
315
|
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
316
|
+
from hjnwtx.colormap import cmp_hjnwtx
|
|
317
|
+
from shancx import crDir
|
|
318
|
+
import os
|
|
319
|
+
def plot_dataset(ax, data, title, cmap, vmin, vmax): #Cited methods
|
|
320
|
+
"""
|
|
321
|
+
Helper function to plot a single dataset on a given axis.
|
|
322
|
+
"""
|
|
323
|
+
im = ax.matshow(data, cmap=cmap, vmin=vmin, vmax=vmax)
|
|
324
|
+
# Remove axis ticks
|
|
325
|
+
ax.set_xticks([])
|
|
326
|
+
ax.set_yticks([])
|
|
327
|
+
# Add colorbar
|
|
328
|
+
divider = make_axes_locatable(ax)
|
|
329
|
+
cax = divider.append_axes("right", size="5%", pad=0.05)
|
|
330
|
+
cbar = plt.colorbar(im, cax=cax, ticks=np.linspace(vmin, vmax, 15))
|
|
331
|
+
cbar.set_ticks(np.linspace(vmin, vmax, 15))
|
|
332
|
+
# Set title
|
|
333
|
+
ax.set_title(title)
|
|
334
|
+
return im
|
|
273
335
|
|
|
274
336
|
import matplotlib.pyplot as plt
|
|
275
337
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
@@ -285,27 +347,24 @@ def calculate_colorbar_range(data):
|
|
|
285
347
|
vmin = int(np.nanmin(data))
|
|
286
348
|
vmax = int(np.nanmax(data))
|
|
287
349
|
return vmin, vmax
|
|
288
|
-
def
|
|
289
|
-
cmap="viridis", vmin=None, vmax=None):
|
|
350
|
+
def plotgriddata(data, titles=None,name="temp", save_dir="plots",
|
|
351
|
+
cmap="viridis", vmin=None, vmax=None): #Cited methods
|
|
290
352
|
if not isinstance(data, np.ndarray) or data.ndim != 3:
|
|
291
|
-
raise ValueError("
|
|
353
|
+
raise ValueError("The input data must be a three-dimensional NumPy array [num_images, height, width]")
|
|
292
354
|
num_images = data.shape[0]
|
|
293
355
|
titles = titles or [f"Data {i}" for i in range(num_images)]
|
|
294
|
-
if vmin is None or vmax is None:
|
|
295
|
-
vmin, vmax = calculate_colorbar_range(data)
|
|
296
356
|
ncols = int(np.ceil(np.sqrt(num_images)))
|
|
297
357
|
nrows = int(np.ceil(num_images / ncols))
|
|
298
358
|
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 3, nrows * 3))
|
|
299
359
|
axes = axes.ravel()
|
|
300
360
|
for i in range(num_images):
|
|
301
361
|
ax = axes[i]
|
|
302
|
-
im = ax.imshow(data[i], cmap=cmap
|
|
362
|
+
im = ax.imshow(data[i], cmap=cmap)
|
|
303
363
|
ax.set_title(titles[i])
|
|
304
364
|
ax.axis('off')
|
|
305
365
|
divider = make_axes_locatable(ax)
|
|
306
366
|
cax = divider.append_axes("right", size="5%", pad=0.05)
|
|
307
367
|
cbar = plt.colorbar(im, cax=cax,
|
|
308
|
-
ticks=np.linspace(vmin, vmax, 15),
|
|
309
368
|
format='%.1f')
|
|
310
369
|
cbar.ax.tick_params(labelsize=6)
|
|
311
370
|
for j in range(num_images, len(axes)):
|
|
@@ -316,7 +375,7 @@ def plot_grid_data(data, titles=None, save_dir="plots", name="temp",
|
|
|
316
375
|
filename = f"{name}_{timestamp}.png"
|
|
317
376
|
plt.savefig(os.path.join(save_dir, filename), dpi=300)
|
|
318
377
|
plt.close()
|
|
319
|
-
def
|
|
378
|
+
def plotDrawpic(basedata, save_dir="plotDrawpic_com", name="temp", cmap="summer"):
|
|
320
379
|
data_all = basedata[:,::2,::2]
|
|
321
380
|
if isinstance(name, str):
|
|
322
381
|
print("name str")
|
|
@@ -324,7 +383,7 @@ def drawpic_com(basedata, save_dir="plots", name="temp", cmap="summer"):
|
|
|
324
383
|
else:
|
|
325
384
|
titles = [f"{i}" for i in name.strftime("%Y%m%d%H%M%S")]
|
|
326
385
|
name = name.strftime("%Y%m%d%H%M%S")[0]
|
|
327
|
-
|
|
386
|
+
plotgriddata(
|
|
328
387
|
data=data_all,
|
|
329
388
|
titles=titles,
|
|
330
389
|
name=name,
|
|
@@ -333,398 +392,65 @@ def drawpic_com(basedata, save_dir="plots", name="temp", cmap="summer"):
|
|
|
333
392
|
)
|
|
334
393
|
|
|
335
394
|
"""
|
|
336
|
-
drawpic_com(Data_con, save_dir="
|
|
337
|
-
"""
|
|
338
|
-
|
|
339
|
-
import os
|
|
340
|
-
import matplotlib.pyplot as plt
|
|
341
|
-
from matplotlib.ticker import MaxNLocator
|
|
342
|
-
class trainingVis:
|
|
343
|
-
def __init__(self, args=None, dataset_key_map=None, root_path="./"):
|
|
344
|
-
self.args = args
|
|
345
|
-
self.dataset_key_map = dataset_key_map
|
|
346
|
-
self.root_path = root_path
|
|
347
|
-
self.record = {
|
|
348
|
-
"train_loss": [],
|
|
349
|
-
"train_psnr": [],
|
|
350
|
-
"val_loss": [],
|
|
351
|
-
"val_psnr": [],
|
|
352
|
-
}
|
|
353
|
-
self.x_epoch = []
|
|
354
|
-
self.output_dir = self._get_output_dir()
|
|
355
|
-
def _get_output_dir(self):
|
|
356
|
-
"""生成输出目录路径并创建目录"""
|
|
357
|
-
output_dir = os.path.join(
|
|
358
|
-
self.root_path,
|
|
359
|
-
"Rec",
|
|
360
|
-
# "weights_dir",
|
|
361
|
-
# self.dataset_key_map[self.args.dataset_key],
|
|
362
|
-
"trainvalViscure",
|
|
363
|
-
)
|
|
364
|
-
os.makedirs(output_dir, exist_ok=True)
|
|
365
|
-
return output_dir
|
|
366
|
-
def _plot_curve(self, ax, x, y_train, y_val, y_label, train_color="blue", val_color="red"):
|
|
367
|
-
"""绘制单条曲线并优化坐标轴显示"""
|
|
368
|
-
if y_train is not None:
|
|
369
|
-
ax.plot(x, y_train, marker='o', linestyle='-', color=train_color, label="Train")
|
|
370
|
-
if y_val is not None:
|
|
371
|
-
ax.plot(x, y_val, marker='o', linestyle='-', color=val_color, label="Val")
|
|
372
|
-
# 设置坐标轴标签和格式
|
|
373
|
-
ax.set_xlabel("Epoch", fontsize=10)
|
|
374
|
-
ax.set_ylabel(y_label, fontsize=10)
|
|
375
|
-
ax.set_title(f"{y_label} Curve", fontsize=12)
|
|
376
|
-
# 配置x轴刻度(确保整数显示)
|
|
377
|
-
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
|
378
|
-
plt.setp(ax.get_xticklabels(), rotation=30, ha='right', fontsize=8)
|
|
379
|
-
# 配置y轴刻度
|
|
380
|
-
ax.yaxis.set_major_locator(MaxNLocator(nbins=6))
|
|
381
|
-
plt.setp(ax.get_yticklabels(), fontsize=8)
|
|
382
|
-
# 添加图例
|
|
383
|
-
ax.legend(loc="upper right", fontsize=8)
|
|
384
|
-
def draw_curve(self, epoch=None, train_loss=None, train_psnr=None, val_loss=None, val_psnr=None):
|
|
385
|
-
"""动态绘制训练曲线并根据数据存在性调整布局"""
|
|
386
|
-
self.record["train_loss"].append(train_loss)
|
|
387
|
-
self.record["val_loss"].append(val_loss)
|
|
388
|
-
self.x_epoch.append(epoch)
|
|
389
|
-
has_psnr = train_psnr is not None and val_psnr is not None
|
|
390
|
-
if has_psnr:
|
|
391
|
-
self.record["train_psnr"].append(train_psnr)
|
|
392
|
-
self.record["val_psnr"].append(val_psnr)
|
|
393
|
-
else:
|
|
394
|
-
self.record["train_psnr"].append(None)
|
|
395
|
-
self.record["val_psnr"].append(None)
|
|
396
|
-
fig = plt.figure(figsize=(10, 4.5) if has_psnr else (6, 4.5))
|
|
397
|
-
plt.subplots_adjust(wspace=0.3 if has_psnr else 0)
|
|
398
|
-
ax0 = fig.add_subplot(111 if not has_psnr else 121)
|
|
399
|
-
# 如果只有train_loss数据存在
|
|
400
|
-
if train_loss is not None and val_loss is None:
|
|
401
|
-
self._plot_curve(ax0
|
|
402
|
-
,self.x_epoch
|
|
403
|
-
,self.record["train_loss"]
|
|
404
|
-
,None
|
|
405
|
-
,"Loss"
|
|
406
|
-
,train_color="blue"
|
|
407
|
-
,val_color="red"
|
|
408
|
-
)
|
|
409
|
-
elif val_loss is not None and train_loss is None:
|
|
410
|
-
self._plot_curve(ax0,
|
|
411
|
-
self.x_epoch,
|
|
412
|
-
None,
|
|
413
|
-
self.record["val_loss"],
|
|
414
|
-
"Loss",
|
|
415
|
-
train_color="blue",
|
|
416
|
-
val_color="red"
|
|
417
|
-
)
|
|
418
|
-
else:
|
|
419
|
-
self._plot_curve(ax0,
|
|
420
|
-
self.x_epoch,
|
|
421
|
-
self.record["train_loss"],
|
|
422
|
-
self.record["val_loss"],
|
|
423
|
-
"Loss", train_color="blue",
|
|
424
|
-
val_color="red"
|
|
425
|
-
)
|
|
426
|
-
if has_psnr:
|
|
427
|
-
ax1 = fig.add_subplot(122)
|
|
428
|
-
self._plot_curve(ax1,
|
|
429
|
-
self.x_epoch,
|
|
430
|
-
self.record["train_psnr"],
|
|
431
|
-
self.record["val_psnr"],
|
|
432
|
-
"PSNR",
|
|
433
|
-
train_color="orange",
|
|
434
|
-
val_color="grey"
|
|
435
|
-
)
|
|
436
|
-
plt.tight_layout()
|
|
437
|
-
fig.savefig(
|
|
438
|
-
os.path.join(self.output_dir, f"train_{epoch}.jpg"),
|
|
439
|
-
dpi=300,
|
|
440
|
-
bbox_inches="tight",
|
|
441
|
-
pad_inches=0.1
|
|
442
|
-
)
|
|
443
|
-
plt.close(fig)
|
|
444
|
-
"""
|
|
445
|
-
vis3 = TrainingVis(root_path="./Rec3")
|
|
446
|
-
vis4 = TrainingVis(root_path="./Rec4")
|
|
447
|
-
vis5 = TrainingVis(root_path="./Rec5")
|
|
448
|
-
for epoch in range(t.epoch, t.epoch + args.num_epochs):
|
|
449
|
-
train_loss, train_psnr = t.train(epoch)
|
|
450
|
-
val_loss, val_psnr = t.val(epoch)
|
|
451
|
-
if (epoch + 1) % 3 == 0:
|
|
452
|
-
# t.draw_curve(fig, epoch, train_loss, train_psnr, val_loss, val_psnr)
|
|
453
|
-
vis.draw_curve(epoch, train_loss, train_psnr, val_loss, val_psnr)
|
|
454
|
-
vis1.draw_curve(epoch, train_loss,val_loss)
|
|
455
|
-
vis2.draw_curve(epoch, train_loss,val_loss,train_psnr,val_psnr)
|
|
456
|
-
vis3.draw_curve(epoch, train_loss,val_loss,train_psnr,val_psnr)
|
|
457
|
-
vis4.draw_curve(epoch, train_loss)
|
|
458
|
-
vis5.draw_curve(epoch, val_loss)
|
|
459
|
-
------------------------
|
|
460
|
-
from shancx.Plot import trainingVis
|
|
461
|
-
vis= trainingVis(root_path="./Rec3")
|
|
462
|
-
if (epoch + 1) % 3== 0:
|
|
463
|
-
vis.draw_curve(epoch, epoch_loss.detach().cpu().numpy(),epoch_val_loss.detach().cpu().numpy())
|
|
464
|
-
"""
|
|
465
|
-
|
|
466
|
-
class trainingVisplus:
|
|
467
|
-
def __init__(self, args=None, dataset_key_map=None, root_path="./"):
|
|
468
|
-
self.args = args
|
|
469
|
-
self.dataset_key_map = dataset_key_map
|
|
470
|
-
self.root_path = root_path
|
|
471
|
-
self.record = {
|
|
472
|
-
"train_loss": [],
|
|
473
|
-
"train_psnr": [],
|
|
474
|
-
"train_acc": [],
|
|
475
|
-
"val_loss": [],
|
|
476
|
-
"val_psnr": [],
|
|
477
|
-
"val_acc": [],
|
|
478
|
-
}
|
|
479
|
-
self.x_epoch = []
|
|
480
|
-
self.output_dir = self._get_output_dir()
|
|
481
|
-
|
|
482
|
-
def _get_output_dir(self):
|
|
483
|
-
"""生成输出目录路径并创建目录"""
|
|
484
|
-
output_dir = os.path.join(
|
|
485
|
-
self.root_path,
|
|
486
|
-
"Rec",
|
|
487
|
-
"trainvalViscure",
|
|
488
|
-
)
|
|
489
|
-
os.makedirs(output_dir, exist_ok=True)
|
|
490
|
-
return output_dir
|
|
491
|
-
|
|
492
|
-
def _plot_curve(
|
|
493
|
-
self,
|
|
494
|
-
ax,
|
|
495
|
-
x,
|
|
496
|
-
y_train,
|
|
497
|
-
y_val,
|
|
498
|
-
y_label,
|
|
499
|
-
train_color="blue",
|
|
500
|
-
val_color="red"
|
|
501
|
-
):
|
|
502
|
-
"""绘制单条曲线并优化坐标轴显示"""
|
|
503
|
-
if y_train is not None:
|
|
504
|
-
ax.plot(x, y_train,
|
|
505
|
-
marker='o',
|
|
506
|
-
linestyle='-',
|
|
507
|
-
color=train_color,
|
|
508
|
-
label="Train"
|
|
509
|
-
)
|
|
510
|
-
if y_val is not None:
|
|
511
|
-
ax.plot(x, y_val,
|
|
512
|
-
marker='o',
|
|
513
|
-
linestyle='-',
|
|
514
|
-
color=val_color,
|
|
515
|
-
label="Val"
|
|
516
|
-
)
|
|
517
|
-
# 设置坐标轴标签和格式
|
|
518
|
-
ax.set_xlabel("Epoch", fontsize=10)
|
|
519
|
-
ax.set_ylabel(y_label, fontsize=10)
|
|
520
|
-
ax.set_title(f"{y_label} Curve", fontsize=12)
|
|
521
|
-
# 配置x轴刻度(确保整数显示)
|
|
522
|
-
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
|
523
|
-
plt.setp(ax.get_xticklabels(), rotation=30, ha='right', fontsize=8)
|
|
524
|
-
# 配置y轴刻度
|
|
525
|
-
ax.yaxis.set_major_locator(MaxNLocator(nbins=6))
|
|
526
|
-
plt.setp(ax.get_yticklabels(), fontsize=8)
|
|
527
|
-
# 添加图例
|
|
528
|
-
ax.legend(loc="upper right", fontsize=8)
|
|
529
|
-
|
|
530
|
-
def draw_curve(
|
|
531
|
-
self, epoch=None,
|
|
532
|
-
train_loss=None,
|
|
533
|
-
train_psnr=None,
|
|
534
|
-
val_loss=None,
|
|
535
|
-
val_psnr=None,
|
|
536
|
-
train_acc=None,
|
|
537
|
-
val_acc=None
|
|
538
|
-
):
|
|
539
|
-
"""动态绘制训练曲线并根据数据存在性调整布局"""
|
|
540
|
-
# 更新训练记录
|
|
541
|
-
self.record["train_loss"].append(train_loss)
|
|
542
|
-
self.record["val_loss"].append(val_loss)
|
|
543
|
-
self.x_epoch.append(epoch)
|
|
544
|
-
|
|
545
|
-
# 有条件地更新PSNR和Acc记录
|
|
546
|
-
has_psnr = train_psnr is not None and val_psnr is not None
|
|
547
|
-
has_acc = train_acc is not None and val_acc is not None
|
|
548
|
-
|
|
549
|
-
if has_psnr:
|
|
550
|
-
self.record["train_psnr"].append(train_psnr)
|
|
551
|
-
self.record["val_psnr"].append(val_psnr)
|
|
552
|
-
else:
|
|
553
|
-
# 用None占位保持数据对齐
|
|
554
|
-
self.record["train_psnr"].append(None)
|
|
555
|
-
self.record["val_psnr"].append(None)
|
|
556
|
-
|
|
557
|
-
if has_acc:
|
|
558
|
-
self.record["train_acc"].append(train_acc)
|
|
559
|
-
self.record["val_acc"].append(val_acc)
|
|
560
|
-
else:
|
|
561
|
-
# 用None占位保持数据对齐
|
|
562
|
-
self.record["train_acc"].append(None)
|
|
563
|
-
self.record["val_acc"].append(None)
|
|
564
|
-
|
|
565
|
-
# 创建自适应布局的画布
|
|
566
|
-
num_plots = 1 + int(has_psnr) + int(has_acc)
|
|
567
|
-
fig_width = 4 * num_plots # 每个子图宽度为4
|
|
568
|
-
fig = plt.figure(figsize=(fig_width, 4.5))
|
|
569
|
-
plt.subplots_adjust(wspace=0.3)
|
|
570
|
-
|
|
571
|
-
# 绘制Loss曲线
|
|
572
|
-
ax0 = fig.add_subplot(1, num_plots, 1)
|
|
573
|
-
self._plot_curve(
|
|
574
|
-
ax0,
|
|
575
|
-
self.x_epoch,
|
|
576
|
-
self.record["train_loss"],
|
|
577
|
-
self.record["val_loss"],
|
|
578
|
-
"Loss",
|
|
579
|
-
train_color="blue",
|
|
580
|
-
val_color="red"
|
|
581
|
-
)
|
|
582
|
-
|
|
583
|
-
# 绘制PSNR曲线(如果存在)
|
|
584
|
-
if has_psnr:
|
|
585
|
-
ax1 = fig.add_subplot(1, num_plots, 2)
|
|
586
|
-
self._plot_curve(
|
|
587
|
-
ax1,
|
|
588
|
-
self.x_epoch,
|
|
589
|
-
self.record["train_psnr"],
|
|
590
|
-
self.record["val_psnr"],
|
|
591
|
-
"PSNR",
|
|
592
|
-
train_color="orange",
|
|
593
|
-
val_color="grey"
|
|
594
|
-
)
|
|
595
|
-
|
|
596
|
-
# 绘制Acc曲线(如果存在)
|
|
597
|
-
if has_acc:
|
|
598
|
-
ax2 = fig.add_subplot(1, num_plots, num_plots)
|
|
599
|
-
self._plot_curve(
|
|
600
|
-
ax2,
|
|
601
|
-
self.x_epoch,
|
|
602
|
-
self.record["train_acc"],
|
|
603
|
-
self.record["val_acc"],
|
|
604
|
-
"Accuracy",
|
|
605
|
-
train_color="grey",
|
|
606
|
-
val_color="purple"
|
|
607
|
-
)
|
|
608
|
-
|
|
609
|
-
# 优化布局并保存
|
|
610
|
-
plt.tight_layout()
|
|
611
|
-
fig.savefig(
|
|
612
|
-
os.path.join(self.output_dir, f"train_{epoch}.jpg"),
|
|
613
|
-
dpi=300,
|
|
614
|
-
bbox_inches="tight",
|
|
615
|
-
pad_inches=0.1
|
|
616
|
-
)
|
|
617
|
-
plt.close(fig)
|
|
618
|
-
"""
|
|
619
|
-
|
|
620
|
-
vis1= trainingVisplus(root_path="./Rec4")
|
|
621
|
-
for epoch in range(t.epoch, t.epoch + args.num_epochs):
|
|
622
|
-
train_loss, train_psnr = t.train(epoch)
|
|
623
|
-
val_loss, val_psnr = t.val(epoch)
|
|
624
|
-
if (epoch + 1) % 3 == 0:
|
|
625
|
-
vis1.draw_curve(epoch=epoch,
|
|
626
|
-
train_loss=epoch_loss.detach().cpu().numpy(),
|
|
627
|
-
val_loss=epoch_val_loss.detach().cpu().numpy(),
|
|
628
|
-
train_acc=epoch_accuracy,
|
|
629
|
-
val_acc=epoch_val_accuracy
|
|
630
|
-
)
|
|
631
|
-
|
|
632
|
-
"""
|
|
633
|
-
|
|
634
|
-
# @staticmethod
|
|
635
|
-
# def calculate_psnr(img1, img2):
|
|
636
|
-
# return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
|
|
637
|
-
""" 使用方法
|
|
638
|
-
psnr += self.calculate_psnr(fake_img, label).item()
|
|
639
|
-
total += 1
|
|
640
|
-
mean_psnr = psnr / total
|
|
641
|
-
"""
|
|
642
|
-
|
|
395
|
+
drawpic_com(Data_con, save_dir="plotDrawpic_com", name=timeList )
|
|
396
|
+
"""
|
|
643
397
|
from hjnwtx.colormap import cmp_hjnwtx
|
|
644
398
|
import matplotlib.pyplot as plt
|
|
645
399
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
646
400
|
import numpy as np
|
|
647
401
|
import os
|
|
648
402
|
import datetime
|
|
649
|
-
def
|
|
650
|
-
"""计算色标范围"""
|
|
651
|
-
valid_data = data[~np.isnan(data)]
|
|
652
|
-
vmin = int(np.min(valid_data))
|
|
653
|
-
vmax = int(np.max(valid_data))
|
|
654
|
-
|
|
655
|
-
return vmin, vmax
|
|
656
|
-
def plot_grid_data(data, titles=None, saveDir="plots", name="temp",
|
|
657
|
-
cmap="viridis", vmin=None, vmax=None):
|
|
403
|
+
def plot_grid_data(data, titles=None, saveDir="plots", name="temp", cmap="summer",radarnmc=1):
|
|
658
404
|
if not isinstance(data, np.ndarray) or data.ndim != 3:
|
|
659
|
-
raise ValueError("
|
|
405
|
+
raise ValueError("The input data must be a three-dimensional NumPy array [num_images, height, width]")
|
|
660
406
|
num_images = data.shape[0]
|
|
661
407
|
titles = titles or [f"Data {i}" for i in range(num_images)]
|
|
662
|
-
# 计算色标范围
|
|
663
|
-
if vmin is None or vmax is None:
|
|
664
|
-
vmin, vmax = calculate_colorbar_range(data)
|
|
665
|
-
# 计算子图布局
|
|
666
408
|
ncols = int(np.ceil(np.sqrt(num_images)))
|
|
667
409
|
nrows = int(np.ceil(num_images / ncols))
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
410
|
+
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 3))
|
|
411
|
+
if num_images == 1:
|
|
412
|
+
axes = np.array([[axes]])
|
|
413
|
+
elif axes.ndim == 1:
|
|
414
|
+
axes = axes.reshape(1, -1)
|
|
415
|
+
axes_flat = axes.ravel()
|
|
671
416
|
for i in range(num_images):
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
ax = axes[i]
|
|
676
|
-
im = ax.imshow(data[i], cmap=cmp_hjnwtx["radar_nmc"], vmin=vmin, vmax=vmax) #cmp_hjnwtx["radar_nmc"]
|
|
417
|
+
ax = axes_flat[i]
|
|
418
|
+
if i >= num_images - radarnmc:
|
|
419
|
+
im = ax.imshow(data[i], cmap=cmp_hjnwtx["radar_nmc"])
|
|
677
420
|
else:
|
|
678
|
-
|
|
679
|
-
vmax = 300
|
|
680
|
-
ax = axes[i]
|
|
681
|
-
im = ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
|
|
421
|
+
im = ax.imshow(data[i], cmap=cmap)
|
|
682
422
|
ax.set_title(titles[i])
|
|
683
|
-
ax.axis('off')
|
|
684
|
-
# 添加色标
|
|
423
|
+
ax.axis('off')
|
|
685
424
|
divider = make_axes_locatable(ax)
|
|
686
425
|
cax = divider.append_axes("right", size="5%", pad=0.05)
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
cbar.ax.tick_params(labelsize=6) # 调整刻度文字大小
|
|
691
|
-
# 隐藏空子图
|
|
692
|
-
for j in range(num_images, len(axes)):
|
|
693
|
-
axes[j].axis('off')
|
|
694
|
-
# 保存图像
|
|
426
|
+
plt.colorbar(im, cax=cax)
|
|
427
|
+
for j in range(num_images, len(axes_flat)):
|
|
428
|
+
axes_flat[j].axis('off')
|
|
695
429
|
plt.tight_layout()
|
|
696
430
|
os.makedirs(saveDir, exist_ok=True)
|
|
697
431
|
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
698
432
|
filename = f"{name}_{timestamp}.png"
|
|
699
433
|
plt.savefig(os.path.join(saveDir, filename), dpi=300)
|
|
700
434
|
plt.close()
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
组合数据并调用绘图函数
|
|
704
|
-
Args:
|
|
705
|
-
base_up: 上部数据数组 [shape_len, H, W]
|
|
706
|
-
base_down: 下部数据数组 [1, H, W]
|
|
707
|
-
shape_len: 上部数据数量
|
|
708
|
-
name: 输出文件名前缀
|
|
709
|
-
cmap: 颜色映射
|
|
710
|
-
"""
|
|
711
|
-
# 合并数据并生成标题
|
|
435
|
+
|
|
436
|
+
def plotTr(base_up, base_down, name="plotTr", saveDir="plotTr",cmap="summer",radarnmc=1):
|
|
712
437
|
data_all = np.concatenate([base_up, base_down], axis=0)
|
|
713
|
-
titles = [f"
|
|
714
|
-
# 调用绘图函数
|
|
438
|
+
titles = [f"Pic_{i}" for i in range(base_up.shape[0])] + [f"Pic_{i+1}" for i in range(base_down.shape[0])]
|
|
715
439
|
plot_grid_data(
|
|
716
440
|
data=data_all,
|
|
717
441
|
titles=titles,
|
|
718
442
|
name=name,
|
|
719
443
|
saveDir=saveDir,
|
|
720
|
-
cmap=cmap
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
444
|
+
cmap=cmap,
|
|
445
|
+
radarnmc = radarnmc
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
"""
|
|
449
|
+
if __name__ == "__main__":
|
|
450
|
+
base_up = np.random.rand(10, 50, 50) * 70
|
|
451
|
+
base_down = np.random.rand(1, 50, 50) * 70
|
|
452
|
+
plotTr(base_up, base_down, name="radar_plot",cmap="summer",radarnmc=1) # radar_mask.detach().cpu().numpy() tensor转numpy
|
|
453
|
+
"""
|
|
728
454
|
|
|
729
455
|
import numpy as np
|
|
730
456
|
import matplotlib.pyplot as plt
|