yms-kan 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,27 +5,11 @@ import sys
5
5
  import numpy as np
6
6
  import torch
7
7
  from matplotlib import pyplot as plt
8
- from sklearn.metrics import classification_report
9
8
  from torch.optim.lr_scheduler import ReduceLROnPlateau
10
9
  from tqdm import tqdm
11
10
 
12
11
  from yms_kan import LBFGS
13
- from yms_kan.tool import initialize_results_file, append_to_results_file
14
-
15
-
16
- def calculate_metric(all_labels, all_predictions, classes, class_metric=False, average='macro avg'):
17
- metric = classification_report(y_true=all_labels, y_pred=all_predictions,
18
- target_names=classes, digits=4, output_dict=True, zero_division=0)
19
- if not class_metric:
20
- metric = {
21
- 'accuracy': metric.get('accuracy'),
22
- 'precision': metric.get(average).get('precision'),
23
- 'recall': metric.get(average).get('recall'),
24
- 'f1-score': metric.get(average).get('f1-score'),
25
- }
26
- return metric
27
- else:
28
- return metric
12
+ from yms_kan.tool import initialize_results_file, append_to_results_file, calculate_metric
29
13
 
30
14
 
31
15
  def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_file=None, opt="LBFGS", epochs=100,
@@ -72,7 +56,7 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
72
56
  lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
73
57
 
74
58
  results = {'train_losses': [], 'val_losses': [], 'accuracies': [], 'precisions': [], 'recalls': [], 'f1-scores': [],
75
- 'lrs': [], 'all_predictions': [], 'all_labels': []}
59
+ 'lrs': [], 'all_predictions': [], 'all_labels': [], 'regularize': []}
76
60
 
77
61
  steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
78
62
 
yms_kan/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.0.4" # 初始版本
1
+ __version__ = "0.0.5" # 初始版本
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: yms_kan
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: My awesome package
5
5
  Author-email: yms <11@qq.com>
6
6
  License-Expression: MIT
@@ -10,13 +10,13 @@ yms_kan/feynman.py,sha256=Eisf69K49s4C6UlPEi5LnNK_p5TUJQLBKxMp-sW0a9w,33687
10
10
  yms_kan/hypothesis.py,sha256=Ec20xadfgOSSWeZHQaGn-h9F2PY7LWFU3iniNI2Zd_4,23165
11
11
  yms_kan/spline.py,sha256=ZXyGwl2Sc-UrnrcuUXeUQkBOMnetaWcHrbpZaqatCvs,4345
12
12
  yms_kan/tool.py,sha256=rkRpqF3EcsAq7a3k1F1zKlxfJ4U9n-FzHyNCJgN4URY,21159
13
- yms_kan/train_eval_utils.py,sha256=Cqw0heB7gOIK3pvOPBx0OIIWi2glfimPpyDqboFq2Tk,17186
13
+ yms_kan/train_eval_utils.py,sha256=e6XqlE3_i-AcGkWsWrVUjBsphTDK5T21sWR_RIJiTEs,16536
14
14
  yms_kan/utils.py,sha256=J07L-tgmc1OfU6Tl6mGwHJRizjFN75EJK8BxejaZLUc,23860
15
- yms_kan/version.py,sha256=eECtaVYZj2CuGnsLuv9pAmxQhOb0PZcTisjYg4JgF5c,39
15
+ yms_kan/version.py,sha256=OUUIPtK264nL8fStaLu6GUnNkVZ5X18jfSMcppcWpa4,39
16
16
  yms_kan/assets/img/mult_symbol.png,sha256=2f4xUKdweft-qUbHjFI5h9-smnEtc0FWq8hNYZhPAXY,6392
17
17
  yms_kan/assets/img/sum_symbol.png,sha256=94QkMUzmEjlCq_yf14nMEQmettaq86FmlGfdl22b4XE,6210
18
- yms_kan-0.0.4.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
19
- yms_kan-0.0.4.dist-info/METADATA,sha256=VsDT6gWg7lsWcP424XM-o-jsTQ6eAlmGQWbb-hLxGmk,240
20
- yms_kan-0.0.4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
21
- yms_kan-0.0.4.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
22
- yms_kan-0.0.4.dist-info/RECORD,,
18
+ yms_kan-0.0.5.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
19
+ yms_kan-0.0.5.dist-info/METADATA,sha256=upxdcPXd-Frc3q0KyYZUahqV12F_toxb5Kh2TGGXNVs,240
20
+ yms_kan-0.0.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
21
+ yms_kan-0.0.5.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
22
+ yms_kan-0.0.5.dist-info/RECORD,,