yms-kan 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yms_kan/MultKAN.py ADDED
@@ -0,0 +1,3085 @@
1
+ import copy
2
+ import math
3
+ import os
4
+ import random
5
+ import sys
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ # from .MultKANLayer import MultKANLayer
10
+ import pandas as pd
11
+ import sympy
12
+ import torch
13
+ import torch.nn as nn
14
+ import yaml
15
+ from PIL import Image
16
+ from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
17
+ from sympy import *
18
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
19
+ from tqdm import tqdm
20
+
21
+ from .KANLayer import KANLayer
22
+ from .LBFGS import *
23
+ # from .Symbolic_MultKANLayer import *
24
+ from .Symbolic_KANLayer import Symbolic_KANLayer
25
+ from .hypothesis import plot_tree
26
+ from .spline import curve2coef
27
+ from .utils import SYMBOLIC_LIB
28
+
29
+
30
+ class MultKAN(nn.Module):
31
+ """
32
+ KAN class
33
+
34
+ Attributes:
35
+ -----------
36
+ grid : int
37
+ the number of grid intervals
38
+ k : int
39
+ spline order
40
+ act_fun : a list of KANLayers
41
+ symbolic_fun: a list of Symbolic_KANLayer
42
+ depth : int
43
+ depth of KAN
44
+ width : list
45
+ number of neurons in each layer.
46
+ Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
47
+ With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2).
48
+ mult_arity : int, or list of int lists
49
+ multiplication arity for each multiplication node (the number of numbers to be multiplied)
50
+ grid : int
51
+ the number of grid intervals
52
+ k : int
53
+ the order of piecewise polynomial
54
+ base_fun : fun
55
+ residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
56
+ symbolic_fun : a list of Symbolic_KANLayer
57
+ Symbolic_KANLayers
58
+ symbolic_enabled : bool
59
+ If False, the symbolic front is not computed (to save time). Default: True.
60
+ width_in : list
61
+ The number of input neurons for each layer
62
+ width_out : list
63
+ The number of output neurons for each layer
64
+ base_fun_name : str
65
+ The base function b(x)
66
+ grip_eps : float
67
+ The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
68
+ node_bias : a list of 1D torch.float
69
+ node_scale : a list of 1D torch.float
70
+ subnode_bias : a list of 1D torch.float
71
+ subnode_scale : a list of 1D torch.float
72
+ symbolic_enabled : bool
73
+ when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
74
+ affine_trainable : bool
75
+ indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
76
+ sp_trainable : bool
77
+ indicate whether the overall magnitude of splines is trainable
78
+ sb_trainable : bool
79
+ indicate whether the overall magnitude of base function is trainable
80
+ save_act : bool
81
+ indicate whether intermediate activations are saved in forward pass
82
+ node_scores : None or list of 1D torch.float
83
+ node attribution score
84
+ edge_scores : None or list of 2D torch.float
85
+ edge attribution score
86
+ subnode_scores : None or list of 1D torch.float
87
+ subnode attribution score
88
+ cache_data : None or 2D torch.float
89
+ cached input data
90
+ acts : None or a list of 2D torch.float
91
+ activations on nodes
92
+ auto_save : bool
93
+ indicate whether to automatically save a checkpoint once the model is modified
94
+ state_id : int
95
+ the state of the model (used to save checkpoint)
96
+ ckpt_path : str
97
+ the folder to store checkpoints
98
+ round : int
99
+ the number of times rewind() has been called
100
+ device : str
101
+ """
102
+
103
+ def __init__(self, width=None, grid=3, k=3, mult_arity=2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0,
104
+ base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1],
105
+ sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True,
106
+ first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):
107
+ """
108
+ initalize a KAN model
109
+
110
+ Args:
111
+ -----
112
+ width : list of int
113
+ Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
114
+ With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
115
+ grid : int
116
+ number of grid intervals. Default: 3.
117
+ k : int
118
+ order of piecewise polynomial. Default: 3.
119
+ mult_arity : int, or list of int lists
120
+ multiplication arity for each multiplication node (the number of numbers to be multiplied)
121
+ noise_scale : float
122
+ initial injected noise to spline.
123
+ base_fun : str
124
+ the residual function b(x). Default: 'silu'
125
+ symbolic_enabled : bool
126
+ compute (True) or skip (False) symbolic computations (for efficiency). By default: True.
127
+ affine_trainable : bool
128
+ affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
129
+ grid_eps : float
130
+ When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
131
+ grid_range : list/np.array of shape (2,))
132
+ setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
133
+ sp_trainable : bool
134
+ If true, scale_sp is trainable. Default: True.
135
+ sb_trainable : bool
136
+ If true, scale_base is trainable. Default: True.
137
+ device : str
138
+ device
139
+ seed : int
140
+ random seed
141
+ save_act : bool
142
+ indicate whether intermediate activations are saved in forward pass
143
+ sparse_init : bool
144
+ sparse initialization (True) or normal dense initialization. Default: False.
145
+ auto_save : bool
146
+ indicate whether to automatically save a checkpoint once the model is modified
147
+ state_id : int
148
+ the state of the model (used to save checkpoint)
149
+ ckpt_path : str
150
+ the folder to store checkpoints. Default: './model'
151
+ round : int
152
+ the number of times rewind() has been called
153
+ device : str
154
+
155
+ Returns:
156
+ --------
157
+ self
158
+
159
+ Example
160
+ -------
161
+ # >>> from yms_kan import *
162
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
163
+ checkpoint directory created: ./model
164
+ saving model version 0.0
165
+ """
166
+ super(MultKAN, self).__init__()
167
+
168
+ torch.manual_seed(seed)
169
+ np.random.seed(seed)
170
+ random.seed(seed)
171
+
172
+ ### initializeing the numerical front ###
173
+
174
+ self.act_fun = []
175
+ self.depth = len(width) - 1
176
+
177
+ #print('haha1', width)
178
+ for i in range(len(width)):
179
+ #print(type(width[i]), type(width[i]) == int)
180
+ if type(width[i]) == int or type(width[i]) == np.int64:
181
+ width[i] = [width[i], 0]
182
+
183
+ #print('haha2', width)
184
+
185
+ self.width = width
186
+
187
+ # if mult_arity is just a scalar, we extend it to a list of lists
188
+ # e.g, mult_arity = [[2,3],[4]] means that in the first hidden layer, 2 mult ops have arity 2 and 3, respectively;
189
+ # in the second hidden layer, 1 mult op has arity 4.
190
+ if isinstance(mult_arity, int):
191
+ self.mult_homo = True # when homo is True, parallelization is possible
192
+ else:
193
+ self.mult_homo = False # when home if False, for loop is required.
194
+ self.mult_arity = mult_arity
195
+
196
+ width_in = self.width_in
197
+ width_out = self.width_out
198
+
199
+ self.base_fun_name = base_fun
200
+ if base_fun == 'silu':
201
+ base_fun = torch.nn.SiLU()
202
+ elif base_fun == 'identity':
203
+ base_fun = torch.nn.Identity()
204
+ elif base_fun == 'zero':
205
+ base_fun = lambda x: x * 0.
206
+
207
+ self.grid_eps = grid_eps
208
+ self.grid_range = grid_range
209
+
210
+ for l in range(self.depth):
211
+ # splines
212
+ if isinstance(grid, list):
213
+ grid_l = grid[l]
214
+ else:
215
+ grid_l = grid
216
+
217
+ if isinstance(k, list):
218
+ k_l = k[l]
219
+ else:
220
+ k_l = k
221
+
222
+ sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l + 1], num=grid_l, k=k_l,
223
+ noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma,
224
+ scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range,
225
+ sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
226
+ self.act_fun.append(sp_batch)
227
+
228
+ self.node_bias = []
229
+ self.node_scale = []
230
+ self.subnode_bias = []
231
+ self.subnode_scale = []
232
+
233
+ globals()['self.node_bias_0'] = torch.nn.Parameter(torch.zeros(3, 1)).requires_grad_(False)
234
+ exec('self.node_bias_0' + " = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)")
235
+
236
+ for l in range(self.depth):
237
+ exec(
238
+ f'self.node_bias_{l} = torch.nn.Parameter(torch.zeros(width_in[l+1])).requires_grad_(affine_trainable)')
239
+ exec(
240
+ f'self.node_scale_{l} = torch.nn.Parameter(torch.ones(width_in[l+1])).requires_grad_(affine_trainable)')
241
+ exec(
242
+ f'self.subnode_bias_{l} = torch.nn.Parameter(torch.zeros(width_out[l+1])).requires_grad_(affine_trainable)')
243
+ exec(
244
+ f'self.subnode_scale_{l} = torch.nn.Parameter(torch.ones(width_out[l+1])).requires_grad_(affine_trainable)')
245
+ exec(f'self.node_bias.append(self.node_bias_{l})')
246
+ exec(f'self.node_scale.append(self.node_scale_{l})')
247
+ exec(f'self.subnode_bias.append(self.subnode_bias_{l})')
248
+ exec(f'self.subnode_scale.append(self.subnode_scale_{l})')
249
+
250
+ self.act_fun = nn.ModuleList(self.act_fun)
251
+
252
+ self.grid = grid
253
+ self.k = k
254
+ self.base_fun = base_fun
255
+
256
+ ### initializing the symbolic front ###
257
+ self.symbolic_fun = []
258
+ for l in range(self.depth):
259
+ sb_batch = Symbolic_KANLayer(in_dim=width_in[l], out_dim=width_out[l + 1])
260
+ self.symbolic_fun.append(sb_batch)
261
+
262
+ self.symbolic_fun = nn.ModuleList(self.symbolic_fun)
263
+ self.symbolic_enabled = symbolic_enabled
264
+ self.affine_trainable = affine_trainable
265
+ self.sp_trainable = sp_trainable
266
+ self.sb_trainable = sb_trainable
267
+
268
+ self.save_act = save_act
269
+
270
+ self.node_scores = None
271
+ self.edge_scores = None
272
+ self.subnode_scores = None
273
+
274
+ self.cache_data = None
275
+ self.acts = None
276
+
277
+ self.auto_save = auto_save
278
+ self.state_id = 0
279
+ self.ckpt_path = ckpt_path
280
+ self.round = round
281
+
282
+ self.device = device
283
+ self.to(device)
284
+
285
+ if auto_save:
286
+ if first_init:
287
+ if not os.path.exists(ckpt_path):
288
+ # Create the directory
289
+ os.makedirs(ckpt_path)
290
+ print(f"checkpoint directory created: {ckpt_path}")
291
+ print('saving model version 0.0')
292
+
293
+ history_path = self.ckpt_path + '/history.txt'
294
+ with open(history_path, 'w') as file:
295
+ file.write(f'### Round {self.round} ###' + '\n')
296
+ file.write('init => 0.0' + '\n')
297
+ self.saveckpt(path=self.ckpt_path + '/' + '0.0')
298
+ else:
299
+ self.state_id = state_id
300
+
301
+ self.input_id = torch.arange(self.width_in[0], )
302
+
303
+ def to(self, device):
304
+ """
305
+ move the model to device
306
+
307
+ Args:
308
+ -----
309
+ device : str or device
310
+
311
+ Returns:
312
+ --------
313
+ self
314
+
315
+ Example
316
+ -------
317
+ # >>> from yms_kan import *
318
+ # >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
319
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
320
+ # >>> model.to(device)
321
+ """
322
+ super(MultKAN, self).to(device)
323
+ self.device = device
324
+
325
+ for kanlayer in self.act_fun:
326
+ kanlayer.to(device)
327
+
328
+ for symbolic_kanlayer in self.symbolic_fun:
329
+ symbolic_kanlayer.to(device)
330
+
331
+ return self
332
+
333
+ @property
334
+ def width_in(self):
335
+ """
336
+ The number of input nodes for each layer
337
+ """
338
+ width = self.width
339
+ width_in = [width[l][0] + width[l][1] for l in range(len(width))]
340
+ return width_in
341
+
342
+ @property
343
+ def width_out(self):
344
+ """
345
+ The number of output subnodes for each layer
346
+ """
347
+ width = self.width
348
+ if self.mult_homo == True:
349
+ width_out = [width[l][0] + self.mult_arity * width[l][1] for l in range(len(width))]
350
+ else:
351
+ width_out = [width[l][0] + int(np.sum(self.mult_arity[l])) for l in range(len(width))]
352
+ return width_out
353
+
354
+ @property
355
+ def n_sum(self):
356
+ """
357
+ The number of addition nodes for each layer
358
+ """
359
+ width = self.width
360
+ n_sum = [width[l][0] for l in range(1, len(width) - 1)]
361
+ return n_sum
362
+
363
+ @property
364
+ def n_mult(self):
365
+ """
366
+ The number of multiplication nodes for each layer
367
+ """
368
+ width = self.width
369
+ n_mult = [width[l][1] for l in range(1, len(width) - 1)]
370
+ return n_mult
371
+
372
+ @property
373
+ def feature_score(self):
374
+ """
375
+ attribution scores for inputs
376
+ """
377
+ self.attribute()
378
+ if self.node_scores is None:
379
+ return None
380
+ else:
381
+ return self.node_scores[0]
382
+
383
+ def initialize_from_another_model(self, another_model, x):
384
+ """
385
+ initialize from another model of the same width, but their 'grid' parameter can be different.
386
+ Note this is equivalent to refine() when we don't want to keep another_model
387
+
388
+ Args:
389
+ -----
390
+ another_model : MultKAN
391
+ x : 2D torch.float
392
+
393
+ Returns:
394
+ --------
395
+ self
396
+
397
+ Example
398
+ -------
399
+ # >>> from yms_kan import *
400
+ # >>> model1 = KAN(width=[2,5,1], grid=3)
401
+ # >>> model2 = KAN(width=[2,5,1], grid=10)
402
+ # >>> x = torch.rand(100,2)
403
+ # >>> model2.initialize_from_another_model(model1, x)
404
+ """
405
+ another_model(x) # get activations
406
+ batch = x.shape[0]
407
+
408
+ self.initialize_grid_from_another_model(another_model, x)
409
+
410
+ for l in range(self.depth):
411
+ spb = self.act_fun[l]
412
+ #spb_parent = another_model.act_fun[l]
413
+
414
+ # spb = spb_parent
415
+ preacts = another_model.spline_preacts[l]
416
+ postsplines = another_model.spline_postsplines[l]
417
+ self.act_fun[l].coef.data = curve2coef(preacts[:, 0, :], postsplines.permute(0, 2, 1), spb.grid, k=spb.k)
418
+ self.act_fun[l].scale_base.data = another_model.act_fun[l].scale_base.data
419
+ self.act_fun[l].scale_sp.data = another_model.act_fun[l].scale_sp.data
420
+ self.act_fun[l].mask.data = another_model.act_fun[l].mask.data
421
+
422
+ for l in range(self.depth):
423
+ self.node_bias[l].data = another_model.node_bias[l].data
424
+ self.node_scale[l].data = another_model.node_scale[l].data
425
+
426
+ self.subnode_bias[l].data = another_model.subnode_bias[l].data
427
+ self.subnode_scale[l].data = another_model.subnode_scale[l].data
428
+
429
+ for l in range(self.depth):
430
+ self.symbolic_fun[l] = another_model.symbolic_fun[l]
431
+
432
+ return self.to(self.device)
433
+
434
+ def log_history(self, method_name):
435
+
436
+ if self.auto_save:
437
+ # save to log file
438
+ # print(func.__name__)
439
+ with open(self.ckpt_path + '/history.txt', 'a') as file:
440
+ file.write(str(self.round) + '.' + str(self.state_id) + ' => ' + method_name + ' => ' + str(
441
+ self.round) + '.' + str(self.state_id + 1) + '\n')
442
+
443
+ # update state_id
444
+ self.state_id += 1
445
+
446
+ # save to ckpt
447
+ self.saveckpt(path=self.ckpt_path + '/' + str(self.round) + '.' + str(self.state_id))
448
+ print('saving model version ' + str(self.round) + '.' + str(self.state_id))
449
+
450
+ def refine(self, new_grid):
451
+ """
452
+ grid refinement
453
+
454
+ Args:
455
+ -----
456
+ new_grid : init
457
+ the number of grid intervals after refinement
458
+
459
+ Returns:
460
+ --------
461
+ a refined model : MultKAN
462
+
463
+ Example
464
+ -------
465
+ # >>> from yms_kan import *
466
+ # >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
467
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
468
+ # >>> print(model.grid)
469
+ # >>> x = torch.rand(100,2)
470
+ # >>> model.get_act(x)
471
+ # >>> model = model.refine(10)
472
+ # >>> print(model.grid)
473
+ checkpoint directory created: ./model
474
+ saving model version 0.0
475
+ 5
476
+ saving model version 0.1
477
+ 10
478
+ """
479
+
480
+ model_new = MultKAN(width=self.width,
481
+ grid=new_grid,
482
+ k=self.k,
483
+ mult_arity=self.mult_arity,
484
+ base_fun=self.base_fun_name,
485
+ symbolic_enabled=self.symbolic_enabled,
486
+ affine_trainable=self.affine_trainable,
487
+ grid_eps=self.grid_eps,
488
+ grid_range=self.grid_range,
489
+ sp_trainable=self.sp_trainable,
490
+ sb_trainable=self.sb_trainable,
491
+ ckpt_path=self.ckpt_path,
492
+ auto_save=True,
493
+ first_init=False,
494
+ state_id=self.state_id,
495
+ round=self.round,
496
+ device=self.device)
497
+
498
+ model_new.initialize_from_another_model(self, self.cache_data)
499
+ model_new.cache_data = self.cache_data
500
+ model_new.grid = new_grid
501
+
502
+ self.log_history('refine')
503
+ model_new.state_id += 1
504
+
505
+ return model_new.to(self.device)
506
+
507
+ def saveckpt(self, path='model'):
508
+ """
509
+ save the current model to files (configuration file and state file)
510
+
511
+ Args:
512
+ -----
513
+ path : str
514
+ the path where checkpoints are saved
515
+
516
+ Returns:
517
+ --------
518
+ None
519
+
520
+ Example
521
+ -------
522
+ # >>> from yms_kan import *
523
+ # >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
524
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
525
+ # >>> model.saveckpt('./mark')
526
+ # There will be three files appearing in the current folder: mark_cache_data, mark_config.yml, mark_state
527
+ """
528
+
529
+ model = self
530
+
531
+ dic = dict(
532
+ width=model.width,
533
+ grid=model.grid,
534
+ k=model.k,
535
+ mult_arity=model.mult_arity,
536
+ base_fun_name=model.base_fun_name,
537
+ symbolic_enabled=model.symbolic_enabled,
538
+ affine_trainable=model.affine_trainable,
539
+ grid_eps=model.grid_eps,
540
+ grid_range=model.grid_range,
541
+ sp_trainable=model.sp_trainable,
542
+ sb_trainable=model.sb_trainable,
543
+ state_id=model.state_id,
544
+ auto_save=model.auto_save,
545
+ ckpt_path=model.ckpt_path,
546
+ round=model.round,
547
+ device=str(model.device)
548
+ )
549
+
550
+ if dic["device"].isdigit():
551
+ dic["device"] = int(model.device)
552
+
553
+ for i in range(model.depth):
554
+ dic[f'symbolic.funs_name.{i}'] = model.symbolic_fun[i].funs_name
555
+
556
+ with open(f'{path}_config.yml', 'w') as outfile:
557
+ yaml.dump(dic, outfile, default_flow_style=False)
558
+
559
+ torch.save(model.state_dict(), f'{path}_state')
560
+ torch.save(model.cache_data, f'{path}_cache_data')
561
+
562
+ @staticmethod
563
+ def loadckpt(path='model'):
564
+ """
565
+ load checkpoint from path
566
+
567
+ Args:
568
+ -----
569
+ path : str
570
+ the path where checkpoints are saved
571
+
572
+ Returns:
573
+ --------
574
+ MultKAN
575
+
576
+ Example
577
+ -------
578
+ # >>> from yms_kan import *
579
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
580
+ # >>> model.saveckpt('./mark')
581
+ # >>> KAN.loadckpt('./mark')
582
+ """
583
+ with open(f'{path}_config.yml', 'r') as stream:
584
+ config = yaml.safe_load(stream)
585
+
586
+ state = torch.load(f'{path}_state', map_location=torch.device('cpu'), weights_only=False)
587
+
588
+ model_load = MultKAN(width=config['width'],
589
+ grid=config['grid'],
590
+ k=config['k'],
591
+ mult_arity=config['mult_arity'],
592
+ base_fun=config['base_fun_name'],
593
+ symbolic_enabled=config['symbolic_enabled'],
594
+ affine_trainable=config['affine_trainable'],
595
+ grid_eps=config['grid_eps'],
596
+ grid_range=config['grid_range'],
597
+ sp_trainable=config['sp_trainable'],
598
+ sb_trainable=config['sb_trainable'],
599
+ state_id=config['state_id'],
600
+ auto_save=config['auto_save'],
601
+ first_init=False,
602
+ ckpt_path=config['ckpt_path'],
603
+ round=config['round'] + 1,
604
+ device='cpu')
605
+
606
+ model_load.load_state_dict(state)
607
+ model_load.cache_data = torch.load(f'{path}_cache_data', map_location=torch.device('cpu'),
608
+ weights_only=False)
609
+
610
+ depth = len(model_load.width) - 1
611
+ for l in range(depth):
612
+ out_dim = model_load.symbolic_fun[l].out_dim
613
+ in_dim = model_load.symbolic_fun[l].in_dim
614
+ funs_name = config[f'symbolic.funs_name.{l}']
615
+ for j in range(out_dim):
616
+ for i in range(in_dim):
617
+ fun_name = funs_name[j][i]
618
+ model_load.symbolic_fun[l].funs_name[j][i] = fun_name
619
+ model_load.symbolic_fun[l].funs[j][i] = SYMBOLIC_LIB[fun_name][0]
620
+ model_load.symbolic_fun[l].funs_sympy[j][i] = SYMBOLIC_LIB[fun_name][1]
621
+ model_load.symbolic_fun[l].funs_avoid_singularity[j][i] = SYMBOLIC_LIB[fun_name][3]
622
+ model_load.device = config['device']
623
+ return model_load
624
+
625
+ def copy(self):
626
+ """
627
+ deepcopy
628
+
629
+ Args:
630
+ -----
631
+ path : str
632
+ the path where checkpoints are saved
633
+
634
+ Returns:
635
+ --------
636
+ MultKAN
637
+
638
+ Example
639
+ -------
640
+ # >>> from yms_kan import *
641
+ # >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
642
+ # >>> model2 = model.copy()
643
+ # >>> model2.act_fun[0].coef.data *= 2
644
+ # >>> print(model2.act_fun[0].coef.data)
645
+ # >>> print(model.act_fun[0].coef.data)
646
+ """
647
+ path = 'copy_temp'
648
+ self.saveckpt(path)
649
+ return KAN.loadckpt(path)
650
+
651
+ def rewind(self, model_id):
652
+ """
653
+ rewind to an old version
654
+
655
+ Args:
656
+ -----
657
+ model_id : str
658
+ in format '{a}.{b}' where a is the round number, b is the version number in that round
659
+
660
+ Returns:
661
+ --------
662
+ MultKAN
663
+
664
+ Example
665
+ -------
666
+ Please refer to tutorials. API 12: Checkpoint, save & load model
667
+ """
668
+ self.round += 1
669
+ self.state_id = model_id.split('.')[-1]
670
+
671
+ history_path = self.ckpt_path + '/history.txt'
672
+ with open(history_path, 'a') as file:
673
+ file.write(f'### Round {self.round} ###' + '\n')
674
+
675
+ self.saveckpt(path=self.ckpt_path + '/' + f'{self.round}.{self.state_id}')
676
+
677
+ print(
678
+ 'rewind to model version ' + f'{self.round - 1}.{self.state_id}' + ', renamed as ' + f'{self.round}.{self.state_id}')
679
+
680
+ return MultKAN.loadckpt(path=self.ckpt_path + '/' + str(model_id))
681
+
682
+ def checkout(self, model_id):
683
+ """
684
+ check out an old version
685
+
686
+ Args:
687
+ -----
688
+ model_id : str
689
+ in format '{a}.{b}' where a is the round number, b is the version number in that round
690
+
691
+ Returns:
692
+ --------
693
+ MultKAN
694
+
695
+ Example
696
+ -------
697
+ Same use as rewind, although checkout doesn't change states
698
+ """
699
+ return MultKAN.loadckpt(path=self.ckpt_path + '/' + str(model_id))
700
+
701
+ def update_grid_from_samples(self, x):
702
+ """
703
+ update grid from samples
704
+
705
+ Args:
706
+ -----
707
+ x : 2D torch.tensor
708
+ inputs
709
+
710
+ Returns:
711
+ --------
712
+ None
713
+
714
+ Example
715
+ -------
716
+ # >>> from yms_kan import *
717
+ # >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
718
+ # >>> print(model.act_fun[0].grid)
719
+ # >>> x = torch.linspace(-10,10,steps=101)[:,None]
720
+ # >>> model.update_grid_from_samples(x)
721
+ # >>> print(model.act_fun[0].grid)
722
+ """
723
+ for l in range(self.depth):
724
+ self.get_act(x)
725
+ self.act_fun[l].update_grid_from_samples(self.acts[l])
726
+
727
+ def update_grid(self, x):
728
+ """
729
+ call update_grid_from_samples. This seems unnecessary but we retain it for the sake of classes that might inherit from MultKAN
730
+ """
731
+ self.update_grid_from_samples(x)
732
+
733
+ def initialize_grid_from_another_model(self, model, x):
734
+ """
735
+ initialize grid from another model
736
+
737
+ Args:
738
+ -----
739
+ model : MultKAN
740
+ parent model
741
+ x : 2D torch.tensor
742
+ inputs
743
+
744
+ Returns:
745
+ --------
746
+ None
747
+
748
+ Example
749
+ -------
750
+ # >>> from yms_kan import *
751
+ # >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
752
+ # >>> print(model.act_fun[0].grid)
753
+ # >>> x = torch.linspace(-10,10,steps=101)[:,None]
754
+ # >>> model2 = KAN(width=[1,1], grid=10, k=3, seed=0)
755
+ # >>> model2.initialize_grid_from_another_model(model, x)
756
+ # >>> print(model2.act_fun[0].grid)
757
+ """
758
+ model(x)
759
+ for l in range(self.depth):
760
+ self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l])
761
+
762
+ def forward(self, x, singularity_avoiding=False, y_th=10.):
763
+ """
764
+ forward pass
765
+
766
+ Args:
767
+ -----
768
+ x : 2D torch.tensor
769
+ inputs
770
+ singularity_avoiding : bool
771
+ whether to avoid singularity for the symbolic branch
772
+ y_th : float
773
+ the threshold for singularity
774
+
775
+ Returns:
776
+ --------
777
+ None
778
+
779
+ Example1
780
+ --------
781
+ # >>> from yms_kan import *
782
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
783
+ # >>> x = torch.rand(100,2)
784
+ # >>> model(x).shape
785
+
786
+ Example2
787
+ --------
788
+ # >>> from yms_kan import *
789
+ # >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
790
+ # >>> x = torch.tensor([[1],[-0.01]])
791
+ # >>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False)
792
+ # >>> print(model(x))
793
+ # >>> print(model(x, singularity_avoiding=True))
794
+ # >>> print(model(x, singularity_avoiding=True, y_th=1.))
795
+ """
796
+ x = x[:, self.input_id.long()]
797
+ assert x.shape[1] == self.width_in[0]
798
+
799
+ # cache data
800
+ self.cache_data = x
801
+
802
+ self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
803
+ self.acts_premult = []
804
+ self.spline_preacts = []
805
+ self.spline_postsplines = []
806
+ self.spline_postacts = []
807
+ self.acts_scale = []
808
+ self.acts_scale_spline = []
809
+ self.subnode_actscale = []
810
+ self.edge_actscale = []
811
+ # self.neurons_scale = []
812
+
813
+ self.acts.append(x) # acts shape: (batch, width[l])
814
+
815
+ for l in range(self.depth):
816
+
817
+ x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
818
+ #print(preacts, postacts_numerical, postspline)
819
+
820
+ if self.symbolic_enabled:
821
+ x_symbolic, postacts_symbolic = self.symbolic_fun[l](x, singularity_avoiding=singularity_avoiding,
822
+ y_th=y_th)
823
+ else:
824
+ x_symbolic = 0.
825
+ postacts_symbolic = 0.
826
+
827
+ x = x_numerical + x_symbolic
828
+
829
+ if self.save_act:
830
+ # save subnode_scale
831
+ self.subnode_actscale.append(torch.std(x, dim=0).detach())
832
+
833
+ # subnode affine transform
834
+ x = self.subnode_scale[l][None, :] * x + self.subnode_bias[l][None, :]
835
+
836
+ if self.save_act:
837
+ postacts = postacts_numerical + postacts_symbolic
838
+
839
+ # self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))
840
+ #grid_reshape = self.act_fun[l].grid.reshape(self.width_out[l + 1], self.width_in[l], -1)
841
+ input_range = torch.std(preacts, dim=0) + 0.1
842
+ output_range_spline = torch.std(postacts_numerical,
843
+ dim=0) # for training, only penalize the spline part
844
+ output_range = torch.std(postacts,
845
+ dim=0) # for visualization, include the contribution from both spline + symbolic
846
+ # save edge_scale
847
+ self.edge_actscale.append(output_range)
848
+
849
+ self.acts_scale.append((output_range / input_range).detach())
850
+ self.acts_scale_spline.append(output_range_spline / input_range)
851
+ self.spline_preacts.append(preacts.detach())
852
+ self.spline_postacts.append(postacts.detach())
853
+ self.spline_postsplines.append(postspline.detach())
854
+
855
+ self.acts_premult.append(x.detach())
856
+
857
+ # multiplication
858
+ dim_sum = self.width[l + 1][0]
859
+ dim_mult = self.width[l + 1][1]
860
+
861
+ if self.mult_homo:
862
+ for i in range(self.mult_arity - 1):
863
+ if i == 0:
864
+ x_mult = x[:, dim_sum::self.mult_arity] * x[:, dim_sum + 1::self.mult_arity]
865
+ else:
866
+ x_mult = x_mult * x[:, dim_sum + i + 1::self.mult_arity]
867
+
868
+ else:
869
+ for j in range(dim_mult):
870
+ acml_id = dim_sum + np.sum(self.mult_arity[l + 1][:j])
871
+ for i in range(self.mult_arity[l + 1][j] - 1):
872
+ if i == 0:
873
+ x_mult_j = x[:, [acml_id]] * x[:, [acml_id + 1]]
874
+ else:
875
+ x_mult_j = x_mult_j * x[:, [acml_id + i + 1]]
876
+
877
+ if j == 0:
878
+ x_mult = x_mult_j
879
+ else:
880
+ x_mult = torch.cat([x_mult, x_mult_j], dim=1)
881
+
882
+ if self.width[l + 1][1] > 0:
883
+ x = torch.cat([x[:, :dim_sum], x_mult], dim=1)
884
+
885
+ # x = x + self.biases[l].weight
886
+ # node affine transform
887
+ x = self.node_scale[l][None, :] * x + self.node_bias[l][None, :]
888
+
889
+ self.acts.append(x.detach())
890
+
891
+ return x
892
+
893
+ def set_mode(self, l, i, j, mode, mask_n=None):
894
+ if mode == "s":
895
+ mask_n = 0.
896
+ mask_s = 1.
897
+ elif mode == "n":
898
+ mask_n = 1.
899
+ mask_s = 0.
900
+ elif mode == "sn" or mode == "ns":
901
+ if mask_n is None:
902
+ mask_n = 1.
903
+ else:
904
+ mask_n = mask_n
905
+ mask_s = 1.
906
+ else:
907
+ mask_n = 0.
908
+ mask_s = 0.
909
+
910
+ self.act_fun[l].mask.data[i][j] = mask_n
911
+ self.symbolic_fun[l].mask.data[j, i] = mask_s
912
+
913
+ def fix_symbolic(self, l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True,
914
+ random=False, log_history=True):
915
+ '''
916
+ set (l,i,j) activation to be symbolic (specified by fun_name)
917
+
918
+ Args:
919
+ -----
920
+ l : int
921
+ layer index
922
+ i : int
923
+ input neuron index
924
+ j : int
925
+ output neuron index
926
+ fun_name : str
927
+ function name
928
+ fit_params_bool : bool
929
+ obtaining affine parameters through fitting (True) or setting default values (False)
930
+ a_range : tuple
931
+ sweeping range of a
932
+ b_range : tuple
933
+ sweeping range of b
934
+ verbose : bool
935
+ If True, more information is printed.
936
+ random : bool
937
+ initialize affine parameteres randomly or as [1,0,1,0]
938
+ log_history : bool
939
+ indicate whether to log history when the function is called
940
+
941
+ Returns:
942
+ --------
943
+ None or r2 (coefficient of determination)
944
+
945
+ Example 1
946
+ ---------
947
+ >>> # when fit_params_bool = False
948
+ >>> model = KAN(width=[2,5,1], grid=5, k=3)
949
+ >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False)
950
+ >>> print(model.act_fun[0].mask.reshape(2,5))
951
+ >>> print(model.symbolic_fun[0].mask.reshape(2,5))
952
+
953
+ Example 2
954
+ ---------
955
+ >>> # when fit_params_bool = True
956
+ >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.)
957
+ >>> x = torch.normal(0,1,size=(100,2))
958
+ >>> model(x) # obtain activations (otherwise model does not have attributes acts)
959
+ >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True)
960
+ >>> print(model.act_fun[0].mask.reshape(2,5))
961
+ >>> print(model.symbolic_fun[0].mask.reshape(2,5))
962
+ '''
963
+ if not fit_params_bool:
964
+ self.symbolic_fun[l].fix_symbolic(i, j, fun_name, verbose=verbose, random=random)
965
+ r2 = None
966
+ else:
967
+ x = self.acts[l][:, i]
968
+ mask = self.act_fun[l].mask
969
+ y = self.spline_postacts[l][:, j, i]
970
+ # y = self.postacts[l][:, j, i]
971
+ r2 = self.symbolic_fun[l].fix_symbolic(i, j, fun_name, x, y, a_range=a_range, b_range=b_range,
972
+ verbose=verbose)
973
+ if mask[i, j] == 0:
974
+ r2 = - 1e8
975
+ self.set_mode(l, i, j, mode="s")
976
+
977
+ if log_history:
978
+ self.log_history('fix_symbolic')
979
+ return r2
980
+
981
+ def unfix_symbolic(self, l, i, j, log_history=True):
982
+ """
983
+ unfix the (l,i,j) activation function.
984
+ """
985
+ self.set_mode(l, i, j, mode="n")
986
+ self.symbolic_fun[l].funs_name[j][i] = "0"
987
+ if log_history:
988
+ self.log_history('unfix_symbolic')
989
+
990
+ def unfix_symbolic_all(self, log_history=True):
991
+ '''
992
+ unfix all activation functions.
993
+ '''
994
+ for l in range(len(self.width) - 1):
995
+ for i in range(self.width_in[l]):
996
+ for j in range(self.width_out[l + 1]):
997
+ self.unfix_symbolic(l, i, j, log_history)
998
+
999
+ def get_range(self, l, i, j, verbose=True):
1000
+ '''
1001
+ Get the input range and output range of the (l,i,j) activation
1002
+
1003
+ Args:
1004
+ -----
1005
+ l : int
1006
+ layer index
1007
+ i : int
1008
+ input neuron index
1009
+ j : int
1010
+ output neuron index
1011
+
1012
+ Returns:
1013
+ --------
1014
+ x_min : float
1015
+ minimum of input
1016
+ x_max : float
1017
+ maximum of input
1018
+ y_min : float
1019
+ minimum of output
1020
+ y_max : float
1021
+ maximum of output
1022
+
1023
+ Example
1024
+ -------
1025
+ >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
1026
+ >>> x = torch.normal(0,1,size=(100,2))
1027
+ >>> model(x) # do a forward pass to obtain model.acts
1028
+ >>> model.get_range(0,0,0)
1029
+ '''
1030
+ x = self.spline_preacts[l][:, j, i]
1031
+ y = self.spline_postacts[l][:, j, i]
1032
+ x_min = torch.min(x).cpu().detach().numpy()
1033
+ x_max = torch.max(x).cpu().detach().numpy()
1034
+ y_min = torch.min(y).cpu().detach().numpy()
1035
+ y_max = torch.max(y).cpu().detach().numpy()
1036
+ if verbose:
1037
+ print('x range: [' + '%.2f' % x_min, ',', '%.2f' % x_max, ']')
1038
+ print('y range: [' + '%.2f' % y_min, ',', '%.2f' % y_max, ']')
1039
+ return x_min, x_max, y_min, y_max
1040
+
1041
+ def _draw_operator_symbol(self, ax, fig, x, y, symbol_type, scale, size):
1042
+ """辅助函数绘制运算符符号"""
1043
+ path = os.path.join(os.path.dirname(__file__), 'assets', 'img', f'{symbol_type}_symbol.png')
1044
+ try:
1045
+ with Image.open(path) as img:
1046
+ trans = ax.transData.transform
1047
+ inv_trans = fig.transFigure.inverted().transform
1048
+ bbox = [x - size / 2, y - size / 2, size, size]
1049
+ bbox = inv_trans.transform_bbox(trans(bbox))
1050
+ ax.imshow(img, extent=bbox)
1051
+ except Exception as e:
1052
+ print(f"无法加载符号 {symbol_type}: {str(e)}")
1053
+
1054
+ def _draw_variable_labels(self, ax, variables, layer, scale, varscale, y_offset, y0, z0):
1055
+ """辅助函数绘制变量标签"""
1056
+ n = self.width_in[layer]
1057
+ for i, var in enumerate(variables):
1058
+ x = 1 / (2 * n) + i / n
1059
+ y = layer * (y0 + z0) + y_offset
1060
+ label = f'${latex(var)}$' if isinstance(var, sympy.Expr) else var
1061
+ ax.text(x, y, label, fontsize=15 * scale * varscale,
1062
+ ha='center', va='center', transform=ax.transData)
1063
+
1064
+ def plot(self, folder="./figures", beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None,
1065
+ out_vars=None, title=None, varscale=1.0):
1066
+ '''
1067
+ plot KAN
1068
+
1069
+ Args:
1070
+ -----
1071
+ folder : str
1072
+ the folder to store pngs
1073
+ beta : float
1074
+ positive number. control the transparency of each activation. transparency = tanh(beta*l1).
1075
+ mask : bool
1076
+ If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.
1077
+ mode : bool
1078
+ "supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean).
1079
+ scale : float
1080
+ control the size of the diagram
1081
+ in_vars: None or list of str
1082
+ the name(s) of input variables
1083
+ out_vars: None or list of str
1084
+ the name(s) of output variables
1085
+ title: None or str
1086
+ title
1087
+ varscale : float
1088
+ the size of input variables
1089
+
1090
+ Returns:
1091
+ --------
1092
+ Figure
1093
+
1094
+ Example
1095
+ -------
1096
+ >>> # see more interactive examples in demos
1097
+ >>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0)
1098
+ >>> x = torch.normal(0,1,size=(100,2))
1099
+ >>> model(x) # do a forward pass to obtain model.acts
1100
+ >>> model.plot()
1101
+ '''
1102
+ global Symbol
1103
+
1104
+ if not self.save_act:
1105
+ print('cannot plot since data are not saved. Set save_act=True first.')
1106
+
1107
+ if self.acts is None:
1108
+ if self.cache_data is None:
1109
+ raise Exception('model hasn\'t seen any data yet.')
1110
+ self.forward(self.cache_data)
1111
+
1112
+ if metric == 'backward':
1113
+ self.attribute()
1114
+
1115
+ if not os.path.exists(folder):
1116
+ os.makedirs(folder)
1117
+
1118
+ depth = len(self.width) - 1
1119
+ for l in range(depth):
1120
+ w_large = 2.0
1121
+ for i in range(self.width_in[l]):
1122
+ for j in range(self.width_out[l + 1]):
1123
+ rank = torch.argsort(self.acts[l][:, i])
1124
+ fig, ax = plt.subplots(figsize=(w_large, w_large))
1125
+
1126
+ num = rank.shape[0]
1127
+ symbolic_mask = self.symbolic_fun[l].mask[j][i]
1128
+ numeric_mask = self.act_fun[l].mask[i][j]
1129
+ if symbolic_mask > 0. and numeric_mask > 0.:
1130
+ color = 'purple'
1131
+ alpha_mask = 1
1132
+ if symbolic_mask > 0. and numeric_mask == 0.:
1133
+ color = "red"
1134
+ alpha_mask = 1
1135
+ if symbolic_mask == 0. and numeric_mask > 0.:
1136
+ color = "black"
1137
+ alpha_mask = 1
1138
+ if symbolic_mask == 0. and numeric_mask == 0.:
1139
+ color = "white"
1140
+ alpha_mask = 0
1141
+
1142
+ if tick:
1143
+ ax.tick_params(axis="y", direction="in", pad=-22, labelsize=50)
1144
+ ax.tick_params(axis="x", direction="in", pad=-15, labelsize=50)
1145
+ x_min, x_max, y_min, y_max = self.get_range(l, i, j, verbose=False)
1146
+ plt.xticks([x_min, x_max], ['%2.f' % x_min, '%2.f' % x_max])
1147
+ plt.yticks([y_min, y_max], ['%2.f' % y_min, '%2.f' % y_max])
1148
+ else:
1149
+ plt.xticks([])
1150
+ plt.yticks([])
1151
+ if alpha_mask == 1:
1152
+ plt.gca().patch.set_edgecolor('black')
1153
+ else:
1154
+ plt.gca().patch.set_edgecolor('white')
1155
+ plt.gca().patch.set_linewidth(1.5)
1156
+ # plt.axis('off')
1157
+
1158
+ plt.plot(self.acts[l][:, i][rank].cpu().detach().numpy(),
1159
+ self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, lw=5)
1160
+ if sample:
1161
+ plt.scatter(self.acts[l][:, i][rank].cpu().detach().numpy(),
1162
+ self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color,
1163
+ s=400 * scale ** 2)
1164
+ plt.gca().spines[:].set_color(color)
1165
+
1166
+ plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400)
1167
+ plt.close()
1168
+
1169
+ def score2alpha(score):
1170
+ return np.tanh(beta * score)
1171
+
1172
+ if metric == 'forward_n':
1173
+ scores = self.acts_scale
1174
+ elif metric == 'forward_u':
1175
+ scores = self.edge_actscale
1176
+ elif metric == 'backward':
1177
+ scores = self.edge_scores
1178
+ else:
1179
+ raise Exception(f'metric = \'{metric}\' not recognized')
1180
+
1181
+ alpha = [score2alpha(score.cpu().detach().numpy()) for score in scores]
1182
+
1183
+ # draw skeleton
1184
+ width = np.array(self.width)
1185
+ width_in = np.array(self.width_in)
1186
+ width_out = np.array(self.width_out)
1187
+ A = 1
1188
+ y0 = 0.3 # height: from input to pre-mult
1189
+ z0 = 0.1 # height: from pre-mult to post-mult (input of next layer)
1190
+
1191
+ neuron_depth = len(width)
1192
+ min_spacing = A / np.maximum(np.max(width_out), 5)
1193
+
1194
+ max_neuron = np.max(width_out)
1195
+ max_num_weights = np.max(width_in[:-1] * width_out[1:])
1196
+ y1 = 0.4 / np.maximum(max_num_weights, 5) # size (height/width) of 1D function diagrams
1197
+ y2 = 0.15 / np.maximum(max_neuron, 5) # size (height/width) of operations (sum and mult)
1198
+
1199
+ fig, ax = plt.subplots(figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * (y0 + z0)))
1200
+ # fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0))
1201
+
1202
+ # -- Transformation functions
1203
+ DC_to_FC = ax.transData.transform
1204
+ FC_to_NFC = fig.transFigure.inverted().transform
1205
+ # -- Take data coordinates and transform them to normalized figure coordinates
1206
+ DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x))
1207
+
1208
+ # plot scatters and lines
1209
+ for l in range(neuron_depth):
1210
+
1211
+ n = width_in[l]
1212
+
1213
+ # scatters
1214
+ for i in range(n):
1215
+ plt.scatter(1 / (2 * n) + i / n, l * (y0 + z0), s=min_spacing ** 2 * 10000 * scale ** 2, color='black')
1216
+
1217
+ # plot connections (input to pre-mult)
1218
+ for i in range(n):
1219
+ if l < neuron_depth - 1:
1220
+ n_next = width_out[l + 1]
1221
+ N = n * n_next
1222
+ for j in range(n_next):
1223
+ id_ = i * n_next + j
1224
+
1225
+ symbol_mask = self.symbolic_fun[l].mask[j][i]
1226
+ numerical_mask = self.act_fun[l].mask[i][j]
1227
+ if symbol_mask == 1. and numerical_mask > 0.:
1228
+ color = 'purple'
1229
+ alpha_mask = 1.
1230
+ if symbol_mask == 1. and numerical_mask == 0.:
1231
+ color = "red"
1232
+ alpha_mask = 1.
1233
+ if symbol_mask == 0. and numerical_mask == 1.:
1234
+ color = "black"
1235
+ alpha_mask = 1.
1236
+ if symbol_mask == 0. and numerical_mask == 0.:
1237
+ color = "white"
1238
+ alpha_mask = 0.
1239
+
1240
+ plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N],
1241
+ [l * (y0 + z0), l * (y0 + z0) + y0 / 2 - y1], color=color, lw=2 * scale,
1242
+ alpha=alpha[l][j][i] * alpha_mask)
1243
+ plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next],
1244
+ [l * (y0 + z0) + y0 / 2 + y1, l * (y0 + z0) + y0], color=color, lw=2 * scale,
1245
+ alpha=alpha[l][j][i] * alpha_mask)
1246
+
1247
+ # plot connections (pre-mult to post-mult, post-mult = next-layer input)
1248
+ if l < neuron_depth - 1:
1249
+ n_in = width_out[l + 1]
1250
+ n_out = width_in[l + 1]
1251
+ mult_id = 0
1252
+ for i in range(n_in):
1253
+ if i < width[l + 1][0]:
1254
+ j = i
1255
+ else:
1256
+ if i == width[l + 1][0]:
1257
+ if isinstance(self.mult_arity, int):
1258
+ ma = self.mult_arity
1259
+ else:
1260
+ ma = self.mult_arity[l + 1][mult_id]
1261
+ current_mult_arity = ma
1262
+ if current_mult_arity == 0:
1263
+ mult_id += 1
1264
+ if isinstance(self.mult_arity, int):
1265
+ ma = self.mult_arity
1266
+ else:
1267
+ ma = self.mult_arity[l + 1][mult_id]
1268
+ current_mult_arity = ma
1269
+ j = width[l + 1][0] + mult_id
1270
+ current_mult_arity -= 1
1271
+ #j = (i-width[l+1][0])//self.mult_arity + width[l+1][0]
1272
+ plt.plot([1 / (2 * n_in) + i / n_in, 1 / (2 * n_out) + j / n_out],
1273
+ [l * (y0 + z0) + y0, (l + 1) * (y0 + z0)], color='black', lw=2 * scale)
1274
+
1275
+ plt.xlim(0, 1)
1276
+ plt.ylim(-0.1 * (y0 + z0), (neuron_depth - 1 + 0.1) * (y0 + z0))
1277
+
1278
+ plt.axis('off')
1279
+
1280
+ for l in range(neuron_depth - 1):
1281
+ # plot splines
1282
+ n = width_in[l]
1283
+ for i in range(n):
1284
+ n_next = width_out[l + 1]
1285
+ N = n * n_next
1286
+ for j in range(n_next):
1287
+ id_ = i * n_next + j
1288
+ im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png')
1289
+ left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0]
1290
+ right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0]
1291
+ bottom = DC_to_NFC([0, l * (y0 + z0) + y0 / 2 - y1])[1]
1292
+ up = DC_to_NFC([0, l * (y0 + z0) + y0 / 2 + y1])[1]
1293
+ newax = fig.add_axes([left, bottom, right - left, up - bottom])
1294
+ # newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE')
1295
+ newax.imshow(im, alpha=alpha[l][j][i])
1296
+ newax.axis('off')
1297
+
1298
+ # plot sum symbols
1299
+ N = n = width_out[l + 1]
1300
+ for j in range(n):
1301
+ id_ = j
1302
+ path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
1303
+ im = plt.imread(path)
1304
+ left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
1305
+ right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
1306
+ bottom = DC_to_NFC([0, l * (y0 + z0) + y0 - y2])[1]
1307
+ up = DC_to_NFC([0, l * (y0 + z0) + y0 + y2])[1]
1308
+ newax = fig.add_axes([left, bottom, right - left, up - bottom])
1309
+ newax.imshow(im)
1310
+ newax.axis('off')
1311
+
1312
+ # plot mult symbols
1313
+ N = n = width_in[l + 1]
1314
+ n_sum = width[l + 1][0]
1315
+ n_mult = width[l + 1][1]
1316
+ for j in range(n_mult):
1317
+ id_ = j + n_sum
1318
+ path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
1319
+ im = plt.imread(path)
1320
+ left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
1321
+ right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
1322
+ bottom = DC_to_NFC([0, (l + 1) * (y0 + z0) - y2])[1]
1323
+ up = DC_to_NFC([0, (l + 1) * (y0 + z0) + y2])[1]
1324
+ newax = fig.add_axes([left, bottom, right - left, up - bottom])
1325
+ newax.imshow(im)
1326
+ newax.axis('off')
1327
+
1328
+ if in_vars is not None:
1329
+ n = self.width_in[0]
1330
+ for i in range(n):
1331
+ if isinstance(in_vars[i], sympy.Expr):
1332
+ plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, f'${latex(in_vars[i])}$',
1333
+ fontsize=40 * scale * varscale, horizontalalignment='center',
1334
+ verticalalignment='center')
1335
+ else:
1336
+ plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i],
1337
+ fontsize=40 * scale * varscale, horizontalalignment='center',
1338
+ verticalalignment='center')
1339
+
1340
+ if out_vars is not None:
1341
+ n = self.width_in[-1]
1342
+ for i in range(n):
1343
+ if isinstance(out_vars[i], sympy.Expr):
1344
+ plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0 + z0) * (len(self.width) - 1) + 0.15,
1345
+ f'${latex(out_vars[i])}$', fontsize=40 * scale * varscale,
1346
+ horizontalalignment='center', verticalalignment='center')
1347
+ else:
1348
+ plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0 + z0) * (len(self.width) - 1) + 0.15,
1349
+ out_vars[i], fontsize=40 * scale * varscale,
1350
+ horizontalalignment='center', verticalalignment='center')
1351
+
1352
+ if title is not None:
1353
+ plt.gcf().get_axes()[0].text(0.5, (y0 + z0) * (len(self.width) - 1) + 0.3, title, fontsize=40 * scale,
1354
+ horizontalalignment='center', verticalalignment='center')
1355
+
1356
+ def reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff):
1357
+ """
1358
+ Get regularization
1359
+
1360
+ Args:
1361
+ -----
1362
+ reg_metric : the regularization metric
1363
+ 'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'
1364
+ lamb_l1 : float
1365
+ l1 penalty strength
1366
+ lamb_entropy : float
1367
+ entropy penalty strength
1368
+ lamb_coef : float
1369
+ coefficient penalty strength
1370
+ lamb_coefdiff : float
1371
+ coefficient smoothness strength
1372
+
1373
+ Returns:
1374
+ --------
1375
+ reg_ : torch.float
1376
+
1377
+ Example
1378
+ -------
1379
+ >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
1380
+ >>> x = torch.rand(100,2)
1381
+ >>> model.get_act(x)
1382
+ >>> model.reg('edge_forward_spline_n', 1.0, 2.0, 1.0, 1.0)
1383
+ """
1384
+ if reg_metric == 'edge_forward_spline_n':
1385
+ acts_scale = self.acts_scale_spline
1386
+
1387
+ elif reg_metric == 'edge_forward_sum':
1388
+ acts_scale = self.acts_scale
1389
+
1390
+ elif reg_metric == 'edge_forward_spline_u':
1391
+ acts_scale = self.edge_actscale
1392
+
1393
+ elif reg_metric == 'edge_backward':
1394
+ acts_scale = self.edge_scores
1395
+
1396
+ elif reg_metric == 'node_backward':
1397
+ acts_scale = self.node_attribute_scores
1398
+
1399
+ else:
1400
+ raise Exception(f'reg_metric = {reg_metric} not recognized!')
1401
+
1402
+ reg_ = 0.
1403
+ for i in range(len(acts_scale)):
1404
+ vec = acts_scale[i]
1405
+
1406
+ l1 = torch.sum(vec)
1407
+ p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1)
1408
+ p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1)
1409
+ entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1))
1410
+ entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0))
1411
+ reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col) # both l1 and entropy
1412
+
1413
+ # regularize coefficient to encourage spline to be zero
1414
+ for i in range(len(self.act_fun)):
1415
+ coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1))
1416
+ coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)), dim=1))
1417
+ reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1
1418
+
1419
+ return reg_
1420
+
1421
+ def get_reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff):
1422
+ """
1423
+ Get regularization. This seems unnecessary but in case a class wants to inherit this, it may want to rewrite get_reg, but not reg.
1424
+ """
1425
+ return self.reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
1426
+
1427
+ def disable_symbolic_in_fit(self, lamb):
1428
+ """
1429
+ during fitting, disable symbolic if either is true (lamb = 0, none of symbolic functions is active)
1430
+ """
1431
+ old_save_act = self.save_act
1432
+ if lamb == 0.:
1433
+ self.save_act = False
1434
+
1435
+ # skip symbolic if no symbolic is turned on
1436
+ depth = len(self.symbolic_fun)
1437
+ no_symbolic = True
1438
+ for l in range(depth):
1439
+ no_symbolic *= torch.sum(torch.abs(self.symbolic_fun[l].mask)) == 0
1440
+
1441
+ old_symbolic_enabled = self.symbolic_enabled
1442
+
1443
+ if no_symbolic:
1444
+ self.symbolic_enabled = False
1445
+
1446
+ return old_save_act, old_symbolic_enabled
1447
+
1448
+ def get_params(self):
1449
+ """
1450
+ Get parameters
1451
+ """
1452
+ return self.parameters()
1453
+
1454
+ def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0.,
1455
+ lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
1456
+ stop_grid_update_step=50, batch=-1,
1457
+ metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
1458
+ singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):
1459
+ '''
1460
+ training
1461
+
1462
+ Args:
1463
+ -----
1464
+ dataset : dic
1465
+ contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
1466
+ opt : str
1467
+ "LBFGS" or "Adam"
1468
+ steps : int
1469
+ training steps
1470
+ log : int
1471
+ logging frequency
1472
+ lamb : float
1473
+ overall penalty strength
1474
+ lamb_l1 : float
1475
+ l1 penalty strength
1476
+ lamb_entropy : float
1477
+ entropy penalty strength
1478
+ lamb_coef : float
1479
+ coefficient magnitude penalty strength
1480
+ lamb_coefdiff : float
1481
+ difference of nearby coefficits (smoothness) penalty strength
1482
+ update_grid : bool
1483
+ If True, update grid regularly before stop_grid_update_step
1484
+ grid_update_num : int
1485
+ the number of grid updates before stop_grid_update_step
1486
+ start_grid_update_step : int
1487
+ no grid updates before this training step
1488
+ stop_grid_update_step : int
1489
+ no grid updates after this training step
1490
+ loss_fn : function
1491
+ loss function
1492
+ lr : float
1493
+ learning rate
1494
+ batch : int
1495
+ batch size, if -1 then full.
1496
+ save_fig_freq : int
1497
+ save figure every (save_fig_freq) steps
1498
+ singularity_avoiding : bool
1499
+ indicate whether to avoid singularity for the symbolic part
1500
+ y_th : float
1501
+ singularity threshold (anything above the threshold is considered singular and is softened in some ways)
1502
+ reg_metric : str
1503
+ regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'}
1504
+ metrics : a list of metrics (as functions)
1505
+ the metrics to be computed in training
1506
+ display_metrics : a list of functions
1507
+ the metric to be displayed in tqdm progress bar
1508
+
1509
+ Returns:
1510
+ --------
1511
+ results : dic
1512
+ results['train_loss'], 1D array of training losses (RMSE)
1513
+ results['test_loss'], 1D array of test losses (RMSE)
1514
+ results['reg'], 1D array of regularization
1515
+ other metrics specified in metrics
1516
+
1517
+ Example
1518
+ -------
1519
+ # >>> from yms_kan import *
1520
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
1521
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
1522
+ # >>> dataset = create_dataset(f, n_var=2)
1523
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
1524
+ # >>> model.plot()
1525
+ # Most examples in toturals involve the fit() method. Please check them for useness.
1526
+ '''
1527
+
1528
+ if lamb > 0. and not self.save_act:
1529
+ print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True')
1530
+
1531
+ old_save_act, old_symbolic_enabled = self.disable_symbolic_in_fit(lamb)
1532
+
1533
+ pbar = tqdm(range(steps), desc='description', ncols=100)
1534
+
1535
+ if loss_fn is None:
1536
+ loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
1537
+ else:
1538
+ loss_fn = loss_fn_eval = loss_fn
1539
+
1540
+ grid_update_freq = int(stop_grid_update_step / grid_update_num)
1541
+
1542
+ if opt == "Adam":
1543
+ optimizer = torch.optim.Adam(self.get_params(), lr=lr)
1544
+ elif opt == "LBFGS":
1545
+ optimizer = LBFGS(self.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
1546
+ tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
1547
+
1548
+ results = {'train_loss': [], 'test_loss': [], 'reg': []}
1549
+ if metrics is not None:
1550
+ for i in range(len(metrics)):
1551
+ results[metrics[i].__name__] = []
1552
+
1553
+ if batch == -1 or batch > dataset['train_input'].shape[0]:
1554
+ batch_size = dataset['train_input'].shape[0]
1555
+ batch_size_test = dataset['test_input'].shape[0]
1556
+ else:
1557
+ batch_size = batch
1558
+ batch_size_test = batch
1559
+
1560
+ global train_loss, reg_
1561
+
1562
+ def closure():
1563
+ global train_loss, reg_
1564
+ optimizer.zero_grad()
1565
+ pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th)
1566
+ train_loss = loss_fn(pred, dataset['train_label'][train_id])
1567
+ if self.save_act:
1568
+ if reg_metric == 'edge_backward':
1569
+ self.attribute()
1570
+ if reg_metric == 'node_backward':
1571
+ self.node_attribute()
1572
+ reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
1573
+ else:
1574
+ reg_ = torch.tensor(0.)
1575
+ objective = train_loss + lamb * reg_
1576
+ objective.backward()
1577
+ return objective
1578
+
1579
+ if save_fig:
1580
+ if not os.path.exists(img_folder):
1581
+ os.makedirs(img_folder)
1582
+
1583
+ for _ in pbar:
1584
+
1585
+ if _ == steps - 1 and old_save_act:
1586
+ self.save_act = True
1587
+
1588
+ if save_fig and _ % save_fig_freq == 0:
1589
+ save_act = self.save_act
1590
+ self.save_act = True
1591
+
1592
+ train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
1593
+ test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
1594
+
1595
+ if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step:
1596
+ self.update_grid(dataset['train_input'][train_id])
1597
+
1598
+ if opt == "LBFGS":
1599
+ optimizer.step(closure)
1600
+
1601
+ if opt == "Adam":
1602
+ pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding,
1603
+ y_th=y_th)
1604
+ train_loss = loss_fn(pred, dataset['train_label'][train_id])
1605
+ if self.save_act:
1606
+ if reg_metric == 'edge_backward':
1607
+ self.attribute()
1608
+ if reg_metric == 'node_backward':
1609
+ self.node_attribute()
1610
+ reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
1611
+ else:
1612
+ reg_ = torch.tensor(0.)
1613
+ loss = train_loss + lamb * reg_
1614
+ optimizer.zero_grad()
1615
+ loss.backward()
1616
+ optimizer.step()
1617
+
1618
+ test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id]), dataset['test_label'][test_id])
1619
+
1620
+ if metrics is not None:
1621
+ for i in range(len(metrics)):
1622
+ results[metrics[i].__name__].append(metrics[i]().item())
1623
+
1624
+ results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
1625
+ results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
1626
+ results['reg'].append(reg_.cpu().detach().numpy())
1627
+
1628
+ if _ % log == 0:
1629
+ if display_metrics is None:
1630
+ pbar.set_description("| train_loss: %.6f | test_loss: %.6f | reg: %.6f | " % (
1631
+ torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(),
1632
+ reg_.cpu().detach().numpy()))
1633
+ else:
1634
+ string = ''
1635
+ data = ()
1636
+ for metric in display_metrics:
1637
+ string += f' {metric}: %.6f |'
1638
+ try:
1639
+ results[metric]
1640
+ except:
1641
+ raise Exception(f'{metric} not recognized')
1642
+ data += (results[metric][-1],)
1643
+ pbar.set_description(string % data)
1644
+
1645
+ if save_fig and _ % save_fig_freq == 0:
1646
+ self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
1647
+ plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=100)
1648
+ plt.close()
1649
+ self.save_act = save_act
1650
+
1651
+ self.log_history('fit')
1652
+ # revert back to original state
1653
+ self.symbolic_enabled = old_symbolic_enabled
1654
+ return results
1655
+
1656
+ def fix(self, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, log=1, lamb=0.,
1657
+ lamb_l1=1., label=None,
1658
+ lamb_entropy=2., lamb_coef=0.,
1659
+ lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
1660
+ stop_grid_update_step=100,
1661
+ metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
1662
+ singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):
1663
+ """
1664
+ Args:
1665
+ label:
1666
+ batch_size:
1667
+ batch_size_test:
1668
+ lamb_coefdiff:
1669
+ metrics:
1670
+ stop_grid_update_step:
1671
+ save_fig:
1672
+ singularity_avoiding:
1673
+ y_th:
1674
+ in_vars:
1675
+ update_grid:
1676
+ grid_update_num:
1677
+ loss_fn:
1678
+ lr:
1679
+ beta:
1680
+ out_vars:
1681
+ save_fig_freq:
1682
+ reg_metric:
1683
+ display_metrics:
1684
+ img_folder:
1685
+ start_grid_update_step:
1686
+ lamb_coef:
1687
+ lamb_entropy:
1688
+ lamb_l1:
1689
+ lamb:
1690
+ log:
1691
+ epochs:
1692
+ opt:
1693
+ dataset (dict):
1694
+ """
1695
+
1696
+ def calculate_results(all_labels, all_predictions, average='macro'):
1697
+ result = {
1698
+ 'accuracy': accuracy_score(y_true=all_labels, y_pred=all_predictions),
1699
+ 'precision': precision_score(y_true=all_labels, y_pred=all_predictions, average=average),
1700
+ 'recall': recall_score(y_true=all_labels, y_pred=all_predictions, average=average),
1701
+ 'f1_score': f1_score(y_true=all_labels, y_pred=all_predictions, average=average),
1702
+ # 'cm': confusion_matrix(y_true=all_labels, y_pred=all_predictions, labels=np.arange(len(classes)))
1703
+ }
1704
+ return result
1705
+
1706
+ all_predictions = []
1707
+ all_labels = []
1708
+ if lamb > 0. and not self.save_act:
1709
+ print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True')
1710
+
1711
+ old_save_act, old_symbolic_enabled = self.disable_symbolic_in_fit(lamb)
1712
+ if label is not None:
1713
+ label = label.to(self.device)
1714
+ pbar = tqdm(range(epochs), desc='description', file=sys.stdout, colour='red')
1715
+
1716
+ if loss_fn is None:
1717
+ loss_fn = lambda x, y: torch.mean((x - y) ** 2)
1718
+ else:
1719
+ loss_fn = loss_fn
1720
+
1721
+ grid_update_freq = int(stop_grid_update_step / grid_update_num)
1722
+
1723
+ if opt == "Adam":
1724
+ optimizer = torch.optim.Adam(self.get_params(), lr=lr)
1725
+ elif opt == "LBFGS":
1726
+ optimizer = LBFGS(self.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
1727
+ tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
1728
+ else:
1729
+ optimizer = torch.optim.SGD(self.get_params(), lr=lr)
1730
+ lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
1731
+ results = {'train_loss': [], 'test_loss': [], 'reg': [], 'accuracy': [],
1732
+ 'precision': [], 'recall': [], 'f1_score': []}
1733
+ if metrics is not None:
1734
+ for i in range(len(metrics)):
1735
+ results[metrics[i].__name__] = []
1736
+ steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
1737
+
1738
+ train_loss = torch.zeros(1).to(self.device)
1739
+ reg_ = torch.zeros(1).to(self.device)
1740
+
1741
+ def closure():
1742
+ nonlocal train_loss, reg_
1743
+ optimizer.zero_grad()
1744
+ pred = self.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
1745
+ loss = loss_fn(pred, batch_train_label)
1746
+ if self.save_act:
1747
+ if reg_metric == 'edge_backward':
1748
+ self.attribute()
1749
+ if reg_metric == 'node_backward':
1750
+ self.node_attribute()
1751
+ reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
1752
+ else:
1753
+ reg_ = torch.tensor(0.)
1754
+ objective = loss + lamb * reg_
1755
+ train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
1756
+ objective.backward()
1757
+ return objective
1758
+
1759
+ if save_fig:
1760
+ if not os.path.exists(img_folder):
1761
+ os.makedirs(img_folder)
1762
+
1763
+ for epoch in pbar:
1764
+
1765
+ if epoch == epochs - 1 and old_save_act:
1766
+ self.save_act = True
1767
+
1768
+ if save_fig and epoch % save_fig_freq == 0:
1769
+ save_act = self.save_act
1770
+ self.save_act = True
1771
+
1772
+ train_indices = np.arange(dataset['train_input'].shape[0])
1773
+ np.random.shuffle(train_indices)
1774
+ for batch_num, i in enumerate(range(0, len(train_indices), batch_size), start=0):
1775
+ step = epoch * steps + batch_num + 1
1776
+ batch_train_id = train_indices[i:i + batch_size]
1777
+ batch_train_input = dataset['train_input'][batch_train_id].to(self.device)
1778
+ batch_train_label = dataset['train_label'][batch_train_id].to(self.device)
1779
+
1780
+ if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
1781
+ self.update_grid(batch_train_input)
1782
+
1783
+ if opt == "LBFGS":
1784
+ optimizer.step(closure)
1785
+
1786
+ # if opt == "Adam":
1787
+ else:
1788
+ optimizer.zero_grad()
1789
+ pred = self.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
1790
+ y_th=y_th)
1791
+ loss = loss_fn(pred, batch_train_label)
1792
+ if self.save_act:
1793
+ if reg_metric == 'edge_backward':
1794
+ self.attribute()
1795
+ if reg_metric == 'node_backward':
1796
+ self.node_attribute()
1797
+ reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
1798
+ else:
1799
+ reg_ = torch.tensor(0.)
1800
+ loss = loss + lamb * reg_
1801
+ train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
1802
+ loss.backward()
1803
+ optimizer.step()
1804
+
1805
+ test_loss = torch.zeros(1).to(self.device)
1806
+ with torch.no_grad():
1807
+ test_indices = np.arange(dataset['test_input'].shape[0])
1808
+ np.random.shuffle(test_indices)
1809
+ for batch_num, i in enumerate(range(0, len(test_indices), batch_size_test)):
1810
+ batch_test_id = test_indices[i:i + batch_size_test]
1811
+ batch_test_input = dataset['test_input'][batch_test_id].to(self.device)
1812
+ batch_test_label = dataset['test_label'][batch_test_id].to(self.device)
1813
+
1814
+ outputs = self.forward(batch_test_input)
1815
+
1816
+ loss = loss_fn(outputs, batch_test_label)
1817
+
1818
+ test_loss = (test_loss * batch_num + loss.detach()) / (batch_num + 1)
1819
+ if label is not None:
1820
+ diffs = torch.abs(outputs - label)
1821
+ closest_indices = torch.argmin(diffs, dim=1)
1822
+ closest_values = label[closest_indices]
1823
+ all_predictions.extend(closest_values.detach().cpu().numpy())
1824
+ all_labels.extend(batch_test_label.detach().cpu().numpy())
1825
+
1826
+ lr_scheduler.step(test_loss)
1827
+ if metrics is not None:
1828
+ for i in range(len(metrics)):
1829
+ results[metrics[i].__name__].append(metrics[i]().item())
1830
+
1831
+ results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
1832
+ results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
1833
+ results['reg'].append(reg_.cpu().detach().numpy())
1834
+ if label is not None:
1835
+ res = calculate_results(all_labels=all_labels, all_predictions=all_predictions)
1836
+ results.update({key: results[key] + [value] for key, value in res.items() if key in results})
1837
+
1838
+ if epoch % log == 0:
1839
+ if display_metrics is None:
1840
+ pbar.set_description(f"|train_loss:%.6f|test_loss:%.6f|reg:%.6f|" % (
1841
+ train_loss.item(), test_loss.item(), reg_.item()))
1842
+ else:
1843
+ string = ''
1844
+ data = ()
1845
+ for metric in display_metrics:
1846
+ string += f' {metric}: %.6f |'
1847
+ try:
1848
+ results[metric]
1849
+ except:
1850
+ raise Exception(f'{metric} not recognized')
1851
+ data += (results[metric][-1],)
1852
+ pbar.set_description(string % data)
1853
+
1854
+ if save_fig and epoch % save_fig_freq == 0:
1855
+ self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
1856
+ beta=beta)
1857
+ plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
1858
+ plt.close()
1859
+ self.save_act = save_act
1860
+
1861
+ self.log_history('fit')
1862
+ self.symbolic_enabled = old_symbolic_enabled
1863
+ return results
1864
+
1865
+ def prune_node(self, threshold=1e-2, mode="auto", active_neurons_id=None, log_history=True):
1866
+ """
1867
+ pruning nodes
1868
+
1869
+ Args:
1870
+ -----
1871
+ threshold : float
1872
+ if the attribution score of a neuron is below the threshold, it is considered dead and will be removed
1873
+ mode : str
1874
+ 'auto' or 'manual'. with 'auto', nodes are automatically pruned using threshold. with 'manual', active_neurons_id should be passed in.
1875
+
1876
+ Returns:
1877
+ --------
1878
+ pruned network : MultKAN
1879
+
1880
+ Example
1881
+ -------
1882
+ # >>> from yms_kan import *
1883
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
1884
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
1885
+ # >>> dataset = create_dataset(f, n_var=2)
1886
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
1887
+ # >>> model = model.prune_node()
1888
+ # >>> model.plot()
1889
+ """
1890
+ if self.acts is None:
1891
+ self.get_act()
1892
+
1893
+ mask_up = [torch.ones(self.width_in[0], device=self.device)]
1894
+ mask_down = []
1895
+ active_neurons_up = [list(range(self.width_in[0]))]
1896
+ active_neurons_down = []
1897
+ num_sums = []
1898
+ num_mults = []
1899
+ mult_arities = [[]]
1900
+
1901
+ if active_neurons_id != None:
1902
+ mode = "manual"
1903
+
1904
+ for i in range(len(self.acts_scale) - 1):
1905
+
1906
+ mult_arity = []
1907
+
1908
+ if mode == "auto":
1909
+ self.attribute()
1910
+ overall_important_up = self.node_scores[i + 1] > threshold
1911
+
1912
+ elif mode == "manual":
1913
+ overall_important_up = torch.zeros(self.width_in[i + 1], dtype=torch.bool, device=self.device)
1914
+ overall_important_up[active_neurons_id[i]] = True
1915
+
1916
+ num_sum = torch.sum(overall_important_up[:self.width[i + 1][0]])
1917
+ num_mult = torch.sum(overall_important_up[self.width[i + 1][0]:])
1918
+ if self.mult_homo == True:
1919
+ overall_important_down = torch.cat([overall_important_up[:self.width[i + 1][0]], (
1920
+ overall_important_up[self.width[i + 1][0]:][None, :].expand(self.mult_arity, -1)).T.reshape(-1, )],
1921
+ dim=0)
1922
+ else:
1923
+ overall_important_down = overall_important_up[:self.width[i + 1][0]]
1924
+ for j in range(overall_important_up[self.width[i + 1][0]:].shape[0]):
1925
+ active_bool = overall_important_up[self.width[i + 1][0] + j]
1926
+ arity = self.mult_arity[i + 1][j]
1927
+ overall_important_down = torch.cat(
1928
+ [overall_important_down, torch.tensor([active_bool] * arity).to(self.device)])
1929
+ if active_bool:
1930
+ mult_arity.append(arity)
1931
+
1932
+ num_sums.append(num_sum.item())
1933
+ num_mults.append(num_mult.item())
1934
+
1935
+ mask_up.append(overall_important_up.float())
1936
+ mask_down.append(overall_important_down.float())
1937
+
1938
+ active_neurons_up.append(torch.where(overall_important_up == True)[0])
1939
+ active_neurons_down.append(torch.where(overall_important_down == True)[0])
1940
+
1941
+ mult_arities.append(mult_arity)
1942
+
1943
+ active_neurons_down.append(list(range(self.width_out[-1])))
1944
+ mask_down.append(torch.ones(self.width_out[-1], device=self.device))
1945
+
1946
+ if not self.mult_homo:
1947
+ mult_arities.append(self.mult_arity[-1])
1948
+
1949
+ self.mask_up = mask_up
1950
+ self.mask_down = mask_down
1951
+
1952
+ # update act_fun[l].mask up
1953
+ for l in range(len(self.acts_scale) - 1):
1954
+ for i in range(self.width_in[l + 1]):
1955
+ if i not in active_neurons_up[l + 1]:
1956
+ self.remove_node(l + 1, i, mode='up', log_history=False)
1957
+
1958
+ for i in range(self.width_out[l + 1]):
1959
+ if i not in active_neurons_down[l]:
1960
+ self.remove_node(l + 1, i, mode='down', log_history=False)
1961
+
1962
+ model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun_name,
1963
+ mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False,
1964
+ state_id=self.state_id, round=self.round).to(self.device)
1965
+ model2.load_state_dict(self.state_dict())
1966
+
1967
+ width_new = [self.width[0]]
1968
+
1969
+ for i in range(len(self.acts_scale)):
1970
+
1971
+ if i < len(self.acts_scale) - 1:
1972
+ num_sum = num_sums[i]
1973
+ num_mult = num_mults[i]
1974
+ model2.node_bias[i].data = model2.node_bias[i].data[active_neurons_up[i + 1]]
1975
+ model2.node_scale[i].data = model2.node_scale[i].data[active_neurons_up[i + 1]]
1976
+ model2.subnode_bias[i].data = model2.subnode_bias[i].data[active_neurons_down[i]]
1977
+ model2.subnode_scale[i].data = model2.subnode_scale[i].data[active_neurons_down[i]]
1978
+ model2.width[i + 1] = [num_sum, num_mult]
1979
+
1980
+ model2.act_fun[i].out_dim_sum = num_sum
1981
+ model2.act_fun[i].out_dim_mult = num_mult
1982
+
1983
+ model2.symbolic_fun[i].out_dim_sum = num_sum
1984
+ model2.symbolic_fun[i].out_dim_mult = num_mult
1985
+
1986
+ width_new.append([num_sum, num_mult])
1987
+
1988
+ model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i])
1989
+ model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i])
1990
+
1991
+ model2.cache_data = self.cache_data
1992
+ model2.acts = None
1993
+
1994
+ width_new.append(self.width[-1])
1995
+ model2.width = width_new
1996
+
1997
+ if not self.mult_homo:
1998
+ model2.mult_arity = mult_arities
1999
+
2000
+ if log_history:
2001
+ self.log_history('prune_node')
2002
+ model2.state_id += 1
2003
+
2004
+ return model2
2005
+
2006
+ def prune_edge(self, threshold=3e-2, log_history=True):
2007
+ """
2008
+ pruning edges
2009
+
2010
+ Args:
2011
+ -----
2012
+ threshold : float
2013
+ if the attribution score of an edge is below the threshold, it is considered dead and will be set to zero.
2014
+
2015
+ Returns:
2016
+ --------
2017
+ pruned network : MultKAN
2018
+
2019
+ Example
2020
+ -------
2021
+ # >>> from yms_kan import *
2022
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2023
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
2024
+ # >>> dataset = create_dataset(f, n_var=2)
2025
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2026
+ # >>> model = model.prune_edge()
2027
+ # >>> model.plot()
2028
+ """
2029
+ if self.acts is None:
2030
+ self.get_act()
2031
+
2032
+ for i in range(len(self.width) - 1):
2033
+ # self.act_fun[i].mask.data = ((self.acts_scale[i] > threshold).permute(1,0)).float()
2034
+ old_mask = self.act_fun[i].mask.data
2035
+ self.act_fun[i].mask.data = ((self.edge_scores[i] > threshold).permute(1, 0) * old_mask).float()
2036
+
2037
+ if log_history:
2038
+ self.log_history('fix_symbolic')
2039
+
2040
+ def prune(self, node_th=1e-2, edge_th=3e-2):
2041
+ """
2042
+ prune (both nodes and edges)
2043
+
2044
+ Args:
2045
+ -----
2046
+ node_th : float
2047
+ if the attribution score of a node is below node_th, it is considered dead and will be set to zero.
2048
+ edge_th : float
2049
+ if the attribution score of an edge is below node_th, it is considered dead and will be set to zero.
2050
+
2051
+ Returns:
2052
+ --------
2053
+ pruned network : MultKAN
2054
+
2055
+ Example
2056
+ -------
2057
+ # >>> from yms_kan import *
2058
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2059
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
2060
+ # >>> dataset = create_dataset(f, n_var=2)
2061
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2062
+ # >>> model = model.prune()
2063
+ # >>> model.plot()
2064
+ """
2065
+ if self.acts is None:
2066
+ self.get_act()
2067
+
2068
+ self = self.prune_node(node_th, log_history=False)
2069
+ # self.prune_node(node_th, log_history=False)
2070
+ self.forward(self.cache_data)
2071
+ self.attribute()
2072
+ self.prune_edge(edge_th, log_history=False)
2073
+ self.log_history('prune')
2074
+ return self
2075
+
2076
+ def prune_input(self, threshold=1e-2, active_inputs=None, log_history=True):
2077
+ """
2078
+ prune inputs
2079
+
2080
+ Args:
2081
+ -----
2082
+ threshold : float
2083
+ if the attribution score of the input feature is below threshold, it is considered irrelevant.
2084
+ active_inputs : None or list
2085
+ if a list is passed, the manual mode will disregard attribution score and prune as instructed.
2086
+
2087
+ Returns:
2088
+ --------
2089
+ pruned network : MultKAN
2090
+
2091
+ Example1
2092
+ --------
2093
+ >>> # automatic
2094
+ # >>> from yms_kan import *
2095
+ # >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2096
+ # >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
2097
+ # >>> dataset = create_dataset(f, n_var=3)
2098
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2099
+ # >>> model.plot()
2100
+ # >>> model = model.prune_input()
2101
+ # >>> model.plot()
2102
+
2103
+ Example2
2104
+ --------
2105
+ >>> # automatic
2106
+ # >>> from yms_kan import *
2107
+ # >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2108
+ # >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
2109
+ # >>> dataset = create_dataset(f, n_var=3)
2110
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2111
+ # >>> model.plot()
2112
+ # >>> model = model.prune_input(active_inputs=[0,1])
2113
+ # >>> model.plot()
2114
+ """
2115
+ if active_inputs is None:
2116
+ self.attribute()
2117
+ input_score = self.node_scores[0]
2118
+ input_mask = input_score > threshold
2119
+ print('keep:', input_mask.tolist())
2120
+ input_id = torch.where(input_mask == True)[0]
2121
+
2122
+ else:
2123
+ input_id = torch.tensor(active_inputs, dtype=torch.long).to(self.device)
2124
+
2125
+ model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun,
2126
+ mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False,
2127
+ state_id=self.state_id, round=self.round).to(self.device)
2128
+ model2.load_state_dict(self.state_dict())
2129
+
2130
+ model2.act_fun[0] = model2.act_fun[0].get_subset(input_id, torch.arange(self.width_out[1]))
2131
+ model2.symbolic_fun[0] = self.symbolic_fun[0].get_subset(input_id, torch.arange(self.width_out[1]))
2132
+
2133
+ model2.cache_data = self.cache_data
2134
+ model2.acts = None
2135
+
2136
+ model2.width[0] = [len(input_id), 0]
2137
+ model2.input_id = input_id
2138
+
2139
+ if log_history:
2140
+ self.log_history('prune_input')
2141
+ model2.state_id += 1
2142
+
2143
+ return model2
2144
+
2145
+ def remove_edge(self, l, i, j, log_history=True):
2146
+ """
2147
+ remove activtion phi(l,i,j) (set its mask to zero)
2148
+ """
2149
+ self.act_fun[l].mask[i][j] = 0.
2150
+ if log_history:
2151
+ self.log_history('remove_edge')
2152
+
2153
+ def remove_node(self, l, i, mode='all', log_history=True):
2154
+ """
2155
+ remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero)
2156
+ """
2157
+ if mode == 'down':
2158
+ self.act_fun[l - 1].mask[:, i] = 0.
2159
+ self.symbolic_fun[l - 1].mask[i, :] *= 0.
2160
+
2161
+ elif mode == 'up':
2162
+ self.act_fun[l].mask[i, :] = 0.
2163
+ self.symbolic_fun[l].mask[:, i] *= 0.
2164
+
2165
+ else:
2166
+ self.remove_node(l, i, mode='up')
2167
+ self.remove_node(l, i, mode='down')
2168
+
2169
+ if log_history:
2170
+ self.log_history('remove_node')
2171
+
2172
+ def attribute(self, l=None, i=None, out_score=None, plot=True):
2173
+ """
2174
+ get attribution scores
2175
+
2176
+ Args:
2177
+ -----
2178
+ l : None or int
2179
+ layer index
2180
+ i : None or int
2181
+ neuron index
2182
+ out_score : None or 1D torch.float
2183
+ specify output scores
2184
+ plot : bool
2185
+ when plot = True, display the bar show
2186
+
2187
+ Returns:
2188
+ --------
2189
+ attribution scores
2190
+
2191
+ Example
2192
+ -------
2193
+ # >>> from yms_kan import *
2194
+ # >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2195
+ # >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
2196
+ # >>> dataset = create_dataset(f, n_var=3)
2197
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2198
+ # >>> model.attribute()
2199
+ # >>> model.feature_score
2200
+ """
2201
+ # output (out_dim, in_dim)
2202
+
2203
+ if l is not None:
2204
+ self.attribute()
2205
+ out_score = self.node_scores[l]
2206
+
2207
+ if self.acts == None:
2208
+ self.get_act()
2209
+
2210
+ def score_node2subnode(node_score, width, mult_arity, out_dim):
2211
+
2212
+ assert np.sum(width) == node_score.shape[1]
2213
+ if isinstance(mult_arity, int):
2214
+ n_subnode = width[0] + mult_arity * width[1]
2215
+ else:
2216
+ n_subnode = width[0] + int(np.sum(mult_arity))
2217
+
2218
+ #subnode_score_leaf = torch.zeros(out_dim, n_subnode).requires_grad_(True)
2219
+ #subnode_score = subnode_score_leaf.clone()
2220
+ #subnode_score[:,:width[0]] = node_score[:,:width[0]]
2221
+ subnode_score = node_score[:, :width[0]]
2222
+ if isinstance(mult_arity, int):
2223
+ #subnode_score[:,width[0]:] = node_score[:,width[0]:][:,:,None].expand(out_dim, node_score[width[0]:].shape[0], mult_arity).reshape(out_dim,-1)
2224
+ subnode_score = torch.cat([subnode_score, node_score[:, width[0]:][:, :, None].expand(out_dim,
2225
+ node_score[:,
2226
+ width[0]:].shape[
2227
+ 1],
2228
+ mult_arity).reshape(
2229
+ out_dim, -1)], dim=1)
2230
+ else:
2231
+ acml = width[0]
2232
+ for i in range(len(mult_arity)):
2233
+ #subnode_score[:, acml:acml+mult_arity[i]] = node_score[:, width[0]+i]
2234
+ subnode_score = torch.cat(
2235
+ [subnode_score, node_score[:, width[0] + i].expand(out_dim, mult_arity[i])], dim=1)
2236
+ acml += mult_arity[i]
2237
+ return subnode_score
2238
+
2239
+ node_scores = []
2240
+ subnode_scores = []
2241
+ edge_scores = []
2242
+
2243
+ l_query = l
2244
+ if l is None:
2245
+ l_end = self.depth
2246
+ else:
2247
+ l_end = l
2248
+
2249
+ # back propagate from the queried layer
2250
+ out_dim = self.width_in[l_end]
2251
+ if out_score is None:
2252
+ node_score = torch.eye(out_dim).requires_grad_(True)
2253
+ else:
2254
+ node_score = torch.diag(out_score).requires_grad_(True)
2255
+ node_scores.append(node_score)
2256
+
2257
+ device = self.act_fun[0].grid.device
2258
+
2259
+ for l in range(l_end, 0, -1):
2260
+
2261
+ # node to subnode
2262
+ if isinstance(self.mult_arity, int):
2263
+ subnode_score = score_node2subnode(node_score, self.width[l], self.mult_arity, out_dim=out_dim)
2264
+ else:
2265
+ mult_arity = self.mult_arity[l]
2266
+ #subnode_score = score_node2subnode(node_score, self.width[l], mult_arity)
2267
+ subnode_score = score_node2subnode(node_score, self.width[l], mult_arity, out_dim=out_dim)
2268
+
2269
+ subnode_scores.append(subnode_score)
2270
+ # subnode to edge
2271
+ #print(self.edge_actscale[l-1].device, subnode_score.device, self.subnode_actscale[l-1].device)
2272
+ edge_score = torch.einsum('ij,ki,i->kij', self.edge_actscale[l - 1], subnode_score.to(device),
2273
+ 1 / (self.subnode_actscale[l - 1] + 1e-4))
2274
+ edge_scores.append(edge_score)
2275
+
2276
+ # edge to node
2277
+ node_score = torch.sum(edge_score, dim=1)
2278
+ node_scores.append(node_score)
2279
+
2280
+ self.node_scores_all = list(reversed(node_scores))
2281
+ self.edge_scores_all = list(reversed(edge_scores))
2282
+ self.subnode_scores_all = list(reversed(subnode_scores))
2283
+
2284
+ self.node_scores = [torch.mean(l, dim=0) for l in self.node_scores_all]
2285
+ self.edge_scores = [torch.mean(l, dim=0) for l in self.edge_scores_all]
2286
+ self.subnode_scores = [torch.mean(l, dim=0) for l in self.subnode_scores_all]
2287
+
2288
+ # return
2289
+ if l_query != None:
2290
+ if i == None:
2291
+ return self.node_scores_all[0]
2292
+ else:
2293
+
2294
+ # plot
2295
+ if plot:
2296
+ in_dim = self.width_in[0]
2297
+ plt.figure(figsize=(1 * in_dim, 3))
2298
+ plt.bar(range(in_dim), self.node_scores_all[0][i].cpu().detach().numpy())
2299
+ plt.xticks(range(in_dim));
2300
+
2301
+ return self.node_scores_all[0][i]
2302
+
2303
+ def node_attribute(self):
2304
+ self.node_attribute_scores = []
2305
+ for l in range(1, self.depth + 1):
2306
+ node_attr = self.attribute(l)
2307
+ self.node_attribute_scores.append(node_attr)
2308
+
2309
+ def feature_interaction(self, l, neuron_th=1e-2, feature_th=1e-2):
2310
+ """
2311
+ get feature interaction
2312
+
2313
+ Args:
2314
+ -----
2315
+ l : int
2316
+ layer index
2317
+ neuron_th : float
2318
+ threshold to determine whether a neuron is active
2319
+ feature_th : float
2320
+ threshold to determine whether a feature is active
2321
+
2322
+ Returns:
2323
+ --------
2324
+ dictionary
2325
+
2326
+ Example
2327
+ -------
2328
+ # >>> from yms_kan import *
2329
+ # >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2330
+ # >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
2331
+ # >>> dataset = create_dataset(f, n_var=3)
2332
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2333
+ # >>> model.attribute()
2334
+ # >>> model.feature_interaction(1)
2335
+ """
2336
+ dic = {}
2337
+ width = self.width_in[l]
2338
+
2339
+ for i in range(width):
2340
+ score = self.attribute(l, i, plot=False)
2341
+
2342
+ if torch.max(score) > neuron_th:
2343
+ features = tuple(torch.where(score > torch.max(score) * feature_th)[0].detach().numpy())
2344
+ if features in dic.keys():
2345
+ dic[features] += 1
2346
+ else:
2347
+ dic[features] = 1
2348
+
2349
+ return dic
2350
+
2351
+ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True,
2352
+ r2_loss_fun=lambda x: np.log2(1 + 1e-5 - x), c_loss_fun=lambda x: x, weight_simple=0.8):
2353
+ """
2354
+ suggest symbolic function
2355
+
2356
+ Args:
2357
+ -----
2358
+ l : int
2359
+ layer index
2360
+ i : int
2361
+ neuron index in layer l
2362
+ j : int
2363
+ neuron index in layer j
2364
+ a_range : tuple
2365
+ search range of a
2366
+ b_range : tuple
2367
+ search range of b
2368
+ lib : list of str
2369
+ library of candidate symbolic functions
2370
+ topk : int
2371
+ the number of top functions displayed
2372
+ verbose : bool
2373
+ if verbose = True, print more information
2374
+ r2_loss_fun : functoon
2375
+ function : r2 -> "bits"
2376
+ c_loss_fun : fun
2377
+ function : c -> 'bits'
2378
+ weight_simple : float
2379
+ the simplifty weight: the higher, more prefer simplicity over performance
2380
+
2381
+
2382
+ Returns:
2383
+ --------
2384
+ best_name (str), best_fun (function), best_r2 (float), best_c (float)
2385
+
2386
+ Example
2387
+ -------
2388
+ # >>> from yms_kan import *
2389
+ # >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
2390
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
2391
+ # >>> dataset = create_dataset(f, n_var=3)
2392
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2393
+ # >>> model.suggest_symbolic(0,1,0)
2394
+ """
2395
+ r2s = []
2396
+ cs = []
2397
+
2398
+ if lib == None:
2399
+ symbolic_lib = SYMBOLIC_LIB
2400
+ else:
2401
+ symbolic_lib = {}
2402
+ for item in lib:
2403
+ symbolic_lib[item] = SYMBOLIC_LIB[item]
2404
+
2405
+ # getting r2 and complexities
2406
+ for (name, content) in symbolic_lib.items():
2407
+ r2 = self.fix_symbolic(l, i, j, name, a_range=a_range, b_range=b_range, verbose=False, log_history=False)
2408
+ if r2 == -1e8: # zero function
2409
+ r2s.append(-1e8)
2410
+ else:
2411
+ r2s.append(r2.item())
2412
+ self.unfix_symbolic(l, i, j, log_history=False)
2413
+ c = content[2]
2414
+ cs.append(c)
2415
+
2416
+ r2s = np.array(r2s)
2417
+ cs = np.array(cs)
2418
+ r2_loss = r2_loss_fun(r2s).astype('float')
2419
+ cs_loss = c_loss_fun(cs)
2420
+
2421
+ loss = weight_simple * cs_loss + (1 - weight_simple) * r2_loss
2422
+
2423
+ sorted_ids = np.argsort(loss)[:topk]
2424
+ r2s = r2s[sorted_ids][:topk]
2425
+ cs = cs[sorted_ids][:topk]
2426
+ r2_loss = r2_loss[sorted_ids][:topk]
2427
+ cs_loss = cs_loss[sorted_ids][:topk]
2428
+ loss = loss[sorted_ids][:topk]
2429
+
2430
+ topk = np.minimum(topk, len(symbolic_lib))
2431
+
2432
+ if verbose:
2433
+ # print results in a dataframe
2434
+ results = {}
2435
+ results['function'] = [list(symbolic_lib.items())[sorted_ids[i]][0] for i in range(topk)]
2436
+ results['fitting r2'] = r2s[:topk]
2437
+ results['r2 loss'] = r2_loss[:topk]
2438
+ results['complexity'] = cs[:topk]
2439
+ results['complexity loss'] = cs_loss[:topk]
2440
+ results['total loss'] = loss[:topk]
2441
+
2442
+ df = pd.DataFrame(results)
2443
+ print(df)
2444
+
2445
+ best_name = list(symbolic_lib.items())[sorted_ids[0]][0]
2446
+ best_fun = list(symbolic_lib.items())[sorted_ids[0]][1]
2447
+ best_r2 = r2s[0]
2448
+ best_c = cs[0]
2449
+
2450
+ return best_name, best_fun, best_r2, best_c;
2451
+
2452
+ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple=0.8,
2453
+ r2_threshold=0.0):
2454
+ """
2455
+ automatic symbolic regression for all edges
2456
+
2457
+ Args:
2458
+ -----
2459
+ a_range : tuple
2460
+ search range of a
2461
+ b_range : tuple
2462
+ search range of b
2463
+ lib : list of str
2464
+ library of candidate symbolic functions
2465
+ verbose : int
2466
+ larger verbosity => more verbosity
2467
+ weight_simple : float
2468
+ a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
2469
+ r2_threshold : float
2470
+ If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
2471
+ Returns:
2472
+ --------
2473
+ None
2474
+
2475
+ Example
2476
+ -------
2477
+ # >>> from yms_kan import *
2478
+ # >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
2479
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
2480
+ # >>> dataset = create_dataset(f, n_var=3)
2481
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2482
+ # >>> model.auto_symbolic()
2483
+ """
2484
+ for l in range(len(self.width_in) - 1):
2485
+ for i in range(self.width_in[l]):
2486
+ for j in range(self.width_out[l + 1]):
2487
+ if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
2488
+ print(f'skipping ({l},{i},{j}) since already symbolic')
2489
+ elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.:
2490
+ self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False)
2491
+ print(f'fixing ({l},{i},{j}) with 0')
2492
+ else:
2493
+ name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib,
2494
+ verbose=False, weight_simple=weight_simple)
2495
+ if r2 >= r2_threshold:
2496
+ self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
2497
+ if verbose >= 1:
2498
+ print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
2499
+ else:
2500
+ print(
2501
+ f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.')
2502
+
2503
+ self.log_history('auto_symbolic')
2504
+
2505
+ def symbolic_formula(self, var=None, normalizer=None, output_normalizer=None):
2506
+ """
2507
+ get symbolic formula
2508
+
2509
+ Args:
2510
+ -----
2511
+ var : None or a list of sympy expression
2512
+ input variables
2513
+ normalizer : [mean, std]
2514
+ output_normalizer : [mean, std]
2515
+
2516
+ Returns:
2517
+ --------
2518
+ None
2519
+
2520
+ Example
2521
+ -------
2522
+ # >>> from yms_kan import *
2523
+ # >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
2524
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
2525
+ # >>> dataset = create_dataset(f, n_var=3)
2526
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2527
+ # >>> model.auto_symbolic()
2528
+ # >>> model.symbolic_formula()[0][0]
2529
+ """
2530
+
2531
+ symbolic_acts = []
2532
+ symbolic_acts_premult = []
2533
+ x = []
2534
+
2535
+ def ex_round(ex1, n_digit):
2536
+ ex2 = ex1
2537
+ for a in sympy.preorder_traversal(ex1):
2538
+ if isinstance(a, sympy.Float):
2539
+ ex2 = ex2.subs(a, round(a, n_digit))
2540
+ return ex2
2541
+
2542
+ # define variables
2543
+ if var is None:
2544
+ for ii in range(1, self.width[0][0] + 1):
2545
+ exec(f"x{ii} = sympy.Symbol('x_{ii}')")
2546
+ exec(f"x.append(x{ii})")
2547
+ elif isinstance(var[0], sympy.Expr):
2548
+ x = var
2549
+ else:
2550
+ x = [sympy.symbols(var_) for var_ in var]
2551
+
2552
+ x0 = x
2553
+
2554
+ if normalizer is not None:
2555
+ mean = normalizer[0]
2556
+ std = normalizer[1]
2557
+ x = [(x[i] - mean[i]) / std[i] for i in range(len(x))]
2558
+
2559
+ symbolic_acts.append(x)
2560
+
2561
+ for l in range(len(self.width_in) - 1):
2562
+ num_sum = self.width[l + 1][0]
2563
+ num_mult = self.width[l + 1][1]
2564
+ y = []
2565
+ for j in range(self.width_out[l + 1]):
2566
+ yj = 0.
2567
+ for i in range(self.width_in[l]):
2568
+ a, b, c, d = self.symbolic_fun[l].affine[j, i]
2569
+ sympy_fun = self.symbolic_fun[l].funs_sympy[j][i]
2570
+ try:
2571
+ yj += c * sympy_fun(a * x[i] + b) + d
2572
+ except:
2573
+ print('make sure all activations need to be converted to symbolic formulas first!')
2574
+ return
2575
+ yj = self.subnode_scale[l][j] * yj + self.subnode_bias[l][j]
2576
+ if simplify == True:
2577
+ y.append(sympy.simplify(yj))
2578
+ else:
2579
+ y.append(yj)
2580
+
2581
+ symbolic_acts_premult.append(y)
2582
+
2583
+ mult = []
2584
+ for k in range(num_mult):
2585
+ if isinstance(self.mult_arity, int):
2586
+ mult_arity = self.mult_arity
2587
+ else:
2588
+ mult_arity = self.mult_arity[l + 1][k]
2589
+ for i in range(mult_arity - 1):
2590
+ if i == 0:
2591
+ mult_k = y[num_sum + 2 * k] * y[num_sum + 2 * k + 1]
2592
+ else:
2593
+ mult_k = mult_k * y[num_sum + 2 * k + i + 1]
2594
+ mult.append(mult_k)
2595
+
2596
+ y = y[:num_sum] + mult
2597
+
2598
+ for j in range(self.width_in[l + 1]):
2599
+ y[j] = self.node_scale[l][j] * y[j] + self.node_bias[l][j]
2600
+
2601
+ x = y
2602
+ symbolic_acts.append(x)
2603
+
2604
+ if output_normalizer != None:
2605
+ output_layer = symbolic_acts[-1]
2606
+ means = output_normalizer[0]
2607
+ stds = output_normalizer[1]
2608
+
2609
+ assert len(output_layer) == len(means), 'output_normalizer does not match the output layer'
2610
+ assert len(output_layer) == len(stds), 'output_normalizer does not match the output layer'
2611
+
2612
+ output_layer = [(output_layer[i] * stds[i] + means[i]) for i in range(len(output_layer))]
2613
+ symbolic_acts[-1] = output_layer
2614
+
2615
+ self.symbolic_acts = [[symbolic_acts[l][i] for i in range(len(symbolic_acts[l]))] for l in
2616
+ range(len(symbolic_acts))]
2617
+ self.symbolic_acts_premult = [[symbolic_acts_premult[l][i] for i in range(len(symbolic_acts_premult[l]))] for l
2618
+ in range(len(symbolic_acts_premult))]
2619
+
2620
+ out_dim = len(symbolic_acts[-1])
2621
+ #return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0
2622
+
2623
+ if simplify:
2624
+ return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0
2625
+ else:
2626
+ return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0
2627
+
2628
+ def expand_depth(self):
2629
+ '''
2630
+ expand network depth, add an indentity layer to the end. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
2631
+
2632
+ Args:
2633
+ -----
2634
+ var : None or a list of sympy expression
2635
+ input variables
2636
+ normalizer : [mean, std]
2637
+ output_normalizer : [mean, std]
2638
+
2639
+ Returns:
2640
+ --------
2641
+ None
2642
+ '''
2643
+ self.depth += 1
2644
+
2645
+ # add kanlayer, set mask to zero
2646
+ dim_out = self.width_in[-1]
2647
+ layer = KANLayer(dim_out, dim_out, num=self.grid, k=self.k)
2648
+ layer.mask *= 0.
2649
+ self.act_fun.append(layer)
2650
+
2651
+ self.width.append([dim_out, 0])
2652
+ self.mult_arity.append([])
2653
+
2654
+ # add symbolic_kanlayer set mask to one. fun = identity on diagonal and zero for off-diagonal
2655
+ layer = Symbolic_KANLayer(dim_out, dim_out)
2656
+ layer.mask += 1.
2657
+
2658
+ for j in range(dim_out):
2659
+ for i in range(dim_out):
2660
+ if i == j:
2661
+ layer.fix_symbolic(i, j, 'x')
2662
+ else:
2663
+ layer.fix_symbolic(i, j, '0')
2664
+
2665
+ self.symbolic_fun.append(layer)
2666
+
2667
+ self.node_bias.append(
2668
+ torch.nn.Parameter(torch.zeros(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
2669
+ self.node_scale.append(
2670
+ torch.nn.Parameter(torch.ones(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
2671
+ self.subnode_bias.append(
2672
+ torch.nn.Parameter(torch.zeros(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
2673
+ self.subnode_scale.append(
2674
+ torch.nn.Parameter(torch.ones(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
2675
+
2676
+ def expand_width(self, layer_id, n_added_nodes, sum_bool=True, mult_arity=2):
2677
+ '''
2678
+ expand network width. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
2679
+
2680
+ Args:
2681
+ -----
2682
+ layer_id : int
2683
+ layer index
2684
+ n_added_nodes : init
2685
+ the number of added nodes
2686
+ sum_bool : bool
2687
+ if sum_bool == True, added nodes are addition nodes; otherwise multiplication nodes
2688
+ mult_arity : init
2689
+ multiplication arity (the number of numbers to be multiplied)
2690
+
2691
+ Returns:
2692
+ --------
2693
+ None
2694
+ '''
2695
+
2696
+ def _expand(layer_id, n_added_nodes, sum_bool=True, mult_arity=2, added_dim='out'):
2697
+ l = layer_id
2698
+ in_dim = self.symbolic_fun[l].in_dim
2699
+ out_dim = self.symbolic_fun[l].out_dim
2700
+ if sum_bool:
2701
+
2702
+ if added_dim == 'out':
2703
+ new = Symbolic_KANLayer(in_dim, out_dim + n_added_nodes)
2704
+ old = self.symbolic_fun[l]
2705
+ in_id = np.arange(in_dim)
2706
+ out_id = np.arange(out_dim + n_added_nodes)
2707
+
2708
+ for j in out_id:
2709
+ for i in in_id:
2710
+ new.fix_symbolic(i, j, '0')
2711
+ new.mask += 1.
2712
+
2713
+ for j in out_id:
2714
+ for i in in_id:
2715
+ if j > n_added_nodes - 1:
2716
+ new.funs[j][i] = old.funs[j - n_added_nodes][i]
2717
+ new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j - n_added_nodes][i]
2718
+ new.funs_sympy[j][i] = old.funs_sympy[j - n_added_nodes][i]
2719
+ new.funs_name[j][i] = old.funs_name[j - n_added_nodes][i]
2720
+ new.affine.data[j][i] = old.affine.data[j - n_added_nodes][i]
2721
+
2722
+ self.symbolic_fun[l] = new
2723
+ self.act_fun[l] = KANLayer(in_dim, out_dim + n_added_nodes, num=self.grid, k=self.k)
2724
+ self.act_fun[l].mask *= 0.
2725
+
2726
+ self.node_scale[l].data = torch.cat(
2727
+ [torch.ones(n_added_nodes, device=self.device), self.node_scale[l].data])
2728
+ self.node_bias[l].data = torch.cat(
2729
+ [torch.zeros(n_added_nodes, device=self.device), self.node_bias[l].data])
2730
+ self.subnode_scale[l].data = torch.cat(
2731
+ [torch.ones(n_added_nodes, device=self.device), self.subnode_scale[l].data])
2732
+ self.subnode_bias[l].data = torch.cat(
2733
+ [torch.zeros(n_added_nodes, device=self.device), self.subnode_bias[l].data])
2734
+
2735
+ if added_dim == 'in':
2736
+ new = Symbolic_KANLayer(in_dim + n_added_nodes, out_dim)
2737
+ old = self.symbolic_fun[l]
2738
+ in_id = np.arange(in_dim + n_added_nodes)
2739
+ out_id = np.arange(out_dim)
2740
+
2741
+ for j in out_id:
2742
+ for i in in_id:
2743
+ new.fix_symbolic(i, j, '0')
2744
+ new.mask += 1.
2745
+
2746
+ for j in out_id:
2747
+ for i in in_id:
2748
+ if i > n_added_nodes - 1:
2749
+ new.funs[j][i] = old.funs[j][i - n_added_nodes]
2750
+ new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i - n_added_nodes]
2751
+ new.funs_sympy[j][i] = old.funs_sympy[j][i - n_added_nodes]
2752
+ new.funs_name[j][i] = old.funs_name[j][i - n_added_nodes]
2753
+ new.affine.data[j][i] = old.affine.data[j][i - n_added_nodes]
2754
+
2755
+ self.symbolic_fun[l] = new
2756
+ self.act_fun[l] = KANLayer(in_dim + n_added_nodes, out_dim, num=self.grid, k=self.k)
2757
+ self.act_fun[l].mask *= 0.
2758
+
2759
+
2760
+ else:
2761
+
2762
+ if isinstance(mult_arity, int):
2763
+ mult_arity = [mult_arity] * n_added_nodes
2764
+
2765
+ if added_dim == 'out':
2766
+ n_added_subnodes = np.sum(mult_arity)
2767
+ new = Symbolic_KANLayer(in_dim, out_dim + n_added_subnodes)
2768
+ old = self.symbolic_fun[l]
2769
+ in_id = np.arange(in_dim)
2770
+ out_id = np.arange(out_dim + n_added_nodes)
2771
+
2772
+ for j in out_id:
2773
+ for i in in_id:
2774
+ new.fix_symbolic(i, j, '0')
2775
+ new.mask += 1.
2776
+
2777
+ for j in out_id:
2778
+ for i in in_id:
2779
+ if j < out_dim:
2780
+ new.funs[j][i] = old.funs[j][i]
2781
+ new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i]
2782
+ new.funs_sympy[j][i] = old.funs_sympy[j][i]
2783
+ new.funs_name[j][i] = old.funs_name[j][i]
2784
+ new.affine.data[j][i] = old.affine.data[j][i]
2785
+
2786
+ self.symbolic_fun[l] = new
2787
+ self.act_fun[l] = KANLayer(in_dim, out_dim + n_added_subnodes, num=self.grid, k=self.k)
2788
+ self.act_fun[l].mask *= 0.
2789
+
2790
+ self.node_scale[l].data = torch.cat(
2791
+ [self.node_scale[l].data, torch.ones(n_added_nodes, device=self.device)])
2792
+ self.node_bias[l].data = torch.cat(
2793
+ [self.node_bias[l].data, torch.zeros(n_added_nodes, device=self.device)])
2794
+ self.subnode_scale[l].data = torch.cat(
2795
+ [self.subnode_scale[l].data, torch.ones(n_added_subnodes, device=self.device)])
2796
+ self.subnode_bias[l].data = torch.cat(
2797
+ [self.subnode_bias[l].data, torch.zeros(n_added_subnodes, device=self.device)])
2798
+
2799
+ if added_dim == 'in':
2800
+ new = Symbolic_KANLayer(in_dim + n_added_nodes, out_dim)
2801
+ old = self.symbolic_fun[l]
2802
+ in_id = np.arange(in_dim + n_added_nodes)
2803
+ out_id = np.arange(out_dim)
2804
+
2805
+ for j in out_id:
2806
+ for i in in_id:
2807
+ new.fix_symbolic(i, j, '0')
2808
+ new.mask += 1.
2809
+
2810
+ for j in out_id:
2811
+ for i in in_id:
2812
+ if i < in_dim:
2813
+ new.funs[j][i] = old.funs[j][i]
2814
+ new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i]
2815
+ new.funs_sympy[j][i] = old.funs_sympy[j][i]
2816
+ new.funs_name[j][i] = old.funs_name[j][i]
2817
+ new.affine.data[j][i] = old.affine.data[j][i]
2818
+
2819
+ self.symbolic_fun[l] = new
2820
+ self.act_fun[l] = KANLayer(in_dim + n_added_nodes, out_dim, num=self.grid, k=self.k)
2821
+ self.act_fun[l].mask *= 0.
2822
+
2823
+ _expand(layer_id - 1, n_added_nodes, sum_bool, mult_arity, added_dim='out')
2824
+ _expand(layer_id, n_added_nodes, sum_bool, mult_arity, added_dim='in')
2825
+ if sum_bool:
2826
+ self.width[layer_id][0] += n_added_nodes
2827
+ else:
2828
+ if isinstance(mult_arity, int):
2829
+ mult_arity = [mult_arity] * n_added_nodes
2830
+
2831
+ self.width[layer_id][1] += n_added_nodes
2832
+ self.mult_arity[layer_id] += mult_arity
2833
+
2834
+ def perturb(self, mag=1.0, mode='non-intrusive'):
2835
+ """
2836
+ preturb a network. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
2837
+
2838
+ Args:
2839
+ -----
2840
+ mag : float
2841
+ perturbation magnitude
2842
+ mode : str
2843
+ pertubatation mode, choices = {'non-intrusive', 'all', 'minimal'}
2844
+
2845
+ Returns:
2846
+ --------
2847
+ None
2848
+ """
2849
+ perturb_bool = {}
2850
+
2851
+ if mode == 'all':
2852
+ perturb_bool['aa_a'] = True
2853
+ perturb_bool['aa_i'] = True
2854
+ perturb_bool['ai'] = True
2855
+ perturb_bool['ia'] = True
2856
+ perturb_bool['ii'] = True
2857
+ elif mode == 'non-intrusive':
2858
+ perturb_bool['aa_a'] = False
2859
+ perturb_bool['aa_i'] = False
2860
+ perturb_bool['ai'] = True
2861
+ perturb_bool['ia'] = False
2862
+ perturb_bool['ii'] = True
2863
+ elif mode == 'minimal':
2864
+ perturb_bool['aa_a'] = True
2865
+ perturb_bool['aa_i'] = False
2866
+ perturb_bool['ai'] = False
2867
+ perturb_bool['ia'] = False
2868
+ perturb_bool['ii'] = False
2869
+ else:
2870
+ raise Exception('mode not recognized, valid modes are \'all\', \'non-intrusive\', \'minimal\'.')
2871
+
2872
+ for l in range(self.depth):
2873
+ funs_name = self.symbolic_fun[l].funs_name
2874
+ for j in range(self.width_out[l + 1]):
2875
+ for i in range(self.width_in[l]):
2876
+ out_array = list(np.array(self.symbolic_fun[l].funs_name)[j])
2877
+ in_array = list(np.array(self.symbolic_fun[l].funs_name)[:, i])
2878
+ out_active = len([i for i, x in enumerate(out_array) if x != "0"]) > 0
2879
+ in_active = len([i for i, x in enumerate(in_array) if x != "0"]) > 0
2880
+ dic = {True: 'a', False: 'i'}
2881
+ edge_type = dic[in_active] + dic[out_active]
2882
+
2883
+ if l < self.depth - 1 or mode != 'non-intrusive':
2884
+
2885
+ if edge_type == 'aa':
2886
+ if self.symbolic_fun[l].funs_name[j][i] == '0':
2887
+ edge_type += '_i'
2888
+ else:
2889
+ edge_type += '_a'
2890
+
2891
+ if perturb_bool[edge_type]:
2892
+ self.act_fun[l].mask.data[i][j] = mag
2893
+
2894
+ if l == self.depth - 1 and mode == 'non-intrusive':
2895
+ self.act_fun[l].mask.data[i][j] = torch.tensor(1.)
2896
+ self.act_fun[l].scale_base.data[i][j] = torch.tensor(0.)
2897
+ self.act_fun[l].scale_sp.data[i][j] = torch.tensor(0.)
2898
+
2899
+ self.get_act(self.cache_data)
2900
+
2901
+ self.log_history('perturb')
2902
+
2903
+ def module(self, start_layer, chain):
2904
+ """
2905
+ specify network modules
2906
+
2907
+ Args:
2908
+ -----
2909
+ start_layer : int
2910
+ the earliest layer of the module
2911
+ chain : str
2912
+ specify neurons in the module
2913
+
2914
+ Returns:
2915
+ --------
2916
+ None
2917
+ """
2918
+ #chain = '[-1]->[-1,-2]->[-1]->[-1]'
2919
+ groups = chain.split('->')
2920
+ n_total_layers = len(groups) // 2
2921
+ #start_layer = 0
2922
+
2923
+ for l in range(n_total_layers):
2924
+ current_layer = cl = start_layer + l
2925
+ id_in = [int(i) for i in groups[2 * l][1:-1].split(',')]
2926
+ id_out = [int(i) for i in groups[2 * l + 1][1:-1].split(',')]
2927
+
2928
+ in_dim = self.width_in[cl]
2929
+ out_dim = self.width_out[cl + 1]
2930
+ id_in_other = list(set(range(in_dim)) - set(id_in))
2931
+ id_out_other = list(set(range(out_dim)) - set(id_out))
2932
+ self.act_fun[cl].mask.data[np.ix_(id_in_other, id_out)] = 0.
2933
+ self.act_fun[cl].mask.data[np.ix_(id_in, id_out_other)] = 0.
2934
+ self.symbolic_fun[cl].mask.data[np.ix_(id_out, id_in_other)] = 0.
2935
+ self.symbolic_fun[cl].mask.data[np.ix_(id_out_other, id_in)] = 0.
2936
+
2937
+ self.log_history('module')
2938
+
2939
+ def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
2940
+ """
2941
+ turn KAN into a tree
2942
+ """
2943
+ if x == None:
2944
+ x = self.cache_data
2945
+ plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test,
2946
+ verbose=verbose)
2947
+
2948
+ def speed(self, compile=False):
2949
+ """
2950
+ turn on KAN's speed mode
2951
+ """
2952
+ self.symbolic_enabled = False
2953
+ self.save_act = False
2954
+ self.auto_save = False
2955
+ if compile == True:
2956
+ return torch.compile(self)
2957
+ else:
2958
+ return self
2959
+
2960
+ def get_act(self, x=None):
2961
+ """
2962
+ collect intermidate activations
2963
+ """
2964
+ if isinstance(x, dict):
2965
+ x = x['train_input']
2966
+ if x is None:
2967
+ if self.cache_data != None:
2968
+ x = self.cache_data
2969
+ else:
2970
+ raise Exception("missing input data x")
2971
+ save_act = self.save_act
2972
+ self.save_act = True
2973
+ self.forward(x)
2974
+ self.save_act = save_act
2975
+
2976
+ def get_fun(self, l, i, j):
2977
+ """
2978
+ get function (l,i,j)
2979
+ """
2980
+ inputs = self.spline_preacts[l][:, j, i].cpu().detach().numpy()
2981
+ outputs = self.spline_postacts[l][:, j, i].cpu().detach().numpy()
2982
+ # they are not ordered yet
2983
+ rank = np.argsort(inputs)
2984
+ inputs = inputs[rank]
2985
+ outputs = outputs[rank]
2986
+ plt.figure(figsize=(3, 3))
2987
+ plt.plot(inputs, outputs, marker="o")
2988
+ return inputs, outputs
2989
+
2990
+ def history(self, k='all'):
2991
+ """
2992
+ get history
2993
+ """
2994
+ with open(self.ckpt_path + '/history.txt', 'r') as f:
2995
+ data = f.readlines()
2996
+ n_line = len(data)
2997
+ if k == 'all':
2998
+ k = n_line
2999
+
3000
+ data = data[-k:]
3001
+ for line in data:
3002
+ print(line[:-1])
3003
+
3004
+ @property
3005
+ def n_edge(self):
3006
+ """
3007
+ the number of active edges
3008
+ """
3009
+ depth = len(self.act_fun)
3010
+ complexity = 0
3011
+ for l in range(depth):
3012
+ complexity += torch.sum(self.act_fun[l].mask > 0.)
3013
+ return complexity.item()
3014
+
3015
+ def evaluate(self, dataset):
3016
+ evaluation = {'test_loss': torch.sqrt(
3017
+ torch.mean((self.forward(dataset['test_input']) - dataset['test_label']) ** 2)).item(),
3018
+ 'n_edge': self.n_edge, 'n_grid': self.grid}
3019
+ # add other metrics (maybe accuracy)
3020
+ return evaluation
3021
+
3022
+ def swap(self, l, i1, i2, log_history=True):
3023
+
3024
+ self.act_fun[l - 1].swap(i1, i2, mode='out')
3025
+ self.symbolic_fun[l - 1].swap(i1, i2, mode='out')
3026
+ self.act_fun[l].swap(i1, i2, mode='in')
3027
+ self.symbolic_fun[l].swap(i1, i2, mode='in')
3028
+
3029
+ def swap_(data, i1, i2):
3030
+ data[i1], data[i2] = data[i2], data[i1]
3031
+
3032
+ swap_(self.node_scale[l - 1].data, i1, i2)
3033
+ swap_(self.node_bias[l - 1].data, i1, i2)
3034
+ swap_(self.subnode_scale[l - 1].data, i1, i2)
3035
+ swap_(self.subnode_bias[l - 1].data, i1, i2)
3036
+
3037
+ if log_history:
3038
+ self.log_history('swap')
3039
+
3040
+ @property
3041
+ def connection_cost(self):
3042
+
3043
+ cc = 0.
3044
+ for t in self.edge_scores:
3045
+ def get_coordinate(n):
3046
+ return torch.linspace(0, 1, steps=n + 1, device=self.device)[:n] + 1 / (2 * n)
3047
+
3048
+ in_dim = t.shape[0]
3049
+ x_in = get_coordinate(in_dim)
3050
+
3051
+ out_dim = t.shape[1]
3052
+ x_out = get_coordinate(out_dim)
3053
+
3054
+ dist = torch.abs(x_in[:, None] - x_out[None, :])
3055
+ cc += torch.sum(dist * t)
3056
+
3057
+ return cc
3058
+
3059
+ def auto_swap_l(self, l):
3060
+
3061
+ num = self.width_in[1]
3062
+ for i in range(num):
3063
+ ccs = []
3064
+ for j in range(num):
3065
+ self.swap(l, i, j, log_history=False)
3066
+ self.get_act()
3067
+ self.attribute()
3068
+ cc = self.connection_cost.detach().clone()
3069
+ ccs.append(cc)
3070
+ self.swap(l, i, j, log_history=False)
3071
+ j = torch.argmin(torch.tensor(ccs))
3072
+ self.swap(l, i, j, log_history=False)
3073
+
3074
+ def auto_swap(self):
3075
+ """
3076
+ automatically swap neurons such as connection costs are minimized
3077
+ """
3078
+ depth = self.depth
3079
+ for l in range(1, depth):
3080
+ self.auto_swap_l(l)
3081
+
3082
+ self.log_history('auto_swap')
3083
+
3084
+
3085
+ KAN = MultKAN