yms-kan 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan/KANLayer.py +364 -0
- yms_kan/LBFGS.py +492 -0
- yms_kan/MLP.py +361 -0
- yms_kan/MultKAN.py +3085 -0
- yms_kan/Symbolic_KANLayer.py +270 -0
- yms_kan/__init__.py +4 -0
- yms_kan/compiler.py +498 -0
- yms_kan/experiment.py +50 -0
- yms_kan/feynman.py +739 -0
- yms_kan/hypothesis.py +695 -0
- yms_kan/spline.py +144 -0
- yms_kan/tool.py +304 -0
- yms_kan/train_eval_utils.py +175 -0
- yms_kan/utils.py +661 -0
- yms_kan/version.py +1 -0
- yms_kan-0.0.1.dist-info/METADATA +11 -0
- yms_kan-0.0.1.dist-info/RECORD +20 -0
- yms_kan-0.0.1.dist-info/WHEEL +5 -0
- yms_kan-0.0.1.dist-info/licenses/LICENSE +21 -0
- yms_kan-0.0.1.dist-info/top_level.txt +1 -0
yms_kan/KANLayer.py
ADDED
@@ -0,0 +1,364 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import numpy as np
|
4
|
+
from .spline import *
|
5
|
+
from .utils import sparse_mask
|
6
|
+
|
7
|
+
|
8
|
+
class KANLayer(nn.Module):
|
9
|
+
"""
|
10
|
+
KANLayer class
|
11
|
+
|
12
|
+
|
13
|
+
Attributes:
|
14
|
+
-----------
|
15
|
+
in_dim: int
|
16
|
+
input dimension
|
17
|
+
out_dim: int
|
18
|
+
output dimension
|
19
|
+
num: int
|
20
|
+
the number of grid intervals
|
21
|
+
k: int
|
22
|
+
the piecewise polynomial order of splines
|
23
|
+
noise_scale: float
|
24
|
+
spline scale at initialization
|
25
|
+
coef: 2D torch.tensor
|
26
|
+
coefficients of B-spline bases
|
27
|
+
scale_base_mu: float
|
28
|
+
magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu
|
29
|
+
scale_base_sigma: float
|
30
|
+
magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma
|
31
|
+
scale_sp: float
|
32
|
+
mangitude of the spline function spline(x)
|
33
|
+
base_fun: fun
|
34
|
+
residual function b(x)
|
35
|
+
mask: 1D torch.float
|
36
|
+
mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function.
|
37
|
+
grid_eps: float in [0,1]
|
38
|
+
a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
|
39
|
+
the id of activation functions that are locked
|
40
|
+
device: str
|
41
|
+
device
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False):
|
45
|
+
''''
|
46
|
+
initialize a KANLayer
|
47
|
+
|
48
|
+
Args:
|
49
|
+
-----
|
50
|
+
in_dim : int
|
51
|
+
input dimension. Default: 2.
|
52
|
+
out_dim : int
|
53
|
+
output dimension. Default: 3.
|
54
|
+
num : int
|
55
|
+
the number of grid intervals = G. Default: 5.
|
56
|
+
k : int
|
57
|
+
the order of piecewise polynomial. Default: 3.
|
58
|
+
noise_scale : float
|
59
|
+
the scale of noise injected at initialization. Default: 0.1.
|
60
|
+
scale_base_mu : float
|
61
|
+
the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
|
62
|
+
scale_base_sigma : float
|
63
|
+
the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
|
64
|
+
scale_sp : float
|
65
|
+
the scale of the base function spline(x).
|
66
|
+
base_fun : function
|
67
|
+
residual function b(x). Default: torch.nn.SiLU()
|
68
|
+
grid_eps : float
|
69
|
+
When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
|
70
|
+
grid_range : list/np.array of shape (2,)
|
71
|
+
setting the range of grids. Default: [-1,1].
|
72
|
+
sp_trainable : bool
|
73
|
+
If true, scale_sp is trainable
|
74
|
+
sb_trainable : bool
|
75
|
+
If true, scale_base is trainable
|
76
|
+
device : str
|
77
|
+
device
|
78
|
+
sparse_init : bool
|
79
|
+
if sparse_init = True, sparse initialization is applied.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
--------
|
83
|
+
self
|
84
|
+
|
85
|
+
Example
|
86
|
+
-------
|
87
|
+
>>> from yms_kan.KANLayer import *
|
88
|
+
>>> model = KANLayer(in_dim=3, out_dim=5)
|
89
|
+
>>> (model.in_dim, model.out_dim)
|
90
|
+
'''
|
91
|
+
super(KANLayer, self).__init__()
|
92
|
+
# size
|
93
|
+
self.out_dim = out_dim
|
94
|
+
self.in_dim = in_dim
|
95
|
+
self.num = num
|
96
|
+
self.k = k
|
97
|
+
|
98
|
+
grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None,:].expand(self.in_dim, num+1)
|
99
|
+
grid = extend_grid(grid, k_extend=k)
|
100
|
+
self.grid = torch.nn.Parameter(grid).requires_grad_(False)
|
101
|
+
noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1/2) * noise_scale / num
|
102
|
+
|
103
|
+
self.coef = torch.nn.Parameter(curve2coef(self.grid[:,k:-k].permute(1,0), noises, self.grid, k))
|
104
|
+
|
105
|
+
if sparse_init:
|
106
|
+
self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False)
|
107
|
+
else:
|
108
|
+
self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False)
|
109
|
+
|
110
|
+
self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
|
111
|
+
scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable)
|
112
|
+
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask).requires_grad_(sp_trainable) # make scale trainable
|
113
|
+
self.base_fun = base_fun
|
114
|
+
|
115
|
+
|
116
|
+
self.grid_eps = grid_eps
|
117
|
+
|
118
|
+
self.to(device)
|
119
|
+
|
120
|
+
def to(self, device):
|
121
|
+
super(KANLayer, self).to(device)
|
122
|
+
self.device = device
|
123
|
+
return self
|
124
|
+
|
125
|
+
def forward(self, x):
|
126
|
+
'''
|
127
|
+
KANLayer forward given input x
|
128
|
+
|
129
|
+
Args:
|
130
|
+
-----
|
131
|
+
x : 2D torch.float
|
132
|
+
inputs, shape (number of samples, input dimension)
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
--------
|
136
|
+
y : 2D torch.float
|
137
|
+
outputs, shape (number of samples, output dimension)
|
138
|
+
preacts : 3D torch.float
|
139
|
+
fan out x into activations, shape (number of sampels, output dimension, input dimension)
|
140
|
+
postacts : 3D torch.float
|
141
|
+
the outputs of activation functions with preacts as inputs
|
142
|
+
postspline : 3D torch.float
|
143
|
+
the outputs of spline functions with preacts as inputs
|
144
|
+
|
145
|
+
Example
|
146
|
+
-------
|
147
|
+
>>> from yms_kan.KANLayer import *
|
148
|
+
>>> model = KANLayer(in_dim=3, out_dim=5)
|
149
|
+
>>> x = torch.normal(0,1,size=(100,3))
|
150
|
+
>>> y, preacts, postacts, postspline = model(x)
|
151
|
+
>>> y.shape, preacts.shape, postacts.shape, postspline.shape
|
152
|
+
'''
|
153
|
+
batch = x.shape[0]
|
154
|
+
preacts = x[:,None,:].clone().expand(batch, self.out_dim, self.in_dim)
|
155
|
+
|
156
|
+
base = self.base_fun(x) # (batch, in_dim)
|
157
|
+
y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k)
|
158
|
+
|
159
|
+
postspline = y.clone().permute(0,2,1)
|
160
|
+
|
161
|
+
y = self.scale_base[None,:,:] * base[:,:,None] + self.scale_sp[None,:,:] * y
|
162
|
+
y = self.mask[None,:,:] * y
|
163
|
+
|
164
|
+
postacts = y.clone().permute(0,2,1)
|
165
|
+
|
166
|
+
y = torch.sum(y, dim=1)
|
167
|
+
return y, preacts, postacts, postspline
|
168
|
+
|
169
|
+
def update_grid_from_samples(self, x, mode='sample'):
|
170
|
+
'''
|
171
|
+
update grid from samples
|
172
|
+
|
173
|
+
Args:
|
174
|
+
-----
|
175
|
+
x : 2D torch.float
|
176
|
+
inputs, shape (number of samples, input dimension)
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
--------
|
180
|
+
None
|
181
|
+
|
182
|
+
Example
|
183
|
+
-------
|
184
|
+
>>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
|
185
|
+
>>> print(model.grid.data)
|
186
|
+
>>> x = torch.linspace(-3,3,steps=100)[:,None]
|
187
|
+
>>> model.update_grid_from_samples(x)
|
188
|
+
>>> print(model.grid.data)
|
189
|
+
'''
|
190
|
+
|
191
|
+
batch = x.shape[0]
|
192
|
+
#x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
|
193
|
+
x_pos = torch.sort(x, dim=0)[0]
|
194
|
+
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
|
195
|
+
num_interval = self.grid.shape[1] - 1 - 2*self.k
|
196
|
+
|
197
|
+
def get_grid(num_interval):
|
198
|
+
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
|
199
|
+
grid_adaptive = x_pos[ids, :].permute(1,0)
|
200
|
+
margin = 0.00
|
201
|
+
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]] + 2 * margin)/num_interval
|
202
|
+
grid_uniform = grid_adaptive[:,[0]] - margin + h * torch.arange(num_interval+1,)[None, :].to(x.device)
|
203
|
+
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
|
204
|
+
return grid
|
205
|
+
|
206
|
+
|
207
|
+
grid = get_grid(num_interval)
|
208
|
+
|
209
|
+
if mode == 'grid':
|
210
|
+
sample_grid = get_grid(2*num_interval)
|
211
|
+
x_pos = sample_grid.permute(1,0)
|
212
|
+
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
|
213
|
+
|
214
|
+
self.grid.data = extend_grid(grid, k_extend=self.k)
|
215
|
+
#print('x_pos 2', x_pos.shape)
|
216
|
+
#print('y_eval 2', y_eval.shape)
|
217
|
+
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
|
218
|
+
|
219
|
+
def initialize_grid_from_parent(self, parent, x, mode='sample'):
|
220
|
+
'''
|
221
|
+
update grid from a parent KANLayer & samples
|
222
|
+
|
223
|
+
Args:
|
224
|
+
-----
|
225
|
+
parent : KANLayer
|
226
|
+
a parent KANLayer (whose grid is usually coarser than the current model)
|
227
|
+
x : 2D torch.float
|
228
|
+
inputs, shape (number of samples, input dimension)
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
--------
|
232
|
+
None
|
233
|
+
|
234
|
+
Example
|
235
|
+
-------
|
236
|
+
>>> batch = 100
|
237
|
+
>>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
|
238
|
+
>>> print(parent_model.grid.data)
|
239
|
+
>>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3)
|
240
|
+
>>> x = torch.normal(0,1,size=(batch, 1))
|
241
|
+
>>> model.initialize_grid_from_parent(parent_model, x)
|
242
|
+
>>> print(model.grid.data)
|
243
|
+
'''
|
244
|
+
|
245
|
+
batch = x.shape[0]
|
246
|
+
|
247
|
+
# shrink grid
|
248
|
+
x_pos = torch.sort(x, dim=0)[0]
|
249
|
+
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
|
250
|
+
num_interval = self.grid.shape[1] - 1 - 2*self.k
|
251
|
+
|
252
|
+
|
253
|
+
'''
|
254
|
+
# based on samples
|
255
|
+
def get_grid(num_interval):
|
256
|
+
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
|
257
|
+
grid_adaptive = x_pos[ids, :].permute(1,0)
|
258
|
+
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
|
259
|
+
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
|
260
|
+
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
|
261
|
+
return grid'''
|
262
|
+
|
263
|
+
#print('p', parent.grid)
|
264
|
+
# based on interpolating parent grid
|
265
|
+
def get_grid(num_interval):
|
266
|
+
x_pos = parent.grid[:,parent.k:-parent.k]
|
267
|
+
#print('x_pos', x_pos)
|
268
|
+
sp2 = KANLayer(in_dim=1, out_dim=self.in_dim,k=1,num=x_pos.shape[1]-1,scale_base_mu=0.0, scale_base_sigma=0.0).to(x.device)
|
269
|
+
|
270
|
+
#print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim))
|
271
|
+
#print('sp2_coef_shape', sp2.coef.shape)
|
272
|
+
sp2_coef = curve2coef(sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim), x_pos.permute(1,0).unsqueeze(dim=2), sp2.grid[:,:], k=1).permute(1,0,2)
|
273
|
+
shp = sp2_coef.shape
|
274
|
+
#sp2_coef = torch.cat([torch.zeros(shp[0], shp[1], 1), sp2_coef, torch.zeros(shp[0], shp[1], 1)], dim=2)
|
275
|
+
#print('sp2_coef',sp2_coef)
|
276
|
+
#print(sp2.coef.shape)
|
277
|
+
sp2.coef.data = sp2_coef
|
278
|
+
percentile = torch.linspace(-1,1,self.num+1).to(self.device)
|
279
|
+
grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0)
|
280
|
+
#print('c', grid)
|
281
|
+
return grid
|
282
|
+
|
283
|
+
grid = get_grid(num_interval)
|
284
|
+
|
285
|
+
if mode == 'grid':
|
286
|
+
sample_grid = get_grid(2*num_interval)
|
287
|
+
x_pos = sample_grid.permute(1,0)
|
288
|
+
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
|
289
|
+
|
290
|
+
grid = extend_grid(grid, k_extend=self.k)
|
291
|
+
self.grid.data = grid
|
292
|
+
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
|
293
|
+
|
294
|
+
def get_subset(self, in_id, out_id):
|
295
|
+
'''
|
296
|
+
get a smaller KANLayer from a larger KANLayer (used for pruning)
|
297
|
+
|
298
|
+
Args:
|
299
|
+
-----
|
300
|
+
in_id : list
|
301
|
+
id of selected input neurons
|
302
|
+
out_id : list
|
303
|
+
id of selected output neurons
|
304
|
+
|
305
|
+
Returns:
|
306
|
+
--------
|
307
|
+
spb : KANLayer
|
308
|
+
|
309
|
+
Example
|
310
|
+
-------
|
311
|
+
>>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3)
|
312
|
+
>>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3])
|
313
|
+
>>> kanlayer_small.in_dim, kanlayer_small.out_dim
|
314
|
+
(2, 3)
|
315
|
+
'''
|
316
|
+
spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun)
|
317
|
+
spb.grid.data = self.grid[in_id]
|
318
|
+
spb.coef.data = self.coef[in_id][:,out_id]
|
319
|
+
spb.scale_base.data = self.scale_base[in_id][:,out_id]
|
320
|
+
spb.scale_sp.data = self.scale_sp[in_id][:,out_id]
|
321
|
+
spb.mask.data = self.mask[in_id][:,out_id]
|
322
|
+
|
323
|
+
spb.in_dim = len(in_id)
|
324
|
+
spb.out_dim = len(out_id)
|
325
|
+
return spb
|
326
|
+
|
327
|
+
|
328
|
+
def swap(self, i1, i2, mode='in'):
|
329
|
+
'''
|
330
|
+
swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out')
|
331
|
+
|
332
|
+
Args:
|
333
|
+
-----
|
334
|
+
i1 : int
|
335
|
+
i2 : int
|
336
|
+
mode : str
|
337
|
+
mode = 'in' or 'out'
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
--------
|
341
|
+
None
|
342
|
+
|
343
|
+
Example
|
344
|
+
-------
|
345
|
+
>>> from yms_kan.KANLayer import *
|
346
|
+
>>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3)
|
347
|
+
>>> print(model.coef)
|
348
|
+
>>> model.swap(0,1,mode='in')
|
349
|
+
>>> print(model.coef)
|
350
|
+
'''
|
351
|
+
with torch.no_grad():
|
352
|
+
def swap_(data, i1, i2, mode='in'):
|
353
|
+
if mode == 'in':
|
354
|
+
data[i1], data[i2] = data[i2].clone(), data[i1].clone()
|
355
|
+
elif mode == 'out':
|
356
|
+
data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
|
357
|
+
|
358
|
+
if mode == 'in':
|
359
|
+
swap_(self.grid.data, i1, i2, mode='in')
|
360
|
+
swap_(self.coef.data, i1, i2, mode=mode)
|
361
|
+
swap_(self.scale_base.data, i1, i2, mode=mode)
|
362
|
+
swap_(self.scale_sp.data, i1, i2, mode=mode)
|
363
|
+
swap_(self.mask.data, i1, i2, mode=mode)
|
364
|
+
|