yms-kan 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan-0.0.7/LICENSE +21 -0
- yms_kan-0.0.7/PKG-INFO +18 -0
- yms_kan-0.0.7/README.md +1 -0
- yms_kan-0.0.7/kan/KANLayer.py +364 -0
- yms_kan-0.0.7/kan/LBFGS.py +492 -0
- yms_kan-0.0.7/kan/MLP.py +361 -0
- yms_kan-0.0.7/kan/MultKAN.py +3087 -0
- yms_kan-0.0.7/kan/Symbolic_KANLayer.py +270 -0
- yms_kan-0.0.7/kan/__init__.py +3 -0
- yms_kan-0.0.7/kan/compiler.py +498 -0
- yms_kan-0.0.7/kan/dataset.py +27 -0
- yms_kan-0.0.7/kan/experiment.py +50 -0
- yms_kan-0.0.7/kan/feynman.py +739 -0
- yms_kan-0.0.7/kan/hypothesis.py +695 -0
- yms_kan-0.0.7/kan/spline.py +144 -0
- yms_kan-0.0.7/kan/utils.py +661 -0
- yms_kan-0.0.7/setup.cfg +4 -0
- yms_kan-0.0.7/setup.py +96 -0
- yms_kan-0.0.7/yms_kan.egg-info/PKG-INFO +18 -0
- yms_kan-0.0.7/yms_kan.egg-info/SOURCES.txt +20 -0
- yms_kan-0.0.7/yms_kan.egg-info/dependency_links.txt +1 -0
- yms_kan-0.0.7/yms_kan.egg-info/top_level.txt +1 -0
@@ -0,0 +1,3087 @@
|
|
1
|
+
import math
|
2
|
+
import sys
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import numpy as np
|
7
|
+
from PIL import Image
|
8
|
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
9
|
+
|
10
|
+
from .KANLayer import KANLayer
|
11
|
+
#from .Symbolic_MultKANLayer import *
|
12
|
+
from .Symbolic_KANLayer import Symbolic_KANLayer
|
13
|
+
from .LBFGS import *
|
14
|
+
import os
|
15
|
+
import glob
|
16
|
+
import matplotlib.pyplot as plt
|
17
|
+
from tqdm import tqdm
|
18
|
+
import random
|
19
|
+
import copy
|
20
|
+
#from .MultKANLayer import MultKANLayer
|
21
|
+
import pandas as pd
|
22
|
+
from sympy.printing import latex
|
23
|
+
from sympy import *
|
24
|
+
import sympy
|
25
|
+
import yaml
|
26
|
+
from .spline import curve2coef
|
27
|
+
from .utils import SYMBOLIC_LIB
|
28
|
+
from .hypothesis import plot_tree
|
29
|
+
from contextlib import contextmanager
|
30
|
+
import shutil
|
31
|
+
import gc
|
32
|
+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
33
|
+
|
34
|
+
|
35
|
+
class MultKAN(nn.Module):
|
36
|
+
"""
|
37
|
+
KAN class
|
38
|
+
|
39
|
+
Attributes:
|
40
|
+
-----------
|
41
|
+
grid : int
|
42
|
+
the number of grid intervals
|
43
|
+
k : int
|
44
|
+
spline order
|
45
|
+
act_fun : a list of KANLayers
|
46
|
+
symbolic_fun: a list of Symbolic_KANLayer
|
47
|
+
depth : int
|
48
|
+
depth of KAN
|
49
|
+
width : list
|
50
|
+
number of neurons in each layer.
|
51
|
+
Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
|
52
|
+
With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2).
|
53
|
+
mult_arity : int, or list of int lists
|
54
|
+
multiplication arity for each multiplication node (the number of numbers to be multiplied)
|
55
|
+
grid : int
|
56
|
+
the number of grid intervals
|
57
|
+
k : int
|
58
|
+
the order of piecewise polynomial
|
59
|
+
base_fun : fun
|
60
|
+
residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
|
61
|
+
symbolic_fun : a list of Symbolic_KANLayer
|
62
|
+
Symbolic_KANLayers
|
63
|
+
symbolic_enabled : bool
|
64
|
+
If False, the symbolic front is not computed (to save time). Default: True.
|
65
|
+
width_in : list
|
66
|
+
The number of input neurons for each layer
|
67
|
+
width_out : list
|
68
|
+
The number of output neurons for each layer
|
69
|
+
base_fun_name : str
|
70
|
+
The base function b(x)
|
71
|
+
grip_eps : float
|
72
|
+
The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
|
73
|
+
node_bias : a list of 1D torch.float
|
74
|
+
node_scale : a list of 1D torch.float
|
75
|
+
subnode_bias : a list of 1D torch.float
|
76
|
+
subnode_scale : a list of 1D torch.float
|
77
|
+
symbolic_enabled : bool
|
78
|
+
when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
|
79
|
+
affine_trainable : bool
|
80
|
+
indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
|
81
|
+
sp_trainable : bool
|
82
|
+
indicate whether the overall magnitude of splines is trainable
|
83
|
+
sb_trainable : bool
|
84
|
+
indicate whether the overall magnitude of base function is trainable
|
85
|
+
save_act : bool
|
86
|
+
indicate whether intermediate activations are saved in forward pass
|
87
|
+
node_scores : None or list of 1D torch.float
|
88
|
+
node attribution score
|
89
|
+
edge_scores : None or list of 2D torch.float
|
90
|
+
edge attribution score
|
91
|
+
subnode_scores : None or list of 1D torch.float
|
92
|
+
subnode attribution score
|
93
|
+
cache_data : None or 2D torch.float
|
94
|
+
cached input data
|
95
|
+
acts : None or a list of 2D torch.float
|
96
|
+
activations on nodes
|
97
|
+
auto_save : bool
|
98
|
+
indicate whether to automatically save a checkpoint once the model is modified
|
99
|
+
state_id : int
|
100
|
+
the state of the model (used to save checkpoint)
|
101
|
+
ckpt_path : str
|
102
|
+
the folder to store checkpoints
|
103
|
+
round : int
|
104
|
+
the number of times rewind() has been called
|
105
|
+
device : str
|
106
|
+
"""
|
107
|
+
|
108
|
+
def __init__(self, width=None, grid=3, k=3, mult_arity=2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0,
|
109
|
+
base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1],
|
110
|
+
sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True,
|
111
|
+
first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):
|
112
|
+
"""
|
113
|
+
initalize a KAN model
|
114
|
+
|
115
|
+
Args:
|
116
|
+
-----
|
117
|
+
width : list of int
|
118
|
+
Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
|
119
|
+
With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
|
120
|
+
grid : int
|
121
|
+
number of grid intervals. Default: 3.
|
122
|
+
k : int
|
123
|
+
order of piecewise polynomial. Default: 3.
|
124
|
+
mult_arity : int, or list of int lists
|
125
|
+
multiplication arity for each multiplication node (the number of numbers to be multiplied)
|
126
|
+
noise_scale : float
|
127
|
+
initial injected noise to spline.
|
128
|
+
base_fun : str
|
129
|
+
the residual function b(x). Default: 'silu'
|
130
|
+
symbolic_enabled : bool
|
131
|
+
compute (True) or skip (False) symbolic computations (for efficiency). By default: True.
|
132
|
+
affine_trainable : bool
|
133
|
+
affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
|
134
|
+
grid_eps : float
|
135
|
+
When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
|
136
|
+
grid_range : list/np.array of shape (2,))
|
137
|
+
setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
|
138
|
+
sp_trainable : bool
|
139
|
+
If true, scale_sp is trainable. Default: True.
|
140
|
+
sb_trainable : bool
|
141
|
+
If true, scale_base is trainable. Default: True.
|
142
|
+
device : str
|
143
|
+
device
|
144
|
+
seed : int
|
145
|
+
random seed
|
146
|
+
save_act : bool
|
147
|
+
indicate whether intermediate activations are saved in forward pass
|
148
|
+
sparse_init : bool
|
149
|
+
sparse initialization (True) or normal dense initialization. Default: False.
|
150
|
+
auto_save : bool
|
151
|
+
indicate whether to automatically save a checkpoint once the model is modified
|
152
|
+
state_id : int
|
153
|
+
the state of the model (used to save checkpoint)
|
154
|
+
ckpt_path : str
|
155
|
+
the folder to store checkpoints. Default: './model'
|
156
|
+
round : int
|
157
|
+
the number of times rewind() has been called
|
158
|
+
device : str
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
--------
|
162
|
+
self
|
163
|
+
|
164
|
+
Example
|
165
|
+
-------
|
166
|
+
# >>> from kan import *
|
167
|
+
# >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
|
168
|
+
checkpoint directory created: ./model
|
169
|
+
saving model version 0.0
|
170
|
+
"""
|
171
|
+
super(MultKAN, self).__init__()
|
172
|
+
|
173
|
+
torch.manual_seed(seed)
|
174
|
+
np.random.seed(seed)
|
175
|
+
random.seed(seed)
|
176
|
+
|
177
|
+
### initializeing the numerical front ###
|
178
|
+
|
179
|
+
self.act_fun = []
|
180
|
+
self.depth = len(width) - 1
|
181
|
+
|
182
|
+
#print('haha1', width)
|
183
|
+
for i in range(len(width)):
|
184
|
+
#print(type(width[i]), type(width[i]) == int)
|
185
|
+
if type(width[i]) == int or type(width[i]) == np.int64:
|
186
|
+
width[i] = [width[i], 0]
|
187
|
+
|
188
|
+
#print('haha2', width)
|
189
|
+
|
190
|
+
self.width = width
|
191
|
+
|
192
|
+
# if mult_arity is just a scalar, we extend it to a list of lists
|
193
|
+
# e.g, mult_arity = [[2,3],[4]] means that in the first hidden layer, 2 mult ops have arity 2 and 3, respectively;
|
194
|
+
# in the second hidden layer, 1 mult op has arity 4.
|
195
|
+
if isinstance(mult_arity, int):
|
196
|
+
self.mult_homo = True # when homo is True, parallelization is possible
|
197
|
+
else:
|
198
|
+
self.mult_homo = False # when home if False, for loop is required.
|
199
|
+
self.mult_arity = mult_arity
|
200
|
+
|
201
|
+
width_in = self.width_in
|
202
|
+
width_out = self.width_out
|
203
|
+
|
204
|
+
self.base_fun_name = base_fun
|
205
|
+
if base_fun == 'silu':
|
206
|
+
base_fun = torch.nn.SiLU()
|
207
|
+
elif base_fun == 'identity':
|
208
|
+
base_fun = torch.nn.Identity()
|
209
|
+
elif base_fun == 'zero':
|
210
|
+
base_fun = lambda x: x * 0.
|
211
|
+
|
212
|
+
self.grid_eps = grid_eps
|
213
|
+
self.grid_range = grid_range
|
214
|
+
|
215
|
+
for l in range(self.depth):
|
216
|
+
# splines
|
217
|
+
if isinstance(grid, list):
|
218
|
+
grid_l = grid[l]
|
219
|
+
else:
|
220
|
+
grid_l = grid
|
221
|
+
|
222
|
+
if isinstance(k, list):
|
223
|
+
k_l = k[l]
|
224
|
+
else:
|
225
|
+
k_l = k
|
226
|
+
|
227
|
+
sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l + 1], num=grid_l, k=k_l,
|
228
|
+
noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma,
|
229
|
+
scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range,
|
230
|
+
sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
|
231
|
+
self.act_fun.append(sp_batch)
|
232
|
+
|
233
|
+
self.node_bias = []
|
234
|
+
self.node_scale = []
|
235
|
+
self.subnode_bias = []
|
236
|
+
self.subnode_scale = []
|
237
|
+
|
238
|
+
globals()['self.node_bias_0'] = torch.nn.Parameter(torch.zeros(3, 1)).requires_grad_(False)
|
239
|
+
exec('self.node_bias_0' + " = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)")
|
240
|
+
|
241
|
+
for l in range(self.depth):
|
242
|
+
exec(
|
243
|
+
f'self.node_bias_{l} = torch.nn.Parameter(torch.zeros(width_in[l+1])).requires_grad_(affine_trainable)')
|
244
|
+
exec(
|
245
|
+
f'self.node_scale_{l} = torch.nn.Parameter(torch.ones(width_in[l+1])).requires_grad_(affine_trainable)')
|
246
|
+
exec(
|
247
|
+
f'self.subnode_bias_{l} = torch.nn.Parameter(torch.zeros(width_out[l+1])).requires_grad_(affine_trainable)')
|
248
|
+
exec(
|
249
|
+
f'self.subnode_scale_{l} = torch.nn.Parameter(torch.ones(width_out[l+1])).requires_grad_(affine_trainable)')
|
250
|
+
exec(f'self.node_bias.append(self.node_bias_{l})')
|
251
|
+
exec(f'self.node_scale.append(self.node_scale_{l})')
|
252
|
+
exec(f'self.subnode_bias.append(self.subnode_bias_{l})')
|
253
|
+
exec(f'self.subnode_scale.append(self.subnode_scale_{l})')
|
254
|
+
|
255
|
+
self.act_fun = nn.ModuleList(self.act_fun)
|
256
|
+
|
257
|
+
self.grid = grid
|
258
|
+
self.k = k
|
259
|
+
self.base_fun = base_fun
|
260
|
+
|
261
|
+
### initializing the symbolic front ###
|
262
|
+
self.symbolic_fun = []
|
263
|
+
for l in range(self.depth):
|
264
|
+
sb_batch = Symbolic_KANLayer(in_dim=width_in[l], out_dim=width_out[l + 1])
|
265
|
+
self.symbolic_fun.append(sb_batch)
|
266
|
+
|
267
|
+
self.symbolic_fun = nn.ModuleList(self.symbolic_fun)
|
268
|
+
self.symbolic_enabled = symbolic_enabled
|
269
|
+
self.affine_trainable = affine_trainable
|
270
|
+
self.sp_trainable = sp_trainable
|
271
|
+
self.sb_trainable = sb_trainable
|
272
|
+
|
273
|
+
self.save_act = save_act
|
274
|
+
|
275
|
+
self.node_scores = None
|
276
|
+
self.edge_scores = None
|
277
|
+
self.subnode_scores = None
|
278
|
+
|
279
|
+
self.cache_data = None
|
280
|
+
self.acts = None
|
281
|
+
|
282
|
+
self.auto_save = auto_save
|
283
|
+
self.state_id = 0
|
284
|
+
self.ckpt_path = ckpt_path
|
285
|
+
self.round = round
|
286
|
+
|
287
|
+
self.device = device
|
288
|
+
self.to(device)
|
289
|
+
|
290
|
+
if auto_save:
|
291
|
+
if first_init:
|
292
|
+
if not os.path.exists(ckpt_path):
|
293
|
+
# Create the directory
|
294
|
+
os.makedirs(ckpt_path)
|
295
|
+
print(f"checkpoint directory created: {ckpt_path}")
|
296
|
+
print('saving model version 0.0')
|
297
|
+
|
298
|
+
history_path = self.ckpt_path + '/history.txt'
|
299
|
+
with open(history_path, 'w') as file:
|
300
|
+
file.write(f'### Round {self.round} ###' + '\n')
|
301
|
+
file.write('init => 0.0' + '\n')
|
302
|
+
self.saveckpt(path=self.ckpt_path + '/' + '0.0')
|
303
|
+
else:
|
304
|
+
self.state_id = state_id
|
305
|
+
|
306
|
+
self.input_id = torch.arange(self.width_in[0], )
|
307
|
+
|
308
|
+
def to(self, device):
|
309
|
+
'''
|
310
|
+
move the model to device
|
311
|
+
|
312
|
+
Args:
|
313
|
+
-----
|
314
|
+
device : str or device
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
--------
|
318
|
+
self
|
319
|
+
|
320
|
+
Example
|
321
|
+
-------
|
322
|
+
# >>> from kan import *
|
323
|
+
# >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
324
|
+
# >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
|
325
|
+
# >>> model.to(device)
|
326
|
+
'''
|
327
|
+
super(MultKAN, self).to(device)
|
328
|
+
self.device = device
|
329
|
+
|
330
|
+
for kanlayer in self.act_fun:
|
331
|
+
kanlayer.to(device)
|
332
|
+
|
333
|
+
for symbolic_kanlayer in self.symbolic_fun:
|
334
|
+
symbolic_kanlayer.to(device)
|
335
|
+
|
336
|
+
return self
|
337
|
+
|
338
|
+
@property
|
339
|
+
def width_in(self):
|
340
|
+
'''
|
341
|
+
The number of input nodes for each layer
|
342
|
+
'''
|
343
|
+
width = self.width
|
344
|
+
width_in = [width[l][0] + width[l][1] for l in range(len(width))]
|
345
|
+
return width_in
|
346
|
+
|
347
|
+
@property
|
348
|
+
def width_out(self):
|
349
|
+
'''
|
350
|
+
The number of output subnodes for each layer
|
351
|
+
'''
|
352
|
+
width = self.width
|
353
|
+
if self.mult_homo == True:
|
354
|
+
width_out = [width[l][0] + self.mult_arity * width[l][1] for l in range(len(width))]
|
355
|
+
else:
|
356
|
+
width_out = [width[l][0] + int(np.sum(self.mult_arity[l])) for l in range(len(width))]
|
357
|
+
return width_out
|
358
|
+
|
359
|
+
@property
|
360
|
+
def n_sum(self):
|
361
|
+
"""
|
362
|
+
The number of addition nodes for each layer
|
363
|
+
"""
|
364
|
+
width = self.width
|
365
|
+
n_sum = [width[l][0] for l in range(1, len(width) - 1)]
|
366
|
+
return n_sum
|
367
|
+
|
368
|
+
@property
|
369
|
+
def n_mult(self):
|
370
|
+
"""
|
371
|
+
The number of multiplication nodes for each layer
|
372
|
+
"""
|
373
|
+
width = self.width
|
374
|
+
n_mult = [width[l][1] for l in range(1, len(width) - 1)]
|
375
|
+
return n_mult
|
376
|
+
|
377
|
+
@property
|
378
|
+
def feature_score(self):
|
379
|
+
'''
|
380
|
+
attribution scores for inputs
|
381
|
+
'''
|
382
|
+
self.attribute()
|
383
|
+
if self.node_scores == None:
|
384
|
+
return None
|
385
|
+
else:
|
386
|
+
return self.node_scores[0]
|
387
|
+
|
388
|
+
def initialize_from_another_model(self, another_model, x):
|
389
|
+
'''
|
390
|
+
initialize from another model of the same width, but their 'grid' parameter can be different.
|
391
|
+
Note this is equivalent to refine() when we don't want to keep another_model
|
392
|
+
|
393
|
+
Args:
|
394
|
+
-----
|
395
|
+
another_model : MultKAN
|
396
|
+
x : 2D torch.float
|
397
|
+
|
398
|
+
Returns:
|
399
|
+
--------
|
400
|
+
self
|
401
|
+
|
402
|
+
Example
|
403
|
+
-------
|
404
|
+
# >>> from kan import *
|
405
|
+
# >>> model1 = KAN(width=[2,5,1], grid=3)
|
406
|
+
# >>> model2 = KAN(width=[2,5,1], grid=10)
|
407
|
+
# >>> x = torch.rand(100,2)
|
408
|
+
# >>> model2.initialize_from_another_model(model1, x)
|
409
|
+
'''
|
410
|
+
another_model(x) # get activations
|
411
|
+
batch = x.shape[0]
|
412
|
+
|
413
|
+
self.initialize_grid_from_another_model(another_model, x)
|
414
|
+
|
415
|
+
for l in range(self.depth):
|
416
|
+
spb = self.act_fun[l]
|
417
|
+
#spb_parent = another_model.act_fun[l]
|
418
|
+
|
419
|
+
# spb = spb_parent
|
420
|
+
preacts = another_model.spline_preacts[l]
|
421
|
+
postsplines = another_model.spline_postsplines[l]
|
422
|
+
self.act_fun[l].coef.data = curve2coef(preacts[:, 0, :], postsplines.permute(0, 2, 1), spb.grid, k=spb.k)
|
423
|
+
self.act_fun[l].scale_base.data = another_model.act_fun[l].scale_base.data
|
424
|
+
self.act_fun[l].scale_sp.data = another_model.act_fun[l].scale_sp.data
|
425
|
+
self.act_fun[l].mask.data = another_model.act_fun[l].mask.data
|
426
|
+
|
427
|
+
for l in range(self.depth):
|
428
|
+
self.node_bias[l].data = another_model.node_bias[l].data
|
429
|
+
self.node_scale[l].data = another_model.node_scale[l].data
|
430
|
+
|
431
|
+
self.subnode_bias[l].data = another_model.subnode_bias[l].data
|
432
|
+
self.subnode_scale[l].data = another_model.subnode_scale[l].data
|
433
|
+
|
434
|
+
for l in range(self.depth):
|
435
|
+
self.symbolic_fun[l] = another_model.symbolic_fun[l]
|
436
|
+
|
437
|
+
return self.to(self.device)
|
438
|
+
|
439
|
+
def log_history(self, method_name):
|
440
|
+
|
441
|
+
if self.auto_save:
|
442
|
+
# save to log file
|
443
|
+
#print(func.__name__)
|
444
|
+
with open(self.ckpt_path + '/history.txt', 'a') as file:
|
445
|
+
file.write(str(self.round) + '.' + str(self.state_id) + ' => ' + method_name + ' => ' + str(
|
446
|
+
self.round) + '.' + str(self.state_id + 1) + '\n')
|
447
|
+
|
448
|
+
# update state_id
|
449
|
+
self.state_id += 1
|
450
|
+
|
451
|
+
# save to ckpt
|
452
|
+
self.saveckpt(path=self.ckpt_path + '/' + str(self.round) + '.' + str(self.state_id))
|
453
|
+
print('saving model version ' + str(self.round) + '.' + str(self.state_id))
|
454
|
+
|
455
|
+
def refine(self, new_grid):
|
456
|
+
'''
|
457
|
+
grid refinement
|
458
|
+
|
459
|
+
Args:
|
460
|
+
-----
|
461
|
+
new_grid : init
|
462
|
+
the number of grid intervals after refinement
|
463
|
+
|
464
|
+
Returns:
|
465
|
+
--------
|
466
|
+
a refined model : MultKAN
|
467
|
+
|
468
|
+
Example
|
469
|
+
-------
|
470
|
+
# >>> from kan import *
|
471
|
+
# >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
472
|
+
# >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
|
473
|
+
# >>> print(model.grid)
|
474
|
+
# >>> x = torch.rand(100,2)
|
475
|
+
# >>> model.get_act(x)
|
476
|
+
# >>> model = model.refine(10)
|
477
|
+
# >>> print(model.grid)
|
478
|
+
checkpoint directory created: ./model
|
479
|
+
saving model version 0.0
|
480
|
+
5
|
481
|
+
saving model version 0.1
|
482
|
+
10
|
483
|
+
'''
|
484
|
+
|
485
|
+
model_new = MultKAN(width=self.width,
|
486
|
+
grid=new_grid,
|
487
|
+
k=self.k,
|
488
|
+
mult_arity=self.mult_arity,
|
489
|
+
base_fun=self.base_fun_name,
|
490
|
+
symbolic_enabled=self.symbolic_enabled,
|
491
|
+
affine_trainable=self.affine_trainable,
|
492
|
+
grid_eps=self.grid_eps,
|
493
|
+
grid_range=self.grid_range,
|
494
|
+
sp_trainable=self.sp_trainable,
|
495
|
+
sb_trainable=self.sb_trainable,
|
496
|
+
ckpt_path=self.ckpt_path,
|
497
|
+
auto_save=True,
|
498
|
+
first_init=False,
|
499
|
+
state_id=self.state_id,
|
500
|
+
round=self.round,
|
501
|
+
device=self.device)
|
502
|
+
|
503
|
+
model_new.initialize_from_another_model(self, self.cache_data)
|
504
|
+
model_new.cache_data = self.cache_data
|
505
|
+
model_new.grid = new_grid
|
506
|
+
|
507
|
+
self.log_history('refine')
|
508
|
+
model_new.state_id += 1
|
509
|
+
|
510
|
+
return model_new.to(self.device)
|
511
|
+
|
512
|
+
def saveckpt(self, path='model'):
|
513
|
+
'''
|
514
|
+
save the current model to files (configuration file and state file)
|
515
|
+
|
516
|
+
Args:
|
517
|
+
-----
|
518
|
+
path : str
|
519
|
+
the path where checkpoints are saved
|
520
|
+
|
521
|
+
Returns:
|
522
|
+
--------
|
523
|
+
None
|
524
|
+
|
525
|
+
Example
|
526
|
+
-------
|
527
|
+
# >>> from kan import *
|
528
|
+
# >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
529
|
+
# >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
|
530
|
+
# >>> model.saveckpt('./mark')
|
531
|
+
# There will be three files appearing in the current folder: mark_cache_data, mark_config.yml, mark_state
|
532
|
+
'''
|
533
|
+
|
534
|
+
model = self
|
535
|
+
|
536
|
+
dic = dict(
|
537
|
+
width=model.width,
|
538
|
+
grid=model.grid,
|
539
|
+
k=model.k,
|
540
|
+
mult_arity=model.mult_arity,
|
541
|
+
base_fun_name=model.base_fun_name,
|
542
|
+
symbolic_enabled=model.symbolic_enabled,
|
543
|
+
affine_trainable=model.affine_trainable,
|
544
|
+
grid_eps=model.grid_eps,
|
545
|
+
grid_range=model.grid_range,
|
546
|
+
sp_trainable=model.sp_trainable,
|
547
|
+
sb_trainable=model.sb_trainable,
|
548
|
+
state_id=model.state_id,
|
549
|
+
auto_save=model.auto_save,
|
550
|
+
ckpt_path=model.ckpt_path,
|
551
|
+
round=model.round,
|
552
|
+
device=str(model.device)
|
553
|
+
)
|
554
|
+
|
555
|
+
if dic["device"].isdigit():
|
556
|
+
dic["device"] = int(model.device)
|
557
|
+
|
558
|
+
for i in range(model.depth):
|
559
|
+
dic[f'symbolic.funs_name.{i}'] = model.symbolic_fun[i].funs_name
|
560
|
+
|
561
|
+
with open(f'{path}_config.yml', 'w') as outfile:
|
562
|
+
yaml.dump(dic, outfile, default_flow_style=False)
|
563
|
+
|
564
|
+
torch.save(model.state_dict(), f'{path}_state')
|
565
|
+
torch.save(model.cache_data, f'{path}_cache_data')
|
566
|
+
|
567
|
+
@staticmethod
|
568
|
+
def loadckpt(path='model'):
|
569
|
+
'''
|
570
|
+
load checkpoint from path
|
571
|
+
|
572
|
+
Args:
|
573
|
+
-----
|
574
|
+
path : str
|
575
|
+
the path where checkpoints are saved
|
576
|
+
|
577
|
+
Returns:
|
578
|
+
--------
|
579
|
+
MultKAN
|
580
|
+
|
581
|
+
Example
|
582
|
+
-------
|
583
|
+
# >>> from kan import *
|
584
|
+
# >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
|
585
|
+
# >>> model.saveckpt('./mark')
|
586
|
+
# >>> KAN.loadckpt('./mark')
|
587
|
+
'''
|
588
|
+
with open(f'{path}_config.yml', 'r') as stream:
|
589
|
+
config = yaml.safe_load(stream)
|
590
|
+
|
591
|
+
state = torch.load(f'{path}_state', weights_only=False)
|
592
|
+
|
593
|
+
model_load = MultKAN(width=config['width'],
|
594
|
+
grid=config['grid'],
|
595
|
+
k=config['k'],
|
596
|
+
mult_arity=config['mult_arity'],
|
597
|
+
base_fun=config['base_fun_name'],
|
598
|
+
symbolic_enabled=config['symbolic_enabled'],
|
599
|
+
affine_trainable=config['affine_trainable'],
|
600
|
+
grid_eps=config['grid_eps'],
|
601
|
+
grid_range=config['grid_range'],
|
602
|
+
sp_trainable=config['sp_trainable'],
|
603
|
+
sb_trainable=config['sb_trainable'],
|
604
|
+
state_id=config['state_id'],
|
605
|
+
auto_save=config['auto_save'],
|
606
|
+
first_init=False,
|
607
|
+
ckpt_path=config['ckpt_path'],
|
608
|
+
round=config['round'] + 1,
|
609
|
+
device=config['device'])
|
610
|
+
|
611
|
+
model_load.load_state_dict(state)
|
612
|
+
model_load.cache_data = torch.load(f'{path}_cache_data')
|
613
|
+
|
614
|
+
depth = len(model_load.width) - 1
|
615
|
+
for l in range(depth):
|
616
|
+
out_dim = model_load.symbolic_fun[l].out_dim
|
617
|
+
in_dim = model_load.symbolic_fun[l].in_dim
|
618
|
+
funs_name = config[f'symbolic.funs_name.{l}']
|
619
|
+
for j in range(out_dim):
|
620
|
+
for i in range(in_dim):
|
621
|
+
fun_name = funs_name[j][i]
|
622
|
+
model_load.symbolic_fun[l].funs_name[j][i] = fun_name
|
623
|
+
model_load.symbolic_fun[l].funs[j][i] = SYMBOLIC_LIB[fun_name][0]
|
624
|
+
model_load.symbolic_fun[l].funs_sympy[j][i] = SYMBOLIC_LIB[fun_name][1]
|
625
|
+
model_load.symbolic_fun[l].funs_avoid_singularity[j][i] = SYMBOLIC_LIB[fun_name][3]
|
626
|
+
return model_load
|
627
|
+
|
628
|
+
def copy(self):
|
629
|
+
"""
|
630
|
+
deepcopy
|
631
|
+
|
632
|
+
Args:
|
633
|
+
-----
|
634
|
+
path : str
|
635
|
+
the path where checkpoints are saved
|
636
|
+
|
637
|
+
Returns:
|
638
|
+
--------
|
639
|
+
MultKAN
|
640
|
+
|
641
|
+
Example
|
642
|
+
-------
|
643
|
+
# >>> from kan import *
|
644
|
+
# >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
|
645
|
+
# >>> model2 = model.copy()
|
646
|
+
# >>> model2.act_fun[0].coef.data *= 2
|
647
|
+
# >>> print(model2.act_fun[0].coef.data)
|
648
|
+
# >>> print(model.act_fun[0].coef.data)
|
649
|
+
"""
|
650
|
+
path = 'copy_temp'
|
651
|
+
self.saveckpt(path)
|
652
|
+
return KAN.loadckpt(path)
|
653
|
+
|
654
|
+
def rewind(self, model_id):
|
655
|
+
"""
|
656
|
+
rewind to an old version
|
657
|
+
|
658
|
+
Args:
|
659
|
+
-----
|
660
|
+
model_id : str
|
661
|
+
in format '{a}.{b}' where a is the round number, b is the version number in that round
|
662
|
+
|
663
|
+
Returns:
|
664
|
+
--------
|
665
|
+
MultKAN
|
666
|
+
|
667
|
+
Example
|
668
|
+
-------
|
669
|
+
Please refer to tutorials. API 12: Checkpoint, save & load model
|
670
|
+
"""
|
671
|
+
self.round += 1
|
672
|
+
self.state_id = model_id.split('.')[-1]
|
673
|
+
|
674
|
+
history_path = self.ckpt_path + '/history.txt'
|
675
|
+
with open(history_path, 'a') as file:
|
676
|
+
file.write(f'### Round {self.round} ###' + '\n')
|
677
|
+
|
678
|
+
self.saveckpt(path=self.ckpt_path + '/' + f'{self.round}.{self.state_id}')
|
679
|
+
|
680
|
+
print(
|
681
|
+
'rewind to model version ' + f'{self.round - 1}.{self.state_id}' + ', renamed as ' + f'{self.round}.{self.state_id}')
|
682
|
+
|
683
|
+
return MultKAN.loadckpt(path=self.ckpt_path + '/' + str(model_id))
|
684
|
+
|
685
|
+
def checkout(self, model_id):
|
686
|
+
"""
|
687
|
+
check out an old version
|
688
|
+
|
689
|
+
Args:
|
690
|
+
-----
|
691
|
+
model_id : str
|
692
|
+
in format '{a}.{b}' where a is the round number, b is the version number in that round
|
693
|
+
|
694
|
+
Returns:
|
695
|
+
--------
|
696
|
+
MultKAN
|
697
|
+
|
698
|
+
Example
|
699
|
+
-------
|
700
|
+
Same use as rewind, although checkout doesn't change states
|
701
|
+
"""
|
702
|
+
return MultKAN.loadckpt(path=self.ckpt_path + '/' + str(model_id))
|
703
|
+
|
704
|
+
def update_grid_from_samples(self, x):
|
705
|
+
"""
|
706
|
+
update grid from samples
|
707
|
+
|
708
|
+
Args:
|
709
|
+
-----
|
710
|
+
x : 2D torch.tensor
|
711
|
+
inputs
|
712
|
+
|
713
|
+
Returns:
|
714
|
+
--------
|
715
|
+
None
|
716
|
+
|
717
|
+
Example
|
718
|
+
-------
|
719
|
+
# >>> from kan import *
|
720
|
+
# >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
|
721
|
+
# >>> print(model.act_fun[0].grid)
|
722
|
+
# >>> x = torch.linspace(-10,10,steps=101)[:,None]
|
723
|
+
# >>> model.update_grid_from_samples(x)
|
724
|
+
# >>> print(model.act_fun[0].grid)
|
725
|
+
"""
|
726
|
+
for l in range(self.depth):
|
727
|
+
self.get_act(x)
|
728
|
+
self.act_fun[l].update_grid_from_samples(self.acts[l])
|
729
|
+
|
730
|
+
def update_grid(self, x):
|
731
|
+
"""
|
732
|
+
call update_grid_from_samples. This seems unnecessary but we retain it for the sake of classes that might inherit from MultKAN
|
733
|
+
"""
|
734
|
+
self.update_grid_from_samples(x)
|
735
|
+
|
736
|
+
def initialize_grid_from_another_model(self, model, x):
|
737
|
+
"""
|
738
|
+
initialize grid from another model
|
739
|
+
|
740
|
+
Args:
|
741
|
+
-----
|
742
|
+
model : MultKAN
|
743
|
+
parent model
|
744
|
+
x : 2D torch.tensor
|
745
|
+
inputs
|
746
|
+
|
747
|
+
Returns:
|
748
|
+
--------
|
749
|
+
None
|
750
|
+
|
751
|
+
Example
|
752
|
+
-------
|
753
|
+
# >>> from kan import *
|
754
|
+
# >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
|
755
|
+
# >>> print(model.act_fun[0].grid)
|
756
|
+
# >>> x = torch.linspace(-10,10,steps=101)[:,None]
|
757
|
+
# >>> model2 = KAN(width=[1,1], grid=10, k=3, seed=0)
|
758
|
+
# >>> model2.initialize_grid_from_another_model(model, x)
|
759
|
+
# >>> print(model2.act_fun[0].grid)
|
760
|
+
"""
|
761
|
+
model(x)
|
762
|
+
for l in range(self.depth):
|
763
|
+
self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l])
|
764
|
+
|
765
|
+
def forward(self, x, singularity_avoiding=False, y_th=10.):
|
766
|
+
'''
|
767
|
+
forward pass
|
768
|
+
|
769
|
+
Args:
|
770
|
+
-----
|
771
|
+
x : 2D torch.tensor
|
772
|
+
inputs
|
773
|
+
singularity_avoiding : bool
|
774
|
+
whether to avoid singularity for the symbolic branch
|
775
|
+
y_th : float
|
776
|
+
the threshold for singularity
|
777
|
+
|
778
|
+
Returns:
|
779
|
+
--------
|
780
|
+
None
|
781
|
+
|
782
|
+
Example1
|
783
|
+
--------
|
784
|
+
# >>> from kan import *
|
785
|
+
# >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
|
786
|
+
# >>> x = torch.rand(100,2)
|
787
|
+
# >>> model(x).shape
|
788
|
+
|
789
|
+
Example2
|
790
|
+
--------
|
791
|
+
# >>> from kan import *
|
792
|
+
# >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
|
793
|
+
# >>> x = torch.tensor([[1],[-0.01]])
|
794
|
+
# >>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False)
|
795
|
+
# >>> print(model(x))
|
796
|
+
# >>> print(model(x, singularity_avoiding=True))
|
797
|
+
# >>> print(model(x, singularity_avoiding=True, y_th=1.))
|
798
|
+
'''
|
799
|
+
x = x[:, self.input_id.long()]
|
800
|
+
assert x.shape[1] == self.width_in[0]
|
801
|
+
|
802
|
+
# cache data
|
803
|
+
self.cache_data = x
|
804
|
+
|
805
|
+
self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
|
806
|
+
self.acts_premult = []
|
807
|
+
self.spline_preacts = []
|
808
|
+
self.spline_postsplines = []
|
809
|
+
self.spline_postacts = []
|
810
|
+
self.acts_scale = []
|
811
|
+
self.acts_scale_spline = []
|
812
|
+
self.subnode_actscale = []
|
813
|
+
self.edge_actscale = []
|
814
|
+
# self.neurons_scale = []
|
815
|
+
|
816
|
+
self.acts.append(x) # acts shape: (batch, width[l])
|
817
|
+
|
818
|
+
for l in range(self.depth):
|
819
|
+
|
820
|
+
x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
|
821
|
+
#print(preacts, postacts_numerical, postspline)
|
822
|
+
|
823
|
+
if self.symbolic_enabled == True:
|
824
|
+
x_symbolic, postacts_symbolic = self.symbolic_fun[l](x, singularity_avoiding=singularity_avoiding,
|
825
|
+
y_th=y_th)
|
826
|
+
else:
|
827
|
+
x_symbolic = 0.
|
828
|
+
postacts_symbolic = 0.
|
829
|
+
|
830
|
+
x = x_numerical + x_symbolic
|
831
|
+
|
832
|
+
if self.save_act:
|
833
|
+
# save subnode_scale
|
834
|
+
self.subnode_actscale.append(torch.std(x, dim=0).detach())
|
835
|
+
|
836
|
+
# subnode affine transform
|
837
|
+
x = self.subnode_scale[l][None, :] * x + self.subnode_bias[l][None, :]
|
838
|
+
|
839
|
+
if self.save_act:
|
840
|
+
postacts = postacts_numerical + postacts_symbolic
|
841
|
+
|
842
|
+
# self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))
|
843
|
+
#grid_reshape = self.act_fun[l].grid.reshape(self.width_out[l + 1], self.width_in[l], -1)
|
844
|
+
input_range = torch.std(preacts, dim=0) + 0.1
|
845
|
+
output_range_spline = torch.std(postacts_numerical,
|
846
|
+
dim=0) # for training, only penalize the spline part
|
847
|
+
output_range = torch.std(postacts,
|
848
|
+
dim=0) # for visualization, include the contribution from both spline + symbolic
|
849
|
+
# save edge_scale
|
850
|
+
self.edge_actscale.append(output_range)
|
851
|
+
|
852
|
+
self.acts_scale.append((output_range / input_range).detach())
|
853
|
+
self.acts_scale_spline.append(output_range_spline / input_range)
|
854
|
+
self.spline_preacts.append(preacts.detach())
|
855
|
+
self.spline_postacts.append(postacts.detach())
|
856
|
+
self.spline_postsplines.append(postspline.detach())
|
857
|
+
|
858
|
+
self.acts_premult.append(x.detach())
|
859
|
+
|
860
|
+
# multiplication
|
861
|
+
dim_sum = self.width[l + 1][0]
|
862
|
+
dim_mult = self.width[l + 1][1]
|
863
|
+
|
864
|
+
if self.mult_homo:
|
865
|
+
for i in range(self.mult_arity - 1):
|
866
|
+
if i == 0:
|
867
|
+
x_mult = x[:, dim_sum::self.mult_arity] * x[:, dim_sum + 1::self.mult_arity]
|
868
|
+
else:
|
869
|
+
x_mult = x_mult * x[:, dim_sum + i + 1::self.mult_arity]
|
870
|
+
|
871
|
+
else:
|
872
|
+
for j in range(dim_mult):
|
873
|
+
acml_id = dim_sum + np.sum(self.mult_arity[l + 1][:j])
|
874
|
+
for i in range(self.mult_arity[l + 1][j] - 1):
|
875
|
+
if i == 0:
|
876
|
+
x_mult_j = x[:, [acml_id]] * x[:, [acml_id + 1]]
|
877
|
+
else:
|
878
|
+
x_mult_j = x_mult_j * x[:, [acml_id + i + 1]]
|
879
|
+
|
880
|
+
if j == 0:
|
881
|
+
x_mult = x_mult_j
|
882
|
+
else:
|
883
|
+
x_mult = torch.cat([x_mult, x_mult_j], dim=1)
|
884
|
+
|
885
|
+
if self.width[l + 1][1] > 0:
|
886
|
+
x = torch.cat([x[:, :dim_sum], x_mult], dim=1)
|
887
|
+
|
888
|
+
# x = x + self.biases[l].weight
|
889
|
+
# node affine transform
|
890
|
+
x = self.node_scale[l][None, :] * x + self.node_bias[l][None, :]
|
891
|
+
|
892
|
+
self.acts.append(x.detach())
|
893
|
+
|
894
|
+
return x
|
895
|
+
|
896
|
+
def set_mode(self, l, i, j, mode, mask_n=None):
|
897
|
+
if mode == "s":
|
898
|
+
mask_n = 0.;
|
899
|
+
mask_s = 1.
|
900
|
+
elif mode == "n":
|
901
|
+
mask_n = 1.;
|
902
|
+
mask_s = 0.
|
903
|
+
elif mode == "sn" or mode == "ns":
|
904
|
+
if mask_n == None:
|
905
|
+
mask_n = 1.
|
906
|
+
else:
|
907
|
+
mask_n = mask_n
|
908
|
+
mask_s = 1.
|
909
|
+
else:
|
910
|
+
mask_n = 0.;
|
911
|
+
mask_s = 0.
|
912
|
+
|
913
|
+
self.act_fun[l].mask.data[i][j] = mask_n
|
914
|
+
self.symbolic_fun[l].mask.data[j, i] = mask_s
|
915
|
+
|
916
|
+
def fix_symbolic(self, l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True,
|
917
|
+
random=False, log_history=True):
|
918
|
+
'''
|
919
|
+
set (l,i,j) activation to be symbolic (specified by fun_name)
|
920
|
+
|
921
|
+
Args:
|
922
|
+
-----
|
923
|
+
l : int
|
924
|
+
layer index
|
925
|
+
i : int
|
926
|
+
input neuron index
|
927
|
+
j : int
|
928
|
+
output neuron index
|
929
|
+
fun_name : str
|
930
|
+
function name
|
931
|
+
fit_params_bool : bool
|
932
|
+
obtaining affine parameters through fitting (True) or setting default values (False)
|
933
|
+
a_range : tuple
|
934
|
+
sweeping range of a
|
935
|
+
b_range : tuple
|
936
|
+
sweeping range of b
|
937
|
+
verbose : bool
|
938
|
+
If True, more information is printed.
|
939
|
+
random : bool
|
940
|
+
initialize affine parameteres randomly or as [1,0,1,0]
|
941
|
+
log_history : bool
|
942
|
+
indicate whether to log history when the function is called
|
943
|
+
|
944
|
+
Returns:
|
945
|
+
--------
|
946
|
+
None or r2 (coefficient of determination)
|
947
|
+
|
948
|
+
Example 1
|
949
|
+
---------
|
950
|
+
>>> # when fit_params_bool = False
|
951
|
+
>>> model = KAN(width=[2,5,1], grid=5, k=3)
|
952
|
+
>>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False)
|
953
|
+
>>> print(model.act_fun[0].mask.reshape(2,5))
|
954
|
+
>>> print(model.symbolic_fun[0].mask.reshape(2,5))
|
955
|
+
|
956
|
+
Example 2
|
957
|
+
---------
|
958
|
+
>>> # when fit_params_bool = True
|
959
|
+
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.)
|
960
|
+
>>> x = torch.normal(0,1,size=(100,2))
|
961
|
+
>>> model(x) # obtain activations (otherwise model does not have attributes acts)
|
962
|
+
>>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True)
|
963
|
+
>>> print(model.act_fun[0].mask.reshape(2,5))
|
964
|
+
>>> print(model.symbolic_fun[0].mask.reshape(2,5))
|
965
|
+
'''
|
966
|
+
if not fit_params_bool:
|
967
|
+
self.symbolic_fun[l].fix_symbolic(i, j, fun_name, verbose=verbose, random=random)
|
968
|
+
r2 = None
|
969
|
+
else:
|
970
|
+
x = self.acts[l][:, i]
|
971
|
+
mask = self.act_fun[l].mask
|
972
|
+
y = self.spline_postacts[l][:, j, i]
|
973
|
+
#y = self.postacts[l][:, j, i]
|
974
|
+
r2 = self.symbolic_fun[l].fix_symbolic(i, j, fun_name, x, y, a_range=a_range, b_range=b_range,
|
975
|
+
verbose=verbose)
|
976
|
+
if mask[i, j] == 0:
|
977
|
+
r2 = - 1e8
|
978
|
+
self.set_mode(l, i, j, mode="s")
|
979
|
+
|
980
|
+
if log_history:
|
981
|
+
self.log_history('fix_symbolic')
|
982
|
+
return r2
|
983
|
+
|
984
|
+
def unfix_symbolic(self, l, i, j, log_history=True):
|
985
|
+
'''
|
986
|
+
unfix the (l,i,j) activation function.
|
987
|
+
'''
|
988
|
+
self.set_mode(l, i, j, mode="n")
|
989
|
+
self.symbolic_fun[l].funs_name[j][i] = "0"
|
990
|
+
if log_history:
|
991
|
+
self.log_history('unfix_symbolic')
|
992
|
+
|
993
|
+
def unfix_symbolic_all(self, log_history=True):
|
994
|
+
'''
|
995
|
+
unfix all activation functions.
|
996
|
+
'''
|
997
|
+
for l in range(len(self.width) - 1):
|
998
|
+
for i in range(self.width_in[l]):
|
999
|
+
for j in range(self.width_out[l + 1]):
|
1000
|
+
self.unfix_symbolic(l, i, j, log_history)
|
1001
|
+
|
1002
|
+
def get_range(self, l, i, j, verbose=True):
|
1003
|
+
'''
|
1004
|
+
Get the input range and output range of the (l,i,j) activation
|
1005
|
+
|
1006
|
+
Args:
|
1007
|
+
-----
|
1008
|
+
l : int
|
1009
|
+
layer index
|
1010
|
+
i : int
|
1011
|
+
input neuron index
|
1012
|
+
j : int
|
1013
|
+
output neuron index
|
1014
|
+
|
1015
|
+
Returns:
|
1016
|
+
--------
|
1017
|
+
x_min : float
|
1018
|
+
minimum of input
|
1019
|
+
x_max : float
|
1020
|
+
maximum of input
|
1021
|
+
y_min : float
|
1022
|
+
minimum of output
|
1023
|
+
y_max : float
|
1024
|
+
maximum of output
|
1025
|
+
|
1026
|
+
Example
|
1027
|
+
-------
|
1028
|
+
>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
|
1029
|
+
>>> x = torch.normal(0,1,size=(100,2))
|
1030
|
+
>>> model(x) # do a forward pass to obtain model.acts
|
1031
|
+
>>> model.get_range(0,0,0)
|
1032
|
+
'''
|
1033
|
+
x = self.spline_preacts[l][:, j, i]
|
1034
|
+
y = self.spline_postacts[l][:, j, i]
|
1035
|
+
x_min = torch.min(x).cpu().detach().numpy()
|
1036
|
+
x_max = torch.max(x).cpu().detach().numpy()
|
1037
|
+
y_min = torch.min(y).cpu().detach().numpy()
|
1038
|
+
y_max = torch.max(y).cpu().detach().numpy()
|
1039
|
+
if verbose:
|
1040
|
+
print('x range: [' + '%.2f' % x_min, ',', '%.2f' % x_max, ']')
|
1041
|
+
print('y range: [' + '%.2f' % y_min, ',', '%.2f' % y_max, ']')
|
1042
|
+
return x_min, x_max, y_min, y_max
|
1043
|
+
|
1044
|
+
def _draw_operator_symbol(self, ax, fig, x, y, symbol_type, scale, size):
|
1045
|
+
"""辅助函数绘制运算符符号"""
|
1046
|
+
path = os.path.join(os.path.dirname(__file__), 'assets', 'img', f'{symbol_type}_symbol.png')
|
1047
|
+
try:
|
1048
|
+
with Image.open(path) as img:
|
1049
|
+
trans = ax.transData.transform
|
1050
|
+
inv_trans = fig.transFigure.inverted().transform
|
1051
|
+
bbox = [x - size / 2, y - size / 2, size, size]
|
1052
|
+
bbox = inv_trans.transform_bbox(trans(bbox))
|
1053
|
+
ax.imshow(img, extent=bbox)
|
1054
|
+
except Exception as e:
|
1055
|
+
print(f"无法加载符号 {symbol_type}: {str(e)}")
|
1056
|
+
|
1057
|
+
def _draw_variable_labels(self, ax, variables, layer, scale, varscale, y_offset, y0, z0):
|
1058
|
+
"""辅助函数绘制变量标签"""
|
1059
|
+
n = self.width_in[layer]
|
1060
|
+
for i, var in enumerate(variables):
|
1061
|
+
x = 1 / (2 * n) + i / n
|
1062
|
+
y = layer * (y0 + z0) + y_offset
|
1063
|
+
label = f'${latex(var)}$' if isinstance(var, sympy.Expr) else var
|
1064
|
+
ax.text(x, y, label, fontsize=15 * scale * varscale,
|
1065
|
+
ha='center', va='center', transform=ax.transData)
|
1066
|
+
|
1067
|
+
def plot(self, folder="./figures", beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None,
|
1068
|
+
out_vars=None, title=None, varscale=1.0):
|
1069
|
+
'''
|
1070
|
+
plot KAN
|
1071
|
+
|
1072
|
+
Args:
|
1073
|
+
-----
|
1074
|
+
folder : str
|
1075
|
+
the folder to store pngs
|
1076
|
+
beta : float
|
1077
|
+
positive number. control the transparency of each activation. transparency = tanh(beta*l1).
|
1078
|
+
mask : bool
|
1079
|
+
If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.
|
1080
|
+
mode : bool
|
1081
|
+
"supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean).
|
1082
|
+
scale : float
|
1083
|
+
control the size of the diagram
|
1084
|
+
in_vars: None or list of str
|
1085
|
+
the name(s) of input variables
|
1086
|
+
out_vars: None or list of str
|
1087
|
+
the name(s) of output variables
|
1088
|
+
title: None or str
|
1089
|
+
title
|
1090
|
+
varscale : float
|
1091
|
+
the size of input variables
|
1092
|
+
|
1093
|
+
Returns:
|
1094
|
+
--------
|
1095
|
+
Figure
|
1096
|
+
|
1097
|
+
Example
|
1098
|
+
-------
|
1099
|
+
>>> # see more interactive examples in demos
|
1100
|
+
>>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0)
|
1101
|
+
>>> x = torch.normal(0,1,size=(100,2))
|
1102
|
+
>>> model(x) # do a forward pass to obtain model.acts
|
1103
|
+
>>> model.plot()
|
1104
|
+
'''
|
1105
|
+
global Symbol
|
1106
|
+
|
1107
|
+
if not self.save_act:
|
1108
|
+
print('cannot plot since data are not saved. Set save_act=True first.')
|
1109
|
+
|
1110
|
+
if self.acts is None:
|
1111
|
+
if self.cache_data is None:
|
1112
|
+
raise Exception('model hasn\'t seen any data yet.')
|
1113
|
+
self.forward(self.cache_data)
|
1114
|
+
|
1115
|
+
if metric == 'backward':
|
1116
|
+
self.attribute()
|
1117
|
+
|
1118
|
+
if not os.path.exists(folder):
|
1119
|
+
os.makedirs(folder)
|
1120
|
+
|
1121
|
+
depth = len(self.width) - 1
|
1122
|
+
for l in range(depth):
|
1123
|
+
w_large = 2.0
|
1124
|
+
for i in range(self.width_in[l]):
|
1125
|
+
for j in range(self.width_out[l + 1]):
|
1126
|
+
rank = torch.argsort(self.acts[l][:, i])
|
1127
|
+
fig, ax = plt.subplots(figsize=(w_large, w_large))
|
1128
|
+
|
1129
|
+
num = rank.shape[0]
|
1130
|
+
symbolic_mask = self.symbolic_fun[l].mask[j][i]
|
1131
|
+
numeric_mask = self.act_fun[l].mask[i][j]
|
1132
|
+
if symbolic_mask > 0. and numeric_mask > 0.:
|
1133
|
+
color = 'purple'
|
1134
|
+
alpha_mask = 1
|
1135
|
+
if symbolic_mask > 0. and numeric_mask == 0.:
|
1136
|
+
color = "red"
|
1137
|
+
alpha_mask = 1
|
1138
|
+
if symbolic_mask == 0. and numeric_mask > 0.:
|
1139
|
+
color = "black"
|
1140
|
+
alpha_mask = 1
|
1141
|
+
if symbolic_mask == 0. and numeric_mask == 0.:
|
1142
|
+
color = "white"
|
1143
|
+
alpha_mask = 0
|
1144
|
+
|
1145
|
+
if tick:
|
1146
|
+
ax.tick_params(axis="y", direction="in", pad=-22, labelsize=50)
|
1147
|
+
ax.tick_params(axis="x", direction="in", pad=-15, labelsize=50)
|
1148
|
+
x_min, x_max, y_min, y_max = self.get_range(l, i, j, verbose=False)
|
1149
|
+
plt.xticks([x_min, x_max], ['%2.f' % x_min, '%2.f' % x_max])
|
1150
|
+
plt.yticks([y_min, y_max], ['%2.f' % y_min, '%2.f' % y_max])
|
1151
|
+
else:
|
1152
|
+
plt.xticks([])
|
1153
|
+
plt.yticks([])
|
1154
|
+
if alpha_mask == 1:
|
1155
|
+
plt.gca().patch.set_edgecolor('black')
|
1156
|
+
else:
|
1157
|
+
plt.gca().patch.set_edgecolor('white')
|
1158
|
+
plt.gca().patch.set_linewidth(1.5)
|
1159
|
+
# plt.axis('off')
|
1160
|
+
|
1161
|
+
plt.plot(self.acts[l][:, i][rank].cpu().detach().numpy(),
|
1162
|
+
self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, lw=5)
|
1163
|
+
if sample:
|
1164
|
+
plt.scatter(self.acts[l][:, i][rank].cpu().detach().numpy(),
|
1165
|
+
self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color,
|
1166
|
+
s=400 * scale ** 2)
|
1167
|
+
plt.gca().spines[:].set_color(color)
|
1168
|
+
|
1169
|
+
plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400)
|
1170
|
+
plt.close()
|
1171
|
+
|
1172
|
+
def score2alpha(score):
|
1173
|
+
return np.tanh(beta * score)
|
1174
|
+
|
1175
|
+
if metric == 'forward_n':
|
1176
|
+
scores = self.acts_scale
|
1177
|
+
elif metric == 'forward_u':
|
1178
|
+
scores = self.edge_actscale
|
1179
|
+
elif metric == 'backward':
|
1180
|
+
scores = self.edge_scores
|
1181
|
+
else:
|
1182
|
+
raise Exception(f'metric = \'{metric}\' not recognized')
|
1183
|
+
|
1184
|
+
alpha = [score2alpha(score.cpu().detach().numpy()) for score in scores]
|
1185
|
+
|
1186
|
+
# draw skeleton
|
1187
|
+
width = np.array(self.width)
|
1188
|
+
width_in = np.array(self.width_in)
|
1189
|
+
width_out = np.array(self.width_out)
|
1190
|
+
A = 1
|
1191
|
+
y0 = 0.3 # height: from input to pre-mult
|
1192
|
+
z0 = 0.1 # height: from pre-mult to post-mult (input of next layer)
|
1193
|
+
|
1194
|
+
neuron_depth = len(width)
|
1195
|
+
min_spacing = A / np.maximum(np.max(width_out), 5)
|
1196
|
+
|
1197
|
+
max_neuron = np.max(width_out)
|
1198
|
+
max_num_weights = np.max(width_in[:-1] * width_out[1:])
|
1199
|
+
y1 = 0.4 / np.maximum(max_num_weights, 5) # size (height/width) of 1D function diagrams
|
1200
|
+
y2 = 0.15 / np.maximum(max_neuron, 5) # size (height/width) of operations (sum and mult)
|
1201
|
+
|
1202
|
+
fig, ax = plt.subplots(figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * (y0 + z0)))
|
1203
|
+
# fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0))
|
1204
|
+
|
1205
|
+
# -- Transformation functions
|
1206
|
+
DC_to_FC = ax.transData.transform
|
1207
|
+
FC_to_NFC = fig.transFigure.inverted().transform
|
1208
|
+
# -- Take data coordinates and transform them to normalized figure coordinates
|
1209
|
+
DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x))
|
1210
|
+
|
1211
|
+
# plot scatters and lines
|
1212
|
+
for l in range(neuron_depth):
|
1213
|
+
|
1214
|
+
n = width_in[l]
|
1215
|
+
|
1216
|
+
# scatters
|
1217
|
+
for i in range(n):
|
1218
|
+
plt.scatter(1 / (2 * n) + i / n, l * (y0 + z0), s=min_spacing ** 2 * 10000 * scale ** 2, color='black')
|
1219
|
+
|
1220
|
+
# plot connections (input to pre-mult)
|
1221
|
+
for i in range(n):
|
1222
|
+
if l < neuron_depth - 1:
|
1223
|
+
n_next = width_out[l + 1]
|
1224
|
+
N = n * n_next
|
1225
|
+
for j in range(n_next):
|
1226
|
+
id_ = i * n_next + j
|
1227
|
+
|
1228
|
+
symbol_mask = self.symbolic_fun[l].mask[j][i]
|
1229
|
+
numerical_mask = self.act_fun[l].mask[i][j]
|
1230
|
+
if symbol_mask == 1. and numerical_mask > 0.:
|
1231
|
+
color = 'purple'
|
1232
|
+
alpha_mask = 1.
|
1233
|
+
if symbol_mask == 1. and numerical_mask == 0.:
|
1234
|
+
color = "red"
|
1235
|
+
alpha_mask = 1.
|
1236
|
+
if symbol_mask == 0. and numerical_mask == 1.:
|
1237
|
+
color = "black"
|
1238
|
+
alpha_mask = 1.
|
1239
|
+
if symbol_mask == 0. and numerical_mask == 0.:
|
1240
|
+
color = "white"
|
1241
|
+
alpha_mask = 0.
|
1242
|
+
|
1243
|
+
plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N],
|
1244
|
+
[l * (y0 + z0), l * (y0 + z0) + y0 / 2 - y1], color=color, lw=2 * scale,
|
1245
|
+
alpha=alpha[l][j][i] * alpha_mask)
|
1246
|
+
plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next],
|
1247
|
+
[l * (y0 + z0) + y0 / 2 + y1, l * (y0 + z0) + y0], color=color, lw=2 * scale,
|
1248
|
+
alpha=alpha[l][j][i] * alpha_mask)
|
1249
|
+
|
1250
|
+
# plot connections (pre-mult to post-mult, post-mult = next-layer input)
|
1251
|
+
if l < neuron_depth - 1:
|
1252
|
+
n_in = width_out[l + 1]
|
1253
|
+
n_out = width_in[l + 1]
|
1254
|
+
mult_id = 0
|
1255
|
+
for i in range(n_in):
|
1256
|
+
if i < width[l + 1][0]:
|
1257
|
+
j = i
|
1258
|
+
else:
|
1259
|
+
if i == width[l + 1][0]:
|
1260
|
+
if isinstance(self.mult_arity, int):
|
1261
|
+
ma = self.mult_arity
|
1262
|
+
else:
|
1263
|
+
ma = self.mult_arity[l + 1][mult_id]
|
1264
|
+
current_mult_arity = ma
|
1265
|
+
if current_mult_arity == 0:
|
1266
|
+
mult_id += 1
|
1267
|
+
if isinstance(self.mult_arity, int):
|
1268
|
+
ma = self.mult_arity
|
1269
|
+
else:
|
1270
|
+
ma = self.mult_arity[l + 1][mult_id]
|
1271
|
+
current_mult_arity = ma
|
1272
|
+
j = width[l + 1][0] + mult_id
|
1273
|
+
current_mult_arity -= 1
|
1274
|
+
#j = (i-width[l+1][0])//self.mult_arity + width[l+1][0]
|
1275
|
+
plt.plot([1 / (2 * n_in) + i / n_in, 1 / (2 * n_out) + j / n_out],
|
1276
|
+
[l * (y0 + z0) + y0, (l + 1) * (y0 + z0)], color='black', lw=2 * scale)
|
1277
|
+
|
1278
|
+
plt.xlim(0, 1)
|
1279
|
+
plt.ylim(-0.1 * (y0 + z0), (neuron_depth - 1 + 0.1) * (y0 + z0))
|
1280
|
+
|
1281
|
+
plt.axis('off')
|
1282
|
+
|
1283
|
+
for l in range(neuron_depth - 1):
|
1284
|
+
# plot splines
|
1285
|
+
n = width_in[l]
|
1286
|
+
for i in range(n):
|
1287
|
+
n_next = width_out[l + 1]
|
1288
|
+
N = n * n_next
|
1289
|
+
for j in range(n_next):
|
1290
|
+
id_ = i * n_next + j
|
1291
|
+
im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png')
|
1292
|
+
left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0]
|
1293
|
+
right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0]
|
1294
|
+
bottom = DC_to_NFC([0, l * (y0 + z0) + y0 / 2 - y1])[1]
|
1295
|
+
up = DC_to_NFC([0, l * (y0 + z0) + y0 / 2 + y1])[1]
|
1296
|
+
newax = fig.add_axes([left, bottom, right - left, up - bottom])
|
1297
|
+
# newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE')
|
1298
|
+
newax.imshow(im, alpha=alpha[l][j][i])
|
1299
|
+
newax.axis('off')
|
1300
|
+
|
1301
|
+
# plot sum symbols
|
1302
|
+
N = n = width_out[l + 1]
|
1303
|
+
for j in range(n):
|
1304
|
+
id_ = j
|
1305
|
+
path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
|
1306
|
+
im = plt.imread(path)
|
1307
|
+
left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
|
1308
|
+
right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
|
1309
|
+
bottom = DC_to_NFC([0, l * (y0 + z0) + y0 - y2])[1]
|
1310
|
+
up = DC_to_NFC([0, l * (y0 + z0) + y0 + y2])[1]
|
1311
|
+
newax = fig.add_axes([left, bottom, right - left, up - bottom])
|
1312
|
+
newax.imshow(im)
|
1313
|
+
newax.axis('off')
|
1314
|
+
|
1315
|
+
# plot mult symbols
|
1316
|
+
N = n = width_in[l + 1]
|
1317
|
+
n_sum = width[l + 1][0]
|
1318
|
+
n_mult = width[l + 1][1]
|
1319
|
+
for j in range(n_mult):
|
1320
|
+
id_ = j + n_sum
|
1321
|
+
path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
|
1322
|
+
im = plt.imread(path)
|
1323
|
+
left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
|
1324
|
+
right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
|
1325
|
+
bottom = DC_to_NFC([0, (l + 1) * (y0 + z0) - y2])[1]
|
1326
|
+
up = DC_to_NFC([0, (l + 1) * (y0 + z0) + y2])[1]
|
1327
|
+
newax = fig.add_axes([left, bottom, right - left, up - bottom])
|
1328
|
+
newax.imshow(im)
|
1329
|
+
newax.axis('off')
|
1330
|
+
|
1331
|
+
if in_vars is not None:
|
1332
|
+
n = self.width_in[0]
|
1333
|
+
for i in range(n):
|
1334
|
+
if isinstance(in_vars[i], sympy.Expr):
|
1335
|
+
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, f'${latex(in_vars[i])}$',
|
1336
|
+
fontsize=40 * scale * varscale, horizontalalignment='center',
|
1337
|
+
verticalalignment='center')
|
1338
|
+
else:
|
1339
|
+
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i],
|
1340
|
+
fontsize=40 * scale * varscale, horizontalalignment='center',
|
1341
|
+
verticalalignment='center')
|
1342
|
+
|
1343
|
+
if out_vars is not None:
|
1344
|
+
n = self.width_in[-1]
|
1345
|
+
for i in range(n):
|
1346
|
+
if isinstance(out_vars[i], sympy.Expr):
|
1347
|
+
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0 + z0) * (len(self.width) - 1) + 0.15,
|
1348
|
+
f'${latex(out_vars[i])}$', fontsize=40 * scale * varscale,
|
1349
|
+
horizontalalignment='center', verticalalignment='center')
|
1350
|
+
else:
|
1351
|
+
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0 + z0) * (len(self.width) - 1) + 0.15,
|
1352
|
+
out_vars[i], fontsize=40 * scale * varscale,
|
1353
|
+
horizontalalignment='center', verticalalignment='center')
|
1354
|
+
|
1355
|
+
if title is not None:
|
1356
|
+
plt.gcf().get_axes()[0].text(0.5, (y0 + z0) * (len(self.width) - 1) + 0.3, title, fontsize=40 * scale,
|
1357
|
+
horizontalalignment='center', verticalalignment='center')
|
1358
|
+
|
1359
|
+
def reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff):
|
1360
|
+
"""
|
1361
|
+
Get regularization
|
1362
|
+
|
1363
|
+
Args:
|
1364
|
+
-----
|
1365
|
+
reg_metric : the regularization metric
|
1366
|
+
'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'
|
1367
|
+
lamb_l1 : float
|
1368
|
+
l1 penalty strength
|
1369
|
+
lamb_entropy : float
|
1370
|
+
entropy penalty strength
|
1371
|
+
lamb_coef : float
|
1372
|
+
coefficient penalty strength
|
1373
|
+
lamb_coefdiff : float
|
1374
|
+
coefficient smoothness strength
|
1375
|
+
|
1376
|
+
Returns:
|
1377
|
+
--------
|
1378
|
+
reg_ : torch.float
|
1379
|
+
|
1380
|
+
Example
|
1381
|
+
-------
|
1382
|
+
>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
|
1383
|
+
>>> x = torch.rand(100,2)
|
1384
|
+
>>> model.get_act(x)
|
1385
|
+
>>> model.reg('edge_forward_spline_n', 1.0, 2.0, 1.0, 1.0)
|
1386
|
+
"""
|
1387
|
+
if reg_metric == 'edge_forward_spline_n':
|
1388
|
+
acts_scale = self.acts_scale_spline
|
1389
|
+
|
1390
|
+
elif reg_metric == 'edge_forward_sum':
|
1391
|
+
acts_scale = self.acts_scale
|
1392
|
+
|
1393
|
+
elif reg_metric == 'edge_forward_spline_u':
|
1394
|
+
acts_scale = self.edge_actscale
|
1395
|
+
|
1396
|
+
elif reg_metric == 'edge_backward':
|
1397
|
+
acts_scale = self.edge_scores
|
1398
|
+
|
1399
|
+
elif reg_metric == 'node_backward':
|
1400
|
+
acts_scale = self.node_attribute_scores
|
1401
|
+
|
1402
|
+
else:
|
1403
|
+
raise Exception(f'reg_metric = {reg_metric} not recognized!')
|
1404
|
+
|
1405
|
+
reg_ = 0.
|
1406
|
+
for i in range(len(acts_scale)):
|
1407
|
+
vec = acts_scale[i]
|
1408
|
+
|
1409
|
+
l1 = torch.sum(vec)
|
1410
|
+
p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1)
|
1411
|
+
p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1)
|
1412
|
+
entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1))
|
1413
|
+
entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0))
|
1414
|
+
reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col) # both l1 and entropy
|
1415
|
+
|
1416
|
+
# regularize coefficient to encourage spline to be zero
|
1417
|
+
for i in range(len(self.act_fun)):
|
1418
|
+
coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1))
|
1419
|
+
coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)), dim=1))
|
1420
|
+
reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1
|
1421
|
+
|
1422
|
+
return reg_
|
1423
|
+
|
1424
|
+
def get_reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff):
|
1425
|
+
"""
|
1426
|
+
Get regularization. This seems unnecessary but in case a class wants to inherit this, it may want to rewrite get_reg, but not reg.
|
1427
|
+
"""
|
1428
|
+
return self.reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
1429
|
+
|
1430
|
+
def disable_symbolic_in_fit(self, lamb):
|
1431
|
+
"""
|
1432
|
+
during fitting, disable symbolic if either is true (lamb = 0, none of symbolic functions is active)
|
1433
|
+
"""
|
1434
|
+
old_save_act = self.save_act
|
1435
|
+
if lamb == 0.:
|
1436
|
+
self.save_act = False
|
1437
|
+
|
1438
|
+
# skip symbolic if no symbolic is turned on
|
1439
|
+
depth = len(self.symbolic_fun)
|
1440
|
+
no_symbolic = True
|
1441
|
+
for l in range(depth):
|
1442
|
+
no_symbolic *= torch.sum(torch.abs(self.symbolic_fun[l].mask)) == 0
|
1443
|
+
|
1444
|
+
old_symbolic_enabled = self.symbolic_enabled
|
1445
|
+
|
1446
|
+
if no_symbolic:
|
1447
|
+
self.symbolic_enabled = False
|
1448
|
+
|
1449
|
+
return old_save_act, old_symbolic_enabled
|
1450
|
+
|
1451
|
+
def get_params(self):
|
1452
|
+
"""
|
1453
|
+
Get parameters
|
1454
|
+
"""
|
1455
|
+
return self.parameters()
|
1456
|
+
|
1457
|
+
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0.,
|
1458
|
+
lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
|
1459
|
+
stop_grid_update_step=50, batch=-1,
|
1460
|
+
metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
1461
|
+
singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):
|
1462
|
+
'''
|
1463
|
+
training
|
1464
|
+
|
1465
|
+
Args:
|
1466
|
+
-----
|
1467
|
+
dataset : dic
|
1468
|
+
contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
|
1469
|
+
opt : str
|
1470
|
+
"LBFGS" or "Adam"
|
1471
|
+
steps : int
|
1472
|
+
training steps
|
1473
|
+
log : int
|
1474
|
+
logging frequency
|
1475
|
+
lamb : float
|
1476
|
+
overall penalty strength
|
1477
|
+
lamb_l1 : float
|
1478
|
+
l1 penalty strength
|
1479
|
+
lamb_entropy : float
|
1480
|
+
entropy penalty strength
|
1481
|
+
lamb_coef : float
|
1482
|
+
coefficient magnitude penalty strength
|
1483
|
+
lamb_coefdiff : float
|
1484
|
+
difference of nearby coefficits (smoothness) penalty strength
|
1485
|
+
update_grid : bool
|
1486
|
+
If True, update grid regularly before stop_grid_update_step
|
1487
|
+
grid_update_num : int
|
1488
|
+
the number of grid updates before stop_grid_update_step
|
1489
|
+
start_grid_update_step : int
|
1490
|
+
no grid updates before this training step
|
1491
|
+
stop_grid_update_step : int
|
1492
|
+
no grid updates after this training step
|
1493
|
+
loss_fn : function
|
1494
|
+
loss function
|
1495
|
+
lr : float
|
1496
|
+
learning rate
|
1497
|
+
batch : int
|
1498
|
+
batch size, if -1 then full.
|
1499
|
+
save_fig_freq : int
|
1500
|
+
save figure every (save_fig_freq) steps
|
1501
|
+
singularity_avoiding : bool
|
1502
|
+
indicate whether to avoid singularity for the symbolic part
|
1503
|
+
y_th : float
|
1504
|
+
singularity threshold (anything above the threshold is considered singular and is softened in some ways)
|
1505
|
+
reg_metric : str
|
1506
|
+
regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'}
|
1507
|
+
metrics : a list of metrics (as functions)
|
1508
|
+
the metrics to be computed in training
|
1509
|
+
display_metrics : a list of functions
|
1510
|
+
the metric to be displayed in tqdm progress bar
|
1511
|
+
|
1512
|
+
Returns:
|
1513
|
+
--------
|
1514
|
+
results : dic
|
1515
|
+
results['train_loss'], 1D array of training losses (RMSE)
|
1516
|
+
results['test_loss'], 1D array of test losses (RMSE)
|
1517
|
+
results['reg'], 1D array of regularization
|
1518
|
+
other metrics specified in metrics
|
1519
|
+
|
1520
|
+
Example
|
1521
|
+
-------
|
1522
|
+
# >>> from kan import *
|
1523
|
+
# >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
|
1524
|
+
# >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
1525
|
+
# >>> dataset = create_dataset(f, n_var=2)
|
1526
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
1527
|
+
# >>> model.plot()
|
1528
|
+
# Most examples in toturals involve the fit() method. Please check them for useness.
|
1529
|
+
'''
|
1530
|
+
|
1531
|
+
if lamb > 0. and not self.save_act:
|
1532
|
+
print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True')
|
1533
|
+
|
1534
|
+
old_save_act, old_symbolic_enabled = self.disable_symbolic_in_fit(lamb)
|
1535
|
+
|
1536
|
+
pbar = tqdm(range(steps), desc='description', ncols=100)
|
1537
|
+
|
1538
|
+
if loss_fn is None:
|
1539
|
+
loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
|
1540
|
+
else:
|
1541
|
+
loss_fn = loss_fn_eval = loss_fn
|
1542
|
+
|
1543
|
+
grid_update_freq = int(stop_grid_update_step / grid_update_num)
|
1544
|
+
|
1545
|
+
if opt == "Adam":
|
1546
|
+
optimizer = torch.optim.Adam(self.get_params(), lr=lr)
|
1547
|
+
elif opt == "LBFGS":
|
1548
|
+
optimizer = LBFGS(self.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
1549
|
+
tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
1550
|
+
|
1551
|
+
results = {'train_loss': [], 'test_loss': [], 'reg': []}
|
1552
|
+
if metrics is not None:
|
1553
|
+
for i in range(len(metrics)):
|
1554
|
+
results[metrics[i].__name__] = []
|
1555
|
+
|
1556
|
+
if batch == -1 or batch > dataset['train_input'].shape[0]:
|
1557
|
+
batch_size = dataset['train_input'].shape[0]
|
1558
|
+
batch_size_test = dataset['test_input'].shape[0]
|
1559
|
+
else:
|
1560
|
+
batch_size = batch
|
1561
|
+
batch_size_test = batch
|
1562
|
+
|
1563
|
+
global train_loss, reg_
|
1564
|
+
|
1565
|
+
def closure():
|
1566
|
+
global train_loss, reg_
|
1567
|
+
optimizer.zero_grad()
|
1568
|
+
pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th)
|
1569
|
+
train_loss = loss_fn(pred, dataset['train_label'][train_id])
|
1570
|
+
if self.save_act:
|
1571
|
+
if reg_metric == 'edge_backward':
|
1572
|
+
self.attribute()
|
1573
|
+
if reg_metric == 'node_backward':
|
1574
|
+
self.node_attribute()
|
1575
|
+
reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
1576
|
+
else:
|
1577
|
+
reg_ = torch.tensor(0.)
|
1578
|
+
objective = train_loss + lamb * reg_
|
1579
|
+
objective.backward()
|
1580
|
+
return objective
|
1581
|
+
|
1582
|
+
if save_fig:
|
1583
|
+
if not os.path.exists(img_folder):
|
1584
|
+
os.makedirs(img_folder)
|
1585
|
+
|
1586
|
+
for _ in pbar:
|
1587
|
+
|
1588
|
+
if _ == steps - 1 and old_save_act:
|
1589
|
+
self.save_act = True
|
1590
|
+
|
1591
|
+
if save_fig and _ % save_fig_freq == 0:
|
1592
|
+
save_act = self.save_act
|
1593
|
+
self.save_act = True
|
1594
|
+
|
1595
|
+
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
|
1596
|
+
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
|
1597
|
+
|
1598
|
+
if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step:
|
1599
|
+
self.update_grid(dataset['train_input'][train_id])
|
1600
|
+
|
1601
|
+
if opt == "LBFGS":
|
1602
|
+
optimizer.step(closure)
|
1603
|
+
|
1604
|
+
if opt == "Adam":
|
1605
|
+
pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding,
|
1606
|
+
y_th=y_th)
|
1607
|
+
train_loss = loss_fn(pred, dataset['train_label'][train_id])
|
1608
|
+
if self.save_act:
|
1609
|
+
if reg_metric == 'edge_backward':
|
1610
|
+
self.attribute()
|
1611
|
+
if reg_metric == 'node_backward':
|
1612
|
+
self.node_attribute()
|
1613
|
+
reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
1614
|
+
else:
|
1615
|
+
reg_ = torch.tensor(0.)
|
1616
|
+
loss = train_loss + lamb * reg_
|
1617
|
+
optimizer.zero_grad()
|
1618
|
+
loss.backward()
|
1619
|
+
optimizer.step()
|
1620
|
+
|
1621
|
+
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id]), dataset['test_label'][test_id])
|
1622
|
+
|
1623
|
+
if metrics is not None:
|
1624
|
+
for i in range(len(metrics)):
|
1625
|
+
results[metrics[i].__name__].append(metrics[i]().item())
|
1626
|
+
|
1627
|
+
results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
|
1628
|
+
results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
|
1629
|
+
results['reg'].append(reg_.cpu().detach().numpy())
|
1630
|
+
|
1631
|
+
if _ % log == 0:
|
1632
|
+
if display_metrics is None:
|
1633
|
+
pbar.set_description("| train_loss: %.6f | test_loss: %.6f | reg: %.6f | " % (
|
1634
|
+
torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(),
|
1635
|
+
reg_.cpu().detach().numpy()))
|
1636
|
+
else:
|
1637
|
+
string = ''
|
1638
|
+
data = ()
|
1639
|
+
for metric in display_metrics:
|
1640
|
+
string += f' {metric}: %.6f |'
|
1641
|
+
try:
|
1642
|
+
results[metric]
|
1643
|
+
except:
|
1644
|
+
raise Exception(f'{metric} not recognized')
|
1645
|
+
data += (results[metric][-1],)
|
1646
|
+
pbar.set_description(string % data)
|
1647
|
+
|
1648
|
+
if save_fig and _ % save_fig_freq == 0:
|
1649
|
+
self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
|
1650
|
+
plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=100)
|
1651
|
+
plt.close()
|
1652
|
+
self.save_act = save_act
|
1653
|
+
|
1654
|
+
self.log_history('fit')
|
1655
|
+
# revert back to original state
|
1656
|
+
self.symbolic_enabled = old_symbolic_enabled
|
1657
|
+
return results
|
1658
|
+
|
1659
|
+
def fix(self, dataset: dict, batch_size, batch_size_test, labels, opt="LBFGS", epochs=100, log=1, lamb=0.,
|
1660
|
+
lamb_l1=1.,
|
1661
|
+
lamb_entropy=2., lamb_coef=0.,
|
1662
|
+
lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
|
1663
|
+
stop_grid_update_step=100,
|
1664
|
+
metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
1665
|
+
singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):
|
1666
|
+
"""
|
1667
|
+
|
1668
|
+
Args:
|
1669
|
+
labels:
|
1670
|
+
batch_size:
|
1671
|
+
batch_size_test:
|
1672
|
+
lamb_coefdiff:
|
1673
|
+
metrics:
|
1674
|
+
stop_grid_update_step:
|
1675
|
+
save_fig:
|
1676
|
+
singularity_avoiding:
|
1677
|
+
y_th:
|
1678
|
+
in_vars:
|
1679
|
+
batch:
|
1680
|
+
update_grid:
|
1681
|
+
grid_update_num:
|
1682
|
+
loss_fn:
|
1683
|
+
lr:
|
1684
|
+
beta:
|
1685
|
+
out_vars:
|
1686
|
+
save_fig_freq:
|
1687
|
+
reg_metric:
|
1688
|
+
display_metrics:
|
1689
|
+
img_folder:
|
1690
|
+
start_grid_update_step:
|
1691
|
+
lamb_coef:
|
1692
|
+
lamb_entropy:
|
1693
|
+
lamb_l1:
|
1694
|
+
lamb:
|
1695
|
+
log:
|
1696
|
+
epochs:
|
1697
|
+
opt:
|
1698
|
+
dataset (dict):
|
1699
|
+
"""
|
1700
|
+
|
1701
|
+
def calculate_results(all_labels, all_predictions, classes=None, average='macro'):
|
1702
|
+
result = {
|
1703
|
+
'accuracy': accuracy_score(y_true=all_labels, y_pred=all_predictions),
|
1704
|
+
'precision': precision_score(y_true=all_labels, y_pred=all_predictions, average=average),
|
1705
|
+
'recall': recall_score(y_true=all_labels, y_pred=all_predictions, average=average),
|
1706
|
+
'f1_score': f1_score(y_true=all_labels, y_pred=all_predictions, average=average),
|
1707
|
+
# 'cm': confusion_matrix(y_true=all_labels, y_pred=all_predictions, labels=np.arange(len(classes)))
|
1708
|
+
}
|
1709
|
+
return result
|
1710
|
+
if lamb > 0. and not self.save_act:
|
1711
|
+
print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True')
|
1712
|
+
|
1713
|
+
all_predictions = []
|
1714
|
+
all_labels = []
|
1715
|
+
|
1716
|
+
old_save_act, old_symbolic_enabled = self.disable_symbolic_in_fit(lamb)
|
1717
|
+
|
1718
|
+
pbar = tqdm(range(epochs), desc='description', ncols=100)
|
1719
|
+
|
1720
|
+
if loss_fn is None:
|
1721
|
+
loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
|
1722
|
+
else:
|
1723
|
+
loss_fn = loss_fn_eval = loss_fn
|
1724
|
+
|
1725
|
+
grid_update_freq = int(stop_grid_update_step / grid_update_num)
|
1726
|
+
|
1727
|
+
if opt == "Adam":
|
1728
|
+
optimizer = torch.optim.Adam(self.get_params(), lr=lr)
|
1729
|
+
elif opt == "LBFGS":
|
1730
|
+
optimizer = LBFGS(self.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
1731
|
+
tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
1732
|
+
|
1733
|
+
results = {'train_loss': [], 'test_loss': [], 'reg': [], 'accuracy': [],
|
1734
|
+
'precision': [], 'recall': [], 'f1_score': []}
|
1735
|
+
if metrics is not None:
|
1736
|
+
for i in range(len(metrics)):
|
1737
|
+
results[metrics[i].__name__] = []
|
1738
|
+
|
1739
|
+
steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
1740
|
+
total_steps = steps * epochs
|
1741
|
+
# global train_loss, reg_
|
1742
|
+
train_loss = torch.zeros(1).to(self.device)
|
1743
|
+
reg_ = torch.zeros(1).to(self.device)
|
1744
|
+
labels = labels.to(self.device)
|
1745
|
+
|
1746
|
+
def closure():
|
1747
|
+
nonlocal train_loss, reg_
|
1748
|
+
optimizer.zero_grad()
|
1749
|
+
pred = self.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
|
1750
|
+
loss = loss_fn(pred, batch_train_label)
|
1751
|
+
if self.save_act:
|
1752
|
+
if reg_metric == 'edge_backward':
|
1753
|
+
self.attribute()
|
1754
|
+
if reg_metric == 'node_backward':
|
1755
|
+
self.node_attribute()
|
1756
|
+
reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
1757
|
+
else:
|
1758
|
+
reg_ = torch.tensor(0.)
|
1759
|
+
objective = loss + lamb * reg_
|
1760
|
+
train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
|
1761
|
+
objective.backward()
|
1762
|
+
return objective
|
1763
|
+
|
1764
|
+
if save_fig:
|
1765
|
+
if not os.path.exists(img_folder):
|
1766
|
+
os.makedirs(img_folder)
|
1767
|
+
|
1768
|
+
for epoch in pbar:
|
1769
|
+
|
1770
|
+
if epoch == epochs - 1 and old_save_act:
|
1771
|
+
self.save_act = True
|
1772
|
+
|
1773
|
+
if save_fig and epoch % save_fig_freq == 0:
|
1774
|
+
save_act = self.save_act
|
1775
|
+
self.save_act = True
|
1776
|
+
|
1777
|
+
train_indices = np.arange(dataset['train_input'].shape[0])
|
1778
|
+
np.random.shuffle(train_indices)
|
1779
|
+
for batch_num, i in enumerate(range(0, len(train_indices), batch_size), start=0):
|
1780
|
+
step = epoch * steps + batch_num + 1
|
1781
|
+
batch_train_id = train_indices[i:i + batch_size]
|
1782
|
+
batch_train_input = dataset['train_input'][batch_train_id].to(self.device)
|
1783
|
+
batch_train_label = dataset['train_label'][batch_train_id].to(self.device)
|
1784
|
+
|
1785
|
+
if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
|
1786
|
+
self.update_grid(batch_train_input)
|
1787
|
+
|
1788
|
+
if opt == "LBFGS":
|
1789
|
+
optimizer.step(closure)
|
1790
|
+
|
1791
|
+
if opt == "Adam":
|
1792
|
+
optimizer.zero_grad()
|
1793
|
+
pred = self.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
|
1794
|
+
y_th=y_th)
|
1795
|
+
loss = loss_fn(pred, batch_train_label)
|
1796
|
+
if self.save_act:
|
1797
|
+
if reg_metric == 'edge_backward':
|
1798
|
+
self.attribute()
|
1799
|
+
if reg_metric == 'node_backward':
|
1800
|
+
self.node_attribute()
|
1801
|
+
reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
1802
|
+
else:
|
1803
|
+
reg_ = torch.tensor(0.)
|
1804
|
+
loss = loss + lamb * reg_
|
1805
|
+
train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
|
1806
|
+
loss.backward()
|
1807
|
+
optimizer.step()
|
1808
|
+
|
1809
|
+
test_loss = torch.zeros(1).to(self.device)
|
1810
|
+
with torch.no_grad():
|
1811
|
+
test_indices = np.arange(dataset['test_input'].shape[0])
|
1812
|
+
np.random.shuffle(test_indices)
|
1813
|
+
for batch_num, i in enumerate(range(0, len(test_indices), batch_size_test)):
|
1814
|
+
batch_test_id = test_indices[i:i + batch_size_test]
|
1815
|
+
batch_test_input = dataset['test_input'][batch_test_id].to(self.device)
|
1816
|
+
batch_test_label = dataset['test_label'][batch_test_id].to(self.device)
|
1817
|
+
|
1818
|
+
outputs = self.forward(batch_test_input)
|
1819
|
+
|
1820
|
+
loss = loss_fn(outputs, batch_test_label)
|
1821
|
+
|
1822
|
+
test_loss = (test_loss * batch_num + loss.detach()) / (batch_num + 1)
|
1823
|
+
diffs = torch.abs(outputs - labels)
|
1824
|
+
closest_indices = torch.argmin(diffs, dim=1)
|
1825
|
+
predict = labels[closest_indices]
|
1826
|
+
all_predictions.extend(predict.detach().cpu().numpy())
|
1827
|
+
all_labels.extend(batch_test_label.detach().cpu().numpy())
|
1828
|
+
|
1829
|
+
if metrics is not None:
|
1830
|
+
for i in range(len(metrics)):
|
1831
|
+
results[metrics[i].__name__].append(metrics[i]().item())
|
1832
|
+
|
1833
|
+
results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
|
1834
|
+
results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
|
1835
|
+
results['reg'].append(reg_.cpu().detach().numpy())
|
1836
|
+
res = calculate_results(all_predictions, all_labels)
|
1837
|
+
for key, value in res.items():
|
1838
|
+
results[key].append(value)
|
1839
|
+
|
1840
|
+
if epoch % log == 0:
|
1841
|
+
if display_metrics is None:
|
1842
|
+
pbar.set_description(f"| train_loss: %.6f | test_loss: %.6f | reg: %.6f |step:{step} " % (
|
1843
|
+
train_loss.item(), test_loss.item(), reg_.item()))
|
1844
|
+
else:
|
1845
|
+
string = ''
|
1846
|
+
data = ()
|
1847
|
+
for metric in display_metrics:
|
1848
|
+
string += f' {metric}: %.6f |'
|
1849
|
+
try:
|
1850
|
+
results[metric]
|
1851
|
+
except:
|
1852
|
+
raise Exception(f'{metric} not recognized')
|
1853
|
+
data += (results[metric][-1],)
|
1854
|
+
pbar.set_description(string % data)
|
1855
|
+
|
1856
|
+
if save_fig and epoch % save_fig_freq == 0:
|
1857
|
+
self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
|
1858
|
+
beta=beta)
|
1859
|
+
plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
|
1860
|
+
plt.close()
|
1861
|
+
self.save_act = save_act
|
1862
|
+
|
1863
|
+
self.log_history('fit')
|
1864
|
+
self.symbolic_enabled = old_symbolic_enabled
|
1865
|
+
return results
|
1866
|
+
|
1867
|
+
def prune_node(self, threshold=1e-2, mode="auto", active_neurons_id=None, log_history=True):
|
1868
|
+
"""
|
1869
|
+
pruning nodes
|
1870
|
+
|
1871
|
+
Args:
|
1872
|
+
-----
|
1873
|
+
threshold : float
|
1874
|
+
if the attribution score of a neuron is below the threshold, it is considered dead and will be removed
|
1875
|
+
mode : str
|
1876
|
+
'auto' or 'manual'. with 'auto', nodes are automatically pruned using threshold. with 'manual', active_neurons_id should be passed in.
|
1877
|
+
|
1878
|
+
Returns:
|
1879
|
+
--------
|
1880
|
+
pruned network : MultKAN
|
1881
|
+
|
1882
|
+
Example
|
1883
|
+
-------
|
1884
|
+
# >>> from kan import *
|
1885
|
+
# >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
|
1886
|
+
# >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
1887
|
+
# >>> dataset = create_dataset(f, n_var=2)
|
1888
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
1889
|
+
# >>> model = model.prune_node()
|
1890
|
+
# >>> model.plot()
|
1891
|
+
"""
|
1892
|
+
if self.acts is None:
|
1893
|
+
self.get_act()
|
1894
|
+
|
1895
|
+
mask_up = [torch.ones(self.width_in[0], device=self.device)]
|
1896
|
+
mask_down = []
|
1897
|
+
active_neurons_up = [list(range(self.width_in[0]))]
|
1898
|
+
active_neurons_down = []
|
1899
|
+
num_sums = []
|
1900
|
+
num_mults = []
|
1901
|
+
mult_arities = [[]]
|
1902
|
+
|
1903
|
+
if active_neurons_id != None:
|
1904
|
+
mode = "manual"
|
1905
|
+
|
1906
|
+
for i in range(len(self.acts_scale) - 1):
|
1907
|
+
|
1908
|
+
mult_arity = []
|
1909
|
+
|
1910
|
+
if mode == "auto":
|
1911
|
+
self.attribute()
|
1912
|
+
overall_important_up = self.node_scores[i + 1] > threshold
|
1913
|
+
|
1914
|
+
elif mode == "manual":
|
1915
|
+
overall_important_up = torch.zeros(self.width_in[i + 1], dtype=torch.bool, device=self.device)
|
1916
|
+
overall_important_up[active_neurons_id[i]] = True
|
1917
|
+
|
1918
|
+
num_sum = torch.sum(overall_important_up[:self.width[i + 1][0]])
|
1919
|
+
num_mult = torch.sum(overall_important_up[self.width[i + 1][0]:])
|
1920
|
+
if self.mult_homo == True:
|
1921
|
+
overall_important_down = torch.cat([overall_important_up[:self.width[i + 1][0]], (
|
1922
|
+
overall_important_up[self.width[i + 1][0]:][None, :].expand(self.mult_arity, -1)).T.reshape(-1, )],
|
1923
|
+
dim=0)
|
1924
|
+
else:
|
1925
|
+
overall_important_down = overall_important_up[:self.width[i + 1][0]]
|
1926
|
+
for j in range(overall_important_up[self.width[i + 1][0]:].shape[0]):
|
1927
|
+
active_bool = overall_important_up[self.width[i + 1][0] + j]
|
1928
|
+
arity = self.mult_arity[i + 1][j]
|
1929
|
+
overall_important_down = torch.cat(
|
1930
|
+
[overall_important_down, torch.tensor([active_bool] * arity).to(self.device)])
|
1931
|
+
if active_bool:
|
1932
|
+
mult_arity.append(arity)
|
1933
|
+
|
1934
|
+
num_sums.append(num_sum.item())
|
1935
|
+
num_mults.append(num_mult.item())
|
1936
|
+
|
1937
|
+
mask_up.append(overall_important_up.float())
|
1938
|
+
mask_down.append(overall_important_down.float())
|
1939
|
+
|
1940
|
+
active_neurons_up.append(torch.where(overall_important_up == True)[0])
|
1941
|
+
active_neurons_down.append(torch.where(overall_important_down == True)[0])
|
1942
|
+
|
1943
|
+
mult_arities.append(mult_arity)
|
1944
|
+
|
1945
|
+
active_neurons_down.append(list(range(self.width_out[-1])))
|
1946
|
+
mask_down.append(torch.ones(self.width_out[-1], device=self.device))
|
1947
|
+
|
1948
|
+
if self.mult_homo == False:
|
1949
|
+
mult_arities.append(self.mult_arity[-1])
|
1950
|
+
|
1951
|
+
self.mask_up = mask_up
|
1952
|
+
self.mask_down = mask_down
|
1953
|
+
|
1954
|
+
# update act_fun[l].mask up
|
1955
|
+
for l in range(len(self.acts_scale) - 1):
|
1956
|
+
for i in range(self.width_in[l + 1]):
|
1957
|
+
if i not in active_neurons_up[l + 1]:
|
1958
|
+
self.remove_node(l + 1, i, mode='up', log_history=False)
|
1959
|
+
|
1960
|
+
for i in range(self.width_out[l + 1]):
|
1961
|
+
if i not in active_neurons_down[l]:
|
1962
|
+
self.remove_node(l + 1, i, mode='down', log_history=False)
|
1963
|
+
|
1964
|
+
model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun_name,
|
1965
|
+
mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False,
|
1966
|
+
state_id=self.state_id, round=self.round).to(self.device)
|
1967
|
+
model2.load_state_dict(self.state_dict())
|
1968
|
+
|
1969
|
+
width_new = [self.width[0]]
|
1970
|
+
|
1971
|
+
for i in range(len(self.acts_scale)):
|
1972
|
+
|
1973
|
+
if i < len(self.acts_scale) - 1:
|
1974
|
+
num_sum = num_sums[i]
|
1975
|
+
num_mult = num_mults[i]
|
1976
|
+
model2.node_bias[i].data = model2.node_bias[i].data[active_neurons_up[i + 1]]
|
1977
|
+
model2.node_scale[i].data = model2.node_scale[i].data[active_neurons_up[i + 1]]
|
1978
|
+
model2.subnode_bias[i].data = model2.subnode_bias[i].data[active_neurons_down[i]]
|
1979
|
+
model2.subnode_scale[i].data = model2.subnode_scale[i].data[active_neurons_down[i]]
|
1980
|
+
model2.width[i + 1] = [num_sum, num_mult]
|
1981
|
+
|
1982
|
+
model2.act_fun[i].out_dim_sum = num_sum
|
1983
|
+
model2.act_fun[i].out_dim_mult = num_mult
|
1984
|
+
|
1985
|
+
model2.symbolic_fun[i].out_dim_sum = num_sum
|
1986
|
+
model2.symbolic_fun[i].out_dim_mult = num_mult
|
1987
|
+
|
1988
|
+
width_new.append([num_sum, num_mult])
|
1989
|
+
|
1990
|
+
model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i])
|
1991
|
+
model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i])
|
1992
|
+
|
1993
|
+
model2.cache_data = self.cache_data
|
1994
|
+
model2.acts = None
|
1995
|
+
|
1996
|
+
width_new.append(self.width[-1])
|
1997
|
+
model2.width = width_new
|
1998
|
+
|
1999
|
+
if self.mult_homo == False:
|
2000
|
+
model2.mult_arity = mult_arities
|
2001
|
+
|
2002
|
+
if log_history:
|
2003
|
+
self.log_history('prune_node')
|
2004
|
+
model2.state_id += 1
|
2005
|
+
|
2006
|
+
return model2
|
2007
|
+
|
2008
|
+
def prune_edge(self, threshold=3e-2, log_history=True):
|
2009
|
+
'''
|
2010
|
+
pruning edges
|
2011
|
+
|
2012
|
+
Args:
|
2013
|
+
-----
|
2014
|
+
threshold : float
|
2015
|
+
if the attribution score of an edge is below the threshold, it is considered dead and will be set to zero.
|
2016
|
+
|
2017
|
+
Returns:
|
2018
|
+
--------
|
2019
|
+
pruned network : MultKAN
|
2020
|
+
|
2021
|
+
Example
|
2022
|
+
-------
|
2023
|
+
# >>> from kan import *
|
2024
|
+
# >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
|
2025
|
+
# >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
2026
|
+
# >>> dataset = create_dataset(f, n_var=2)
|
2027
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
2028
|
+
# >>> model = model.prune_edge()
|
2029
|
+
# >>> model.plot()
|
2030
|
+
'''
|
2031
|
+
if self.acts == None:
|
2032
|
+
self.get_act()
|
2033
|
+
|
2034
|
+
for i in range(len(self.width) - 1):
|
2035
|
+
#self.act_fun[i].mask.data = ((self.acts_scale[i] > threshold).permute(1,0)).float()
|
2036
|
+
old_mask = self.act_fun[i].mask.data
|
2037
|
+
self.act_fun[i].mask.data = ((self.edge_scores[i] > threshold).permute(1, 0) * old_mask).float()
|
2038
|
+
|
2039
|
+
if log_history:
|
2040
|
+
self.log_history('fix_symbolic')
|
2041
|
+
|
2042
|
+
def prune(self, node_th=1e-2, edge_th=3e-2):
|
2043
|
+
'''
|
2044
|
+
prune (both nodes and edges)
|
2045
|
+
|
2046
|
+
Args:
|
2047
|
+
-----
|
2048
|
+
node_th : float
|
2049
|
+
if the attribution score of a node is below node_th, it is considered dead and will be set to zero.
|
2050
|
+
edge_th : float
|
2051
|
+
if the attribution score of an edge is below node_th, it is considered dead and will be set to zero.
|
2052
|
+
|
2053
|
+
Returns:
|
2054
|
+
--------
|
2055
|
+
pruned network : MultKAN
|
2056
|
+
|
2057
|
+
Example
|
2058
|
+
-------
|
2059
|
+
# >>> from kan import *
|
2060
|
+
# >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
|
2061
|
+
# >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
2062
|
+
# >>> dataset = create_dataset(f, n_var=2)
|
2063
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
2064
|
+
# >>> model = model.prune()
|
2065
|
+
# >>> model.plot()
|
2066
|
+
'''
|
2067
|
+
if self.acts == None:
|
2068
|
+
self.get_act()
|
2069
|
+
|
2070
|
+
self = self.prune_node(node_th, log_history=False)
|
2071
|
+
#self.prune_node(node_th, log_history=False)
|
2072
|
+
self.forward(self.cache_data)
|
2073
|
+
self.attribute()
|
2074
|
+
self.prune_edge(edge_th, log_history=False)
|
2075
|
+
self.log_history('prune')
|
2076
|
+
return self
|
2077
|
+
|
2078
|
+
def prune_input(self, threshold=1e-2, active_inputs=None, log_history=True):
|
2079
|
+
"""
|
2080
|
+
prune inputs
|
2081
|
+
|
2082
|
+
Args:
|
2083
|
+
-----
|
2084
|
+
threshold : float
|
2085
|
+
if the attribution score of the input feature is below threshold, it is considered irrelevant.
|
2086
|
+
active_inputs : None or list
|
2087
|
+
if a list is passed, the manual mode will disregard attribution score and prune as instructed.
|
2088
|
+
|
2089
|
+
Returns:
|
2090
|
+
--------
|
2091
|
+
pruned network : MultKAN
|
2092
|
+
|
2093
|
+
Example1
|
2094
|
+
--------
|
2095
|
+
>>> # automatic
|
2096
|
+
# >>> from kan import *
|
2097
|
+
# >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
|
2098
|
+
# >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
|
2099
|
+
# >>> dataset = create_dataset(f, n_var=3)
|
2100
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
2101
|
+
# >>> model.plot()
|
2102
|
+
# >>> model = model.prune_input()
|
2103
|
+
# >>> model.plot()
|
2104
|
+
|
2105
|
+
Example2
|
2106
|
+
--------
|
2107
|
+
>>> # automatic
|
2108
|
+
# >>> from kan import *
|
2109
|
+
# >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
|
2110
|
+
# >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
|
2111
|
+
# >>> dataset = create_dataset(f, n_var=3)
|
2112
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
2113
|
+
# >>> model.plot()
|
2114
|
+
# >>> model = model.prune_input(active_inputs=[0,1])
|
2115
|
+
# >>> model.plot()
|
2116
|
+
"""
|
2117
|
+
if active_inputs is None:
|
2118
|
+
self.attribute()
|
2119
|
+
input_score = self.node_scores[0]
|
2120
|
+
input_mask = input_score > threshold
|
2121
|
+
print('keep:', input_mask.tolist())
|
2122
|
+
input_id = torch.where(input_mask == True)[0]
|
2123
|
+
|
2124
|
+
else:
|
2125
|
+
input_id = torch.tensor(active_inputs, dtype=torch.long).to(self.device)
|
2126
|
+
|
2127
|
+
model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun,
|
2128
|
+
mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False,
|
2129
|
+
state_id=self.state_id, round=self.round).to(self.device)
|
2130
|
+
model2.load_state_dict(self.state_dict())
|
2131
|
+
|
2132
|
+
model2.act_fun[0] = model2.act_fun[0].get_subset(input_id, torch.arange(self.width_out[1]))
|
2133
|
+
model2.symbolic_fun[0] = self.symbolic_fun[0].get_subset(input_id, torch.arange(self.width_out[1]))
|
2134
|
+
|
2135
|
+
model2.cache_data = self.cache_data
|
2136
|
+
model2.acts = None
|
2137
|
+
|
2138
|
+
model2.width[0] = [len(input_id), 0]
|
2139
|
+
model2.input_id = input_id
|
2140
|
+
|
2141
|
+
if log_history:
|
2142
|
+
self.log_history('prune_input')
|
2143
|
+
model2.state_id += 1
|
2144
|
+
|
2145
|
+
return model2
|
2146
|
+
|
2147
|
+
def remove_edge(self, l, i, j, log_history=True):
|
2148
|
+
"""
|
2149
|
+
remove activtion phi(l,i,j) (set its mask to zero)
|
2150
|
+
"""
|
2151
|
+
self.act_fun[l].mask[i][j] = 0.
|
2152
|
+
if log_history:
|
2153
|
+
self.log_history('remove_edge')
|
2154
|
+
|
2155
|
+
def remove_node(self, l, i, mode='all', log_history=True):
|
2156
|
+
"""
|
2157
|
+
remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero)
|
2158
|
+
"""
|
2159
|
+
if mode == 'down':
|
2160
|
+
self.act_fun[l - 1].mask[:, i] = 0.
|
2161
|
+
self.symbolic_fun[l - 1].mask[i, :] *= 0.
|
2162
|
+
|
2163
|
+
elif mode == 'up':
|
2164
|
+
self.act_fun[l].mask[i, :] = 0.
|
2165
|
+
self.symbolic_fun[l].mask[:, i] *= 0.
|
2166
|
+
|
2167
|
+
else:
|
2168
|
+
self.remove_node(l, i, mode='up')
|
2169
|
+
self.remove_node(l, i, mode='down')
|
2170
|
+
|
2171
|
+
if log_history:
|
2172
|
+
self.log_history('remove_node')
|
2173
|
+
|
2174
|
+
def attribute(self, l=None, i=None, out_score=None, plot=True):
|
2175
|
+
"""
|
2176
|
+
get attribution scores
|
2177
|
+
|
2178
|
+
Args:
|
2179
|
+
-----
|
2180
|
+
l : None or int
|
2181
|
+
layer index
|
2182
|
+
i : None or int
|
2183
|
+
neuron index
|
2184
|
+
out_score : None or 1D torch.float
|
2185
|
+
specify output scores
|
2186
|
+
plot : bool
|
2187
|
+
when plot = True, display the bar show
|
2188
|
+
|
2189
|
+
Returns:
|
2190
|
+
--------
|
2191
|
+
attribution scores
|
2192
|
+
|
2193
|
+
Example
|
2194
|
+
-------
|
2195
|
+
# >>> from kan import *
|
2196
|
+
# >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
|
2197
|
+
# >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
|
2198
|
+
# >>> dataset = create_dataset(f, n_var=3)
|
2199
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
2200
|
+
# >>> model.attribute()
|
2201
|
+
# >>> model.feature_score
|
2202
|
+
"""
|
2203
|
+
# output (out_dim, in_dim)
|
2204
|
+
|
2205
|
+
if l is not None:
|
2206
|
+
self.attribute()
|
2207
|
+
out_score = self.node_scores[l]
|
2208
|
+
|
2209
|
+
if self.acts == None:
|
2210
|
+
self.get_act()
|
2211
|
+
|
2212
|
+
def score_node2subnode(node_score, width, mult_arity, out_dim):
|
2213
|
+
|
2214
|
+
assert np.sum(width) == node_score.shape[1]
|
2215
|
+
if isinstance(mult_arity, int):
|
2216
|
+
n_subnode = width[0] + mult_arity * width[1]
|
2217
|
+
else:
|
2218
|
+
n_subnode = width[0] + int(np.sum(mult_arity))
|
2219
|
+
|
2220
|
+
#subnode_score_leaf = torch.zeros(out_dim, n_subnode).requires_grad_(True)
|
2221
|
+
#subnode_score = subnode_score_leaf.clone()
|
2222
|
+
#subnode_score[:,:width[0]] = node_score[:,:width[0]]
|
2223
|
+
subnode_score = node_score[:, :width[0]]
|
2224
|
+
if isinstance(mult_arity, int):
|
2225
|
+
#subnode_score[:,width[0]:] = node_score[:,width[0]:][:,:,None].expand(out_dim, node_score[width[0]:].shape[0], mult_arity).reshape(out_dim,-1)
|
2226
|
+
subnode_score = torch.cat([subnode_score, node_score[:, width[0]:][:, :, None].expand(out_dim,
|
2227
|
+
node_score[:,
|
2228
|
+
width[0]:].shape[
|
2229
|
+
1],
|
2230
|
+
mult_arity).reshape(
|
2231
|
+
out_dim, -1)], dim=1)
|
2232
|
+
else:
|
2233
|
+
acml = width[0]
|
2234
|
+
for i in range(len(mult_arity)):
|
2235
|
+
#subnode_score[:, acml:acml+mult_arity[i]] = node_score[:, width[0]+i]
|
2236
|
+
subnode_score = torch.cat(
|
2237
|
+
[subnode_score, node_score[:, width[0] + i].expand(out_dim, mult_arity[i])], dim=1)
|
2238
|
+
acml += mult_arity[i]
|
2239
|
+
return subnode_score
|
2240
|
+
|
2241
|
+
node_scores = []
|
2242
|
+
subnode_scores = []
|
2243
|
+
edge_scores = []
|
2244
|
+
|
2245
|
+
l_query = l
|
2246
|
+
if l is None:
|
2247
|
+
l_end = self.depth
|
2248
|
+
else:
|
2249
|
+
l_end = l
|
2250
|
+
|
2251
|
+
# back propagate from the queried layer
|
2252
|
+
out_dim = self.width_in[l_end]
|
2253
|
+
if out_score is None:
|
2254
|
+
node_score = torch.eye(out_dim).requires_grad_(True)
|
2255
|
+
else:
|
2256
|
+
node_score = torch.diag(out_score).requires_grad_(True)
|
2257
|
+
node_scores.append(node_score)
|
2258
|
+
|
2259
|
+
device = self.act_fun[0].grid.device
|
2260
|
+
|
2261
|
+
for l in range(l_end, 0, -1):
|
2262
|
+
|
2263
|
+
# node to subnode
|
2264
|
+
if isinstance(self.mult_arity, int):
|
2265
|
+
subnode_score = score_node2subnode(node_score, self.width[l], self.mult_arity, out_dim=out_dim)
|
2266
|
+
else:
|
2267
|
+
mult_arity = self.mult_arity[l]
|
2268
|
+
#subnode_score = score_node2subnode(node_score, self.width[l], mult_arity)
|
2269
|
+
subnode_score = score_node2subnode(node_score, self.width[l], mult_arity, out_dim=out_dim)
|
2270
|
+
|
2271
|
+
subnode_scores.append(subnode_score)
|
2272
|
+
# subnode to edge
|
2273
|
+
#print(self.edge_actscale[l-1].device, subnode_score.device, self.subnode_actscale[l-1].device)
|
2274
|
+
edge_score = torch.einsum('ij,ki,i->kij', self.edge_actscale[l - 1], subnode_score.to(device),
|
2275
|
+
1 / (self.subnode_actscale[l - 1] + 1e-4))
|
2276
|
+
edge_scores.append(edge_score)
|
2277
|
+
|
2278
|
+
# edge to node
|
2279
|
+
node_score = torch.sum(edge_score, dim=1)
|
2280
|
+
node_scores.append(node_score)
|
2281
|
+
|
2282
|
+
self.node_scores_all = list(reversed(node_scores))
|
2283
|
+
self.edge_scores_all = list(reversed(edge_scores))
|
2284
|
+
self.subnode_scores_all = list(reversed(subnode_scores))
|
2285
|
+
|
2286
|
+
self.node_scores = [torch.mean(l, dim=0) for l in self.node_scores_all]
|
2287
|
+
self.edge_scores = [torch.mean(l, dim=0) for l in self.edge_scores_all]
|
2288
|
+
self.subnode_scores = [torch.mean(l, dim=0) for l in self.subnode_scores_all]
|
2289
|
+
|
2290
|
+
# return
|
2291
|
+
if l_query != None:
|
2292
|
+
if i == None:
|
2293
|
+
return self.node_scores_all[0]
|
2294
|
+
else:
|
2295
|
+
|
2296
|
+
# plot
|
2297
|
+
if plot:
|
2298
|
+
in_dim = self.width_in[0]
|
2299
|
+
plt.figure(figsize=(1 * in_dim, 3))
|
2300
|
+
plt.bar(range(in_dim), self.node_scores_all[0][i].cpu().detach().numpy())
|
2301
|
+
plt.xticks(range(in_dim));
|
2302
|
+
|
2303
|
+
return self.node_scores_all[0][i]
|
2304
|
+
|
2305
|
+
def node_attribute(self):
|
2306
|
+
self.node_attribute_scores = []
|
2307
|
+
for l in range(1, self.depth + 1):
|
2308
|
+
node_attr = self.attribute(l)
|
2309
|
+
self.node_attribute_scores.append(node_attr)
|
2310
|
+
|
2311
|
+
def feature_interaction(self, l, neuron_th=1e-2, feature_th=1e-2):
|
2312
|
+
"""
|
2313
|
+
get feature interaction
|
2314
|
+
|
2315
|
+
Args:
|
2316
|
+
-----
|
2317
|
+
l : int
|
2318
|
+
layer index
|
2319
|
+
neuron_th : float
|
2320
|
+
threshold to determine whether a neuron is active
|
2321
|
+
feature_th : float
|
2322
|
+
threshold to determine whether a feature is active
|
2323
|
+
|
2324
|
+
Returns:
|
2325
|
+
--------
|
2326
|
+
dictionary
|
2327
|
+
|
2328
|
+
Example
|
2329
|
+
-------
|
2330
|
+
# >>> from kan import *
|
2331
|
+
# >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
|
2332
|
+
# >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
|
2333
|
+
# >>> dataset = create_dataset(f, n_var=3)
|
2334
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
2335
|
+
# >>> model.attribute()
|
2336
|
+
# >>> model.feature_interaction(1)
|
2337
|
+
"""
|
2338
|
+
dic = {}
|
2339
|
+
width = self.width_in[l]
|
2340
|
+
|
2341
|
+
for i in range(width):
|
2342
|
+
score = self.attribute(l, i, plot=False)
|
2343
|
+
|
2344
|
+
if torch.max(score) > neuron_th:
|
2345
|
+
features = tuple(torch.where(score > torch.max(score) * feature_th)[0].detach().numpy())
|
2346
|
+
if features in dic.keys():
|
2347
|
+
dic[features] += 1
|
2348
|
+
else:
|
2349
|
+
dic[features] = 1
|
2350
|
+
|
2351
|
+
return dic
|
2352
|
+
|
2353
|
+
def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True,
|
2354
|
+
r2_loss_fun=lambda x: np.log2(1 + 1e-5 - x), c_loss_fun=lambda x: x, weight_simple=0.8):
|
2355
|
+
"""
|
2356
|
+
suggest symbolic function
|
2357
|
+
|
2358
|
+
Args:
|
2359
|
+
-----
|
2360
|
+
l : int
|
2361
|
+
layer index
|
2362
|
+
i : int
|
2363
|
+
neuron index in layer l
|
2364
|
+
j : int
|
2365
|
+
neuron index in layer j
|
2366
|
+
a_range : tuple
|
2367
|
+
search range of a
|
2368
|
+
b_range : tuple
|
2369
|
+
search range of b
|
2370
|
+
lib : list of str
|
2371
|
+
library of candidate symbolic functions
|
2372
|
+
topk : int
|
2373
|
+
the number of top functions displayed
|
2374
|
+
verbose : bool
|
2375
|
+
if verbose = True, print more information
|
2376
|
+
r2_loss_fun : functoon
|
2377
|
+
function : r2 -> "bits"
|
2378
|
+
c_loss_fun : fun
|
2379
|
+
function : c -> 'bits'
|
2380
|
+
weight_simple : float
|
2381
|
+
the simplifty weight: the higher, more prefer simplicity over performance
|
2382
|
+
|
2383
|
+
|
2384
|
+
Returns:
|
2385
|
+
--------
|
2386
|
+
best_name (str), best_fun (function), best_r2 (float), best_c (float)
|
2387
|
+
|
2388
|
+
Example
|
2389
|
+
-------
|
2390
|
+
# >>> from kan import *
|
2391
|
+
# >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
|
2392
|
+
# >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
|
2393
|
+
# >>> dataset = create_dataset(f, n_var=3)
|
2394
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
2395
|
+
# >>> model.suggest_symbolic(0,1,0)
|
2396
|
+
"""
|
2397
|
+
r2s = []
|
2398
|
+
cs = []
|
2399
|
+
|
2400
|
+
if lib == None:
|
2401
|
+
symbolic_lib = SYMBOLIC_LIB
|
2402
|
+
else:
|
2403
|
+
symbolic_lib = {}
|
2404
|
+
for item in lib:
|
2405
|
+
symbolic_lib[item] = SYMBOLIC_LIB[item]
|
2406
|
+
|
2407
|
+
# getting r2 and complexities
|
2408
|
+
for (name, content) in symbolic_lib.items():
|
2409
|
+
r2 = self.fix_symbolic(l, i, j, name, a_range=a_range, b_range=b_range, verbose=False, log_history=False)
|
2410
|
+
if r2 == -1e8: # zero function
|
2411
|
+
r2s.append(-1e8)
|
2412
|
+
else:
|
2413
|
+
r2s.append(r2.item())
|
2414
|
+
self.unfix_symbolic(l, i, j, log_history=False)
|
2415
|
+
c = content[2]
|
2416
|
+
cs.append(c)
|
2417
|
+
|
2418
|
+
r2s = np.array(r2s)
|
2419
|
+
cs = np.array(cs)
|
2420
|
+
r2_loss = r2_loss_fun(r2s).astype('float')
|
2421
|
+
cs_loss = c_loss_fun(cs)
|
2422
|
+
|
2423
|
+
loss = weight_simple * cs_loss + (1 - weight_simple) * r2_loss
|
2424
|
+
|
2425
|
+
sorted_ids = np.argsort(loss)[:topk]
|
2426
|
+
r2s = r2s[sorted_ids][:topk]
|
2427
|
+
cs = cs[sorted_ids][:topk]
|
2428
|
+
r2_loss = r2_loss[sorted_ids][:topk]
|
2429
|
+
cs_loss = cs_loss[sorted_ids][:topk]
|
2430
|
+
loss = loss[sorted_ids][:topk]
|
2431
|
+
|
2432
|
+
topk = np.minimum(topk, len(symbolic_lib))
|
2433
|
+
|
2434
|
+
if verbose == True:
|
2435
|
+
# print results in a dataframe
|
2436
|
+
results = {}
|
2437
|
+
results['function'] = [list(symbolic_lib.items())[sorted_ids[i]][0] for i in range(topk)]
|
2438
|
+
results['fitting r2'] = r2s[:topk]
|
2439
|
+
results['r2 loss'] = r2_loss[:topk]
|
2440
|
+
results['complexity'] = cs[:topk]
|
2441
|
+
results['complexity loss'] = cs_loss[:topk]
|
2442
|
+
results['total loss'] = loss[:topk]
|
2443
|
+
|
2444
|
+
df = pd.DataFrame(results)
|
2445
|
+
print(df)
|
2446
|
+
|
2447
|
+
best_name = list(symbolic_lib.items())[sorted_ids[0]][0]
|
2448
|
+
best_fun = list(symbolic_lib.items())[sorted_ids[0]][1]
|
2449
|
+
best_r2 = r2s[0]
|
2450
|
+
best_c = cs[0]
|
2451
|
+
|
2452
|
+
return best_name, best_fun, best_r2, best_c;
|
2453
|
+
|
2454
|
+
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple=0.8,
|
2455
|
+
r2_threshold=0.0):
|
2456
|
+
"""
|
2457
|
+
automatic symbolic regression for all edges
|
2458
|
+
|
2459
|
+
Args:
|
2460
|
+
-----
|
2461
|
+
a_range : tuple
|
2462
|
+
search range of a
|
2463
|
+
b_range : tuple
|
2464
|
+
search range of b
|
2465
|
+
lib : list of str
|
2466
|
+
library of candidate symbolic functions
|
2467
|
+
verbose : int
|
2468
|
+
larger verbosity => more verbosity
|
2469
|
+
weight_simple : float
|
2470
|
+
a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
|
2471
|
+
r2_threshold : float
|
2472
|
+
If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
|
2473
|
+
Returns:
|
2474
|
+
--------
|
2475
|
+
None
|
2476
|
+
|
2477
|
+
Example
|
2478
|
+
-------
|
2479
|
+
# >>> from kan import *
|
2480
|
+
# >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
|
2481
|
+
# >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
|
2482
|
+
# >>> dataset = create_dataset(f, n_var=3)
|
2483
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
2484
|
+
# >>> model.auto_symbolic()
|
2485
|
+
"""
|
2486
|
+
for l in range(len(self.width_in) - 1):
|
2487
|
+
for i in range(self.width_in[l]):
|
2488
|
+
for j in range(self.width_out[l + 1]):
|
2489
|
+
if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
|
2490
|
+
print(f'skipping ({l},{i},{j}) since already symbolic')
|
2491
|
+
elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.:
|
2492
|
+
self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False)
|
2493
|
+
print(f'fixing ({l},{i},{j}) with 0')
|
2494
|
+
else:
|
2495
|
+
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib,
|
2496
|
+
verbose=False, weight_simple=weight_simple)
|
2497
|
+
if r2 >= r2_threshold:
|
2498
|
+
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
|
2499
|
+
if verbose >= 1:
|
2500
|
+
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
|
2501
|
+
else:
|
2502
|
+
print(
|
2503
|
+
f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.')
|
2504
|
+
|
2505
|
+
self.log_history('auto_symbolic')
|
2506
|
+
|
2507
|
+
def symbolic_formula(self, var=None, normalizer=None, output_normalizer=None):
|
2508
|
+
"""
|
2509
|
+
get symbolic formula
|
2510
|
+
|
2511
|
+
Args:
|
2512
|
+
-----
|
2513
|
+
var : None or a list of sympy expression
|
2514
|
+
input variables
|
2515
|
+
normalizer : [mean, std]
|
2516
|
+
output_normalizer : [mean, std]
|
2517
|
+
|
2518
|
+
Returns:
|
2519
|
+
--------
|
2520
|
+
None
|
2521
|
+
|
2522
|
+
Example
|
2523
|
+
-------
|
2524
|
+
# >>> from kan import *
|
2525
|
+
# >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
|
2526
|
+
# >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
|
2527
|
+
# >>> dataset = create_dataset(f, n_var=3)
|
2528
|
+
# >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
|
2529
|
+
# >>> model.auto_symbolic()
|
2530
|
+
# >>> model.symbolic_formula()[0][0]
|
2531
|
+
"""
|
2532
|
+
|
2533
|
+
symbolic_acts = []
|
2534
|
+
symbolic_acts_premult = []
|
2535
|
+
x = []
|
2536
|
+
|
2537
|
+
def ex_round(ex1, n_digit):
|
2538
|
+
ex2 = ex1
|
2539
|
+
for a in sympy.preorder_traversal(ex1):
|
2540
|
+
if isinstance(a, sympy.Float):
|
2541
|
+
ex2 = ex2.subs(a, round(a, n_digit))
|
2542
|
+
return ex2
|
2543
|
+
|
2544
|
+
# define variables
|
2545
|
+
if var is None:
|
2546
|
+
for ii in range(1, self.width[0][0] + 1):
|
2547
|
+
exec(f"x{ii} = sympy.Symbol('x_{ii}')")
|
2548
|
+
exec(f"x.append(x{ii})")
|
2549
|
+
elif isinstance(var[0], sympy.Expr):
|
2550
|
+
x = var
|
2551
|
+
else:
|
2552
|
+
x = [sympy.symbols(var_) for var_ in var]
|
2553
|
+
|
2554
|
+
x0 = x
|
2555
|
+
|
2556
|
+
if normalizer is not None:
|
2557
|
+
mean = normalizer[0]
|
2558
|
+
std = normalizer[1]
|
2559
|
+
x = [(x[i] - mean[i]) / std[i] for i in range(len(x))]
|
2560
|
+
|
2561
|
+
symbolic_acts.append(x)
|
2562
|
+
|
2563
|
+
for l in range(len(self.width_in) - 1):
|
2564
|
+
num_sum = self.width[l + 1][0]
|
2565
|
+
num_mult = self.width[l + 1][1]
|
2566
|
+
y = []
|
2567
|
+
for j in range(self.width_out[l + 1]):
|
2568
|
+
yj = 0.
|
2569
|
+
for i in range(self.width_in[l]):
|
2570
|
+
a, b, c, d = self.symbolic_fun[l].affine[j, i]
|
2571
|
+
sympy_fun = self.symbolic_fun[l].funs_sympy[j][i]
|
2572
|
+
try:
|
2573
|
+
yj += c * sympy_fun(a * x[i] + b) + d
|
2574
|
+
except:
|
2575
|
+
print('make sure all activations need to be converted to symbolic formulas first!')
|
2576
|
+
return
|
2577
|
+
yj = self.subnode_scale[l][j] * yj + self.subnode_bias[l][j]
|
2578
|
+
if simplify == True:
|
2579
|
+
y.append(sympy.simplify(yj))
|
2580
|
+
else:
|
2581
|
+
y.append(yj)
|
2582
|
+
|
2583
|
+
symbolic_acts_premult.append(y)
|
2584
|
+
|
2585
|
+
mult = []
|
2586
|
+
for k in range(num_mult):
|
2587
|
+
if isinstance(self.mult_arity, int):
|
2588
|
+
mult_arity = self.mult_arity
|
2589
|
+
else:
|
2590
|
+
mult_arity = self.mult_arity[l + 1][k]
|
2591
|
+
for i in range(mult_arity - 1):
|
2592
|
+
if i == 0:
|
2593
|
+
mult_k = y[num_sum + 2 * k] * y[num_sum + 2 * k + 1]
|
2594
|
+
else:
|
2595
|
+
mult_k = mult_k * y[num_sum + 2 * k + i + 1]
|
2596
|
+
mult.append(mult_k)
|
2597
|
+
|
2598
|
+
y = y[:num_sum] + mult
|
2599
|
+
|
2600
|
+
for j in range(self.width_in[l + 1]):
|
2601
|
+
y[j] = self.node_scale[l][j] * y[j] + self.node_bias[l][j]
|
2602
|
+
|
2603
|
+
x = y
|
2604
|
+
symbolic_acts.append(x)
|
2605
|
+
|
2606
|
+
if output_normalizer != None:
|
2607
|
+
output_layer = symbolic_acts[-1]
|
2608
|
+
means = output_normalizer[0]
|
2609
|
+
stds = output_normalizer[1]
|
2610
|
+
|
2611
|
+
assert len(output_layer) == len(means), 'output_normalizer does not match the output layer'
|
2612
|
+
assert len(output_layer) == len(stds), 'output_normalizer does not match the output layer'
|
2613
|
+
|
2614
|
+
output_layer = [(output_layer[i] * stds[i] + means[i]) for i in range(len(output_layer))]
|
2615
|
+
symbolic_acts[-1] = output_layer
|
2616
|
+
|
2617
|
+
self.symbolic_acts = [[symbolic_acts[l][i] for i in range(len(symbolic_acts[l]))] for l in
|
2618
|
+
range(len(symbolic_acts))]
|
2619
|
+
self.symbolic_acts_premult = [[symbolic_acts_premult[l][i] for i in range(len(symbolic_acts_premult[l]))] for l
|
2620
|
+
in range(len(symbolic_acts_premult))]
|
2621
|
+
|
2622
|
+
out_dim = len(symbolic_acts[-1])
|
2623
|
+
#return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0
|
2624
|
+
|
2625
|
+
if simplify:
|
2626
|
+
return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0
|
2627
|
+
else:
|
2628
|
+
return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0
|
2629
|
+
|
2630
|
+
def expand_depth(self):
|
2631
|
+
'''
|
2632
|
+
expand network depth, add an indentity layer to the end. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
|
2633
|
+
|
2634
|
+
Args:
|
2635
|
+
-----
|
2636
|
+
var : None or a list of sympy expression
|
2637
|
+
input variables
|
2638
|
+
normalizer : [mean, std]
|
2639
|
+
output_normalizer : [mean, std]
|
2640
|
+
|
2641
|
+
Returns:
|
2642
|
+
--------
|
2643
|
+
None
|
2644
|
+
'''
|
2645
|
+
self.depth += 1
|
2646
|
+
|
2647
|
+
# add kanlayer, set mask to zero
|
2648
|
+
dim_out = self.width_in[-1]
|
2649
|
+
layer = KANLayer(dim_out, dim_out, num=self.grid, k=self.k)
|
2650
|
+
layer.mask *= 0.
|
2651
|
+
self.act_fun.append(layer)
|
2652
|
+
|
2653
|
+
self.width.append([dim_out, 0])
|
2654
|
+
self.mult_arity.append([])
|
2655
|
+
|
2656
|
+
# add symbolic_kanlayer set mask to one. fun = identity on diagonal and zero for off-diagonal
|
2657
|
+
layer = Symbolic_KANLayer(dim_out, dim_out)
|
2658
|
+
layer.mask += 1.
|
2659
|
+
|
2660
|
+
for j in range(dim_out):
|
2661
|
+
for i in range(dim_out):
|
2662
|
+
if i == j:
|
2663
|
+
layer.fix_symbolic(i, j, 'x')
|
2664
|
+
else:
|
2665
|
+
layer.fix_symbolic(i, j, '0')
|
2666
|
+
|
2667
|
+
self.symbolic_fun.append(layer)
|
2668
|
+
|
2669
|
+
self.node_bias.append(
|
2670
|
+
torch.nn.Parameter(torch.zeros(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
|
2671
|
+
self.node_scale.append(
|
2672
|
+
torch.nn.Parameter(torch.ones(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
|
2673
|
+
self.subnode_bias.append(
|
2674
|
+
torch.nn.Parameter(torch.zeros(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
|
2675
|
+
self.subnode_scale.append(
|
2676
|
+
torch.nn.Parameter(torch.ones(dim_out, device=self.device)).requires_grad_(self.affine_trainable))
|
2677
|
+
|
2678
|
+
def expand_width(self, layer_id, n_added_nodes, sum_bool=True, mult_arity=2):
|
2679
|
+
'''
|
2680
|
+
expand network width. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
|
2681
|
+
|
2682
|
+
Args:
|
2683
|
+
-----
|
2684
|
+
layer_id : int
|
2685
|
+
layer index
|
2686
|
+
n_added_nodes : init
|
2687
|
+
the number of added nodes
|
2688
|
+
sum_bool : bool
|
2689
|
+
if sum_bool == True, added nodes are addition nodes; otherwise multiplication nodes
|
2690
|
+
mult_arity : init
|
2691
|
+
multiplication arity (the number of numbers to be multiplied)
|
2692
|
+
|
2693
|
+
Returns:
|
2694
|
+
--------
|
2695
|
+
None
|
2696
|
+
'''
|
2697
|
+
|
2698
|
+
def _expand(layer_id, n_added_nodes, sum_bool=True, mult_arity=2, added_dim='out'):
|
2699
|
+
l = layer_id
|
2700
|
+
in_dim = self.symbolic_fun[l].in_dim
|
2701
|
+
out_dim = self.symbolic_fun[l].out_dim
|
2702
|
+
if sum_bool:
|
2703
|
+
|
2704
|
+
if added_dim == 'out':
|
2705
|
+
new = Symbolic_KANLayer(in_dim, out_dim + n_added_nodes)
|
2706
|
+
old = self.symbolic_fun[l]
|
2707
|
+
in_id = np.arange(in_dim)
|
2708
|
+
out_id = np.arange(out_dim + n_added_nodes)
|
2709
|
+
|
2710
|
+
for j in out_id:
|
2711
|
+
for i in in_id:
|
2712
|
+
new.fix_symbolic(i, j, '0')
|
2713
|
+
new.mask += 1.
|
2714
|
+
|
2715
|
+
for j in out_id:
|
2716
|
+
for i in in_id:
|
2717
|
+
if j > n_added_nodes - 1:
|
2718
|
+
new.funs[j][i] = old.funs[j - n_added_nodes][i]
|
2719
|
+
new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j - n_added_nodes][i]
|
2720
|
+
new.funs_sympy[j][i] = old.funs_sympy[j - n_added_nodes][i]
|
2721
|
+
new.funs_name[j][i] = old.funs_name[j - n_added_nodes][i]
|
2722
|
+
new.affine.data[j][i] = old.affine.data[j - n_added_nodes][i]
|
2723
|
+
|
2724
|
+
self.symbolic_fun[l] = new
|
2725
|
+
self.act_fun[l] = KANLayer(in_dim, out_dim + n_added_nodes, num=self.grid, k=self.k)
|
2726
|
+
self.act_fun[l].mask *= 0.
|
2727
|
+
|
2728
|
+
self.node_scale[l].data = torch.cat(
|
2729
|
+
[torch.ones(n_added_nodes, device=self.device), self.node_scale[l].data])
|
2730
|
+
self.node_bias[l].data = torch.cat(
|
2731
|
+
[torch.zeros(n_added_nodes, device=self.device), self.node_bias[l].data])
|
2732
|
+
self.subnode_scale[l].data = torch.cat(
|
2733
|
+
[torch.ones(n_added_nodes, device=self.device), self.subnode_scale[l].data])
|
2734
|
+
self.subnode_bias[l].data = torch.cat(
|
2735
|
+
[torch.zeros(n_added_nodes, device=self.device), self.subnode_bias[l].data])
|
2736
|
+
|
2737
|
+
if added_dim == 'in':
|
2738
|
+
new = Symbolic_KANLayer(in_dim + n_added_nodes, out_dim)
|
2739
|
+
old = self.symbolic_fun[l]
|
2740
|
+
in_id = np.arange(in_dim + n_added_nodes)
|
2741
|
+
out_id = np.arange(out_dim)
|
2742
|
+
|
2743
|
+
for j in out_id:
|
2744
|
+
for i in in_id:
|
2745
|
+
new.fix_symbolic(i, j, '0')
|
2746
|
+
new.mask += 1.
|
2747
|
+
|
2748
|
+
for j in out_id:
|
2749
|
+
for i in in_id:
|
2750
|
+
if i > n_added_nodes - 1:
|
2751
|
+
new.funs[j][i] = old.funs[j][i - n_added_nodes]
|
2752
|
+
new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i - n_added_nodes]
|
2753
|
+
new.funs_sympy[j][i] = old.funs_sympy[j][i - n_added_nodes]
|
2754
|
+
new.funs_name[j][i] = old.funs_name[j][i - n_added_nodes]
|
2755
|
+
new.affine.data[j][i] = old.affine.data[j][i - n_added_nodes]
|
2756
|
+
|
2757
|
+
self.symbolic_fun[l] = new
|
2758
|
+
self.act_fun[l] = KANLayer(in_dim + n_added_nodes, out_dim, num=self.grid, k=self.k)
|
2759
|
+
self.act_fun[l].mask *= 0.
|
2760
|
+
|
2761
|
+
|
2762
|
+
else:
|
2763
|
+
|
2764
|
+
if isinstance(mult_arity, int):
|
2765
|
+
mult_arity = [mult_arity] * n_added_nodes
|
2766
|
+
|
2767
|
+
if added_dim == 'out':
|
2768
|
+
n_added_subnodes = np.sum(mult_arity)
|
2769
|
+
new = Symbolic_KANLayer(in_dim, out_dim + n_added_subnodes)
|
2770
|
+
old = self.symbolic_fun[l]
|
2771
|
+
in_id = np.arange(in_dim)
|
2772
|
+
out_id = np.arange(out_dim + n_added_nodes)
|
2773
|
+
|
2774
|
+
for j in out_id:
|
2775
|
+
for i in in_id:
|
2776
|
+
new.fix_symbolic(i, j, '0')
|
2777
|
+
new.mask += 1.
|
2778
|
+
|
2779
|
+
for j in out_id:
|
2780
|
+
for i in in_id:
|
2781
|
+
if j < out_dim:
|
2782
|
+
new.funs[j][i] = old.funs[j][i]
|
2783
|
+
new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i]
|
2784
|
+
new.funs_sympy[j][i] = old.funs_sympy[j][i]
|
2785
|
+
new.funs_name[j][i] = old.funs_name[j][i]
|
2786
|
+
new.affine.data[j][i] = old.affine.data[j][i]
|
2787
|
+
|
2788
|
+
self.symbolic_fun[l] = new
|
2789
|
+
self.act_fun[l] = KANLayer(in_dim, out_dim + n_added_subnodes, num=self.grid, k=self.k)
|
2790
|
+
self.act_fun[l].mask *= 0.
|
2791
|
+
|
2792
|
+
self.node_scale[l].data = torch.cat(
|
2793
|
+
[self.node_scale[l].data, torch.ones(n_added_nodes, device=self.device)])
|
2794
|
+
self.node_bias[l].data = torch.cat(
|
2795
|
+
[self.node_bias[l].data, torch.zeros(n_added_nodes, device=self.device)])
|
2796
|
+
self.subnode_scale[l].data = torch.cat(
|
2797
|
+
[self.subnode_scale[l].data, torch.ones(n_added_subnodes, device=self.device)])
|
2798
|
+
self.subnode_bias[l].data = torch.cat(
|
2799
|
+
[self.subnode_bias[l].data, torch.zeros(n_added_subnodes, device=self.device)])
|
2800
|
+
|
2801
|
+
if added_dim == 'in':
|
2802
|
+
new = Symbolic_KANLayer(in_dim + n_added_nodes, out_dim)
|
2803
|
+
old = self.symbolic_fun[l]
|
2804
|
+
in_id = np.arange(in_dim + n_added_nodes)
|
2805
|
+
out_id = np.arange(out_dim)
|
2806
|
+
|
2807
|
+
for j in out_id:
|
2808
|
+
for i in in_id:
|
2809
|
+
new.fix_symbolic(i, j, '0')
|
2810
|
+
new.mask += 1.
|
2811
|
+
|
2812
|
+
for j in out_id:
|
2813
|
+
for i in in_id:
|
2814
|
+
if i < in_dim:
|
2815
|
+
new.funs[j][i] = old.funs[j][i]
|
2816
|
+
new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i]
|
2817
|
+
new.funs_sympy[j][i] = old.funs_sympy[j][i]
|
2818
|
+
new.funs_name[j][i] = old.funs_name[j][i]
|
2819
|
+
new.affine.data[j][i] = old.affine.data[j][i]
|
2820
|
+
|
2821
|
+
self.symbolic_fun[l] = new
|
2822
|
+
self.act_fun[l] = KANLayer(in_dim + n_added_nodes, out_dim, num=self.grid, k=self.k)
|
2823
|
+
self.act_fun[l].mask *= 0.
|
2824
|
+
|
2825
|
+
_expand(layer_id - 1, n_added_nodes, sum_bool, mult_arity, added_dim='out')
|
2826
|
+
_expand(layer_id, n_added_nodes, sum_bool, mult_arity, added_dim='in')
|
2827
|
+
if sum_bool:
|
2828
|
+
self.width[layer_id][0] += n_added_nodes
|
2829
|
+
else:
|
2830
|
+
if isinstance(mult_arity, int):
|
2831
|
+
mult_arity = [mult_arity] * n_added_nodes
|
2832
|
+
|
2833
|
+
self.width[layer_id][1] += n_added_nodes
|
2834
|
+
self.mult_arity[layer_id] += mult_arity
|
2835
|
+
|
2836
|
+
def perturb(self, mag=1.0, mode='non-intrusive'):
|
2837
|
+
"""
|
2838
|
+
preturb a network. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
|
2839
|
+
|
2840
|
+
Args:
|
2841
|
+
-----
|
2842
|
+
mag : float
|
2843
|
+
perturbation magnitude
|
2844
|
+
mode : str
|
2845
|
+
pertubatation mode, choices = {'non-intrusive', 'all', 'minimal'}
|
2846
|
+
|
2847
|
+
Returns:
|
2848
|
+
--------
|
2849
|
+
None
|
2850
|
+
"""
|
2851
|
+
perturb_bool = {}
|
2852
|
+
|
2853
|
+
if mode == 'all':
|
2854
|
+
perturb_bool['aa_a'] = True
|
2855
|
+
perturb_bool['aa_i'] = True
|
2856
|
+
perturb_bool['ai'] = True
|
2857
|
+
perturb_bool['ia'] = True
|
2858
|
+
perturb_bool['ii'] = True
|
2859
|
+
elif mode == 'non-intrusive':
|
2860
|
+
perturb_bool['aa_a'] = False
|
2861
|
+
perturb_bool['aa_i'] = False
|
2862
|
+
perturb_bool['ai'] = True
|
2863
|
+
perturb_bool['ia'] = False
|
2864
|
+
perturb_bool['ii'] = True
|
2865
|
+
elif mode == 'minimal':
|
2866
|
+
perturb_bool['aa_a'] = True
|
2867
|
+
perturb_bool['aa_i'] = False
|
2868
|
+
perturb_bool['ai'] = False
|
2869
|
+
perturb_bool['ia'] = False
|
2870
|
+
perturb_bool['ii'] = False
|
2871
|
+
else:
|
2872
|
+
raise Exception('mode not recognized, valid modes are \'all\', \'non-intrusive\', \'minimal\'.')
|
2873
|
+
|
2874
|
+
for l in range(self.depth):
|
2875
|
+
funs_name = self.symbolic_fun[l].funs_name
|
2876
|
+
for j in range(self.width_out[l + 1]):
|
2877
|
+
for i in range(self.width_in[l]):
|
2878
|
+
out_array = list(np.array(self.symbolic_fun[l].funs_name)[j])
|
2879
|
+
in_array = list(np.array(self.symbolic_fun[l].funs_name)[:, i])
|
2880
|
+
out_active = len([i for i, x in enumerate(out_array) if x != "0"]) > 0
|
2881
|
+
in_active = len([i for i, x in enumerate(in_array) if x != "0"]) > 0
|
2882
|
+
dic = {True: 'a', False: 'i'}
|
2883
|
+
edge_type = dic[in_active] + dic[out_active]
|
2884
|
+
|
2885
|
+
if l < self.depth - 1 or mode != 'non-intrusive':
|
2886
|
+
|
2887
|
+
if edge_type == 'aa':
|
2888
|
+
if self.symbolic_fun[l].funs_name[j][i] == '0':
|
2889
|
+
edge_type += '_i'
|
2890
|
+
else:
|
2891
|
+
edge_type += '_a'
|
2892
|
+
|
2893
|
+
if perturb_bool[edge_type]:
|
2894
|
+
self.act_fun[l].mask.data[i][j] = mag
|
2895
|
+
|
2896
|
+
if l == self.depth - 1 and mode == 'non-intrusive':
|
2897
|
+
self.act_fun[l].mask.data[i][j] = torch.tensor(1.)
|
2898
|
+
self.act_fun[l].scale_base.data[i][j] = torch.tensor(0.)
|
2899
|
+
self.act_fun[l].scale_sp.data[i][j] = torch.tensor(0.)
|
2900
|
+
|
2901
|
+
self.get_act(self.cache_data)
|
2902
|
+
|
2903
|
+
self.log_history('perturb')
|
2904
|
+
|
2905
|
+
def module(self, start_layer, chain):
|
2906
|
+
"""
|
2907
|
+
specify network modules
|
2908
|
+
|
2909
|
+
Args:
|
2910
|
+
-----
|
2911
|
+
start_layer : int
|
2912
|
+
the earliest layer of the module
|
2913
|
+
chain : str
|
2914
|
+
specify neurons in the module
|
2915
|
+
|
2916
|
+
Returns:
|
2917
|
+
--------
|
2918
|
+
None
|
2919
|
+
"""
|
2920
|
+
#chain = '[-1]->[-1,-2]->[-1]->[-1]'
|
2921
|
+
groups = chain.split('->')
|
2922
|
+
n_total_layers = len(groups) // 2
|
2923
|
+
#start_layer = 0
|
2924
|
+
|
2925
|
+
for l in range(n_total_layers):
|
2926
|
+
current_layer = cl = start_layer + l
|
2927
|
+
id_in = [int(i) for i in groups[2 * l][1:-1].split(',')]
|
2928
|
+
id_out = [int(i) for i in groups[2 * l + 1][1:-1].split(',')]
|
2929
|
+
|
2930
|
+
in_dim = self.width_in[cl]
|
2931
|
+
out_dim = self.width_out[cl + 1]
|
2932
|
+
id_in_other = list(set(range(in_dim)) - set(id_in))
|
2933
|
+
id_out_other = list(set(range(out_dim)) - set(id_out))
|
2934
|
+
self.act_fun[cl].mask.data[np.ix_(id_in_other, id_out)] = 0.
|
2935
|
+
self.act_fun[cl].mask.data[np.ix_(id_in, id_out_other)] = 0.
|
2936
|
+
self.symbolic_fun[cl].mask.data[np.ix_(id_out, id_in_other)] = 0.
|
2937
|
+
self.symbolic_fun[cl].mask.data[np.ix_(id_out_other, id_in)] = 0.
|
2938
|
+
|
2939
|
+
self.log_history('module')
|
2940
|
+
|
2941
|
+
def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
|
2942
|
+
"""
|
2943
|
+
turn KAN into a tree
|
2944
|
+
"""
|
2945
|
+
if x == None:
|
2946
|
+
x = self.cache_data
|
2947
|
+
plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test,
|
2948
|
+
verbose=verbose)
|
2949
|
+
|
2950
|
+
def speed(self, compile=False):
|
2951
|
+
"""
|
2952
|
+
turn on KAN's speed mode
|
2953
|
+
"""
|
2954
|
+
self.symbolic_enabled = False
|
2955
|
+
self.save_act = False
|
2956
|
+
self.auto_save = False
|
2957
|
+
if compile == True:
|
2958
|
+
return torch.compile(self)
|
2959
|
+
else:
|
2960
|
+
return self
|
2961
|
+
|
2962
|
+
def get_act(self, x=None):
|
2963
|
+
"""
|
2964
|
+
collect intermidate activations
|
2965
|
+
"""
|
2966
|
+
if isinstance(x, dict):
|
2967
|
+
x = x['train_input']
|
2968
|
+
if x is None:
|
2969
|
+
if self.cache_data != None:
|
2970
|
+
x = self.cache_data
|
2971
|
+
else:
|
2972
|
+
raise Exception("missing input data x")
|
2973
|
+
save_act = self.save_act
|
2974
|
+
self.save_act = True
|
2975
|
+
self.forward(x)
|
2976
|
+
self.save_act = save_act
|
2977
|
+
|
2978
|
+
def get_fun(self, l, i, j):
|
2979
|
+
"""
|
2980
|
+
get function (l,i,j)
|
2981
|
+
"""
|
2982
|
+
inputs = self.spline_preacts[l][:, j, i].cpu().detach().numpy()
|
2983
|
+
outputs = self.spline_postacts[l][:, j, i].cpu().detach().numpy()
|
2984
|
+
# they are not ordered yet
|
2985
|
+
rank = np.argsort(inputs)
|
2986
|
+
inputs = inputs[rank]
|
2987
|
+
outputs = outputs[rank]
|
2988
|
+
plt.figure(figsize=(3, 3))
|
2989
|
+
plt.plot(inputs, outputs, marker="o")
|
2990
|
+
return inputs, outputs
|
2991
|
+
|
2992
|
+
def history(self, k='all'):
|
2993
|
+
"""
|
2994
|
+
get history
|
2995
|
+
"""
|
2996
|
+
with open(self.ckpt_path + '/history.txt', 'r') as f:
|
2997
|
+
data = f.readlines()
|
2998
|
+
n_line = len(data)
|
2999
|
+
if k == 'all':
|
3000
|
+
k = n_line
|
3001
|
+
|
3002
|
+
data = data[-k:]
|
3003
|
+
for line in data:
|
3004
|
+
print(line[:-1])
|
3005
|
+
|
3006
|
+
@property
|
3007
|
+
def n_edge(self):
|
3008
|
+
"""
|
3009
|
+
the number of active edges
|
3010
|
+
"""
|
3011
|
+
depth = len(self.act_fun)
|
3012
|
+
complexity = 0
|
3013
|
+
for l in range(depth):
|
3014
|
+
complexity += torch.sum(self.act_fun[l].mask > 0.)
|
3015
|
+
return complexity.item()
|
3016
|
+
|
3017
|
+
def evaluate(self, dataset):
|
3018
|
+
evaluation = {'test_loss': torch.sqrt(
|
3019
|
+
torch.mean((self.forward(dataset['test_input']) - dataset['test_label']) ** 2)).item(),
|
3020
|
+
'n_edge': self.n_edge, 'n_grid': self.grid}
|
3021
|
+
# add other metrics (maybe accuracy)
|
3022
|
+
return evaluation
|
3023
|
+
|
3024
|
+
def swap(self, l, i1, i2, log_history=True):
|
3025
|
+
|
3026
|
+
self.act_fun[l - 1].swap(i1, i2, mode='out')
|
3027
|
+
self.symbolic_fun[l - 1].swap(i1, i2, mode='out')
|
3028
|
+
self.act_fun[l].swap(i1, i2, mode='in')
|
3029
|
+
self.symbolic_fun[l].swap(i1, i2, mode='in')
|
3030
|
+
|
3031
|
+
def swap_(data, i1, i2):
|
3032
|
+
data[i1], data[i2] = data[i2], data[i1]
|
3033
|
+
|
3034
|
+
swap_(self.node_scale[l - 1].data, i1, i2)
|
3035
|
+
swap_(self.node_bias[l - 1].data, i1, i2)
|
3036
|
+
swap_(self.subnode_scale[l - 1].data, i1, i2)
|
3037
|
+
swap_(self.subnode_bias[l - 1].data, i1, i2)
|
3038
|
+
|
3039
|
+
if log_history:
|
3040
|
+
self.log_history('swap')
|
3041
|
+
|
3042
|
+
@property
|
3043
|
+
def connection_cost(self):
|
3044
|
+
|
3045
|
+
cc = 0.
|
3046
|
+
for t in self.edge_scores:
|
3047
|
+
def get_coordinate(n):
|
3048
|
+
return torch.linspace(0, 1, steps=n + 1, device=self.device)[:n] + 1 / (2 * n)
|
3049
|
+
|
3050
|
+
in_dim = t.shape[0]
|
3051
|
+
x_in = get_coordinate(in_dim)
|
3052
|
+
|
3053
|
+
out_dim = t.shape[1]
|
3054
|
+
x_out = get_coordinate(out_dim)
|
3055
|
+
|
3056
|
+
dist = torch.abs(x_in[:, None] - x_out[None, :])
|
3057
|
+
cc += torch.sum(dist * t)
|
3058
|
+
|
3059
|
+
return cc
|
3060
|
+
|
3061
|
+
def auto_swap_l(self, l):
|
3062
|
+
|
3063
|
+
num = self.width_in[1]
|
3064
|
+
for i in range(num):
|
3065
|
+
ccs = []
|
3066
|
+
for j in range(num):
|
3067
|
+
self.swap(l, i, j, log_history=False)
|
3068
|
+
self.get_act()
|
3069
|
+
self.attribute()
|
3070
|
+
cc = self.connection_cost.detach().clone()
|
3071
|
+
ccs.append(cc)
|
3072
|
+
self.swap(l, i, j, log_history=False)
|
3073
|
+
j = torch.argmin(torch.tensor(ccs))
|
3074
|
+
self.swap(l, i, j, log_history=False)
|
3075
|
+
|
3076
|
+
def auto_swap(self):
|
3077
|
+
"""
|
3078
|
+
automatically swap neurons such as connection costs are minimized
|
3079
|
+
"""
|
3080
|
+
depth = self.depth
|
3081
|
+
for l in range(1, depth):
|
3082
|
+
self.auto_swap_l(l)
|
3083
|
+
|
3084
|
+
self.log_history('auto_swap')
|
3085
|
+
|
3086
|
+
|
3087
|
+
KAN = MultKAN
|