yms-kan 0.0.7__tar.gz → 0.0.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. yms_kan-0.0.10/MANIFEST.in +1 -0
  2. yms_kan-0.0.10/PKG-INFO +11 -0
  3. yms_kan-0.0.10/README.md +1 -0
  4. yms_kan-0.0.10/pyproject.toml +31 -0
  5. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/KANLayer.py +3 -3
  6. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/MLP.py +81 -71
  7. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/MultKAN.py +138 -135
  8. yms_kan-0.0.10/yms_kan/__init__.py +4 -0
  9. yms_kan-0.0.10/yms_kan/assets/img/mult_symbol.png +0 -0
  10. yms_kan-0.0.10/yms_kan/assets/img/sum_symbol.png +0 -0
  11. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/compiler.py +4 -4
  12. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/hypothesis.py +10 -10
  13. yms_kan-0.0.10/yms_kan/plotting.py +468 -0
  14. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/spline.py +1 -1
  15. yms_kan-0.0.10/yms_kan/tool.py +569 -0
  16. yms_kan-0.0.10/yms_kan/train_eval_utils.py +353 -0
  17. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/utils.py +90 -18
  18. yms_kan-0.0.10/yms_kan/version.py +1 -0
  19. yms_kan-0.0.10/yms_kan.egg-info/PKG-INFO +11 -0
  20. yms_kan-0.0.10/yms_kan.egg-info/SOURCES.txt +26 -0
  21. yms_kan-0.0.10/yms_kan.egg-info/top_level.txt +1 -0
  22. yms_kan-0.0.7/PKG-INFO +0 -18
  23. yms_kan-0.0.7/README.md +0 -1
  24. yms_kan-0.0.7/kan/__init__.py +0 -3
  25. yms_kan-0.0.7/kan/dataset.py +0 -27
  26. yms_kan-0.0.7/setup.py +0 -96
  27. yms_kan-0.0.7/yms_kan.egg-info/PKG-INFO +0 -18
  28. yms_kan-0.0.7/yms_kan.egg-info/SOURCES.txt +0 -20
  29. yms_kan-0.0.7/yms_kan.egg-info/top_level.txt +0 -1
  30. {yms_kan-0.0.7 → yms_kan-0.0.10}/LICENSE +0 -0
  31. {yms_kan-0.0.7 → yms_kan-0.0.10}/setup.cfg +0 -0
  32. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/LBFGS.py +0 -0
  33. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/Symbolic_KANLayer.py +0 -0
  34. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/experiment.py +0 -0
  35. {yms_kan-0.0.7/kan → yms_kan-0.0.10/yms_kan}/feynman.py +0 -0
  36. {yms_kan-0.0.7 → yms_kan-0.0.10}/yms_kan.egg-info/dependency_links.txt +0 -0
@@ -0,0 +1 @@
1
+ recursive-include yms_kan/assets/img *.png
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: yms_kan
3
+ Version: 0.0.10
4
+ Summary: My awesome package
5
+ Author-email: yms <11@qq.com>
6
+ License-Expression: MIT
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Dynamic: license-file
10
+
11
+ ### 测试
@@ -0,0 +1 @@
1
+ ### 测试
@@ -0,0 +1,31 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "yms_kan"
7
+ dynamic = ["version"]
8
+
9
+ description = "My awesome package"
10
+ authors = [{name = "yms", email = "11@qq.com"}]
11
+ readme = "README.md"
12
+ license = "MIT" # 直接使用字符串格式
13
+
14
+ # 可选:指定许可证文件
15
+ license-files = ["LICENSE"]
16
+
17
+ [tool.setuptools.dynamic]
18
+ # 明确指定版本号来源
19
+ version = {attr = "yms_kan.version.__version__"}
20
+
21
+ [tool.setuptools]
22
+ # 包含非代码文件
23
+ include-package-data = true
24
+
25
+ [tool.setuptools.package-data]
26
+ # 指定包内资源文件的匹配规则
27
+ yms_kan = [
28
+ "assets/img/*.png", # 包含所有png文件
29
+ "assets/img/*.svg", # 可扩展其他格式
30
+ "assets/**/*" # 递归包含子目录
31
+ ]
@@ -84,7 +84,7 @@ class KANLayer(nn.Module):
84
84
 
85
85
  Example
86
86
  -------
87
- >>> from kan.KANLayer import *
87
+ >>> from yms_kan.KANLayer import *
88
88
  >>> model = KANLayer(in_dim=3, out_dim=5)
89
89
  >>> (model.in_dim, model.out_dim)
90
90
  '''
@@ -144,7 +144,7 @@ class KANLayer(nn.Module):
144
144
 
145
145
  Example
146
146
  -------
147
- >>> from kan.KANLayer import *
147
+ >>> from yms_kan.KANLayer import *
148
148
  >>> model = KANLayer(in_dim=3, out_dim=5)
149
149
  >>> x = torch.normal(0,1,size=(100,3))
150
150
  >>> y, preacts, postacts, postspline = model(x)
@@ -342,7 +342,7 @@ class KANLayer(nn.Module):
342
342
 
343
343
  Example
344
344
  -------
345
- >>> from kan.KANLayer import *
345
+ >>> from yms_kan.KANLayer import *
346
346
  >>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3)
347
347
  >>> print(model.coef)
348
348
  >>> model.swap(0,1,mode='in')
@@ -3,43 +3,44 @@ import torch.nn as nn
3
3
  import matplotlib.pyplot as plt
4
4
  import numpy as np
5
5
  from tqdm import tqdm
6
+
7
+ from . import plot_tree
6
8
  from .LBFGS import LBFGS
7
9
 
8
10
  seed = 0
9
11
  torch.manual_seed(seed)
10
12
 
13
+
11
14
  class MLP(nn.Module):
12
-
15
+
13
16
  def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
14
17
  super(MLP, self).__init__()
15
-
18
+
16
19
  torch.manual_seed(seed)
17
-
20
+
18
21
  linears = []
19
22
  self.width = width
20
23
  self.depth = depth = len(width) - 1
21
24
  for i in range(depth):
22
- linears.append(nn.Linear(width[i], width[i+1]))
25
+ linears.append(nn.Linear(width[i], width[i + 1]))
23
26
  self.linears = nn.ModuleList(linears)
24
-
27
+
25
28
  #if activation == 'silu':
26
29
  self.act_fun = torch.nn.SiLU()
27
30
  self.save_act = save_act
28
31
  self.acts = None
29
-
32
+
30
33
  self.cache_data = None
31
-
34
+
32
35
  self.device = device
33
36
  self.to(device)
34
-
35
-
37
+
36
38
  def to(self, device):
37
39
  super(MLP, self).to(device)
38
40
  self.device = device
39
-
41
+
40
42
  return self
41
-
42
-
43
+
43
44
  def get_act(self, x=None):
44
45
  if isinstance(x, dict):
45
46
  x = x['train_input']
@@ -52,23 +53,23 @@ class MLP(nn.Module):
52
53
  self.save_act = True
53
54
  self.forward(x)
54
55
  self.save_act = save_act
55
-
56
+
56
57
  @property
57
58
  def w(self):
58
59
  return [self.linears[l].weight for l in range(self.depth)]
59
-
60
+
60
61
  def forward(self, x):
61
-
62
+
62
63
  # cache data
63
64
  self.cache_data = x
64
-
65
+
65
66
  self.acts = []
66
67
  self.acts_scale = []
67
68
  self.wa_forward = []
68
69
  self.a_forward = []
69
-
70
+
70
71
  for i in range(self.depth):
71
-
72
+
72
73
  if self.save_act:
73
74
  act = x.clone()
74
75
  act_scale = torch.std(x, dim=0)
@@ -77,7 +78,7 @@ class MLP(nn.Module):
77
78
  if i > 0:
78
79
  self.acts_scale.append(act_scale)
79
80
  self.wa_forward.append(wa_forward)
80
-
81
+
81
82
  x = self.linears[i](x)
82
83
  if i < self.depth - 1:
83
84
  x = self.act_fun(x)
@@ -85,9 +86,9 @@ class MLP(nn.Module):
85
86
  if self.save_act:
86
87
  act_scale = torch.std(x, dim=0)
87
88
  self.acts_scale.append(act_scale)
88
-
89
+
89
90
  return x
90
-
91
+
91
92
  def attribute(self):
92
93
  if self.acts == None:
93
94
  self.get_act()
@@ -99,47 +100,47 @@ class MLP(nn.Module):
99
100
  node_score = torch.ones(self.width[-1]).requires_grad_(True).to(self.device)
100
101
  node_scores.append(node_score)
101
102
 
102
- for l in range(self.depth,0,-1):
103
-
104
- edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l-1]), node_score/(self.acts_scale[l-1]+1e-4))
103
+ for l in range(self.depth, 0, -1):
104
+ edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l - 1]),
105
+ node_score / (self.acts_scale[l - 1] + 1e-4))
105
106
  edge_scores.append(edge_score)
106
107
 
107
108
  # this might be improper for MLPs (although reasonable for KANs)
108
- node_score = torch.sum(edge_score, dim=0)/torch.sqrt(torch.tensor(self.width[l-1], device=self.device))
109
+ node_score = torch.sum(edge_score, dim=0) / torch.sqrt(torch.tensor(self.width[l - 1], device=self.device))
109
110
  #print(self.width[l])
110
111
  node_scores.append(node_score)
111
112
 
112
113
  self.node_scores = list(reversed(node_scores))
113
114
  self.edge_scores = list(reversed(edge_scores))
114
115
  self.wa_backward = self.edge_scores
115
-
116
+
116
117
  def plot(self, beta=3, scale=1., metric='w'):
117
118
  # metric = 'w', 'act' or 'fa'
118
-
119
+
119
120
  if metric == 'fa':
120
121
  self.attribute()
121
-
122
+
122
123
  depth = self.depth
123
124
  y0 = 0.5
124
- fig, ax = plt.subplots(figsize=(3*scale,3*y0*depth*scale))
125
+ fig, ax = plt.subplots(figsize=(3 * scale, 3 * y0 * depth * scale))
125
126
  shp = self.width
126
-
127
- min_spacing = 1/max(self.width)
127
+
128
+ min_spacing = 1 / max(self.width)
128
129
  for j in range(len(shp)):
129
130
  N = shp[j]
130
131
  for i in range(N):
131
132
  plt.scatter(1 / (2 * N) + i / N, j * y0, s=min_spacing ** 2 * 5000 * scale ** 2, color='black')
132
-
133
- plt.ylim(-0.1*y0,y0*depth+0.1*y0)
134
- plt.xlim(-0.02,1.02)
133
+
134
+ plt.ylim(-0.1 * y0, y0 * depth + 0.1 * y0)
135
+ plt.xlim(-0.02, 1.02)
135
136
 
136
137
  linears = self.linears
137
-
138
+
138
139
  for ii in range(len(linears)):
139
140
  linear = linears[ii]
140
141
  p = linear.weight
141
142
  p_shp = p.shape
142
-
143
+
143
144
  if metric == 'w':
144
145
  pass
145
146
  elif metric == 'act':
@@ -150,12 +151,15 @@ class MLP(nn.Module):
150
151
  raise Exception('metric = \'{}\' not recognized. Choices are \'w\', \'act\', \'fa\'.'.format(metric))
151
152
  for i in range(p_shp[0]):
152
153
  for j in range(p_shp[1]):
153
- plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]], [y0*(ii+1),y0*ii], lw=0.5*scale, alpha=np.tanh(beta*np.abs(p[i,j].cpu().detach().numpy())), color="blue" if p[i,j]>0 else "red")
154
-
154
+ plt.plot([1 / (2 * p_shp[0]) + i / p_shp[0], 1 / (2 * p_shp[1]) + j / p_shp[1]],
155
+ [y0 * (ii + 1), y0 * ii], lw=0.5 * scale,
156
+ alpha=np.tanh(beta * np.abs(p[i, j].cpu().detach().numpy())),
157
+ color="blue" if p[i, j] > 0 else "red")
158
+
155
159
  ax.axis('off')
156
-
160
+
157
161
  def reg(self, reg_metric, lamb_l1, lamb_entropy):
158
-
162
+
159
163
  if reg_metric == 'w':
160
164
  acts_scale = self.w
161
165
  if reg_metric == 'act':
@@ -164,7 +168,7 @@ class MLP(nn.Module):
164
168
  acts_scale = self.wa_backward
165
169
  if reg_metric == 'a':
166
170
  acts_scale = self.acts_scale
167
-
171
+
168
172
  if len(acts_scale[0].shape) == 2:
169
173
  reg_ = 0.
170
174
 
@@ -178,9 +182,9 @@ class MLP(nn.Module):
178
182
  entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1))
179
183
  entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0))
180
184
  reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col)
181
-
185
+
182
186
  elif len(acts_scale[0].shape) == 1:
183
-
187
+
184
188
  reg_ = 0.
185
189
 
186
190
  for i in range(len(acts_scale)):
@@ -193,20 +197,21 @@ class MLP(nn.Module):
193
197
  reg_ += lamb_l1 * l1 + lamb_entropy * entropy
194
198
 
195
199
  return reg_
196
-
200
+
197
201
  def get_reg(self, reg_metric, lamb_l1, lamb_entropy):
198
202
  return self.reg(reg_metric, lamb_l1, lamb_entropy)
199
-
200
- def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1., batch=-1,
201
- metrics=None, in_vars=None, out_vars=None, beta=3, device='cpu', reg_metric='w', display_metrics=None):
203
+
204
+ def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1.,
205
+ batch=-1,
206
+ metrics=None, device='cpu', reg_metric='w', display_metrics=None):
202
207
 
203
208
  if lamb > 0. and not self.save_act:
204
209
  print('setting lamb=0. If you want to set lamb > 0, set =True')
205
-
210
+
206
211
  old_save_act = self.save_act
207
212
  if lamb == 0.:
208
213
  self.save_act = False
209
-
214
+
210
215
  pbar = tqdm(range(steps), desc='description', ncols=100)
211
216
 
212
217
  if loss_fn == None:
@@ -217,7 +222,8 @@ class MLP(nn.Module):
217
222
  if opt == "Adam":
218
223
  optimizer = torch.optim.Adam(self.parameters(), lr=lr)
219
224
  elif opt == "LBFGS":
220
- optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
225
+ optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
226
+ tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
221
227
 
222
228
  results = {}
223
229
  results['train_loss'] = []
@@ -252,10 +258,10 @@ class MLP(nn.Module):
252
258
  return objective
253
259
 
254
260
  for _ in pbar:
255
-
256
- if _ == steps-1 and old_save_act:
261
+
262
+ if _ == steps - 1 and old_save_act:
257
263
  self.save_act = True
258
-
264
+
259
265
  train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
260
266
  test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
261
267
 
@@ -274,9 +280,9 @@ class MLP(nn.Module):
274
280
  loss.backward()
275
281
  optimizer.step()
276
282
 
277
- test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)), dataset['test_label'][test_id].to(self.device))
278
-
279
-
283
+ test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)),
284
+ dataset['test_label'][test_id].to(self.device))
285
+
280
286
  if metrics != None:
281
287
  for i in range(len(metrics)):
282
288
  results[metrics[i].__name__].append(metrics[i]().item())
@@ -287,7 +293,9 @@ class MLP(nn.Module):
287
293
 
288
294
  if _ % log == 0:
289
295
  if display_metrics == None:
290
- pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
296
+ pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (
297
+ torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(),
298
+ reg_.cpu().detach().numpy()))
291
299
  else:
292
300
  string = ''
293
301
  data = ()
@@ -299,9 +307,9 @@ class MLP(nn.Module):
299
307
  raise Exception(f'{metric} not recognized')
300
308
  data += (results[metric][-1],)
301
309
  pbar.set_description(string % data)
302
-
310
+
303
311
  return results
304
-
312
+
305
313
  @property
306
314
  def connection_cost(self):
307
315
 
@@ -309,8 +317,9 @@ class MLP(nn.Module):
309
317
  cc = 0.
310
318
  for linear in self.linears:
311
319
  t = torch.abs(linear.weight)
320
+
312
321
  def get_coordinate(n):
313
- return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n)
322
+ return torch.linspace(0, 1, steps=n + 1, device=self.device)[:n] + 1 / (2 * n)
314
323
 
315
324
  in_dim = t.shape[0]
316
325
  x_in = get_coordinate(in_dim)
@@ -318,44 +327,45 @@ class MLP(nn.Module):
318
327
  out_dim = t.shape[1]
319
328
  x_out = get_coordinate(out_dim)
320
329
 
321
- dist = torch.abs(x_in[:,None] - x_out[None,:])
330
+ dist = torch.abs(x_in[:, None] - x_out[None, :])
322
331
  cc += torch.sum(dist * t)
323
332
 
324
333
  return cc
325
-
334
+
326
335
  def swap(self, l, i1, i2):
327
336
 
328
337
  def swap_row(data, i1, i2):
329
338
  data[i1], data[i2] = data[i2].clone(), data[i1].clone()
330
339
 
331
340
  def swap_col(data, i1, i2):
332
- data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
341
+ data[:, i1], data[:, i2] = data[:, i2].clone(), data[:, i1].clone()
333
342
 
334
- swap_row(self.linears[l-1].weight.data, i1, i2)
335
- swap_row(self.linears[l-1].bias.data, i1, i2)
343
+ swap_row(self.linears[l - 1].weight.data, i1, i2)
344
+ swap_row(self.linears[l - 1].bias.data, i1, i2)
336
345
  swap_col(self.linears[l].weight.data, i1, i2)
337
-
346
+
338
347
  def auto_swap_l(self, l):
339
348
 
340
349
  num = self.width[l]
341
350
  for i in range(num):
342
351
  ccs = []
343
352
  for j in range(num):
344
- self.swap(l,i,j)
353
+ self.swap(l, i, j)
345
354
  self.get_act()
346
355
  self.attribute()
347
356
  cc = self.connection_cost.detach().clone()
348
357
  ccs.append(cc)
349
- self.swap(l,i,j)
358
+ self.swap(l, i, j)
350
359
  j = torch.argmin(torch.tensor(ccs))
351
- self.swap(l,i,j)
360
+ self.swap(l, i, j)
352
361
 
353
362
  def auto_swap(self):
354
363
  depth = self.depth
355
364
  for l in range(1, depth):
356
365
  self.auto_swap_l(l)
357
-
366
+
358
367
  def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
359
368
  if x == None:
360
369
  x = self.cache_data
361
- plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test, verbose=verbose)
370
+ plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test,
371
+ verbose=verbose)