yms-kan 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan/MLP.py +81 -71
- yms_kan/MultKAN.py +6 -4
- yms_kan/train_eval_utils.py +161 -159
- yms_kan/utils.py +75 -3
- yms_kan/version.py +1 -1
- {yms_kan-0.0.9.dist-info → yms_kan-0.0.10.dist-info}/METADATA +1 -1
- {yms_kan-0.0.9.dist-info → yms_kan-0.0.10.dist-info}/RECORD +10 -10
- {yms_kan-0.0.9.dist-info → yms_kan-0.0.10.dist-info}/WHEEL +1 -1
- {yms_kan-0.0.9.dist-info → yms_kan-0.0.10.dist-info}/licenses/LICENSE +0 -0
- {yms_kan-0.0.9.dist-info → yms_kan-0.0.10.dist-info}/top_level.txt +0 -0
yms_kan/MLP.py
CHANGED
@@ -3,43 +3,44 @@ import torch.nn as nn
|
|
3
3
|
import matplotlib.pyplot as plt
|
4
4
|
import numpy as np
|
5
5
|
from tqdm import tqdm
|
6
|
+
|
7
|
+
from . import plot_tree
|
6
8
|
from .LBFGS import LBFGS
|
7
9
|
|
8
10
|
seed = 0
|
9
11
|
torch.manual_seed(seed)
|
10
12
|
|
13
|
+
|
11
14
|
class MLP(nn.Module):
|
12
|
-
|
15
|
+
|
13
16
|
def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
|
14
17
|
super(MLP, self).__init__()
|
15
|
-
|
18
|
+
|
16
19
|
torch.manual_seed(seed)
|
17
|
-
|
20
|
+
|
18
21
|
linears = []
|
19
22
|
self.width = width
|
20
23
|
self.depth = depth = len(width) - 1
|
21
24
|
for i in range(depth):
|
22
|
-
linears.append(nn.Linear(width[i], width[i+1]))
|
25
|
+
linears.append(nn.Linear(width[i], width[i + 1]))
|
23
26
|
self.linears = nn.ModuleList(linears)
|
24
|
-
|
27
|
+
|
25
28
|
#if activation == 'silu':
|
26
29
|
self.act_fun = torch.nn.SiLU()
|
27
30
|
self.save_act = save_act
|
28
31
|
self.acts = None
|
29
|
-
|
32
|
+
|
30
33
|
self.cache_data = None
|
31
|
-
|
34
|
+
|
32
35
|
self.device = device
|
33
36
|
self.to(device)
|
34
|
-
|
35
|
-
|
37
|
+
|
36
38
|
def to(self, device):
|
37
39
|
super(MLP, self).to(device)
|
38
40
|
self.device = device
|
39
|
-
|
41
|
+
|
40
42
|
return self
|
41
|
-
|
42
|
-
|
43
|
+
|
43
44
|
def get_act(self, x=None):
|
44
45
|
if isinstance(x, dict):
|
45
46
|
x = x['train_input']
|
@@ -52,23 +53,23 @@ class MLP(nn.Module):
|
|
52
53
|
self.save_act = True
|
53
54
|
self.forward(x)
|
54
55
|
self.save_act = save_act
|
55
|
-
|
56
|
+
|
56
57
|
@property
|
57
58
|
def w(self):
|
58
59
|
return [self.linears[l].weight for l in range(self.depth)]
|
59
|
-
|
60
|
+
|
60
61
|
def forward(self, x):
|
61
|
-
|
62
|
+
|
62
63
|
# cache data
|
63
64
|
self.cache_data = x
|
64
|
-
|
65
|
+
|
65
66
|
self.acts = []
|
66
67
|
self.acts_scale = []
|
67
68
|
self.wa_forward = []
|
68
69
|
self.a_forward = []
|
69
|
-
|
70
|
+
|
70
71
|
for i in range(self.depth):
|
71
|
-
|
72
|
+
|
72
73
|
if self.save_act:
|
73
74
|
act = x.clone()
|
74
75
|
act_scale = torch.std(x, dim=0)
|
@@ -77,7 +78,7 @@ class MLP(nn.Module):
|
|
77
78
|
if i > 0:
|
78
79
|
self.acts_scale.append(act_scale)
|
79
80
|
self.wa_forward.append(wa_forward)
|
80
|
-
|
81
|
+
|
81
82
|
x = self.linears[i](x)
|
82
83
|
if i < self.depth - 1:
|
83
84
|
x = self.act_fun(x)
|
@@ -85,9 +86,9 @@ class MLP(nn.Module):
|
|
85
86
|
if self.save_act:
|
86
87
|
act_scale = torch.std(x, dim=0)
|
87
88
|
self.acts_scale.append(act_scale)
|
88
|
-
|
89
|
+
|
89
90
|
return x
|
90
|
-
|
91
|
+
|
91
92
|
def attribute(self):
|
92
93
|
if self.acts == None:
|
93
94
|
self.get_act()
|
@@ -99,47 +100,47 @@ class MLP(nn.Module):
|
|
99
100
|
node_score = torch.ones(self.width[-1]).requires_grad_(True).to(self.device)
|
100
101
|
node_scores.append(node_score)
|
101
102
|
|
102
|
-
for l in range(self.depth,0
|
103
|
-
|
104
|
-
|
103
|
+
for l in range(self.depth, 0, -1):
|
104
|
+
edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l - 1]),
|
105
|
+
node_score / (self.acts_scale[l - 1] + 1e-4))
|
105
106
|
edge_scores.append(edge_score)
|
106
107
|
|
107
108
|
# this might be improper for MLPs (although reasonable for KANs)
|
108
|
-
node_score = torch.sum(edge_score, dim=0)/torch.sqrt(torch.tensor(self.width[l-1], device=self.device))
|
109
|
+
node_score = torch.sum(edge_score, dim=0) / torch.sqrt(torch.tensor(self.width[l - 1], device=self.device))
|
109
110
|
#print(self.width[l])
|
110
111
|
node_scores.append(node_score)
|
111
112
|
|
112
113
|
self.node_scores = list(reversed(node_scores))
|
113
114
|
self.edge_scores = list(reversed(edge_scores))
|
114
115
|
self.wa_backward = self.edge_scores
|
115
|
-
|
116
|
+
|
116
117
|
def plot(self, beta=3, scale=1., metric='w'):
|
117
118
|
# metric = 'w', 'act' or 'fa'
|
118
|
-
|
119
|
+
|
119
120
|
if metric == 'fa':
|
120
121
|
self.attribute()
|
121
|
-
|
122
|
+
|
122
123
|
depth = self.depth
|
123
124
|
y0 = 0.5
|
124
|
-
fig, ax = plt.subplots(figsize=(3*scale,3*y0*depth*scale))
|
125
|
+
fig, ax = plt.subplots(figsize=(3 * scale, 3 * y0 * depth * scale))
|
125
126
|
shp = self.width
|
126
|
-
|
127
|
-
min_spacing = 1/max(self.width)
|
127
|
+
|
128
|
+
min_spacing = 1 / max(self.width)
|
128
129
|
for j in range(len(shp)):
|
129
130
|
N = shp[j]
|
130
131
|
for i in range(N):
|
131
132
|
plt.scatter(1 / (2 * N) + i / N, j * y0, s=min_spacing ** 2 * 5000 * scale ** 2, color='black')
|
132
|
-
|
133
|
-
plt.ylim(-0.1*y0,y0*depth+0.1*y0)
|
134
|
-
plt.xlim(-0.02,1.02)
|
133
|
+
|
134
|
+
plt.ylim(-0.1 * y0, y0 * depth + 0.1 * y0)
|
135
|
+
plt.xlim(-0.02, 1.02)
|
135
136
|
|
136
137
|
linears = self.linears
|
137
|
-
|
138
|
+
|
138
139
|
for ii in range(len(linears)):
|
139
140
|
linear = linears[ii]
|
140
141
|
p = linear.weight
|
141
142
|
p_shp = p.shape
|
142
|
-
|
143
|
+
|
143
144
|
if metric == 'w':
|
144
145
|
pass
|
145
146
|
elif metric == 'act':
|
@@ -150,12 +151,15 @@ class MLP(nn.Module):
|
|
150
151
|
raise Exception('metric = \'{}\' not recognized. Choices are \'w\', \'act\', \'fa\'.'.format(metric))
|
151
152
|
for i in range(p_shp[0]):
|
152
153
|
for j in range(p_shp[1]):
|
153
|
-
plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]],
|
154
|
-
|
154
|
+
plt.plot([1 / (2 * p_shp[0]) + i / p_shp[0], 1 / (2 * p_shp[1]) + j / p_shp[1]],
|
155
|
+
[y0 * (ii + 1), y0 * ii], lw=0.5 * scale,
|
156
|
+
alpha=np.tanh(beta * np.abs(p[i, j].cpu().detach().numpy())),
|
157
|
+
color="blue" if p[i, j] > 0 else "red")
|
158
|
+
|
155
159
|
ax.axis('off')
|
156
|
-
|
160
|
+
|
157
161
|
def reg(self, reg_metric, lamb_l1, lamb_entropy):
|
158
|
-
|
162
|
+
|
159
163
|
if reg_metric == 'w':
|
160
164
|
acts_scale = self.w
|
161
165
|
if reg_metric == 'act':
|
@@ -164,7 +168,7 @@ class MLP(nn.Module):
|
|
164
168
|
acts_scale = self.wa_backward
|
165
169
|
if reg_metric == 'a':
|
166
170
|
acts_scale = self.acts_scale
|
167
|
-
|
171
|
+
|
168
172
|
if len(acts_scale[0].shape) == 2:
|
169
173
|
reg_ = 0.
|
170
174
|
|
@@ -178,9 +182,9 @@ class MLP(nn.Module):
|
|
178
182
|
entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1))
|
179
183
|
entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0))
|
180
184
|
reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col)
|
181
|
-
|
185
|
+
|
182
186
|
elif len(acts_scale[0].shape) == 1:
|
183
|
-
|
187
|
+
|
184
188
|
reg_ = 0.
|
185
189
|
|
186
190
|
for i in range(len(acts_scale)):
|
@@ -193,20 +197,21 @@ class MLP(nn.Module):
|
|
193
197
|
reg_ += lamb_l1 * l1 + lamb_entropy * entropy
|
194
198
|
|
195
199
|
return reg_
|
196
|
-
|
200
|
+
|
197
201
|
def get_reg(self, reg_metric, lamb_l1, lamb_entropy):
|
198
202
|
return self.reg(reg_metric, lamb_l1, lamb_entropy)
|
199
|
-
|
200
|
-
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1.,
|
201
|
-
|
203
|
+
|
204
|
+
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1.,
|
205
|
+
batch=-1,
|
206
|
+
metrics=None, device='cpu', reg_metric='w', display_metrics=None):
|
202
207
|
|
203
208
|
if lamb > 0. and not self.save_act:
|
204
209
|
print('setting lamb=0. If you want to set lamb > 0, set =True')
|
205
|
-
|
210
|
+
|
206
211
|
old_save_act = self.save_act
|
207
212
|
if lamb == 0.:
|
208
213
|
self.save_act = False
|
209
|
-
|
214
|
+
|
210
215
|
pbar = tqdm(range(steps), desc='description', ncols=100)
|
211
216
|
|
212
217
|
if loss_fn == None:
|
@@ -217,7 +222,8 @@ class MLP(nn.Module):
|
|
217
222
|
if opt == "Adam":
|
218
223
|
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
|
219
224
|
elif opt == "LBFGS":
|
220
|
-
optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
225
|
+
optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
226
|
+
tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
221
227
|
|
222
228
|
results = {}
|
223
229
|
results['train_loss'] = []
|
@@ -252,10 +258,10 @@ class MLP(nn.Module):
|
|
252
258
|
return objective
|
253
259
|
|
254
260
|
for _ in pbar:
|
255
|
-
|
256
|
-
if _ == steps-1 and old_save_act:
|
261
|
+
|
262
|
+
if _ == steps - 1 and old_save_act:
|
257
263
|
self.save_act = True
|
258
|
-
|
264
|
+
|
259
265
|
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
|
260
266
|
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
|
261
267
|
|
@@ -274,9 +280,9 @@ class MLP(nn.Module):
|
|
274
280
|
loss.backward()
|
275
281
|
optimizer.step()
|
276
282
|
|
277
|
-
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)),
|
278
|
-
|
279
|
-
|
283
|
+
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)),
|
284
|
+
dataset['test_label'][test_id].to(self.device))
|
285
|
+
|
280
286
|
if metrics != None:
|
281
287
|
for i in range(len(metrics)):
|
282
288
|
results[metrics[i].__name__].append(metrics[i]().item())
|
@@ -287,7 +293,9 @@ class MLP(nn.Module):
|
|
287
293
|
|
288
294
|
if _ % log == 0:
|
289
295
|
if display_metrics == None:
|
290
|
-
pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (
|
296
|
+
pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (
|
297
|
+
torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(),
|
298
|
+
reg_.cpu().detach().numpy()))
|
291
299
|
else:
|
292
300
|
string = ''
|
293
301
|
data = ()
|
@@ -299,9 +307,9 @@ class MLP(nn.Module):
|
|
299
307
|
raise Exception(f'{metric} not recognized')
|
300
308
|
data += (results[metric][-1],)
|
301
309
|
pbar.set_description(string % data)
|
302
|
-
|
310
|
+
|
303
311
|
return results
|
304
|
-
|
312
|
+
|
305
313
|
@property
|
306
314
|
def connection_cost(self):
|
307
315
|
|
@@ -309,8 +317,9 @@ class MLP(nn.Module):
|
|
309
317
|
cc = 0.
|
310
318
|
for linear in self.linears:
|
311
319
|
t = torch.abs(linear.weight)
|
320
|
+
|
312
321
|
def get_coordinate(n):
|
313
|
-
return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n)
|
322
|
+
return torch.linspace(0, 1, steps=n + 1, device=self.device)[:n] + 1 / (2 * n)
|
314
323
|
|
315
324
|
in_dim = t.shape[0]
|
316
325
|
x_in = get_coordinate(in_dim)
|
@@ -318,44 +327,45 @@ class MLP(nn.Module):
|
|
318
327
|
out_dim = t.shape[1]
|
319
328
|
x_out = get_coordinate(out_dim)
|
320
329
|
|
321
|
-
dist = torch.abs(x_in[:,None] - x_out[None
|
330
|
+
dist = torch.abs(x_in[:, None] - x_out[None, :])
|
322
331
|
cc += torch.sum(dist * t)
|
323
332
|
|
324
333
|
return cc
|
325
|
-
|
334
|
+
|
326
335
|
def swap(self, l, i1, i2):
|
327
336
|
|
328
337
|
def swap_row(data, i1, i2):
|
329
338
|
data[i1], data[i2] = data[i2].clone(), data[i1].clone()
|
330
339
|
|
331
340
|
def swap_col(data, i1, i2):
|
332
|
-
data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
|
341
|
+
data[:, i1], data[:, i2] = data[:, i2].clone(), data[:, i1].clone()
|
333
342
|
|
334
|
-
swap_row(self.linears[l-1].weight.data, i1, i2)
|
335
|
-
swap_row(self.linears[l-1].bias.data, i1, i2)
|
343
|
+
swap_row(self.linears[l - 1].weight.data, i1, i2)
|
344
|
+
swap_row(self.linears[l - 1].bias.data, i1, i2)
|
336
345
|
swap_col(self.linears[l].weight.data, i1, i2)
|
337
|
-
|
346
|
+
|
338
347
|
def auto_swap_l(self, l):
|
339
348
|
|
340
349
|
num = self.width[l]
|
341
350
|
for i in range(num):
|
342
351
|
ccs = []
|
343
352
|
for j in range(num):
|
344
|
-
self.swap(l,i,j)
|
353
|
+
self.swap(l, i, j)
|
345
354
|
self.get_act()
|
346
355
|
self.attribute()
|
347
356
|
cc = self.connection_cost.detach().clone()
|
348
357
|
ccs.append(cc)
|
349
|
-
self.swap(l,i,j)
|
358
|
+
self.swap(l, i, j)
|
350
359
|
j = torch.argmin(torch.tensor(ccs))
|
351
|
-
self.swap(l,i,j)
|
360
|
+
self.swap(l, i, j)
|
352
361
|
|
353
362
|
def auto_swap(self):
|
354
363
|
depth = self.depth
|
355
364
|
for l in range(1, depth):
|
356
365
|
self.auto_swap_l(l)
|
357
|
-
|
366
|
+
|
358
367
|
def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
|
359
368
|
if x == None:
|
360
369
|
x = self.cache_data
|
361
|
-
plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test,
|
370
|
+
plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test,
|
371
|
+
verbose=verbose)
|
yms_kan/MultKAN.py
CHANGED
@@ -794,6 +794,8 @@ class MultKAN(nn.Module):
|
|
794
794
|
# >>> print(model(x, singularity_avoiding=True))
|
795
795
|
# >>> print(model(x, singularity_avoiding=True, y_th=1.))
|
796
796
|
"""
|
797
|
+
# x = abs(torch.fft(x, dim=-1,norm='forward'))
|
798
|
+
# _,x = x.chunk(2,dim=-1)
|
797
799
|
x = x[:, self.input_id.long()]
|
798
800
|
assert x.shape[1] == self.width_in[0]
|
799
801
|
|
@@ -1063,7 +1065,7 @@ class MultKAN(nn.Module):
|
|
1063
1065
|
ha='center', va='center', transform=ax.transData)
|
1064
1066
|
|
1065
1067
|
def plot(self, folder="./figures", beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None,
|
1066
|
-
out_vars=None, title=None, varscale=1.0):
|
1068
|
+
out_vars=None, title=None, varscale=1.0, dpi=100):
|
1067
1069
|
'''
|
1068
1070
|
plot KAN
|
1069
1071
|
|
@@ -1164,7 +1166,7 @@ class MultKAN(nn.Module):
|
|
1164
1166
|
s=400 * scale ** 2)
|
1165
1167
|
plt.gca().spines[:].set_color(color)
|
1166
1168
|
|
1167
|
-
plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=
|
1169
|
+
plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=dpi)
|
1168
1170
|
plt.close()
|
1169
1171
|
|
1170
1172
|
def score2alpha(score):
|
@@ -1647,7 +1649,7 @@ class MultKAN(nn.Module):
|
|
1647
1649
|
|
1648
1650
|
if save_fig and _ % save_fig_freq == 0:
|
1649
1651
|
self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
|
1650
|
-
plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight'
|
1652
|
+
plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight')
|
1651
1653
|
plt.close()
|
1652
1654
|
self.save_act = save_act
|
1653
1655
|
|
@@ -1857,7 +1859,7 @@ class MultKAN(nn.Module):
|
|
1857
1859
|
if save_fig and epoch % save_fig_freq == 0:
|
1858
1860
|
self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
|
1859
1861
|
beta=beta)
|
1860
|
-
plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight'
|
1862
|
+
plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight')
|
1861
1863
|
plt.close()
|
1862
1864
|
self.save_act = save_act
|
1863
1865
|
|
yms_kan/train_eval_utils.py
CHANGED
@@ -13,11 +13,11 @@ from yms_kan.plotting import plot_confusion_matrix
|
|
13
13
|
from yms_kan.tool import initialize_results_file, append_to_results_file, calculate_metric
|
14
14
|
|
15
15
|
|
16
|
-
def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_file=None, opt="LBFGS", epochs=100,
|
16
|
+
def train_val(model, dataset: dict, batch_size, batch_size_test, save_path=None, txt_file=None, opt="LBFGS", epochs=100,
|
17
17
|
lamb=0.,
|
18
18
|
lamb_l1=1., label=None, class_dict=None, lamb_entropy=2., lamb_coef=0.,
|
19
19
|
lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
|
20
|
-
|
20
|
+
stop_grid_update_epoch=100,
|
21
21
|
save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
22
22
|
singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
|
23
23
|
best = -1
|
@@ -42,7 +42,6 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
42
42
|
else:
|
43
43
|
loss_fn = loss_fn
|
44
44
|
|
45
|
-
grid_update_freq = int(stop_grid_update_step / grid_update_num)
|
46
45
|
|
47
46
|
if opt == "Adam":
|
48
47
|
optimizer = torch.optim.Adam(model.get_params(), lr=lr)
|
@@ -55,9 +54,11 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
55
54
|
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
|
56
55
|
|
57
56
|
results = {'train_losses': [], 'val_losses': [], 'accuracies': [], 'precisions': [], 'recalls': [], 'f1-scores': [],
|
58
|
-
'lrs': [], '
|
57
|
+
'lrs': [], 'regularize': []} # , 'all_predictions': [], 'all_labels': []
|
59
58
|
|
60
59
|
steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
60
|
+
stop_grid_update_step = stop_grid_update_epoch * steps
|
61
|
+
grid_update_freq = int(stop_grid_update_step / grid_update_num)
|
61
62
|
|
62
63
|
train_loss = torch.zeros(1).to(model.device)
|
63
64
|
reg_ = torch.zeros(1).to(model.device)
|
@@ -85,7 +86,6 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
85
86
|
os.makedirs(img_folder)
|
86
87
|
|
87
88
|
for epoch in range(epochs):
|
88
|
-
|
89
89
|
if epoch == epochs - 1 and old_save_act:
|
90
90
|
model.save_act = True
|
91
91
|
|
@@ -172,7 +172,9 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
172
172
|
best = m["f1-score"]
|
173
173
|
results['all_predictions'] = all_predictions
|
174
174
|
results['all_labels'] = all_labels
|
175
|
-
|
175
|
+
|
176
|
+
# if save_path is not None:
|
177
|
+
# plot_confusion_matrix(all_labels, all_predictions, class_dict, save_path)
|
176
178
|
# if save_path is not None:
|
177
179
|
# model.saveckpt(path=(os.path.join(save_path, 'save_model') + '/' + 'model'))
|
178
180
|
if txt_file is not None:
|
@@ -196,156 +198,156 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
|
|
196
198
|
model.symbolic_enabled = old_symbolic_enabled
|
197
199
|
return results
|
198
200
|
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
#
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
#
|
339
|
-
#
|
340
|
-
#
|
341
|
-
#
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
201
|
+
|
202
|
+
def fit(model, dataset, batch_size, opt="LBFGS", epochs=100, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None,
|
203
|
+
lr=1., label=None, class_dict=None,
|
204
|
+
txt_file=None,
|
205
|
+
reg_metric='w'):
|
206
|
+
best = -1
|
207
|
+
column_order = ['epoch', 'train_losses', 'val_losses', 'accuracies', 'precisions', 'recalls',
|
208
|
+
'f1-scores', 'lrs']
|
209
|
+
custom_column_widths = {'epoch': 5, 'train_loss': 12, 'val_loss': 10, 'accuracy': 10, 'precision': 9,
|
210
|
+
'recall': 7,
|
211
|
+
'f1-score': 8,
|
212
|
+
'lr': 3}
|
213
|
+
if txt_file is not None:
|
214
|
+
initialize_results_file(txt_file, column_order)
|
215
|
+
if lamb > 0. and not model.save_act:
|
216
|
+
print('setting lamb=0. If you want to set lamb > 0, set =True')
|
217
|
+
|
218
|
+
old_save_act = model.save_act
|
219
|
+
if lamb == 0.:
|
220
|
+
model.save_act = False
|
221
|
+
|
222
|
+
# pbar = tqdm(range(steps), desc='description', ncols=100)
|
223
|
+
|
224
|
+
if loss_fn == None:
|
225
|
+
loss_fn = lambda x, y: torch.mean((x - y) ** 2)
|
226
|
+
else:
|
227
|
+
loss_fn = loss_fn
|
228
|
+
|
229
|
+
if opt == "Adam":
|
230
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
231
|
+
elif opt == "LBFGS":
|
232
|
+
optimizer = LBFGS(model.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
233
|
+
tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
234
|
+
else:
|
235
|
+
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True)
|
236
|
+
|
237
|
+
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
|
238
|
+
|
239
|
+
results = {'train_losses': [], 'val_losses': [], 'accuracies': [], 'precisions': [], 'recalls': [], 'f1-scores': [],
|
240
|
+
'lrs': [], 'regularize': []}
|
241
|
+
|
242
|
+
steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
243
|
+
|
244
|
+
train_loss = torch.zeros(1).to(model.device)
|
245
|
+
reg_ = torch.zeros(1).to(model.device)
|
246
|
+
|
247
|
+
def closure():
|
248
|
+
nonlocal train_loss, reg_
|
249
|
+
optimizer.zero_grad()
|
250
|
+
pred = model.forward(batch_train_input)
|
251
|
+
loss = loss_fn(pred, batch_train_label)
|
252
|
+
if model.save_act:
|
253
|
+
if reg_metric == 'edge_backward':
|
254
|
+
model.attribute()
|
255
|
+
if reg_metric == 'node_backward':
|
256
|
+
model.node_attribute()
|
257
|
+
reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy)
|
258
|
+
else:
|
259
|
+
reg_ = torch.tensor(0.)
|
260
|
+
objective = loss + lamb * reg_
|
261
|
+
train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
|
262
|
+
objective.backward()
|
263
|
+
return objective
|
264
|
+
|
265
|
+
for epoch in range(epochs):
|
266
|
+
|
267
|
+
if epoch == steps - 1 and old_save_act:
|
268
|
+
model.save_act = True
|
269
|
+
|
270
|
+
train_indices = np.arange(dataset['train_input'].shape[0])
|
271
|
+
np.random.shuffle(train_indices)
|
272
|
+
train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
|
273
|
+
|
274
|
+
for batch_num in train_pbar:
|
275
|
+
step = epoch * steps + batch_num + 1
|
276
|
+
i = batch_num * batch_size
|
277
|
+
batch_train_id = train_indices[i:i + batch_size]
|
278
|
+
batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
|
279
|
+
batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
|
280
|
+
|
281
|
+
if opt == "LBFGS":
|
282
|
+
optimizer.step(closure)
|
283
|
+
|
284
|
+
if opt == "Adam":
|
285
|
+
optimizer.zero_grad()
|
286
|
+
pred = model.forward(batch_train_input)
|
287
|
+
train_loss = loss_fn(pred, batch_train_input)
|
288
|
+
if model.save_act:
|
289
|
+
reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy)
|
290
|
+
else:
|
291
|
+
reg_ = torch.tensor(0.)
|
292
|
+
loss = train_loss + lamb * reg_
|
293
|
+
train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
|
294
|
+
|
295
|
+
loss.backward()
|
296
|
+
optimizer.step()
|
297
|
+
|
298
|
+
train_pbar.set_postfix(loss=train_loss.item())
|
299
|
+
|
300
|
+
val_loss = torch.zeros(1).to(model.device)
|
301
|
+
with torch.no_grad():
|
302
|
+
all_predictions = []
|
303
|
+
all_labels = []
|
304
|
+
test_indices = np.arange(dataset['test_input'].shape[0])
|
305
|
+
np.random.shuffle(test_indices)
|
306
|
+
test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size)
|
307
|
+
test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
|
308
|
+
for batch_num in test_pbar:
|
309
|
+
i = batch_num * batch_size
|
310
|
+
batch_test_id = test_indices[i:i + batch_size]
|
311
|
+
batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
|
312
|
+
batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
|
313
|
+
outputs = model.forward(batch_test_input)
|
314
|
+
loss = loss_fn(outputs, batch_test_label)
|
315
|
+
val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
|
316
|
+
test_pbar.set_postfix(val_loss=val_loss.item())
|
317
|
+
if label is not None:
|
318
|
+
diffs = torch.abs(outputs - label)
|
319
|
+
closest_indices = torch.argmin(diffs, dim=1)
|
320
|
+
closest_values = label[closest_indices]
|
321
|
+
all_predictions.extend(closest_values.detach().cpu().numpy())
|
322
|
+
all_labels.extend(batch_test_label.detach().cpu().numpy())
|
323
|
+
|
324
|
+
train_lr = lr_scheduler.get_last_lr()[0]
|
325
|
+
lr_scheduler.step(val_loss)
|
326
|
+
|
327
|
+
if label is not None:
|
328
|
+
m = calculate_metric(all_labels, all_predictions, class_dict)
|
329
|
+
print(m)
|
330
|
+
results["accuracies"].append(m["accuracy"])
|
331
|
+
results["precisions"].append(m["precision"])
|
332
|
+
results["recalls"].append(m["recall"])
|
333
|
+
results["f1-scores"].append(m["f1-score"])
|
334
|
+
results["lrs"].append(train_lr)
|
335
|
+
if best < m["f1-score"]:
|
336
|
+
best = m["f1-score"]
|
337
|
+
results['all_predictions'] = all_predictions
|
338
|
+
results['all_labels'] = all_labels
|
339
|
+
|
340
|
+
# if save_path is not None:
|
341
|
+
# plot_confusion_matrix(all_labels, all_predictions, class_dict, save_path)
|
342
|
+
# if save_path is not None:
|
343
|
+
# model.saveckpt(path=(os.path.join(save_path, 'save_model') + '/' + 'model'))
|
344
|
+
if txt_file is not None:
|
345
|
+
m.update({'lr': train_lr, 'epoch': epoch, 'train_loss': train_loss.item(), 'val_loss': val_loss.item()})
|
346
|
+
append_to_results_file(txt_file, m, column_order,
|
347
|
+
custom_column_widths=custom_column_widths)
|
348
|
+
|
349
|
+
results["train_losses"].append(train_loss.item())
|
350
|
+
results["val_losses"].append(val_loss.item())
|
351
|
+
results["regularize"].append(reg_.item())
|
352
|
+
|
353
|
+
return results
|
yms_kan/utils.py
CHANGED
@@ -75,9 +75,9 @@ SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)),
|
|
75
75
|
'0': (lambda x: x * 0, lambda x: x * 0, 0, lambda x, y_th: ((), x * 0)),
|
76
76
|
'gaussian': (lambda x: torch.exp(-x ** 2), lambda x: sympy.exp(-x ** 2), 3,
|
77
77
|
lambda x, y_th: ((), torch.exp(-x ** 2))),
|
78
|
-
#'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5),
|
79
|
-
#'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4),
|
80
|
-
#'relu': (lambda x: torch.relu(x), relu),
|
78
|
+
# 'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5),
|
79
|
+
# 'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4),
|
80
|
+
# 'relu': (lambda x: torch.relu(x), relu),
|
81
81
|
}
|
82
82
|
|
83
83
|
|
@@ -465,6 +465,78 @@ def batch_hessian(model, x, create_graph=False):
|
|
465
465
|
return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1, 0, 2)
|
466
466
|
|
467
467
|
|
468
|
+
# def create_from_data(inputs, labels, ratios=[0.8, 0.2], device='cpu'):
|
469
|
+
# # 参数校验
|
470
|
+
# if not (2 <= len(ratios) <= 3):
|
471
|
+
# raise ValueError("ratios必须是长度为2或3的列表")
|
472
|
+
# if not np.isclose(sum(ratios), 1.0):
|
473
|
+
# raise ValueError("ratios元素之和必须为1")
|
474
|
+
# for r in ratios:
|
475
|
+
# if r < 0 or r > 1:
|
476
|
+
# raise ValueError("ratios元素必须在0-1之间")
|
477
|
+
#
|
478
|
+
# from collections import defaultdict
|
479
|
+
# class_indices = defaultdict(list)
|
480
|
+
# for idx, label in enumerate(labels):
|
481
|
+
# class_indices[label.item()].append(idx)
|
482
|
+
#
|
483
|
+
# # 初始化各数据集索引
|
484
|
+
# split_indices = defaultdict(list)
|
485
|
+
#
|
486
|
+
# # 分层抽样
|
487
|
+
# for class_label, indices in class_indices.items():
|
488
|
+
# if not indices:
|
489
|
+
# continue
|
490
|
+
#
|
491
|
+
# num_samples = len(indices)
|
492
|
+
# np.random.shuffle(indices) # 先打乱类别内样本
|
493
|
+
#
|
494
|
+
# # 计算各子集分割点
|
495
|
+
# split_points = np.cumsum(ratios).tolist()[:-1] # 计算分割点(不含最后一个1)
|
496
|
+
# splits = np.split(indices, [int(num_samples * p) for p in split_points])
|
497
|
+
#
|
498
|
+
# # 分配到对应数据集
|
499
|
+
# for i, subset in enumerate(splits):
|
500
|
+
# if i == 0:
|
501
|
+
# split_indices['train'].extend(subset)
|
502
|
+
# elif i == 1:
|
503
|
+
# split_indices['test'].extend(subset)
|
504
|
+
# elif i == 2:
|
505
|
+
# split_indices['val'].extend(subset)
|
506
|
+
#
|
507
|
+
# # 处理数据集合并(验证集可选)
|
508
|
+
# train_val = {}
|
509
|
+
# if 'test' in split_indices:
|
510
|
+
# # 合并训练+验证并打乱
|
511
|
+
# train_val_idx = np.concatenate([
|
512
|
+
# np.array(split_indices['train']),
|
513
|
+
# np.array(split_indices['val'])
|
514
|
+
# ])
|
515
|
+
# np.random.shuffle(train_val_idx)
|
516
|
+
# train_val = {
|
517
|
+
# 'train_input': inputs[train_val_idx[:len(split_indices['train'])]].detach().to(device),
|
518
|
+
# 'train_label': labels[train_val_idx[:len(split_indices['train'])]].detach().to(device),
|
519
|
+
# 'test_input': inputs[train_val_idx[len(split_indices['train']):]].detach().to(device),
|
520
|
+
# 'test_label': labels[train_val_idx[len(split_indices['train']):]].detach().to(device)
|
521
|
+
# }
|
522
|
+
# else:
|
523
|
+
# # 只有训练集
|
524
|
+
# train_idx = np.array(split_indices['train'])
|
525
|
+
# np.random.shuffle(train_idx)
|
526
|
+
# train_val = {
|
527
|
+
# 'train_input': inputs[train_idx].detach().to(device),
|
528
|
+
# 'train_label': labels[train_idx].detach().to(device)
|
529
|
+
# }
|
530
|
+
#
|
531
|
+
# # 处理测试集
|
532
|
+
# test_idx = np.array(split_indices.get('val', []))
|
533
|
+
# np.random.shuffle(test_idx)
|
534
|
+
# test_set = {
|
535
|
+
# 'val_input': inputs[test_idx].detach().to(device) if test_idx.size else None,
|
536
|
+
# 'val_label': labels[test_idx].detach().to(device) if test_idx.size else None
|
537
|
+
# }
|
538
|
+
#
|
539
|
+
# return train_val, test_set
|
468
540
|
def create_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
|
469
541
|
from collections import defaultdict
|
470
542
|
class_indices = defaultdict(list)
|
yms_kan/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.0.
|
1
|
+
__version__ = "0.0.10" # 初始版本
|
@@ -1,7 +1,7 @@
|
|
1
1
|
yms_kan/KANLayer.py,sha256=-V2Fh5wvPYvfF1tmQVxJKWvvaAHiwo2EiFpd8VDgB1c,14149
|
2
2
|
yms_kan/LBFGS.py,sha256=OPeRPDp40jaVH4qPoBDMEub7TPhyvw7pbqwQar3OZ1A,17620
|
3
|
-
yms_kan/MLP.py,sha256=
|
4
|
-
yms_kan/MultKAN.py,sha256=
|
3
|
+
yms_kan/MLP.py,sha256=JFLogd1EPFVCrBJJtvNMNu68ejdUJ2O6qYc7l4pfFFI,12728
|
4
|
+
yms_kan/MultKAN.py,sha256=eFh5jCGRPUzrpeWkvonGGzNWK4w0aODN8Q6M9wZ5IaY,122193
|
5
5
|
yms_kan/Symbolic_KANLayer.py,sha256=WhJzC5IMIpXI_K7aYamOrWTK7uckxVdsM9N4oLZMO3I,9897
|
6
6
|
yms_kan/__init__.py,sha256=O2c6DIG4PHavXF2v7K9jNqMbJXWr4-gTN3Vs1YSlc64,120
|
7
7
|
yms_kan/compiler.py,sha256=7bVwDNX0xmLAjQ8V1FdmkIIIibmy_W5eaeSKBlYL0Vc,18632
|
@@ -11,13 +11,13 @@ yms_kan/hypothesis.py,sha256=Ec20xadfgOSSWeZHQaGn-h9F2PY7LWFU3iniNI2Zd_4,23165
|
|
11
11
|
yms_kan/plotting.py,sha256=Moi6QTJQxHjutGMgxR9oSsqZSzYY3TP-7WNapdCIqzw,18097
|
12
12
|
yms_kan/spline.py,sha256=ZXyGwl2Sc-UrnrcuUXeUQkBOMnetaWcHrbpZaqatCvs,4345
|
13
13
|
yms_kan/tool.py,sha256=rkRpqF3EcsAq7a3k1F1zKlxfJ4U9n-FzHyNCJgN4URY,21159
|
14
|
-
yms_kan/train_eval_utils.py,sha256=
|
15
|
-
yms_kan/utils.py,sha256=
|
16
|
-
yms_kan/version.py,sha256=
|
14
|
+
yms_kan/train_eval_utils.py,sha256=3WPtCKLcrotU92s4S0uuIa1rXOAHxyAfDwFJGwUxvy0,16210
|
15
|
+
yms_kan/utils.py,sha256=k1fZvv9P6vBBV7LMysoTL2j-bglkBWO0l31dNkWI_Jo,26763
|
16
|
+
yms_kan/version.py,sha256=ts9Xi3n2P07g5eVEUSK46avv_nYOOTJ_EeHV2X6IfhM,40
|
17
17
|
yms_kan/assets/img/mult_symbol.png,sha256=2f4xUKdweft-qUbHjFI5h9-smnEtc0FWq8hNYZhPAXY,6392
|
18
18
|
yms_kan/assets/img/sum_symbol.png,sha256=94QkMUzmEjlCq_yf14nMEQmettaq86FmlGfdl22b4XE,6210
|
19
|
-
yms_kan-0.0.
|
20
|
-
yms_kan-0.0.
|
21
|
-
yms_kan-0.0.
|
22
|
-
yms_kan-0.0.
|
23
|
-
yms_kan-0.0.
|
19
|
+
yms_kan-0.0.10.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
|
20
|
+
yms_kan-0.0.10.dist-info/METADATA,sha256=pjZJ0E5OIBv93IoXD10lXpQCPN9ixDIRDuNUwov6Fls,241
|
21
|
+
yms_kan-0.0.10.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
22
|
+
yms_kan-0.0.10.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
|
23
|
+
yms_kan-0.0.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|