yms-kan 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yms_kan/tool.py CHANGED
@@ -1,8 +1,155 @@
1
1
  import os
2
+ import re
3
+ from datetime import datetime, timezone, timedelta
4
+ from typing import Optional, Dict, List
2
5
 
6
+ import click
3
7
  import numpy as np
4
- import torch
8
+ import pandas as pd
5
9
  import wandb
10
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, \
11
+ classification_report
12
+ from tqdm import tqdm
13
+
14
+
15
+ # 读取txt内两个不同表格的数据,并将结果转换为字典列表输出
16
+ def read_multi_table_txt(file_path):
17
+ # 读取原始内容
18
+ with open(file_path, 'r') as f:
19
+ content = f.read()
20
+
21
+ # 按表格标题分割内容(假设每个新表格以"epoch"开头)
22
+ table_blocks = re.split(r'\n(?=epoch\s)', content.strip())
23
+
24
+ # 处理每个表格块
25
+ table_dicts = []
26
+ for block in table_blocks:
27
+ lines = [line.strip() for line in block.split('\n') if line.strip()]
28
+
29
+ # 解析列名(处理制表符和混合空格)
30
+ columns = re.split(r'\s{2,}|\t', lines[0])
31
+
32
+ # 解析数据行(处理混合分隔符)
33
+ data = []
34
+ for line in lines[1:]:
35
+ # 使用正则表达式分割多个连续空格/制表符
36
+ row = re.split(r'\s{2,}|\t', line)
37
+ data.append(row)
38
+
39
+ # 创建DataFrame并自动转换数值类型
40
+ df = pd.DataFrame(data, columns=columns)
41
+ df = df.apply(pd.to_numeric, errors='coerce') # 自动识别数值列,非数值转换为NaN
42
+
43
+ # 将DataFrame转换为字典,每列以列表形式保存
44
+ table_dict = df.to_dict(orient='list')
45
+ table_dicts.append(table_dict)
46
+
47
+ return table_dicts
48
+
49
+
50
+ def get_current_time(format_str="%Y-%m-%d %H:%M:%S"):
51
+ """
52
+ 获取东八区(UTC+8)的当前时间,并返回指定格式的字符串
53
+ :param format_str: 时间格式(默认为 "%Y-%m-%d %H:%M:%S")
54
+ :return: 格式化后的时间字符串
55
+ """
56
+
57
+ # 创建东八区的时区对象
58
+ utc8_timezone = timezone(timedelta(hours=8))
59
+
60
+ # 转换为东八区时间
61
+ utc8_time = datetime.now(utc8_timezone)
62
+
63
+ # 格式化为字符串
64
+ formatted_time = utc8_time.strftime(format_str)
65
+ return formatted_time
66
+
67
+
68
+ # val和test时的相关结果指标计算
69
+ def calculate_results(all_labels, all_predictions, classes, average='macro'):
70
+ results = {
71
+ 'accuracy': accuracy_score(y_true=all_labels, y_pred=all_predictions),
72
+ 'precision': precision_score(y_true=all_labels, y_pred=all_predictions, average=average),
73
+ 'recall': recall_score(y_true=all_labels, y_pred=all_predictions, average=average),
74
+ 'f1_score': f1_score(y_true=all_labels, y_pred=all_predictions, average=average),
75
+ 'cm': confusion_matrix(y_true=all_labels, y_pred=all_predictions, labels=np.arange(len(classes)))
76
+ }
77
+ return results
78
+
79
+
80
+ def calculate_metric(all_labels, all_predictions, classes, class_metric=False, average='macro avg'):
81
+ metric = classification_report(y_true=all_labels, y_pred=all_predictions,
82
+ target_names=classes, digits=4, output_dict=True, zero_division=0)
83
+ if not class_metric:
84
+ metric = {
85
+ 'accuracy': metric.get('accuracy'),
86
+ 'precision': metric.get(average).get('precision'),
87
+ 'recall': metric.get(average).get('recall'),
88
+ 'f1-score': metric.get(average).get('f1-score'),
89
+ }
90
+ return metric
91
+ else:
92
+ return metric
93
+
94
+
95
+ def dict_to_classification_report(report_dict, digits=2):
96
+ headers = ["precision", "recall", "f1-score", "support"]
97
+ target_names = list(report_dict.keys())
98
+ target_names.remove('accuracy') if 'accuracy' in target_names else None
99
+ longest_last_line_heading = "weighted avg"
100
+ name_width = max(len(cn) for cn in target_names)
101
+ width = max(name_width, len(longest_last_line_heading), digits)
102
+ head_fmt = "{:>{width}s} " + " {:>9}" * len(headers)
103
+ report = head_fmt.format("", *headers, width=width)
104
+ report += "\n\n"
105
+ row_fmt = "{:>{width}s} " + " {:>9.{digits}f}" * 3 + " {:>9}\n"
106
+ for target_name in target_names:
107
+ scores = [report_dict[target_name][h] for h in headers]
108
+ report += row_fmt.format(target_name, *scores, width=width, digits=digits)
109
+ report += "\n"
110
+
111
+ average_options = ["micro avg", "macro avg", "weighted avg"]
112
+ if 'samples avg' in report_dict:
113
+ average_options.append('samples avg')
114
+ for average in average_options:
115
+ if average in report_dict:
116
+ scores = [report_dict[average][h] for h in headers]
117
+ if average == "accuracy":
118
+ row_fmt_accuracy = (
119
+ "{:>{width}s} "
120
+ + " {:>9.{digits}}" * 2
121
+ + " {:>9.{digits}f}"
122
+ + " {:>9}\n"
123
+ )
124
+ report += row_fmt_accuracy.format(
125
+ average, "", "", *scores[2:], width=width, digits=digits
126
+ )
127
+ else:
128
+ report += row_fmt.format(average, *scores, width=width, digits=digits)
129
+
130
+ if 'accuracy' in report_dict:
131
+ row_fmt_accuracy = (
132
+ "{:>{width}s} "
133
+ + " {:>9.{digits}}" * 2
134
+ + " {:>9.{digits}f}"
135
+ + " {:>9}\n"
136
+ )
137
+ report += row_fmt_accuracy.format(
138
+ "accuracy", "", "", report_dict["accuracy"], "", width=width, digits=digits
139
+ )
140
+
141
+ return report
142
+
143
+
144
+ # def append_metrics(metrics, metric, result, lr):
145
+ # metrics['train_losses'].append(result['train_loss'])
146
+ # metrics['val_losses'].append(result['val_loss'])
147
+ # metrics['accuracies'].append(metric['accuracy'])
148
+ # metrics['precisions'].append(metric['precision'])
149
+ # metrics['recalls'].append(metric['recall'])
150
+ # metrics['f1-scores'].append(metric['f1-score'])
151
+ # metrics['lrs'].append(lr)
152
+ # return metrics
6
153
 
7
154
 
8
155
  def initialize_results_file(results_file, result_info):
@@ -12,9 +159,10 @@ def initialize_results_file(results_file, result_info):
12
159
  参数:
13
160
  results_file (str): 结果文件的路径。
14
161
  result_info (list): 需要写入的第一行内容列表。
162
+ space:列名间隔(默认两个空格的距离)
15
163
  """
16
164
  # 处理 result_info,在每个单词后添加两个空格
17
- result_info_str = " ".join(result_info) + '\n'
165
+ result_info_str = ' '.join(result_info) + '\n'
18
166
  # 检查文件是否存在
19
167
  if os.path.exists(results_file):
20
168
  # 如果文件存在,读取第一行
@@ -35,86 +183,45 @@ def initialize_results_file(results_file, result_info):
35
183
  print(f"文件 {results_file} 已创建并写入 result_info。")
36
184
 
37
185
 
38
- def write_results_file(file_path: str,
39
- data_dict: dict,
40
- column_order: list,
41
- float_precision: int = 5) -> None:
186
+ def is_similar_key(key1, key2):
42
187
  """
43
- 通用格式化文本行写入函数(支持列表形式数据)
188
+ 检查两个键是否相似,考虑复数形式的转换。
44
189
 
45
- 参数:
46
- file_path: 目标文件路径
47
- data_dict: 包含数据的字典,键为列名,值为列表
48
- column_order: 列顺序列表,元素为字典键
49
- float_precision: 浮点数精度位数 (默认5位)
190
+ Args:
191
+ key1 (str): 第一个键
192
+ key2 (str): 第二个键
193
+
194
+ Returns:
195
+ bool: 如果两个键相似(包括复数形式的转换),返回 True,否则返回 False
50
196
  """
51
- # 验证数据格式
52
- rows = None
53
- for key in data_dict:
54
- if not isinstance(data_dict[key], list):
55
- raise ValueError(f"Value for key '{key}' is not a list")
56
- if rows is None:
57
- rows = len(data_dict[key])
58
- else:
59
- if len(data_dict[key]) != rows:
60
- raise ValueError("All lists in data_dict must have the same length")
197
+ if key1 == key2:
198
+ return True
61
199
 
62
- # 辅助函数:格式化单个值
63
- def format_value(value, column_name):
64
- if isinstance(value, (int, np.integer)):
65
- return f"{value:d}"
66
- elif isinstance(value, (float, np.floating)):
67
- if column_name in ['train_losses', 'val_losses']:
68
- return f"{value:.{float_precision + 1}f}"
69
- elif column_name == 'lrs':
70
- return f"{value:.8f}"
71
- else:
72
- return f"{value:.{float_precision}f}"
73
- elif isinstance(value, str):
74
- return value
75
- else:
76
- return str(value)
200
+ # 检查 key2 是否是复数形式
201
+ if key2.endswith("ies"):
202
+ singular_candidate = key2.removesuffix("ies") + "y"
203
+ if key1 == singular_candidate:
204
+ return True
77
205
 
78
- # 计算列宽
79
- column_widths = []
80
- for col in column_order:
81
- dict_key = 'val_accuracies' if col == 'accuracies' else col
82
- if dict_key not in data_dict:
83
- raise ValueError(f"Missing required column: {dict_key}")
84
- values = data_dict[dict_key]
85
-
86
- max_width = len(col)
87
- for val in values:
88
- fmt_val = format_value(val, col)
89
- max_width = max(max_width, len(fmt_val))
90
- column_widths.append(max_width)
91
-
92
- # 生成格式化行
93
- lines = []
94
- for i in range(rows):
95
- row = []
96
- for j, col in enumerate(column_order):
97
- dict_key = 'val_accuracies' if col == 'accuracies' else col
98
- val = data_dict[dict_key][i]
99
- fmt_val = format_value(val, col)
100
-
101
- # 对齐处理
102
- if j == len(column_order) - 1:
103
- fmt_val = fmt_val.ljust(column_widths[j])
104
- else:
105
- fmt_val = fmt_val.rjust(column_widths[j])
106
- row.append(fmt_val)
107
- lines.append(" ".join(row) + '\n')
206
+ if key2.endswith("es"):
207
+ singular_candidate = key2.removesuffix("es")
208
+ if key1 == singular_candidate:
209
+ return True
108
210
 
109
- # 写入文件
110
- with open(file_path, 'a', encoding='utf-8') as f:
111
- f.writelines(lines)
211
+ if key2.endswith("s"):
212
+ singular_candidate = key2.removesuffix("s")
213
+ if key1 == singular_candidate:
214
+ return True
215
+
216
+ return False
112
217
 
113
218
 
114
219
  def append_to_results_file(file_path: str,
115
220
  data_dict: dict,
116
221
  column_order: list,
117
- float_precision: int = 5) -> None:
222
+ float_precision: int = 4,
223
+ more_float: int = 2,
224
+ custom_column_widths: dict = None) -> None:
118
225
  """
119
226
  通用格式化文本行写入函数
120
227
 
@@ -123,182 +230,340 @@ def append_to_results_file(file_path: str,
123
230
  data_dict: 包含数据的字典,键为列名
124
231
  column_order: 列顺序列表,元素为字典键
125
232
  float_precision: 浮点数精度位数 (默认5位)
233
+ more_float: 额外的浮点数精度位数
234
+ custom_column_widths: 自定义列宽的字典,键为列名,值为列宽
126
235
  """
127
- # 检查 data_dict 中的值是否为列表
128
- all_values_are_lists = all(isinstance(value, list) for value in data_dict.values())
129
- if all_values_are_lists:
130
- num_rows = len(next(iter(data_dict.values())))
131
- # 逐行处理
132
- for row_index in range(num_rows):
133
- formatted_data = []
134
- column_widths = []
135
- for col in column_order:
136
- # 处理字典键的别名
137
- dict_key = 'val_accuracies' if col == 'accuracies' else col
138
- # 如果键不存在,跳过该列
139
- if dict_key not in data_dict:
140
- continue
141
- value_list = data_dict[dict_key]
142
- if row_index >= len(value_list):
143
- continue
144
- value = value_list[row_index]
145
-
146
- # 根据数据类型进行格式化
147
- if isinstance(value, (int, np.integer)):
148
- fmt_value = f"{value:d}"
149
- elif isinstance(value, (float, np.floating)):
150
- if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
151
- fmt_value = f"{value:.{float_precision + 1}f}"
152
- elif col == 'lrs': # 如果列名是'lrs',保留8位小数
153
- fmt_value = f"{value:.8f}"
154
- else:
155
- fmt_value = f"{value:.{float_precision}f}"
156
- elif isinstance(value, str):
157
- fmt_value = value
158
- else: # 处理其他类型转换为字符串
159
- fmt_value = str(value)
160
-
161
- # 取列名长度和数值长度的最大值作为列宽
162
- column_width = max(len(col), len(fmt_value))
163
- column_widths.append(column_width)
164
-
165
- # 应用列宽对齐
166
- if col == column_order[-1]: # 最后一列左边对齐
167
- fmt_value = fmt_value.ljust(column_width)
168
- else:
169
- fmt_value = fmt_value.rjust(column_width)
170
-
171
- formatted_data.append(fmt_value)
172
-
173
- # 构建文本行并写入,列之间用两个空格分隔
174
- if formatted_data:
175
- line = " ".join(formatted_data) + '\n'
176
- with open(file_path, 'a', encoding='utf-8') as f:
177
- f.write(line)
178
- else:
179
- # 非列表情况,原逻辑处理
180
- # 计算每列的最大宽度
181
- column_widths = []
182
- formatted_data = []
183
- for col in column_order:
184
- # 处理字典键的别名
185
- dict_key = 'val_accuracies' if col == 'accuracies' else col
186
- # 如果键不存在,跳过该列
187
- if dict_key not in data_dict:
188
- continue
189
-
190
- value = data_dict[dict_key]
191
-
192
- # 根据数据类型进行格式化
193
- if isinstance(value, (int, np.integer)):
194
- fmt_value = f"{value:d}"
195
- elif isinstance(value, (float, np.floating)):
196
- if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
197
- fmt_value = f"{value:.{float_precision + 1}f}"
198
- elif col == 'lrs': # 如果列名是'lrs',保留8位小数
199
- fmt_value = f"{value:.8f}"
200
- else:
201
- fmt_value = f"{value:.{float_precision}f}"
202
- elif isinstance(value, str):
203
- fmt_value = value
204
- else: # 处理其他类型转换为字符串
205
- fmt_value = str(value)
236
+ # 计算每列的最大宽度
237
+ column_widths = []
238
+ formatted_data = []
239
+ for col in column_order:
240
+ # 查找 data_dict 中相似的键
241
+ dict_key = None
242
+ for key in data_dict:
243
+ if is_similar_key(key, col):
244
+ dict_key = key
245
+ break
246
+ if dict_key is None:
247
+ raise ValueError(f"Missing required column: {col}")
248
+
249
+ value = data_dict[dict_key]
250
+
251
+ # 根据数据类型进行格式化
252
+ if isinstance(value, (int, np.integer)):
253
+ fmt_value = f"{value:d}"
254
+ elif isinstance(value, (float, np.floating)):
255
+ if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
256
+ fmt_value = f"{value:.{float_precision + more_float}f}"
257
+ elif col == 'lrs':
258
+ fmt_value = f"{value:.8f}"
259
+ else:
260
+ fmt_value = f"{value:.{float_precision}f}"
261
+ elif isinstance(value, str):
262
+ fmt_value = value
263
+ else: # 处理其他类型转换为字符串
264
+ fmt_value = str(value)
206
265
 
266
+ # 确定列宽
267
+ if custom_column_widths and col in custom_column_widths:
268
+ column_width = custom_column_widths[col]
269
+ else:
207
270
  # 取列名长度和数值长度的最大值作为列宽
208
271
  column_width = max(len(col), len(fmt_value))
209
- column_widths.append(column_width)
272
+ column_widths.append(column_width)
210
273
 
211
- # 应用列宽对齐
212
- if col == column_order[-1]: # 最后一列左边对齐
213
- fmt_value = fmt_value.ljust(column_width)
214
- else:
215
- fmt_value = fmt_value.rjust(column_width)
216
-
217
- formatted_data.append(fmt_value)
218
-
219
- # 构建文本行并写入,列之间用两个空格分隔
220
- if formatted_data:
221
- line = " ".join(formatted_data) + '\n'
222
- with open(file_path, 'a', encoding='utf-8') as f:
223
- f.write(line)
224
-
225
-
226
- # def append_to_results_file(file_path: str,
227
- # data_dict: dict,
228
- # column_order: list,
229
- # column_widths: list = None,
230
- # float_precision: int = 5) -> None:
231
- # """
232
- # 通用格式化文本行写入函数
233
- #
234
- # 参数:
235
- # file_path: 目标文件路径
236
- # data_dict: 包含数据的字典,键为列名
237
- # column_order: 列顺序列表,元素为字典键
238
- # column_widths: 每列字符宽度列表 (可选)
239
- # float_precision: 浮点数精度位数 (默认4位)
240
- # """
241
- # formatted_data = []
242
- #
243
- # # 遍历指定列顺序处理数据
244
- # for i, col in enumerate(column_order):
245
- # # 处理字典键的别名
246
- # if col == 'accuracies':
247
- # dict_key = 'val_accuracies'
248
- # else:
249
- # dict_key = col
250
- #
251
- # if dict_key not in data_dict:
252
- # raise ValueError(f"Missing required column: {dict_key}")
253
- #
254
- # value = data_dict[dict_key]
255
- #
256
- # # 根据数据类型进行格式化
257
- # if isinstance(value, (int, np.integer)):
258
- # fmt_value = f"{value:d}"
259
- # elif isinstance(value, (float, np.floating)):
260
- # if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
261
- # fmt_value = f"{value:.{float_precision + 1}f}"
262
- # elif col == 'lr': # 如果列名是'lr',保留8位小数
263
- # fmt_value = f"{value:.8f}"
264
- # else:
265
- # fmt_value = f"{value:.{float_precision}f}"
266
- # elif isinstance(value, str):
267
- # fmt_value = value
268
- # else: # 处理其他类型转换为字符串
269
- # fmt_value = str(value)
270
- #
271
- # # 应用列宽对齐
272
- # if column_widths and i < len(column_widths):
273
- # try:
274
- # if i == len(column_order) - 1: # 最后一列左边对齐
275
- # fmt_value = fmt_value.ljust(column_widths[i])
276
- # else:
277
- # fmt_value = fmt_value.rjust(column_widths[i])
278
- # except TypeError: # 处理非字符串类型
279
- # if i == len(column_order) - 1: # 最后一列左边对齐
280
- # fmt_value = str(fmt_value).ljust(column_widths[i])
281
- # else:
282
- # fmt_value = str(fmt_value).rjust(column_widths[i])
283
- #
284
- # formatted_data.append(fmt_value)
285
- #
286
- # # 构建文本行并写入
287
- # line = '\t'.join(formatted_data) + '\n'
288
- # with open(file_path, 'a', encoding='utf-8') as f:
289
- # f.write(line)
290
-
291
-
292
- def get_wandb_key(key_path='tools/wandb_key.txt'):
274
+ # 应用列宽对齐
275
+ if col == column_order[-1]: # 最后一列左边对齐
276
+ fmt_value = fmt_value.ljust(column_width)
277
+ else:
278
+ fmt_value = fmt_value.rjust(column_width)
279
+
280
+ formatted_data.append(fmt_value)
281
+
282
+ # 构建文本行并写入,列之间用两个空格分隔
283
+ line = " ".join(formatted_data) + '\n'
284
+ with open(file_path, 'a', encoding='utf-8') as f:
285
+ f.write(line)
286
+
287
+
288
+ def get_wandb_key(key_path):
293
289
  with open(key_path, 'r', encoding='utf-8') as f:
294
290
  key = f.read()
295
291
  return key
296
292
 
297
293
 
298
- def wandb_use(project=None, name=None, key_path='tools/wandb_key.txt'):
294
+ def wandb_init(project=None, key_path=None, name=None):
299
295
  run = None
300
296
  if project is not None:
297
+ if key_path is None:
298
+ raise ValueError("When 'project' is not None, 'key_path' should also not be None.")
301
299
  wandb_key = get_wandb_key(key_path)
302
300
  wandb.login(key=wandb_key)
303
301
  run = wandb.init(project=project, name=name)
304
302
  return run
303
+
304
+
305
+ def check_wandb_login_required():
306
+ """兼容旧版的登录检查函数"""
307
+ # 优先检查环境变量
308
+ if os.environ.get("WANDB_API_KEY"):
309
+ return False
310
+
311
+ try:
312
+ api = wandb.Api()
313
+ # 方法 1:通过 settings 检查(适用于旧版)
314
+ if hasattr(api, "settings") and api.settings.get("entity"):
315
+ return False
316
+
317
+ # 方法 2:通过 projects() 验证(通用性强)
318
+ api.projects(per_page=1) # 仅请求第一页的第一个项目
319
+ return False
320
+ except Exception as e:
321
+ print(f"检测到意外错误: {str(e)}")
322
+ return True # 保守返回需要登录
323
+
324
+
325
+ def get_wandb_runs(
326
+ project_path: str,
327
+ default_name: str = "未命名",
328
+ api_key: Optional[str] = None,
329
+ per_page: int = 1000
330
+ ) -> List[Dict[str, str]]:
331
+ """
332
+ 获取指定 WandB 项目的所有运行信息(ID 和 Name)
333
+
334
+ Args:
335
+ project_path (str): 项目路径,格式为 "username/project_name"
336
+ default_name (str): 当运行未命名时的默认显示名称(默认:"未命名")
337
+ api_key (str, optional): WandB API 密钥,若未设置环境变量则需传入
338
+ per_page (int): 分页查询每页数量(默认1000,用于处理大量运行)
339
+
340
+ Returns:
341
+ List[Dict]: 包含运行信息的字典列表,格式 [{"id": "...", "name": "..."}]
342
+
343
+ Raises:
344
+ ValueError: 项目路径格式错误
345
+ wandb.errors.UsageError: API 密钥无效或未登录
346
+ """
347
+ # 参数校验
348
+ if "/" not in project_path or len(project_path.split("/")) != 2:
349
+ raise ValueError("项目路径格式应为 'username/project_name'")
350
+
351
+ # 登录(仅在需要时)
352
+ if api_key:
353
+ wandb.login(key=api_key)
354
+ elif not wandb.api.api_key:
355
+ raise wandb.errors.UsageError("需要提供API密钥或预先调用wandb.login()")
356
+
357
+ # 初始化API
358
+ api = wandb.Api()
359
+
360
+ try:
361
+ # 分页获取所有运行(自动处理分页逻辑)
362
+ runs = api.runs(project_path, per_page=per_page)
363
+ print(f'共获取{len(runs)}个run')
364
+ return [
365
+ {
366
+ "id": run.id,
367
+ "name": run.name or default_name,
368
+ "url": run.url, # 增加实用字段
369
+ "state": run.state # 包含运行状态
370
+ }
371
+ for run in runs
372
+ ]
373
+
374
+ except wandb.errors.CommError as e:
375
+ raise ConnectionError(f"连接失败: {str(e)}") from e
376
+ except Exception as e:
377
+ raise RuntimeError(f"获取运行数据失败: {str(e)}") from e
378
+
379
+
380
+ def delete_runs(
381
+ project_path: str,
382
+ run_ids: Optional[List[str]] = None,
383
+ run_names: Optional[List[str]] = None,
384
+ delete_all: bool = False,
385
+ dry_run: bool = True,
386
+ api_key: Optional[str] = None,
387
+ per_page: int = 500
388
+ ) -> dict:
389
+ """
390
+ 多功能WandB运行删除工具
391
+
392
+ :param project_path: 项目路径(格式:username/project_name)
393
+ :param run_ids: 指定要删除的运行ID列表(无视状态)
394
+ :param run_names: 指定要删除的运行名称列表(无视状态)
395
+ # :param preserve_states: 保护状态列表(默认保护 finished/running)
396
+ :param delete_all: 危险模式!删除所有运行(默认False)
397
+ :param dry_run: 模拟运行模式(默认True)
398
+ :param api_key: WandB API密钥
399
+ :param per_page: 分页查询数量
400
+ :return: 操作统计字典
401
+
402
+ 使用场景:
403
+ 1. 删除指定运行:delete_runs(..., run_ids=["abc","def"])
404
+ 2. 默认删除失败运行:delete_runs(...)
405
+ 3. 删除所有运行:delete_runs(..., delete_all=True)
406
+ """
407
+ preserve_states: List[str] = ["finished", "running"]
408
+ # 参数校验
409
+ if not project_path.count("/") == 1:
410
+ raise ValueError("项目路径格式应为 username/project_name")
411
+ if delete_all and (run_ids or run_names):
412
+ raise ValueError("delete_all模式不能与其他筛选参数同时使用")
413
+
414
+ # 身份验证
415
+ if api_key:
416
+ wandb.login(key=api_key)
417
+ elif not wandb.api.api_key:
418
+ raise wandb.errors.UsageError("需要API密钥或预先登录")
419
+
420
+ api = wandb.Api()
421
+ stats = {
422
+ "total": 0,
423
+ "candidates": 0,
424
+ "deleted": 0,
425
+ "failed": 0,
426
+ "dry_run": dry_run
427
+ }
428
+
429
+ try:
430
+ runs = api.runs(project_path, per_page=per_page)
431
+ stats["total"] = len(runs)
432
+
433
+ # 确定删除目标
434
+ if delete_all:
435
+ targets = runs
436
+ click.secho("\n⚠️ 危险操作:将删除项目所有运行!", fg="red", bold=True)
437
+ elif run_ids or run_names:
438
+ targets = [
439
+ run for run in runs
440
+ if run.id in (run_ids or []) or run.name in (run_names or [])
441
+ ]
442
+ print(f"\n找到 {len(targets)} 个指定运行")
443
+ else:
444
+ targets = [run for run in runs if run.state not in preserve_states]
445
+ print(f"\n找到 {len(targets)} 个非正常状态运行")
446
+
447
+ stats["candidates"] = len(targets)
448
+
449
+ if not targets:
450
+ print("没有符合条件的运行")
451
+ return stats
452
+
453
+ # 打印预览
454
+ print("\n待删除运行示例:")
455
+ for run in targets[:3]:
456
+ state = click.style(run.state, fg="green" if run.state == "finished" else "red")
457
+ print(f" • {run.id} | {run.name} | 状态:{state}")
458
+ if len(targets) > 3:
459
+ print(f" ...(共 {len(targets)} 条)")
460
+
461
+ # 安全确认
462
+ if dry_run:
463
+ click.secho("\n模拟运行模式:不会实际删除", fg="yellow")
464
+ return stats
465
+
466
+ if delete_all:
467
+ msg = click.style("确认要删除所有运行吗?此操作不可逆!", fg="red", bold=True)
468
+ else:
469
+ msg = f"确认要删除 {len(targets)} 个运行吗?"
470
+
471
+ if not click.confirm(msg, default=False):
472
+ print("操作已取消")
473
+ return stats
474
+
475
+ # 执行删除
476
+ print("\n删除进度:")
477
+ for i, run in enumerate(targets, 1):
478
+ try:
479
+ run.delete()
480
+ stats["deleted"] += 1
481
+ print(click.style(f" [{i}/{len(targets)}] 已删除 {run.id}", fg="green"))
482
+ except Exception as e:
483
+ stats["failed"] += 1
484
+ print(click.style(f" [{i}/{len(targets)}] 删除失败 {run.id}: {str(e)}", fg="red"))
485
+
486
+ return stats
487
+
488
+ except wandb.errors.CommError as e:
489
+ raise ConnectionError(f"网络错误: {str(e)}")
490
+ except Exception as e:
491
+ raise RuntimeError(f"操作失败: {str(e)}")
492
+
493
+
494
+ def get_all_artifacts_from_project(project_path, max_runs=None, run_id=None):
495
+ """获取WandB项目或指定Run的所有Artifact
496
+
497
+ Args:
498
+ project_path (str): 项目路径,格式为 "entity/project"
499
+ max_runs (int, optional): 最大获取Run数量(仅当未指定run_id时生效)
500
+ run_id (str, optional): 指定要查询的Run ID
501
+
502
+ Returns:
503
+ list: 包含所有Artifact对象的列表
504
+ """
505
+ api = wandb.Api()
506
+ all_artifacts = []
507
+ seen_artifacts = set() # 用于去重
508
+
509
+ try:
510
+ if run_id:
511
+ # 处理单个Run的情况
512
+ run = api.run(f"{project_path}/{run_id}")
513
+ artifacts = run.logged_artifacts()
514
+
515
+ for artifact in artifacts:
516
+ artifact_id = f"{artifact.name}:{artifact.version}"
517
+ if artifact_id not in seen_artifacts:
518
+ all_artifacts.append(artifact)
519
+ seen_artifacts.add(artifact_id)
520
+
521
+ print(f"Found {len(all_artifacts)} artifacts in run {run_id}")
522
+ else:
523
+ # 处理整个项目的情况
524
+ runs = api.runs(project_path, per_page=500)
525
+ run_iterator = tqdm(runs[:max_runs] if max_runs else runs,
526
+ desc=f"Scanning {project_path}")
527
+
528
+ for run in run_iterator:
529
+ try:
530
+ artifacts = run.logged_artifacts()
531
+ for artifact in artifacts:
532
+ artifact_id = f"{artifact.name}:{artifact.version}"
533
+ if artifact_id not in seen_artifacts:
534
+ all_artifacts.append(artifact)
535
+ seen_artifacts.add(artifact_id)
536
+ except Exception as run_error:
537
+ print(f"Error processing run {run.id}: {str(run_error)}")
538
+
539
+ except Exception as e:
540
+ print(f"Error: {str(e)}")
541
+ return []
542
+
543
+ return all_artifacts
544
+
545
+
546
+ def upload_model_dataset(
547
+ artifact_dir: str,
548
+ artifact_name: str,
549
+ artifact_type: str) -> None:
550
+ run_id = f'yms_upload_{artifact_type}_' + get_current_time('%y%m%d_%H%M%S')
551
+ run = wandb.init(project='upload_model_dataset', name=artifact_name, id=run_id)
552
+ artifact = wandb.Artifact(artifact_name, artifact_type)
553
+ artifact.add_dir(artifact_dir)
554
+ run.log_artifact(artifact)
555
+ run.finish()
556
+
557
+
558
+ def download_model_dataset(
559
+ download_name: str,
560
+ run_name: str,
561
+ artifact_type: str,
562
+ download_dir: str = None,
563
+ entity: str = 'YNA-DeepLearning'
564
+ ) -> str:
565
+ run_id = f'yms_download_{artifact_type}_' + get_current_time('%y%m%d_%H%M%S')
566
+ run = wandb.init(project='download_model_dataset', name=run_name, id=run_id)
567
+ artifact = run.use_artifact(entity + '/upload_model_dataset/' + download_name, type=artifact_type)
568
+ artifact_dir = artifact.download(root=download_dir)
569
+ return artifact_dir
@@ -1,25 +1,52 @@
1
1
  import math
2
2
  import os
3
3
  import sys
4
- from enum import Enum, auto
5
4
 
6
5
  import numpy as np
7
6
  import torch
8
7
  from matplotlib import pyplot as plt
8
+ from sklearn.metrics import classification_report
9
9
  from torch.optim.lr_scheduler import ReduceLROnPlateau
10
10
  from tqdm import tqdm
11
11
 
12
12
  from yms_kan import LBFGS
13
+ from yms_kan.tool import initialize_results_file, append_to_results_file
13
14
 
14
15
 
15
- def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
16
- lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
16
+ def calculate_metric(all_labels, all_predictions, classes, class_metric=False, average='macro avg'):
17
+ metric = classification_report(y_true=all_labels, y_pred=all_predictions,
18
+ target_names=classes, digits=4, output_dict=True, zero_division=0)
19
+ if not class_metric:
20
+ metric = {
21
+ 'accuracy': metric.get('accuracy'),
22
+ 'precision': metric.get(average).get('precision'),
23
+ 'recall': metric.get(average).get('recall'),
24
+ 'f1-score': metric.get(average).get('f1-score'),
25
+ }
26
+ return metric
27
+ else:
28
+ return metric
29
+
30
+
31
+ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_file=None, opt="LBFGS", epochs=100,
32
+ lamb=0.,
33
+ lamb_l1=1., label=None, class_dict=None, lamb_entropy=2., lamb_coef=0.,
17
34
  lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
18
35
  stop_grid_update_step=100,
19
36
  save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
20
37
  singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
21
- # all_predictions = []
22
- # all_labels = []
38
+ all_predictions = []
39
+ all_labels = []
40
+ best = -1
41
+ column_order = ['epoch', 'train_losses', 'val_losses', 'accuracies', 'precisions', 'recalls',
42
+ 'f1-scores', 'lrs']
43
+ custom_column_widths = {'epoch': 5, 'train_loss': 12, 'val_loss': 10, 'accuracy': 10, 'precision': 9,
44
+ 'recall': 7,
45
+ 'f1-score': 8,
46
+ 'lr': 3}
47
+ if txt_file is not None:
48
+ initialize_results_file(txt_file, column_order)
49
+
23
50
  if lamb > 0. and not model.save_act:
24
51
  print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
25
52
 
@@ -44,8 +71,8 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
44
71
 
45
72
  lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
46
73
 
47
- results = {'train_loss': .0, 'val_loss': .0, 'regularize': .0, 'all_predictions': [],
48
- 'all_labels': []}
74
+ results = {'train_losses': [], 'val_losses': [], 'accuracies': [], 'precisions': [], 'recalls': [], 'f1-scores': [],
75
+ 'lrs': [], 'all_predictions': [], 'all_labels': []}
49
76
 
50
77
  steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
51
78
 
@@ -118,7 +145,6 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
118
145
  optimizer.step()
119
146
  train_pbar.set_postfix(loss=train_loss.item())
120
147
 
121
- # print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
122
148
  val_loss = torch.zeros(1).to(model.device)
123
149
  with torch.no_grad():
124
150
  test_indices = np.arange(dataset['test_input'].shape[0])
@@ -142,14 +168,34 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
142
168
  diffs = torch.abs(outputs - label)
143
169
  closest_indices = torch.argmin(diffs, dim=1)
144
170
  closest_values = label[closest_indices]
145
- results['all_predictions'].extend(closest_values.detach().cpu().numpy())
146
- results['all_labels'].extend(batch_test_label.detach().cpu().numpy())
171
+ all_predictions.extend(closest_values.detach().cpu().numpy())
172
+ all_labels.extend(batch_test_label.detach().cpu().numpy())
147
173
 
174
+ train_lr = lr_scheduler.get_last_lr()[0]
148
175
  lr_scheduler.step(val_loss)
149
176
 
150
- results['train_loss'] = train_loss.item()
151
- results['val_loss'] = val_loss.item()
152
- results['regularize'] = reg_.item()
177
+ if label is not None:
178
+ m = calculate_metric(all_labels, all_predictions, class_dict)
179
+ print(m)
180
+ results["accuracies"].append(m["accuracy"])
181
+ results["precisions"].append(m["precision"])
182
+ results["recalls"].append(m["recall"])
183
+ results["f1-scores"].append(m["f1-score"])
184
+ results["lrs"].append(train_lr)
185
+ if best < m["f1-score"]:
186
+ best = m["f1-score"]
187
+ results['all_predictions'] = all_predictions
188
+ results['all_labels'] = all_labels
189
+ if save_path is not None:
190
+ model.saveckpt(path=save_path + '/' + 'model')
191
+ if txt_file is not None:
192
+ m.update({'lr': train_lr, 'epoch': epoch, 'train_loss': train_loss.item(), 'val_loss': val_loss.item()})
193
+ append_to_results_file(txt_file, m, column_order,
194
+ custom_column_widths=custom_column_widths)
195
+
196
+ results["train_losses"].append(train_loss.item())
197
+ results["val_losses"].append(val_loss.item())
198
+ results["regularize"].append(reg_.item())
153
199
 
154
200
  if save_fig and epoch % save_fig_freq == 0:
155
201
  model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
yms_kan/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.0.2" # 初始版本
1
+ __version__ = "0.0.4" # 初始版本
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: yms_kan
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: My awesome package
5
5
  Author-email: yms <11@qq.com>
6
6
  License-Expression: MIT
@@ -9,14 +9,14 @@ yms_kan/experiment.py,sha256=VnZq7hmcvRk08GNI7VIpkOjkaRZBoIw1C8SU_f1KbaA,1682
9
9
  yms_kan/feynman.py,sha256=Eisf69K49s4C6UlPEi5LnNK_p5TUJQLBKxMp-sW0a9w,33687
10
10
  yms_kan/hypothesis.py,sha256=Ec20xadfgOSSWeZHQaGn-h9F2PY7LWFU3iniNI2Zd_4,23165
11
11
  yms_kan/spline.py,sha256=ZXyGwl2Sc-UrnrcuUXeUQkBOMnetaWcHrbpZaqatCvs,4345
12
- yms_kan/tool.py,sha256=CLIsOYWwG-A5PJvoyIP8cRBzX8iRhEssW-2uXdLfi-U,12124
13
- yms_kan/train_eval_utils.py,sha256=73pA3-HDPDik_yCsDW0oF1dIvVu_vPeHbvJ08o26ygQ,14867
12
+ yms_kan/tool.py,sha256=rkRpqF3EcsAq7a3k1F1zKlxfJ4U9n-FzHyNCJgN4URY,21159
13
+ yms_kan/train_eval_utils.py,sha256=Cqw0heB7gOIK3pvOPBx0OIIWi2glfimPpyDqboFq2Tk,17186
14
14
  yms_kan/utils.py,sha256=J07L-tgmc1OfU6Tl6mGwHJRizjFN75EJK8BxejaZLUc,23860
15
- yms_kan/version.py,sha256=qeSnHAh3t9Zb2L0FPUF5OaQvWEJcfTki6FmrfynjWz4,39
15
+ yms_kan/version.py,sha256=eECtaVYZj2CuGnsLuv9pAmxQhOb0PZcTisjYg4JgF5c,39
16
16
  yms_kan/assets/img/mult_symbol.png,sha256=2f4xUKdweft-qUbHjFI5h9-smnEtc0FWq8hNYZhPAXY,6392
17
17
  yms_kan/assets/img/sum_symbol.png,sha256=94QkMUzmEjlCq_yf14nMEQmettaq86FmlGfdl22b4XE,6210
18
- yms_kan-0.0.2.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
19
- yms_kan-0.0.2.dist-info/METADATA,sha256=jTD-nNMWFF64GiFO2-bYQePNxJh4J1-yi4eniWT1djQ,240
20
- yms_kan-0.0.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
21
- yms_kan-0.0.2.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
22
- yms_kan-0.0.2.dist-info/RECORD,,
18
+ yms_kan-0.0.4.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
19
+ yms_kan-0.0.4.dist-info/METADATA,sha256=VsDT6gWg7lsWcP424XM-o-jsTQ6eAlmGQWbb-hLxGmk,240
20
+ yms_kan-0.0.4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
21
+ yms_kan-0.0.4.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
22
+ yms_kan-0.0.4.dist-info/RECORD,,