yms-kan 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yms_kan/MLP.py CHANGED
@@ -3,43 +3,44 @@ import torch.nn as nn
3
3
  import matplotlib.pyplot as plt
4
4
  import numpy as np
5
5
  from tqdm import tqdm
6
+
7
+ from . import plot_tree
6
8
  from .LBFGS import LBFGS
7
9
 
8
10
  seed = 0
9
11
  torch.manual_seed(seed)
10
12
 
13
+
11
14
  class MLP(nn.Module):
12
-
15
+
13
16
  def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
14
17
  super(MLP, self).__init__()
15
-
18
+
16
19
  torch.manual_seed(seed)
17
-
20
+
18
21
  linears = []
19
22
  self.width = width
20
23
  self.depth = depth = len(width) - 1
21
24
  for i in range(depth):
22
- linears.append(nn.Linear(width[i], width[i+1]))
25
+ linears.append(nn.Linear(width[i], width[i + 1]))
23
26
  self.linears = nn.ModuleList(linears)
24
-
27
+
25
28
  #if activation == 'silu':
26
29
  self.act_fun = torch.nn.SiLU()
27
30
  self.save_act = save_act
28
31
  self.acts = None
29
-
32
+
30
33
  self.cache_data = None
31
-
34
+
32
35
  self.device = device
33
36
  self.to(device)
34
-
35
-
37
+
36
38
  def to(self, device):
37
39
  super(MLP, self).to(device)
38
40
  self.device = device
39
-
41
+
40
42
  return self
41
-
42
-
43
+
43
44
  def get_act(self, x=None):
44
45
  if isinstance(x, dict):
45
46
  x = x['train_input']
@@ -52,23 +53,23 @@ class MLP(nn.Module):
52
53
  self.save_act = True
53
54
  self.forward(x)
54
55
  self.save_act = save_act
55
-
56
+
56
57
  @property
57
58
  def w(self):
58
59
  return [self.linears[l].weight for l in range(self.depth)]
59
-
60
+
60
61
  def forward(self, x):
61
-
62
+
62
63
  # cache data
63
64
  self.cache_data = x
64
-
65
+
65
66
  self.acts = []
66
67
  self.acts_scale = []
67
68
  self.wa_forward = []
68
69
  self.a_forward = []
69
-
70
+
70
71
  for i in range(self.depth):
71
-
72
+
72
73
  if self.save_act:
73
74
  act = x.clone()
74
75
  act_scale = torch.std(x, dim=0)
@@ -77,7 +78,7 @@ class MLP(nn.Module):
77
78
  if i > 0:
78
79
  self.acts_scale.append(act_scale)
79
80
  self.wa_forward.append(wa_forward)
80
-
81
+
81
82
  x = self.linears[i](x)
82
83
  if i < self.depth - 1:
83
84
  x = self.act_fun(x)
@@ -85,9 +86,9 @@ class MLP(nn.Module):
85
86
  if self.save_act:
86
87
  act_scale = torch.std(x, dim=0)
87
88
  self.acts_scale.append(act_scale)
88
-
89
+
89
90
  return x
90
-
91
+
91
92
  def attribute(self):
92
93
  if self.acts == None:
93
94
  self.get_act()
@@ -99,47 +100,47 @@ class MLP(nn.Module):
99
100
  node_score = torch.ones(self.width[-1]).requires_grad_(True).to(self.device)
100
101
  node_scores.append(node_score)
101
102
 
102
- for l in range(self.depth,0,-1):
103
-
104
- edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l-1]), node_score/(self.acts_scale[l-1]+1e-4))
103
+ for l in range(self.depth, 0, -1):
104
+ edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l - 1]),
105
+ node_score / (self.acts_scale[l - 1] + 1e-4))
105
106
  edge_scores.append(edge_score)
106
107
 
107
108
  # this might be improper for MLPs (although reasonable for KANs)
108
- node_score = torch.sum(edge_score, dim=0)/torch.sqrt(torch.tensor(self.width[l-1], device=self.device))
109
+ node_score = torch.sum(edge_score, dim=0) / torch.sqrt(torch.tensor(self.width[l - 1], device=self.device))
109
110
  #print(self.width[l])
110
111
  node_scores.append(node_score)
111
112
 
112
113
  self.node_scores = list(reversed(node_scores))
113
114
  self.edge_scores = list(reversed(edge_scores))
114
115
  self.wa_backward = self.edge_scores
115
-
116
+
116
117
  def plot(self, beta=3, scale=1., metric='w'):
117
118
  # metric = 'w', 'act' or 'fa'
118
-
119
+
119
120
  if metric == 'fa':
120
121
  self.attribute()
121
-
122
+
122
123
  depth = self.depth
123
124
  y0 = 0.5
124
- fig, ax = plt.subplots(figsize=(3*scale,3*y0*depth*scale))
125
+ fig, ax = plt.subplots(figsize=(3 * scale, 3 * y0 * depth * scale))
125
126
  shp = self.width
126
-
127
- min_spacing = 1/max(self.width)
127
+
128
+ min_spacing = 1 / max(self.width)
128
129
  for j in range(len(shp)):
129
130
  N = shp[j]
130
131
  for i in range(N):
131
132
  plt.scatter(1 / (2 * N) + i / N, j * y0, s=min_spacing ** 2 * 5000 * scale ** 2, color='black')
132
-
133
- plt.ylim(-0.1*y0,y0*depth+0.1*y0)
134
- plt.xlim(-0.02,1.02)
133
+
134
+ plt.ylim(-0.1 * y0, y0 * depth + 0.1 * y0)
135
+ plt.xlim(-0.02, 1.02)
135
136
 
136
137
  linears = self.linears
137
-
138
+
138
139
  for ii in range(len(linears)):
139
140
  linear = linears[ii]
140
141
  p = linear.weight
141
142
  p_shp = p.shape
142
-
143
+
143
144
  if metric == 'w':
144
145
  pass
145
146
  elif metric == 'act':
@@ -150,12 +151,15 @@ class MLP(nn.Module):
150
151
  raise Exception('metric = \'{}\' not recognized. Choices are \'w\', \'act\', \'fa\'.'.format(metric))
151
152
  for i in range(p_shp[0]):
152
153
  for j in range(p_shp[1]):
153
- plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]], [y0*(ii+1),y0*ii], lw=0.5*scale, alpha=np.tanh(beta*np.abs(p[i,j].cpu().detach().numpy())), color="blue" if p[i,j]>0 else "red")
154
-
154
+ plt.plot([1 / (2 * p_shp[0]) + i / p_shp[0], 1 / (2 * p_shp[1]) + j / p_shp[1]],
155
+ [y0 * (ii + 1), y0 * ii], lw=0.5 * scale,
156
+ alpha=np.tanh(beta * np.abs(p[i, j].cpu().detach().numpy())),
157
+ color="blue" if p[i, j] > 0 else "red")
158
+
155
159
  ax.axis('off')
156
-
160
+
157
161
  def reg(self, reg_metric, lamb_l1, lamb_entropy):
158
-
162
+
159
163
  if reg_metric == 'w':
160
164
  acts_scale = self.w
161
165
  if reg_metric == 'act':
@@ -164,7 +168,7 @@ class MLP(nn.Module):
164
168
  acts_scale = self.wa_backward
165
169
  if reg_metric == 'a':
166
170
  acts_scale = self.acts_scale
167
-
171
+
168
172
  if len(acts_scale[0].shape) == 2:
169
173
  reg_ = 0.
170
174
 
@@ -178,9 +182,9 @@ class MLP(nn.Module):
178
182
  entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1))
179
183
  entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0))
180
184
  reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col)
181
-
185
+
182
186
  elif len(acts_scale[0].shape) == 1:
183
-
187
+
184
188
  reg_ = 0.
185
189
 
186
190
  for i in range(len(acts_scale)):
@@ -193,20 +197,21 @@ class MLP(nn.Module):
193
197
  reg_ += lamb_l1 * l1 + lamb_entropy * entropy
194
198
 
195
199
  return reg_
196
-
200
+
197
201
  def get_reg(self, reg_metric, lamb_l1, lamb_entropy):
198
202
  return self.reg(reg_metric, lamb_l1, lamb_entropy)
199
-
200
- def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1., batch=-1,
201
- metrics=None, in_vars=None, out_vars=None, beta=3, device='cpu', reg_metric='w', display_metrics=None):
203
+
204
+ def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1.,
205
+ batch=-1,
206
+ metrics=None, device='cpu', reg_metric='w', display_metrics=None):
202
207
 
203
208
  if lamb > 0. and not self.save_act:
204
209
  print('setting lamb=0. If you want to set lamb > 0, set =True')
205
-
210
+
206
211
  old_save_act = self.save_act
207
212
  if lamb == 0.:
208
213
  self.save_act = False
209
-
214
+
210
215
  pbar = tqdm(range(steps), desc='description', ncols=100)
211
216
 
212
217
  if loss_fn == None:
@@ -217,7 +222,8 @@ class MLP(nn.Module):
217
222
  if opt == "Adam":
218
223
  optimizer = torch.optim.Adam(self.parameters(), lr=lr)
219
224
  elif opt == "LBFGS":
220
- optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
225
+ optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
226
+ tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
221
227
 
222
228
  results = {}
223
229
  results['train_loss'] = []
@@ -252,10 +258,10 @@ class MLP(nn.Module):
252
258
  return objective
253
259
 
254
260
  for _ in pbar:
255
-
256
- if _ == steps-1 and old_save_act:
261
+
262
+ if _ == steps - 1 and old_save_act:
257
263
  self.save_act = True
258
-
264
+
259
265
  train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
260
266
  test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
261
267
 
@@ -274,9 +280,9 @@ class MLP(nn.Module):
274
280
  loss.backward()
275
281
  optimizer.step()
276
282
 
277
- test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)), dataset['test_label'][test_id].to(self.device))
278
-
279
-
283
+ test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)),
284
+ dataset['test_label'][test_id].to(self.device))
285
+
280
286
  if metrics != None:
281
287
  for i in range(len(metrics)):
282
288
  results[metrics[i].__name__].append(metrics[i]().item())
@@ -287,7 +293,9 @@ class MLP(nn.Module):
287
293
 
288
294
  if _ % log == 0:
289
295
  if display_metrics == None:
290
- pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
296
+ pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (
297
+ torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(),
298
+ reg_.cpu().detach().numpy()))
291
299
  else:
292
300
  string = ''
293
301
  data = ()
@@ -299,9 +307,9 @@ class MLP(nn.Module):
299
307
  raise Exception(f'{metric} not recognized')
300
308
  data += (results[metric][-1],)
301
309
  pbar.set_description(string % data)
302
-
310
+
303
311
  return results
304
-
312
+
305
313
  @property
306
314
  def connection_cost(self):
307
315
 
@@ -309,8 +317,9 @@ class MLP(nn.Module):
309
317
  cc = 0.
310
318
  for linear in self.linears:
311
319
  t = torch.abs(linear.weight)
320
+
312
321
  def get_coordinate(n):
313
- return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n)
322
+ return torch.linspace(0, 1, steps=n + 1, device=self.device)[:n] + 1 / (2 * n)
314
323
 
315
324
  in_dim = t.shape[0]
316
325
  x_in = get_coordinate(in_dim)
@@ -318,44 +327,45 @@ class MLP(nn.Module):
318
327
  out_dim = t.shape[1]
319
328
  x_out = get_coordinate(out_dim)
320
329
 
321
- dist = torch.abs(x_in[:,None] - x_out[None,:])
330
+ dist = torch.abs(x_in[:, None] - x_out[None, :])
322
331
  cc += torch.sum(dist * t)
323
332
 
324
333
  return cc
325
-
334
+
326
335
  def swap(self, l, i1, i2):
327
336
 
328
337
  def swap_row(data, i1, i2):
329
338
  data[i1], data[i2] = data[i2].clone(), data[i1].clone()
330
339
 
331
340
  def swap_col(data, i1, i2):
332
- data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
341
+ data[:, i1], data[:, i2] = data[:, i2].clone(), data[:, i1].clone()
333
342
 
334
- swap_row(self.linears[l-1].weight.data, i1, i2)
335
- swap_row(self.linears[l-1].bias.data, i1, i2)
343
+ swap_row(self.linears[l - 1].weight.data, i1, i2)
344
+ swap_row(self.linears[l - 1].bias.data, i1, i2)
336
345
  swap_col(self.linears[l].weight.data, i1, i2)
337
-
346
+
338
347
  def auto_swap_l(self, l):
339
348
 
340
349
  num = self.width[l]
341
350
  for i in range(num):
342
351
  ccs = []
343
352
  for j in range(num):
344
- self.swap(l,i,j)
353
+ self.swap(l, i, j)
345
354
  self.get_act()
346
355
  self.attribute()
347
356
  cc = self.connection_cost.detach().clone()
348
357
  ccs.append(cc)
349
- self.swap(l,i,j)
358
+ self.swap(l, i, j)
350
359
  j = torch.argmin(torch.tensor(ccs))
351
- self.swap(l,i,j)
360
+ self.swap(l, i, j)
352
361
 
353
362
  def auto_swap(self):
354
363
  depth = self.depth
355
364
  for l in range(1, depth):
356
365
  self.auto_swap_l(l)
357
-
366
+
358
367
  def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
359
368
  if x == None:
360
369
  x = self.cache_data
361
- plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test, verbose=verbose)
370
+ plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test,
371
+ verbose=verbose)
yms_kan/MultKAN.py CHANGED
@@ -794,6 +794,8 @@ class MultKAN(nn.Module):
794
794
  # >>> print(model(x, singularity_avoiding=True))
795
795
  # >>> print(model(x, singularity_avoiding=True, y_th=1.))
796
796
  """
797
+ # x = abs(torch.fft(x, dim=-1,norm='forward'))
798
+ # _,x = x.chunk(2,dim=-1)
797
799
  x = x[:, self.input_id.long()]
798
800
  assert x.shape[1] == self.width_in[0]
799
801
 
@@ -1063,7 +1065,7 @@ class MultKAN(nn.Module):
1063
1065
  ha='center', va='center', transform=ax.transData)
1064
1066
 
1065
1067
  def plot(self, folder="./figures", beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None,
1066
- out_vars=None, title=None, varscale=1.0):
1068
+ out_vars=None, title=None, varscale=1.0, dpi=100):
1067
1069
  '''
1068
1070
  plot KAN
1069
1071
 
@@ -1164,7 +1166,7 @@ class MultKAN(nn.Module):
1164
1166
  s=400 * scale ** 2)
1165
1167
  plt.gca().spines[:].set_color(color)
1166
1168
 
1167
- plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400)
1169
+ plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=dpi)
1168
1170
  plt.close()
1169
1171
 
1170
1172
  def score2alpha(score):
@@ -1647,7 +1649,7 @@ class MultKAN(nn.Module):
1647
1649
 
1648
1650
  if save_fig and _ % save_fig_freq == 0:
1649
1651
  self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
1650
- plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=100)
1652
+ plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight')
1651
1653
  plt.close()
1652
1654
  self.save_act = save_act
1653
1655
 
@@ -1857,7 +1859,7 @@ class MultKAN(nn.Module):
1857
1859
  if save_fig and epoch % save_fig_freq == 0:
1858
1860
  self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
1859
1861
  beta=beta)
1860
- plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
1862
+ plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight')
1861
1863
  plt.close()
1862
1864
  self.save_act = save_act
1863
1865
 
@@ -13,11 +13,11 @@ from yms_kan.plotting import plot_confusion_matrix
13
13
  from yms_kan.tool import initialize_results_file, append_to_results_file, calculate_metric
14
14
 
15
15
 
16
- def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_file=None, opt="LBFGS", epochs=100,
16
+ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path=None, txt_file=None, opt="LBFGS", epochs=100,
17
17
  lamb=0.,
18
18
  lamb_l1=1., label=None, class_dict=None, lamb_entropy=2., lamb_coef=0.,
19
19
  lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
20
- stop_grid_update_step=100,
20
+ stop_grid_update_epoch=100,
21
21
  save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
22
22
  singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
23
23
  best = -1
@@ -42,7 +42,6 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
42
42
  else:
43
43
  loss_fn = loss_fn
44
44
 
45
- grid_update_freq = int(stop_grid_update_step / grid_update_num)
46
45
 
47
46
  if opt == "Adam":
48
47
  optimizer = torch.optim.Adam(model.get_params(), lr=lr)
@@ -55,9 +54,11 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
55
54
  lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
56
55
 
57
56
  results = {'train_losses': [], 'val_losses': [], 'accuracies': [], 'precisions': [], 'recalls': [], 'f1-scores': [],
58
- 'lrs': [], 'all_predictions': [], 'all_labels': [], 'regularize': []}
57
+ 'lrs': [], 'regularize': []} # , 'all_predictions': [], 'all_labels': []
59
58
 
60
59
  steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
60
+ stop_grid_update_step = stop_grid_update_epoch * steps
61
+ grid_update_freq = int(stop_grid_update_step / grid_update_num)
61
62
 
62
63
  train_loss = torch.zeros(1).to(model.device)
63
64
  reg_ = torch.zeros(1).to(model.device)
@@ -85,7 +86,6 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
85
86
  os.makedirs(img_folder)
86
87
 
87
88
  for epoch in range(epochs):
88
-
89
89
  if epoch == epochs - 1 and old_save_act:
90
90
  model.save_act = True
91
91
 
@@ -172,7 +172,9 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
172
172
  best = m["f1-score"]
173
173
  results['all_predictions'] = all_predictions
174
174
  results['all_labels'] = all_labels
175
- plot_confusion_matrix(all_labels, all_predictions, class_dict, save_path)
175
+
176
+ # if save_path is not None:
177
+ # plot_confusion_matrix(all_labels, all_predictions, class_dict, save_path)
176
178
  # if save_path is not None:
177
179
  # model.saveckpt(path=(os.path.join(save_path, 'save_model') + '/' + 'model'))
178
180
  if txt_file is not None:
@@ -196,156 +198,156 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, save_path, txt_
196
198
  model.symbolic_enabled = old_symbolic_enabled
197
199
  return results
198
200
 
199
- # def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
200
- # lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
201
- # lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
202
- # stop_grid_update_step=100,
203
- # save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
204
- # singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
205
- # # result_info = ['epoch','train_losses', 'val_losses', 'regularize', 'accuracies',
206
- # # 'precisions', 'recalls', 'f1-scores']
207
- # # initialize_results_file(results_file, result_info)
208
- # all_predictions = []
209
- # all_labels = []
210
- # if lamb > 0. and not model.save_act:
211
- # print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
212
- #
213
- # old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
214
- # if label is not None:
215
- # label = label.to(model.device)
216
- #
217
- # if loss_fn is None:
218
- # loss_fn = lambda x, y: torch.mean((x - y) ** 2)
219
- # else:
220
- # loss_fn = loss_fn
221
- #
222
- # grid_update_freq = int(stop_grid_update_step / grid_update_num)
223
- #
224
- # if opt == "Adam":
225
- # optimizer = torch.optim.Adam(model.get_params(), lr=lr)
226
- # elif opt == "LBFGS":
227
- # optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
228
- # tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
229
- # else:
230
- # optimizer = torch.optim.SGD(model.get_params(), lr=lr)
231
- #
232
- # lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
233
- #
234
- # results = {'train_losses': [], 'val_losses': [], 'regularize': [], 'accuracies': [],
235
- # 'precisions': [], 'recalls': [], 'f1-scores': []}
236
- #
237
- # steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
238
- #
239
- # train_loss = torch.zeros(1).to(model.device)
240
- # reg_ = torch.zeros(1).to(model.device)
241
- #
242
- # def closure():
243
- # nonlocal train_loss, reg_
244
- # optimizer.zero_grad()
245
- # pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
246
- # loss = loss_fn(pred, batch_train_label)
247
- # if model.save_act:
248
- # if reg_metric == 'edge_backward':
249
- # model.attribute()
250
- # if reg_metric == 'node_backward':
251
- # model.node_attribute()
252
- # reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
253
- # else:
254
- # reg_ = torch.tensor(0.)
255
- # objective = loss + lamb * reg_
256
- # train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
257
- # objective.backward()
258
- # return objective
259
- #
260
- # if save_fig:
261
- # if not os.path.exists(img_folder):
262
- # os.makedirs(img_folder)
263
- #
264
- # for epoch in range(epochs):
265
- #
266
- # if epoch == epochs - 1 and old_save_act:
267
- # model.save_act = True
268
- #
269
- # if save_fig and epoch % save_fig_freq == 0:
270
- # save_act = model.save_act
271
- # model.save_act = True
272
- #
273
- # train_indices = np.arange(dataset['train_input'].shape[0])
274
- # np.random.shuffle(train_indices)
275
- # train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
276
- # for batch_num in train_pbar:
277
- # step = epoch * steps + batch_num + 1
278
- # i = batch_num * batch_size
279
- # batch_train_id = train_indices[i:i + batch_size]
280
- # batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
281
- # batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
282
- #
283
- # if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
284
- # model.update_grid(batch_train_input)
285
- #
286
- # if opt == "LBFGS":
287
- # optimizer.step(closure)
288
- #
289
- # else:
290
- # optimizer.zero_grad()
291
- # pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
292
- # y_th=y_th)
293
- # loss = loss_fn(pred, batch_train_label)
294
- # if model.save_act:
295
- # if reg_metric == 'edge_backward':
296
- # model.attribute()
297
- # if reg_metric == 'node_backward':
298
- # model.node_attribute()
299
- # reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
300
- # else:
301
- # reg_ = torch.tensor(0.)
302
- # loss = loss + lamb * reg_
303
- # train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
304
- # loss.backward()
305
- # optimizer.step()
306
- # train_pbar.set_postfix(loss=train_loss.item())
307
- #
308
- # print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
309
- # val_loss = torch.zeros(1).to(model.device)
310
- # with torch.no_grad():
311
- # test_indices = np.arange(dataset['test_input'].shape[0])
312
- # np.random.shuffle(test_indices)
313
- # test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
314
- # test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
315
- # for batch_num in test_pbar:
316
- # i = batch_num * batch_size_test
317
- # batch_test_id = test_indices[i:i + batch_size_test]
318
- # batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
319
- # batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
320
- #
321
- # outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
322
- # y_th=y_th)
323
- #
324
- # loss = loss_fn(outputs, batch_test_label)
325
- #
326
- # val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
327
- # test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
328
- # if label is not None:
329
- # diffs = torch.abs(outputs - label)
330
- # closest_indices = torch.argmin(diffs, dim=1)
331
- # closest_values = label[closest_indices]
332
- # all_predictions.extend(closest_values.detach().cpu().numpy())
333
- # all_labels.extend(batch_test_label.detach().cpu().numpy())
334
- #
335
- # lr_scheduler.step(val_loss)
336
- #
337
- # results['train_losses'].append(train_loss.cpu().item())
338
- # results['val_losses'].append(val_loss.cpu().item())
339
- # results['regularize'].append(reg_.cpu().item())
340
- #
341
- # if save_fig and epoch % save_fig_freq == 0:
342
- # model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
343
- # beta=beta)
344
- # plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
345
- # plt.close()
346
- # model.save_act = save_act
347
- #
348
- # # append_to_results_file(results_file, results, result_info)
349
- # model.log_history('fit')
350
- # model.symbolic_enabled = old_symbolic_enabled
351
- # return results
201
+
202
+ def fit(model, dataset, batch_size, opt="LBFGS", epochs=100, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None,
203
+ lr=1., label=None, class_dict=None,
204
+ txt_file=None,
205
+ reg_metric='w'):
206
+ best = -1
207
+ column_order = ['epoch', 'train_losses', 'val_losses', 'accuracies', 'precisions', 'recalls',
208
+ 'f1-scores', 'lrs']
209
+ custom_column_widths = {'epoch': 5, 'train_loss': 12, 'val_loss': 10, 'accuracy': 10, 'precision': 9,
210
+ 'recall': 7,
211
+ 'f1-score': 8,
212
+ 'lr': 3}
213
+ if txt_file is not None:
214
+ initialize_results_file(txt_file, column_order)
215
+ if lamb > 0. and not model.save_act:
216
+ print('setting lamb=0. If you want to set lamb > 0, set =True')
217
+
218
+ old_save_act = model.save_act
219
+ if lamb == 0.:
220
+ model.save_act = False
221
+
222
+ # pbar = tqdm(range(steps), desc='description', ncols=100)
223
+
224
+ if loss_fn == None:
225
+ loss_fn = lambda x, y: torch.mean((x - y) ** 2)
226
+ else:
227
+ loss_fn = loss_fn
228
+
229
+ if opt == "Adam":
230
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
231
+ elif opt == "LBFGS":
232
+ optimizer = LBFGS(model.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
233
+ tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
234
+ else:
235
+ optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True)
236
+
237
+ lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
238
+
239
+ results = {'train_losses': [], 'val_losses': [], 'accuracies': [], 'precisions': [], 'recalls': [], 'f1-scores': [],
240
+ 'lrs': [], 'regularize': []}
241
+
242
+ steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
243
+
244
+ train_loss = torch.zeros(1).to(model.device)
245
+ reg_ = torch.zeros(1).to(model.device)
246
+
247
+ def closure():
248
+ nonlocal train_loss, reg_
249
+ optimizer.zero_grad()
250
+ pred = model.forward(batch_train_input)
251
+ loss = loss_fn(pred, batch_train_label)
252
+ if model.save_act:
253
+ if reg_metric == 'edge_backward':
254
+ model.attribute()
255
+ if reg_metric == 'node_backward':
256
+ model.node_attribute()
257
+ reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy)
258
+ else:
259
+ reg_ = torch.tensor(0.)
260
+ objective = loss + lamb * reg_
261
+ train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
262
+ objective.backward()
263
+ return objective
264
+
265
+ for epoch in range(epochs):
266
+
267
+ if epoch == steps - 1 and old_save_act:
268
+ model.save_act = True
269
+
270
+ train_indices = np.arange(dataset['train_input'].shape[0])
271
+ np.random.shuffle(train_indices)
272
+ train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
273
+
274
+ for batch_num in train_pbar:
275
+ step = epoch * steps + batch_num + 1
276
+ i = batch_num * batch_size
277
+ batch_train_id = train_indices[i:i + batch_size]
278
+ batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
279
+ batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
280
+
281
+ if opt == "LBFGS":
282
+ optimizer.step(closure)
283
+
284
+ if opt == "Adam":
285
+ optimizer.zero_grad()
286
+ pred = model.forward(batch_train_input)
287
+ train_loss = loss_fn(pred, batch_train_input)
288
+ if model.save_act:
289
+ reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy)
290
+ else:
291
+ reg_ = torch.tensor(0.)
292
+ loss = train_loss + lamb * reg_
293
+ train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
294
+
295
+ loss.backward()
296
+ optimizer.step()
297
+
298
+ train_pbar.set_postfix(loss=train_loss.item())
299
+
300
+ val_loss = torch.zeros(1).to(model.device)
301
+ with torch.no_grad():
302
+ all_predictions = []
303
+ all_labels = []
304
+ test_indices = np.arange(dataset['test_input'].shape[0])
305
+ np.random.shuffle(test_indices)
306
+ test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size)
307
+ test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
308
+ for batch_num in test_pbar:
309
+ i = batch_num * batch_size
310
+ batch_test_id = test_indices[i:i + batch_size]
311
+ batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
312
+ batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
313
+ outputs = model.forward(batch_test_input)
314
+ loss = loss_fn(outputs, batch_test_label)
315
+ val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
316
+ test_pbar.set_postfix(val_loss=val_loss.item())
317
+ if label is not None:
318
+ diffs = torch.abs(outputs - label)
319
+ closest_indices = torch.argmin(diffs, dim=1)
320
+ closest_values = label[closest_indices]
321
+ all_predictions.extend(closest_values.detach().cpu().numpy())
322
+ all_labels.extend(batch_test_label.detach().cpu().numpy())
323
+
324
+ train_lr = lr_scheduler.get_last_lr()[0]
325
+ lr_scheduler.step(val_loss)
326
+
327
+ if label is not None:
328
+ m = calculate_metric(all_labels, all_predictions, class_dict)
329
+ print(m)
330
+ results["accuracies"].append(m["accuracy"])
331
+ results["precisions"].append(m["precision"])
332
+ results["recalls"].append(m["recall"])
333
+ results["f1-scores"].append(m["f1-score"])
334
+ results["lrs"].append(train_lr)
335
+ if best < m["f1-score"]:
336
+ best = m["f1-score"]
337
+ results['all_predictions'] = all_predictions
338
+ results['all_labels'] = all_labels
339
+
340
+ # if save_path is not None:
341
+ # plot_confusion_matrix(all_labels, all_predictions, class_dict, save_path)
342
+ # if save_path is not None:
343
+ # model.saveckpt(path=(os.path.join(save_path, 'save_model') + '/' + 'model'))
344
+ if txt_file is not None:
345
+ m.update({'lr': train_lr, 'epoch': epoch, 'train_loss': train_loss.item(), 'val_loss': val_loss.item()})
346
+ append_to_results_file(txt_file, m, column_order,
347
+ custom_column_widths=custom_column_widths)
348
+
349
+ results["train_losses"].append(train_loss.item())
350
+ results["val_losses"].append(val_loss.item())
351
+ results["regularize"].append(reg_.item())
352
+
353
+ return results
yms_kan/utils.py CHANGED
@@ -75,9 +75,9 @@ SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)),
75
75
  '0': (lambda x: x * 0, lambda x: x * 0, 0, lambda x, y_th: ((), x * 0)),
76
76
  'gaussian': (lambda x: torch.exp(-x ** 2), lambda x: sympy.exp(-x ** 2), 3,
77
77
  lambda x, y_th: ((), torch.exp(-x ** 2))),
78
- #'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5),
79
- #'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4),
80
- #'relu': (lambda x: torch.relu(x), relu),
78
+ # 'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5),
79
+ # 'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4),
80
+ # 'relu': (lambda x: torch.relu(x), relu),
81
81
  }
82
82
 
83
83
 
@@ -465,6 +465,78 @@ def batch_hessian(model, x, create_graph=False):
465
465
  return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1, 0, 2)
466
466
 
467
467
 
468
+ # def create_from_data(inputs, labels, ratios=[0.8, 0.2], device='cpu'):
469
+ # # 参数校验
470
+ # if not (2 <= len(ratios) <= 3):
471
+ # raise ValueError("ratios必须是长度为2或3的列表")
472
+ # if not np.isclose(sum(ratios), 1.0):
473
+ # raise ValueError("ratios元素之和必须为1")
474
+ # for r in ratios:
475
+ # if r < 0 or r > 1:
476
+ # raise ValueError("ratios元素必须在0-1之间")
477
+ #
478
+ # from collections import defaultdict
479
+ # class_indices = defaultdict(list)
480
+ # for idx, label in enumerate(labels):
481
+ # class_indices[label.item()].append(idx)
482
+ #
483
+ # # 初始化各数据集索引
484
+ # split_indices = defaultdict(list)
485
+ #
486
+ # # 分层抽样
487
+ # for class_label, indices in class_indices.items():
488
+ # if not indices:
489
+ # continue
490
+ #
491
+ # num_samples = len(indices)
492
+ # np.random.shuffle(indices) # 先打乱类别内样本
493
+ #
494
+ # # 计算各子集分割点
495
+ # split_points = np.cumsum(ratios).tolist()[:-1] # 计算分割点(不含最后一个1)
496
+ # splits = np.split(indices, [int(num_samples * p) for p in split_points])
497
+ #
498
+ # # 分配到对应数据集
499
+ # for i, subset in enumerate(splits):
500
+ # if i == 0:
501
+ # split_indices['train'].extend(subset)
502
+ # elif i == 1:
503
+ # split_indices['test'].extend(subset)
504
+ # elif i == 2:
505
+ # split_indices['val'].extend(subset)
506
+ #
507
+ # # 处理数据集合并(验证集可选)
508
+ # train_val = {}
509
+ # if 'test' in split_indices:
510
+ # # 合并训练+验证并打乱
511
+ # train_val_idx = np.concatenate([
512
+ # np.array(split_indices['train']),
513
+ # np.array(split_indices['val'])
514
+ # ])
515
+ # np.random.shuffle(train_val_idx)
516
+ # train_val = {
517
+ # 'train_input': inputs[train_val_idx[:len(split_indices['train'])]].detach().to(device),
518
+ # 'train_label': labels[train_val_idx[:len(split_indices['train'])]].detach().to(device),
519
+ # 'test_input': inputs[train_val_idx[len(split_indices['train']):]].detach().to(device),
520
+ # 'test_label': labels[train_val_idx[len(split_indices['train']):]].detach().to(device)
521
+ # }
522
+ # else:
523
+ # # 只有训练集
524
+ # train_idx = np.array(split_indices['train'])
525
+ # np.random.shuffle(train_idx)
526
+ # train_val = {
527
+ # 'train_input': inputs[train_idx].detach().to(device),
528
+ # 'train_label': labels[train_idx].detach().to(device)
529
+ # }
530
+ #
531
+ # # 处理测试集
532
+ # test_idx = np.array(split_indices.get('val', []))
533
+ # np.random.shuffle(test_idx)
534
+ # test_set = {
535
+ # 'val_input': inputs[test_idx].detach().to(device) if test_idx.size else None,
536
+ # 'val_label': labels[test_idx].detach().to(device) if test_idx.size else None
537
+ # }
538
+ #
539
+ # return train_val, test_set
468
540
  def create_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
469
541
  from collections import defaultdict
470
542
  class_indices = defaultdict(list)
yms_kan/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.0.9" # 初始版本
1
+ __version__ = "0.0.10" # 初始版本
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: yms_kan
3
- Version: 0.0.9
3
+ Version: 0.0.10
4
4
  Summary: My awesome package
5
5
  Author-email: yms <11@qq.com>
6
6
  License-Expression: MIT
@@ -1,7 +1,7 @@
1
1
  yms_kan/KANLayer.py,sha256=-V2Fh5wvPYvfF1tmQVxJKWvvaAHiwo2EiFpd8VDgB1c,14149
2
2
  yms_kan/LBFGS.py,sha256=OPeRPDp40jaVH4qPoBDMEub7TPhyvw7pbqwQar3OZ1A,17620
3
- yms_kan/MLP.py,sha256=ryLzSuBrsGlSHRLwnQZCNj-Ru9BwXJYoHNkAwX14N64,12804
4
- yms_kan/MultKAN.py,sha256=n58W6tORBDuh_-pUVp-ER-R9KNJKdYjFaisDx8wJpWw,122113
3
+ yms_kan/MLP.py,sha256=JFLogd1EPFVCrBJJtvNMNu68ejdUJ2O6qYc7l4pfFFI,12728
4
+ yms_kan/MultKAN.py,sha256=eFh5jCGRPUzrpeWkvonGGzNWK4w0aODN8Q6M9wZ5IaY,122193
5
5
  yms_kan/Symbolic_KANLayer.py,sha256=WhJzC5IMIpXI_K7aYamOrWTK7uckxVdsM9N4oLZMO3I,9897
6
6
  yms_kan/__init__.py,sha256=O2c6DIG4PHavXF2v7K9jNqMbJXWr4-gTN3Vs1YSlc64,120
7
7
  yms_kan/compiler.py,sha256=7bVwDNX0xmLAjQ8V1FdmkIIIibmy_W5eaeSKBlYL0Vc,18632
@@ -11,13 +11,13 @@ yms_kan/hypothesis.py,sha256=Ec20xadfgOSSWeZHQaGn-h9F2PY7LWFU3iniNI2Zd_4,23165
11
11
  yms_kan/plotting.py,sha256=Moi6QTJQxHjutGMgxR9oSsqZSzYY3TP-7WNapdCIqzw,18097
12
12
  yms_kan/spline.py,sha256=ZXyGwl2Sc-UrnrcuUXeUQkBOMnetaWcHrbpZaqatCvs,4345
13
13
  yms_kan/tool.py,sha256=rkRpqF3EcsAq7a3k1F1zKlxfJ4U9n-FzHyNCJgN4URY,21159
14
- yms_kan/train_eval_utils.py,sha256=ZY_2GbSjVNAeaGQ24tq1NdjHKrAKKhpTGyamrL98Ap4,16713
15
- yms_kan/utils.py,sha256=J07L-tgmc1OfU6Tl6mGwHJRizjFN75EJK8BxejaZLUc,23860
16
- yms_kan/version.py,sha256=BAglq1pSrHfsSqSyAI9RhpMrPmvDRjt0yW5FHNU8gT0,39
14
+ yms_kan/train_eval_utils.py,sha256=3WPtCKLcrotU92s4S0uuIa1rXOAHxyAfDwFJGwUxvy0,16210
15
+ yms_kan/utils.py,sha256=k1fZvv9P6vBBV7LMysoTL2j-bglkBWO0l31dNkWI_Jo,26763
16
+ yms_kan/version.py,sha256=ts9Xi3n2P07g5eVEUSK46avv_nYOOTJ_EeHV2X6IfhM,40
17
17
  yms_kan/assets/img/mult_symbol.png,sha256=2f4xUKdweft-qUbHjFI5h9-smnEtc0FWq8hNYZhPAXY,6392
18
18
  yms_kan/assets/img/sum_symbol.png,sha256=94QkMUzmEjlCq_yf14nMEQmettaq86FmlGfdl22b4XE,6210
19
- yms_kan-0.0.9.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
20
- yms_kan-0.0.9.dist-info/METADATA,sha256=5CcbAeKN87Gh4nrcbkOndaEq734S7Ac4Y4ewkK5D_LU,240
21
- yms_kan-0.0.9.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
22
- yms_kan-0.0.9.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
23
- yms_kan-0.0.9.dist-info/RECORD,,
19
+ yms_kan-0.0.10.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
20
+ yms_kan-0.0.10.dist-info/METADATA,sha256=pjZJ0E5OIBv93IoXD10lXpQCPN9ixDIRDuNUwov6Fls,241
21
+ yms_kan-0.0.10.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
22
+ yms_kan-0.0.10.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
23
+ yms_kan-0.0.10.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (79.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5