yms-kan 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan/tool.py +499 -234
- yms_kan/train_eval_utils.py +59 -13
- yms_kan/version.py +1 -1
- {yms_kan-0.0.2.dist-info → yms_kan-0.0.3.dist-info}/METADATA +1 -1
- {yms_kan-0.0.2.dist-info → yms_kan-0.0.3.dist-info}/RECORD +8 -8
- {yms_kan-0.0.2.dist-info → yms_kan-0.0.3.dist-info}/WHEEL +0 -0
- {yms_kan-0.0.2.dist-info → yms_kan-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {yms_kan-0.0.2.dist-info → yms_kan-0.0.3.dist-info}/top_level.txt +0 -0
yms_kan/tool.py
CHANGED
@@ -1,8 +1,155 @@
|
|
1
1
|
import os
|
2
|
+
import re
|
3
|
+
from datetime import datetime, timezone, timedelta
|
4
|
+
from typing import Optional, Dict, List
|
2
5
|
|
6
|
+
import click
|
3
7
|
import numpy as np
|
4
|
-
import
|
8
|
+
import pandas as pd
|
5
9
|
import wandb
|
10
|
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, \
|
11
|
+
classification_report
|
12
|
+
from tqdm import tqdm
|
13
|
+
|
14
|
+
|
15
|
+
# 读取txt内两个不同表格的数据,并将结果转换为字典列表输出
|
16
|
+
def read_multi_table_txt(file_path):
|
17
|
+
# 读取原始内容
|
18
|
+
with open(file_path, 'r') as f:
|
19
|
+
content = f.read()
|
20
|
+
|
21
|
+
# 按表格标题分割内容(假设每个新表格以"epoch"开头)
|
22
|
+
table_blocks = re.split(r'\n(?=epoch\s)', content.strip())
|
23
|
+
|
24
|
+
# 处理每个表格块
|
25
|
+
table_dicts = []
|
26
|
+
for block in table_blocks:
|
27
|
+
lines = [line.strip() for line in block.split('\n') if line.strip()]
|
28
|
+
|
29
|
+
# 解析列名(处理制表符和混合空格)
|
30
|
+
columns = re.split(r'\s{2,}|\t', lines[0])
|
31
|
+
|
32
|
+
# 解析数据行(处理混合分隔符)
|
33
|
+
data = []
|
34
|
+
for line in lines[1:]:
|
35
|
+
# 使用正则表达式分割多个连续空格/制表符
|
36
|
+
row = re.split(r'\s{2,}|\t', line)
|
37
|
+
data.append(row)
|
38
|
+
|
39
|
+
# 创建DataFrame并自动转换数值类型
|
40
|
+
df = pd.DataFrame(data, columns=columns)
|
41
|
+
df = df.apply(pd.to_numeric, errors='coerce') # 自动识别数值列,非数值转换为NaN
|
42
|
+
|
43
|
+
# 将DataFrame转换为字典,每列以列表形式保存
|
44
|
+
table_dict = df.to_dict(orient='list')
|
45
|
+
table_dicts.append(table_dict)
|
46
|
+
|
47
|
+
return table_dicts
|
48
|
+
|
49
|
+
|
50
|
+
def get_current_time(format_str="%Y-%m-%d %H:%M:%S"):
|
51
|
+
"""
|
52
|
+
获取东八区(UTC+8)的当前时间,并返回指定格式的字符串
|
53
|
+
:param format_str: 时间格式(默认为 "%Y-%m-%d %H:%M:%S")
|
54
|
+
:return: 格式化后的时间字符串
|
55
|
+
"""
|
56
|
+
|
57
|
+
# 创建东八区的时区对象
|
58
|
+
utc8_timezone = timezone(timedelta(hours=8))
|
59
|
+
|
60
|
+
# 转换为东八区时间
|
61
|
+
utc8_time = datetime.now(utc8_timezone)
|
62
|
+
|
63
|
+
# 格式化为字符串
|
64
|
+
formatted_time = utc8_time.strftime(format_str)
|
65
|
+
return formatted_time
|
66
|
+
|
67
|
+
|
68
|
+
# val和test时的相关结果指标计算
|
69
|
+
def calculate_results(all_labels, all_predictions, classes, average='macro'):
|
70
|
+
results = {
|
71
|
+
'accuracy': accuracy_score(y_true=all_labels, y_pred=all_predictions),
|
72
|
+
'precision': precision_score(y_true=all_labels, y_pred=all_predictions, average=average),
|
73
|
+
'recall': recall_score(y_true=all_labels, y_pred=all_predictions, average=average),
|
74
|
+
'f1_score': f1_score(y_true=all_labels, y_pred=all_predictions, average=average),
|
75
|
+
'cm': confusion_matrix(y_true=all_labels, y_pred=all_predictions, labels=np.arange(len(classes)))
|
76
|
+
}
|
77
|
+
return results
|
78
|
+
|
79
|
+
|
80
|
+
def calculate_metric(all_labels, all_predictions, classes, class_metric=False, average='macro avg'):
|
81
|
+
metric = classification_report(y_true=all_labels, y_pred=all_predictions,
|
82
|
+
target_names=classes, digits=4, output_dict=True, zero_division=0)
|
83
|
+
if not class_metric:
|
84
|
+
metric = {
|
85
|
+
'accuracy': metric.get('accuracy'),
|
86
|
+
'precision': metric.get(average).get('precision'),
|
87
|
+
'recall': metric.get(average).get('recall'),
|
88
|
+
'f1-score': metric.get(average).get('f1-score'),
|
89
|
+
}
|
90
|
+
return metric
|
91
|
+
else:
|
92
|
+
return metric
|
93
|
+
|
94
|
+
|
95
|
+
def dict_to_classification_report(report_dict, digits=2):
|
96
|
+
headers = ["precision", "recall", "f1-score", "support"]
|
97
|
+
target_names = list(report_dict.keys())
|
98
|
+
target_names.remove('accuracy') if 'accuracy' in target_names else None
|
99
|
+
longest_last_line_heading = "weighted avg"
|
100
|
+
name_width = max(len(cn) for cn in target_names)
|
101
|
+
width = max(name_width, len(longest_last_line_heading), digits)
|
102
|
+
head_fmt = "{:>{width}s} " + " {:>9}" * len(headers)
|
103
|
+
report = head_fmt.format("", *headers, width=width)
|
104
|
+
report += "\n\n"
|
105
|
+
row_fmt = "{:>{width}s} " + " {:>9.{digits}f}" * 3 + " {:>9}\n"
|
106
|
+
for target_name in target_names:
|
107
|
+
scores = [report_dict[target_name][h] for h in headers]
|
108
|
+
report += row_fmt.format(target_name, *scores, width=width, digits=digits)
|
109
|
+
report += "\n"
|
110
|
+
|
111
|
+
average_options = ["micro avg", "macro avg", "weighted avg"]
|
112
|
+
if 'samples avg' in report_dict:
|
113
|
+
average_options.append('samples avg')
|
114
|
+
for average in average_options:
|
115
|
+
if average in report_dict:
|
116
|
+
scores = [report_dict[average][h] for h in headers]
|
117
|
+
if average == "accuracy":
|
118
|
+
row_fmt_accuracy = (
|
119
|
+
"{:>{width}s} "
|
120
|
+
+ " {:>9.{digits}}" * 2
|
121
|
+
+ " {:>9.{digits}f}"
|
122
|
+
+ " {:>9}\n"
|
123
|
+
)
|
124
|
+
report += row_fmt_accuracy.format(
|
125
|
+
average, "", "", *scores[2:], width=width, digits=digits
|
126
|
+
)
|
127
|
+
else:
|
128
|
+
report += row_fmt.format(average, *scores, width=width, digits=digits)
|
129
|
+
|
130
|
+
if 'accuracy' in report_dict:
|
131
|
+
row_fmt_accuracy = (
|
132
|
+
"{:>{width}s} "
|
133
|
+
+ " {:>9.{digits}}" * 2
|
134
|
+
+ " {:>9.{digits}f}"
|
135
|
+
+ " {:>9}\n"
|
136
|
+
)
|
137
|
+
report += row_fmt_accuracy.format(
|
138
|
+
"accuracy", "", "", report_dict["accuracy"], "", width=width, digits=digits
|
139
|
+
)
|
140
|
+
|
141
|
+
return report
|
142
|
+
|
143
|
+
|
144
|
+
# def append_metrics(metrics, metric, result, lr):
|
145
|
+
# metrics['train_losses'].append(result['train_loss'])
|
146
|
+
# metrics['val_losses'].append(result['val_loss'])
|
147
|
+
# metrics['accuracies'].append(metric['accuracy'])
|
148
|
+
# metrics['precisions'].append(metric['precision'])
|
149
|
+
# metrics['recalls'].append(metric['recall'])
|
150
|
+
# metrics['f1-scores'].append(metric['f1-score'])
|
151
|
+
# metrics['lrs'].append(lr)
|
152
|
+
# return metrics
|
6
153
|
|
7
154
|
|
8
155
|
def initialize_results_file(results_file, result_info):
|
@@ -12,9 +159,10 @@ def initialize_results_file(results_file, result_info):
|
|
12
159
|
参数:
|
13
160
|
results_file (str): 结果文件的路径。
|
14
161
|
result_info (list): 需要写入的第一行内容列表。
|
162
|
+
space:列名间隔(默认两个空格的距离)
|
15
163
|
"""
|
16
164
|
# 处理 result_info,在每个单词后添加两个空格
|
17
|
-
result_info_str =
|
165
|
+
result_info_str = ' '.join(result_info) + '\n'
|
18
166
|
# 检查文件是否存在
|
19
167
|
if os.path.exists(results_file):
|
20
168
|
# 如果文件存在,读取第一行
|
@@ -35,86 +183,45 @@ def initialize_results_file(results_file, result_info):
|
|
35
183
|
print(f"文件 {results_file} 已创建并写入 result_info。")
|
36
184
|
|
37
185
|
|
38
|
-
def
|
39
|
-
data_dict: dict,
|
40
|
-
column_order: list,
|
41
|
-
float_precision: int = 5) -> None:
|
186
|
+
def is_similar_key(key1, key2):
|
42
187
|
"""
|
43
|
-
|
188
|
+
检查两个键是否相似,考虑复数形式的转换。
|
44
189
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
190
|
+
Args:
|
191
|
+
key1 (str): 第一个键
|
192
|
+
key2 (str): 第二个键
|
193
|
+
|
194
|
+
Returns:
|
195
|
+
bool: 如果两个键相似(包括复数形式的转换),返回 True,否则返回 False
|
50
196
|
"""
|
51
|
-
|
52
|
-
|
53
|
-
for key in data_dict:
|
54
|
-
if not isinstance(data_dict[key], list):
|
55
|
-
raise ValueError(f"Value for key '{key}' is not a list")
|
56
|
-
if rows is None:
|
57
|
-
rows = len(data_dict[key])
|
58
|
-
else:
|
59
|
-
if len(data_dict[key]) != rows:
|
60
|
-
raise ValueError("All lists in data_dict must have the same length")
|
197
|
+
if key1 == key2:
|
198
|
+
return True
|
61
199
|
|
62
|
-
#
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
if column_name in ['train_losses', 'val_losses']:
|
68
|
-
return f"{value:.{float_precision + 1}f}"
|
69
|
-
elif column_name == 'lrs':
|
70
|
-
return f"{value:.8f}"
|
71
|
-
else:
|
72
|
-
return f"{value:.{float_precision}f}"
|
73
|
-
elif isinstance(value, str):
|
74
|
-
return value
|
75
|
-
else:
|
76
|
-
return str(value)
|
200
|
+
# 检查 key2 是否是复数形式
|
201
|
+
if key2.endswith("ies"):
|
202
|
+
singular_candidate = key2.removesuffix("ies") + "y"
|
203
|
+
if key1 == singular_candidate:
|
204
|
+
return True
|
77
205
|
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
if dict_key not in data_dict:
|
83
|
-
raise ValueError(f"Missing required column: {dict_key}")
|
84
|
-
values = data_dict[dict_key]
|
85
|
-
|
86
|
-
max_width = len(col)
|
87
|
-
for val in values:
|
88
|
-
fmt_val = format_value(val, col)
|
89
|
-
max_width = max(max_width, len(fmt_val))
|
90
|
-
column_widths.append(max_width)
|
91
|
-
|
92
|
-
# 生成格式化行
|
93
|
-
lines = []
|
94
|
-
for i in range(rows):
|
95
|
-
row = []
|
96
|
-
for j, col in enumerate(column_order):
|
97
|
-
dict_key = 'val_accuracies' if col == 'accuracies' else col
|
98
|
-
val = data_dict[dict_key][i]
|
99
|
-
fmt_val = format_value(val, col)
|
100
|
-
|
101
|
-
# 对齐处理
|
102
|
-
if j == len(column_order) - 1:
|
103
|
-
fmt_val = fmt_val.ljust(column_widths[j])
|
104
|
-
else:
|
105
|
-
fmt_val = fmt_val.rjust(column_widths[j])
|
106
|
-
row.append(fmt_val)
|
107
|
-
lines.append(" ".join(row) + '\n')
|
206
|
+
if key2.endswith("es"):
|
207
|
+
singular_candidate = key2.removesuffix("es")
|
208
|
+
if key1 == singular_candidate:
|
209
|
+
return True
|
108
210
|
|
109
|
-
|
110
|
-
|
111
|
-
|
211
|
+
if key2.endswith("s"):
|
212
|
+
singular_candidate = key2.removesuffix("s")
|
213
|
+
if key1 == singular_candidate:
|
214
|
+
return True
|
215
|
+
|
216
|
+
return False
|
112
217
|
|
113
218
|
|
114
219
|
def append_to_results_file(file_path: str,
|
115
220
|
data_dict: dict,
|
116
221
|
column_order: list,
|
117
|
-
float_precision: int =
|
222
|
+
float_precision: int = 4,
|
223
|
+
more_float: int = 2,
|
224
|
+
custom_column_widths: dict = None) -> None:
|
118
225
|
"""
|
119
226
|
通用格式化文本行写入函数
|
120
227
|
|
@@ -123,182 +230,340 @@ def append_to_results_file(file_path: str,
|
|
123
230
|
data_dict: 包含数据的字典,键为列名
|
124
231
|
column_order: 列顺序列表,元素为字典键
|
125
232
|
float_precision: 浮点数精度位数 (默认5位)
|
233
|
+
more_float: 额外的浮点数精度位数
|
234
|
+
custom_column_widths: 自定义列宽的字典,键为列名,值为列宽
|
126
235
|
"""
|
127
|
-
#
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
#
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
elif isinstance(value, str):
|
157
|
-
fmt_value = value
|
158
|
-
else: # 处理其他类型转换为字符串
|
159
|
-
fmt_value = str(value)
|
160
|
-
|
161
|
-
# 取列名长度和数值长度的最大值作为列宽
|
162
|
-
column_width = max(len(col), len(fmt_value))
|
163
|
-
column_widths.append(column_width)
|
164
|
-
|
165
|
-
# 应用列宽对齐
|
166
|
-
if col == column_order[-1]: # 最后一列左边对齐
|
167
|
-
fmt_value = fmt_value.ljust(column_width)
|
168
|
-
else:
|
169
|
-
fmt_value = fmt_value.rjust(column_width)
|
170
|
-
|
171
|
-
formatted_data.append(fmt_value)
|
172
|
-
|
173
|
-
# 构建文本行并写入,列之间用两个空格分隔
|
174
|
-
if formatted_data:
|
175
|
-
line = " ".join(formatted_data) + '\n'
|
176
|
-
with open(file_path, 'a', encoding='utf-8') as f:
|
177
|
-
f.write(line)
|
178
|
-
else:
|
179
|
-
# 非列表情况,原逻辑处理
|
180
|
-
# 计算每列的最大宽度
|
181
|
-
column_widths = []
|
182
|
-
formatted_data = []
|
183
|
-
for col in column_order:
|
184
|
-
# 处理字典键的别名
|
185
|
-
dict_key = 'val_accuracies' if col == 'accuracies' else col
|
186
|
-
# 如果键不存在,跳过该列
|
187
|
-
if dict_key not in data_dict:
|
188
|
-
continue
|
189
|
-
|
190
|
-
value = data_dict[dict_key]
|
191
|
-
|
192
|
-
# 根据数据类型进行格式化
|
193
|
-
if isinstance(value, (int, np.integer)):
|
194
|
-
fmt_value = f"{value:d}"
|
195
|
-
elif isinstance(value, (float, np.floating)):
|
196
|
-
if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
|
197
|
-
fmt_value = f"{value:.{float_precision + 1}f}"
|
198
|
-
elif col == 'lrs': # 如果列名是'lrs',保留8位小数
|
199
|
-
fmt_value = f"{value:.8f}"
|
200
|
-
else:
|
201
|
-
fmt_value = f"{value:.{float_precision}f}"
|
202
|
-
elif isinstance(value, str):
|
203
|
-
fmt_value = value
|
204
|
-
else: # 处理其他类型转换为字符串
|
205
|
-
fmt_value = str(value)
|
236
|
+
# 计算每列的最大宽度
|
237
|
+
column_widths = []
|
238
|
+
formatted_data = []
|
239
|
+
for col in column_order:
|
240
|
+
# 查找 data_dict 中相似的键
|
241
|
+
dict_key = None
|
242
|
+
for key in data_dict:
|
243
|
+
if is_similar_key(key, col):
|
244
|
+
dict_key = key
|
245
|
+
break
|
246
|
+
if dict_key is None:
|
247
|
+
raise ValueError(f"Missing required column: {col}")
|
248
|
+
|
249
|
+
value = data_dict[dict_key]
|
250
|
+
|
251
|
+
# 根据数据类型进行格式化
|
252
|
+
if isinstance(value, (int, np.integer)):
|
253
|
+
fmt_value = f"{value:d}"
|
254
|
+
elif isinstance(value, (float, np.floating)):
|
255
|
+
if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
|
256
|
+
fmt_value = f"{value:.{float_precision + more_float}f}"
|
257
|
+
elif col == 'lrs':
|
258
|
+
fmt_value = f"{value:.8f}"
|
259
|
+
else:
|
260
|
+
fmt_value = f"{value:.{float_precision}f}"
|
261
|
+
elif isinstance(value, str):
|
262
|
+
fmt_value = value
|
263
|
+
else: # 处理其他类型转换为字符串
|
264
|
+
fmt_value = str(value)
|
206
265
|
|
266
|
+
# 确定列宽
|
267
|
+
if custom_column_widths and col in custom_column_widths:
|
268
|
+
column_width = custom_column_widths[col]
|
269
|
+
else:
|
207
270
|
# 取列名长度和数值长度的最大值作为列宽
|
208
271
|
column_width = max(len(col), len(fmt_value))
|
209
|
-
|
272
|
+
column_widths.append(column_width)
|
210
273
|
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
# def append_to_results_file(file_path: str,
|
227
|
-
# data_dict: dict,
|
228
|
-
# column_order: list,
|
229
|
-
# column_widths: list = None,
|
230
|
-
# float_precision: int = 5) -> None:
|
231
|
-
# """
|
232
|
-
# 通用格式化文本行写入函数
|
233
|
-
#
|
234
|
-
# 参数:
|
235
|
-
# file_path: 目标文件路径
|
236
|
-
# data_dict: 包含数据的字典,键为列名
|
237
|
-
# column_order: 列顺序列表,元素为字典键
|
238
|
-
# column_widths: 每列字符宽度列表 (可选)
|
239
|
-
# float_precision: 浮点数精度位数 (默认4位)
|
240
|
-
# """
|
241
|
-
# formatted_data = []
|
242
|
-
#
|
243
|
-
# # 遍历指定列顺序处理数据
|
244
|
-
# for i, col in enumerate(column_order):
|
245
|
-
# # 处理字典键的别名
|
246
|
-
# if col == 'accuracies':
|
247
|
-
# dict_key = 'val_accuracies'
|
248
|
-
# else:
|
249
|
-
# dict_key = col
|
250
|
-
#
|
251
|
-
# if dict_key not in data_dict:
|
252
|
-
# raise ValueError(f"Missing required column: {dict_key}")
|
253
|
-
#
|
254
|
-
# value = data_dict[dict_key]
|
255
|
-
#
|
256
|
-
# # 根据数据类型进行格式化
|
257
|
-
# if isinstance(value, (int, np.integer)):
|
258
|
-
# fmt_value = f"{value:d}"
|
259
|
-
# elif isinstance(value, (float, np.floating)):
|
260
|
-
# if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
|
261
|
-
# fmt_value = f"{value:.{float_precision + 1}f}"
|
262
|
-
# elif col == 'lr': # 如果列名是'lr',保留8位小数
|
263
|
-
# fmt_value = f"{value:.8f}"
|
264
|
-
# else:
|
265
|
-
# fmt_value = f"{value:.{float_precision}f}"
|
266
|
-
# elif isinstance(value, str):
|
267
|
-
# fmt_value = value
|
268
|
-
# else: # 处理其他类型转换为字符串
|
269
|
-
# fmt_value = str(value)
|
270
|
-
#
|
271
|
-
# # 应用列宽对齐
|
272
|
-
# if column_widths and i < len(column_widths):
|
273
|
-
# try:
|
274
|
-
# if i == len(column_order) - 1: # 最后一列左边对齐
|
275
|
-
# fmt_value = fmt_value.ljust(column_widths[i])
|
276
|
-
# else:
|
277
|
-
# fmt_value = fmt_value.rjust(column_widths[i])
|
278
|
-
# except TypeError: # 处理非字符串类型
|
279
|
-
# if i == len(column_order) - 1: # 最后一列左边对齐
|
280
|
-
# fmt_value = str(fmt_value).ljust(column_widths[i])
|
281
|
-
# else:
|
282
|
-
# fmt_value = str(fmt_value).rjust(column_widths[i])
|
283
|
-
#
|
284
|
-
# formatted_data.append(fmt_value)
|
285
|
-
#
|
286
|
-
# # 构建文本行并写入
|
287
|
-
# line = '\t'.join(formatted_data) + '\n'
|
288
|
-
# with open(file_path, 'a', encoding='utf-8') as f:
|
289
|
-
# f.write(line)
|
290
|
-
|
291
|
-
|
292
|
-
def get_wandb_key(key_path='tools/wandb_key.txt'):
|
274
|
+
# 应用列宽对齐
|
275
|
+
if col == column_order[-1]: # 最后一列左边对齐
|
276
|
+
fmt_value = fmt_value.ljust(column_width)
|
277
|
+
else:
|
278
|
+
fmt_value = fmt_value.rjust(column_width)
|
279
|
+
|
280
|
+
formatted_data.append(fmt_value)
|
281
|
+
|
282
|
+
# 构建文本行并写入,列之间用两个空格分隔
|
283
|
+
line = " ".join(formatted_data) + '\n'
|
284
|
+
with open(file_path, 'a', encoding='utf-8') as f:
|
285
|
+
f.write(line)
|
286
|
+
|
287
|
+
|
288
|
+
def get_wandb_key(key_path):
|
293
289
|
with open(key_path, 'r', encoding='utf-8') as f:
|
294
290
|
key = f.read()
|
295
291
|
return key
|
296
292
|
|
297
293
|
|
298
|
-
def
|
294
|
+
def wandb_init(project=None, key_path=None, name=None):
|
299
295
|
run = None
|
300
296
|
if project is not None:
|
297
|
+
if key_path is None:
|
298
|
+
raise ValueError("When 'project' is not None, 'key_path' should also not be None.")
|
301
299
|
wandb_key = get_wandb_key(key_path)
|
302
300
|
wandb.login(key=wandb_key)
|
303
301
|
run = wandb.init(project=project, name=name)
|
304
302
|
return run
|
303
|
+
|
304
|
+
|
305
|
+
def check_wandb_login_required():
|
306
|
+
"""兼容旧版的登录检查函数"""
|
307
|
+
# 优先检查环境变量
|
308
|
+
if os.environ.get("WANDB_API_KEY"):
|
309
|
+
return False
|
310
|
+
|
311
|
+
try:
|
312
|
+
api = wandb.Api()
|
313
|
+
# 方法 1:通过 settings 检查(适用于旧版)
|
314
|
+
if hasattr(api, "settings") and api.settings.get("entity"):
|
315
|
+
return False
|
316
|
+
|
317
|
+
# 方法 2:通过 projects() 验证(通用性强)
|
318
|
+
api.projects(per_page=1) # 仅请求第一页的第一个项目
|
319
|
+
return False
|
320
|
+
except Exception as e:
|
321
|
+
print(f"检测到意外错误: {str(e)}")
|
322
|
+
return True # 保守返回需要登录
|
323
|
+
|
324
|
+
|
325
|
+
def get_wandb_runs(
|
326
|
+
project_path: str,
|
327
|
+
default_name: str = "未命名",
|
328
|
+
api_key: Optional[str] = None,
|
329
|
+
per_page: int = 1000
|
330
|
+
) -> List[Dict[str, str]]:
|
331
|
+
"""
|
332
|
+
获取指定 WandB 项目的所有运行信息(ID 和 Name)
|
333
|
+
|
334
|
+
Args:
|
335
|
+
project_path (str): 项目路径,格式为 "username/project_name"
|
336
|
+
default_name (str): 当运行未命名时的默认显示名称(默认:"未命名")
|
337
|
+
api_key (str, optional): WandB API 密钥,若未设置环境变量则需传入
|
338
|
+
per_page (int): 分页查询每页数量(默认1000,用于处理大量运行)
|
339
|
+
|
340
|
+
Returns:
|
341
|
+
List[Dict]: 包含运行信息的字典列表,格式 [{"id": "...", "name": "..."}]
|
342
|
+
|
343
|
+
Raises:
|
344
|
+
ValueError: 项目路径格式错误
|
345
|
+
wandb.errors.UsageError: API 密钥无效或未登录
|
346
|
+
"""
|
347
|
+
# 参数校验
|
348
|
+
if "/" not in project_path or len(project_path.split("/")) != 2:
|
349
|
+
raise ValueError("项目路径格式应为 'username/project_name'")
|
350
|
+
|
351
|
+
# 登录(仅在需要时)
|
352
|
+
if api_key:
|
353
|
+
wandb.login(key=api_key)
|
354
|
+
elif not wandb.api.api_key:
|
355
|
+
raise wandb.errors.UsageError("需要提供API密钥或预先调用wandb.login()")
|
356
|
+
|
357
|
+
# 初始化API
|
358
|
+
api = wandb.Api()
|
359
|
+
|
360
|
+
try:
|
361
|
+
# 分页获取所有运行(自动处理分页逻辑)
|
362
|
+
runs = api.runs(project_path, per_page=per_page)
|
363
|
+
print(f'共获取{len(runs)}个run')
|
364
|
+
return [
|
365
|
+
{
|
366
|
+
"id": run.id,
|
367
|
+
"name": run.name or default_name,
|
368
|
+
"url": run.url, # 增加实用字段
|
369
|
+
"state": run.state # 包含运行状态
|
370
|
+
}
|
371
|
+
for run in runs
|
372
|
+
]
|
373
|
+
|
374
|
+
except wandb.errors.CommError as e:
|
375
|
+
raise ConnectionError(f"连接失败: {str(e)}") from e
|
376
|
+
except Exception as e:
|
377
|
+
raise RuntimeError(f"获取运行数据失败: {str(e)}") from e
|
378
|
+
|
379
|
+
|
380
|
+
def delete_runs(
|
381
|
+
project_path: str,
|
382
|
+
run_ids: Optional[List[str]] = None,
|
383
|
+
run_names: Optional[List[str]] = None,
|
384
|
+
delete_all: bool = False,
|
385
|
+
dry_run: bool = True,
|
386
|
+
api_key: Optional[str] = None,
|
387
|
+
per_page: int = 500
|
388
|
+
) -> dict:
|
389
|
+
"""
|
390
|
+
多功能WandB运行删除工具
|
391
|
+
|
392
|
+
:param project_path: 项目路径(格式:username/project_name)
|
393
|
+
:param run_ids: 指定要删除的运行ID列表(无视状态)
|
394
|
+
:param run_names: 指定要删除的运行名称列表(无视状态)
|
395
|
+
# :param preserve_states: 保护状态列表(默认保护 finished/running)
|
396
|
+
:param delete_all: 危险模式!删除所有运行(默认False)
|
397
|
+
:param dry_run: 模拟运行模式(默认True)
|
398
|
+
:param api_key: WandB API密钥
|
399
|
+
:param per_page: 分页查询数量
|
400
|
+
:return: 操作统计字典
|
401
|
+
|
402
|
+
使用场景:
|
403
|
+
1. 删除指定运行:delete_runs(..., run_ids=["abc","def"])
|
404
|
+
2. 默认删除失败运行:delete_runs(...)
|
405
|
+
3. 删除所有运行:delete_runs(..., delete_all=True)
|
406
|
+
"""
|
407
|
+
preserve_states: List[str] = ["finished", "running"]
|
408
|
+
# 参数校验
|
409
|
+
if not project_path.count("/") == 1:
|
410
|
+
raise ValueError("项目路径格式应为 username/project_name")
|
411
|
+
if delete_all and (run_ids or run_names):
|
412
|
+
raise ValueError("delete_all模式不能与其他筛选参数同时使用")
|
413
|
+
|
414
|
+
# 身份验证
|
415
|
+
if api_key:
|
416
|
+
wandb.login(key=api_key)
|
417
|
+
elif not wandb.api.api_key:
|
418
|
+
raise wandb.errors.UsageError("需要API密钥或预先登录")
|
419
|
+
|
420
|
+
api = wandb.Api()
|
421
|
+
stats = {
|
422
|
+
"total": 0,
|
423
|
+
"candidates": 0,
|
424
|
+
"deleted": 0,
|
425
|
+
"failed": 0,
|
426
|
+
"dry_run": dry_run
|
427
|
+
}
|
428
|
+
|
429
|
+
try:
|
430
|
+
runs = api.runs(project_path, per_page=per_page)
|
431
|
+
stats["total"] = len(runs)
|
432
|
+
|
433
|
+
# 确定删除目标
|
434
|
+
if delete_all:
|
435
|
+
targets = runs
|
436
|
+
click.secho("\n⚠️ 危险操作:将删除项目所有运行!", fg="red", bold=True)
|
437
|
+
elif run_ids or run_names:
|
438
|
+
targets = [
|
439
|
+
run for run in runs
|
440
|
+
if run.id in (run_ids or []) or run.name in (run_names or [])
|
441
|
+
]
|
442
|
+
print(f"\n找到 {len(targets)} 个指定运行")
|
443
|
+
else:
|
444
|
+
targets = [run for run in runs if run.state not in preserve_states]
|
445
|
+
print(f"\n找到 {len(targets)} 个非正常状态运行")
|
446
|
+
|
447
|
+
stats["candidates"] = len(targets)
|
448
|
+
|
449
|
+
if not targets:
|
450
|
+
print("没有符合条件的运行")
|
451
|
+
return stats
|
452
|
+
|
453
|
+
# 打印预览
|
454
|
+
print("\n待删除运行示例:")
|
455
|
+
for run in targets[:3]:
|
456
|
+
state = click.style(run.state, fg="green" if run.state == "finished" else "red")
|
457
|
+
print(f" • {run.id} | {run.name} | 状态:{state}")
|
458
|
+
if len(targets) > 3:
|
459
|
+
print(f" ...(共 {len(targets)} 条)")
|
460
|
+
|
461
|
+
# 安全确认
|
462
|
+
if dry_run:
|
463
|
+
click.secho("\n模拟运行模式:不会实际删除", fg="yellow")
|
464
|
+
return stats
|
465
|
+
|
466
|
+
if delete_all:
|
467
|
+
msg = click.style("确认要删除所有运行吗?此操作不可逆!", fg="red", bold=True)
|
468
|
+
else:
|
469
|
+
msg = f"确认要删除 {len(targets)} 个运行吗?"
|
470
|
+
|
471
|
+
if not click.confirm(msg, default=False):
|
472
|
+
print("操作已取消")
|
473
|
+
return stats
|
474
|
+
|
475
|
+
# 执行删除
|
476
|
+
print("\n删除进度:")
|
477
|
+
for i, run in enumerate(targets, 1):
|
478
|
+
try:
|
479
|
+
run.delete()
|
480
|
+
stats["deleted"] += 1
|
481
|
+
print(click.style(f" [{i}/{len(targets)}] 已删除 {run.id}", fg="green"))
|
482
|
+
except Exception as e:
|
483
|
+
stats["failed"] += 1
|
484
|
+
print(click.style(f" [{i}/{len(targets)}] 删除失败 {run.id}: {str(e)}", fg="red"))
|
485
|
+
|
486
|
+
return stats
|
487
|
+
|
488
|
+
except wandb.errors.CommError as e:
|
489
|
+
raise ConnectionError(f"网络错误: {str(e)}")
|
490
|
+
except Exception as e:
|
491
|
+
raise RuntimeError(f"操作失败: {str(e)}")
|
492
|
+
|
493
|
+
|
494
|
+
def get_all_artifacts_from_project(project_path, max_runs=None, run_id=None):
|
495
|
+
"""获取WandB项目或指定Run的所有Artifact
|
496
|
+
|
497
|
+
Args:
|
498
|
+
project_path (str): 项目路径,格式为 "entity/project"
|
499
|
+
max_runs (int, optional): 最大获取Run数量(仅当未指定run_id时生效)
|
500
|
+
run_id (str, optional): 指定要查询的Run ID
|
501
|
+
|
502
|
+
Returns:
|
503
|
+
list: 包含所有Artifact对象的列表
|
504
|
+
"""
|
505
|
+
api = wandb.Api()
|
506
|
+
all_artifacts = []
|
507
|
+
seen_artifacts = set() # 用于去重
|
508
|
+
|
509
|
+
try:
|
510
|
+
if run_id:
|
511
|
+
# 处理单个Run的情况
|
512
|
+
run = api.run(f"{project_path}/{run_id}")
|
513
|
+
artifacts = run.logged_artifacts()
|
514
|
+
|
515
|
+
for artifact in artifacts:
|
516
|
+
artifact_id = f"{artifact.name}:{artifact.version}"
|
517
|
+
if artifact_id not in seen_artifacts:
|
518
|
+
all_artifacts.append(artifact)
|
519
|
+
seen_artifacts.add(artifact_id)
|
520
|
+
|
521
|
+
print(f"Found {len(all_artifacts)} artifacts in run {run_id}")
|
522
|
+
else:
|
523
|
+
# 处理整个项目的情况
|
524
|
+
runs = api.runs(project_path, per_page=500)
|
525
|
+
run_iterator = tqdm(runs[:max_runs] if max_runs else runs,
|
526
|
+
desc=f"Scanning {project_path}")
|
527
|
+
|
528
|
+
for run in run_iterator:
|
529
|
+
try:
|
530
|
+
artifacts = run.logged_artifacts()
|
531
|
+
for artifact in artifacts:
|
532
|
+
artifact_id = f"{artifact.name}:{artifact.version}"
|
533
|
+
if artifact_id not in seen_artifacts:
|
534
|
+
all_artifacts.append(artifact)
|
535
|
+
seen_artifacts.add(artifact_id)
|
536
|
+
except Exception as run_error:
|
537
|
+
print(f"Error processing run {run.id}: {str(run_error)}")
|
538
|
+
|
539
|
+
except Exception as e:
|
540
|
+
print(f"Error: {str(e)}")
|
541
|
+
return []
|
542
|
+
|
543
|
+
return all_artifacts
|
544
|
+
|
545
|
+
|
546
|
+
def upload_model_dataset(
|
547
|
+
artifact_dir: str,
|
548
|
+
artifact_name: str,
|
549
|
+
artifact_type: str) -> None:
|
550
|
+
run_id = f'yms_upload_{artifact_type}_' + get_current_time('%y%m%d_%H%M%S')
|
551
|
+
run = wandb.init(project='upload_model_dataset', name=artifact_name, id=run_id)
|
552
|
+
artifact = wandb.Artifact(artifact_name, artifact_type)
|
553
|
+
artifact.add_dir(artifact_dir)
|
554
|
+
run.log_artifact(artifact)
|
555
|
+
run.finish()
|
556
|
+
|
557
|
+
|
558
|
+
def download_model_dataset(
|
559
|
+
download_name: str,
|
560
|
+
run_name: str,
|
561
|
+
artifact_type: str,
|
562
|
+
download_dir: str = None,
|
563
|
+
entity: str = 'YNA-DeepLearning'
|
564
|
+
) -> str:
|
565
|
+
run_id = f'yms_download_{artifact_type}_' + get_current_time('%y%m%d_%H%M%S')
|
566
|
+
run = wandb.init(project='download_model_dataset', name=run_name, id=run_id)
|
567
|
+
artifact = run.use_artifact(entity + '/upload_model_dataset/' + download_name, type=artifact_type)
|
568
|
+
artifact_dir = artifact.download(root=download_dir)
|
569
|
+
return artifact_dir
|
yms_kan/train_eval_utils.py
CHANGED
@@ -1,25 +1,52 @@
|
|
1
1
|
import math
|
2
2
|
import os
|
3
3
|
import sys
|
4
|
-
from enum import Enum, auto
|
5
4
|
|
6
5
|
import numpy as np
|
7
6
|
import torch
|
8
7
|
from matplotlib import pyplot as plt
|
8
|
+
from sklearn.metrics import classification_report
|
9
9
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
10
10
|
from tqdm import tqdm
|
11
11
|
|
12
12
|
from yms_kan import LBFGS
|
13
|
+
from yms_kan.tool import initialize_results_file, append_to_results_file
|
13
14
|
|
14
15
|
|
15
|
-
def
|
16
|
-
|
16
|
+
def calculate_metric(all_labels, all_predictions, classes, class_metric=False, average='macro avg'):
|
17
|
+
metric = classification_report(y_true=all_labels, y_pred=all_predictions,
|
18
|
+
target_names=classes, digits=4, output_dict=True, zero_division=0)
|
19
|
+
if not class_metric:
|
20
|
+
metric = {
|
21
|
+
'accuracy': metric.get('accuracy'),
|
22
|
+
'precision': metric.get(average).get('precision'),
|
23
|
+
'recall': metric.get(average).get('recall'),
|
24
|
+
'f1-score': metric.get(average).get('f1-score'),
|
25
|
+
}
|
26
|
+
return metric
|
27
|
+
else:
|
28
|
+
return metric
|
29
|
+
|
30
|
+
|
31
|
+
def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_file=None, opt="LBFGS", epochs=100,
|
32
|
+
lamb=0.,
|
33
|
+
lamb_l1=1., label=None, class_dict=None, lamb_entropy=2., lamb_coef=0.,
|
17
34
|
lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
|
18
35
|
stop_grid_update_step=100,
|
19
36
|
save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
20
37
|
singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
|
21
|
-
|
22
|
-
|
38
|
+
all_predictions = []
|
39
|
+
all_labels = []
|
40
|
+
best = -1
|
41
|
+
column_order = ['epoch', 'train_losses', 'val_losses', 'accuracies', 'precisions', 'recalls',
|
42
|
+
'f1-scores', 'lrs']
|
43
|
+
custom_column_widths = {'epoch': 5, 'train_loss': 12, 'val_loss': 10, 'accuracy': 10, 'precision': 9,
|
44
|
+
'recall': 7,
|
45
|
+
'f1-score': 8,
|
46
|
+
'lr': 3}
|
47
|
+
if txt_file is not None:
|
48
|
+
initialize_results_file(txt_file, column_order)
|
49
|
+
|
23
50
|
if lamb > 0. and not model.save_act:
|
24
51
|
print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
|
25
52
|
|
@@ -44,8 +71,8 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
|
|
44
71
|
|
45
72
|
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
|
46
73
|
|
47
|
-
results = {'
|
48
|
-
'all_labels': []}
|
74
|
+
results = {'train_losses': [], 'val_losses': [], 'accuracies': [], 'precisions': [], 'recalls': [], 'f1-scores': [],
|
75
|
+
'lrs': [], 'all_predictions': [], 'all_labels': []}
|
49
76
|
|
50
77
|
steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
51
78
|
|
@@ -118,7 +145,6 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
|
|
118
145
|
optimizer.step()
|
119
146
|
train_pbar.set_postfix(loss=train_loss.item())
|
120
147
|
|
121
|
-
# print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
|
122
148
|
val_loss = torch.zeros(1).to(model.device)
|
123
149
|
with torch.no_grad():
|
124
150
|
test_indices = np.arange(dataset['test_input'].shape[0])
|
@@ -142,14 +168,34 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
|
|
142
168
|
diffs = torch.abs(outputs - label)
|
143
169
|
closest_indices = torch.argmin(diffs, dim=1)
|
144
170
|
closest_values = label[closest_indices]
|
145
|
-
|
146
|
-
|
171
|
+
all_predictions.extend(closest_values.detach().cpu().numpy())
|
172
|
+
all_labels.extend(batch_test_label.detach().cpu().numpy())
|
147
173
|
|
174
|
+
train_lr = lr_scheduler.get_last_lr()[0]
|
148
175
|
lr_scheduler.step(val_loss)
|
149
176
|
|
150
|
-
|
151
|
-
|
152
|
-
|
177
|
+
if label is not None:
|
178
|
+
m = calculate_metric(all_labels, all_predictions, class_dict)
|
179
|
+
print(m)
|
180
|
+
results["accuracy"].append(m["accuracy"])
|
181
|
+
results["precisions"].append(m["precision"])
|
182
|
+
results["recalls"].append(m["recall"])
|
183
|
+
results["f1-scores"].append(m["f1-score"])
|
184
|
+
results["lrs"].append(train_lr)
|
185
|
+
if best < m["f1-score"]:
|
186
|
+
best = m["f1-score"]
|
187
|
+
results['all_predictions'] = all_predictions
|
188
|
+
results['all_labels'] = all_labels
|
189
|
+
if save_path is not None:
|
190
|
+
model.saveckpt(path=save_path + '/' + 'model')
|
191
|
+
if txt_file is not None:
|
192
|
+
m.update({'lr': train_lr, 'epoch': epoch, 'train_loss': train_loss.item(), 'val_loss': val_loss.item()})
|
193
|
+
append_to_results_file(txt_file, m, column_order,
|
194
|
+
custom_column_widths=custom_column_widths)
|
195
|
+
|
196
|
+
results["train_losses"].append(train_loss.item())
|
197
|
+
results["val_losses"].append(val_loss.item())
|
198
|
+
results["regularize"].append(reg_.item())
|
153
199
|
|
154
200
|
if save_fig and epoch % save_fig_freq == 0:
|
155
201
|
model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
|
yms_kan/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.0.
|
1
|
+
__version__ = "0.0.3" # 初始版本
|
@@ -9,14 +9,14 @@ yms_kan/experiment.py,sha256=VnZq7hmcvRk08GNI7VIpkOjkaRZBoIw1C8SU_f1KbaA,1682
|
|
9
9
|
yms_kan/feynman.py,sha256=Eisf69K49s4C6UlPEi5LnNK_p5TUJQLBKxMp-sW0a9w,33687
|
10
10
|
yms_kan/hypothesis.py,sha256=Ec20xadfgOSSWeZHQaGn-h9F2PY7LWFU3iniNI2Zd_4,23165
|
11
11
|
yms_kan/spline.py,sha256=ZXyGwl2Sc-UrnrcuUXeUQkBOMnetaWcHrbpZaqatCvs,4345
|
12
|
-
yms_kan/tool.py,sha256=
|
13
|
-
yms_kan/train_eval_utils.py,sha256=
|
12
|
+
yms_kan/tool.py,sha256=rkRpqF3EcsAq7a3k1F1zKlxfJ4U9n-FzHyNCJgN4URY,21159
|
13
|
+
yms_kan/train_eval_utils.py,sha256=y5eI6-kJU51pKTgB3TdwGyu1QKTACwbamZ9ZOdhPogc,17184
|
14
14
|
yms_kan/utils.py,sha256=J07L-tgmc1OfU6Tl6mGwHJRizjFN75EJK8BxejaZLUc,23860
|
15
|
-
yms_kan/version.py,sha256=
|
15
|
+
yms_kan/version.py,sha256=ue5T-H1rqmrk8ISYQmYosD_ZfIp5J-L-wsfBrW8sgCw,39
|
16
16
|
yms_kan/assets/img/mult_symbol.png,sha256=2f4xUKdweft-qUbHjFI5h9-smnEtc0FWq8hNYZhPAXY,6392
|
17
17
|
yms_kan/assets/img/sum_symbol.png,sha256=94QkMUzmEjlCq_yf14nMEQmettaq86FmlGfdl22b4XE,6210
|
18
|
-
yms_kan-0.0.
|
19
|
-
yms_kan-0.0.
|
20
|
-
yms_kan-0.0.
|
21
|
-
yms_kan-0.0.
|
22
|
-
yms_kan-0.0.
|
18
|
+
yms_kan-0.0.3.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
|
19
|
+
yms_kan-0.0.3.dist-info/METADATA,sha256=o53cYpZ1jV7K8ptCYWV5aG-jsekuKcb6wuzJW_sxsWo,240
|
20
|
+
yms_kan-0.0.3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
21
|
+
yms_kan-0.0.3.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
|
22
|
+
yms_kan-0.0.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|