yms-kan 0.0.1__tar.gz → 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan-0.0.2/MANIFEST.in +1 -0
- {yms_kan-0.0.1/yms_kan.egg-info → yms_kan-0.0.2}/PKG-INFO +1 -1
- {yms_kan-0.0.1 → yms_kan-0.0.2}/pyproject.toml +12 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/MultKAN.py +5 -2
- yms_kan-0.0.2/yms_kan/assets/img/mult_symbol.png +0 -0
- yms_kan-0.0.2/yms_kan/assets/img/sum_symbol.png +0 -0
- yms_kan-0.0.2/yms_kan/train_eval_utils.py +318 -0
- yms_kan-0.0.2/yms_kan/version.py +1 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2/yms_kan.egg-info}/PKG-INFO +1 -1
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan.egg-info/SOURCES.txt +4 -1
- yms_kan-0.0.1/yms_kan/train_eval_utils.py +0 -175
- yms_kan-0.0.1/yms_kan/version.py +0 -1
- {yms_kan-0.0.1 → yms_kan-0.0.2}/LICENSE +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/README.md +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/setup.cfg +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/KANLayer.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/LBFGS.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/MLP.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/Symbolic_KANLayer.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/__init__.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/compiler.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/experiment.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/feynman.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/hypothesis.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/spline.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/tool.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/utils.py +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan.egg-info/dependency_links.txt +0 -0
- {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan.egg-info/top_level.txt +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
recursive-include yms_kan/assets/img *.png
|
@@ -17,3 +17,15 @@ license-files = ["LICENSE"]
|
|
17
17
|
[tool.setuptools.dynamic]
|
18
18
|
# 明确指定版本号来源
|
19
19
|
version = {attr = "yms_kan.version.__version__"}
|
20
|
+
|
21
|
+
[tool.setuptools]
|
22
|
+
# 包含非代码文件
|
23
|
+
include-package-data = true
|
24
|
+
|
25
|
+
[tool.setuptools.package-data]
|
26
|
+
# 指定包内资源文件的匹配规则
|
27
|
+
yms_kan = [
|
28
|
+
"assets/img/*.png", # 包含所有png文件
|
29
|
+
"assets/img/*.svg", # 可扩展其他格式
|
30
|
+
"assets/**/*" # 递归包含子目录
|
31
|
+
]
|
@@ -3,6 +3,7 @@ import math
|
|
3
3
|
import os
|
4
4
|
import random
|
5
5
|
import sys
|
6
|
+
from importlib.resources import files
|
6
7
|
|
7
8
|
import matplotlib.pyplot as plt
|
8
9
|
import numpy as np
|
@@ -1299,7 +1300,8 @@ class MultKAN(nn.Module):
|
|
1299
1300
|
N = n = width_out[l + 1]
|
1300
1301
|
for j in range(n):
|
1301
1302
|
id_ = j
|
1302
|
-
path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
|
1303
|
+
# path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
|
1304
|
+
path = files('yms_kan') / "assets/img/sum_symbol.png"
|
1303
1305
|
im = plt.imread(path)
|
1304
1306
|
left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
|
1305
1307
|
right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
|
@@ -1315,7 +1317,8 @@ class MultKAN(nn.Module):
|
|
1315
1317
|
n_mult = width[l + 1][1]
|
1316
1318
|
for j in range(n_mult):
|
1317
1319
|
id_ = j + n_sum
|
1318
|
-
path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
|
1320
|
+
# path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
|
1321
|
+
path = files('yms_kan') / "assets/img/mult_symbol.png"
|
1319
1322
|
im = plt.imread(path)
|
1320
1323
|
left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
|
1321
1324
|
right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
|
Binary file
|
Binary file
|
@@ -0,0 +1,318 @@
|
|
1
|
+
import math
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
from enum import Enum, auto
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from matplotlib import pyplot as plt
|
9
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
10
|
+
from tqdm import tqdm
|
11
|
+
|
12
|
+
from yms_kan import LBFGS
|
13
|
+
|
14
|
+
|
15
|
+
def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
|
16
|
+
lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
|
17
|
+
lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
|
18
|
+
stop_grid_update_step=100,
|
19
|
+
save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
20
|
+
singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
|
21
|
+
# all_predictions = []
|
22
|
+
# all_labels = []
|
23
|
+
if lamb > 0. and not model.save_act:
|
24
|
+
print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
|
25
|
+
|
26
|
+
old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
|
27
|
+
if label is not None:
|
28
|
+
label = label.to(model.device)
|
29
|
+
|
30
|
+
if loss_fn is None:
|
31
|
+
loss_fn = lambda x, y: torch.mean((x - y) ** 2)
|
32
|
+
else:
|
33
|
+
loss_fn = loss_fn
|
34
|
+
|
35
|
+
grid_update_freq = int(stop_grid_update_step / grid_update_num)
|
36
|
+
|
37
|
+
if opt == "Adam":
|
38
|
+
optimizer = torch.optim.Adam(model.get_params(), lr=lr)
|
39
|
+
elif opt == "LBFGS":
|
40
|
+
optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
41
|
+
tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
42
|
+
else:
|
43
|
+
optimizer = torch.optim.SGD(model.get_params(), lr=lr)
|
44
|
+
|
45
|
+
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
|
46
|
+
|
47
|
+
results = {'train_loss': .0, 'val_loss': .0, 'regularize': .0, 'all_predictions': [],
|
48
|
+
'all_labels': []}
|
49
|
+
|
50
|
+
steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
51
|
+
|
52
|
+
train_loss = torch.zeros(1).to(model.device)
|
53
|
+
reg_ = torch.zeros(1).to(model.device)
|
54
|
+
|
55
|
+
def closure():
|
56
|
+
nonlocal train_loss, reg_
|
57
|
+
optimizer.zero_grad()
|
58
|
+
pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
|
59
|
+
loss = loss_fn(pred, batch_train_label)
|
60
|
+
if model.save_act:
|
61
|
+
if reg_metric == 'edge_backward':
|
62
|
+
model.attribute()
|
63
|
+
if reg_metric == 'node_backward':
|
64
|
+
model.node_attribute()
|
65
|
+
reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
66
|
+
else:
|
67
|
+
reg_ = torch.tensor(0.)
|
68
|
+
objective = loss + lamb * reg_
|
69
|
+
train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
|
70
|
+
objective.backward()
|
71
|
+
return objective
|
72
|
+
|
73
|
+
if save_fig:
|
74
|
+
if not os.path.exists(img_folder):
|
75
|
+
os.makedirs(img_folder)
|
76
|
+
|
77
|
+
for epoch in range(epochs):
|
78
|
+
|
79
|
+
if epoch == epochs - 1 and old_save_act:
|
80
|
+
model.save_act = True
|
81
|
+
|
82
|
+
if save_fig and epoch % save_fig_freq == 0:
|
83
|
+
save_act = model.save_act
|
84
|
+
model.save_act = True
|
85
|
+
|
86
|
+
train_indices = np.arange(dataset['train_input'].shape[0])
|
87
|
+
np.random.shuffle(train_indices)
|
88
|
+
train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
|
89
|
+
for batch_num in train_pbar:
|
90
|
+
step = epoch * steps + batch_num + 1
|
91
|
+
i = batch_num * batch_size
|
92
|
+
batch_train_id = train_indices[i:i + batch_size]
|
93
|
+
batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
|
94
|
+
batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
|
95
|
+
|
96
|
+
if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
|
97
|
+
model.update_grid(batch_train_input)
|
98
|
+
|
99
|
+
if opt == "LBFGS":
|
100
|
+
optimizer.step(closure)
|
101
|
+
|
102
|
+
else:
|
103
|
+
optimizer.zero_grad()
|
104
|
+
pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
|
105
|
+
y_th=y_th)
|
106
|
+
loss = loss_fn(pred, batch_train_label)
|
107
|
+
if model.save_act:
|
108
|
+
if reg_metric == 'edge_backward':
|
109
|
+
model.attribute()
|
110
|
+
if reg_metric == 'node_backward':
|
111
|
+
model.node_attribute()
|
112
|
+
reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
113
|
+
else:
|
114
|
+
reg_ = torch.tensor(0.)
|
115
|
+
loss = loss + lamb * reg_
|
116
|
+
train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
|
117
|
+
loss.backward()
|
118
|
+
optimizer.step()
|
119
|
+
train_pbar.set_postfix(loss=train_loss.item())
|
120
|
+
|
121
|
+
# print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
|
122
|
+
val_loss = torch.zeros(1).to(model.device)
|
123
|
+
with torch.no_grad():
|
124
|
+
test_indices = np.arange(dataset['test_input'].shape[0])
|
125
|
+
np.random.shuffle(test_indices)
|
126
|
+
test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
|
127
|
+
test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
|
128
|
+
for batch_num in test_pbar:
|
129
|
+
i = batch_num * batch_size_test
|
130
|
+
batch_test_id = test_indices[i:i + batch_size_test]
|
131
|
+
batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
|
132
|
+
batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
|
133
|
+
|
134
|
+
outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
|
135
|
+
y_th=y_th)
|
136
|
+
|
137
|
+
loss = loss_fn(outputs, batch_test_label)
|
138
|
+
|
139
|
+
val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
|
140
|
+
test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
|
141
|
+
if label is not None:
|
142
|
+
diffs = torch.abs(outputs - label)
|
143
|
+
closest_indices = torch.argmin(diffs, dim=1)
|
144
|
+
closest_values = label[closest_indices]
|
145
|
+
results['all_predictions'].extend(closest_values.detach().cpu().numpy())
|
146
|
+
results['all_labels'].extend(batch_test_label.detach().cpu().numpy())
|
147
|
+
|
148
|
+
lr_scheduler.step(val_loss)
|
149
|
+
|
150
|
+
results['train_loss'] = train_loss.item()
|
151
|
+
results['val_loss'] = val_loss.item()
|
152
|
+
results['regularize'] = reg_.item()
|
153
|
+
|
154
|
+
if save_fig and epoch % save_fig_freq == 0:
|
155
|
+
model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
|
156
|
+
beta=beta)
|
157
|
+
plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
|
158
|
+
plt.close()
|
159
|
+
model.save_act = save_act
|
160
|
+
|
161
|
+
# append_to_results_file(results_file, results, result_info)
|
162
|
+
model.log_history('fit')
|
163
|
+
model.symbolic_enabled = old_symbolic_enabled
|
164
|
+
return results
|
165
|
+
|
166
|
+
# def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
|
167
|
+
# lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
|
168
|
+
# lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
|
169
|
+
# stop_grid_update_step=100,
|
170
|
+
# save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
171
|
+
# singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
|
172
|
+
# # result_info = ['epoch','train_losses', 'val_losses', 'regularize', 'accuracies',
|
173
|
+
# # 'precisions', 'recalls', 'f1-scores']
|
174
|
+
# # initialize_results_file(results_file, result_info)
|
175
|
+
# all_predictions = []
|
176
|
+
# all_labels = []
|
177
|
+
# if lamb > 0. and not model.save_act:
|
178
|
+
# print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
|
179
|
+
#
|
180
|
+
# old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
|
181
|
+
# if label is not None:
|
182
|
+
# label = label.to(model.device)
|
183
|
+
#
|
184
|
+
# if loss_fn is None:
|
185
|
+
# loss_fn = lambda x, y: torch.mean((x - y) ** 2)
|
186
|
+
# else:
|
187
|
+
# loss_fn = loss_fn
|
188
|
+
#
|
189
|
+
# grid_update_freq = int(stop_grid_update_step / grid_update_num)
|
190
|
+
#
|
191
|
+
# if opt == "Adam":
|
192
|
+
# optimizer = torch.optim.Adam(model.get_params(), lr=lr)
|
193
|
+
# elif opt == "LBFGS":
|
194
|
+
# optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
195
|
+
# tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
196
|
+
# else:
|
197
|
+
# optimizer = torch.optim.SGD(model.get_params(), lr=lr)
|
198
|
+
#
|
199
|
+
# lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
|
200
|
+
#
|
201
|
+
# results = {'train_losses': [], 'val_losses': [], 'regularize': [], 'accuracies': [],
|
202
|
+
# 'precisions': [], 'recalls': [], 'f1-scores': []}
|
203
|
+
#
|
204
|
+
# steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
205
|
+
#
|
206
|
+
# train_loss = torch.zeros(1).to(model.device)
|
207
|
+
# reg_ = torch.zeros(1).to(model.device)
|
208
|
+
#
|
209
|
+
# def closure():
|
210
|
+
# nonlocal train_loss, reg_
|
211
|
+
# optimizer.zero_grad()
|
212
|
+
# pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
|
213
|
+
# loss = loss_fn(pred, batch_train_label)
|
214
|
+
# if model.save_act:
|
215
|
+
# if reg_metric == 'edge_backward':
|
216
|
+
# model.attribute()
|
217
|
+
# if reg_metric == 'node_backward':
|
218
|
+
# model.node_attribute()
|
219
|
+
# reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
220
|
+
# else:
|
221
|
+
# reg_ = torch.tensor(0.)
|
222
|
+
# objective = loss + lamb * reg_
|
223
|
+
# train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
|
224
|
+
# objective.backward()
|
225
|
+
# return objective
|
226
|
+
#
|
227
|
+
# if save_fig:
|
228
|
+
# if not os.path.exists(img_folder):
|
229
|
+
# os.makedirs(img_folder)
|
230
|
+
#
|
231
|
+
# for epoch in range(epochs):
|
232
|
+
#
|
233
|
+
# if epoch == epochs - 1 and old_save_act:
|
234
|
+
# model.save_act = True
|
235
|
+
#
|
236
|
+
# if save_fig and epoch % save_fig_freq == 0:
|
237
|
+
# save_act = model.save_act
|
238
|
+
# model.save_act = True
|
239
|
+
#
|
240
|
+
# train_indices = np.arange(dataset['train_input'].shape[0])
|
241
|
+
# np.random.shuffle(train_indices)
|
242
|
+
# train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
|
243
|
+
# for batch_num in train_pbar:
|
244
|
+
# step = epoch * steps + batch_num + 1
|
245
|
+
# i = batch_num * batch_size
|
246
|
+
# batch_train_id = train_indices[i:i + batch_size]
|
247
|
+
# batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
|
248
|
+
# batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
|
249
|
+
#
|
250
|
+
# if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
|
251
|
+
# model.update_grid(batch_train_input)
|
252
|
+
#
|
253
|
+
# if opt == "LBFGS":
|
254
|
+
# optimizer.step(closure)
|
255
|
+
#
|
256
|
+
# else:
|
257
|
+
# optimizer.zero_grad()
|
258
|
+
# pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
|
259
|
+
# y_th=y_th)
|
260
|
+
# loss = loss_fn(pred, batch_train_label)
|
261
|
+
# if model.save_act:
|
262
|
+
# if reg_metric == 'edge_backward':
|
263
|
+
# model.attribute()
|
264
|
+
# if reg_metric == 'node_backward':
|
265
|
+
# model.node_attribute()
|
266
|
+
# reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
267
|
+
# else:
|
268
|
+
# reg_ = torch.tensor(0.)
|
269
|
+
# loss = loss + lamb * reg_
|
270
|
+
# train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
|
271
|
+
# loss.backward()
|
272
|
+
# optimizer.step()
|
273
|
+
# train_pbar.set_postfix(loss=train_loss.item())
|
274
|
+
#
|
275
|
+
# print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
|
276
|
+
# val_loss = torch.zeros(1).to(model.device)
|
277
|
+
# with torch.no_grad():
|
278
|
+
# test_indices = np.arange(dataset['test_input'].shape[0])
|
279
|
+
# np.random.shuffle(test_indices)
|
280
|
+
# test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
|
281
|
+
# test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
|
282
|
+
# for batch_num in test_pbar:
|
283
|
+
# i = batch_num * batch_size_test
|
284
|
+
# batch_test_id = test_indices[i:i + batch_size_test]
|
285
|
+
# batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
|
286
|
+
# batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
|
287
|
+
#
|
288
|
+
# outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
|
289
|
+
# y_th=y_th)
|
290
|
+
#
|
291
|
+
# loss = loss_fn(outputs, batch_test_label)
|
292
|
+
#
|
293
|
+
# val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
|
294
|
+
# test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
|
295
|
+
# if label is not None:
|
296
|
+
# diffs = torch.abs(outputs - label)
|
297
|
+
# closest_indices = torch.argmin(diffs, dim=1)
|
298
|
+
# closest_values = label[closest_indices]
|
299
|
+
# all_predictions.extend(closest_values.detach().cpu().numpy())
|
300
|
+
# all_labels.extend(batch_test_label.detach().cpu().numpy())
|
301
|
+
#
|
302
|
+
# lr_scheduler.step(val_loss)
|
303
|
+
#
|
304
|
+
# results['train_losses'].append(train_loss.cpu().item())
|
305
|
+
# results['val_losses'].append(val_loss.cpu().item())
|
306
|
+
# results['regularize'].append(reg_.cpu().item())
|
307
|
+
#
|
308
|
+
# if save_fig and epoch % save_fig_freq == 0:
|
309
|
+
# model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
|
310
|
+
# beta=beta)
|
311
|
+
# plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
|
312
|
+
# plt.close()
|
313
|
+
# model.save_act = save_act
|
314
|
+
#
|
315
|
+
# # append_to_results_file(results_file, results, result_info)
|
316
|
+
# model.log_history('fit')
|
317
|
+
# model.symbolic_enabled = old_symbolic_enabled
|
318
|
+
# return results
|
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = "0.0.2" # 初始版本
|
@@ -1,4 +1,5 @@
|
|
1
1
|
LICENSE
|
2
|
+
MANIFEST.in
|
2
3
|
README.md
|
3
4
|
pyproject.toml
|
4
5
|
yms_kan/KANLayer.py
|
@@ -19,4 +20,6 @@ yms_kan/version.py
|
|
19
20
|
yms_kan.egg-info/PKG-INFO
|
20
21
|
yms_kan.egg-info/SOURCES.txt
|
21
22
|
yms_kan.egg-info/dependency_links.txt
|
22
|
-
yms_kan.egg-info/top_level.txt
|
23
|
+
yms_kan.egg-info/top_level.txt
|
24
|
+
yms_kan/assets/img/mult_symbol.png
|
25
|
+
yms_kan/assets/img/sum_symbol.png
|
@@ -1,175 +0,0 @@
|
|
1
|
-
import math
|
2
|
-
import os
|
3
|
-
import sys
|
4
|
-
from enum import Enum, auto
|
5
|
-
|
6
|
-
import numpy as np
|
7
|
-
import torch
|
8
|
-
from matplotlib import pyplot as plt
|
9
|
-
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
10
|
-
from tqdm import tqdm
|
11
|
-
|
12
|
-
from yms_kan import LBFGS
|
13
|
-
|
14
|
-
|
15
|
-
class TaskType(Enum):
|
16
|
-
classification = auto()
|
17
|
-
zlm = auto()
|
18
|
-
|
19
|
-
|
20
|
-
def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
|
21
|
-
lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
|
22
|
-
lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
|
23
|
-
stop_grid_update_step=100,
|
24
|
-
save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
25
|
-
singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
|
26
|
-
# result_info = ['epoch','train_losses', 'val_losses', 'regularize', 'accuracies',
|
27
|
-
# 'precisions', 'recalls', 'f1-scores']
|
28
|
-
# initialize_results_file(results_file, result_info)
|
29
|
-
all_predictions = []
|
30
|
-
all_labels = []
|
31
|
-
if lamb > 0. and not model.save_act:
|
32
|
-
print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
|
33
|
-
|
34
|
-
old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
|
35
|
-
if label is not None:
|
36
|
-
label = label.to(model.device)
|
37
|
-
|
38
|
-
if loss_fn is None:
|
39
|
-
loss_fn = lambda x, y: torch.mean((x - y) ** 2)
|
40
|
-
else:
|
41
|
-
loss_fn = loss_fn
|
42
|
-
|
43
|
-
grid_update_freq = int(stop_grid_update_step / grid_update_num)
|
44
|
-
|
45
|
-
if opt == "Adam":
|
46
|
-
optimizer = torch.optim.Adam(model.get_params(), lr=lr)
|
47
|
-
elif opt == "LBFGS":
|
48
|
-
optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
49
|
-
tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
50
|
-
else:
|
51
|
-
optimizer = torch.optim.SGD(model.get_params(), lr=lr)
|
52
|
-
|
53
|
-
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
|
54
|
-
|
55
|
-
results = {'train_losses': [], 'val_losses': [], 'regularize': [], 'accuracies': [],
|
56
|
-
'precisions': [], 'recalls': [], 'f1-scores': []}
|
57
|
-
|
58
|
-
steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
59
|
-
|
60
|
-
train_loss = torch.zeros(1).to(model.device)
|
61
|
-
reg_ = torch.zeros(1).to(model.device)
|
62
|
-
|
63
|
-
def closure():
|
64
|
-
nonlocal train_loss, reg_
|
65
|
-
optimizer.zero_grad()
|
66
|
-
pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
|
67
|
-
loss = loss_fn(pred, batch_train_label)
|
68
|
-
if model.save_act:
|
69
|
-
if reg_metric == 'edge_backward':
|
70
|
-
model.attribute()
|
71
|
-
if reg_metric == 'node_backward':
|
72
|
-
model.node_attribute()
|
73
|
-
reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
74
|
-
else:
|
75
|
-
reg_ = torch.tensor(0.)
|
76
|
-
objective = loss + lamb * reg_
|
77
|
-
train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
|
78
|
-
objective.backward()
|
79
|
-
return objective
|
80
|
-
|
81
|
-
if save_fig:
|
82
|
-
if not os.path.exists(img_folder):
|
83
|
-
os.makedirs(img_folder)
|
84
|
-
|
85
|
-
for epoch in range(epochs):
|
86
|
-
|
87
|
-
if epoch == epochs - 1 and old_save_act:
|
88
|
-
model.save_act = True
|
89
|
-
|
90
|
-
if save_fig and epoch % save_fig_freq == 0:
|
91
|
-
save_act = model.save_act
|
92
|
-
model.save_act = True
|
93
|
-
|
94
|
-
train_indices = np.arange(dataset['train_input'].shape[0])
|
95
|
-
np.random.shuffle(train_indices)
|
96
|
-
train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
|
97
|
-
for batch_num in train_pbar:
|
98
|
-
step = epoch * steps + batch_num + 1
|
99
|
-
i = batch_num * batch_size
|
100
|
-
batch_train_id = train_indices[i:i + batch_size]
|
101
|
-
batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
|
102
|
-
batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
|
103
|
-
|
104
|
-
if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
|
105
|
-
model.update_grid(batch_train_input)
|
106
|
-
|
107
|
-
if opt == "LBFGS":
|
108
|
-
optimizer.step(closure)
|
109
|
-
|
110
|
-
else:
|
111
|
-
optimizer.zero_grad()
|
112
|
-
pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
|
113
|
-
y_th=y_th)
|
114
|
-
loss = loss_fn(pred, batch_train_label)
|
115
|
-
if model.save_act:
|
116
|
-
if reg_metric == 'edge_backward':
|
117
|
-
model.attribute()
|
118
|
-
if reg_metric == 'node_backward':
|
119
|
-
model.node_attribute()
|
120
|
-
reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
121
|
-
else:
|
122
|
-
reg_ = torch.tensor(0.)
|
123
|
-
loss = loss + lamb * reg_
|
124
|
-
train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
|
125
|
-
loss.backward()
|
126
|
-
optimizer.step()
|
127
|
-
train_pbar.set_postfix(loss=train_loss.item())
|
128
|
-
|
129
|
-
val_loss = torch.zeros(1).to(model.device)
|
130
|
-
with torch.no_grad():
|
131
|
-
test_indices = np.arange(dataset['test_input'].shape[0])
|
132
|
-
np.random.shuffle(test_indices)
|
133
|
-
test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
|
134
|
-
test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
|
135
|
-
for batch_num in test_pbar:
|
136
|
-
i = batch_num * batch_size_test
|
137
|
-
batch_test_id = test_indices[i:i + batch_size_test]
|
138
|
-
batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
|
139
|
-
batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
|
140
|
-
|
141
|
-
outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
|
142
|
-
y_th=y_th)
|
143
|
-
|
144
|
-
loss = loss_fn(outputs, batch_test_label)
|
145
|
-
|
146
|
-
val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
|
147
|
-
test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
|
148
|
-
if label is not None:
|
149
|
-
diffs = torch.abs(outputs - label)
|
150
|
-
closest_indices = torch.argmin(diffs, dim=1)
|
151
|
-
closest_values = label[closest_indices]
|
152
|
-
all_predictions.extend(closest_values.detach().cpu().numpy())
|
153
|
-
all_labels.extend(batch_test_label.detach().cpu().numpy())
|
154
|
-
|
155
|
-
lr_scheduler.step(val_loss)
|
156
|
-
|
157
|
-
results['train_losses'].append(train_loss.cpu().item())
|
158
|
-
results['val_losses'].append(val_loss.cpu().item())
|
159
|
-
results['regularize'].append(reg_.cpu().item())
|
160
|
-
|
161
|
-
if save_fig and epoch % save_fig_freq == 0:
|
162
|
-
model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
|
163
|
-
beta=beta)
|
164
|
-
plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
|
165
|
-
plt.close()
|
166
|
-
model.save_act = save_act
|
167
|
-
|
168
|
-
# append_to_results_file(results_file, results, result_info)
|
169
|
-
model.log_history('fit')
|
170
|
-
model.symbolic_enabled = old_symbolic_enabled
|
171
|
-
return results
|
172
|
-
|
173
|
-
|
174
|
-
if __name__ == '__main__':
|
175
|
-
print(TaskType.zlm.value)
|
yms_kan-0.0.1/yms_kan/version.py
DELETED
@@ -1 +0,0 @@
|
|
1
|
-
__version__ = "0.0.1" # 初始版本
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|