shancx 1.9.33.109__py3-none-any.whl → 1.9.33.218__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. shancx/{Dsalgor → Algo}/__init__.py +37 -1
  2. shancx/Calmetrics/__init__.py +78 -9
  3. shancx/Calmetrics/calmetrics.py +14 -0
  4. shancx/Calmetrics/rmseR2score.py +14 -3
  5. shancx/{Command.py → Cmd.py} +20 -15
  6. shancx/Config_.py +26 -0
  7. shancx/Df/__init__.py +11 -0
  8. shancx/Df/tool.py +0 -1
  9. shancx/NN/__init__.py +200 -11
  10. shancx/{path.py → Path1.py} +2 -3
  11. shancx/Plot/__init__.py +129 -403
  12. shancx/Plot/draw_day_CR_PNG.py +4 -21
  13. shancx/Plot/exam.py +116 -0
  14. shancx/Plot/plotGlobal.py +325 -0
  15. shancx/Plot/radarNmc.py +1 -48
  16. shancx/Plot/single_china_map.py +1 -1
  17. shancx/Point.py +46 -0
  18. shancx/QC.py +223 -0
  19. shancx/Read.py +17 -10
  20. shancx/Resize.py +79 -0
  21. shancx/SN/__init__.py +8 -1
  22. shancx/Time/timeCycle.py +97 -23
  23. shancx/Train/makelist.py +161 -155
  24. shancx/__init__.py +79 -232
  25. shancx/bak.py +78 -53
  26. shancx/geosProj.py +2 -2
  27. shancx/wait.py +35 -1
  28. {shancx-1.9.33.109.dist-info → shancx-1.9.33.218.dist-info}/METADATA +12 -4
  29. shancx-1.9.33.218.dist-info/RECORD +91 -0
  30. {shancx-1.9.33.109.dist-info → shancx-1.9.33.218.dist-info}/WHEEL +1 -1
  31. shancx/Plot/Mip.py +0 -42
  32. shancx/Plot/border.py +0 -44
  33. shancx/Plot/draw_day_CR_PNGUS.py +0 -206
  34. shancx/Plot/draw_day_CR_SVG.py +0 -275
  35. shancx/Plot/draw_day_pre_PNGUS.py +0 -205
  36. shancx/Plot/radar_nmc_china_map_compare1.py +0 -50
  37. shancx/makenetCDFN.py +0 -42
  38. shancx-1.9.33.109.dist-info/RECORD +0 -91
  39. /shancx/{3DJU → 3D}/__init__.py +0 -0
  40. /shancx/{Dsalgor → Algo}/Class.py +0 -0
  41. /shancx/{Dsalgor → Algo}/CudaPrefetcher1.py +0 -0
  42. /shancx/{Dsalgor → Algo}/Fake_image.py +0 -0
  43. /shancx/{Dsalgor → Algo}/Hsml.py +0 -0
  44. /shancx/{Dsalgor → Algo}/L2Loss.py +0 -0
  45. /shancx/{Dsalgor → Algo}/MetricTracker.py +0 -0
  46. /shancx/{Dsalgor → Algo}/Normalize.py +0 -0
  47. /shancx/{Dsalgor → Algo}/OptimizerWScheduler.py +0 -0
  48. /shancx/{Dsalgor → Algo}/Rmageresize.py +0 -0
  49. /shancx/{Dsalgor → Algo}/Savemodel.py +0 -0
  50. /shancx/{Dsalgor → Algo}/SmoothL1_losses.py +0 -0
  51. /shancx/{Dsalgor → Algo}/Tqdm.py +0 -0
  52. /shancx/{Dsalgor → Algo}/checknan.py +0 -0
  53. /shancx/{Dsalgor → Algo}/dsalgor.py +0 -0
  54. /shancx/{Dsalgor → Algo}/iouJU.py +0 -0
  55. /shancx/{Dsalgor → Algo}/mask.py +0 -0
  56. /shancx/{Dsalgor → Algo}/psnr.py +0 -0
  57. /shancx/{Dsalgor → Algo}/ssim.py +0 -0
  58. /shancx/{Dsalgor → Algo}/structural_similarity.py +0 -0
  59. /shancx/{Dsalgor → Algo}/tool.py +0 -0
  60. /shancx/Calmetrics/{matrixLib.py → calmetricsmatrixLib.py} +0 -0
  61. /shancx/{Diffmodel → Diffm}/Psamples.py +0 -0
  62. /shancx/{Diffmodel → Diffm}/__init__.py +0 -0
  63. /shancx/{Diffmodel → Diffm}/test.py +0 -0
  64. /shancx/{Board → tensBoard}/__init__.py +0 -0
  65. {shancx-1.9.33.109.dist-info → shancx-1.9.33.218.dist-info}/top_level.txt +0 -0
shancx/Plot/__init__.py CHANGED
@@ -21,9 +21,9 @@ def plotGrey(img,name="plotGrey", saveDir="plotGrey",cmap='gray', title='Image')
21
21
  import matplotlib.pyplot as plt
22
22
  from shancx import crDir
23
23
  import datetime
24
- def plotMat(matrix,name='plotMat',saveDir="plotMat",title='Matrix Plot', xlabel='X-axis', ylabel='Y-axis', color_label='Value', cmap='viridis'):
24
+ def plotMat(matrix,name='plotMat',saveDir="plotMat",title='Matrix Plot', xlabel='X-axis', ylabel='Y-axis', color_label='Value', cmap='viridis',aspect="equal"): #aspect='auto'
25
25
  now_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
26
- plt.imshow(matrix, cmap=cmap, origin='upper', aspect='auto')
26
+ plt.imshow(matrix, cmap=cmap, origin='upper', aspect=f'{aspect}')
27
27
  plt.colorbar(label=color_label)
28
28
  plt.title(title)
29
29
  plt.xlabel(xlabel)
@@ -34,6 +34,35 @@ def plotMat(matrix,name='plotMat',saveDir="plotMat",title='Matrix Plot', xlabel=
34
34
  plt.savefig(outpath)
35
35
  plt.close()
36
36
 
37
+ import matplotlib.pyplot as plt
38
+ from shancx import crDir
39
+ import datetime
40
+ def plotMatplus(matrix, name='plotMat', saveDir="plotMat", title='Matrix Plot',
41
+ xlabel='Longitude', ylabel='Latitude', color_label='Value',
42
+ cmap='viridis', extent=None):
43
+ """
44
+ extent: [lon_min, lon_max, lat_min, lat_max]
45
+ """
46
+ now_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
47
+ plt.imshow(matrix, cmap=cmap, origin='upper', aspect='auto', extent=extent)
48
+ plt.colorbar(label=color_label)
49
+ plt.title(title)
50
+
51
+ # 添加度符号和方向标识
52
+ plt.xlabel(f'{xlabel} (°E)') # 东经
53
+ plt.ylabel(f'{ylabel} (°N)') # 北纬
54
+
55
+ plt.tight_layout()
56
+ outpath = f'./{saveDir}/{name}_{now_str}.png' if name=="plotMat" else f"./{saveDir}/{name}.png"
57
+ crDir(outpath)
58
+ plt.savefig(outpath)
59
+ plt.close()
60
+ """
61
+ latlon = [10.0, 37.0, 105.0, 125.0]
62
+ latmin, latmax, lonmin, lonmax = latlon
63
+ plotMatplus(data,extent=[lon_min, lon_max, lat_min, lat_max])
64
+ """
65
+
37
66
  import datetime
38
67
  from hjnwtx.colormap import cmp_hjnwtx
39
68
  from shancx import crDir
@@ -45,7 +74,6 @@ def plotRadar(array_dt,name="plotRadar", saveDir="plotRadar",ty="CR"):
45
74
  if len(array_dt.shape) == 2 and ty == "pre":
46
75
  fig, ax = plt.subplots()
47
76
  im = ax.imshow(array_dt, vmin=0, vmax=10, cmap=cmp_hjnwtx["pre_tqw"])
48
- # 创建与图像高度一致的colorbar
49
77
  divider = make_axes_locatable(ax)
50
78
  cax = divider.append_axes("right", size="5%", pad=0.05)
51
79
  plt.colorbar(im, cax=cax)
@@ -122,6 +150,27 @@ def plotScatter(df1,saveDir="plotScatter"):
122
150
  plt.savefig(f"./{saveDir}/plotScatter_{now_str}.png", dpi=300, bbox_inches="tight")
123
151
  plt.close()
124
152
 
153
+ import matplotlib.pyplot as plt
154
+ import os
155
+ def plotScatter1(true,pre,saveDir="plotScatter"):
156
+ now_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
157
+ plt.figure(figsize=(10, 8))
158
+ plt.scatter(
159
+ true,
160
+ pre,
161
+ s=25,
162
+ alpha=0.6,
163
+ edgecolor="black",
164
+ linewidth=0.5
165
+ )
166
+ plt.title("Scatter Plot of Ture Pre", fontsize=14)
167
+ plt.xlabel("Longitude", fontsize=12)
168
+ plt.ylabel("Latitude", fontsize=12)
169
+ plt.tight_layout()
170
+ os.makedirs(saveDir, exist_ok=True)
171
+ plt.savefig(f"./{saveDir}/plotScatter1_{now_str}.png", dpi=300, bbox_inches="tight")
172
+ plt.close()
173
+
125
174
  import numpy as np
126
175
  import matplotlib.pyplot as plt
127
176
  from mpl_toolkits.axes_grid1 import make_axes_locatable
@@ -148,7 +197,7 @@ def plotVal( epoch=0,*datasets, title=["input","prediction","truth"], saveDir="p
148
197
  plt.savefig(filename)
149
198
  plt.close(fig)
150
199
 
151
- """ 使用方法
200
+ """
152
201
  if total >= 3:
153
202
  break
154
203
  if epoch % 2 == 0:
@@ -200,11 +249,11 @@ def plotValplus(epoch=0, *datasets, title=["input", "prediction", "truth"], save
200
249
  filename = f"{saveDir}/epoch_{epoch}.png"
201
250
  plt.savefig(filename)
202
251
  plt.close(fig)
203
- """ 使用方法
252
+ """
204
253
  if total >= 3:
205
254
  break
206
255
  if epoch % 2 == 0:
207
- plotVal(epoch,
256
+ plotValplus(epoch,
208
257
  data[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
209
258
  output[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
210
259
  label[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
@@ -213,30 +262,6 @@ def plotValplus(epoch=0, *datasets, title=["input", "prediction", "truth"], save
213
262
  )
214
263
  """
215
264
 
216
- import numpy as np
217
- import matplotlib
218
- matplotlib.use("Agg")
219
- import matplotlib.pyplot as plt
220
- from mpl_toolkits.axes_grid1 import make_axes_locatable
221
- from hjnwtx.colormap import cmp_hjnwtx
222
- from shancx import crDir
223
- import os
224
- def plot_dataset(ax, data, title, cmap, vmin, vmax):
225
- """
226
- Helper function to plot a single dataset on a given axis.
227
- """
228
- im = ax.matshow(data, cmap=cmap, vmin=vmin, vmax=vmax)
229
- # Remove axis ticks
230
- ax.set_xticks([])
231
- ax.set_yticks([])
232
- # Add colorbar
233
- divider = make_axes_locatable(ax)
234
- cax = divider.append_axes("right", size="5%", pad=0.05)
235
- cbar = plt.colorbar(im, cax=cax, ticks=np.linspace(vmin, vmax, 15))
236
- cbar.set_ticks(np.linspace(vmin, vmax, 15))
237
- # Set title
238
- ax.set_title(title)
239
- return im
240
265
 
241
266
  def plotValplus1(epoch=0, *datasets, title=["input", "prediction", "truth"], saveDir="plotValplus", cmap='summer'):
242
267
  """
@@ -270,6 +295,43 @@ def plotValplus1(epoch=0, *datasets, title=["input", "prediction", "truth"], sav
270
295
  plt.savefig(filename)
271
296
  plt.close(fig)
272
297
 
298
+ """
299
+ if total >= 3:
300
+ break
301
+ if epoch % 2 == 0:
302
+ plotValplus1(epoch,
303
+ data[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
304
+ output[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
305
+ label[0][0].detach().cpu().numpy().squeeze(), # 使用 detach()
306
+ title=["input", "prediction", "groundtruth"],
307
+ saveDir="plot_train_dir"
308
+ )
309
+ """
310
+
311
+ import numpy as np
312
+ import matplotlib
313
+ matplotlib.use("Agg")
314
+ import matplotlib.pyplot as plt
315
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
316
+ from hjnwtx.colormap import cmp_hjnwtx
317
+ from shancx import crDir
318
+ import os
319
+ def plot_dataset(ax, data, title, cmap, vmin, vmax): #Cited methods
320
+ """
321
+ Helper function to plot a single dataset on a given axis.
322
+ """
323
+ im = ax.matshow(data, cmap=cmap, vmin=vmin, vmax=vmax)
324
+ # Remove axis ticks
325
+ ax.set_xticks([])
326
+ ax.set_yticks([])
327
+ # Add colorbar
328
+ divider = make_axes_locatable(ax)
329
+ cax = divider.append_axes("right", size="5%", pad=0.05)
330
+ cbar = plt.colorbar(im, cax=cax, ticks=np.linspace(vmin, vmax, 15))
331
+ cbar.set_ticks(np.linspace(vmin, vmax, 15))
332
+ # Set title
333
+ ax.set_title(title)
334
+ return im
273
335
 
274
336
  import matplotlib.pyplot as plt
275
337
  from mpl_toolkits.axes_grid1 import make_axes_locatable
@@ -285,27 +347,24 @@ def calculate_colorbar_range(data):
285
347
  vmin = int(np.nanmin(data))
286
348
  vmax = int(np.nanmax(data))
287
349
  return vmin, vmax
288
- def plot_grid_data(data, titles=None, save_dir="plots", name="temp",
289
- cmap="viridis", vmin=None, vmax=None):
350
+ def plotgriddata(data, titles=None,name="temp", save_dir="plots",
351
+ cmap="viridis", vmin=None, vmax=None): #Cited methods
290
352
  if not isinstance(data, np.ndarray) or data.ndim != 3:
291
- raise ValueError("输入数据必须为三维numpy数组 [num_images, height, width]")
353
+ raise ValueError("The input data must be a three-dimensional NumPy array [num_images, height, width]")
292
354
  num_images = data.shape[0]
293
355
  titles = titles or [f"Data {i}" for i in range(num_images)]
294
- if vmin is None or vmax is None:
295
- vmin, vmax = calculate_colorbar_range(data)
296
356
  ncols = int(np.ceil(np.sqrt(num_images)))
297
357
  nrows = int(np.ceil(num_images / ncols))
298
358
  fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 3, nrows * 3))
299
359
  axes = axes.ravel()
300
360
  for i in range(num_images):
301
361
  ax = axes[i]
302
- im = ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
362
+ im = ax.imshow(data[i], cmap=cmap)
303
363
  ax.set_title(titles[i])
304
364
  ax.axis('off')
305
365
  divider = make_axes_locatable(ax)
306
366
  cax = divider.append_axes("right", size="5%", pad=0.05)
307
367
  cbar = plt.colorbar(im, cax=cax,
308
- ticks=np.linspace(vmin, vmax, 15),
309
368
  format='%.1f')
310
369
  cbar.ax.tick_params(labelsize=6)
311
370
  for j in range(num_images, len(axes)):
@@ -316,7 +375,7 @@ def plot_grid_data(data, titles=None, save_dir="plots", name="temp",
316
375
  filename = f"{name}_{timestamp}.png"
317
376
  plt.savefig(os.path.join(save_dir, filename), dpi=300)
318
377
  plt.close()
319
- def drawpic_com(basedata, save_dir="plots", name="temp", cmap="summer"):
378
+ def plotDrawpic(basedata, save_dir="plotDrawpic_com", name="temp", cmap="summer"):
320
379
  data_all = basedata[:,::2,::2]
321
380
  if isinstance(name, str):
322
381
  print("name str")
@@ -324,7 +383,7 @@ def drawpic_com(basedata, save_dir="plots", name="temp", cmap="summer"):
324
383
  else:
325
384
  titles = [f"{i}" for i in name.strftime("%Y%m%d%H%M%S")]
326
385
  name = name.strftime("%Y%m%d%H%M%S")[0]
327
- plot_grid_data(
386
+ plotgriddata(
328
387
  data=data_all,
329
388
  titles=titles,
330
389
  name=name,
@@ -333,398 +392,65 @@ def drawpic_com(basedata, save_dir="plots", name="temp", cmap="summer"):
333
392
  )
334
393
 
335
394
  """
336
- drawpic_com(Data_con, save_dir="plots_H9", name=timeList )
337
- """
338
-
339
- import os
340
- import matplotlib.pyplot as plt
341
- from matplotlib.ticker import MaxNLocator
342
- class trainingVis:
343
- def __init__(self, args=None, dataset_key_map=None, root_path="./"):
344
- self.args = args
345
- self.dataset_key_map = dataset_key_map
346
- self.root_path = root_path
347
- self.record = {
348
- "train_loss": [],
349
- "train_psnr": [],
350
- "val_loss": [],
351
- "val_psnr": [],
352
- }
353
- self.x_epoch = []
354
- self.output_dir = self._get_output_dir()
355
- def _get_output_dir(self):
356
- """生成输出目录路径并创建目录"""
357
- output_dir = os.path.join(
358
- self.root_path,
359
- "Rec",
360
- # "weights_dir",
361
- # self.dataset_key_map[self.args.dataset_key],
362
- "trainvalViscure",
363
- )
364
- os.makedirs(output_dir, exist_ok=True)
365
- return output_dir
366
- def _plot_curve(self, ax, x, y_train, y_val, y_label, train_color="blue", val_color="red"):
367
- """绘制单条曲线并优化坐标轴显示"""
368
- if y_train is not None:
369
- ax.plot(x, y_train, marker='o', linestyle='-', color=train_color, label="Train")
370
- if y_val is not None:
371
- ax.plot(x, y_val, marker='o', linestyle='-', color=val_color, label="Val")
372
- # 设置坐标轴标签和格式
373
- ax.set_xlabel("Epoch", fontsize=10)
374
- ax.set_ylabel(y_label, fontsize=10)
375
- ax.set_title(f"{y_label} Curve", fontsize=12)
376
- # 配置x轴刻度(确保整数显示)
377
- ax.xaxis.set_major_locator(MaxNLocator(integer=True))
378
- plt.setp(ax.get_xticklabels(), rotation=30, ha='right', fontsize=8)
379
- # 配置y轴刻度
380
- ax.yaxis.set_major_locator(MaxNLocator(nbins=6))
381
- plt.setp(ax.get_yticklabels(), fontsize=8)
382
- # 添加图例
383
- ax.legend(loc="upper right", fontsize=8)
384
- def draw_curve(self, epoch=None, train_loss=None, train_psnr=None, val_loss=None, val_psnr=None):
385
- """动态绘制训练曲线并根据数据存在性调整布局"""
386
- self.record["train_loss"].append(train_loss)
387
- self.record["val_loss"].append(val_loss)
388
- self.x_epoch.append(epoch)
389
- has_psnr = train_psnr is not None and val_psnr is not None
390
- if has_psnr:
391
- self.record["train_psnr"].append(train_psnr)
392
- self.record["val_psnr"].append(val_psnr)
393
- else:
394
- self.record["train_psnr"].append(None)
395
- self.record["val_psnr"].append(None)
396
- fig = plt.figure(figsize=(10, 4.5) if has_psnr else (6, 4.5))
397
- plt.subplots_adjust(wspace=0.3 if has_psnr else 0)
398
- ax0 = fig.add_subplot(111 if not has_psnr else 121)
399
- # 如果只有train_loss数据存在
400
- if train_loss is not None and val_loss is None:
401
- self._plot_curve(ax0
402
- ,self.x_epoch
403
- ,self.record["train_loss"]
404
- ,None
405
- ,"Loss"
406
- ,train_color="blue"
407
- ,val_color="red"
408
- )
409
- elif val_loss is not None and train_loss is None:
410
- self._plot_curve(ax0,
411
- self.x_epoch,
412
- None,
413
- self.record["val_loss"],
414
- "Loss",
415
- train_color="blue",
416
- val_color="red"
417
- )
418
- else:
419
- self._plot_curve(ax0,
420
- self.x_epoch,
421
- self.record["train_loss"],
422
- self.record["val_loss"],
423
- "Loss", train_color="blue",
424
- val_color="red"
425
- )
426
- if has_psnr:
427
- ax1 = fig.add_subplot(122)
428
- self._plot_curve(ax1,
429
- self.x_epoch,
430
- self.record["train_psnr"],
431
- self.record["val_psnr"],
432
- "PSNR",
433
- train_color="orange",
434
- val_color="grey"
435
- )
436
- plt.tight_layout()
437
- fig.savefig(
438
- os.path.join(self.output_dir, f"train_{epoch}.jpg"),
439
- dpi=300,
440
- bbox_inches="tight",
441
- pad_inches=0.1
442
- )
443
- plt.close(fig)
444
- """
445
- vis3 = TrainingVis(root_path="./Rec3")
446
- vis4 = TrainingVis(root_path="./Rec4")
447
- vis5 = TrainingVis(root_path="./Rec5")
448
- for epoch in range(t.epoch, t.epoch + args.num_epochs):
449
- train_loss, train_psnr = t.train(epoch)
450
- val_loss, val_psnr = t.val(epoch)
451
- if (epoch + 1) % 3 == 0:
452
- # t.draw_curve(fig, epoch, train_loss, train_psnr, val_loss, val_psnr)
453
- vis.draw_curve(epoch, train_loss, train_psnr, val_loss, val_psnr)
454
- vis1.draw_curve(epoch, train_loss,val_loss)
455
- vis2.draw_curve(epoch, train_loss,val_loss,train_psnr,val_psnr)
456
- vis3.draw_curve(epoch, train_loss,val_loss,train_psnr,val_psnr)
457
- vis4.draw_curve(epoch, train_loss)
458
- vis5.draw_curve(epoch, val_loss)
459
- ------------------------
460
- from shancx.Plot import trainingVis
461
- vis= trainingVis(root_path="./Rec3")
462
- if (epoch + 1) % 3== 0:
463
- vis.draw_curve(epoch, epoch_loss.detach().cpu().numpy(),epoch_val_loss.detach().cpu().numpy())
464
- """
465
-
466
- class trainingVisplus:
467
- def __init__(self, args=None, dataset_key_map=None, root_path="./"):
468
- self.args = args
469
- self.dataset_key_map = dataset_key_map
470
- self.root_path = root_path
471
- self.record = {
472
- "train_loss": [],
473
- "train_psnr": [],
474
- "train_acc": [],
475
- "val_loss": [],
476
- "val_psnr": [],
477
- "val_acc": [],
478
- }
479
- self.x_epoch = []
480
- self.output_dir = self._get_output_dir()
481
-
482
- def _get_output_dir(self):
483
- """生成输出目录路径并创建目录"""
484
- output_dir = os.path.join(
485
- self.root_path,
486
- "Rec",
487
- "trainvalViscure",
488
- )
489
- os.makedirs(output_dir, exist_ok=True)
490
- return output_dir
491
-
492
- def _plot_curve(
493
- self,
494
- ax,
495
- x,
496
- y_train,
497
- y_val,
498
- y_label,
499
- train_color="blue",
500
- val_color="red"
501
- ):
502
- """绘制单条曲线并优化坐标轴显示"""
503
- if y_train is not None:
504
- ax.plot(x, y_train,
505
- marker='o',
506
- linestyle='-',
507
- color=train_color,
508
- label="Train"
509
- )
510
- if y_val is not None:
511
- ax.plot(x, y_val,
512
- marker='o',
513
- linestyle='-',
514
- color=val_color,
515
- label="Val"
516
- )
517
- # 设置坐标轴标签和格式
518
- ax.set_xlabel("Epoch", fontsize=10)
519
- ax.set_ylabel(y_label, fontsize=10)
520
- ax.set_title(f"{y_label} Curve", fontsize=12)
521
- # 配置x轴刻度(确保整数显示)
522
- ax.xaxis.set_major_locator(MaxNLocator(integer=True))
523
- plt.setp(ax.get_xticklabels(), rotation=30, ha='right', fontsize=8)
524
- # 配置y轴刻度
525
- ax.yaxis.set_major_locator(MaxNLocator(nbins=6))
526
- plt.setp(ax.get_yticklabels(), fontsize=8)
527
- # 添加图例
528
- ax.legend(loc="upper right", fontsize=8)
529
-
530
- def draw_curve(
531
- self, epoch=None,
532
- train_loss=None,
533
- train_psnr=None,
534
- val_loss=None,
535
- val_psnr=None,
536
- train_acc=None,
537
- val_acc=None
538
- ):
539
- """动态绘制训练曲线并根据数据存在性调整布局"""
540
- # 更新训练记录
541
- self.record["train_loss"].append(train_loss)
542
- self.record["val_loss"].append(val_loss)
543
- self.x_epoch.append(epoch)
544
-
545
- # 有条件地更新PSNR和Acc记录
546
- has_psnr = train_psnr is not None and val_psnr is not None
547
- has_acc = train_acc is not None and val_acc is not None
548
-
549
- if has_psnr:
550
- self.record["train_psnr"].append(train_psnr)
551
- self.record["val_psnr"].append(val_psnr)
552
- else:
553
- # 用None占位保持数据对齐
554
- self.record["train_psnr"].append(None)
555
- self.record["val_psnr"].append(None)
556
-
557
- if has_acc:
558
- self.record["train_acc"].append(train_acc)
559
- self.record["val_acc"].append(val_acc)
560
- else:
561
- # 用None占位保持数据对齐
562
- self.record["train_acc"].append(None)
563
- self.record["val_acc"].append(None)
564
-
565
- # 创建自适应布局的画布
566
- num_plots = 1 + int(has_psnr) + int(has_acc)
567
- fig_width = 4 * num_plots # 每个子图宽度为4
568
- fig = plt.figure(figsize=(fig_width, 4.5))
569
- plt.subplots_adjust(wspace=0.3)
570
-
571
- # 绘制Loss曲线
572
- ax0 = fig.add_subplot(1, num_plots, 1)
573
- self._plot_curve(
574
- ax0,
575
- self.x_epoch,
576
- self.record["train_loss"],
577
- self.record["val_loss"],
578
- "Loss",
579
- train_color="blue",
580
- val_color="red"
581
- )
582
-
583
- # 绘制PSNR曲线(如果存在)
584
- if has_psnr:
585
- ax1 = fig.add_subplot(1, num_plots, 2)
586
- self._plot_curve(
587
- ax1,
588
- self.x_epoch,
589
- self.record["train_psnr"],
590
- self.record["val_psnr"],
591
- "PSNR",
592
- train_color="orange",
593
- val_color="grey"
594
- )
595
-
596
- # 绘制Acc曲线(如果存在)
597
- if has_acc:
598
- ax2 = fig.add_subplot(1, num_plots, num_plots)
599
- self._plot_curve(
600
- ax2,
601
- self.x_epoch,
602
- self.record["train_acc"],
603
- self.record["val_acc"],
604
- "Accuracy",
605
- train_color="grey",
606
- val_color="purple"
607
- )
608
-
609
- # 优化布局并保存
610
- plt.tight_layout()
611
- fig.savefig(
612
- os.path.join(self.output_dir, f"train_{epoch}.jpg"),
613
- dpi=300,
614
- bbox_inches="tight",
615
- pad_inches=0.1
616
- )
617
- plt.close(fig)
618
- """
619
-
620
- vis1= trainingVisplus(root_path="./Rec4")
621
- for epoch in range(t.epoch, t.epoch + args.num_epochs):
622
- train_loss, train_psnr = t.train(epoch)
623
- val_loss, val_psnr = t.val(epoch)
624
- if (epoch + 1) % 3 == 0:
625
- vis1.draw_curve(epoch=epoch,
626
- train_loss=epoch_loss.detach().cpu().numpy(),
627
- val_loss=epoch_val_loss.detach().cpu().numpy(),
628
- train_acc=epoch_accuracy,
629
- val_acc=epoch_val_accuracy
630
- )
631
-
632
- """
633
-
634
- # @staticmethod
635
- # def calculate_psnr(img1, img2):
636
- # return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
637
- """ 使用方法
638
- psnr += self.calculate_psnr(fake_img, label).item()
639
- total += 1
640
- mean_psnr = psnr / total
641
- """
642
-
395
+ drawpic_com(Data_con, save_dir="plotDrawpic_com", name=timeList )
396
+ """
643
397
  from hjnwtx.colormap import cmp_hjnwtx
644
398
  import matplotlib.pyplot as plt
645
399
  from mpl_toolkits.axes_grid1 import make_axes_locatable
646
400
  import numpy as np
647
401
  import os
648
402
  import datetime
649
- def calculate_colorbar_range(data):
650
- """计算色标范围"""
651
- valid_data = data[~np.isnan(data)]
652
- vmin = int(np.min(valid_data))
653
- vmax = int(np.max(valid_data))
654
-
655
- return vmin, vmax
656
- def plot_grid_data(data, titles=None, saveDir="plots", name="temp",
657
- cmap="viridis", vmin=None, vmax=None):
403
+ def plot_grid_data(data, titles=None, saveDir="plots", name="temp", cmap="summer",radarnmc=1):
658
404
  if not isinstance(data, np.ndarray) or data.ndim != 3:
659
- raise ValueError("输入数据必须为三维numpy数组 [num_images, height, width]")
405
+ raise ValueError("The input data must be a three-dimensional NumPy array [num_images, height, width]")
660
406
  num_images = data.shape[0]
661
407
  titles = titles or [f"Data {i}" for i in range(num_images)]
662
- # 计算色标范围
663
- if vmin is None or vmax is None:
664
- vmin, vmax = calculate_colorbar_range(data)
665
- # 计算子图布局
666
408
  ncols = int(np.ceil(np.sqrt(num_images)))
667
409
  nrows = int(np.ceil(num_images / ncols))
668
- # 创建子图
669
- fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 3, nrows * 3))
670
- axes = axes.ravel()
410
+ fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 3))
411
+ if num_images == 1:
412
+ axes = np.array([[axes]])
413
+ elif axes.ndim == 1:
414
+ axes = axes.reshape(1, -1)
415
+ axes_flat = axes.ravel()
671
416
  for i in range(num_images):
672
- if i==int(num_images-1):
673
- vmin = 0
674
- vmax = 70
675
- ax = axes[i]
676
- im = ax.imshow(data[i], cmap=cmp_hjnwtx["radar_nmc"], vmin=vmin, vmax=vmax) #cmp_hjnwtx["radar_nmc"]
417
+ ax = axes_flat[i]
418
+ if i >= num_images - radarnmc:
419
+ im = ax.imshow(data[i], cmap=cmp_hjnwtx["radar_nmc"])
677
420
  else:
678
- vmin = 190
679
- vmax = 300
680
- ax = axes[i]
681
- im = ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
421
+ im = ax.imshow(data[i], cmap=cmap)
682
422
  ax.set_title(titles[i])
683
- ax.axis('off')
684
- # 添加色标
423
+ ax.axis('off')
685
424
  divider = make_axes_locatable(ax)
686
425
  cax = divider.append_axes("right", size="5%", pad=0.05)
687
- cbar = plt.colorbar(im, cax=cax,
688
- ticks=np.linspace(vmin, vmax, 15), # 增加刻度密度
689
- format='%.1f') # 设置数值格式
690
- cbar.ax.tick_params(labelsize=6) # 调整刻度文字大小
691
- # 隐藏空子图
692
- for j in range(num_images, len(axes)):
693
- axes[j].axis('off')
694
- # 保存图像
426
+ plt.colorbar(im, cax=cax)
427
+ for j in range(num_images, len(axes_flat)):
428
+ axes_flat[j].axis('off')
695
429
  plt.tight_layout()
696
430
  os.makedirs(saveDir, exist_ok=True)
697
431
  timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
698
432
  filename = f"{name}_{timestamp}.png"
699
433
  plt.savefig(os.path.join(saveDir, filename), dpi=300)
700
434
  plt.close()
701
- def plotTr(base_up, base_down, name="plotTr", saveDir="plotTr", cmap="summer"): #viridis
702
- """
703
- 组合数据并调用绘图函数
704
- Args:
705
- base_up: 上部数据数组 [shape_len, H, W]
706
- base_down: 下部数据数组 [1, H, W]
707
- shape_len: 上部数据数量
708
- name: 输出文件名前缀
709
- cmap: 颜色映射
710
- """
711
- # 合并数据并生成标题
435
+
436
+ def plotTr(base_up, base_down, name="plotTr", saveDir="plotTr",cmap="summer",radarnmc=1):
712
437
  data_all = np.concatenate([base_up, base_down], axis=0)
713
- titles = [f"B_{i}" for i in range(base_up.shape[0])] + [f"radar_{i+1}" for i in range(base_down.shape[0])]
714
- # 调用绘图函数
438
+ titles = [f"Pic_{i}" for i in range(base_up.shape[0])] + [f"Pic_{i+1}" for i in range(base_down.shape[0])]
715
439
  plot_grid_data(
716
440
  data=data_all,
717
441
  titles=titles,
718
442
  name=name,
719
443
  saveDir=saveDir,
720
- cmap=cmap
721
- )
722
- """
723
- if __name__ == "__main__":
724
- base_up = np.random.rand(10, 50, 50) * 70
725
- base_down = np.random.rand(1, 50, 50) * 70
726
- plotTr(base_up, base_down, name="radar_plot") # radar_mask.detach().cpu().numpy() tensor转numpy
727
- """
444
+ cmap=cmap,
445
+ radarnmc = radarnmc
446
+ )
447
+
448
+ """
449
+ if __name__ == "__main__":
450
+ base_up = np.random.rand(10, 50, 50) * 70
451
+ base_down = np.random.rand(1, 50, 50) * 70
452
+ plotTr(base_up, base_down, name="radar_plot",cmap="summer",radarnmc=1) # radar_mask.detach().cpu().numpy() tensor转numpy
453
+ """
728
454
 
729
455
  import numpy as np
730
456
  import matplotlib.pyplot as plt