yms-kan 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3087 @@
1
+ import math
2
+ import sys
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from PIL import Image
8
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
9
+
10
+ from .KANLayer import KANLayer
11
+ #from .Symbolic_MultKANLayer import *
12
+ from .Symbolic_KANLayer import Symbolic_KANLayer
13
+ from .LBFGS import *
14
+ import os
15
+ import glob
16
+ import matplotlib.pyplot as plt
17
+ from tqdm import tqdm
18
+ import random
19
+ import copy
20
+ #from .MultKANLayer import MultKANLayer
21
+ import pandas as pd
22
+ from sympy.printing import latex
23
+ from sympy import *
24
+ import sympy
25
+ import yaml
26
+ from .spline import curve2coef
27
+ from .utils import SYMBOLIC_LIB
28
+ from .hypothesis import plot_tree
29
+ from contextlib import contextmanager
30
+ import shutil
31
+ import gc
32
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
33
+
34
+
35
+ class MultKAN(nn.Module):
36
+ """
37
+ KAN class
38
+
39
+ Attributes:
40
+ -----------
41
+ grid : int
42
+ the number of grid intervals
43
+ k : int
44
+ spline order
45
+ act_fun : a list of KANLayers
46
+ symbolic_fun: a list of Symbolic_KANLayer
47
+ depth : int
48
+ depth of KAN
49
+ width : list
50
+ number of neurons in each layer.
51
+ Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
52
+ With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2).
53
+ mult_arity : int, or list of int lists
54
+ multiplication arity for each multiplication node (the number of numbers to be multiplied)
55
+ grid : int
56
+ the number of grid intervals
57
+ k : int
58
+ the order of piecewise polynomial
59
+ base_fun : fun
60
+ residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
61
+ symbolic_fun : a list of Symbolic_KANLayer
62
+ Symbolic_KANLayers
63
+ symbolic_enabled : bool
64
+ If False, the symbolic front is not computed (to save time). Default: True.
65
+ width_in : list
66
+ The number of input neurons for each layer
67
+ width_out : list
68
+ The number of output neurons for each layer
69
+ base_fun_name : str
70
+ The base function b(x)
71
+ grip_eps : float
72
+ The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
73
+ node_bias : a list of 1D torch.float
74
+ node_scale : a list of 1D torch.float
75
+ subnode_bias : a list of 1D torch.float
76
+ subnode_scale : a list of 1D torch.float
77
+ symbolic_enabled : bool
78
+ when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
79
+ affine_trainable : bool
80
+ indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
81
+ sp_trainable : bool
82
+ indicate whether the overall magnitude of splines is trainable
83
+ sb_trainable : bool
84
+ indicate whether the overall magnitude of base function is trainable
85
+ save_act : bool
86
+ indicate whether intermediate activations are saved in forward pass
87
+ node_scores : None or list of 1D torch.float
88
+ node attribution score
89
+ edge_scores : None or list of 2D torch.float
90
+ edge attribution score
91
+ subnode_scores : None or list of 1D torch.float
92
+ subnode attribution score
93
+ cache_data : None or 2D torch.float
94
+ cached input data
95
+ acts : None or a list of 2D torch.float
96
+ activations on nodes
97
+ auto_save : bool
98
+ indicate whether to automatically save a checkpoint once the model is modified
99
+ state_id : int
100
+ the state of the model (used to save checkpoint)
101
+ ckpt_path : str
102
+ the folder to store checkpoints
103
+ round : int
104
+ the number of times rewind() has been called
105
+ device : str
106
+ """
107
+
108
+ def __init__(self, width=None, grid=3, k=3, mult_arity=2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0,
109
+ base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1],
110
+ sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True,
111
+ first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):
112
+ """
113
+ initalize a KAN model
114
+
115
+ Args:
116
+ -----
117
+ width : list of int
118
+ Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
119
+ With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
120
+ grid : int
121
+ number of grid intervals. Default: 3.
122
+ k : int
123
+ order of piecewise polynomial. Default: 3.
124
+ mult_arity : int, or list of int lists
125
+ multiplication arity for each multiplication node (the number of numbers to be multiplied)
126
+ noise_scale : float
127
+ initial injected noise to spline.
128
+ base_fun : str
129
+ the residual function b(x). Default: 'silu'
130
+ symbolic_enabled : bool
131
+ compute (True) or skip (False) symbolic computations (for efficiency). By default: True.
132
+ affine_trainable : bool
133
+ affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
134
+ grid_eps : float
135
+ When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
136
+ grid_range : list/np.array of shape (2,))
137
+ setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
138
+ sp_trainable : bool
139
+ If true, scale_sp is trainable. Default: True.
140
+ sb_trainable : bool
141
+ If true, scale_base is trainable. Default: True.
142
+ device : str
143
+ device
144
+ seed : int
145
+ random seed
146
+ save_act : bool
147
+ indicate whether intermediate activations are saved in forward pass
148
+ sparse_init : bool
149
+ sparse initialization (True) or normal dense initialization. Default: False.
150
+ auto_save : bool
151
+ indicate whether to automatically save a checkpoint once the model is modified
152
+ state_id : int
153
+ the state of the model (used to save checkpoint)
154
+ ckpt_path : str
155
+ the folder to store checkpoints. Default: './model'
156
+ round : int
157
+ the number of times rewind() has been called
158
+ device : str
159
+
160
+ Returns:
161
+ --------
162
+ self
163
+
164
+ Example
165
+ -------
166
+ # >>> from kan import *
167
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
168
+ checkpoint directory created: ./model
169
+ saving model version 0.0
170
+ """
171
+ super(MultKAN, self).__init__()
172
+
173
+ torch.manual_seed(seed)
174
+ np.random.seed(seed)
175
+ random.seed(seed)
176
+
177
+ ### initializeing the numerical front ###
178
+
179
+ self.act_fun = []
180
+ self.depth = len(width) - 1
181
+
182
+ #print('haha1', width)
183
+ for i in range(len(width)):
184
+ #print(type(width[i]), type(width[i]) == int)
185
+ if type(width[i]) == int or type(width[i]) == np.int64:
186
+ width[i] = [width[i], 0]
187
+
188
+ #print('haha2', width)
189
+
190
+ self.width = width
191
+
192
+ # if mult_arity is just a scalar, we extend it to a list of lists
193
+ # e.g, mult_arity = [[2,3],[4]] means that in the first hidden layer, 2 mult ops have arity 2 and 3, respectively;
194
+ # in the second hidden layer, 1 mult op has arity 4.
195
+ if isinstance(mult_arity, int):
196
+ self.mult_homo = True # when homo is True, parallelization is possible
197
+ else:
198
+ self.mult_homo = False # when home if False, for loop is required.
199
+ self.mult_arity = mult_arity
200
+
201
+ width_in = self.width_in
202
+ width_out = self.width_out
203
+
204
+ self.base_fun_name = base_fun
205
+ if base_fun == 'silu':
206
+ base_fun = torch.nn.SiLU()
207
+ elif base_fun == 'identity':
208
+ base_fun = torch.nn.Identity()
209
+ elif base_fun == 'zero':
210
+ base_fun = lambda x: x * 0.
211
+
212
+ self.grid_eps = grid_eps
213
+ self.grid_range = grid_range
214
+
215
+ for l in range(self.depth):
216
+ # splines
217
+ if isinstance(grid, list):
218
+ grid_l = grid[l]
219
+ else:
220
+ grid_l = grid
221
+
222
+ if isinstance(k, list):
223
+ k_l = k[l]
224
+ else:
225
+ k_l = k
226
+
227
+ sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l + 1], num=grid_l, k=k_l,
228
+ noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma,
229
+ scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range,
230
+ sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
231
+ self.act_fun.append(sp_batch)
232
+
233
+ self.node_bias = []
234
+ self.node_scale = []
235
+ self.subnode_bias = []
236
+ self.subnode_scale = []
237
+
238
+ globals()['self.node_bias_0'] = torch.nn.Parameter(torch.zeros(3, 1)).requires_grad_(False)
239
+ exec('self.node_bias_0' + " = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)")
240
+
241
+ for l in range(self.depth):
242
+ exec(
243
+ f'self.node_bias_{l} = torch.nn.Parameter(torch.zeros(width_in[l+1])).requires_grad_(affine_trainable)')
244
+ exec(
245
+ f'self.node_scale_{l} = torch.nn.Parameter(torch.ones(width_in[l+1])).requires_grad_(affine_trainable)')
246
+ exec(
247
+ f'self.subnode_bias_{l} = torch.nn.Parameter(torch.zeros(width_out[l+1])).requires_grad_(affine_trainable)')
248
+ exec(
249
+ f'self.subnode_scale_{l} = torch.nn.Parameter(torch.ones(width_out[l+1])).requires_grad_(affine_trainable)')
250
+ exec(f'self.node_bias.append(self.node_bias_{l})')
251
+ exec(f'self.node_scale.append(self.node_scale_{l})')
252
+ exec(f'self.subnode_bias.append(self.subnode_bias_{l})')
253
+ exec(f'self.subnode_scale.append(self.subnode_scale_{l})')
254
+
255
+ self.act_fun = nn.ModuleList(self.act_fun)
256
+
257
+ self.grid = grid
258
+ self.k = k
259
+ self.base_fun = base_fun
260
+
261
+ ### initializing the symbolic front ###
262
+ self.symbolic_fun = []
263
+ for l in range(self.depth):
264
+ sb_batch = Symbolic_KANLayer(in_dim=width_in[l], out_dim=width_out[l + 1])
265
+ self.symbolic_fun.append(sb_batch)
266
+
267
+ self.symbolic_fun = nn.ModuleList(self.symbolic_fun)
268
+ self.symbolic_enabled = symbolic_enabled
269
+ self.affine_trainable = affine_trainable
270
+ self.sp_trainable = sp_trainable
271
+ self.sb_trainable = sb_trainable
272
+
273
+ self.save_act = save_act
274
+
275
+ self.node_scores = None
276
+ self.edge_scores = None
277
+ self.subnode_scores = None
278
+
279
+ self.cache_data = None
280
+ self.acts = None
281
+
282
+ self.auto_save = auto_save
283
+ self.state_id = 0
284
+ self.ckpt_path = ckpt_path
285
+ self.round = round
286
+
287
+ self.device = device
288
+ self.to(device)
289
+
290
+ if auto_save:
291
+ if first_init:
292
+ if not os.path.exists(ckpt_path):
293
+ # Create the directory
294
+ os.makedirs(ckpt_path)
295
+ print(f"checkpoint directory created: {ckpt_path}")
296
+ print('saving model version 0.0')
297
+
298
+ history_path = self.ckpt_path + '/history.txt'
299
+ with open(history_path, 'w') as file:
300
+ file.write(f'### Round {self.round} ###' + '\n')
301
+ file.write('init => 0.0' + '\n')
302
+ self.saveckpt(path=self.ckpt_path + '/' + '0.0')
303
+ else:
304
+ self.state_id = state_id
305
+
306
+ self.input_id = torch.arange(self.width_in[0], )
307
+
308
+ def to(self, device):
309
+ '''
310
+ move the model to device
311
+
312
+ Args:
313
+ -----
314
+ device : str or device
315
+
316
+ Returns:
317
+ --------
318
+ self
319
+
320
+ Example
321
+ -------
322
+ # >>> from kan import *
323
+ # >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
324
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
325
+ # >>> model.to(device)
326
+ '''
327
+ super(MultKAN, self).to(device)
328
+ self.device = device
329
+
330
+ for kanlayer in self.act_fun:
331
+ kanlayer.to(device)
332
+
333
+ for symbolic_kanlayer in self.symbolic_fun:
334
+ symbolic_kanlayer.to(device)
335
+
336
+ return self
337
+
338
+ @property
339
+ def width_in(self):
340
+ '''
341
+ The number of input nodes for each layer
342
+ '''
343
+ width = self.width
344
+ width_in = [width[l][0] + width[l][1] for l in range(len(width))]
345
+ return width_in
346
+
347
+ @property
348
+ def width_out(self):
349
+ '''
350
+ The number of output subnodes for each layer
351
+ '''
352
+ width = self.width
353
+ if self.mult_homo == True:
354
+ width_out = [width[l][0] + self.mult_arity * width[l][1] for l in range(len(width))]
355
+ else:
356
+ width_out = [width[l][0] + int(np.sum(self.mult_arity[l])) for l in range(len(width))]
357
+ return width_out
358
+
359
+ @property
360
+ def n_sum(self):
361
+ """
362
+ The number of addition nodes for each layer
363
+ """
364
+ width = self.width
365
+ n_sum = [width[l][0] for l in range(1, len(width) - 1)]
366
+ return n_sum
367
+
368
+ @property
369
+ def n_mult(self):
370
+ """
371
+ The number of multiplication nodes for each layer
372
+ """
373
+ width = self.width
374
+ n_mult = [width[l][1] for l in range(1, len(width) - 1)]
375
+ return n_mult
376
+
377
+ @property
378
+ def feature_score(self):
379
+ '''
380
+ attribution scores for inputs
381
+ '''
382
+ self.attribute()
383
+ if self.node_scores == None:
384
+ return None
385
+ else:
386
+ return self.node_scores[0]
387
+
388
+ def initialize_from_another_model(self, another_model, x):
389
+ '''
390
+ initialize from another model of the same width, but their 'grid' parameter can be different.
391
+ Note this is equivalent to refine() when we don't want to keep another_model
392
+
393
+ Args:
394
+ -----
395
+ another_model : MultKAN
396
+ x : 2D torch.float
397
+
398
+ Returns:
399
+ --------
400
+ self
401
+
402
+ Example
403
+ -------
404
+ # >>> from kan import *
405
+ # >>> model1 = KAN(width=[2,5,1], grid=3)
406
+ # >>> model2 = KAN(width=[2,5,1], grid=10)
407
+ # >>> x = torch.rand(100,2)
408
+ # >>> model2.initialize_from_another_model(model1, x)
409
+ '''
410
+ another_model(x) # get activations
411
+ batch = x.shape[0]
412
+
413
+ self.initialize_grid_from_another_model(another_model, x)
414
+
415
+ for l in range(self.depth):
416
+ spb = self.act_fun[l]
417
+ #spb_parent = another_model.act_fun[l]
418
+
419
+ # spb = spb_parent
420
+ preacts = another_model.spline_preacts[l]
421
+ postsplines = another_model.spline_postsplines[l]
422
+ self.act_fun[l].coef.data = curve2coef(preacts[:, 0, :], postsplines.permute(0, 2, 1), spb.grid, k=spb.k)
423
+ self.act_fun[l].scale_base.data = another_model.act_fun[l].scale_base.data
424
+ self.act_fun[l].scale_sp.data = another_model.act_fun[l].scale_sp.data
425
+ self.act_fun[l].mask.data = another_model.act_fun[l].mask.data
426
+
427
+ for l in range(self.depth):
428
+ self.node_bias[l].data = another_model.node_bias[l].data
429
+ self.node_scale[l].data = another_model.node_scale[l].data
430
+
431
+ self.subnode_bias[l].data = another_model.subnode_bias[l].data
432
+ self.subnode_scale[l].data = another_model.subnode_scale[l].data
433
+
434
+ for l in range(self.depth):
435
+ self.symbolic_fun[l] = another_model.symbolic_fun[l]
436
+
437
+ return self.to(self.device)
438
+
439
+ def log_history(self, method_name):
440
+
441
+ if self.auto_save:
442
+ # save to log file
443
+ #print(func.__name__)
444
+ with open(self.ckpt_path + '/history.txt', 'a') as file:
445
+ file.write(str(self.round) + '.' + str(self.state_id) + ' => ' + method_name + ' => ' + str(
446
+ self.round) + '.' + str(self.state_id + 1) + '\n')
447
+
448
+ # update state_id
449
+ self.state_id += 1
450
+
451
+ # save to ckpt
452
+ self.saveckpt(path=self.ckpt_path + '/' + str(self.round) + '.' + str(self.state_id))
453
+ print('saving model version ' + str(self.round) + '.' + str(self.state_id))
454
+
455
+ def refine(self, new_grid):
456
+ '''
457
+ grid refinement
458
+
459
+ Args:
460
+ -----
461
+ new_grid : init
462
+ the number of grid intervals after refinement
463
+
464
+ Returns:
465
+ --------
466
+ a refined model : MultKAN
467
+
468
+ Example
469
+ -------
470
+ # >>> from kan import *
471
+ # >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
472
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
473
+ # >>> print(model.grid)
474
+ # >>> x = torch.rand(100,2)
475
+ # >>> model.get_act(x)
476
+ # >>> model = model.refine(10)
477
+ # >>> print(model.grid)
478
+ checkpoint directory created: ./model
479
+ saving model version 0.0
480
+ 5
481
+ saving model version 0.1
482
+ 10
483
+ '''
484
+
485
+ model_new = MultKAN(width=self.width,
486
+ grid=new_grid,
487
+ k=self.k,
488
+ mult_arity=self.mult_arity,
489
+ base_fun=self.base_fun_name,
490
+ symbolic_enabled=self.symbolic_enabled,
491
+ affine_trainable=self.affine_trainable,
492
+ grid_eps=self.grid_eps,
493
+ grid_range=self.grid_range,
494
+ sp_trainable=self.sp_trainable,
495
+ sb_trainable=self.sb_trainable,
496
+ ckpt_path=self.ckpt_path,
497
+ auto_save=True,
498
+ first_init=False,
499
+ state_id=self.state_id,
500
+ round=self.round,
501
+ device=self.device)
502
+
503
+ model_new.initialize_from_another_model(self, self.cache_data)
504
+ model_new.cache_data = self.cache_data
505
+ model_new.grid = new_grid
506
+
507
+ self.log_history('refine')
508
+ model_new.state_id += 1
509
+
510
+ return model_new.to(self.device)
511
+
512
+ def saveckpt(self, path='model'):
513
+ '''
514
+ save the current model to files (configuration file and state file)
515
+
516
+ Args:
517
+ -----
518
+ path : str
519
+ the path where checkpoints are saved
520
+
521
+ Returns:
522
+ --------
523
+ None
524
+
525
+ Example
526
+ -------
527
+ # >>> from kan import *
528
+ # >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
529
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
530
+ # >>> model.saveckpt('./mark')
531
+ # There will be three files appearing in the current folder: mark_cache_data, mark_config.yml, mark_state
532
+ '''
533
+
534
+ model = self
535
+
536
+ dic = dict(
537
+ width=model.width,
538
+ grid=model.grid,
539
+ k=model.k,
540
+ mult_arity=model.mult_arity,
541
+ base_fun_name=model.base_fun_name,
542
+ symbolic_enabled=model.symbolic_enabled,
543
+ affine_trainable=model.affine_trainable,
544
+ grid_eps=model.grid_eps,
545
+ grid_range=model.grid_range,
546
+ sp_trainable=model.sp_trainable,
547
+ sb_trainable=model.sb_trainable,
548
+ state_id=model.state_id,
549
+ auto_save=model.auto_save,
550
+ ckpt_path=model.ckpt_path,
551
+ round=model.round,
552
+ device=str(model.device)
553
+ )
554
+
555
+ if dic["device"].isdigit():
556
+ dic["device"] = int(model.device)
557
+
558
+ for i in range(model.depth):
559
+ dic[f'symbolic.funs_name.{i}'] = model.symbolic_fun[i].funs_name
560
+
561
+ with open(f'{path}_config.yml', 'w') as outfile:
562
+ yaml.dump(dic, outfile, default_flow_style=False)
563
+
564
+ torch.save(model.state_dict(), f'{path}_state')
565
+ torch.save(model.cache_data, f'{path}_cache_data')
566
+
567
+ @staticmethod
568
+ def loadckpt(path='model'):
569
+ '''
570
+ load checkpoint from path
571
+
572
+ Args:
573
+ -----
574
+ path : str
575
+ the path where checkpoints are saved
576
+
577
+ Returns:
578
+ --------
579
+ MultKAN
580
+
581
+ Example
582
+ -------
583
+ # >>> from kan import *
584
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
585
+ # >>> model.saveckpt('./mark')
586
+ # >>> KAN.loadckpt('./mark')
587
+ '''
588
+ with open(f'{path}_config.yml', 'r') as stream:
589
+ config = yaml.safe_load(stream)
590
+
591
+ state = torch.load(f'{path}_state', weights_only=False)
592
+
593
+ model_load = MultKAN(width=config['width'],
594
+ grid=config['grid'],
595
+ k=config['k'],
596
+ mult_arity=config['mult_arity'],
597
+ base_fun=config['base_fun_name'],
598
+ symbolic_enabled=config['symbolic_enabled'],
599
+ affine_trainable=config['affine_trainable'],
600
+ grid_eps=config['grid_eps'],
601
+ grid_range=config['grid_range'],
602
+ sp_trainable=config['sp_trainable'],
603
+ sb_trainable=config['sb_trainable'],
604
+ state_id=config['state_id'],
605
+ auto_save=config['auto_save'],
606
+ first_init=False,
607
+ ckpt_path=config['ckpt_path'],
608
+ round=config['round'] + 1,
609
+ device=config['device'])
610
+
611
+ model_load.load_state_dict(state)
612
+ model_load.cache_data = torch.load(f'{path}_cache_data')
613
+
614
+ depth = len(model_load.width) - 1
615
+ for l in range(depth):
616
+ out_dim = model_load.symbolic_fun[l].out_dim
617
+ in_dim = model_load.symbolic_fun[l].in_dim
618
+ funs_name = config[f'symbolic.funs_name.{l}']
619
+ for j in range(out_dim):
620
+ for i in range(in_dim):
621
+ fun_name = funs_name[j][i]
622
+ model_load.symbolic_fun[l].funs_name[j][i] = fun_name
623
+ model_load.symbolic_fun[l].funs[j][i] = SYMBOLIC_LIB[fun_name][0]
624
+ model_load.symbolic_fun[l].funs_sympy[j][i] = SYMBOLIC_LIB[fun_name][1]
625
+ model_load.symbolic_fun[l].funs_avoid_singularity[j][i] = SYMBOLIC_LIB[fun_name][3]
626
+ return model_load
627
+
628
+ def copy(self):
629
+ """
630
+ deepcopy
631
+
632
+ Args:
633
+ -----
634
+ path : str
635
+ the path where checkpoints are saved
636
+
637
+ Returns:
638
+ --------
639
+ MultKAN
640
+
641
+ Example
642
+ -------
643
+ # >>> from kan import *
644
+ # >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
645
+ # >>> model2 = model.copy()
646
+ # >>> model2.act_fun[0].coef.data *= 2
647
+ # >>> print(model2.act_fun[0].coef.data)
648
+ # >>> print(model.act_fun[0].coef.data)
649
+ """
650
+ path = 'copy_temp'
651
+ self.saveckpt(path)
652
+ return KAN.loadckpt(path)
653
+
654
+ def rewind(self, model_id):
655
+ """
656
+ rewind to an old version
657
+
658
+ Args:
659
+ -----
660
+ model_id : str
661
+ in format '{a}.{b}' where a is the round number, b is the version number in that round
662
+
663
+ Returns:
664
+ --------
665
+ MultKAN
666
+
667
+ Example
668
+ -------
669
+ Please refer to tutorials. API 12: Checkpoint, save & load model
670
+ """
671
+ self.round += 1
672
+ self.state_id = model_id.split('.')[-1]
673
+
674
+ history_path = self.ckpt_path + '/history.txt'
675
+ with open(history_path, 'a') as file:
676
+ file.write(f'### Round {self.round} ###' + '\n')
677
+
678
+ self.saveckpt(path=self.ckpt_path + '/' + f'{self.round}.{self.state_id}')
679
+
680
+ print(
681
+ 'rewind to model version ' + f'{self.round - 1}.{self.state_id}' + ', renamed as ' + f'{self.round}.{self.state_id}')
682
+
683
+ return MultKAN.loadckpt(path=self.ckpt_path + '/' + str(model_id))
684
+
685
+ def checkout(self, model_id):
686
+ """
687
+ check out an old version
688
+
689
+ Args:
690
+ -----
691
+ model_id : str
692
+ in format '{a}.{b}' where a is the round number, b is the version number in that round
693
+
694
+ Returns:
695
+ --------
696
+ MultKAN
697
+
698
+ Example
699
+ -------
700
+ Same use as rewind, although checkout doesn't change states
701
+ """
702
+ return MultKAN.loadckpt(path=self.ckpt_path + '/' + str(model_id))
703
+
704
+ def update_grid_from_samples(self, x):
705
+ """
706
+ update grid from samples
707
+
708
+ Args:
709
+ -----
710
+ x : 2D torch.tensor
711
+ inputs
712
+
713
+ Returns:
714
+ --------
715
+ None
716
+
717
+ Example
718
+ -------
719
+ # >>> from kan import *
720
+ # >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
721
+ # >>> print(model.act_fun[0].grid)
722
+ # >>> x = torch.linspace(-10,10,steps=101)[:,None]
723
+ # >>> model.update_grid_from_samples(x)
724
+ # >>> print(model.act_fun[0].grid)
725
+ """
726
+ for l in range(self.depth):
727
+ self.get_act(x)
728
+ self.act_fun[l].update_grid_from_samples(self.acts[l])
729
+
730
+ def update_grid(self, x):
731
+ """
732
+ call update_grid_from_samples. This seems unnecessary but we retain it for the sake of classes that might inherit from MultKAN
733
+ """
734
+ self.update_grid_from_samples(x)
735
+
736
+ def initialize_grid_from_another_model(self, model, x):
737
+ """
738
+ initialize grid from another model
739
+
740
+ Args:
741
+ -----
742
+ model : MultKAN
743
+ parent model
744
+ x : 2D torch.tensor
745
+ inputs
746
+
747
+ Returns:
748
+ --------
749
+ None
750
+
751
+ Example
752
+ -------
753
+ # >>> from kan import *
754
+ # >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
755
+ # >>> print(model.act_fun[0].grid)
756
+ # >>> x = torch.linspace(-10,10,steps=101)[:,None]
757
+ # >>> model2 = KAN(width=[1,1], grid=10, k=3, seed=0)
758
+ # >>> model2.initialize_grid_from_another_model(model, x)
759
+ # >>> print(model2.act_fun[0].grid)
760
+ """
761
+ model(x)
762
+ for l in range(self.depth):
763
+ self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l])
764
+
765
+ def forward(self, x, singularity_avoiding=False, y_th=10.):
766
+ '''
767
+ forward pass
768
+
769
+ Args:
770
+ -----
771
+ x : 2D torch.tensor
772
+ inputs
773
+ singularity_avoiding : bool
774
+ whether to avoid singularity for the symbolic branch
775
+ y_th : float
776
+ the threshold for singularity
777
+
778
+ Returns:
779
+ --------
780
+ None
781
+
782
+ Example1
783
+ --------
784
+ # >>> from kan import *
785
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
786
+ # >>> x = torch.rand(100,2)
787
+ # >>> model(x).shape
788
+
789
+ Example2
790
+ --------
791
+ # >>> from kan import *
792
+ # >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
793
+ # >>> x = torch.tensor([[1],[-0.01]])
794
+ # >>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False)
795
+ # >>> print(model(x))
796
+ # >>> print(model(x, singularity_avoiding=True))
797
+ # >>> print(model(x, singularity_avoiding=True, y_th=1.))
798
+ '''
799
+ x = x[:, self.input_id.long()]
800
+ assert x.shape[1] == self.width_in[0]
801
+
802
+ # cache data
803
+ self.cache_data = x
804
+
805
+ self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
806
+ self.acts_premult = []
807
+ self.spline_preacts = []
808
+ self.spline_postsplines = []
809
+ self.spline_postacts = []
810
+ self.acts_scale = []
811
+ self.acts_scale_spline = []
812
+ self.subnode_actscale = []
813
+ self.edge_actscale = []
814
+ # self.neurons_scale = []
815
+
816
+ self.acts.append(x) # acts shape: (batch, width[l])
817
+
818
+ for l in range(self.depth):
819
+
820
+ x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
821
+ #print(preacts, postacts_numerical, postspline)
822
+
823
+ if self.symbolic_enabled == True:
824
+ x_symbolic, postacts_symbolic = self.symbolic_fun[l](x, singularity_avoiding=singularity_avoiding,
825
+ y_th=y_th)
826
+ else:
827
+ x_symbolic = 0.
828
+ postacts_symbolic = 0.
829
+
830
+ x = x_numerical + x_symbolic
831
+
832
+ if self.save_act:
833
+ # save subnode_scale
834
+ self.subnode_actscale.append(torch.std(x, dim=0).detach())
835
+
836
+ # subnode affine transform
837
+ x = self.subnode_scale[l][None, :] * x + self.subnode_bias[l][None, :]
838
+
839
+ if self.save_act:
840
+ postacts = postacts_numerical + postacts_symbolic
841
+
842
+ # self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))
843
+ #grid_reshape = self.act_fun[l].grid.reshape(self.width_out[l + 1], self.width_in[l], -1)
844
+ input_range = torch.std(preacts, dim=0) + 0.1
845
+ output_range_spline = torch.std(postacts_numerical,
846
+ dim=0) # for training, only penalize the spline part
847
+ output_range = torch.std(postacts,
848
+ dim=0) # for visualization, include the contribution from both spline + symbolic
849
+ # save edge_scale
850
+ self.edge_actscale.append(output_range)
851
+
852
+ self.acts_scale.append((output_range / input_range).detach())
853
+ self.acts_scale_spline.append(output_range_spline / input_range)
854
+ self.spline_preacts.append(preacts.detach())
855
+ self.spline_postacts.append(postacts.detach())
856
+ self.spline_postsplines.append(postspline.detach())
857
+
858
+ self.acts_premult.append(x.detach())
859
+
860
+ # multiplication
861
+ dim_sum = self.width[l + 1][0]
862
+ dim_mult = self.width[l + 1][1]
863
+
864
+ if self.mult_homo:
865
+ for i in range(self.mult_arity - 1):
866
+ if i == 0:
867
+ x_mult = x[:, dim_sum::self.mult_arity] * x[:, dim_sum + 1::self.mult_arity]
868
+ else:
869
+ x_mult = x_mult * x[:, dim_sum + i + 1::self.mult_arity]
870
+
871
+ else:
872
+ for j in range(dim_mult):
873
+ acml_id = dim_sum + np.sum(self.mult_arity[l + 1][:j])
874
+ for i in range(self.mult_arity[l + 1][j] - 1):
875
+ if i == 0:
876
+ x_mult_j = x[:, [acml_id]] * x[:, [acml_id + 1]]
877
+ else:
878
+ x_mult_j = x_mult_j * x[:, [acml_id + i + 1]]
879
+
880
+ if j == 0:
881
+ x_mult = x_mult_j
882
+ else:
883
+ x_mult = torch.cat([x_mult, x_mult_j], dim=1)
884
+
885
+ if self.width[l + 1][1] > 0:
886
+ x = torch.cat([x[:, :dim_sum], x_mult], dim=1)
887
+
888
+ # x = x + self.biases[l].weight
889
+ # node affine transform
890
+ x = self.node_scale[l][None, :] * x + self.node_bias[l][None, :]
891
+
892
+ self.acts.append(x.detach())
893
+
894
+ return x
895
+
896
+ def set_mode(self, l, i, j, mode, mask_n=None):
897
+ if mode == "s":
898
+ mask_n = 0.;
899
+ mask_s = 1.
900
+ elif mode == "n":
901
+ mask_n = 1.;
902
+ mask_s = 0.
903
+ elif mode == "sn" or mode == "ns":
904
+ if mask_n == None:
905
+ mask_n = 1.
906
+ else:
907
+ mask_n = mask_n
908
+ mask_s = 1.
909
+ else:
910
+ mask_n = 0.;
911
+ mask_s = 0.
912
+
913
+ self.act_fun[l].mask.data[i][j] = mask_n
914
+ self.symbolic_fun[l].mask.data[j, i] = mask_s
915
+
916
+ def fix_symbolic(self, l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True,
917
+ random=False, log_history=True):
918
+ '''
919
+ set (l,i,j) activation to be symbolic (specified by fun_name)
920
+
921
+ Args:
922
+ -----
923
+ l : int
924
+ layer index
925
+ i : int
926
+ input neuron index
927
+ j : int
928
+ output neuron index
929
+ fun_name : str
930
+ function name
931
+ fit_params_bool : bool
932
+ obtaining affine parameters through fitting (True) or setting default values (False)
933
+ a_range : tuple
934
+ sweeping range of a
935
+ b_range : tuple
936
+ sweeping range of b
937
+ verbose : bool
938
+ If True, more information is printed.
939
+ random : bool
940
+ initialize affine parameteres randomly or as [1,0,1,0]
941
+ log_history : bool
942
+ indicate whether to log history when the function is called
943
+
944
+ Returns:
945
+ --------
946
+ None or r2 (coefficient of determination)
947
+
948
+ Example 1
949
+ ---------
950
+ >>> # when fit_params_bool = False
951
+ >>> model = KAN(width=[2,5,1], grid=5, k=3)
952
+ >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False)
953
+ >>> print(model.act_fun[0].mask.reshape(2,5))
954
+ >>> print(model.symbolic_fun[0].mask.reshape(2,5))
955
+
956
+ Example 2
957
+ ---------
958
+ >>> # when fit_params_bool = True
959
+ >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.)
960
+ >>> x = torch.normal(0,1,size=(100,2))
961
+ >>> model(x) # obtain activations (otherwise model does not have attributes acts)
962
+ >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True)
963
+ >>> print(model.act_fun[0].mask.reshape(2,5))
964
+ >>> print(model.symbolic_fun[0].mask.reshape(2,5))
965
+ '''
966
+ if not fit_params_bool:
967
+ self.symbolic_fun[l].fix_symbolic(i, j, fun_name, verbose=verbose, random=random)
968
+ r2 = None
969
+ else:
970
+ x = self.acts[l][:, i]
971
+ mask = self.act_fun[l].mask
972
+ y = self.spline_postacts[l][:, j, i]
973
+ #y = self.postacts[l][:, j, i]
974
+ r2 = self.symbolic_fun[l].fix_symbolic(i, j, fun_name, x, y, a_range=a_range, b_range=b_range,
975
+ verbose=verbose)
976
+ if mask[i, j] == 0:
977
+ r2 = - 1e8
978
+ self.set_mode(l, i, j, mode="s")
979
+
980
+ if log_history:
981
+ self.log_history('fix_symbolic')
982
+ return r2
983
+
984
+ def unfix_symbolic(self, l, i, j, log_history=True):
985
+ '''
986
+ unfix the (l,i,j) activation function.
987
+ '''
988
+ self.set_mode(l, i, j, mode="n")
989
+ self.symbolic_fun[l].funs_name[j][i] = "0"
990
+ if log_history:
991
+ self.log_history('unfix_symbolic')
992
+
993
+ def unfix_symbolic_all(self, log_history=True):
994
+ '''
995
+ unfix all activation functions.
996
+ '''
997
+ for l in range(len(self.width) - 1):
998
+ for i in range(self.width_in[l]):
999
+ for j in range(self.width_out[l + 1]):
1000
+ self.unfix_symbolic(l, i, j, log_history)
1001
+
1002
+ def get_range(self, l, i, j, verbose=True):
1003
+ '''
1004
+ Get the input range and output range of the (l,i,j) activation
1005
+
1006
+ Args:
1007
+ -----
1008
+ l : int
1009
+ layer index
1010
+ i : int
1011
+ input neuron index
1012
+ j : int
1013
+ output neuron index
1014
+
1015
+ Returns:
1016
+ --------
1017
+ x_min : float
1018
+ minimum of input
1019
+ x_max : float
1020
+ maximum of input
1021
+ y_min : float
1022
+ minimum of output
1023
+ y_max : float
1024
+ maximum of output
1025
+
1026
+ Example
1027
+ -------
1028
+ >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
1029
+ >>> x = torch.normal(0,1,size=(100,2))
1030
+ >>> model(x) # do a forward pass to obtain model.acts
1031
+ >>> model.get_range(0,0,0)
1032
+ '''
1033
+ x = self.spline_preacts[l][:, j, i]
1034
+ y = self.spline_postacts[l][:, j, i]
1035
+ x_min = torch.min(x).cpu().detach().numpy()
1036
+ x_max = torch.max(x).cpu().detach().numpy()
1037
+ y_min = torch.min(y).cpu().detach().numpy()
1038
+ y_max = torch.max(y).cpu().detach().numpy()
1039
+ if verbose:
1040
+ print('x range: [' + '%.2f' % x_min, ',', '%.2f' % x_max, ']')
1041
+ print('y range: [' + '%.2f' % y_min, ',', '%.2f' % y_max, ']')
1042
+ return x_min, x_max, y_min, y_max
1043
+
1044
+ def _draw_operator_symbol(self, ax, fig, x, y, symbol_type, scale, size):
1045
+ """辅助函数绘制运算符符号"""
1046
+ path = os.path.join(os.path.dirname(__file__), 'assets', 'img', f'{symbol_type}_symbol.png')
1047
+ try:
1048
+ with Image.open(path) as img:
1049
+ trans = ax.transData.transform
1050
+ inv_trans = fig.transFigure.inverted().transform
1051
+ bbox = [x - size / 2, y - size / 2, size, size]
1052
+ bbox = inv_trans.transform_bbox(trans(bbox))
1053
+ ax.imshow(img, extent=bbox)
1054
+ except Exception as e:
1055
+ print(f"无法加载符号 {symbol_type}: {str(e)}")
1056
+
1057
+ def _draw_variable_labels(self, ax, variables, layer, scale, varscale, y_offset, y0, z0):
1058
+ """辅助函数绘制变量标签"""
1059
+ n = self.width_in[layer]
1060
+ for i, var in enumerate(variables):
1061
+ x = 1 / (2 * n) + i / n
1062
+ y = layer * (y0 + z0) + y_offset
1063
+ label = f'${latex(var)}$' if isinstance(var, sympy.Expr) else var
1064
+ ax.text(x, y, label, fontsize=15 * scale * varscale,
1065
+ ha='center', va='center', transform=ax.transData)
1066
+
1067
+ def plot(self, folder="./figures", beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None,
1068
+ out_vars=None, title=None, varscale=1.0):
1069
+ '''
1070
+ plot KAN
1071
+
1072
+ Args:
1073
+ -----
1074
+ folder : str
1075
+ the folder to store pngs
1076
+ beta : float
1077
+ positive number. control the transparency of each activation. transparency = tanh(beta*l1).
1078
+ mask : bool
1079
+ If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.
1080
+ mode : bool
1081
+ "supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean).
1082
+ scale : float
1083
+ control the size of the diagram
1084
+ in_vars: None or list of str
1085
+ the name(s) of input variables
1086
+ out_vars: None or list of str
1087
+ the name(s) of output variables
1088
+ title: None or str
1089
+ title
1090
+ varscale : float
1091
+ the size of input variables
1092
+
1093
+ Returns:
1094
+ --------
1095
+ Figure
1096
+
1097
+ Example
1098
+ -------
1099
+ >>> # see more interactive examples in demos
1100
+ >>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0)
1101
+ >>> x = torch.normal(0,1,size=(100,2))
1102
+ >>> model(x) # do a forward pass to obtain model.acts
1103
+ >>> model.plot()
1104
+ '''
1105
+ global Symbol
1106
+
1107
+ if not self.save_act:
1108
+ print('cannot plot since data are not saved. Set save_act=True first.')
1109
+
1110
+ if self.acts is None:
1111
+ if self.cache_data is None:
1112
+ raise Exception('model hasn\'t seen any data yet.')
1113
+ self.forward(self.cache_data)
1114
+
1115
+ if metric == 'backward':
1116
+ self.attribute()
1117
+
1118
+ if not os.path.exists(folder):
1119
+ os.makedirs(folder)
1120
+
1121
+ depth = len(self.width) - 1
1122
+ for l in range(depth):
1123
+ w_large = 2.0
1124
+ for i in range(self.width_in[l]):
1125
+ for j in range(self.width_out[l + 1]):
1126
+ rank = torch.argsort(self.acts[l][:, i])
1127
+ fig, ax = plt.subplots(figsize=(w_large, w_large))
1128
+
1129
+ num = rank.shape[0]
1130
+ symbolic_mask = self.symbolic_fun[l].mask[j][i]
1131
+ numeric_mask = self.act_fun[l].mask[i][j]
1132
+ if symbolic_mask > 0. and numeric_mask > 0.:
1133
+ color = 'purple'
1134
+ alpha_mask = 1
1135
+ if symbolic_mask > 0. and numeric_mask == 0.:
1136
+ color = "red"
1137
+ alpha_mask = 1
1138
+ if symbolic_mask == 0. and numeric_mask > 0.:
1139
+ color = "black"
1140
+ alpha_mask = 1
1141
+ if symbolic_mask == 0. and numeric_mask == 0.:
1142
+ color = "white"
1143
+ alpha_mask = 0
1144
+
1145
+ if tick:
1146
+ ax.tick_params(axis="y", direction="in", pad=-22, labelsize=50)
1147
+ ax.tick_params(axis="x", direction="in", pad=-15, labelsize=50)
1148
+ x_min, x_max, y_min, y_max = self.get_range(l, i, j, verbose=False)
1149
+ plt.xticks([x_min, x_max], ['%2.f' % x_min, '%2.f' % x_max])
1150
+ plt.yticks([y_min, y_max], ['%2.f' % y_min, '%2.f' % y_max])
1151
+ else:
1152
+ plt.xticks([])
1153
+ plt.yticks([])
1154
+ if alpha_mask == 1:
1155
+ plt.gca().patch.set_edgecolor('black')
1156
+ else:
1157
+ plt.gca().patch.set_edgecolor('white')
1158
+ plt.gca().patch.set_linewidth(1.5)
1159
+ # plt.axis('off')
1160
+
1161
+ plt.plot(self.acts[l][:, i][rank].cpu().detach().numpy(),
1162
+ self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, lw=5)
1163
+ if sample:
1164
+ plt.scatter(self.acts[l][:, i][rank].cpu().detach().numpy(),
1165
+ self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color,
1166
+ s=400 * scale ** 2)
1167
+ plt.gca().spines[:].set_color(color)
1168
+
1169
+ plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400)
1170
+ plt.close()
1171
+
1172
+ def score2alpha(score):
1173
+ return np.tanh(beta * score)
1174
+
1175
+ if metric == 'forward_n':
1176
+ scores = self.acts_scale
1177
+ elif metric == 'forward_u':
1178
+ scores = self.edge_actscale
1179
+ elif metric == 'backward':
1180
+ scores = self.edge_scores
1181
+ else:
1182
+ raise Exception(f'metric = \'{metric}\' not recognized')
1183
+
1184
+ alpha = [score2alpha(score.cpu().detach().numpy()) for score in scores]
1185
+
1186
+ # draw skeleton
1187
+ width = np.array(self.width)
1188
+ width_in = np.array(self.width_in)
1189
+ width_out = np.array(self.width_out)
1190
+ A = 1
1191
+ y0 = 0.3 # height: from input to pre-mult
1192
+ z0 = 0.1 # height: from pre-mult to post-mult (input of next layer)
1193
+
1194
+ neuron_depth = len(width)
1195
+ min_spacing = A / np.maximum(np.max(width_out), 5)
1196
+
1197
+ max_neuron = np.max(width_out)
1198
+ max_num_weights = np.max(width_in[:-1] * width_out[1:])
1199
+ y1 = 0.4 / np.maximum(max_num_weights, 5) # size (height/width) of 1D function diagrams
1200
+ y2 = 0.15 / np.maximum(max_neuron, 5) # size (height/width) of operations (sum and mult)
1201
+
1202
+ fig, ax = plt.subplots(figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * (y0 + z0)))
1203
+ # fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0))
1204
+
1205
+ # -- Transformation functions
1206
+ DC_to_FC = ax.transData.transform
1207
+ FC_to_NFC = fig.transFigure.inverted().transform
1208
+ # -- Take data coordinates and transform them to normalized figure coordinates
1209
+ DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x))
1210
+
1211
+ # plot scatters and lines
1212
+ for l in range(neuron_depth):
1213
+
1214
+ n = width_in[l]
1215
+
1216
+ # scatters
1217
+ for i in range(n):
1218
+ plt.scatter(1 / (2 * n) + i / n, l * (y0 + z0), s=min_spacing ** 2 * 10000 * scale ** 2, color='black')
1219
+
1220
+ # plot connections (input to pre-mult)
1221
+ for i in range(n):
1222
+ if l < neuron_depth - 1:
1223
+ n_next = width_out[l + 1]
1224
+ N = n * n_next
1225
+ for j in range(n_next):
1226
+ id_ = i * n_next + j
1227
+
1228
+ symbol_mask = self.symbolic_fun[l].mask[j][i]
1229
+ numerical_mask = self.act_fun[l].mask[i][j]
1230
+ if symbol_mask == 1. and numerical_mask > 0.:
1231
+ color = 'purple'
1232
+ alpha_mask = 1.
1233
+ if symbol_mask == 1. and numerical_mask == 0.:
1234
+ color = "red"
1235
+ alpha_mask = 1.
1236
+ if symbol_mask == 0. and numerical_mask == 1.:
1237
+ color = "black"
1238
+ alpha_mask = 1.
1239
+ if symbol_mask == 0. and numerical_mask == 0.:
1240
+ color = "white"
1241
+ alpha_mask = 0.
1242
+
1243
+ plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N],
1244
+ [l * (y0 + z0), l * (y0 + z0) + y0 / 2 - y1], color=color, lw=2 * scale,
1245
+ alpha=alpha[l][j][i] * alpha_mask)
1246
+ plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next],
1247
+ [l * (y0 + z0) + y0 / 2 + y1, l * (y0 + z0) + y0], color=color, lw=2 * scale,
1248
+ alpha=alpha[l][j][i] * alpha_mask)
1249
+
1250
+ # plot connections (pre-mult to post-mult, post-mult = next-layer input)
1251
+ if l < neuron_depth - 1:
1252
+ n_in = width_out[l + 1]
1253
+ n_out = width_in[l + 1]
1254
+ mult_id = 0
1255
+ for i in range(n_in):
1256
+ if i < width[l + 1][0]:
1257
+ j = i
1258
+ else:
1259
+ if i == width[l + 1][0]:
1260
+ if isinstance(self.mult_arity, int):
1261
+ ma = self.mult_arity
1262
+ else:
1263
+ ma = self.mult_arity[l + 1][mult_id]
1264
+ current_mult_arity = ma
1265
+ if current_mult_arity == 0:
1266
+ mult_id += 1
1267
+ if isinstance(self.mult_arity, int):
1268
+ ma = self.mult_arity
1269
+ else:
1270
+ ma = self.mult_arity[l + 1][mult_id]
1271
+ current_mult_arity = ma
1272
+ j = width[l + 1][0] + mult_id
1273
+ current_mult_arity -= 1
1274
+ #j = (i-width[l+1][0])//self.mult_arity + width[l+1][0]
1275
+ plt.plot([1 / (2 * n_in) + i / n_in, 1 / (2 * n_out) + j / n_out],
1276
+ [l * (y0 + z0) + y0, (l + 1) * (y0 + z0)], color='black', lw=2 * scale)
1277
+
1278
+ plt.xlim(0, 1)
1279
+ plt.ylim(-0.1 * (y0 + z0), (neuron_depth - 1 + 0.1) * (y0 + z0))
1280
+
1281
+ plt.axis('off')
1282
+
1283
+ for l in range(neuron_depth - 1):
1284
+ # plot splines
1285
+ n = width_in[l]
1286
+ for i in range(n):
1287
+ n_next = width_out[l + 1]
1288
+ N = n * n_next
1289
+ for j in range(n_next):
1290
+ id_ = i * n_next + j
1291
+ im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png')
1292
+ left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0]
1293
+ right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0]
1294
+ bottom = DC_to_NFC([0, l * (y0 + z0) + y0 / 2 - y1])[1]
1295
+ up = DC_to_NFC([0, l * (y0 + z0) + y0 / 2 + y1])[1]
1296
+ newax = fig.add_axes([left, bottom, right - left, up - bottom])
1297
+ # newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE')
1298
+ newax.imshow(im, alpha=alpha[l][j][i])
1299
+ newax.axis('off')
1300
+
1301
+ # plot sum symbols
1302
+ N = n = width_out[l + 1]
1303
+ for j in range(n):
1304
+ id_ = j
1305
+ path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
1306
+ im = plt.imread(path)
1307
+ left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
1308
+ right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
1309
+ bottom = DC_to_NFC([0, l * (y0 + z0) + y0 - y2])[1]
1310
+ up = DC_to_NFC([0, l * (y0 + z0) + y0 + y2])[1]
1311
+ newax = fig.add_axes([left, bottom, right - left, up - bottom])
1312
+ newax.imshow(im)
1313
+ newax.axis('off')
1314
+
1315
+ # plot mult symbols
1316
+ N = n = width_in[l + 1]
1317
+ n_sum = width[l + 1][0]
1318
+ n_mult = width[l + 1][1]
1319
+ for j in range(n_mult):
1320
+ id_ = j + n_sum
1321
+ path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
1322
+ im = plt.imread(path)
1323
+ left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
1324
+ right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
1325
+ bottom = DC_to_NFC([0, (l + 1) * (y0 + z0) - y2])[1]
1326
+ up = DC_to_NFC([0, (l + 1) * (y0 + z0) + y2])[1]
1327
+ newax = fig.add_axes([left, bottom, right - left, up - bottom])
1328
+ newax.imshow(im)
1329
+ newax.axis('off')
1330
+
1331
+ if in_vars is not None:
1332
+ n = self.width_in[0]
1333
+ for i in range(n):
1334
+ if isinstance(in_vars[i], sympy.Expr):
1335
+ plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, f'${latex(in_vars[i])}$',
1336
+ fontsize=40 * scale * varscale, horizontalalignment='center',
1337
+ verticalalignment='center')
1338
+ else:
1339
+ plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i],
1340
+ fontsize=40 * scale * varscale, horizontalalignment='center',
1341
+ verticalalignment='center')
1342
+
1343
+ if out_vars is not None:
1344
+ n = self.width_in[-1]
1345
+ for i in range(n):
1346
+ if isinstance(out_vars[i], sympy.Expr):
1347
+ plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0 + z0) * (len(self.width) - 1) + 0.15,
1348
+ f'${latex(out_vars[i])}$', fontsize=40 * scale * varscale,
1349
+ horizontalalignment='center', verticalalignment='center')
1350
+ else:
1351
+ plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0 + z0) * (len(self.width) - 1) + 0.15,
1352
+ out_vars[i], fontsize=40 * scale * varscale,
1353
+ horizontalalignment='center', verticalalignment='center')
1354
+
1355
+ if title is not None:
1356
+ plt.gcf().get_axes()[0].text(0.5, (y0 + z0) * (len(self.width) - 1) + 0.3, title, fontsize=40 * scale,
1357
+ horizontalalignment='center', verticalalignment='center')
1358
+
1359
+ def reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff):
1360
+ """
1361
+ Get regularization
1362
+
1363
+ Args:
1364
+ -----
1365
+ reg_metric : the regularization metric
1366
+ 'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'
1367
+ lamb_l1 : float
1368
+ l1 penalty strength
1369
+ lamb_entropy : float
1370
+ entropy penalty strength
1371
+ lamb_coef : float
1372
+ coefficient penalty strength
1373
+ lamb_coefdiff : float
1374
+ coefficient smoothness strength
1375
+
1376
+ Returns:
1377
+ --------
1378
+ reg_ : torch.float
1379
+
1380
+ Example
1381
+ -------
1382
+ >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
1383
+ >>> x = torch.rand(100,2)
1384
+ >>> model.get_act(x)
1385
+ >>> model.reg('edge_forward_spline_n', 1.0, 2.0, 1.0, 1.0)
1386
+ """
1387
+ if reg_metric == 'edge_forward_spline_n':
1388
+ acts_scale = self.acts_scale_spline
1389
+
1390
+ elif reg_metric == 'edge_forward_sum':
1391
+ acts_scale = self.acts_scale
1392
+
1393
+ elif reg_metric == 'edge_forward_spline_u':
1394
+ acts_scale = self.edge_actscale
1395
+
1396
+ elif reg_metric == 'edge_backward':
1397
+ acts_scale = self.edge_scores
1398
+
1399
+ elif reg_metric == 'node_backward':
1400
+ acts_scale = self.node_attribute_scores
1401
+
1402
+ else:
1403
+ raise Exception(f'reg_metric = {reg_metric} not recognized!')
1404
+
1405
+ reg_ = 0.
1406
+ for i in range(len(acts_scale)):
1407
+ vec = acts_scale[i]
1408
+
1409
+ l1 = torch.sum(vec)
1410
+ p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1)
1411
+ p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1)
1412
+ entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1))
1413
+ entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0))
1414
+ reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col) # both l1 and entropy
1415
+
1416
+ # regularize coefficient to encourage spline to be zero
1417
+ for i in range(len(self.act_fun)):
1418
+ coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1))
1419
+ coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)), dim=1))
1420
+ reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1
1421
+
1422
+ return reg_
1423
+
1424
+ def get_reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff):
1425
+ """
1426
+ Get regularization. This seems unnecessary but in case a class wants to inherit this, it may want to rewrite get_reg, but not reg.
1427
+ """
1428
+ return self.reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
1429
+
1430
+ def disable_symbolic_in_fit(self, lamb):
1431
+ """
1432
+ during fitting, disable symbolic if either is true (lamb = 0, none of symbolic functions is active)
1433
+ """
1434
+ old_save_act = self.save_act
1435
+ if lamb == 0.:
1436
+ self.save_act = False
1437
+
1438
+ # skip symbolic if no symbolic is turned on
1439
+ depth = len(self.symbolic_fun)
1440
+ no_symbolic = True
1441
+ for l in range(depth):
1442
+ no_symbolic *= torch.sum(torch.abs(self.symbolic_fun[l].mask)) == 0
1443
+
1444
+ old_symbolic_enabled = self.symbolic_enabled
1445
+
1446
+ if no_symbolic:
1447
+ self.symbolic_enabled = False
1448
+
1449
+ return old_save_act, old_symbolic_enabled
1450
+
1451
+ def get_params(self):
1452
+ """
1453
+ Get parameters
1454
+ """
1455
+ return self.parameters()
1456
+
1457
+ def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0.,
1458
+ lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
1459
+ stop_grid_update_step=50, batch=-1,
1460
+ metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
1461
+ singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):
1462
+ '''
1463
+ training
1464
+
1465
+ Args:
1466
+ -----
1467
+ dataset : dic
1468
+ contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
1469
+ opt : str
1470
+ "LBFGS" or "Adam"
1471
+ steps : int
1472
+ training steps
1473
+ log : int
1474
+ logging frequency
1475
+ lamb : float
1476
+ overall penalty strength
1477
+ lamb_l1 : float
1478
+ l1 penalty strength
1479
+ lamb_entropy : float
1480
+ entropy penalty strength
1481
+ lamb_coef : float
1482
+ coefficient magnitude penalty strength
1483
+ lamb_coefdiff : float
1484
+ difference of nearby coefficits (smoothness) penalty strength
1485
+ update_grid : bool
1486
+ If True, update grid regularly before stop_grid_update_step
1487
+ grid_update_num : int
1488
+ the number of grid updates before stop_grid_update_step
1489
+ start_grid_update_step : int
1490
+ no grid updates before this training step
1491
+ stop_grid_update_step : int
1492
+ no grid updates after this training step
1493
+ loss_fn : function
1494
+ loss function
1495
+ lr : float
1496
+ learning rate
1497
+ batch : int
1498
+ batch size, if -1 then full.
1499
+ save_fig_freq : int
1500
+ save figure every (save_fig_freq) steps
1501
+ singularity_avoiding : bool
1502
+ indicate whether to avoid singularity for the symbolic part
1503
+ y_th : float
1504
+ singularity threshold (anything above the threshold is considered singular and is softened in some ways)
1505
+ reg_metric : str
1506
+ regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'}
1507
+ metrics : a list of metrics (as functions)
1508
+ the metrics to be computed in training
1509
+ display_metrics : a list of functions
1510
+ the metric to be displayed in tqdm progress bar
1511
+
1512
+ Returns:
1513
+ --------
1514
+ results : dic
1515
+ results['train_loss'], 1D array of training losses (RMSE)
1516
+ results['test_loss'], 1D array of test losses (RMSE)
1517
+ results['reg'], 1D array of regularization
1518
+ other metrics specified in metrics
1519
+
1520
+ Example
1521
+ -------
1522
+ # >>> from kan import *
1523
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
1524
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
1525
+ # >>> dataset = create_dataset(f, n_var=2)
1526
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
1527
+ # >>> model.plot()
1528
+ # Most examples in toturals involve the fit() method. Please check them for useness.
1529
+ '''
1530
+
1531
+ if lamb > 0. and not self.save_act:
1532
+ print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True')
1533
+
1534
+ old_save_act, old_symbolic_enabled = self.disable_symbolic_in_fit(lamb)
1535
+
1536
+ pbar = tqdm(range(steps), desc='description', ncols=100)
1537
+
1538
+ if loss_fn is None:
1539
+ loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
1540
+ else:
1541
+ loss_fn = loss_fn_eval = loss_fn
1542
+
1543
+ grid_update_freq = int(stop_grid_update_step / grid_update_num)
1544
+
1545
+ if opt == "Adam":
1546
+ optimizer = torch.optim.Adam(self.get_params(), lr=lr)
1547
+ elif opt == "LBFGS":
1548
+ optimizer = LBFGS(self.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
1549
+ tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
1550
+
1551
+ results = {'train_loss': [], 'test_loss': [], 'reg': []}
1552
+ if metrics is not None:
1553
+ for i in range(len(metrics)):
1554
+ results[metrics[i].__name__] = []
1555
+
1556
+ if batch == -1 or batch > dataset['train_input'].shape[0]:
1557
+ batch_size = dataset['train_input'].shape[0]
1558
+ batch_size_test = dataset['test_input'].shape[0]
1559
+ else:
1560
+ batch_size = batch
1561
+ batch_size_test = batch
1562
+
1563
+ global train_loss, reg_
1564
+
1565
+ def closure():
1566
+ global train_loss, reg_
1567
+ optimizer.zero_grad()
1568
+ pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th)
1569
+ train_loss = loss_fn(pred, dataset['train_label'][train_id])
1570
+ if self.save_act:
1571
+ if reg_metric == 'edge_backward':
1572
+ self.attribute()
1573
+ if reg_metric == 'node_backward':
1574
+ self.node_attribute()
1575
+ reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
1576
+ else:
1577
+ reg_ = torch.tensor(0.)
1578
+ objective = train_loss + lamb * reg_
1579
+ objective.backward()
1580
+ return objective
1581
+
1582
+ if save_fig:
1583
+ if not os.path.exists(img_folder):
1584
+ os.makedirs(img_folder)
1585
+
1586
+ for _ in pbar:
1587
+
1588
+ if _ == steps - 1 and old_save_act:
1589
+ self.save_act = True
1590
+
1591
+ if save_fig and _ % save_fig_freq == 0:
1592
+ save_act = self.save_act
1593
+ self.save_act = True
1594
+
1595
+ train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
1596
+ test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
1597
+
1598
+ if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step:
1599
+ self.update_grid(dataset['train_input'][train_id])
1600
+
1601
+ if opt == "LBFGS":
1602
+ optimizer.step(closure)
1603
+
1604
+ if opt == "Adam":
1605
+ pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding,
1606
+ y_th=y_th)
1607
+ train_loss = loss_fn(pred, dataset['train_label'][train_id])
1608
+ if self.save_act:
1609
+ if reg_metric == 'edge_backward':
1610
+ self.attribute()
1611
+ if reg_metric == 'node_backward':
1612
+ self.node_attribute()
1613
+ reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
1614
+ else:
1615
+ reg_ = torch.tensor(0.)
1616
+ loss = train_loss + lamb * reg_
1617
+ optimizer.zero_grad()
1618
+ loss.backward()
1619
+ optimizer.step()
1620
+
1621
+ test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id]), dataset['test_label'][test_id])
1622
+
1623
+ if metrics is not None:
1624
+ for i in range(len(metrics)):
1625
+ results[metrics[i].__name__].append(metrics[i]().item())
1626
+
1627
+ results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
1628
+ results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
1629
+ results['reg'].append(reg_.cpu().detach().numpy())
1630
+
1631
+ if _ % log == 0:
1632
+ if display_metrics is None:
1633
+ pbar.set_description("| train_loss: %.6f | test_loss: %.6f | reg: %.6f | " % (
1634
+ torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(),
1635
+ reg_.cpu().detach().numpy()))
1636
+ else:
1637
+ string = ''
1638
+ data = ()
1639
+ for metric in display_metrics:
1640
+ string += f' {metric}: %.6f |'
1641
+ try:
1642
+ results[metric]
1643
+ except:
1644
+ raise Exception(f'{metric} not recognized')
1645
+ data += (results[metric][-1],)
1646
+ pbar.set_description(string % data)
1647
+
1648
+ if save_fig and _ % save_fig_freq == 0:
1649
+ self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
1650
+ plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=100)
1651
+ plt.close()
1652
+ self.save_act = save_act
1653
+
1654
+ self.log_history('fit')
1655
+ # revert back to original state
1656
+ self.symbolic_enabled = old_symbolic_enabled
1657
+ return results
1658
+
1659
+ def fix(self, dataset: dict, batch_size, batch_size_test, labels, opt="LBFGS", epochs=100, log=1, lamb=0.,
1660
+ lamb_l1=1.,
1661
+ lamb_entropy=2., lamb_coef=0.,
1662
+ lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
1663
+ stop_grid_update_step=100,
1664
+ metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
1665
+ singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):
1666
+ """
1667
+
1668
+ Args:
1669
+ labels:
1670
+ batch_size:
1671
+ batch_size_test:
1672
+ lamb_coefdiff:
1673
+ metrics:
1674
+ stop_grid_update_step:
1675
+ save_fig:
1676
+ singularity_avoiding:
1677
+ y_th:
1678
+ in_vars:
1679
+ batch:
1680
+ update_grid:
1681
+ grid_update_num:
1682
+ loss_fn:
1683
+ lr:
1684
+ beta:
1685
+ out_vars:
1686
+ save_fig_freq:
1687
+ reg_metric:
1688
+ display_metrics:
1689
+ img_folder:
1690
+ start_grid_update_step:
1691
+ lamb_coef:
1692
+ lamb_entropy:
1693
+ lamb_l1:
1694
+ lamb:
1695
+ log:
1696
+ epochs:
1697
+ opt:
1698
+ dataset (dict):
1699
+ """
1700
+
1701
+ def calculate_results(all_labels, all_predictions, classes=None, average='macro'):
1702
+ result = {
1703
+ 'accuracy': accuracy_score(y_true=all_labels, y_pred=all_predictions),
1704
+ 'precision': precision_score(y_true=all_labels, y_pred=all_predictions, average=average),
1705
+ 'recall': recall_score(y_true=all_labels, y_pred=all_predictions, average=average),
1706
+ 'f1_score': f1_score(y_true=all_labels, y_pred=all_predictions, average=average),
1707
+ # 'cm': confusion_matrix(y_true=all_labels, y_pred=all_predictions, labels=np.arange(len(classes)))
1708
+ }
1709
+ return result
1710
+ if lamb > 0. and not self.save_act:
1711
+ print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True')
1712
+
1713
+ all_predictions = []
1714
+ all_labels = []
1715
+
1716
+ old_save_act, old_symbolic_enabled = self.disable_symbolic_in_fit(lamb)
1717
+
1718
+ pbar = tqdm(range(epochs), desc='description', ncols=100)
1719
+
1720
+ if loss_fn is None:
1721
+ loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
1722
+ else:
1723
+ loss_fn = loss_fn_eval = loss_fn
1724
+
1725
+ grid_update_freq = int(stop_grid_update_step / grid_update_num)
1726
+
1727
+ if opt == "Adam":
1728
+ optimizer = torch.optim.Adam(self.get_params(), lr=lr)
1729
+ elif opt == "LBFGS":
1730
+ optimizer = LBFGS(self.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
1731
+ tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
1732
+
1733
+ results = {'train_loss': [], 'test_loss': [], 'reg': [], 'accuracy': [],
1734
+ 'precision': [], 'recall': [], 'f1_score': []}
1735
+ if metrics is not None:
1736
+ for i in range(len(metrics)):
1737
+ results[metrics[i].__name__] = []
1738
+
1739
+ steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
1740
+ total_steps = steps * epochs
1741
+ # global train_loss, reg_
1742
+ train_loss = torch.zeros(1).to(self.device)
1743
+ reg_ = torch.zeros(1).to(self.device)
1744
+ labels = labels.to(self.device)
1745
+
1746
+ def closure():
1747
+ nonlocal train_loss, reg_
1748
+ optimizer.zero_grad()
1749
+ pred = self.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
1750
+ loss = loss_fn(pred, batch_train_label)
1751
+ if self.save_act:
1752
+ if reg_metric == 'edge_backward':
1753
+ self.attribute()
1754
+ if reg_metric == 'node_backward':
1755
+ self.node_attribute()
1756
+ reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
1757
+ else:
1758
+ reg_ = torch.tensor(0.)
1759
+ objective = loss + lamb * reg_
1760
+ train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
1761
+ objective.backward()
1762
+ return objective
1763
+
1764
+ if save_fig:
1765
+ if not os.path.exists(img_folder):
1766
+ os.makedirs(img_folder)
1767
+
1768
+ for epoch in pbar:
1769
+
1770
+ if epoch == epochs - 1 and old_save_act:
1771
+ self.save_act = True
1772
+
1773
+ if save_fig and epoch % save_fig_freq == 0:
1774
+ save_act = self.save_act
1775
+ self.save_act = True
1776
+
1777
+ train_indices = np.arange(dataset['train_input'].shape[0])
1778
+ np.random.shuffle(train_indices)
1779
+ for batch_num, i in enumerate(range(0, len(train_indices), batch_size), start=0):
1780
+ step = epoch * steps + batch_num + 1
1781
+ batch_train_id = train_indices[i:i + batch_size]
1782
+ batch_train_input = dataset['train_input'][batch_train_id].to(self.device)
1783
+ batch_train_label = dataset['train_label'][batch_train_id].to(self.device)
1784
+
1785
+ if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
1786
+ self.update_grid(batch_train_input)
1787
+
1788
+ if opt == "LBFGS":
1789
+ optimizer.step(closure)
1790
+
1791
+ if opt == "Adam":
1792
+ optimizer.zero_grad()
1793
+ pred = self.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
1794
+ y_th=y_th)
1795
+ loss = loss_fn(pred, batch_train_label)
1796
+ if self.save_act:
1797
+ if reg_metric == 'edge_backward':
1798
+ self.attribute()
1799
+ if reg_metric == 'node_backward':
1800
+ self.node_attribute()
1801
+ reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
1802
+ else:
1803
+ reg_ = torch.tensor(0.)
1804
+ loss = loss + lamb * reg_
1805
+ train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
1806
+ loss.backward()
1807
+ optimizer.step()
1808
+
1809
+ test_loss = torch.zeros(1).to(self.device)
1810
+ with torch.no_grad():
1811
+ test_indices = np.arange(dataset['test_input'].shape[0])
1812
+ np.random.shuffle(test_indices)
1813
+ for batch_num, i in enumerate(range(0, len(test_indices), batch_size_test)):
1814
+ batch_test_id = test_indices[i:i + batch_size_test]
1815
+ batch_test_input = dataset['test_input'][batch_test_id].to(self.device)
1816
+ batch_test_label = dataset['test_label'][batch_test_id].to(self.device)
1817
+
1818
+ outputs = self.forward(batch_test_input)
1819
+
1820
+ loss = loss_fn(outputs, batch_test_label)
1821
+
1822
+ test_loss = (test_loss * batch_num + loss.detach()) / (batch_num + 1)
1823
+ diffs = torch.abs(outputs - labels)
1824
+ closest_indices = torch.argmin(diffs, dim=1)
1825
+ predict = labels[closest_indices]
1826
+ all_predictions.extend(predict.detach().cpu().numpy())
1827
+ all_labels.extend(batch_test_label.detach().cpu().numpy())
1828
+
1829
+ if metrics is not None:
1830
+ for i in range(len(metrics)):
1831
+ results[metrics[i].__name__].append(metrics[i]().item())
1832
+
1833
+ results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
1834
+ results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
1835
+ results['reg'].append(reg_.cpu().detach().numpy())
1836
+ res = calculate_results(all_predictions, all_labels)
1837
+ for key, value in res.items():
1838
+ results[key].append(value)
1839
+
1840
+ if epoch % log == 0:
1841
+ if display_metrics is None:
1842
+ pbar.set_description(f"| train_loss: %.6f | test_loss: %.6f | reg: %.6f |step:{step} " % (
1843
+ train_loss.item(), test_loss.item(), reg_.item()))
1844
+ else:
1845
+ string = ''
1846
+ data = ()
1847
+ for metric in display_metrics:
1848
+ string += f' {metric}: %.6f |'
1849
+ try:
1850
+ results[metric]
1851
+ except:
1852
+ raise Exception(f'{metric} not recognized')
1853
+ data += (results[metric][-1],)
1854
+ pbar.set_description(string % data)
1855
+
1856
+ if save_fig and epoch % save_fig_freq == 0:
1857
+ self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
1858
+ beta=beta)
1859
+ plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
1860
+ plt.close()
1861
+ self.save_act = save_act
1862
+
1863
+ self.log_history('fit')
1864
+ self.symbolic_enabled = old_symbolic_enabled
1865
+ return results
1866
+
1867
+ def prune_node(self, threshold=1e-2, mode="auto", active_neurons_id=None, log_history=True):
1868
+ """
1869
+ pruning nodes
1870
+
1871
+ Args:
1872
+ -----
1873
+ threshold : float
1874
+ if the attribution score of a neuron is below the threshold, it is considered dead and will be removed
1875
+ mode : str
1876
+ 'auto' or 'manual'. with 'auto', nodes are automatically pruned using threshold. with 'manual', active_neurons_id should be passed in.
1877
+
1878
+ Returns:
1879
+ --------
1880
+ pruned network : MultKAN
1881
+
1882
+ Example
1883
+ -------
1884
+ # >>> from kan import *
1885
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
1886
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
1887
+ # >>> dataset = create_dataset(f, n_var=2)
1888
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
1889
+ # >>> model = model.prune_node()
1890
+ # >>> model.plot()
1891
+ """
1892
+ if self.acts is None:
1893
+ self.get_act()
1894
+
1895
+ mask_up = [torch.ones(self.width_in[0], device=self.device)]
1896
+ mask_down = []
1897
+ active_neurons_up = [list(range(self.width_in[0]))]
1898
+ active_neurons_down = []
1899
+ num_sums = []
1900
+ num_mults = []
1901
+ mult_arities = [[]]
1902
+
1903
+ if active_neurons_id != None:
1904
+ mode = "manual"
1905
+
1906
+ for i in range(len(self.acts_scale) - 1):
1907
+
1908
+ mult_arity = []
1909
+
1910
+ if mode == "auto":
1911
+ self.attribute()
1912
+ overall_important_up = self.node_scores[i + 1] > threshold
1913
+
1914
+ elif mode == "manual":
1915
+ overall_important_up = torch.zeros(self.width_in[i + 1], dtype=torch.bool, device=self.device)
1916
+ overall_important_up[active_neurons_id[i]] = True
1917
+
1918
+ num_sum = torch.sum(overall_important_up[:self.width[i + 1][0]])
1919
+ num_mult = torch.sum(overall_important_up[self.width[i + 1][0]:])
1920
+ if self.mult_homo == True:
1921
+ overall_important_down = torch.cat([overall_important_up[:self.width[i + 1][0]], (
1922
+ overall_important_up[self.width[i + 1][0]:][None, :].expand(self.mult_arity, -1)).T.reshape(-1, )],
1923
+ dim=0)
1924
+ else:
1925
+ overall_important_down = overall_important_up[:self.width[i + 1][0]]
1926
+ for j in range(overall_important_up[self.width[i + 1][0]:].shape[0]):
1927
+ active_bool = overall_important_up[self.width[i + 1][0] + j]
1928
+ arity = self.mult_arity[i + 1][j]
1929
+ overall_important_down = torch.cat(
1930
+ [overall_important_down, torch.tensor([active_bool] * arity).to(self.device)])
1931
+ if active_bool:
1932
+ mult_arity.append(arity)
1933
+
1934
+ num_sums.append(num_sum.item())
1935
+ num_mults.append(num_mult.item())
1936
+
1937
+ mask_up.append(overall_important_up.float())
1938
+ mask_down.append(overall_important_down.float())
1939
+
1940
+ active_neurons_up.append(torch.where(overall_important_up == True)[0])
1941
+ active_neurons_down.append(torch.where(overall_important_down == True)[0])
1942
+
1943
+ mult_arities.append(mult_arity)
1944
+
1945
+ active_neurons_down.append(list(range(self.width_out[-1])))
1946
+ mask_down.append(torch.ones(self.width_out[-1], device=self.device))
1947
+
1948
+ if self.mult_homo == False:
1949
+ mult_arities.append(self.mult_arity[-1])
1950
+
1951
+ self.mask_up = mask_up
1952
+ self.mask_down = mask_down
1953
+
1954
+ # update act_fun[l].mask up
1955
+ for l in range(len(self.acts_scale) - 1):
1956
+ for i in range(self.width_in[l + 1]):
1957
+ if i not in active_neurons_up[l + 1]:
1958
+ self.remove_node(l + 1, i, mode='up', log_history=False)
1959
+
1960
+ for i in range(self.width_out[l + 1]):
1961
+ if i not in active_neurons_down[l]:
1962
+ self.remove_node(l + 1, i, mode='down', log_history=False)
1963
+
1964
+ model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun_name,
1965
+ mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False,
1966
+ state_id=self.state_id, round=self.round).to(self.device)
1967
+ model2.load_state_dict(self.state_dict())
1968
+
1969
+ width_new = [self.width[0]]
1970
+
1971
+ for i in range(len(self.acts_scale)):
1972
+
1973
+ if i < len(self.acts_scale) - 1:
1974
+ num_sum = num_sums[i]
1975
+ num_mult = num_mults[i]
1976
+ model2.node_bias[i].data = model2.node_bias[i].data[active_neurons_up[i + 1]]
1977
+ model2.node_scale[i].data = model2.node_scale[i].data[active_neurons_up[i + 1]]
1978
+ model2.subnode_bias[i].data = model2.subnode_bias[i].data[active_neurons_down[i]]
1979
+ model2.subnode_scale[i].data = model2.subnode_scale[i].data[active_neurons_down[i]]
1980
+ model2.width[i + 1] = [num_sum, num_mult]
1981
+
1982
+ model2.act_fun[i].out_dim_sum = num_sum
1983
+ model2.act_fun[i].out_dim_mult = num_mult
1984
+
1985
+ model2.symbolic_fun[i].out_dim_sum = num_sum
1986
+ model2.symbolic_fun[i].out_dim_mult = num_mult
1987
+
1988
+ width_new.append([num_sum, num_mult])
1989
+
1990
+ model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i])
1991
+ model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i])
1992
+
1993
+ model2.cache_data = self.cache_data
1994
+ model2.acts = None
1995
+
1996
+ width_new.append(self.width[-1])
1997
+ model2.width = width_new
1998
+
1999
+ if self.mult_homo == False:
2000
+ model2.mult_arity = mult_arities
2001
+
2002
+ if log_history:
2003
+ self.log_history('prune_node')
2004
+ model2.state_id += 1
2005
+
2006
+ return model2
2007
+
2008
+ def prune_edge(self, threshold=3e-2, log_history=True):
2009
+ '''
2010
+ pruning edges
2011
+
2012
+ Args:
2013
+ -----
2014
+ threshold : float
2015
+ if the attribution score of an edge is below the threshold, it is considered dead and will be set to zero.
2016
+
2017
+ Returns:
2018
+ --------
2019
+ pruned network : MultKAN
2020
+
2021
+ Example
2022
+ -------
2023
+ # >>> from kan import *
2024
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2025
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
2026
+ # >>> dataset = create_dataset(f, n_var=2)
2027
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2028
+ # >>> model = model.prune_edge()
2029
+ # >>> model.plot()
2030
+ '''
2031
+ if self.acts == None:
2032
+ self.get_act()
2033
+
2034
+ for i in range(len(self.width) - 1):
2035
+ #self.act_fun[i].mask.data = ((self.acts_scale[i] > threshold).permute(1,0)).float()
2036
+ old_mask = self.act_fun[i].mask.data
2037
+ self.act_fun[i].mask.data = ((self.edge_scores[i] > threshold).permute(1, 0) * old_mask).float()
2038
+
2039
+ if log_history:
2040
+ self.log_history('fix_symbolic')
2041
+
2042
+ def prune(self, node_th=1e-2, edge_th=3e-2):
2043
+ '''
2044
+ prune (both nodes and edges)
2045
+
2046
+ Args:
2047
+ -----
2048
+ node_th : float
2049
+ if the attribution score of a node is below node_th, it is considered dead and will be set to zero.
2050
+ edge_th : float
2051
+ if the attribution score of an edge is below node_th, it is considered dead and will be set to zero.
2052
+
2053
+ Returns:
2054
+ --------
2055
+ pruned network : MultKAN
2056
+
2057
+ Example
2058
+ -------
2059
+ # >>> from kan import *
2060
+ # >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2061
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
2062
+ # >>> dataset = create_dataset(f, n_var=2)
2063
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2064
+ # >>> model = model.prune()
2065
+ # >>> model.plot()
2066
+ '''
2067
+ if self.acts == None:
2068
+ self.get_act()
2069
+
2070
+ self = self.prune_node(node_th, log_history=False)
2071
+ #self.prune_node(node_th, log_history=False)
2072
+ self.forward(self.cache_data)
2073
+ self.attribute()
2074
+ self.prune_edge(edge_th, log_history=False)
2075
+ self.log_history('prune')
2076
+ return self
2077
+
2078
+ def prune_input(self, threshold=1e-2, active_inputs=None, log_history=True):
2079
+ """
2080
+ prune inputs
2081
+
2082
+ Args:
2083
+ -----
2084
+ threshold : float
2085
+ if the attribution score of the input feature is below threshold, it is considered irrelevant.
2086
+ active_inputs : None or list
2087
+ if a list is passed, the manual mode will disregard attribution score and prune as instructed.
2088
+
2089
+ Returns:
2090
+ --------
2091
+ pruned network : MultKAN
2092
+
2093
+ Example1
2094
+ --------
2095
+ >>> # automatic
2096
+ # >>> from kan import *
2097
+ # >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2098
+ # >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
2099
+ # >>> dataset = create_dataset(f, n_var=3)
2100
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2101
+ # >>> model.plot()
2102
+ # >>> model = model.prune_input()
2103
+ # >>> model.plot()
2104
+
2105
+ Example2
2106
+ --------
2107
+ >>> # automatic
2108
+ # >>> from kan import *
2109
+ # >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2110
+ # >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
2111
+ # >>> dataset = create_dataset(f, n_var=3)
2112
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2113
+ # >>> model.plot()
2114
+ # >>> model = model.prune_input(active_inputs=[0,1])
2115
+ # >>> model.plot()
2116
+ """
2117
+ if active_inputs is None:
2118
+ self.attribute()
2119
+ input_score = self.node_scores[0]
2120
+ input_mask = input_score > threshold
2121
+ print('keep:', input_mask.tolist())
2122
+ input_id = torch.where(input_mask == True)[0]
2123
+
2124
+ else:
2125
+ input_id = torch.tensor(active_inputs, dtype=torch.long).to(self.device)
2126
+
2127
+ model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun,
2128
+ mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False,
2129
+ state_id=self.state_id, round=self.round).to(self.device)
2130
+ model2.load_state_dict(self.state_dict())
2131
+
2132
+ model2.act_fun[0] = model2.act_fun[0].get_subset(input_id, torch.arange(self.width_out[1]))
2133
+ model2.symbolic_fun[0] = self.symbolic_fun[0].get_subset(input_id, torch.arange(self.width_out[1]))
2134
+
2135
+ model2.cache_data = self.cache_data
2136
+ model2.acts = None
2137
+
2138
+ model2.width[0] = [len(input_id), 0]
2139
+ model2.input_id = input_id
2140
+
2141
+ if log_history:
2142
+ self.log_history('prune_input')
2143
+ model2.state_id += 1
2144
+
2145
+ return model2
2146
+
2147
+ def remove_edge(self, l, i, j, log_history=True):
2148
+ """
2149
+ remove activtion phi(l,i,j) (set its mask to zero)
2150
+ """
2151
+ self.act_fun[l].mask[i][j] = 0.
2152
+ if log_history:
2153
+ self.log_history('remove_edge')
2154
+
2155
+ def remove_node(self, l, i, mode='all', log_history=True):
2156
+ """
2157
+ remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero)
2158
+ """
2159
+ if mode == 'down':
2160
+ self.act_fun[l - 1].mask[:, i] = 0.
2161
+ self.symbolic_fun[l - 1].mask[i, :] *= 0.
2162
+
2163
+ elif mode == 'up':
2164
+ self.act_fun[l].mask[i, :] = 0.
2165
+ self.symbolic_fun[l].mask[:, i] *= 0.
2166
+
2167
+ else:
2168
+ self.remove_node(l, i, mode='up')
2169
+ self.remove_node(l, i, mode='down')
2170
+
2171
+ if log_history:
2172
+ self.log_history('remove_node')
2173
+
2174
+ def attribute(self, l=None, i=None, out_score=None, plot=True):
2175
+ """
2176
+ get attribution scores
2177
+
2178
+ Args:
2179
+ -----
2180
+ l : None or int
2181
+ layer index
2182
+ i : None or int
2183
+ neuron index
2184
+ out_score : None or 1D torch.float
2185
+ specify output scores
2186
+ plot : bool
2187
+ when plot = True, display the bar show
2188
+
2189
+ Returns:
2190
+ --------
2191
+ attribution scores
2192
+
2193
+ Example
2194
+ -------
2195
+ # >>> from kan import *
2196
+ # >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2197
+ # >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
2198
+ # >>> dataset = create_dataset(f, n_var=3)
2199
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2200
+ # >>> model.attribute()
2201
+ # >>> model.feature_score
2202
+ """
2203
+ # output (out_dim, in_dim)
2204
+
2205
+ if l is not None:
2206
+ self.attribute()
2207
+ out_score = self.node_scores[l]
2208
+
2209
+ if self.acts == None:
2210
+ self.get_act()
2211
+
2212
+ def score_node2subnode(node_score, width, mult_arity, out_dim):
2213
+
2214
+ assert np.sum(width) == node_score.shape[1]
2215
+ if isinstance(mult_arity, int):
2216
+ n_subnode = width[0] + mult_arity * width[1]
2217
+ else:
2218
+ n_subnode = width[0] + int(np.sum(mult_arity))
2219
+
2220
+ #subnode_score_leaf = torch.zeros(out_dim, n_subnode).requires_grad_(True)
2221
+ #subnode_score = subnode_score_leaf.clone()
2222
+ #subnode_score[:,:width[0]] = node_score[:,:width[0]]
2223
+ subnode_score = node_score[:, :width[0]]
2224
+ if isinstance(mult_arity, int):
2225
+ #subnode_score[:,width[0]:] = node_score[:,width[0]:][:,:,None].expand(out_dim, node_score[width[0]:].shape[0], mult_arity).reshape(out_dim,-1)
2226
+ subnode_score = torch.cat([subnode_score, node_score[:, width[0]:][:, :, None].expand(out_dim,
2227
+ node_score[:,
2228
+ width[0]:].shape[
2229
+ 1],
2230
+ mult_arity).reshape(
2231
+ out_dim, -1)], dim=1)
2232
+ else:
2233
+ acml = width[0]
2234
+ for i in range(len(mult_arity)):
2235
+ #subnode_score[:, acml:acml+mult_arity[i]] = node_score[:, width[0]+i]
2236
+ subnode_score = torch.cat(
2237
+ [subnode_score, node_score[:, width[0] + i].expand(out_dim, mult_arity[i])], dim=1)
2238
+ acml += mult_arity[i]
2239
+ return subnode_score
2240
+
2241
+ node_scores = []
2242
+ subnode_scores = []
2243
+ edge_scores = []
2244
+
2245
+ l_query = l
2246
+ if l is None:
2247
+ l_end = self.depth
2248
+ else:
2249
+ l_end = l
2250
+
2251
+ # back propagate from the queried layer
2252
+ out_dim = self.width_in[l_end]
2253
+ if out_score is None:
2254
+ node_score = torch.eye(out_dim).requires_grad_(True)
2255
+ else:
2256
+ node_score = torch.diag(out_score).requires_grad_(True)
2257
+ node_scores.append(node_score)
2258
+
2259
+ device = self.act_fun[0].grid.device
2260
+
2261
+ for l in range(l_end, 0, -1):
2262
+
2263
+ # node to subnode
2264
+ if isinstance(self.mult_arity, int):
2265
+ subnode_score = score_node2subnode(node_score, self.width[l], self.mult_arity, out_dim=out_dim)
2266
+ else:
2267
+ mult_arity = self.mult_arity[l]
2268
+ #subnode_score = score_node2subnode(node_score, self.width[l], mult_arity)
2269
+ subnode_score = score_node2subnode(node_score, self.width[l], mult_arity, out_dim=out_dim)
2270
+
2271
+ subnode_scores.append(subnode_score)
2272
+ # subnode to edge
2273
+ #print(self.edge_actscale[l-1].device, subnode_score.device, self.subnode_actscale[l-1].device)
2274
+ edge_score = torch.einsum('ij,ki,i->kij', self.edge_actscale[l - 1], subnode_score.to(device),
2275
+ 1 / (self.subnode_actscale[l - 1] + 1e-4))
2276
+ edge_scores.append(edge_score)
2277
+
2278
+ # edge to node
2279
+ node_score = torch.sum(edge_score, dim=1)
2280
+ node_scores.append(node_score)
2281
+
2282
+ self.node_scores_all = list(reversed(node_scores))
2283
+ self.edge_scores_all = list(reversed(edge_scores))
2284
+ self.subnode_scores_all = list(reversed(subnode_scores))
2285
+
2286
+ self.node_scores = [torch.mean(l, dim=0) for l in self.node_scores_all]
2287
+ self.edge_scores = [torch.mean(l, dim=0) for l in self.edge_scores_all]
2288
+ self.subnode_scores = [torch.mean(l, dim=0) for l in self.subnode_scores_all]
2289
+
2290
+ # return
2291
+ if l_query != None:
2292
+ if i == None:
2293
+ return self.node_scores_all[0]
2294
+ else:
2295
+
2296
+ # plot
2297
+ if plot:
2298
+ in_dim = self.width_in[0]
2299
+ plt.figure(figsize=(1 * in_dim, 3))
2300
+ plt.bar(range(in_dim), self.node_scores_all[0][i].cpu().detach().numpy())
2301
+ plt.xticks(range(in_dim));
2302
+
2303
+ return self.node_scores_all[0][i]
2304
+
2305
+ def node_attribute(self):
2306
+ self.node_attribute_scores = []
2307
+ for l in range(1, self.depth + 1):
2308
+ node_attr = self.attribute(l)
2309
+ self.node_attribute_scores.append(node_attr)
2310
+
2311
+ def feature_interaction(self, l, neuron_th=1e-2, feature_th=1e-2):
2312
+ """
2313
+ get feature interaction
2314
+
2315
+ Args:
2316
+ -----
2317
+ l : int
2318
+ layer index
2319
+ neuron_th : float
2320
+ threshold to determine whether a neuron is active
2321
+ feature_th : float
2322
+ threshold to determine whether a feature is active
2323
+
2324
+ Returns:
2325
+ --------
2326
+ dictionary
2327
+
2328
+ Example
2329
+ -------
2330
+ # >>> from kan import *
2331
+ # >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
2332
+ # >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
2333
+ # >>> dataset = create_dataset(f, n_var=3)
2334
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2335
+ # >>> model.attribute()
2336
+ # >>> model.feature_interaction(1)
2337
+ """
2338
+ dic = {}
2339
+ width = self.width_in[l]
2340
+
2341
+ for i in range(width):
2342
+ score = self.attribute(l, i, plot=False)
2343
+
2344
+ if torch.max(score) > neuron_th:
2345
+ features = tuple(torch.where(score > torch.max(score) * feature_th)[0].detach().numpy())
2346
+ if features in dic.keys():
2347
+ dic[features] += 1
2348
+ else:
2349
+ dic[features] = 1
2350
+
2351
+ return dic
2352
+
2353
+ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True,
2354
+ r2_loss_fun=lambda x: np.log2(1 + 1e-5 - x), c_loss_fun=lambda x: x, weight_simple=0.8):
2355
+ """
2356
+ suggest symbolic function
2357
+
2358
+ Args:
2359
+ -----
2360
+ l : int
2361
+ layer index
2362
+ i : int
2363
+ neuron index in layer l
2364
+ j : int
2365
+ neuron index in layer j
2366
+ a_range : tuple
2367
+ search range of a
2368
+ b_range : tuple
2369
+ search range of b
2370
+ lib : list of str
2371
+ library of candidate symbolic functions
2372
+ topk : int
2373
+ the number of top functions displayed
2374
+ verbose : bool
2375
+ if verbose = True, print more information
2376
+ r2_loss_fun : functoon
2377
+ function : r2 -> "bits"
2378
+ c_loss_fun : fun
2379
+ function : c -> 'bits'
2380
+ weight_simple : float
2381
+ the simplifty weight: the higher, more prefer simplicity over performance
2382
+
2383
+
2384
+ Returns:
2385
+ --------
2386
+ best_name (str), best_fun (function), best_r2 (float), best_c (float)
2387
+
2388
+ Example
2389
+ -------
2390
+ # >>> from kan import *
2391
+ # >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
2392
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
2393
+ # >>> dataset = create_dataset(f, n_var=3)
2394
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2395
+ # >>> model.suggest_symbolic(0,1,0)
2396
+ """
2397
+ r2s = []
2398
+ cs = []
2399
+
2400
+ if lib == None:
2401
+ symbolic_lib = SYMBOLIC_LIB
2402
+ else:
2403
+ symbolic_lib = {}
2404
+ for item in lib:
2405
+ symbolic_lib[item] = SYMBOLIC_LIB[item]
2406
+
2407
+ # getting r2 and complexities
2408
+ for (name, content) in symbolic_lib.items():
2409
+ r2 = self.fix_symbolic(l, i, j, name, a_range=a_range, b_range=b_range, verbose=False, log_history=False)
2410
+ if r2 == -1e8: # zero function
2411
+ r2s.append(-1e8)
2412
+ else:
2413
+ r2s.append(r2.item())
2414
+ self.unfix_symbolic(l, i, j, log_history=False)
2415
+ c = content[2]
2416
+ cs.append(c)
2417
+
2418
+ r2s = np.array(r2s)
2419
+ cs = np.array(cs)
2420
+ r2_loss = r2_loss_fun(r2s).astype('float')
2421
+ cs_loss = c_loss_fun(cs)
2422
+
2423
+ loss = weight_simple * cs_loss + (1 - weight_simple) * r2_loss
2424
+
2425
+ sorted_ids = np.argsort(loss)[:topk]
2426
+ r2s = r2s[sorted_ids][:topk]
2427
+ cs = cs[sorted_ids][:topk]
2428
+ r2_loss = r2_loss[sorted_ids][:topk]
2429
+ cs_loss = cs_loss[sorted_ids][:topk]
2430
+ loss = loss[sorted_ids][:topk]
2431
+
2432
+ topk = np.minimum(topk, len(symbolic_lib))
2433
+
2434
+ if verbose == True:
2435
+ # print results in a dataframe
2436
+ results = {}
2437
+ results['function'] = [list(symbolic_lib.items())[sorted_ids[i]][0] for i in range(topk)]
2438
+ results['fitting r2'] = r2s[:topk]
2439
+ results['r2 loss'] = r2_loss[:topk]
2440
+ results['complexity'] = cs[:topk]
2441
+ results['complexity loss'] = cs_loss[:topk]
2442
+ results['total loss'] = loss[:topk]
2443
+
2444
+ df = pd.DataFrame(results)
2445
+ print(df)
2446
+
2447
+ best_name = list(symbolic_lib.items())[sorted_ids[0]][0]
2448
+ best_fun = list(symbolic_lib.items())[sorted_ids[0]][1]
2449
+ best_r2 = r2s[0]
2450
+ best_c = cs[0]
2451
+
2452
+ return best_name, best_fun, best_r2, best_c;
2453
+
2454
+ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple=0.8,
2455
+ r2_threshold=0.0):
2456
+ """
2457
+ automatic symbolic regression for all edges
2458
+
2459
+ Args:
2460
+ -----
2461
+ a_range : tuple
2462
+ search range of a
2463
+ b_range : tuple
2464
+ search range of b
2465
+ lib : list of str
2466
+ library of candidate symbolic functions
2467
+ verbose : int
2468
+ larger verbosity => more verbosity
2469
+ weight_simple : float
2470
+ a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
2471
+ r2_threshold : float
2472
+ If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
2473
+ Returns:
2474
+ --------
2475
+ None
2476
+
2477
+ Example
2478
+ -------
2479
+ # >>> from kan import *
2480
+ # >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
2481
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
2482
+ # >>> dataset = create_dataset(f, n_var=3)
2483
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2484
+ # >>> model.auto_symbolic()
2485
+ """
2486
+ for l in range(len(self.width_in) - 1):
2487
+ for i in range(self.width_in[l]):
2488
+ for j in range(self.width_out[l + 1]):
2489
+ if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
2490
+ print(f'skipping ({l},{i},{j}) since already symbolic')
2491
+ elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.:
2492
+ self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False)
2493
+ print(f'fixing ({l},{i},{j}) with 0')
2494
+ else:
2495
+ name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib,
2496
+ verbose=False, weight_simple=weight_simple)
2497
+ if r2 >= r2_threshold:
2498
+ self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
2499
+ if verbose >= 1:
2500
+ print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
2501
+ else:
2502
+ print(
2503
+ f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.')
2504
+
2505
+ self.log_history('auto_symbolic')
2506
+
2507
+ def symbolic_formula(self, var=None, normalizer=None, output_normalizer=None):
2508
+ """
2509
+ get symbolic formula
2510
+
2511
+ Args:
2512
+ -----
2513
+ var : None or a list of sympy expression
2514
+ input variables
2515
+ normalizer : [mean, std]
2516
+ output_normalizer : [mean, std]
2517
+
2518
+ Returns:
2519
+ --------
2520
+ None
2521
+
2522
+ Example
2523
+ -------
2524
+ # >>> from kan import *
2525
+ # >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
2526
+ # >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
2527
+ # >>> dataset = create_dataset(f, n_var=3)
2528
+ # >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
2529
+ # >>> model.auto_symbolic()
2530
+ # >>> model.symbolic_formula()[0][0]
2531
+ """
2532
+
2533
+ symbolic_acts = []
2534
+ symbolic_acts_premult = []
2535
+ x = []
2536
+
2537
+ def ex_round(ex1, n_digit):
2538
+ ex2 = ex1
2539
+ for a in sympy.preorder_traversal(ex1):
2540
+ if isinstance(a, sympy.Float):
2541
+ ex2 = ex2.subs(a, round(a, n_digit))
2542
+ return ex2
2543
+
2544
+ # define variables
2545
+ if var is None:
2546
+ for ii in range(1, self.width[0][0] + 1):
2547
+ exec(f"x{ii} = sympy.Symbol('x_{ii}')")
2548
+ exec(f"x.append(x{ii})")
2549
+ elif isinstance(var[0], sympy.Expr):
2550
+ x = var
2551
+ else:
2552
+ x = [sympy.symbols(var_) for var_ in var]
2553
+
2554
+ x0 = x
2555
+
2556
+ if normalizer is not None:
2557
+ mean = normalizer[0]
2558
+ std = normalizer[1]
2559
+ x = [(x[i] - mean[i]) / std[i] for i in range(len(x))]
2560
+
2561
+ symbolic_acts.append(x)
2562
+
2563
+ for l in range(len(self.width_in) - 1):
2564
+ num_sum = self.width[l + 1][0]
2565
+ num_mult = self.width[l + 1][1]
2566
+ y = []
2567
+ for j in range(self.width_out[l + 1]):
2568
+ yj = 0.
2569
+ for i in range(self.width_in[l]):
2570
+ a, b, c, d = self.symbolic_fun[l].affine[j, i]
2571
+ sympy_fun = self.symbolic_fun[l].funs_sympy[j][i]
2572
+ try:
2573
+ yj += c * sympy_fun(a * x[i] + b) + d
2574
+ except:
2575
+ print('make sure all activations need to be converted to symbolic formulas first!')
2576
+ return
2577
+ yj = self.subnode_scale[l][j] * yj + self.subnode_bias[l][j]
2578
+ if simplify == True:
2579
+ y.append(sympy.simplify(yj))
2580
+ else:
2581
+ y.append(yj)
2582
+
2583
+ symbolic_acts_premult.append(y)
2584
+
2585
+ mult = []
2586
+ for k in range(num_mult):
2587
+ if isinstance(self.mult_arity, int):
2588
+ mult_arity = self.mult_arity
2589
+ else:
2590
+ mult_arity = self.mult_arity[l + 1][k]
2591
+ for i in range(mult_arity - 1):
2592
+ if i == 0:
2593
+ mult_k = y[num_sum + 2 * k] * y[num_sum + 2 * k + 1]
2594
+ else:
2595
+ mult_k = mult_k * y[num_sum + 2 * k + i + 1]
2596
+ mult.append(mult_k)
2597
+
2598
+ y = y[:num_sum] + mult
2599
+
2600
+ for j in range(self.width_in[l + 1]):
2601
+ y[j] = self.node_scale[l][j] * y[j] + self.node_bias[l][j]
2602
+
2603
+ x = y
2604
+ symbolic_acts.append(x)
2605
+
2606
+ if output_normalizer != None:
2607
+ output_layer = symbolic_acts[-1]
2608
+ means = output_normalizer[0]
2609
+ stds = output_normalizer[1]
2610
+
2611
+ assert len(output_layer) == len(means), 'output_normalizer does not match the output layer'
2612
+ assert len(output_layer) == len(stds), 'output_normalizer does not match the output layer'
2613
+
2614
+ output_layer = [(output_layer[i] * stds[i] + means[i]) for i in range(len(output_layer))]
2615
+ symbolic_acts[-1] = output_layer
2616
+
2617
+ self.symbolic_acts = [[symbolic_acts[l][i] for i in range(len(symbolic_acts[l]))] for l in
2618
+ range(len(symbolic_acts))]
2619
+ self.symbolic_acts_premult = [[symbolic_acts_premult[l][i] for i in range(len(symbolic_acts_premult[l]))] for l
2620
+ in range(len(symbolic_acts_premult))]
2621
+
2622
+ out_dim = len(symbolic_acts[-1])
2623
+ #return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0
2624
+
2625
+ if simplify:
2626
+ return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0
2627
+ else:
2628
+ return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0
2629
+
2630
+ def expand_depth(self):
2631
+ '''
2632
+ expand network depth, add an indentity layer to the end. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
2633
+
2634
+ Args:
2635
+ -----
2636
+ var : None or a list of sympy expression
2637
+ input variables
2638
+ normalizer : [mean, std]
2639
+ output_normalizer : [mean, std]
2640
+
2641
+ Returns:
2642
+ --------
2643
+ None
2644
+ '''
2645
+ self.depth += 1
2646
+
2647
+ # add kanlayer, set mask to zero
2648
+ dim_out = self.width_in[-1]
2649
+ layer = KANLayer(dim_out, dim_out, num=self.grid, k=self.k)
2650
+ layer.mask *= 0.
2651
+ self.act_fun.append(layer)
2652
+
2653
+ self.width.append([dim_out, 0])
2654
+ self.mult_arity.append([])
2655
+
2656
+ # add symbolic_kanlayer set mask to one. fun = identity on diagonal and zero for off-diagonal
2657
+ layer = Symbolic_KANLayer(dim_out, dim_out)
2658
+ layer.mask += 1.
2659
+
2660
+ for j in range(dim_out):
2661
+ for i in range(dim_out):
2662
+ if i == j:
2663
+ layer.fix_symbolic(i, j, 'x')
2664
+ else:
2665
+ layer.fix_symbolic(i, j, '0')
2666
+
2667
+ self.symbolic_fun.append(layer)
2668
+
2669
+ self.node_bias.append(
2670
+ torch.nn.Parameter(torch.zeros(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
2671
+ self.node_scale.append(
2672
+ torch.nn.Parameter(torch.ones(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
2673
+ self.subnode_bias.append(
2674
+ torch.nn.Parameter(torch.zeros(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
2675
+ self.subnode_scale.append(
2676
+ torch.nn.Parameter(torch.ones(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
2677
+
2678
+ def expand_width(self, layer_id, n_added_nodes, sum_bool=True, mult_arity=2):
2679
+ '''
2680
+ expand network width. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
2681
+
2682
+ Args:
2683
+ -----
2684
+ layer_id : int
2685
+ layer index
2686
+ n_added_nodes : init
2687
+ the number of added nodes
2688
+ sum_bool : bool
2689
+ if sum_bool == True, added nodes are addition nodes; otherwise multiplication nodes
2690
+ mult_arity : init
2691
+ multiplication arity (the number of numbers to be multiplied)
2692
+
2693
+ Returns:
2694
+ --------
2695
+ None
2696
+ '''
2697
+
2698
+ def _expand(layer_id, n_added_nodes, sum_bool=True, mult_arity=2, added_dim='out'):
2699
+ l = layer_id
2700
+ in_dim = self.symbolic_fun[l].in_dim
2701
+ out_dim = self.symbolic_fun[l].out_dim
2702
+ if sum_bool:
2703
+
2704
+ if added_dim == 'out':
2705
+ new = Symbolic_KANLayer(in_dim, out_dim + n_added_nodes)
2706
+ old = self.symbolic_fun[l]
2707
+ in_id = np.arange(in_dim)
2708
+ out_id = np.arange(out_dim + n_added_nodes)
2709
+
2710
+ for j in out_id:
2711
+ for i in in_id:
2712
+ new.fix_symbolic(i, j, '0')
2713
+ new.mask += 1.
2714
+
2715
+ for j in out_id:
2716
+ for i in in_id:
2717
+ if j > n_added_nodes - 1:
2718
+ new.funs[j][i] = old.funs[j - n_added_nodes][i]
2719
+ new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j - n_added_nodes][i]
2720
+ new.funs_sympy[j][i] = old.funs_sympy[j - n_added_nodes][i]
2721
+ new.funs_name[j][i] = old.funs_name[j - n_added_nodes][i]
2722
+ new.affine.data[j][i] = old.affine.data[j - n_added_nodes][i]
2723
+
2724
+ self.symbolic_fun[l] = new
2725
+ self.act_fun[l] = KANLayer(in_dim, out_dim + n_added_nodes, num=self.grid, k=self.k)
2726
+ self.act_fun[l].mask *= 0.
2727
+
2728
+ self.node_scale[l].data = torch.cat(
2729
+ [torch.ones(n_added_nodes, device=self.device), self.node_scale[l].data])
2730
+ self.node_bias[l].data = torch.cat(
2731
+ [torch.zeros(n_added_nodes, device=self.device), self.node_bias[l].data])
2732
+ self.subnode_scale[l].data = torch.cat(
2733
+ [torch.ones(n_added_nodes, device=self.device), self.subnode_scale[l].data])
2734
+ self.subnode_bias[l].data = torch.cat(
2735
+ [torch.zeros(n_added_nodes, device=self.device), self.subnode_bias[l].data])
2736
+
2737
+ if added_dim == 'in':
2738
+ new = Symbolic_KANLayer(in_dim + n_added_nodes, out_dim)
2739
+ old = self.symbolic_fun[l]
2740
+ in_id = np.arange(in_dim + n_added_nodes)
2741
+ out_id = np.arange(out_dim)
2742
+
2743
+ for j in out_id:
2744
+ for i in in_id:
2745
+ new.fix_symbolic(i, j, '0')
2746
+ new.mask += 1.
2747
+
2748
+ for j in out_id:
2749
+ for i in in_id:
2750
+ if i > n_added_nodes - 1:
2751
+ new.funs[j][i] = old.funs[j][i - n_added_nodes]
2752
+ new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i - n_added_nodes]
2753
+ new.funs_sympy[j][i] = old.funs_sympy[j][i - n_added_nodes]
2754
+ new.funs_name[j][i] = old.funs_name[j][i - n_added_nodes]
2755
+ new.affine.data[j][i] = old.affine.data[j][i - n_added_nodes]
2756
+
2757
+ self.symbolic_fun[l] = new
2758
+ self.act_fun[l] = KANLayer(in_dim + n_added_nodes, out_dim, num=self.grid, k=self.k)
2759
+ self.act_fun[l].mask *= 0.
2760
+
2761
+
2762
+ else:
2763
+
2764
+ if isinstance(mult_arity, int):
2765
+ mult_arity = [mult_arity] * n_added_nodes
2766
+
2767
+ if added_dim == 'out':
2768
+ n_added_subnodes = np.sum(mult_arity)
2769
+ new = Symbolic_KANLayer(in_dim, out_dim + n_added_subnodes)
2770
+ old = self.symbolic_fun[l]
2771
+ in_id = np.arange(in_dim)
2772
+ out_id = np.arange(out_dim + n_added_nodes)
2773
+
2774
+ for j in out_id:
2775
+ for i in in_id:
2776
+ new.fix_symbolic(i, j, '0')
2777
+ new.mask += 1.
2778
+
2779
+ for j in out_id:
2780
+ for i in in_id:
2781
+ if j < out_dim:
2782
+ new.funs[j][i] = old.funs[j][i]
2783
+ new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i]
2784
+ new.funs_sympy[j][i] = old.funs_sympy[j][i]
2785
+ new.funs_name[j][i] = old.funs_name[j][i]
2786
+ new.affine.data[j][i] = old.affine.data[j][i]
2787
+
2788
+ self.symbolic_fun[l] = new
2789
+ self.act_fun[l] = KANLayer(in_dim, out_dim + n_added_subnodes, num=self.grid, k=self.k)
2790
+ self.act_fun[l].mask *= 0.
2791
+
2792
+ self.node_scale[l].data = torch.cat(
2793
+ [self.node_scale[l].data, torch.ones(n_added_nodes, device=self.device)])
2794
+ self.node_bias[l].data = torch.cat(
2795
+ [self.node_bias[l].data, torch.zeros(n_added_nodes, device=self.device)])
2796
+ self.subnode_scale[l].data = torch.cat(
2797
+ [self.subnode_scale[l].data, torch.ones(n_added_subnodes, device=self.device)])
2798
+ self.subnode_bias[l].data = torch.cat(
2799
+ [self.subnode_bias[l].data, torch.zeros(n_added_subnodes, device=self.device)])
2800
+
2801
+ if added_dim == 'in':
2802
+ new = Symbolic_KANLayer(in_dim + n_added_nodes, out_dim)
2803
+ old = self.symbolic_fun[l]
2804
+ in_id = np.arange(in_dim + n_added_nodes)
2805
+ out_id = np.arange(out_dim)
2806
+
2807
+ for j in out_id:
2808
+ for i in in_id:
2809
+ new.fix_symbolic(i, j, '0')
2810
+ new.mask += 1.
2811
+
2812
+ for j in out_id:
2813
+ for i in in_id:
2814
+ if i < in_dim:
2815
+ new.funs[j][i] = old.funs[j][i]
2816
+ new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i]
2817
+ new.funs_sympy[j][i] = old.funs_sympy[j][i]
2818
+ new.funs_name[j][i] = old.funs_name[j][i]
2819
+ new.affine.data[j][i] = old.affine.data[j][i]
2820
+
2821
+ self.symbolic_fun[l] = new
2822
+ self.act_fun[l] = KANLayer(in_dim + n_added_nodes, out_dim, num=self.grid, k=self.k)
2823
+ self.act_fun[l].mask *= 0.
2824
+
2825
+ _expand(layer_id - 1, n_added_nodes, sum_bool, mult_arity, added_dim='out')
2826
+ _expand(layer_id, n_added_nodes, sum_bool, mult_arity, added_dim='in')
2827
+ if sum_bool:
2828
+ self.width[layer_id][0] += n_added_nodes
2829
+ else:
2830
+ if isinstance(mult_arity, int):
2831
+ mult_arity = [mult_arity] * n_added_nodes
2832
+
2833
+ self.width[layer_id][1] += n_added_nodes
2834
+ self.mult_arity[layer_id] += mult_arity
2835
+
2836
+ def perturb(self, mag=1.0, mode='non-intrusive'):
2837
+ """
2838
+ preturb a network. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
2839
+
2840
+ Args:
2841
+ -----
2842
+ mag : float
2843
+ perturbation magnitude
2844
+ mode : str
2845
+ pertubatation mode, choices = {'non-intrusive', 'all', 'minimal'}
2846
+
2847
+ Returns:
2848
+ --------
2849
+ None
2850
+ """
2851
+ perturb_bool = {}
2852
+
2853
+ if mode == 'all':
2854
+ perturb_bool['aa_a'] = True
2855
+ perturb_bool['aa_i'] = True
2856
+ perturb_bool['ai'] = True
2857
+ perturb_bool['ia'] = True
2858
+ perturb_bool['ii'] = True
2859
+ elif mode == 'non-intrusive':
2860
+ perturb_bool['aa_a'] = False
2861
+ perturb_bool['aa_i'] = False
2862
+ perturb_bool['ai'] = True
2863
+ perturb_bool['ia'] = False
2864
+ perturb_bool['ii'] = True
2865
+ elif mode == 'minimal':
2866
+ perturb_bool['aa_a'] = True
2867
+ perturb_bool['aa_i'] = False
2868
+ perturb_bool['ai'] = False
2869
+ perturb_bool['ia'] = False
2870
+ perturb_bool['ii'] = False
2871
+ else:
2872
+ raise Exception('mode not recognized, valid modes are \'all\', \'non-intrusive\', \'minimal\'.')
2873
+
2874
+ for l in range(self.depth):
2875
+ funs_name = self.symbolic_fun[l].funs_name
2876
+ for j in range(self.width_out[l + 1]):
2877
+ for i in range(self.width_in[l]):
2878
+ out_array = list(np.array(self.symbolic_fun[l].funs_name)[j])
2879
+ in_array = list(np.array(self.symbolic_fun[l].funs_name)[:, i])
2880
+ out_active = len([i for i, x in enumerate(out_array) if x != "0"]) > 0
2881
+ in_active = len([i for i, x in enumerate(in_array) if x != "0"]) > 0
2882
+ dic = {True: 'a', False: 'i'}
2883
+ edge_type = dic[in_active] + dic[out_active]
2884
+
2885
+ if l < self.depth - 1 or mode != 'non-intrusive':
2886
+
2887
+ if edge_type == 'aa':
2888
+ if self.symbolic_fun[l].funs_name[j][i] == '0':
2889
+ edge_type += '_i'
2890
+ else:
2891
+ edge_type += '_a'
2892
+
2893
+ if perturb_bool[edge_type]:
2894
+ self.act_fun[l].mask.data[i][j] = mag
2895
+
2896
+ if l == self.depth - 1 and mode == 'non-intrusive':
2897
+ self.act_fun[l].mask.data[i][j] = torch.tensor(1.)
2898
+ self.act_fun[l].scale_base.data[i][j] = torch.tensor(0.)
2899
+ self.act_fun[l].scale_sp.data[i][j] = torch.tensor(0.)
2900
+
2901
+ self.get_act(self.cache_data)
2902
+
2903
+ self.log_history('perturb')
2904
+
2905
+ def module(self, start_layer, chain):
2906
+ """
2907
+ specify network modules
2908
+
2909
+ Args:
2910
+ -----
2911
+ start_layer : int
2912
+ the earliest layer of the module
2913
+ chain : str
2914
+ specify neurons in the module
2915
+
2916
+ Returns:
2917
+ --------
2918
+ None
2919
+ """
2920
+ #chain = '[-1]->[-1,-2]->[-1]->[-1]'
2921
+ groups = chain.split('->')
2922
+ n_total_layers = len(groups) // 2
2923
+ #start_layer = 0
2924
+
2925
+ for l in range(n_total_layers):
2926
+ current_layer = cl = start_layer + l
2927
+ id_in = [int(i) for i in groups[2 * l][1:-1].split(',')]
2928
+ id_out = [int(i) for i in groups[2 * l + 1][1:-1].split(',')]
2929
+
2930
+ in_dim = self.width_in[cl]
2931
+ out_dim = self.width_out[cl + 1]
2932
+ id_in_other = list(set(range(in_dim)) - set(id_in))
2933
+ id_out_other = list(set(range(out_dim)) - set(id_out))
2934
+ self.act_fun[cl].mask.data[np.ix_(id_in_other, id_out)] = 0.
2935
+ self.act_fun[cl].mask.data[np.ix_(id_in, id_out_other)] = 0.
2936
+ self.symbolic_fun[cl].mask.data[np.ix_(id_out, id_in_other)] = 0.
2937
+ self.symbolic_fun[cl].mask.data[np.ix_(id_out_other, id_in)] = 0.
2938
+
2939
+ self.log_history('module')
2940
+
2941
+ def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
2942
+ """
2943
+ turn KAN into a tree
2944
+ """
2945
+ if x == None:
2946
+ x = self.cache_data
2947
+ plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test,
2948
+ verbose=verbose)
2949
+
2950
+ def speed(self, compile=False):
2951
+ """
2952
+ turn on KAN's speed mode
2953
+ """
2954
+ self.symbolic_enabled = False
2955
+ self.save_act = False
2956
+ self.auto_save = False
2957
+ if compile == True:
2958
+ return torch.compile(self)
2959
+ else:
2960
+ return self
2961
+
2962
+ def get_act(self, x=None):
2963
+ """
2964
+ collect intermidate activations
2965
+ """
2966
+ if isinstance(x, dict):
2967
+ x = x['train_input']
2968
+ if x is None:
2969
+ if self.cache_data != None:
2970
+ x = self.cache_data
2971
+ else:
2972
+ raise Exception("missing input data x")
2973
+ save_act = self.save_act
2974
+ self.save_act = True
2975
+ self.forward(x)
2976
+ self.save_act = save_act
2977
+
2978
+ def get_fun(self, l, i, j):
2979
+ """
2980
+ get function (l,i,j)
2981
+ """
2982
+ inputs = self.spline_preacts[l][:, j, i].cpu().detach().numpy()
2983
+ outputs = self.spline_postacts[l][:, j, i].cpu().detach().numpy()
2984
+ # they are not ordered yet
2985
+ rank = np.argsort(inputs)
2986
+ inputs = inputs[rank]
2987
+ outputs = outputs[rank]
2988
+ plt.figure(figsize=(3, 3))
2989
+ plt.plot(inputs, outputs, marker="o")
2990
+ return inputs, outputs
2991
+
2992
+ def history(self, k='all'):
2993
+ """
2994
+ get history
2995
+ """
2996
+ with open(self.ckpt_path + '/history.txt', 'r') as f:
2997
+ data = f.readlines()
2998
+ n_line = len(data)
2999
+ if k == 'all':
3000
+ k = n_line
3001
+
3002
+ data = data[-k:]
3003
+ for line in data:
3004
+ print(line[:-1])
3005
+
3006
+ @property
3007
+ def n_edge(self):
3008
+ """
3009
+ the number of active edges
3010
+ """
3011
+ depth = len(self.act_fun)
3012
+ complexity = 0
3013
+ for l in range(depth):
3014
+ complexity += torch.sum(self.act_fun[l].mask > 0.)
3015
+ return complexity.item()
3016
+
3017
+ def evaluate(self, dataset):
3018
+ evaluation = {'test_loss': torch.sqrt(
3019
+ torch.mean((self.forward(dataset['test_input']) - dataset['test_label']) ** 2)).item(),
3020
+ 'n_edge': self.n_edge, 'n_grid': self.grid}
3021
+ # add other metrics (maybe accuracy)
3022
+ return evaluation
3023
+
3024
+ def swap(self, l, i1, i2, log_history=True):
3025
+
3026
+ self.act_fun[l - 1].swap(i1, i2, mode='out')
3027
+ self.symbolic_fun[l - 1].swap(i1, i2, mode='out')
3028
+ self.act_fun[l].swap(i1, i2, mode='in')
3029
+ self.symbolic_fun[l].swap(i1, i2, mode='in')
3030
+
3031
+ def swap_(data, i1, i2):
3032
+ data[i1], data[i2] = data[i2], data[i1]
3033
+
3034
+ swap_(self.node_scale[l - 1].data, i1, i2)
3035
+ swap_(self.node_bias[l - 1].data, i1, i2)
3036
+ swap_(self.subnode_scale[l - 1].data, i1, i2)
3037
+ swap_(self.subnode_bias[l - 1].data, i1, i2)
3038
+
3039
+ if log_history:
3040
+ self.log_history('swap')
3041
+
3042
+ @property
3043
+ def connection_cost(self):
3044
+
3045
+ cc = 0.
3046
+ for t in self.edge_scores:
3047
+ def get_coordinate(n):
3048
+ return torch.linspace(0, 1, steps=n + 1, device=self.device)[:n] + 1 / (2 * n)
3049
+
3050
+ in_dim = t.shape[0]
3051
+ x_in = get_coordinate(in_dim)
3052
+
3053
+ out_dim = t.shape[1]
3054
+ x_out = get_coordinate(out_dim)
3055
+
3056
+ dist = torch.abs(x_in[:, None] - x_out[None, :])
3057
+ cc += torch.sum(dist * t)
3058
+
3059
+ return cc
3060
+
3061
+ def auto_swap_l(self, l):
3062
+
3063
+ num = self.width_in[1]
3064
+ for i in range(num):
3065
+ ccs = []
3066
+ for j in range(num):
3067
+ self.swap(l, i, j, log_history=False)
3068
+ self.get_act()
3069
+ self.attribute()
3070
+ cc = self.connection_cost.detach().clone()
3071
+ ccs.append(cc)
3072
+ self.swap(l, i, j, log_history=False)
3073
+ j = torch.argmin(torch.tensor(ccs))
3074
+ self.swap(l, i, j, log_history=False)
3075
+
3076
+ def auto_swap(self):
3077
+ """
3078
+ automatically swap neurons such as connection costs are minimized
3079
+ """
3080
+ depth = self.depth
3081
+ for l in range(1, depth):
3082
+ self.auto_swap_l(l)
3083
+
3084
+ self.log_history('auto_swap')
3085
+
3086
+
3087
+ KAN = MultKAN