yms-kan 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan/train_eval_utils.py +3 -19
- yms_kan/version.py +1 -1
- {yms_kan-0.0.3.dist-info → yms_kan-0.0.5.dist-info}/METADATA +1 -1
- {yms_kan-0.0.3.dist-info → yms_kan-0.0.5.dist-info}/RECORD +7 -7
- {yms_kan-0.0.3.dist-info → yms_kan-0.0.5.dist-info}/WHEEL +0 -0
- {yms_kan-0.0.3.dist-info → yms_kan-0.0.5.dist-info}/licenses/LICENSE +0 -0
- {yms_kan-0.0.3.dist-info → yms_kan-0.0.5.dist-info}/top_level.txt +0 -0
yms_kan/train_eval_utils.py
CHANGED
@@ -5,27 +5,11 @@ import sys
|
|
5
5
|
import numpy as np
|
6
6
|
import torch
|
7
7
|
from matplotlib import pyplot as plt
|
8
|
-
from sklearn.metrics import classification_report
|
9
8
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
10
9
|
from tqdm import tqdm
|
11
10
|
|
12
11
|
from yms_kan import LBFGS
|
13
|
-
from yms_kan.tool import initialize_results_file, append_to_results_file
|
14
|
-
|
15
|
-
|
16
|
-
def calculate_metric(all_labels, all_predictions, classes, class_metric=False, average='macro avg'):
|
17
|
-
metric = classification_report(y_true=all_labels, y_pred=all_predictions,
|
18
|
-
target_names=classes, digits=4, output_dict=True, zero_division=0)
|
19
|
-
if not class_metric:
|
20
|
-
metric = {
|
21
|
-
'accuracy': metric.get('accuracy'),
|
22
|
-
'precision': metric.get(average).get('precision'),
|
23
|
-
'recall': metric.get(average).get('recall'),
|
24
|
-
'f1-score': metric.get(average).get('f1-score'),
|
25
|
-
}
|
26
|
-
return metric
|
27
|
-
else:
|
28
|
-
return metric
|
12
|
+
from yms_kan.tool import initialize_results_file, append_to_results_file, calculate_metric
|
29
13
|
|
30
14
|
|
31
15
|
def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_file=None, opt="LBFGS", epochs=100,
|
@@ -72,7 +56,7 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
72
56
|
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
|
73
57
|
|
74
58
|
results = {'train_losses': [], 'val_losses': [], 'accuracies': [], 'precisions': [], 'recalls': [], 'f1-scores': [],
|
75
|
-
'lrs': [], 'all_predictions': [], 'all_labels': []}
|
59
|
+
'lrs': [], 'all_predictions': [], 'all_labels': [], 'regularize': []}
|
76
60
|
|
77
61
|
steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
78
62
|
|
@@ -177,7 +161,7 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
177
161
|
if label is not None:
|
178
162
|
m = calculate_metric(all_labels, all_predictions, class_dict)
|
179
163
|
print(m)
|
180
|
-
results["
|
164
|
+
results["accuracies"].append(m["accuracy"])
|
181
165
|
results["precisions"].append(m["precision"])
|
182
166
|
results["recalls"].append(m["recall"])
|
183
167
|
results["f1-scores"].append(m["f1-score"])
|
yms_kan/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.0.
|
1
|
+
__version__ = "0.0.5" # 初始版本
|
@@ -10,13 +10,13 @@ yms_kan/feynman.py,sha256=Eisf69K49s4C6UlPEi5LnNK_p5TUJQLBKxMp-sW0a9w,33687
|
|
10
10
|
yms_kan/hypothesis.py,sha256=Ec20xadfgOSSWeZHQaGn-h9F2PY7LWFU3iniNI2Zd_4,23165
|
11
11
|
yms_kan/spline.py,sha256=ZXyGwl2Sc-UrnrcuUXeUQkBOMnetaWcHrbpZaqatCvs,4345
|
12
12
|
yms_kan/tool.py,sha256=rkRpqF3EcsAq7a3k1F1zKlxfJ4U9n-FzHyNCJgN4URY,21159
|
13
|
-
yms_kan/train_eval_utils.py,sha256=
|
13
|
+
yms_kan/train_eval_utils.py,sha256=e6XqlE3_i-AcGkWsWrVUjBsphTDK5T21sWR_RIJiTEs,16536
|
14
14
|
yms_kan/utils.py,sha256=J07L-tgmc1OfU6Tl6mGwHJRizjFN75EJK8BxejaZLUc,23860
|
15
|
-
yms_kan/version.py,sha256=
|
15
|
+
yms_kan/version.py,sha256=OUUIPtK264nL8fStaLu6GUnNkVZ5X18jfSMcppcWpa4,39
|
16
16
|
yms_kan/assets/img/mult_symbol.png,sha256=2f4xUKdweft-qUbHjFI5h9-smnEtc0FWq8hNYZhPAXY,6392
|
17
17
|
yms_kan/assets/img/sum_symbol.png,sha256=94QkMUzmEjlCq_yf14nMEQmettaq86FmlGfdl22b4XE,6210
|
18
|
-
yms_kan-0.0.
|
19
|
-
yms_kan-0.0.
|
20
|
-
yms_kan-0.0.
|
21
|
-
yms_kan-0.0.
|
22
|
-
yms_kan-0.0.
|
18
|
+
yms_kan-0.0.5.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
|
19
|
+
yms_kan-0.0.5.dist-info/METADATA,sha256=upxdcPXd-Frc3q0KyYZUahqV12F_toxb5Kh2TGGXNVs,240
|
20
|
+
yms_kan-0.0.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
21
|
+
yms_kan-0.0.5.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
|
22
|
+
yms_kan-0.0.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|