yms-kan 0.0.1__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. yms_kan-0.0.3/MANIFEST.in +1 -0
  2. {yms_kan-0.0.1/yms_kan.egg-info → yms_kan-0.0.3}/PKG-INFO +1 -1
  3. {yms_kan-0.0.1 → yms_kan-0.0.3}/pyproject.toml +12 -0
  4. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/MultKAN.py +5 -2
  5. yms_kan-0.0.3/yms_kan/assets/img/mult_symbol.png +0 -0
  6. yms_kan-0.0.3/yms_kan/assets/img/sum_symbol.png +0 -0
  7. yms_kan-0.0.3/yms_kan/tool.py +569 -0
  8. yms_kan-0.0.3/yms_kan/train_eval_utils.py +364 -0
  9. yms_kan-0.0.3/yms_kan/version.py +1 -0
  10. {yms_kan-0.0.1 → yms_kan-0.0.3/yms_kan.egg-info}/PKG-INFO +1 -1
  11. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan.egg-info/SOURCES.txt +4 -1
  12. yms_kan-0.0.1/yms_kan/tool.py +0 -304
  13. yms_kan-0.0.1/yms_kan/train_eval_utils.py +0 -175
  14. yms_kan-0.0.1/yms_kan/version.py +0 -1
  15. {yms_kan-0.0.1 → yms_kan-0.0.3}/LICENSE +0 -0
  16. {yms_kan-0.0.1 → yms_kan-0.0.3}/README.md +0 -0
  17. {yms_kan-0.0.1 → yms_kan-0.0.3}/setup.cfg +0 -0
  18. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/KANLayer.py +0 -0
  19. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/LBFGS.py +0 -0
  20. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/MLP.py +0 -0
  21. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/Symbolic_KANLayer.py +0 -0
  22. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/__init__.py +0 -0
  23. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/compiler.py +0 -0
  24. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/experiment.py +0 -0
  25. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/feynman.py +0 -0
  26. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/hypothesis.py +0 -0
  27. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/spline.py +0 -0
  28. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan/utils.py +0 -0
  29. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan.egg-info/dependency_links.txt +0 -0
  30. {yms_kan-0.0.1 → yms_kan-0.0.3}/yms_kan.egg-info/top_level.txt +0 -0
@@ -0,0 +1 @@
1
+ recursive-include yms_kan/assets/img *.png
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: yms_kan
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: My awesome package
5
5
  Author-email: yms <11@qq.com>
6
6
  License-Expression: MIT
@@ -17,3 +17,15 @@ license-files = ["LICENSE"]
17
17
  [tool.setuptools.dynamic]
18
18
  # 明确指定版本号来源
19
19
  version = {attr = "yms_kan.version.__version__"}
20
+
21
+ [tool.setuptools]
22
+ # 包含非代码文件
23
+ include-package-data = true
24
+
25
+ [tool.setuptools.package-data]
26
+ # 指定包内资源文件的匹配规则
27
+ yms_kan = [
28
+ "assets/img/*.png", # 包含所有png文件
29
+ "assets/img/*.svg", # 可扩展其他格式
30
+ "assets/**/*" # 递归包含子目录
31
+ ]
@@ -3,6 +3,7 @@ import math
3
3
  import os
4
4
  import random
5
5
  import sys
6
+ from importlib.resources import files
6
7
 
7
8
  import matplotlib.pyplot as plt
8
9
  import numpy as np
@@ -1299,7 +1300,8 @@ class MultKAN(nn.Module):
1299
1300
  N = n = width_out[l + 1]
1300
1301
  for j in range(n):
1301
1302
  id_ = j
1302
- path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
1303
+ # path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
1304
+ path = files('yms_kan') / "assets/img/sum_symbol.png"
1303
1305
  im = plt.imread(path)
1304
1306
  left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
1305
1307
  right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
@@ -1315,7 +1317,8 @@ class MultKAN(nn.Module):
1315
1317
  n_mult = width[l + 1][1]
1316
1318
  for j in range(n_mult):
1317
1319
  id_ = j + n_sum
1318
- path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
1320
+ # path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
1321
+ path = files('yms_kan') / "assets/img/mult_symbol.png"
1319
1322
  im = plt.imread(path)
1320
1323
  left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
1321
1324
  right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
@@ -0,0 +1,569 @@
1
+ import os
2
+ import re
3
+ from datetime import datetime, timezone, timedelta
4
+ from typing import Optional, Dict, List
5
+
6
+ import click
7
+ import numpy as np
8
+ import pandas as pd
9
+ import wandb
10
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, \
11
+ classification_report
12
+ from tqdm import tqdm
13
+
14
+
15
+ # 读取txt内两个不同表格的数据,并将结果转换为字典列表输出
16
+ def read_multi_table_txt(file_path):
17
+ # 读取原始内容
18
+ with open(file_path, 'r') as f:
19
+ content = f.read()
20
+
21
+ # 按表格标题分割内容(假设每个新表格以"epoch"开头)
22
+ table_blocks = re.split(r'\n(?=epoch\s)', content.strip())
23
+
24
+ # 处理每个表格块
25
+ table_dicts = []
26
+ for block in table_blocks:
27
+ lines = [line.strip() for line in block.split('\n') if line.strip()]
28
+
29
+ # 解析列名(处理制表符和混合空格)
30
+ columns = re.split(r'\s{2,}|\t', lines[0])
31
+
32
+ # 解析数据行(处理混合分隔符)
33
+ data = []
34
+ for line in lines[1:]:
35
+ # 使用正则表达式分割多个连续空格/制表符
36
+ row = re.split(r'\s{2,}|\t', line)
37
+ data.append(row)
38
+
39
+ # 创建DataFrame并自动转换数值类型
40
+ df = pd.DataFrame(data, columns=columns)
41
+ df = df.apply(pd.to_numeric, errors='coerce') # 自动识别数值列,非数值转换为NaN
42
+
43
+ # 将DataFrame转换为字典,每列以列表形式保存
44
+ table_dict = df.to_dict(orient='list')
45
+ table_dicts.append(table_dict)
46
+
47
+ return table_dicts
48
+
49
+
50
+ def get_current_time(format_str="%Y-%m-%d %H:%M:%S"):
51
+ """
52
+ 获取东八区(UTC+8)的当前时间,并返回指定格式的字符串
53
+ :param format_str: 时间格式(默认为 "%Y-%m-%d %H:%M:%S")
54
+ :return: 格式化后的时间字符串
55
+ """
56
+
57
+ # 创建东八区的时区对象
58
+ utc8_timezone = timezone(timedelta(hours=8))
59
+
60
+ # 转换为东八区时间
61
+ utc8_time = datetime.now(utc8_timezone)
62
+
63
+ # 格式化为字符串
64
+ formatted_time = utc8_time.strftime(format_str)
65
+ return formatted_time
66
+
67
+
68
+ # val和test时的相关结果指标计算
69
+ def calculate_results(all_labels, all_predictions, classes, average='macro'):
70
+ results = {
71
+ 'accuracy': accuracy_score(y_true=all_labels, y_pred=all_predictions),
72
+ 'precision': precision_score(y_true=all_labels, y_pred=all_predictions, average=average),
73
+ 'recall': recall_score(y_true=all_labels, y_pred=all_predictions, average=average),
74
+ 'f1_score': f1_score(y_true=all_labels, y_pred=all_predictions, average=average),
75
+ 'cm': confusion_matrix(y_true=all_labels, y_pred=all_predictions, labels=np.arange(len(classes)))
76
+ }
77
+ return results
78
+
79
+
80
+ def calculate_metric(all_labels, all_predictions, classes, class_metric=False, average='macro avg'):
81
+ metric = classification_report(y_true=all_labels, y_pred=all_predictions,
82
+ target_names=classes, digits=4, output_dict=True, zero_division=0)
83
+ if not class_metric:
84
+ metric = {
85
+ 'accuracy': metric.get('accuracy'),
86
+ 'precision': metric.get(average).get('precision'),
87
+ 'recall': metric.get(average).get('recall'),
88
+ 'f1-score': metric.get(average).get('f1-score'),
89
+ }
90
+ return metric
91
+ else:
92
+ return metric
93
+
94
+
95
+ def dict_to_classification_report(report_dict, digits=2):
96
+ headers = ["precision", "recall", "f1-score", "support"]
97
+ target_names = list(report_dict.keys())
98
+ target_names.remove('accuracy') if 'accuracy' in target_names else None
99
+ longest_last_line_heading = "weighted avg"
100
+ name_width = max(len(cn) for cn in target_names)
101
+ width = max(name_width, len(longest_last_line_heading), digits)
102
+ head_fmt = "{:>{width}s} " + " {:>9}" * len(headers)
103
+ report = head_fmt.format("", *headers, width=width)
104
+ report += "\n\n"
105
+ row_fmt = "{:>{width}s} " + " {:>9.{digits}f}" * 3 + " {:>9}\n"
106
+ for target_name in target_names:
107
+ scores = [report_dict[target_name][h] for h in headers]
108
+ report += row_fmt.format(target_name, *scores, width=width, digits=digits)
109
+ report += "\n"
110
+
111
+ average_options = ["micro avg", "macro avg", "weighted avg"]
112
+ if 'samples avg' in report_dict:
113
+ average_options.append('samples avg')
114
+ for average in average_options:
115
+ if average in report_dict:
116
+ scores = [report_dict[average][h] for h in headers]
117
+ if average == "accuracy":
118
+ row_fmt_accuracy = (
119
+ "{:>{width}s} "
120
+ + " {:>9.{digits}}" * 2
121
+ + " {:>9.{digits}f}"
122
+ + " {:>9}\n"
123
+ )
124
+ report += row_fmt_accuracy.format(
125
+ average, "", "", *scores[2:], width=width, digits=digits
126
+ )
127
+ else:
128
+ report += row_fmt.format(average, *scores, width=width, digits=digits)
129
+
130
+ if 'accuracy' in report_dict:
131
+ row_fmt_accuracy = (
132
+ "{:>{width}s} "
133
+ + " {:>9.{digits}}" * 2
134
+ + " {:>9.{digits}f}"
135
+ + " {:>9}\n"
136
+ )
137
+ report += row_fmt_accuracy.format(
138
+ "accuracy", "", "", report_dict["accuracy"], "", width=width, digits=digits
139
+ )
140
+
141
+ return report
142
+
143
+
144
+ # def append_metrics(metrics, metric, result, lr):
145
+ # metrics['train_losses'].append(result['train_loss'])
146
+ # metrics['val_losses'].append(result['val_loss'])
147
+ # metrics['accuracies'].append(metric['accuracy'])
148
+ # metrics['precisions'].append(metric['precision'])
149
+ # metrics['recalls'].append(metric['recall'])
150
+ # metrics['f1-scores'].append(metric['f1-score'])
151
+ # metrics['lrs'].append(lr)
152
+ # return metrics
153
+
154
+
155
+ def initialize_results_file(results_file, result_info):
156
+ """
157
+ 初始化结果文件,确保文件存在且第一行包含指定的内容。
158
+
159
+ 参数:
160
+ results_file (str): 结果文件的路径。
161
+ result_info (list): 需要写入的第一行内容列表。
162
+ space:列名间隔(默认两个空格的距离)
163
+ """
164
+ # 处理 result_info,在每个单词后添加两个空格
165
+ result_info_str = ' '.join(result_info) + '\n'
166
+ # 检查文件是否存在
167
+ if os.path.exists(results_file):
168
+ # 如果文件存在,读取第一行
169
+ with open(results_file, "r") as f:
170
+ first_line = f.readline().strip()
171
+ # 检查第一行是否与 result_info 一致
172
+ if first_line == result_info_str.strip():
173
+ print(f"文件 {results_file} 已存在且第一行已包含 result_info,不进行写入。")
174
+ else:
175
+ # 如果不一致,写入 result_info
176
+ with open(results_file, "w") as f:
177
+ f.write(result_info_str)
178
+ print(f"文件 {results_file} 已被重新初始化。")
179
+ else:
180
+ # 如果文件不存在,创建并写入 result_info
181
+ with open(results_file, "w") as f:
182
+ f.write(result_info_str)
183
+ print(f"文件 {results_file} 已创建并写入 result_info。")
184
+
185
+
186
+ def is_similar_key(key1, key2):
187
+ """
188
+ 检查两个键是否相似,考虑复数形式的转换。
189
+
190
+ Args:
191
+ key1 (str): 第一个键
192
+ key2 (str): 第二个键
193
+
194
+ Returns:
195
+ bool: 如果两个键相似(包括复数形式的转换),返回 True,否则返回 False
196
+ """
197
+ if key1 == key2:
198
+ return True
199
+
200
+ # 检查 key2 是否是复数形式
201
+ if key2.endswith("ies"):
202
+ singular_candidate = key2.removesuffix("ies") + "y"
203
+ if key1 == singular_candidate:
204
+ return True
205
+
206
+ if key2.endswith("es"):
207
+ singular_candidate = key2.removesuffix("es")
208
+ if key1 == singular_candidate:
209
+ return True
210
+
211
+ if key2.endswith("s"):
212
+ singular_candidate = key2.removesuffix("s")
213
+ if key1 == singular_candidate:
214
+ return True
215
+
216
+ return False
217
+
218
+
219
+ def append_to_results_file(file_path: str,
220
+ data_dict: dict,
221
+ column_order: list,
222
+ float_precision: int = 4,
223
+ more_float: int = 2,
224
+ custom_column_widths: dict = None) -> None:
225
+ """
226
+ 通用格式化文本行写入函数
227
+
228
+ 参数:
229
+ file_path: 目标文件路径
230
+ data_dict: 包含数据的字典,键为列名
231
+ column_order: 列顺序列表,元素为字典键
232
+ float_precision: 浮点数精度位数 (默认5位)
233
+ more_float: 额外的浮点数精度位数
234
+ custom_column_widths: 自定义列宽的字典,键为列名,值为列宽
235
+ """
236
+ # 计算每列的最大宽度
237
+ column_widths = []
238
+ formatted_data = []
239
+ for col in column_order:
240
+ # 查找 data_dict 中相似的键
241
+ dict_key = None
242
+ for key in data_dict:
243
+ if is_similar_key(key, col):
244
+ dict_key = key
245
+ break
246
+ if dict_key is None:
247
+ raise ValueError(f"Missing required column: {col}")
248
+
249
+ value = data_dict[dict_key]
250
+
251
+ # 根据数据类型进行格式化
252
+ if isinstance(value, (int, np.integer)):
253
+ fmt_value = f"{value:d}"
254
+ elif isinstance(value, (float, np.floating)):
255
+ if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
256
+ fmt_value = f"{value:.{float_precision + more_float}f}"
257
+ elif col == 'lrs':
258
+ fmt_value = f"{value:.8f}"
259
+ else:
260
+ fmt_value = f"{value:.{float_precision}f}"
261
+ elif isinstance(value, str):
262
+ fmt_value = value
263
+ else: # 处理其他类型转换为字符串
264
+ fmt_value = str(value)
265
+
266
+ # 确定列宽
267
+ if custom_column_widths and col in custom_column_widths:
268
+ column_width = custom_column_widths[col]
269
+ else:
270
+ # 取列名长度和数值长度的最大值作为列宽
271
+ column_width = max(len(col), len(fmt_value))
272
+ column_widths.append(column_width)
273
+
274
+ # 应用列宽对齐
275
+ if col == column_order[-1]: # 最后一列左边对齐
276
+ fmt_value = fmt_value.ljust(column_width)
277
+ else:
278
+ fmt_value = fmt_value.rjust(column_width)
279
+
280
+ formatted_data.append(fmt_value)
281
+
282
+ # 构建文本行并写入,列之间用两个空格分隔
283
+ line = " ".join(formatted_data) + '\n'
284
+ with open(file_path, 'a', encoding='utf-8') as f:
285
+ f.write(line)
286
+
287
+
288
+ def get_wandb_key(key_path):
289
+ with open(key_path, 'r', encoding='utf-8') as f:
290
+ key = f.read()
291
+ return key
292
+
293
+
294
+ def wandb_init(project=None, key_path=None, name=None):
295
+ run = None
296
+ if project is not None:
297
+ if key_path is None:
298
+ raise ValueError("When 'project' is not None, 'key_path' should also not be None.")
299
+ wandb_key = get_wandb_key(key_path)
300
+ wandb.login(key=wandb_key)
301
+ run = wandb.init(project=project, name=name)
302
+ return run
303
+
304
+
305
+ def check_wandb_login_required():
306
+ """兼容旧版的登录检查函数"""
307
+ # 优先检查环境变量
308
+ if os.environ.get("WANDB_API_KEY"):
309
+ return False
310
+
311
+ try:
312
+ api = wandb.Api()
313
+ # 方法 1:通过 settings 检查(适用于旧版)
314
+ if hasattr(api, "settings") and api.settings.get("entity"):
315
+ return False
316
+
317
+ # 方法 2:通过 projects() 验证(通用性强)
318
+ api.projects(per_page=1) # 仅请求第一页的第一个项目
319
+ return False
320
+ except Exception as e:
321
+ print(f"检测到意外错误: {str(e)}")
322
+ return True # 保守返回需要登录
323
+
324
+
325
+ def get_wandb_runs(
326
+ project_path: str,
327
+ default_name: str = "未命名",
328
+ api_key: Optional[str] = None,
329
+ per_page: int = 1000
330
+ ) -> List[Dict[str, str]]:
331
+ """
332
+ 获取指定 WandB 项目的所有运行信息(ID 和 Name)
333
+
334
+ Args:
335
+ project_path (str): 项目路径,格式为 "username/project_name"
336
+ default_name (str): 当运行未命名时的默认显示名称(默认:"未命名")
337
+ api_key (str, optional): WandB API 密钥,若未设置环境变量则需传入
338
+ per_page (int): 分页查询每页数量(默认1000,用于处理大量运行)
339
+
340
+ Returns:
341
+ List[Dict]: 包含运行信息的字典列表,格式 [{"id": "...", "name": "..."}]
342
+
343
+ Raises:
344
+ ValueError: 项目路径格式错误
345
+ wandb.errors.UsageError: API 密钥无效或未登录
346
+ """
347
+ # 参数校验
348
+ if "/" not in project_path or len(project_path.split("/")) != 2:
349
+ raise ValueError("项目路径格式应为 'username/project_name'")
350
+
351
+ # 登录(仅在需要时)
352
+ if api_key:
353
+ wandb.login(key=api_key)
354
+ elif not wandb.api.api_key:
355
+ raise wandb.errors.UsageError("需要提供API密钥或预先调用wandb.login()")
356
+
357
+ # 初始化API
358
+ api = wandb.Api()
359
+
360
+ try:
361
+ # 分页获取所有运行(自动处理分页逻辑)
362
+ runs = api.runs(project_path, per_page=per_page)
363
+ print(f'共获取{len(runs)}个run')
364
+ return [
365
+ {
366
+ "id": run.id,
367
+ "name": run.name or default_name,
368
+ "url": run.url, # 增加实用字段
369
+ "state": run.state # 包含运行状态
370
+ }
371
+ for run in runs
372
+ ]
373
+
374
+ except wandb.errors.CommError as e:
375
+ raise ConnectionError(f"连接失败: {str(e)}") from e
376
+ except Exception as e:
377
+ raise RuntimeError(f"获取运行数据失败: {str(e)}") from e
378
+
379
+
380
+ def delete_runs(
381
+ project_path: str,
382
+ run_ids: Optional[List[str]] = None,
383
+ run_names: Optional[List[str]] = None,
384
+ delete_all: bool = False,
385
+ dry_run: bool = True,
386
+ api_key: Optional[str] = None,
387
+ per_page: int = 500
388
+ ) -> dict:
389
+ """
390
+ 多功能WandB运行删除工具
391
+
392
+ :param project_path: 项目路径(格式:username/project_name)
393
+ :param run_ids: 指定要删除的运行ID列表(无视状态)
394
+ :param run_names: 指定要删除的运行名称列表(无视状态)
395
+ # :param preserve_states: 保护状态列表(默认保护 finished/running)
396
+ :param delete_all: 危险模式!删除所有运行(默认False)
397
+ :param dry_run: 模拟运行模式(默认True)
398
+ :param api_key: WandB API密钥
399
+ :param per_page: 分页查询数量
400
+ :return: 操作统计字典
401
+
402
+ 使用场景:
403
+ 1. 删除指定运行:delete_runs(..., run_ids=["abc","def"])
404
+ 2. 默认删除失败运行:delete_runs(...)
405
+ 3. 删除所有运行:delete_runs(..., delete_all=True)
406
+ """
407
+ preserve_states: List[str] = ["finished", "running"]
408
+ # 参数校验
409
+ if not project_path.count("/") == 1:
410
+ raise ValueError("项目路径格式应为 username/project_name")
411
+ if delete_all and (run_ids or run_names):
412
+ raise ValueError("delete_all模式不能与其他筛选参数同时使用")
413
+
414
+ # 身份验证
415
+ if api_key:
416
+ wandb.login(key=api_key)
417
+ elif not wandb.api.api_key:
418
+ raise wandb.errors.UsageError("需要API密钥或预先登录")
419
+
420
+ api = wandb.Api()
421
+ stats = {
422
+ "total": 0,
423
+ "candidates": 0,
424
+ "deleted": 0,
425
+ "failed": 0,
426
+ "dry_run": dry_run
427
+ }
428
+
429
+ try:
430
+ runs = api.runs(project_path, per_page=per_page)
431
+ stats["total"] = len(runs)
432
+
433
+ # 确定删除目标
434
+ if delete_all:
435
+ targets = runs
436
+ click.secho("\n⚠️ 危险操作:将删除项目所有运行!", fg="red", bold=True)
437
+ elif run_ids or run_names:
438
+ targets = [
439
+ run for run in runs
440
+ if run.id in (run_ids or []) or run.name in (run_names or [])
441
+ ]
442
+ print(f"\n找到 {len(targets)} 个指定运行")
443
+ else:
444
+ targets = [run for run in runs if run.state not in preserve_states]
445
+ print(f"\n找到 {len(targets)} 个非正常状态运行")
446
+
447
+ stats["candidates"] = len(targets)
448
+
449
+ if not targets:
450
+ print("没有符合条件的运行")
451
+ return stats
452
+
453
+ # 打印预览
454
+ print("\n待删除运行示例:")
455
+ for run in targets[:3]:
456
+ state = click.style(run.state, fg="green" if run.state == "finished" else "red")
457
+ print(f" • {run.id} | {run.name} | 状态:{state}")
458
+ if len(targets) > 3:
459
+ print(f" ...(共 {len(targets)} 条)")
460
+
461
+ # 安全确认
462
+ if dry_run:
463
+ click.secho("\n模拟运行模式:不会实际删除", fg="yellow")
464
+ return stats
465
+
466
+ if delete_all:
467
+ msg = click.style("确认要删除所有运行吗?此操作不可逆!", fg="red", bold=True)
468
+ else:
469
+ msg = f"确认要删除 {len(targets)} 个运行吗?"
470
+
471
+ if not click.confirm(msg, default=False):
472
+ print("操作已取消")
473
+ return stats
474
+
475
+ # 执行删除
476
+ print("\n删除进度:")
477
+ for i, run in enumerate(targets, 1):
478
+ try:
479
+ run.delete()
480
+ stats["deleted"] += 1
481
+ print(click.style(f" [{i}/{len(targets)}] 已删除 {run.id}", fg="green"))
482
+ except Exception as e:
483
+ stats["failed"] += 1
484
+ print(click.style(f" [{i}/{len(targets)}] 删除失败 {run.id}: {str(e)}", fg="red"))
485
+
486
+ return stats
487
+
488
+ except wandb.errors.CommError as e:
489
+ raise ConnectionError(f"网络错误: {str(e)}")
490
+ except Exception as e:
491
+ raise RuntimeError(f"操作失败: {str(e)}")
492
+
493
+
494
+ def get_all_artifacts_from_project(project_path, max_runs=None, run_id=None):
495
+ """获取WandB项目或指定Run的所有Artifact
496
+
497
+ Args:
498
+ project_path (str): 项目路径,格式为 "entity/project"
499
+ max_runs (int, optional): 最大获取Run数量(仅当未指定run_id时生效)
500
+ run_id (str, optional): 指定要查询的Run ID
501
+
502
+ Returns:
503
+ list: 包含所有Artifact对象的列表
504
+ """
505
+ api = wandb.Api()
506
+ all_artifacts = []
507
+ seen_artifacts = set() # 用于去重
508
+
509
+ try:
510
+ if run_id:
511
+ # 处理单个Run的情况
512
+ run = api.run(f"{project_path}/{run_id}")
513
+ artifacts = run.logged_artifacts()
514
+
515
+ for artifact in artifacts:
516
+ artifact_id = f"{artifact.name}:{artifact.version}"
517
+ if artifact_id not in seen_artifacts:
518
+ all_artifacts.append(artifact)
519
+ seen_artifacts.add(artifact_id)
520
+
521
+ print(f"Found {len(all_artifacts)} artifacts in run {run_id}")
522
+ else:
523
+ # 处理整个项目的情况
524
+ runs = api.runs(project_path, per_page=500)
525
+ run_iterator = tqdm(runs[:max_runs] if max_runs else runs,
526
+ desc=f"Scanning {project_path}")
527
+
528
+ for run in run_iterator:
529
+ try:
530
+ artifacts = run.logged_artifacts()
531
+ for artifact in artifacts:
532
+ artifact_id = f"{artifact.name}:{artifact.version}"
533
+ if artifact_id not in seen_artifacts:
534
+ all_artifacts.append(artifact)
535
+ seen_artifacts.add(artifact_id)
536
+ except Exception as run_error:
537
+ print(f"Error processing run {run.id}: {str(run_error)}")
538
+
539
+ except Exception as e:
540
+ print(f"Error: {str(e)}")
541
+ return []
542
+
543
+ return all_artifacts
544
+
545
+
546
+ def upload_model_dataset(
547
+ artifact_dir: str,
548
+ artifact_name: str,
549
+ artifact_type: str) -> None:
550
+ run_id = f'yms_upload_{artifact_type}_' + get_current_time('%y%m%d_%H%M%S')
551
+ run = wandb.init(project='upload_model_dataset', name=artifact_name, id=run_id)
552
+ artifact = wandb.Artifact(artifact_name, artifact_type)
553
+ artifact.add_dir(artifact_dir)
554
+ run.log_artifact(artifact)
555
+ run.finish()
556
+
557
+
558
+ def download_model_dataset(
559
+ download_name: str,
560
+ run_name: str,
561
+ artifact_type: str,
562
+ download_dir: str = None,
563
+ entity: str = 'YNA-DeepLearning'
564
+ ) -> str:
565
+ run_id = f'yms_download_{artifact_type}_' + get_current_time('%y%m%d_%H%M%S')
566
+ run = wandb.init(project='download_model_dataset', name=run_name, id=run_id)
567
+ artifact = run.use_artifact(entity + '/upload_model_dataset/' + download_name, type=artifact_type)
568
+ artifact_dir = artifact.download(root=download_dir)
569
+ return artifact_dir