yms-kan 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan/MultKAN.py +5 -2
- yms_kan/assets/img/mult_symbol.png +0 -0
- yms_kan/assets/img/sum_symbol.png +0 -0
- yms_kan/train_eval_utils.py +163 -20
- yms_kan/version.py +1 -1
- {yms_kan-0.0.1.dist-info → yms_kan-0.0.2.dist-info}/METADATA +1 -1
- {yms_kan-0.0.1.dist-info → yms_kan-0.0.2.dist-info}/RECORD +10 -8
- {yms_kan-0.0.1.dist-info → yms_kan-0.0.2.dist-info}/WHEEL +0 -0
- {yms_kan-0.0.1.dist-info → yms_kan-0.0.2.dist-info}/licenses/LICENSE +0 -0
- {yms_kan-0.0.1.dist-info → yms_kan-0.0.2.dist-info}/top_level.txt +0 -0
yms_kan/MultKAN.py
CHANGED
@@ -3,6 +3,7 @@ import math
|
|
3
3
|
import os
|
4
4
|
import random
|
5
5
|
import sys
|
6
|
+
from importlib.resources import files
|
6
7
|
|
7
8
|
import matplotlib.pyplot as plt
|
8
9
|
import numpy as np
|
@@ -1299,7 +1300,8 @@ class MultKAN(nn.Module):
|
|
1299
1300
|
N = n = width_out[l + 1]
|
1300
1301
|
for j in range(n):
|
1301
1302
|
id_ = j
|
1302
|
-
path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
|
1303
|
+
# path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
|
1304
|
+
path = files('yms_kan') / "assets/img/sum_symbol.png"
|
1303
1305
|
im = plt.imread(path)
|
1304
1306
|
left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
|
1305
1307
|
right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
|
@@ -1315,7 +1317,8 @@ class MultKAN(nn.Module):
|
|
1315
1317
|
n_mult = width[l + 1][1]
|
1316
1318
|
for j in range(n_mult):
|
1317
1319
|
id_ = j + n_sum
|
1318
|
-
path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
|
1320
|
+
# path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
|
1321
|
+
path = files('yms_kan') / "assets/img/mult_symbol.png"
|
1319
1322
|
im = plt.imread(path)
|
1320
1323
|
left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
|
1321
1324
|
right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
|
Binary file
|
Binary file
|
yms_kan/train_eval_utils.py
CHANGED
@@ -12,22 +12,14 @@ from tqdm import tqdm
|
|
12
12
|
from yms_kan import LBFGS
|
13
13
|
|
14
14
|
|
15
|
-
class TaskType(Enum):
|
16
|
-
classification = auto()
|
17
|
-
zlm = auto()
|
18
|
-
|
19
|
-
|
20
15
|
def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
|
21
16
|
lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
|
22
17
|
lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
|
23
18
|
stop_grid_update_step=100,
|
24
19
|
save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
25
20
|
singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
|
26
|
-
#
|
27
|
-
#
|
28
|
-
# initialize_results_file(results_file, result_info)
|
29
|
-
all_predictions = []
|
30
|
-
all_labels = []
|
21
|
+
# all_predictions = []
|
22
|
+
# all_labels = []
|
31
23
|
if lamb > 0. and not model.save_act:
|
32
24
|
print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
|
33
25
|
|
@@ -52,8 +44,8 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
|
|
52
44
|
|
53
45
|
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
|
54
46
|
|
55
|
-
results = {'
|
56
|
-
'
|
47
|
+
results = {'train_loss': .0, 'val_loss': .0, 'regularize': .0, 'all_predictions': [],
|
48
|
+
'all_labels': []}
|
57
49
|
|
58
50
|
steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
59
51
|
|
@@ -126,6 +118,7 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
|
|
126
118
|
optimizer.step()
|
127
119
|
train_pbar.set_postfix(loss=train_loss.item())
|
128
120
|
|
121
|
+
# print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
|
129
122
|
val_loss = torch.zeros(1).to(model.device)
|
130
123
|
with torch.no_grad():
|
131
124
|
test_indices = np.arange(dataset['test_input'].shape[0])
|
@@ -149,14 +142,14 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
|
|
149
142
|
diffs = torch.abs(outputs - label)
|
150
143
|
closest_indices = torch.argmin(diffs, dim=1)
|
151
144
|
closest_values = label[closest_indices]
|
152
|
-
all_predictions.extend(closest_values.detach().cpu().numpy())
|
153
|
-
all_labels.extend(batch_test_label.detach().cpu().numpy())
|
145
|
+
results['all_predictions'].extend(closest_values.detach().cpu().numpy())
|
146
|
+
results['all_labels'].extend(batch_test_label.detach().cpu().numpy())
|
154
147
|
|
155
148
|
lr_scheduler.step(val_loss)
|
156
149
|
|
157
|
-
results['
|
158
|
-
results['
|
159
|
-
results['regularize']
|
150
|
+
results['train_loss'] = train_loss.item()
|
151
|
+
results['val_loss'] = val_loss.item()
|
152
|
+
results['regularize'] = reg_.item()
|
160
153
|
|
161
154
|
if save_fig and epoch % save_fig_freq == 0:
|
162
155
|
model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
|
@@ -170,6 +163,156 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
|
|
170
163
|
model.symbolic_enabled = old_symbolic_enabled
|
171
164
|
return results
|
172
165
|
|
173
|
-
|
174
|
-
|
175
|
-
|
166
|
+
# def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
|
167
|
+
# lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
|
168
|
+
# lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
|
169
|
+
# stop_grid_update_step=100,
|
170
|
+
# save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
171
|
+
# singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
|
172
|
+
# # result_info = ['epoch','train_losses', 'val_losses', 'regularize', 'accuracies',
|
173
|
+
# # 'precisions', 'recalls', 'f1-scores']
|
174
|
+
# # initialize_results_file(results_file, result_info)
|
175
|
+
# all_predictions = []
|
176
|
+
# all_labels = []
|
177
|
+
# if lamb > 0. and not model.save_act:
|
178
|
+
# print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
|
179
|
+
#
|
180
|
+
# old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
|
181
|
+
# if label is not None:
|
182
|
+
# label = label.to(model.device)
|
183
|
+
#
|
184
|
+
# if loss_fn is None:
|
185
|
+
# loss_fn = lambda x, y: torch.mean((x - y) ** 2)
|
186
|
+
# else:
|
187
|
+
# loss_fn = loss_fn
|
188
|
+
#
|
189
|
+
# grid_update_freq = int(stop_grid_update_step / grid_update_num)
|
190
|
+
#
|
191
|
+
# if opt == "Adam":
|
192
|
+
# optimizer = torch.optim.Adam(model.get_params(), lr=lr)
|
193
|
+
# elif opt == "LBFGS":
|
194
|
+
# optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
195
|
+
# tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
196
|
+
# else:
|
197
|
+
# optimizer = torch.optim.SGD(model.get_params(), lr=lr)
|
198
|
+
#
|
199
|
+
# lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
|
200
|
+
#
|
201
|
+
# results = {'train_losses': [], 'val_losses': [], 'regularize': [], 'accuracies': [],
|
202
|
+
# 'precisions': [], 'recalls': [], 'f1-scores': []}
|
203
|
+
#
|
204
|
+
# steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
205
|
+
#
|
206
|
+
# train_loss = torch.zeros(1).to(model.device)
|
207
|
+
# reg_ = torch.zeros(1).to(model.device)
|
208
|
+
#
|
209
|
+
# def closure():
|
210
|
+
# nonlocal train_loss, reg_
|
211
|
+
# optimizer.zero_grad()
|
212
|
+
# pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
|
213
|
+
# loss = loss_fn(pred, batch_train_label)
|
214
|
+
# if model.save_act:
|
215
|
+
# if reg_metric == 'edge_backward':
|
216
|
+
# model.attribute()
|
217
|
+
# if reg_metric == 'node_backward':
|
218
|
+
# model.node_attribute()
|
219
|
+
# reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
220
|
+
# else:
|
221
|
+
# reg_ = torch.tensor(0.)
|
222
|
+
# objective = loss + lamb * reg_
|
223
|
+
# train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
|
224
|
+
# objective.backward()
|
225
|
+
# return objective
|
226
|
+
#
|
227
|
+
# if save_fig:
|
228
|
+
# if not os.path.exists(img_folder):
|
229
|
+
# os.makedirs(img_folder)
|
230
|
+
#
|
231
|
+
# for epoch in range(epochs):
|
232
|
+
#
|
233
|
+
# if epoch == epochs - 1 and old_save_act:
|
234
|
+
# model.save_act = True
|
235
|
+
#
|
236
|
+
# if save_fig and epoch % save_fig_freq == 0:
|
237
|
+
# save_act = model.save_act
|
238
|
+
# model.save_act = True
|
239
|
+
#
|
240
|
+
# train_indices = np.arange(dataset['train_input'].shape[0])
|
241
|
+
# np.random.shuffle(train_indices)
|
242
|
+
# train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
|
243
|
+
# for batch_num in train_pbar:
|
244
|
+
# step = epoch * steps + batch_num + 1
|
245
|
+
# i = batch_num * batch_size
|
246
|
+
# batch_train_id = train_indices[i:i + batch_size]
|
247
|
+
# batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
|
248
|
+
# batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
|
249
|
+
#
|
250
|
+
# if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
|
251
|
+
# model.update_grid(batch_train_input)
|
252
|
+
#
|
253
|
+
# if opt == "LBFGS":
|
254
|
+
# optimizer.step(closure)
|
255
|
+
#
|
256
|
+
# else:
|
257
|
+
# optimizer.zero_grad()
|
258
|
+
# pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
|
259
|
+
# y_th=y_th)
|
260
|
+
# loss = loss_fn(pred, batch_train_label)
|
261
|
+
# if model.save_act:
|
262
|
+
# if reg_metric == 'edge_backward':
|
263
|
+
# model.attribute()
|
264
|
+
# if reg_metric == 'node_backward':
|
265
|
+
# model.node_attribute()
|
266
|
+
# reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
267
|
+
# else:
|
268
|
+
# reg_ = torch.tensor(0.)
|
269
|
+
# loss = loss + lamb * reg_
|
270
|
+
# train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
|
271
|
+
# loss.backward()
|
272
|
+
# optimizer.step()
|
273
|
+
# train_pbar.set_postfix(loss=train_loss.item())
|
274
|
+
#
|
275
|
+
# print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
|
276
|
+
# val_loss = torch.zeros(1).to(model.device)
|
277
|
+
# with torch.no_grad():
|
278
|
+
# test_indices = np.arange(dataset['test_input'].shape[0])
|
279
|
+
# np.random.shuffle(test_indices)
|
280
|
+
# test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
|
281
|
+
# test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
|
282
|
+
# for batch_num in test_pbar:
|
283
|
+
# i = batch_num * batch_size_test
|
284
|
+
# batch_test_id = test_indices[i:i + batch_size_test]
|
285
|
+
# batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
|
286
|
+
# batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
|
287
|
+
#
|
288
|
+
# outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
|
289
|
+
# y_th=y_th)
|
290
|
+
#
|
291
|
+
# loss = loss_fn(outputs, batch_test_label)
|
292
|
+
#
|
293
|
+
# val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
|
294
|
+
# test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
|
295
|
+
# if label is not None:
|
296
|
+
# diffs = torch.abs(outputs - label)
|
297
|
+
# closest_indices = torch.argmin(diffs, dim=1)
|
298
|
+
# closest_values = label[closest_indices]
|
299
|
+
# all_predictions.extend(closest_values.detach().cpu().numpy())
|
300
|
+
# all_labels.extend(batch_test_label.detach().cpu().numpy())
|
301
|
+
#
|
302
|
+
# lr_scheduler.step(val_loss)
|
303
|
+
#
|
304
|
+
# results['train_losses'].append(train_loss.cpu().item())
|
305
|
+
# results['val_losses'].append(val_loss.cpu().item())
|
306
|
+
# results['regularize'].append(reg_.cpu().item())
|
307
|
+
#
|
308
|
+
# if save_fig and epoch % save_fig_freq == 0:
|
309
|
+
# model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
|
310
|
+
# beta=beta)
|
311
|
+
# plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
|
312
|
+
# plt.close()
|
313
|
+
# model.save_act = save_act
|
314
|
+
#
|
315
|
+
# # append_to_results_file(results_file, results, result_info)
|
316
|
+
# model.log_history('fit')
|
317
|
+
# model.symbolic_enabled = old_symbolic_enabled
|
318
|
+
# return results
|
yms_kan/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.0.
|
1
|
+
__version__ = "0.0.2" # 初始版本
|
@@ -1,7 +1,7 @@
|
|
1
1
|
yms_kan/KANLayer.py,sha256=-V2Fh5wvPYvfF1tmQVxJKWvvaAHiwo2EiFpd8VDgB1c,14149
|
2
2
|
yms_kan/LBFGS.py,sha256=OPeRPDp40jaVH4qPoBDMEub7TPhyvw7pbqwQar3OZ1A,17620
|
3
3
|
yms_kan/MLP.py,sha256=ryLzSuBrsGlSHRLwnQZCNj-Ru9BwXJYoHNkAwX14N64,12804
|
4
|
-
yms_kan/MultKAN.py,sha256=
|
4
|
+
yms_kan/MultKAN.py,sha256=n58W6tORBDuh_-pUVp-ER-R9KNJKdYjFaisDx8wJpWw,122113
|
5
5
|
yms_kan/Symbolic_KANLayer.py,sha256=WhJzC5IMIpXI_K7aYamOrWTK7uckxVdsM9N4oLZMO3I,9897
|
6
6
|
yms_kan/__init__.py,sha256=O2c6DIG4PHavXF2v7K9jNqMbJXWr4-gTN3Vs1YSlc64,120
|
7
7
|
yms_kan/compiler.py,sha256=7bVwDNX0xmLAjQ8V1FdmkIIIibmy_W5eaeSKBlYL0Vc,18632
|
@@ -10,11 +10,13 @@ yms_kan/feynman.py,sha256=Eisf69K49s4C6UlPEi5LnNK_p5TUJQLBKxMp-sW0a9w,33687
|
|
10
10
|
yms_kan/hypothesis.py,sha256=Ec20xadfgOSSWeZHQaGn-h9F2PY7LWFU3iniNI2Zd_4,23165
|
11
11
|
yms_kan/spline.py,sha256=ZXyGwl2Sc-UrnrcuUXeUQkBOMnetaWcHrbpZaqatCvs,4345
|
12
12
|
yms_kan/tool.py,sha256=CLIsOYWwG-A5PJvoyIP8cRBzX8iRhEssW-2uXdLfi-U,12124
|
13
|
-
yms_kan/train_eval_utils.py,sha256=
|
13
|
+
yms_kan/train_eval_utils.py,sha256=73pA3-HDPDik_yCsDW0oF1dIvVu_vPeHbvJ08o26ygQ,14867
|
14
14
|
yms_kan/utils.py,sha256=J07L-tgmc1OfU6Tl6mGwHJRizjFN75EJK8BxejaZLUc,23860
|
15
|
-
yms_kan/version.py,sha256=
|
16
|
-
yms_kan
|
17
|
-
yms_kan
|
18
|
-
yms_kan-0.0.
|
19
|
-
yms_kan-0.0.
|
20
|
-
yms_kan-0.0.
|
15
|
+
yms_kan/version.py,sha256=qeSnHAh3t9Zb2L0FPUF5OaQvWEJcfTki6FmrfynjWz4,39
|
16
|
+
yms_kan/assets/img/mult_symbol.png,sha256=2f4xUKdweft-qUbHjFI5h9-smnEtc0FWq8hNYZhPAXY,6392
|
17
|
+
yms_kan/assets/img/sum_symbol.png,sha256=94QkMUzmEjlCq_yf14nMEQmettaq86FmlGfdl22b4XE,6210
|
18
|
+
yms_kan-0.0.2.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
|
19
|
+
yms_kan-0.0.2.dist-info/METADATA,sha256=jTD-nNMWFF64GiFO2-bYQePNxJh4J1-yi4eniWT1djQ,240
|
20
|
+
yms_kan-0.0.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
21
|
+
yms_kan-0.0.2.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
|
22
|
+
yms_kan-0.0.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|