yms-kan 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan/plotting.py +468 -0
- yms_kan/train_eval_utils.py +7 -4
- yms_kan/version.py +1 -1
- {yms_kan-0.0.5.dist-info → yms_kan-0.0.6.dist-info}/METADATA +1 -1
- {yms_kan-0.0.5.dist-info → yms_kan-0.0.6.dist-info}/RECORD +8 -7
- {yms_kan-0.0.5.dist-info → yms_kan-0.0.6.dist-info}/WHEEL +0 -0
- {yms_kan-0.0.5.dist-info → yms_kan-0.0.6.dist-info}/licenses/LICENSE +0 -0
- {yms_kan-0.0.5.dist-info → yms_kan-0.0.6.dist-info}/top_level.txt +0 -0
yms_kan/plotting.py
ADDED
@@ -0,0 +1,468 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Union, List
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from matplotlib import pyplot as plt, rcParams
|
6
|
+
from sklearn.manifold import TSNE
|
7
|
+
from sklearn.metrics import ConfusionMatrixDisplay, roc_curve, auc, precision_recall_curve
|
8
|
+
from sklearn.preprocessing import label_binarize
|
9
|
+
|
10
|
+
|
11
|
+
# def plot_confusion_matrix(cm, classes,
|
12
|
+
# save_path='confusion_matrix_D1.png',
|
13
|
+
# normalize=False,
|
14
|
+
# title='Confusion matrix',
|
15
|
+
# cmap=plt.cm.Blues):
|
16
|
+
# """
|
17
|
+
# 绘制混淆矩阵的函数
|
18
|
+
# 这个函数不修改原始数据,但会返回混淆矩阵。
|
19
|
+
# """
|
20
|
+
# plt.figure()
|
21
|
+
# if normalize:
|
22
|
+
# cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
23
|
+
# print("Normalized confusion matrix")
|
24
|
+
# else:
|
25
|
+
# print('Confusion matrix, without normalization')
|
26
|
+
#
|
27
|
+
# plt.imshow(cm, interpolation='nearest', cmap=cmap)
|
28
|
+
# plt.title(title)
|
29
|
+
# plt.colorbar()
|
30
|
+
# tick_marks = np.arange(len(classes))
|
31
|
+
# plt.xticks(tick_marks, classes, rotation=45)
|
32
|
+
# plt.yticks(tick_marks, classes)
|
33
|
+
#
|
34
|
+
# fmt = '.2f' if normalize else 'd'
|
35
|
+
# # 用于判断变量normalize的值。如果normalize为True,则将格式化字符串.2f赋值给变量fmt;否则,将格式化字符串'd'赋值给变量fmt。
|
36
|
+
# # 其中,.2f表示保留两位小数,'d'表示以十进制形式显示。
|
37
|
+
# thresh = cm.max() / 2.
|
38
|
+
# for i, j in np.ndindex(cm.shape):
|
39
|
+
# plt.text(j, i, format(cm[i, j], fmt),
|
40
|
+
# horizontalalignment="center",
|
41
|
+
# color="white" if cm[i, j] > thresh else "black")
|
42
|
+
#
|
43
|
+
# plt.tight_layout()
|
44
|
+
# plt.ylabel('True label')
|
45
|
+
# plt.xlabel('Predicted label')
|
46
|
+
# plt.savefig(save_path)
|
47
|
+
# plt.close()
|
48
|
+
|
49
|
+
def plot_confusion_matrix(all_labels,
|
50
|
+
all_predictions,
|
51
|
+
classes,
|
52
|
+
path,
|
53
|
+
name='confusion_matrix.png',
|
54
|
+
normalize=None,
|
55
|
+
cmap=plt.cm.Blues,
|
56
|
+
):
|
57
|
+
ConfusionMatrixDisplay.from_predictions(all_labels,
|
58
|
+
all_predictions,
|
59
|
+
display_labels=classes,
|
60
|
+
cmap=cmap,
|
61
|
+
normalize=normalize,
|
62
|
+
xticks_rotation=45
|
63
|
+
)
|
64
|
+
plt.savefig(os.path.join(path, name), dpi=500)
|
65
|
+
plt.close()
|
66
|
+
|
67
|
+
|
68
|
+
def plot_multi_class_curves(y_true, y_pred, target_names, save):
|
69
|
+
# 将多分类标签转换为二进制标签(One - vs - Rest)
|
70
|
+
n_classes = len(set(target_names))
|
71
|
+
y_true_bin = label_binarize(y_true, classes=range(n_classes))
|
72
|
+
y_pred_bin = label_binarize(y_pred, classes=range(n_classes))
|
73
|
+
|
74
|
+
# 计算每个类别的精确率 - 召回率曲线和 AUC
|
75
|
+
precision = dict()
|
76
|
+
recall = dict()
|
77
|
+
auc_scores = dict()
|
78
|
+
|
79
|
+
for i in range(n_classes):
|
80
|
+
precision[i], recall[i], _ = precision_recall_curve(y_true_bin[:, i], y_pred_bin[:, i])
|
81
|
+
auc_scores[i] = auc(recall[i], precision[i])
|
82
|
+
|
83
|
+
# 绘制精确率 - 召回率曲线
|
84
|
+
plt.figure()
|
85
|
+
# 使用更丰富的颜色映射来应对类别数不确定的情况
|
86
|
+
cmap = plt.get_cmap('tab10')
|
87
|
+
for i in range(n_classes):
|
88
|
+
color = cmap(i % 10) # 循环使用颜色映射中的颜色
|
89
|
+
plt.plot(recall[i], precision[i], color=color, lw=2,
|
90
|
+
label=f'{target_names[i]}:{auc_scores[i]:0.4f}')
|
91
|
+
|
92
|
+
plt.xlabel('Recall')
|
93
|
+
plt.ylabel('Precision')
|
94
|
+
plt.legend(loc="best")
|
95
|
+
# 确保保存路径存在
|
96
|
+
if not os.path.exists(save):
|
97
|
+
os.makedirs(save)
|
98
|
+
plt.savefig(os.path.join(save, 'precision_recall_curve.png'), dpi=500)
|
99
|
+
plt.close()
|
100
|
+
|
101
|
+
# 计算每个类别的 ROC 曲线和 AUC
|
102
|
+
fpr = dict()
|
103
|
+
tpr = dict()
|
104
|
+
roc_auc = dict()
|
105
|
+
|
106
|
+
for i in range(n_classes):
|
107
|
+
fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_bin[:, i])
|
108
|
+
roc_auc[i] = auc(fpr[i], tpr[i])
|
109
|
+
|
110
|
+
# 绘制 ROC 曲线
|
111
|
+
plt.figure()
|
112
|
+
for i in range(n_classes):
|
113
|
+
color = cmap(i % 10) # 循环使用颜色映射中的颜色
|
114
|
+
plt.plot(fpr[i], tpr[i], color=color, lw=2,
|
115
|
+
label=f'{target_names[i]}:{roc_auc[i]:0.4f}')
|
116
|
+
|
117
|
+
plt.plot([0, 1], [0, 1], 'k--', lw=2)
|
118
|
+
plt.xlim([0.0, 1.0])
|
119
|
+
plt.ylim([0.0, 1.0])
|
120
|
+
plt.xlabel('False Positive Rate')
|
121
|
+
plt.ylabel('True Positive Rate')
|
122
|
+
plt.legend(loc="best")
|
123
|
+
plt.savefig(os.path.join(save, 'roc_curve.png'), dpi=500)
|
124
|
+
plt.close()
|
125
|
+
|
126
|
+
|
127
|
+
def plot_all_metrics(metrics_dict, num_epochs, img_name, save_path, plot_metric=False):
|
128
|
+
"""
|
129
|
+
绘制训练指标曲线
|
130
|
+
|
131
|
+
参数:
|
132
|
+
metrics_dict: 包含指标数据的字典,键为指标名称,值为一个列表
|
133
|
+
num_epochs: 训练的总轮数
|
134
|
+
save_path: 保存图像的路径,默认为 'metrics.png'
|
135
|
+
"""
|
136
|
+
# 检查是否存在 train_losses 和 val_losses
|
137
|
+
has_train_loss = 'train_losses' in metrics_dict
|
138
|
+
has_val_loss = 'val_losses' in metrics_dict
|
139
|
+
|
140
|
+
# 获取指标名称列表,排除 'epoch'
|
141
|
+
metric_names = [key for key in metrics_dict.keys() if key != 'epoch']
|
142
|
+
|
143
|
+
# 如果有 train_losses 和 val_losses,添加一个特殊的键
|
144
|
+
if has_train_loss and has_val_loss:
|
145
|
+
metric_names.append('train_val_loss')
|
146
|
+
|
147
|
+
# 计算子图的行数和列数,使得 m*n 最接近 metric_names 的数量
|
148
|
+
num_metrics = len(metric_names)
|
149
|
+
m = int(np.ceil(np.sqrt(num_metrics)))
|
150
|
+
n = int(np.ceil(num_metrics / m))
|
151
|
+
|
152
|
+
# 创建子图
|
153
|
+
fig, axes = plt.subplots(m, n, figsize=(12 * n, 6 * m))
|
154
|
+
axes = axes.flatten() # 将二维数组的 axes 展平为一维
|
155
|
+
|
156
|
+
# 绘制每个指标的曲线
|
157
|
+
for i, name in enumerate(metric_names):
|
158
|
+
ax = axes[i]
|
159
|
+
|
160
|
+
if name == 'train_val_loss':
|
161
|
+
# 绘制 train_losses 和 val_losses 在同一张图
|
162
|
+
train_loss = metrics_dict['train_losses']
|
163
|
+
val_loss = metrics_dict['val_losses']
|
164
|
+
ax.plot(range(1, num_epochs + 1), train_loss, label='Training Loss')
|
165
|
+
ax.plot(range(1, num_epochs + 1), val_loss, label='Validation Loss')
|
166
|
+
ax.set_title('Loss over epochs')
|
167
|
+
else:
|
168
|
+
# 绘制其他指标
|
169
|
+
metric = metrics_dict[name]
|
170
|
+
ax.plot(range(1, num_epochs + 1), metric, label=f'{name}')
|
171
|
+
ax.set_title(f'{name} over epochs')
|
172
|
+
|
173
|
+
ax.set_xlabel('Epochs')
|
174
|
+
ax.set_ylabel(f'{name}')
|
175
|
+
ax.legend()
|
176
|
+
ax.grid(True)
|
177
|
+
|
178
|
+
# 删除多余的子图
|
179
|
+
for j in range(i + 1, len(axes)):
|
180
|
+
fig.delaxes(axes[j])
|
181
|
+
|
182
|
+
# 调整布局并保存图像
|
183
|
+
plt.tight_layout()
|
184
|
+
plt.savefig(os.path.join(save_path, f'{img_name}_metrics.png'), dpi=500)
|
185
|
+
plt.close()
|
186
|
+
if plot_metric:
|
187
|
+
# 为每个指标单独绘制一张图
|
188
|
+
for name in metric_names:
|
189
|
+
plt.figure(figsize=(12, 6))
|
190
|
+
if name == 'train_val_loss':
|
191
|
+
# 绘制 train_losses 和 val_losses 在同一张图
|
192
|
+
train_loss = metrics_dict['train_losses']
|
193
|
+
val_loss = metrics_dict['val_losses']
|
194
|
+
plt.plot(range(1, num_epochs + 1), train_loss, label='Training Loss')
|
195
|
+
plt.plot(range(1, num_epochs + 1), val_loss, label='Validation Loss')
|
196
|
+
plt.title('Loss over epochs')
|
197
|
+
else:
|
198
|
+
# 绘制其他指标
|
199
|
+
metric = metrics_dict[name]
|
200
|
+
plt.plot(range(1, num_epochs + 1), metric, label=f'{name}')
|
201
|
+
plt.title(f'{name} over epochs')
|
202
|
+
|
203
|
+
plt.xlabel('Epochs')
|
204
|
+
plt.ylabel(f'{name}')
|
205
|
+
plt.legend()
|
206
|
+
plt.grid(True)
|
207
|
+
plt.savefig(os.path.join(save_path, f'{img_name}_{name}.png'), dpi=500)
|
208
|
+
plt.close()
|
209
|
+
|
210
|
+
|
211
|
+
def plot_metrics(metric1, metric2, num_epochs, name, save_path='metrics.png'):
|
212
|
+
plt.figure(figsize=(12, 6))
|
213
|
+
plt.plot(range(1, num_epochs + 1), metric1, label=f'Training {name}')
|
214
|
+
plt.plot(range(1, num_epochs + 1), metric2, label=f'Validation {name}')
|
215
|
+
plt.title(f'{name} over epochs')
|
216
|
+
plt.xlabel('Epochs')
|
217
|
+
plt.ylabel(f'{name}')
|
218
|
+
plt.legend()
|
219
|
+
plt.grid(True)
|
220
|
+
plt.savefig(save_path, dpi=500)
|
221
|
+
plt.close()
|
222
|
+
|
223
|
+
|
224
|
+
def plot_single(met, num_epochs, name, save_path='metrics.png'):
|
225
|
+
plt.figure(figsize=(12, 6))
|
226
|
+
plt.plot(range(1, num_epochs + 1), met)
|
227
|
+
plt.title(f'{name} over epochs')
|
228
|
+
plt.xlabel('Epochs')
|
229
|
+
plt.ylabel(f'{name}')
|
230
|
+
plt.grid(True)
|
231
|
+
plt.savefig(save_path, dpi=500)
|
232
|
+
plt.close()
|
233
|
+
|
234
|
+
|
235
|
+
def plot_data_from_files(file_paths, exclude_headers=None, save_path='metrics.png'):
|
236
|
+
"""
|
237
|
+
Plot data from multiple text files, excluding specified headers, with a fixed 2 rows layout.
|
238
|
+
Use folder name as label if file names are consistent, otherwise use file name.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
- file_paths (list): A list of file paths to read data from.
|
242
|
+
- exclude_headers (list): A list of headers to exclude from plotting.
|
243
|
+
"""
|
244
|
+
# 设置支持中文的字体
|
245
|
+
rcParams['font.family'] = 'Microsoft YaHei' # Windows系统
|
246
|
+
# rcParams['font.family'] = 'PingFang SC' # macOS系统
|
247
|
+
rcParams['axes.unicode_minus'] = False # 正确显示负号
|
248
|
+
if exclude_headers is None:
|
249
|
+
exclude_headers = ['epoch', 'lr'] # 默认排除的头部
|
250
|
+
|
251
|
+
# 初始化一个列表来存储所有数据
|
252
|
+
all_data = []
|
253
|
+
|
254
|
+
# 处理每个文件
|
255
|
+
for path in file_paths:
|
256
|
+
data = {}
|
257
|
+
with open(path, 'r') as file:
|
258
|
+
headers = file.readline().strip().split('\t')
|
259
|
+
for header in headers:
|
260
|
+
data[header] = []
|
261
|
+
for line in file:
|
262
|
+
values = line.strip().split('\t')
|
263
|
+
for header, value in zip(headers, values):
|
264
|
+
data[header].append(float(value))
|
265
|
+
all_data.append(data)
|
266
|
+
|
267
|
+
# 提取共同的epochs
|
268
|
+
epochs = all_data[0]['epoch'] # 假设所有文件都有相同的epochs
|
269
|
+
|
270
|
+
# 创建子图
|
271
|
+
num_metrics = len([m for m in all_data[0].keys() if m not in exclude_headers])
|
272
|
+
num_cols = (num_metrics + 1) // 2 + (1 if num_metrics % 2 else 0) # 计算列数
|
273
|
+
fig, axs = plt.subplots(2, num_cols, figsize=(15, 8), constrained_layout=True)
|
274
|
+
axs = axs.flatten() # 展平数组以便更容易迭代
|
275
|
+
|
276
|
+
# 检查文件名是否一致,以确定使用文件名还是文件夹名作为标签
|
277
|
+
file_names = [os.path.basename(path) for path in file_paths]
|
278
|
+
unique_names = set(file_names)
|
279
|
+
if len(unique_names) == 1:
|
280
|
+
labels = [os.path.basename(os.path.dirname(path)) for path in file_paths]
|
281
|
+
else:
|
282
|
+
labels = file_names
|
283
|
+
|
284
|
+
# 绘制每个指标的曲线图
|
285
|
+
for i, key in enumerate([m for m in all_data[0].keys() if m not in exclude_headers]):
|
286
|
+
for j, data in enumerate(all_data):
|
287
|
+
axs[i].plot(epochs, data[key], label=f'{labels[j]} {key}', color=f'C{j}')
|
288
|
+
axs[i].set_title(f'{key} over Epochs')
|
289
|
+
axs[i].set_xlabel('Epoch')
|
290
|
+
axs[i].set_ylabel(key)
|
291
|
+
axs[i].legend()
|
292
|
+
|
293
|
+
# 隐藏多余的子图
|
294
|
+
for i in range(num_metrics, len(axs)):
|
295
|
+
axs[i].axis('off')
|
296
|
+
|
297
|
+
# 显示图表
|
298
|
+
plt.savefig(save_path, dpi=500)
|
299
|
+
plt.close()
|
300
|
+
|
301
|
+
|
302
|
+
def visualize_features(
|
303
|
+
features: Union[List[List[float]], np.ndarray],
|
304
|
+
labels: Union[List[int], np.ndarray],
|
305
|
+
class_names: List[str],
|
306
|
+
n_components: int = 2,
|
307
|
+
perplexity: int = 30,
|
308
|
+
learning_rate: float = 200.0,
|
309
|
+
title: str = "Feature Visualization (t-SNE)",
|
310
|
+
save_path: str = None,
|
311
|
+
backend: str = "agg",
|
312
|
+
markers: Union[str, List[str]] = 'o'
|
313
|
+
) -> None:
|
314
|
+
"""
|
315
|
+
最终版t-SNE可视化(动态标记适配+警告修复)
|
316
|
+
|
317
|
+
参数:
|
318
|
+
markers: 标记形状(自动区分填充/非填充类型)
|
319
|
+
填充标记 (支持白边): 'o','s','^','v','d','p'等
|
320
|
+
非填充标记 (仅颜色): 'x','+','|','_','*'等
|
321
|
+
3D限制: 仅支持 ['o','s','^','v','x','+'](含填充/非填充)
|
322
|
+
"""
|
323
|
+
# 1. 初始化与环境配置
|
324
|
+
plt.switch_backend(backend) # PyCharm兼容后端
|
325
|
+
np.random.seed(42) # 保证可复现性
|
326
|
+
|
327
|
+
# 2. 数据格式校验(严格模式)
|
328
|
+
try:
|
329
|
+
features = np.asarray(features, dtype=np.float32)
|
330
|
+
labels = np.asarray(labels, dtype=np.int32)
|
331
|
+
except Exception as e:
|
332
|
+
raise TypeError(f"数据格式错误: {str(e)}") from e
|
333
|
+
|
334
|
+
if features.shape[0] != len(labels):
|
335
|
+
raise ValueError(f"特征数量({features.shape[0]})与标签数量({len(labels)})不匹配")
|
336
|
+
if np.any(labels < 0) or np.max(labels) >= len(class_names):
|
337
|
+
raise ValueError(f"标签范围非法: 期望[0,{len(class_names) - 1}], 实际[{np.min(labels)},{np.max(labels)}]")
|
338
|
+
|
339
|
+
# 3. 标记参数校验(含类型/维度校验)
|
340
|
+
unique_labels = np.unique(labels)
|
341
|
+
n_classes = len(unique_labels)
|
342
|
+
|
343
|
+
if isinstance(markers, str):
|
344
|
+
markers = [markers] * n_classes
|
345
|
+
elif not isinstance(markers, list) or len(markers) != n_classes:
|
346
|
+
raise ValueError(f"markers需为字符串或长度为{class_names}的列表")
|
347
|
+
|
348
|
+
for i, m in enumerate(markers):
|
349
|
+
# 3D标记白名单(改为元组) 🌟 修复1
|
350
|
+
if n_components == 3 and m not in ('o', 's', '^', 'v', 'x', '+'): # 元组
|
351
|
+
raise ValueError("3D仅支持: o/s/^/v/x/+")
|
352
|
+
# 填充标记校验(元组+元组拼接) 🌟 修复2
|
353
|
+
if m not in plt.Line2D.filled_markers + ('x', '+', '|', '_', '*'): # 元组+元组
|
354
|
+
raise ValueError(f"未知标记'{m}',参考Matplotlib文档")
|
355
|
+
|
356
|
+
# 4. T-SNE降维(完整流程含进度显示)
|
357
|
+
print(">>> 开始T-SNE降维...")
|
358
|
+
tsne = TSNE(
|
359
|
+
n_components=n_components,
|
360
|
+
perplexity=perplexity,
|
361
|
+
learning_rate=learning_rate,
|
362
|
+
max_iter=1000,
|
363
|
+
random_state=42,
|
364
|
+
verbose=1 # 显示训练进度(PyCharm控制台可见)
|
365
|
+
)
|
366
|
+
reduced = tsne.fit_transform(features)
|
367
|
+
print(">>> 降维完成")
|
368
|
+
|
369
|
+
# 5. 颜色与样式准备
|
370
|
+
cmap = plt.colormaps['tab10'] # Matplotlib 3.7+推荐
|
371
|
+
colors = cmap(unique_labels) # 按标签索引颜色
|
372
|
+
filled_markers = plt.Line2D.filled_markers # 内置填充标记集合
|
373
|
+
|
374
|
+
# 6. 绘图核心(2D与3D分离实现)
|
375
|
+
fig = plt.figure(figsize=(10, 8) if n_components == 2 else (12, 10))
|
376
|
+
|
377
|
+
ax = fig.add_subplot(111)
|
378
|
+
if n_components == 2:
|
379
|
+
|
380
|
+
for i, label in enumerate(unique_labels):
|
381
|
+
mask = labels == label
|
382
|
+
m = markers[i]
|
383
|
+
|
384
|
+
# 动态参数生成(填充/非填充标记区分)
|
385
|
+
if m in filled_markers: # 填充标记(带白边)
|
386
|
+
ax.scatter(
|
387
|
+
reduced[mask, 0], reduced[mask, 1],
|
388
|
+
marker=m,
|
389
|
+
facecolors=colors[i], # 显式指定填充色
|
390
|
+
edgecolors='white', # 保留白色边缘
|
391
|
+
linewidths=1.2, # 边缘粗细
|
392
|
+
alpha=0.8,
|
393
|
+
s=60,
|
394
|
+
label=class_names[i]
|
395
|
+
)
|
396
|
+
else: # 非填充标记(仅颜色)
|
397
|
+
ax.scatter(
|
398
|
+
reduced[mask, 0], reduced[mask, 1],
|
399
|
+
marker=m,
|
400
|
+
color=colors[i], # 直接控制线条颜色
|
401
|
+
lw=1.5, # 线条粗细
|
402
|
+
alpha=0.9,
|
403
|
+
s=80, # 非填充标记适当放大
|
404
|
+
label=class_names[i]
|
405
|
+
)
|
406
|
+
|
407
|
+
ax.set(xlabel='t-SNE Dim 1', ylabel='t-SNE Dim 2')
|
408
|
+
|
409
|
+
elif n_components == 3:
|
410
|
+
ax = fig.add_subplot(111, projection='3d')
|
411
|
+
for i, label in enumerate(unique_labels):
|
412
|
+
mask = labels == label
|
413
|
+
m = markers[i]
|
414
|
+
|
415
|
+
# 3D特殊处理(填充标记保留白边,非填充仅颜色)
|
416
|
+
if m in filled_markers and m in ['o', 's', '^', 'v']: # 3D填充标记
|
417
|
+
ax.scatter(
|
418
|
+
reduced[mask, 0], reduced[mask, 1], reduced[mask, 2],
|
419
|
+
marker=m,
|
420
|
+
facecolors=colors[i],
|
421
|
+
edgecolors='white',
|
422
|
+
linewidths=0.8, # 3D边缘更细
|
423
|
+
alpha=0.7,
|
424
|
+
s=40,
|
425
|
+
label=class_names[i]
|
426
|
+
)
|
427
|
+
else: # 3D非填充标记(x/+)
|
428
|
+
ax.scatter(
|
429
|
+
reduced[mask, 0], reduced[mask, 1], reduced[mask, 2],
|
430
|
+
marker=m,
|
431
|
+
color=colors[i],
|
432
|
+
lw=1.2,
|
433
|
+
alpha=0.8,
|
434
|
+
s=50,
|
435
|
+
label=class_names[i]
|
436
|
+
)
|
437
|
+
|
438
|
+
ax.set(xlabel='Dim 1', ylabel='Dim 2', zlabel='Dim 3')
|
439
|
+
|
440
|
+
# 7. 通用图表设置
|
441
|
+
ax.set_title(title, fontsize=14, pad=20)
|
442
|
+
ax.legend(
|
443
|
+
title='Classes',
|
444
|
+
bbox_to_anchor=(1.05, 1),
|
445
|
+
loc='upper left',
|
446
|
+
fontsize=11,
|
447
|
+
frameon=True,
|
448
|
+
framealpha=0.9
|
449
|
+
)
|
450
|
+
plt.tight_layout(pad=4) # 防止标签截断
|
451
|
+
|
452
|
+
# 8. 保存/显示(带错误处理)
|
453
|
+
if save_path:
|
454
|
+
try:
|
455
|
+
plt.savefig(
|
456
|
+
save_path,
|
457
|
+
dpi=500,
|
458
|
+
bbox_inches='tight',
|
459
|
+
facecolor='white',
|
460
|
+
format='png' if '.' not in save_path else None
|
461
|
+
)
|
462
|
+
print(f"✅ 图片已保存至: {save_path} (尺寸:{fig.get_size_inches()})")
|
463
|
+
except Exception as e:
|
464
|
+
print(f"❌ 保存失败: {str(e)}")
|
465
|
+
finally:
|
466
|
+
plt.close(fig)
|
467
|
+
elif backend == 'tkagg':
|
468
|
+
plt.show()
|
yms_kan/train_eval_utils.py
CHANGED
@@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
9
9
|
from tqdm import tqdm
|
10
10
|
|
11
11
|
from yms_kan import LBFGS
|
12
|
+
from yms_kan.plotting import plot_confusion_matrix
|
12
13
|
from yms_kan.tool import initialize_results_file, append_to_results_file, calculate_metric
|
13
14
|
|
14
15
|
|
@@ -19,8 +20,6 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
19
20
|
stop_grid_update_step=100,
|
20
21
|
save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
21
22
|
singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
|
22
|
-
all_predictions = []
|
23
|
-
all_labels = []
|
24
23
|
best = -1
|
25
24
|
column_order = ['epoch', 'train_losses', 'val_losses', 'accuracies', 'precisions', 'recalls',
|
26
25
|
'f1-scores', 'lrs']
|
@@ -131,6 +130,9 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
131
130
|
|
132
131
|
val_loss = torch.zeros(1).to(model.device)
|
133
132
|
with torch.no_grad():
|
133
|
+
all_predictions = []
|
134
|
+
all_labels = []
|
135
|
+
|
134
136
|
test_indices = np.arange(dataset['test_input'].shape[0])
|
135
137
|
np.random.shuffle(test_indices)
|
136
138
|
test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
|
@@ -147,7 +149,7 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
147
149
|
loss = loss_fn(outputs, batch_test_label)
|
148
150
|
|
149
151
|
val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
|
150
|
-
test_pbar.set_postfix(
|
152
|
+
test_pbar.set_postfix(val_loss=val_loss.item())
|
151
153
|
if label is not None:
|
152
154
|
diffs = torch.abs(outputs - label)
|
153
155
|
closest_indices = torch.argmin(diffs, dim=1)
|
@@ -170,8 +172,9 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
170
172
|
best = m["f1-score"]
|
171
173
|
results['all_predictions'] = all_predictions
|
172
174
|
results['all_labels'] = all_labels
|
175
|
+
plot_confusion_matrix(all_labels, all_predictions, class_dict, save_path)
|
173
176
|
if save_path is not None:
|
174
|
-
model.saveckpt(path=save_path + '/' + 'model')
|
177
|
+
model.saveckpt(path=(os.path.join(save_path, 'save_model') + '/' + 'model'))
|
175
178
|
if txt_file is not None:
|
176
179
|
m.update({'lr': train_lr, 'epoch': epoch, 'train_loss': train_loss.item(), 'val_loss': val_loss.item()})
|
177
180
|
append_to_results_file(txt_file, m, column_order,
|
yms_kan/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.0.
|
1
|
+
__version__ = "0.0.6" # 初始版本
|
@@ -8,15 +8,16 @@ yms_kan/compiler.py,sha256=7bVwDNX0xmLAjQ8V1FdmkIIIibmy_W5eaeSKBlYL0Vc,18632
|
|
8
8
|
yms_kan/experiment.py,sha256=VnZq7hmcvRk08GNI7VIpkOjkaRZBoIw1C8SU_f1KbaA,1682
|
9
9
|
yms_kan/feynman.py,sha256=Eisf69K49s4C6UlPEi5LnNK_p5TUJQLBKxMp-sW0a9w,33687
|
10
10
|
yms_kan/hypothesis.py,sha256=Ec20xadfgOSSWeZHQaGn-h9F2PY7LWFU3iniNI2Zd_4,23165
|
11
|
+
yms_kan/plotting.py,sha256=Moi6QTJQxHjutGMgxR9oSsqZSzYY3TP-7WNapdCIqzw,18097
|
11
12
|
yms_kan/spline.py,sha256=ZXyGwl2Sc-UrnrcuUXeUQkBOMnetaWcHrbpZaqatCvs,4345
|
12
13
|
yms_kan/tool.py,sha256=rkRpqF3EcsAq7a3k1F1zKlxfJ4U9n-FzHyNCJgN4URY,21159
|
13
|
-
yms_kan/train_eval_utils.py,sha256=
|
14
|
+
yms_kan/train_eval_utils.py,sha256=zekRLiPMb0JWw0hWK6zLTF4Ub2mT7PSP81sIvYl-VfY,16709
|
14
15
|
yms_kan/utils.py,sha256=J07L-tgmc1OfU6Tl6mGwHJRizjFN75EJK8BxejaZLUc,23860
|
15
|
-
yms_kan/version.py,sha256=
|
16
|
+
yms_kan/version.py,sha256=hlRkaxPm359oqYesPhlxNh2ehVXK30hleVe2-mln1Rg,39
|
16
17
|
yms_kan/assets/img/mult_symbol.png,sha256=2f4xUKdweft-qUbHjFI5h9-smnEtc0FWq8hNYZhPAXY,6392
|
17
18
|
yms_kan/assets/img/sum_symbol.png,sha256=94QkMUzmEjlCq_yf14nMEQmettaq86FmlGfdl22b4XE,6210
|
18
|
-
yms_kan-0.0.
|
19
|
-
yms_kan-0.0.
|
20
|
-
yms_kan-0.0.
|
21
|
-
yms_kan-0.0.
|
22
|
-
yms_kan-0.0.
|
19
|
+
yms_kan-0.0.6.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
|
20
|
+
yms_kan-0.0.6.dist-info/METADATA,sha256=rGVjdD3NuHuziwieG_JABRnPR9WCYbgV2trlm-pDmnI,240
|
21
|
+
yms_kan-0.0.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
22
|
+
yms_kan-0.0.6.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
|
23
|
+
yms_kan-0.0.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|