yms-kan 0.0.7__tar.gz → 0.0.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan-0.0.10/MANIFEST.in +1 -0
- yms_kan-0.0.10/PKG-INFO +11 -0
- yms_kan-0.0.10/README.md +1 -0
- yms_kan-0.0.10/pyproject.toml +31 -0
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/KANLayer.py +3 -3
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/MLP.py +81 -71
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/MultKAN.py +138 -135
- yms_kan-0.0.10/yms_kan/__init__.py +4 -0
- yms_kan-0.0.10/yms_kan/assets/img/mult_symbol.png +0 -0
- yms_kan-0.0.10/yms_kan/assets/img/sum_symbol.png +0 -0
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/compiler.py +4 -4
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/hypothesis.py +10 -10
- yms_kan-0.0.10/yms_kan/plotting.py +468 -0
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/spline.py +1 -1
- yms_kan-0.0.10/yms_kan/tool.py +569 -0
- yms_kan-0.0.10/yms_kan/train_eval_utils.py +353 -0
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/utils.py +90 -18
- yms_kan-0.0.10/yms_kan/version.py +1 -0
- yms_kan-0.0.10/yms_kan.egg-info/PKG-INFO +11 -0
- yms_kan-0.0.10/yms_kan.egg-info/SOURCES.txt +26 -0
- yms_kan-0.0.10/yms_kan.egg-info/top_level.txt +1 -0
- yms_kan-0.0.7/PKG-INFO +0 -18
- yms_kan-0.0.7/README.md +0 -1
- yms_kan-0.0.7/kan/__init__.py +0 -3
- yms_kan-0.0.7/kan/dataset.py +0 -27
- yms_kan-0.0.7/setup.py +0 -96
- yms_kan-0.0.7/yms_kan.egg-info/PKG-INFO +0 -18
- yms_kan-0.0.7/yms_kan.egg-info/SOURCES.txt +0 -20
- yms_kan-0.0.7/yms_kan.egg-info/top_level.txt +0 -1
- {yms_kan-0.0.7 → yms_kan-0.0.10}/LICENSE +0 -0
- {yms_kan-0.0.7 → yms_kan-0.0.10}/setup.cfg +0 -0
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/LBFGS.py +0 -0
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/Symbolic_KANLayer.py +0 -0
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/experiment.py +0 -0
- {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/feynman.py +0 -0
- {yms_kan-0.0.7 → yms_kan-0.0.10}/yms_kan.egg-info/dependency_links.txt +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
recursive-include yms_kan/assets/img *.png
|
yms_kan-0.0.10/PKG-INFO
ADDED
yms_kan-0.0.10/README.md
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
### 测试
|
@@ -0,0 +1,31 @@
|
|
1
|
+
[build-system]
|
2
|
+
requires = ["setuptools>=61.0"]
|
3
|
+
build-backend = "setuptools.build_meta"
|
4
|
+
|
5
|
+
[project]
|
6
|
+
name = "yms_kan"
|
7
|
+
dynamic = ["version"]
|
8
|
+
|
9
|
+
description = "My awesome package"
|
10
|
+
authors = [{name = "yms", email = "11@qq.com"}]
|
11
|
+
readme = "README.md"
|
12
|
+
license = "MIT" # 直接使用字符串格式
|
13
|
+
|
14
|
+
# 可选:指定许可证文件
|
15
|
+
license-files = ["LICENSE"]
|
16
|
+
|
17
|
+
[tool.setuptools.dynamic]
|
18
|
+
# 明确指定版本号来源
|
19
|
+
version = {attr = "yms_kan.version.__version__"}
|
20
|
+
|
21
|
+
[tool.setuptools]
|
22
|
+
# 包含非代码文件
|
23
|
+
include-package-data = true
|
24
|
+
|
25
|
+
[tool.setuptools.package-data]
|
26
|
+
# 指定包内资源文件的匹配规则
|
27
|
+
yms_kan = [
|
28
|
+
"assets/img/*.png", # 包含所有png文件
|
29
|
+
"assets/img/*.svg", # 可扩展其他格式
|
30
|
+
"assets/**/*" # 递归包含子目录
|
31
|
+
]
|
@@ -84,7 +84,7 @@ class KANLayer(nn.Module):
|
|
84
84
|
|
85
85
|
Example
|
86
86
|
-------
|
87
|
-
>>> from
|
87
|
+
>>> from yms_kan.KANLayer import *
|
88
88
|
>>> model = KANLayer(in_dim=3, out_dim=5)
|
89
89
|
>>> (model.in_dim, model.out_dim)
|
90
90
|
'''
|
@@ -144,7 +144,7 @@ class KANLayer(nn.Module):
|
|
144
144
|
|
145
145
|
Example
|
146
146
|
-------
|
147
|
-
>>> from
|
147
|
+
>>> from yms_kan.KANLayer import *
|
148
148
|
>>> model = KANLayer(in_dim=3, out_dim=5)
|
149
149
|
>>> x = torch.normal(0,1,size=(100,3))
|
150
150
|
>>> y, preacts, postacts, postspline = model(x)
|
@@ -342,7 +342,7 @@ class KANLayer(nn.Module):
|
|
342
342
|
|
343
343
|
Example
|
344
344
|
-------
|
345
|
-
>>> from
|
345
|
+
>>> from yms_kan.KANLayer import *
|
346
346
|
>>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3)
|
347
347
|
>>> print(model.coef)
|
348
348
|
>>> model.swap(0,1,mode='in')
|
@@ -3,43 +3,44 @@ import torch.nn as nn
|
|
3
3
|
import matplotlib.pyplot as plt
|
4
4
|
import numpy as np
|
5
5
|
from tqdm import tqdm
|
6
|
+
|
7
|
+
from . import plot_tree
|
6
8
|
from .LBFGS import LBFGS
|
7
9
|
|
8
10
|
seed = 0
|
9
11
|
torch.manual_seed(seed)
|
10
12
|
|
13
|
+
|
11
14
|
class MLP(nn.Module):
|
12
|
-
|
15
|
+
|
13
16
|
def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
|
14
17
|
super(MLP, self).__init__()
|
15
|
-
|
18
|
+
|
16
19
|
torch.manual_seed(seed)
|
17
|
-
|
20
|
+
|
18
21
|
linears = []
|
19
22
|
self.width = width
|
20
23
|
self.depth = depth = len(width) - 1
|
21
24
|
for i in range(depth):
|
22
|
-
linears.append(nn.Linear(width[i], width[i+1]))
|
25
|
+
linears.append(nn.Linear(width[i], width[i + 1]))
|
23
26
|
self.linears = nn.ModuleList(linears)
|
24
|
-
|
27
|
+
|
25
28
|
#if activation == 'silu':
|
26
29
|
self.act_fun = torch.nn.SiLU()
|
27
30
|
self.save_act = save_act
|
28
31
|
self.acts = None
|
29
|
-
|
32
|
+
|
30
33
|
self.cache_data = None
|
31
|
-
|
34
|
+
|
32
35
|
self.device = device
|
33
36
|
self.to(device)
|
34
|
-
|
35
|
-
|
37
|
+
|
36
38
|
def to(self, device):
|
37
39
|
super(MLP, self).to(device)
|
38
40
|
self.device = device
|
39
|
-
|
41
|
+
|
40
42
|
return self
|
41
|
-
|
42
|
-
|
43
|
+
|
43
44
|
def get_act(self, x=None):
|
44
45
|
if isinstance(x, dict):
|
45
46
|
x = x['train_input']
|
@@ -52,23 +53,23 @@ class MLP(nn.Module):
|
|
52
53
|
self.save_act = True
|
53
54
|
self.forward(x)
|
54
55
|
self.save_act = save_act
|
55
|
-
|
56
|
+
|
56
57
|
@property
|
57
58
|
def w(self):
|
58
59
|
return [self.linears[l].weight for l in range(self.depth)]
|
59
|
-
|
60
|
+
|
60
61
|
def forward(self, x):
|
61
|
-
|
62
|
+
|
62
63
|
# cache data
|
63
64
|
self.cache_data = x
|
64
|
-
|
65
|
+
|
65
66
|
self.acts = []
|
66
67
|
self.acts_scale = []
|
67
68
|
self.wa_forward = []
|
68
69
|
self.a_forward = []
|
69
|
-
|
70
|
+
|
70
71
|
for i in range(self.depth):
|
71
|
-
|
72
|
+
|
72
73
|
if self.save_act:
|
73
74
|
act = x.clone()
|
74
75
|
act_scale = torch.std(x, dim=0)
|
@@ -77,7 +78,7 @@ class MLP(nn.Module):
|
|
77
78
|
if i > 0:
|
78
79
|
self.acts_scale.append(act_scale)
|
79
80
|
self.wa_forward.append(wa_forward)
|
80
|
-
|
81
|
+
|
81
82
|
x = self.linears[i](x)
|
82
83
|
if i < self.depth - 1:
|
83
84
|
x = self.act_fun(x)
|
@@ -85,9 +86,9 @@ class MLP(nn.Module):
|
|
85
86
|
if self.save_act:
|
86
87
|
act_scale = torch.std(x, dim=0)
|
87
88
|
self.acts_scale.append(act_scale)
|
88
|
-
|
89
|
+
|
89
90
|
return x
|
90
|
-
|
91
|
+
|
91
92
|
def attribute(self):
|
92
93
|
if self.acts == None:
|
93
94
|
self.get_act()
|
@@ -99,47 +100,47 @@ class MLP(nn.Module):
|
|
99
100
|
node_score = torch.ones(self.width[-1]).requires_grad_(True).to(self.device)
|
100
101
|
node_scores.append(node_score)
|
101
102
|
|
102
|
-
for l in range(self.depth,0
|
103
|
-
|
104
|
-
|
103
|
+
for l in range(self.depth, 0, -1):
|
104
|
+
edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l - 1]),
|
105
|
+
node_score / (self.acts_scale[l - 1] + 1e-4))
|
105
106
|
edge_scores.append(edge_score)
|
106
107
|
|
107
108
|
# this might be improper for MLPs (although reasonable for KANs)
|
108
|
-
node_score = torch.sum(edge_score, dim=0)/torch.sqrt(torch.tensor(self.width[l-1], device=self.device))
|
109
|
+
node_score = torch.sum(edge_score, dim=0) / torch.sqrt(torch.tensor(self.width[l - 1], device=self.device))
|
109
110
|
#print(self.width[l])
|
110
111
|
node_scores.append(node_score)
|
111
112
|
|
112
113
|
self.node_scores = list(reversed(node_scores))
|
113
114
|
self.edge_scores = list(reversed(edge_scores))
|
114
115
|
self.wa_backward = self.edge_scores
|
115
|
-
|
116
|
+
|
116
117
|
def plot(self, beta=3, scale=1., metric='w'):
|
117
118
|
# metric = 'w', 'act' or 'fa'
|
118
|
-
|
119
|
+
|
119
120
|
if metric == 'fa':
|
120
121
|
self.attribute()
|
121
|
-
|
122
|
+
|
122
123
|
depth = self.depth
|
123
124
|
y0 = 0.5
|
124
|
-
fig, ax = plt.subplots(figsize=(3*scale,3*y0*depth*scale))
|
125
|
+
fig, ax = plt.subplots(figsize=(3 * scale, 3 * y0 * depth * scale))
|
125
126
|
shp = self.width
|
126
|
-
|
127
|
-
min_spacing = 1/max(self.width)
|
127
|
+
|
128
|
+
min_spacing = 1 / max(self.width)
|
128
129
|
for j in range(len(shp)):
|
129
130
|
N = shp[j]
|
130
131
|
for i in range(N):
|
131
132
|
plt.scatter(1 / (2 * N) + i / N, j * y0, s=min_spacing ** 2 * 5000 * scale ** 2, color='black')
|
132
|
-
|
133
|
-
plt.ylim(-0.1*y0,y0*depth+0.1*y0)
|
134
|
-
plt.xlim(-0.02,1.02)
|
133
|
+
|
134
|
+
plt.ylim(-0.1 * y0, y0 * depth + 0.1 * y0)
|
135
|
+
plt.xlim(-0.02, 1.02)
|
135
136
|
|
136
137
|
linears = self.linears
|
137
|
-
|
138
|
+
|
138
139
|
for ii in range(len(linears)):
|
139
140
|
linear = linears[ii]
|
140
141
|
p = linear.weight
|
141
142
|
p_shp = p.shape
|
142
|
-
|
143
|
+
|
143
144
|
if metric == 'w':
|
144
145
|
pass
|
145
146
|
elif metric == 'act':
|
@@ -150,12 +151,15 @@ class MLP(nn.Module):
|
|
150
151
|
raise Exception('metric = \'{}\' not recognized. Choices are \'w\', \'act\', \'fa\'.'.format(metric))
|
151
152
|
for i in range(p_shp[0]):
|
152
153
|
for j in range(p_shp[1]):
|
153
|
-
plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]],
|
154
|
-
|
154
|
+
plt.plot([1 / (2 * p_shp[0]) + i / p_shp[0], 1 / (2 * p_shp[1]) + j / p_shp[1]],
|
155
|
+
[y0 * (ii + 1), y0 * ii], lw=0.5 * scale,
|
156
|
+
alpha=np.tanh(beta * np.abs(p[i, j].cpu().detach().numpy())),
|
157
|
+
color="blue" if p[i, j] > 0 else "red")
|
158
|
+
|
155
159
|
ax.axis('off')
|
156
|
-
|
160
|
+
|
157
161
|
def reg(self, reg_metric, lamb_l1, lamb_entropy):
|
158
|
-
|
162
|
+
|
159
163
|
if reg_metric == 'w':
|
160
164
|
acts_scale = self.w
|
161
165
|
if reg_metric == 'act':
|
@@ -164,7 +168,7 @@ class MLP(nn.Module):
|
|
164
168
|
acts_scale = self.wa_backward
|
165
169
|
if reg_metric == 'a':
|
166
170
|
acts_scale = self.acts_scale
|
167
|
-
|
171
|
+
|
168
172
|
if len(acts_scale[0].shape) == 2:
|
169
173
|
reg_ = 0.
|
170
174
|
|
@@ -178,9 +182,9 @@ class MLP(nn.Module):
|
|
178
182
|
entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1))
|
179
183
|
entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0))
|
180
184
|
reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col)
|
181
|
-
|
185
|
+
|
182
186
|
elif len(acts_scale[0].shape) == 1:
|
183
|
-
|
187
|
+
|
184
188
|
reg_ = 0.
|
185
189
|
|
186
190
|
for i in range(len(acts_scale)):
|
@@ -193,20 +197,21 @@ class MLP(nn.Module):
|
|
193
197
|
reg_ += lamb_l1 * l1 + lamb_entropy * entropy
|
194
198
|
|
195
199
|
return reg_
|
196
|
-
|
200
|
+
|
197
201
|
def get_reg(self, reg_metric, lamb_l1, lamb_entropy):
|
198
202
|
return self.reg(reg_metric, lamb_l1, lamb_entropy)
|
199
|
-
|
200
|
-
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1.,
|
201
|
-
|
203
|
+
|
204
|
+
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1.,
|
205
|
+
batch=-1,
|
206
|
+
metrics=None, device='cpu', reg_metric='w', display_metrics=None):
|
202
207
|
|
203
208
|
if lamb > 0. and not self.save_act:
|
204
209
|
print('setting lamb=0. If you want to set lamb > 0, set =True')
|
205
|
-
|
210
|
+
|
206
211
|
old_save_act = self.save_act
|
207
212
|
if lamb == 0.:
|
208
213
|
self.save_act = False
|
209
|
-
|
214
|
+
|
210
215
|
pbar = tqdm(range(steps), desc='description', ncols=100)
|
211
216
|
|
212
217
|
if loss_fn == None:
|
@@ -217,7 +222,8 @@ class MLP(nn.Module):
|
|
217
222
|
if opt == "Adam":
|
218
223
|
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
|
219
224
|
elif opt == "LBFGS":
|
220
|
-
optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
225
|
+
optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
226
|
+
tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
221
227
|
|
222
228
|
results = {}
|
223
229
|
results['train_loss'] = []
|
@@ -252,10 +258,10 @@ class MLP(nn.Module):
|
|
252
258
|
return objective
|
253
259
|
|
254
260
|
for _ in pbar:
|
255
|
-
|
256
|
-
if _ == steps-1 and old_save_act:
|
261
|
+
|
262
|
+
if _ == steps - 1 and old_save_act:
|
257
263
|
self.save_act = True
|
258
|
-
|
264
|
+
|
259
265
|
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
|
260
266
|
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
|
261
267
|
|
@@ -274,9 +280,9 @@ class MLP(nn.Module):
|
|
274
280
|
loss.backward()
|
275
281
|
optimizer.step()
|
276
282
|
|
277
|
-
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)),
|
278
|
-
|
279
|
-
|
283
|
+
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)),
|
284
|
+
dataset['test_label'][test_id].to(self.device))
|
285
|
+
|
280
286
|
if metrics != None:
|
281
287
|
for i in range(len(metrics)):
|
282
288
|
results[metrics[i].__name__].append(metrics[i]().item())
|
@@ -287,7 +293,9 @@ class MLP(nn.Module):
|
|
287
293
|
|
288
294
|
if _ % log == 0:
|
289
295
|
if display_metrics == None:
|
290
|
-
pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (
|
296
|
+
pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (
|
297
|
+
torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(),
|
298
|
+
reg_.cpu().detach().numpy()))
|
291
299
|
else:
|
292
300
|
string = ''
|
293
301
|
data = ()
|
@@ -299,9 +307,9 @@ class MLP(nn.Module):
|
|
299
307
|
raise Exception(f'{metric} not recognized')
|
300
308
|
data += (results[metric][-1],)
|
301
309
|
pbar.set_description(string % data)
|
302
|
-
|
310
|
+
|
303
311
|
return results
|
304
|
-
|
312
|
+
|
305
313
|
@property
|
306
314
|
def connection_cost(self):
|
307
315
|
|
@@ -309,8 +317,9 @@ class MLP(nn.Module):
|
|
309
317
|
cc = 0.
|
310
318
|
for linear in self.linears:
|
311
319
|
t = torch.abs(linear.weight)
|
320
|
+
|
312
321
|
def get_coordinate(n):
|
313
|
-
return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n)
|
322
|
+
return torch.linspace(0, 1, steps=n + 1, device=self.device)[:n] + 1 / (2 * n)
|
314
323
|
|
315
324
|
in_dim = t.shape[0]
|
316
325
|
x_in = get_coordinate(in_dim)
|
@@ -318,44 +327,45 @@ class MLP(nn.Module):
|
|
318
327
|
out_dim = t.shape[1]
|
319
328
|
x_out = get_coordinate(out_dim)
|
320
329
|
|
321
|
-
dist = torch.abs(x_in[:,None] - x_out[None
|
330
|
+
dist = torch.abs(x_in[:, None] - x_out[None, :])
|
322
331
|
cc += torch.sum(dist * t)
|
323
332
|
|
324
333
|
return cc
|
325
|
-
|
334
|
+
|
326
335
|
def swap(self, l, i1, i2):
|
327
336
|
|
328
337
|
def swap_row(data, i1, i2):
|
329
338
|
data[i1], data[i2] = data[i2].clone(), data[i1].clone()
|
330
339
|
|
331
340
|
def swap_col(data, i1, i2):
|
332
|
-
data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
|
341
|
+
data[:, i1], data[:, i2] = data[:, i2].clone(), data[:, i1].clone()
|
333
342
|
|
334
|
-
swap_row(self.linears[l-1].weight.data, i1, i2)
|
335
|
-
swap_row(self.linears[l-1].bias.data, i1, i2)
|
343
|
+
swap_row(self.linears[l - 1].weight.data, i1, i2)
|
344
|
+
swap_row(self.linears[l - 1].bias.data, i1, i2)
|
336
345
|
swap_col(self.linears[l].weight.data, i1, i2)
|
337
|
-
|
346
|
+
|
338
347
|
def auto_swap_l(self, l):
|
339
348
|
|
340
349
|
num = self.width[l]
|
341
350
|
for i in range(num):
|
342
351
|
ccs = []
|
343
352
|
for j in range(num):
|
344
|
-
self.swap(l,i,j)
|
353
|
+
self.swap(l, i, j)
|
345
354
|
self.get_act()
|
346
355
|
self.attribute()
|
347
356
|
cc = self.connection_cost.detach().clone()
|
348
357
|
ccs.append(cc)
|
349
|
-
self.swap(l,i,j)
|
358
|
+
self.swap(l, i, j)
|
350
359
|
j = torch.argmin(torch.tensor(ccs))
|
351
|
-
self.swap(l,i,j)
|
360
|
+
self.swap(l, i, j)
|
352
361
|
|
353
362
|
def auto_swap(self):
|
354
363
|
depth = self.depth
|
355
364
|
for l in range(1, depth):
|
356
365
|
self.auto_swap_l(l)
|
357
|
-
|
366
|
+
|
358
367
|
def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
|
359
368
|
if x == None:
|
360
369
|
x = self.cache_data
|
361
|
-
plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test,
|
370
|
+
plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test,
|
371
|
+
verbose=verbose)
|