yms-kan 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yms_kan/MLP.py ADDED
@@ -0,0 +1,361 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from .LBFGS import LBFGS
7
+
8
+ seed = 0
9
+ torch.manual_seed(seed)
10
+
11
+ class MLP(nn.Module):
12
+
13
+ def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
14
+ super(MLP, self).__init__()
15
+
16
+ torch.manual_seed(seed)
17
+
18
+ linears = []
19
+ self.width = width
20
+ self.depth = depth = len(width) - 1
21
+ for i in range(depth):
22
+ linears.append(nn.Linear(width[i], width[i+1]))
23
+ self.linears = nn.ModuleList(linears)
24
+
25
+ #if activation == 'silu':
26
+ self.act_fun = torch.nn.SiLU()
27
+ self.save_act = save_act
28
+ self.acts = None
29
+
30
+ self.cache_data = None
31
+
32
+ self.device = device
33
+ self.to(device)
34
+
35
+
36
+ def to(self, device):
37
+ super(MLP, self).to(device)
38
+ self.device = device
39
+
40
+ return self
41
+
42
+
43
+ def get_act(self, x=None):
44
+ if isinstance(x, dict):
45
+ x = x['train_input']
46
+ if x == None:
47
+ if self.cache_data != None:
48
+ x = self.cache_data
49
+ else:
50
+ raise Exception("missing input data x")
51
+ save_act = self.save_act
52
+ self.save_act = True
53
+ self.forward(x)
54
+ self.save_act = save_act
55
+
56
+ @property
57
+ def w(self):
58
+ return [self.linears[l].weight for l in range(self.depth)]
59
+
60
+ def forward(self, x):
61
+
62
+ # cache data
63
+ self.cache_data = x
64
+
65
+ self.acts = []
66
+ self.acts_scale = []
67
+ self.wa_forward = []
68
+ self.a_forward = []
69
+
70
+ for i in range(self.depth):
71
+
72
+ if self.save_act:
73
+ act = x.clone()
74
+ act_scale = torch.std(x, dim=0)
75
+ wa_forward = act_scale[None, :] * self.linears[i].weight
76
+ self.acts.append(act)
77
+ if i > 0:
78
+ self.acts_scale.append(act_scale)
79
+ self.wa_forward.append(wa_forward)
80
+
81
+ x = self.linears[i](x)
82
+ if i < self.depth - 1:
83
+ x = self.act_fun(x)
84
+ else:
85
+ if self.save_act:
86
+ act_scale = torch.std(x, dim=0)
87
+ self.acts_scale.append(act_scale)
88
+
89
+ return x
90
+
91
+ def attribute(self):
92
+ if self.acts == None:
93
+ self.get_act()
94
+
95
+ node_scores = []
96
+ edge_scores = []
97
+
98
+ # back propagate from the last layer
99
+ node_score = torch.ones(self.width[-1]).requires_grad_(True).to(self.device)
100
+ node_scores.append(node_score)
101
+
102
+ for l in range(self.depth,0,-1):
103
+
104
+ edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l-1]), node_score/(self.acts_scale[l-1]+1e-4))
105
+ edge_scores.append(edge_score)
106
+
107
+ # this might be improper for MLPs (although reasonable for KANs)
108
+ node_score = torch.sum(edge_score, dim=0)/torch.sqrt(torch.tensor(self.width[l-1], device=self.device))
109
+ #print(self.width[l])
110
+ node_scores.append(node_score)
111
+
112
+ self.node_scores = list(reversed(node_scores))
113
+ self.edge_scores = list(reversed(edge_scores))
114
+ self.wa_backward = self.edge_scores
115
+
116
+ def plot(self, beta=3, scale=1., metric='w'):
117
+ # metric = 'w', 'act' or 'fa'
118
+
119
+ if metric == 'fa':
120
+ self.attribute()
121
+
122
+ depth = self.depth
123
+ y0 = 0.5
124
+ fig, ax = plt.subplots(figsize=(3*scale,3*y0*depth*scale))
125
+ shp = self.width
126
+
127
+ min_spacing = 1/max(self.width)
128
+ for j in range(len(shp)):
129
+ N = shp[j]
130
+ for i in range(N):
131
+ plt.scatter(1 / (2 * N) + i / N, j * y0, s=min_spacing ** 2 * 5000 * scale ** 2, color='black')
132
+
133
+ plt.ylim(-0.1*y0,y0*depth+0.1*y0)
134
+ plt.xlim(-0.02,1.02)
135
+
136
+ linears = self.linears
137
+
138
+ for ii in range(len(linears)):
139
+ linear = linears[ii]
140
+ p = linear.weight
141
+ p_shp = p.shape
142
+
143
+ if metric == 'w':
144
+ pass
145
+ elif metric == 'act':
146
+ p = self.wa_forward[ii]
147
+ elif metric == 'fa':
148
+ p = self.wa_backward[ii]
149
+ else:
150
+ raise Exception('metric = \'{}\' not recognized. Choices are \'w\', \'act\', \'fa\'.'.format(metric))
151
+ for i in range(p_shp[0]):
152
+ for j in range(p_shp[1]):
153
+ plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]], [y0*(ii+1),y0*ii], lw=0.5*scale, alpha=np.tanh(beta*np.abs(p[i,j].cpu().detach().numpy())), color="blue" if p[i,j]>0 else "red")
154
+
155
+ ax.axis('off')
156
+
157
+ def reg(self, reg_metric, lamb_l1, lamb_entropy):
158
+
159
+ if reg_metric == 'w':
160
+ acts_scale = self.w
161
+ if reg_metric == 'act':
162
+ acts_scale = self.wa_forward
163
+ if reg_metric == 'fa':
164
+ acts_scale = self.wa_backward
165
+ if reg_metric == 'a':
166
+ acts_scale = self.acts_scale
167
+
168
+ if len(acts_scale[0].shape) == 2:
169
+ reg_ = 0.
170
+
171
+ for i in range(len(acts_scale)):
172
+ vec = acts_scale[i]
173
+ vec = torch.abs(vec)
174
+
175
+ l1 = torch.sum(vec)
176
+ p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1)
177
+ p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1)
178
+ entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1))
179
+ entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0))
180
+ reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col)
181
+
182
+ elif len(acts_scale[0].shape) == 1:
183
+
184
+ reg_ = 0.
185
+
186
+ for i in range(len(acts_scale)):
187
+ vec = acts_scale[i]
188
+ vec = torch.abs(vec)
189
+
190
+ l1 = torch.sum(vec)
191
+ p = vec / (torch.sum(vec) + 1)
192
+ entropy = - torch.sum(p * torch.log2(p + 1e-4))
193
+ reg_ += lamb_l1 * l1 + lamb_entropy * entropy
194
+
195
+ return reg_
196
+
197
+ def get_reg(self, reg_metric, lamb_l1, lamb_entropy):
198
+ return self.reg(reg_metric, lamb_l1, lamb_entropy)
199
+
200
+ def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1., batch=-1,
201
+ metrics=None, in_vars=None, out_vars=None, beta=3, device='cpu', reg_metric='w', display_metrics=None):
202
+
203
+ if lamb > 0. and not self.save_act:
204
+ print('setting lamb=0. If you want to set lamb > 0, set =True')
205
+
206
+ old_save_act = self.save_act
207
+ if lamb == 0.:
208
+ self.save_act = False
209
+
210
+ pbar = tqdm(range(steps), desc='description', ncols=100)
211
+
212
+ if loss_fn == None:
213
+ loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
214
+ else:
215
+ loss_fn = loss_fn_eval = loss_fn
216
+
217
+ if opt == "Adam":
218
+ optimizer = torch.optim.Adam(self.parameters(), lr=lr)
219
+ elif opt == "LBFGS":
220
+ optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
221
+
222
+ results = {}
223
+ results['train_loss'] = []
224
+ results['test_loss'] = []
225
+ results['reg'] = []
226
+ if metrics != None:
227
+ for i in range(len(metrics)):
228
+ results[metrics[i].__name__] = []
229
+
230
+ if batch == -1 or batch > dataset['train_input'].shape[0]:
231
+ batch_size = dataset['train_input'].shape[0]
232
+ batch_size_test = dataset['test_input'].shape[0]
233
+ else:
234
+ batch_size = batch
235
+ batch_size_test = batch
236
+
237
+ global train_loss, reg_
238
+
239
+ def closure():
240
+ global train_loss, reg_
241
+ optimizer.zero_grad()
242
+ pred = self.forward(dataset['train_input'][train_id].to(self.device))
243
+ train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
244
+ if self.save_act:
245
+ if reg_metric == 'fa':
246
+ self.attribute()
247
+ reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy)
248
+ else:
249
+ reg_ = torch.tensor(0.)
250
+ objective = train_loss + lamb * reg_
251
+ objective.backward()
252
+ return objective
253
+
254
+ for _ in pbar:
255
+
256
+ if _ == steps-1 and old_save_act:
257
+ self.save_act = True
258
+
259
+ train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
260
+ test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
261
+
262
+ if opt == "LBFGS":
263
+ optimizer.step(closure)
264
+
265
+ if opt == "Adam":
266
+ pred = self.forward(dataset['train_input'][train_id].to(self.device))
267
+ train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
268
+ if self.save_act:
269
+ reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy)
270
+ else:
271
+ reg_ = torch.tensor(0.)
272
+ loss = train_loss + lamb * reg_
273
+ optimizer.zero_grad()
274
+ loss.backward()
275
+ optimizer.step()
276
+
277
+ test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)), dataset['test_label'][test_id].to(self.device))
278
+
279
+
280
+ if metrics != None:
281
+ for i in range(len(metrics)):
282
+ results[metrics[i].__name__].append(metrics[i]().item())
283
+
284
+ results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
285
+ results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
286
+ results['reg'].append(reg_.cpu().detach().numpy())
287
+
288
+ if _ % log == 0:
289
+ if display_metrics == None:
290
+ pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
291
+ else:
292
+ string = ''
293
+ data = ()
294
+ for metric in display_metrics:
295
+ string += f' {metric}: %.2e |'
296
+ try:
297
+ results[metric]
298
+ except:
299
+ raise Exception(f'{metric} not recognized')
300
+ data += (results[metric][-1],)
301
+ pbar.set_description(string % data)
302
+
303
+ return results
304
+
305
+ @property
306
+ def connection_cost(self):
307
+
308
+ with torch.no_grad():
309
+ cc = 0.
310
+ for linear in self.linears:
311
+ t = torch.abs(linear.weight)
312
+ def get_coordinate(n):
313
+ return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n)
314
+
315
+ in_dim = t.shape[0]
316
+ x_in = get_coordinate(in_dim)
317
+
318
+ out_dim = t.shape[1]
319
+ x_out = get_coordinate(out_dim)
320
+
321
+ dist = torch.abs(x_in[:,None] - x_out[None,:])
322
+ cc += torch.sum(dist * t)
323
+
324
+ return cc
325
+
326
+ def swap(self, l, i1, i2):
327
+
328
+ def swap_row(data, i1, i2):
329
+ data[i1], data[i2] = data[i2].clone(), data[i1].clone()
330
+
331
+ def swap_col(data, i1, i2):
332
+ data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
333
+
334
+ swap_row(self.linears[l-1].weight.data, i1, i2)
335
+ swap_row(self.linears[l-1].bias.data, i1, i2)
336
+ swap_col(self.linears[l].weight.data, i1, i2)
337
+
338
+ def auto_swap_l(self, l):
339
+
340
+ num = self.width[l]
341
+ for i in range(num):
342
+ ccs = []
343
+ for j in range(num):
344
+ self.swap(l,i,j)
345
+ self.get_act()
346
+ self.attribute()
347
+ cc = self.connection_cost.detach().clone()
348
+ ccs.append(cc)
349
+ self.swap(l,i,j)
350
+ j = torch.argmin(torch.tensor(ccs))
351
+ self.swap(l,i,j)
352
+
353
+ def auto_swap(self):
354
+ depth = self.depth
355
+ for l in range(1, depth):
356
+ self.auto_swap_l(l)
357
+
358
+ def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
359
+ if x == None:
360
+ x = self.cache_data
361
+ plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test, verbose=verbose)