xlin 0.1.15__py2.py3-none-any.whl → 0.1.17__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xlin/__init__.py +1 -0
- xlin/statistic.py +149 -1
- xlin/timing.py +43 -0
- xlin/util.py +9 -1
- {xlin-0.1.15.dist-info → xlin-0.1.17.dist-info}/LICENSE +1 -1
- {xlin-0.1.15.dist-info → xlin-0.1.17.dist-info}/METADATA +2 -2
- xlin-0.1.17.dist-info/RECORD +15 -0
- xlin-0.1.15.dist-info/RECORD +0 -14
- {xlin-0.1.15.dist-info → xlin-0.1.17.dist-info}/WHEEL +0 -0
xlin/__init__.py
CHANGED
xlin/statistic.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from typing import List
|
2
|
+
from collections import defaultdict
|
2
3
|
|
3
4
|
import pandas as pd
|
4
5
|
|
@@ -115,4 +116,151 @@ def draw_pie(numbers: List[int], title="Pie Chart of Numbers"):
|
|
115
116
|
|
116
117
|
plt.pie(numbers, labels=[str(i) for i in range(len(numbers))], autopct='%1.1f%%')
|
117
118
|
plt.title(title)
|
118
|
-
plt.show()
|
119
|
+
plt.show()
|
120
|
+
|
121
|
+
|
122
|
+
def generate_classification_report(predictions: List[str], labels: List[str]) -> dict:
|
123
|
+
"""
|
124
|
+
生成包含准确率、混淆矩阵、分类报告等详细评估结果的字典
|
125
|
+
|
126
|
+
Args:
|
127
|
+
predictions: 模型预测结果列表
|
128
|
+
labels: 真实标签列表
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
包含以下结构的字典:
|
132
|
+
- accuracy: 整体准确率
|
133
|
+
- confusion_matrix: 混淆矩阵DataFrame
|
134
|
+
- class_report: 分类报告DataFrame
|
135
|
+
- error_analysis: 错误样本分析DataFrame
|
136
|
+
- total_samples: 总样本数
|
137
|
+
- time_generated: 报告生成时间
|
138
|
+
"""
|
139
|
+
# 基础校验
|
140
|
+
assert len(predictions) == len(labels), "预测结果与标签长度不一致"
|
141
|
+
|
142
|
+
# 初始化报告字典
|
143
|
+
report = {}
|
144
|
+
|
145
|
+
# 获取唯一类别
|
146
|
+
classes = sorted(list(set(labels)))
|
147
|
+
error_label = "out_of_class"
|
148
|
+
extend_classes = classes + [error_label]
|
149
|
+
|
150
|
+
# 计算基础指标
|
151
|
+
total = len(labels)
|
152
|
+
correct = sum(p == l for p, l in zip(predictions, labels))
|
153
|
+
|
154
|
+
# 1. 准确率计算
|
155
|
+
report["accuracy"] = correct / total
|
156
|
+
|
157
|
+
# 2. 混淆矩阵构建
|
158
|
+
confusion = defaultdict(int)
|
159
|
+
for true_label, pred_label in zip(labels, predictions):
|
160
|
+
if pred_label not in classes:
|
161
|
+
pred_label = error_label
|
162
|
+
confusion[(true_label, pred_label)] += 1
|
163
|
+
|
164
|
+
confusion_matrix = pd.DataFrame(index=extend_classes, columns=extend_classes, data=0)
|
165
|
+
for (true, pred), count in confusion.items():
|
166
|
+
confusion_matrix.loc[true, pred] = count
|
167
|
+
|
168
|
+
# 3. 分类报告生成
|
169
|
+
micro_tp = 0
|
170
|
+
micro_fp = 0
|
171
|
+
micro_fn = 0
|
172
|
+
class_stats = []
|
173
|
+
for cls in extend_classes:
|
174
|
+
tp = confusion[(cls, cls)]
|
175
|
+
fp = sum(confusion[(other, cls)] for other in extend_classes if other != cls)
|
176
|
+
fn = sum(confusion[(cls, other)] for other in extend_classes if other != cls)
|
177
|
+
micro_tp += tp
|
178
|
+
micro_fp += fp
|
179
|
+
micro_fn += fn
|
180
|
+
|
181
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
182
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
183
|
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
184
|
+
|
185
|
+
class_stats.append(
|
186
|
+
{
|
187
|
+
"class": cls,
|
188
|
+
"precision": precision,
|
189
|
+
"recall": recall,
|
190
|
+
"f1_score": f1,
|
191
|
+
"support": sum(confusion[(cls, other)] for other in extend_classes),
|
192
|
+
},
|
193
|
+
)
|
194
|
+
|
195
|
+
# 添加汇总统计
|
196
|
+
class_df = pd.DataFrame(class_stats)
|
197
|
+
report["class_report"] = class_df
|
198
|
+
confusion_matrix["recall"] = class_df["recall"].values.tolist()
|
199
|
+
p = class_df["precision"].values.tolist() + [None]
|
200
|
+
tail = pd.DataFrame([p], index=["precision"], columns=confusion_matrix.columns)
|
201
|
+
confusion_matrix = pd.concat([confusion_matrix, tail], axis=0)
|
202
|
+
confusion_matrix.index.name = "True \\ Label"
|
203
|
+
report["confusion_matrix"] = confusion_matrix
|
204
|
+
|
205
|
+
micro_precision = micro_tp / (micro_tp + micro_fp) if (micro_tp + micro_fp) > 0 else 0
|
206
|
+
micro_recall = micro_tp / (micro_tp + micro_fn) if (micro_tp + micro_fn) > 0 else 0
|
207
|
+
micro_f1 = 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0
|
208
|
+
report["micro_stats"] = {
|
209
|
+
"precision": micro_precision,
|
210
|
+
"recall": micro_recall,
|
211
|
+
"f1_score": micro_f1,
|
212
|
+
}
|
213
|
+
report["macro_stats"] = {
|
214
|
+
"precision": class_df[class_df["class"] != error_label]["precision"].mean(),
|
215
|
+
"recall": class_df[class_df["class"] != error_label]["recall"].mean(),
|
216
|
+
"f1_score": class_df[class_df["class"] != error_label]["f1_score"].mean(),
|
217
|
+
}
|
218
|
+
|
219
|
+
# 4. 元数据信息
|
220
|
+
import datetime
|
221
|
+
report["total_samples"] = total
|
222
|
+
report["time_generated"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
223
|
+
|
224
|
+
return report
|
225
|
+
|
226
|
+
|
227
|
+
def print_classification_report(predictions: List[str], labels: List[str]):
|
228
|
+
report = generate_classification_report(predictions, labels)
|
229
|
+
"""
|
230
|
+
打印报告内容
|
231
|
+
"""
|
232
|
+
print(f"准确率: {report['accuracy']:.2%}")
|
233
|
+
print(f"总样本数: {report['total_samples']}, 生成时间: {report['time_generated']}")
|
234
|
+
print()
|
235
|
+
# 打印微观统计
|
236
|
+
print("=== 微观统计 ===")
|
237
|
+
micro_stats = report["micro_stats"]
|
238
|
+
print(f"准确率: {micro_stats['precision']:.2%}")
|
239
|
+
print(f"召回率: {micro_stats['recall']:.2%}")
|
240
|
+
print(f"F1分数: {micro_stats['f1_score']:.2%}")
|
241
|
+
print()
|
242
|
+
# 打印宏观统计
|
243
|
+
print("=== 宏观统计 ===")
|
244
|
+
macro_stats = report["macro_stats"]
|
245
|
+
print(f"准确率: {macro_stats['precision']:.2%}")
|
246
|
+
print(f"召回率: {macro_stats['recall']:.2%}")
|
247
|
+
print(f"F1分数: {macro_stats['f1_score']:.2%}")
|
248
|
+
print()
|
249
|
+
|
250
|
+
# 打印混淆矩阵
|
251
|
+
print("=== 混淆矩阵 ===")
|
252
|
+
print(report["confusion_matrix"])
|
253
|
+
print()
|
254
|
+
|
255
|
+
# 打印分类报告
|
256
|
+
print("=== 分类报告 ===")
|
257
|
+
print(report["class_report"])
|
258
|
+
print()
|
259
|
+
|
260
|
+
|
261
|
+
if __name__ == "__main__":
|
262
|
+
# 示例数据
|
263
|
+
preds = ["cat", "dog", "cat", "dog", "extra1", "extra2"]
|
264
|
+
truth = ["cat", "cat", "dog", "dog", "dog", "dog"]
|
265
|
+
|
266
|
+
print_classification_report(preds, truth)
|
xlin/timing.py
ADDED
@@ -0,0 +1,43 @@
|
|
1
|
+
from timeit import default_timer as timer
|
2
|
+
from functools import wraps
|
3
|
+
import time
|
4
|
+
|
5
|
+
|
6
|
+
class Benchmark(object):
|
7
|
+
|
8
|
+
def __init__(self, msg, fmt="%0.3g"):
|
9
|
+
self.msg = msg
|
10
|
+
self.fmt = fmt
|
11
|
+
|
12
|
+
def __enter__(self):
|
13
|
+
self.start = timer()
|
14
|
+
return self
|
15
|
+
|
16
|
+
def __exit__(self, *args):
|
17
|
+
t = timer() - self.start
|
18
|
+
print(("%s : " + self.fmt + " seconds") % (self.msg, t))
|
19
|
+
self.time = t
|
20
|
+
|
21
|
+
|
22
|
+
def timing(f):
|
23
|
+
@wraps(f)
|
24
|
+
def wrap(*args, **kw):
|
25
|
+
ts = time.time()
|
26
|
+
result = f(*args, **kw)
|
27
|
+
te = time.time()
|
28
|
+
print(f'func:{f.__name__!r} args:[{args!r}, {kw!r}] took: {te - ts:2.4f} sec')
|
29
|
+
return result
|
30
|
+
|
31
|
+
return wrap
|
32
|
+
|
33
|
+
|
34
|
+
class Timer:
|
35
|
+
""" Simple block which can be called as a context, to know the time of a block. """
|
36
|
+
|
37
|
+
def __enter__(self):
|
38
|
+
self.start = time.perf_counter()
|
39
|
+
return self
|
40
|
+
|
41
|
+
def __exit__(self, *args):
|
42
|
+
self.end = time.perf_counter()
|
43
|
+
self.interval = self.end - self.start
|
xlin/util.py
CHANGED
@@ -133,10 +133,18 @@ def cp(
|
|
133
133
|
base_input_dir = input_paths[0].parent
|
134
134
|
base_input_dir = Path(base_input_dir)
|
135
135
|
output_dir_path = Path(output_dir_path)
|
136
|
+
if output_dir_path.exists() and not output_dir_path.is_dir():
|
137
|
+
raise Exception(f"output_dir_path exists and is not a directory: {output_dir_path}")
|
138
|
+
if not output_dir_path.exists():
|
139
|
+
output_dir_path.mkdir(parents=True, exist_ok=True)
|
140
|
+
logger.warning(f"创建文件夹 {output_dir_path}")
|
141
|
+
if not base_input_dir.exists():
|
142
|
+
raise Exception(f"base_input_dir does not exist: {base_input_dir}")
|
143
|
+
if not base_input_dir.is_dir():
|
144
|
+
raise Exception(f"base_input_dir is not a directory: {base_input_dir}")
|
136
145
|
for input_path in input_paths:
|
137
146
|
relative_path = input_path.relative_to(base_input_dir)
|
138
147
|
output_path = output_dir_path / relative_path
|
139
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
140
148
|
copy_file(input_path, output_path, force_overwrite, verbose)
|
141
149
|
|
142
150
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: xlin
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.17
|
4
4
|
Summary: toolbox for LinXueyuan
|
5
5
|
License: MIT
|
6
|
-
Author:
|
6
|
+
Author: LinXueyuanStdio
|
7
7
|
Author-email: 23211526+LinXueyuanStdio@users.noreply.github.com
|
8
8
|
Classifier: License :: OSI Approved :: MIT License
|
9
9
|
Classifier: Programming Language :: Python :: 2
|
@@ -0,0 +1,15 @@
|
|
1
|
+
xlin/__init__.py,sha256=MWWCNPgJFS_oV2US52ULa4yg4Ku61qjn40NVKqcp9-c,248
|
2
|
+
xlin/ischinese.py,sha256=Ia9IMQ6q-UHkdLwqS70L1fTnfSPbluFrv_I1UqsKquo,293
|
3
|
+
xlin/jsonl.py,sha256=DvVM241a9VgQlp5WIMPRv-JIolT0RdSxw47IG_fc7xE,6690
|
4
|
+
xlin/metric.py,sha256=N7wJ35y-C-IaBr1I1CJ_37lTG7gA69zmn9Xg6xSwKoI,1690
|
5
|
+
xlin/multiprocess_mapping.py,sha256=pmzyEUYpbpIZ_ezyvWWWRpr7D7n4t3E3jW1nGXBbVck,7652
|
6
|
+
xlin/read_as_dataframe.py,sha256=P8bOYW-zm8uGhehCldZI9ZQhHHLGqDPDbSMNWI2li6g,8885
|
7
|
+
xlin/statistic.py,sha256=i0Z1gbW2IYHCA0lb16w1Ncrk0Q7Q1Ttm0n4we-ki6II,9301
|
8
|
+
xlin/timing.py,sha256=XMT8dMcMolOMohDvAZOIM_BAiPMREhGQKnO1kc5s6PU,998
|
9
|
+
xlin/util.py,sha256=TTWJaqF5D_r-gAZ_fj0kyHomvCagjwHXQZ2OPSgwd54,10976
|
10
|
+
xlin/xls2xlsx.py,sha256=5zfcM0gmunFQOcOj9nYd9Dj0HMhU7-cPKnPIy6Ot9iU,930
|
11
|
+
xlin/yaml.py,sha256=kICi7G3Td5q2MaSXXt85qNTWoHMgjzt7pvn7r3C4dME,183
|
12
|
+
xlin-0.1.17.dist-info/LICENSE,sha256=60ys6rRtc1dZOP8UjSUr9fAqhZudT3WpKe5WbMCralM,1066
|
13
|
+
xlin-0.1.17.dist-info/METADATA,sha256=Lg-wFcZRx0nvtw2tvaB6HCrLrPjRYnVELCp1Duz_IKI,1098
|
14
|
+
xlin-0.1.17.dist-info/WHEEL,sha256=IrRNNNJ-uuL1ggO5qMvT1GGhQVdQU54d6ZpYqEZfEWo,92
|
15
|
+
xlin-0.1.17.dist-info/RECORD,,
|
xlin-0.1.15.dist-info/RECORD
DELETED
@@ -1,14 +0,0 @@
|
|
1
|
-
xlin/__init__.py,sha256=xH5nS8y2RhQ8IDMM2pVkD5W0lxEFuymUSpzSWKo-358,226
|
2
|
-
xlin/ischinese.py,sha256=Ia9IMQ6q-UHkdLwqS70L1fTnfSPbluFrv_I1UqsKquo,293
|
3
|
-
xlin/jsonl.py,sha256=DvVM241a9VgQlp5WIMPRv-JIolT0RdSxw47IG_fc7xE,6690
|
4
|
-
xlin/metric.py,sha256=N7wJ35y-C-IaBr1I1CJ_37lTG7gA69zmn9Xg6xSwKoI,1690
|
5
|
-
xlin/multiprocess_mapping.py,sha256=pmzyEUYpbpIZ_ezyvWWWRpr7D7n4t3E3jW1nGXBbVck,7652
|
6
|
-
xlin/read_as_dataframe.py,sha256=P8bOYW-zm8uGhehCldZI9ZQhHHLGqDPDbSMNWI2li6g,8885
|
7
|
-
xlin/statistic.py,sha256=kp2P-Hr5Kb-R3dNgUXQieG8--iitjidg7SJuSiCpKdM,4131
|
8
|
-
xlin/util.py,sha256=RJHMBKC1xVwso3NfYXxIY3qqAfahzDDgzuU7jvNhQBA,10494
|
9
|
-
xlin/xls2xlsx.py,sha256=5zfcM0gmunFQOcOj9nYd9Dj0HMhU7-cPKnPIy6Ot9iU,930
|
10
|
-
xlin/yaml.py,sha256=kICi7G3Td5q2MaSXXt85qNTWoHMgjzt7pvn7r3C4dME,183
|
11
|
-
xlin-0.1.15.dist-info/LICENSE,sha256=KX0dDCYlO4DskqMZY8qeY94EZMrDRNnNqlGLkXVlKyM,1063
|
12
|
-
xlin-0.1.15.dist-info/METADATA,sha256=GI2Hz1o2lX6rOSEm12phfhejUx1jG3yC29tkLUen6IA,1089
|
13
|
-
xlin-0.1.15.dist-info/WHEEL,sha256=IrRNNNJ-uuL1ggO5qMvT1GGhQVdQU54d6ZpYqEZfEWo,92
|
14
|
-
xlin-0.1.15.dist-info/RECORD,,
|
File without changes
|