yms-kan 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan-0.0.7/LICENSE +21 -0
- yms_kan-0.0.7/PKG-INFO +18 -0
- yms_kan-0.0.7/README.md +1 -0
- yms_kan-0.0.7/kan/KANLayer.py +364 -0
- yms_kan-0.0.7/kan/LBFGS.py +492 -0
- yms_kan-0.0.7/kan/MLP.py +361 -0
- yms_kan-0.0.7/kan/MultKAN.py +3087 -0
- yms_kan-0.0.7/kan/Symbolic_KANLayer.py +270 -0
- yms_kan-0.0.7/kan/__init__.py +3 -0
- yms_kan-0.0.7/kan/compiler.py +498 -0
- yms_kan-0.0.7/kan/dataset.py +27 -0
- yms_kan-0.0.7/kan/experiment.py +50 -0
- yms_kan-0.0.7/kan/feynman.py +739 -0
- yms_kan-0.0.7/kan/hypothesis.py +695 -0
- yms_kan-0.0.7/kan/spline.py +144 -0
- yms_kan-0.0.7/kan/utils.py +661 -0
- yms_kan-0.0.7/setup.cfg +4 -0
- yms_kan-0.0.7/setup.py +96 -0
- yms_kan-0.0.7/yms_kan.egg-info/PKG-INFO +18 -0
- yms_kan-0.0.7/yms_kan.egg-info/SOURCES.txt +20 -0
- yms_kan-0.0.7/yms_kan.egg-info/dependency_links.txt +1 -0
- yms_kan-0.0.7/yms_kan.egg-info/top_level.txt +1 -0
yms_kan-0.0.7/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2024 Ziming Liu
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
yms_kan-0.0.7/PKG-INFO
ADDED
@@ -0,0 +1,18 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: yms_kan
|
3
|
+
Version: 0.0.7
|
4
|
+
Summary: works
|
5
|
+
Author: yms
|
6
|
+
Author-email: 226000@qq.com
|
7
|
+
Requires-Python: >=3.6
|
8
|
+
Description-Content-Type: text/markdown
|
9
|
+
License-File: LICENSE
|
10
|
+
Dynamic: author
|
11
|
+
Dynamic: author-email
|
12
|
+
Dynamic: description
|
13
|
+
Dynamic: description-content-type
|
14
|
+
Dynamic: license-file
|
15
|
+
Dynamic: requires-python
|
16
|
+
Dynamic: summary
|
17
|
+
|
18
|
+
0.0.6
|
yms_kan-0.0.7/README.md
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
0.0.7
|
@@ -0,0 +1,364 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import numpy as np
|
4
|
+
from .spline import *
|
5
|
+
from .utils import sparse_mask
|
6
|
+
|
7
|
+
|
8
|
+
class KANLayer(nn.Module):
|
9
|
+
"""
|
10
|
+
KANLayer class
|
11
|
+
|
12
|
+
|
13
|
+
Attributes:
|
14
|
+
-----------
|
15
|
+
in_dim: int
|
16
|
+
input dimension
|
17
|
+
out_dim: int
|
18
|
+
output dimension
|
19
|
+
num: int
|
20
|
+
the number of grid intervals
|
21
|
+
k: int
|
22
|
+
the piecewise polynomial order of splines
|
23
|
+
noise_scale: float
|
24
|
+
spline scale at initialization
|
25
|
+
coef: 2D torch.tensor
|
26
|
+
coefficients of B-spline bases
|
27
|
+
scale_base_mu: float
|
28
|
+
magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu
|
29
|
+
scale_base_sigma: float
|
30
|
+
magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma
|
31
|
+
scale_sp: float
|
32
|
+
mangitude of the spline function spline(x)
|
33
|
+
base_fun: fun
|
34
|
+
residual function b(x)
|
35
|
+
mask: 1D torch.float
|
36
|
+
mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function.
|
37
|
+
grid_eps: float in [0,1]
|
38
|
+
a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
|
39
|
+
the id of activation functions that are locked
|
40
|
+
device: str
|
41
|
+
device
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False):
|
45
|
+
''''
|
46
|
+
initialize a KANLayer
|
47
|
+
|
48
|
+
Args:
|
49
|
+
-----
|
50
|
+
in_dim : int
|
51
|
+
input dimension. Default: 2.
|
52
|
+
out_dim : int
|
53
|
+
output dimension. Default: 3.
|
54
|
+
num : int
|
55
|
+
the number of grid intervals = G. Default: 5.
|
56
|
+
k : int
|
57
|
+
the order of piecewise polynomial. Default: 3.
|
58
|
+
noise_scale : float
|
59
|
+
the scale of noise injected at initialization. Default: 0.1.
|
60
|
+
scale_base_mu : float
|
61
|
+
the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
|
62
|
+
scale_base_sigma : float
|
63
|
+
the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
|
64
|
+
scale_sp : float
|
65
|
+
the scale of the base function spline(x).
|
66
|
+
base_fun : function
|
67
|
+
residual function b(x). Default: torch.nn.SiLU()
|
68
|
+
grid_eps : float
|
69
|
+
When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
|
70
|
+
grid_range : list/np.array of shape (2,)
|
71
|
+
setting the range of grids. Default: [-1,1].
|
72
|
+
sp_trainable : bool
|
73
|
+
If true, scale_sp is trainable
|
74
|
+
sb_trainable : bool
|
75
|
+
If true, scale_base is trainable
|
76
|
+
device : str
|
77
|
+
device
|
78
|
+
sparse_init : bool
|
79
|
+
if sparse_init = True, sparse initialization is applied.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
--------
|
83
|
+
self
|
84
|
+
|
85
|
+
Example
|
86
|
+
-------
|
87
|
+
>>> from kan.KANLayer import *
|
88
|
+
>>> model = KANLayer(in_dim=3, out_dim=5)
|
89
|
+
>>> (model.in_dim, model.out_dim)
|
90
|
+
'''
|
91
|
+
super(KANLayer, self).__init__()
|
92
|
+
# size
|
93
|
+
self.out_dim = out_dim
|
94
|
+
self.in_dim = in_dim
|
95
|
+
self.num = num
|
96
|
+
self.k = k
|
97
|
+
|
98
|
+
grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None,:].expand(self.in_dim, num+1)
|
99
|
+
grid = extend_grid(grid, k_extend=k)
|
100
|
+
self.grid = torch.nn.Parameter(grid).requires_grad_(False)
|
101
|
+
noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1/2) * noise_scale / num
|
102
|
+
|
103
|
+
self.coef = torch.nn.Parameter(curve2coef(self.grid[:,k:-k].permute(1,0), noises, self.grid, k))
|
104
|
+
|
105
|
+
if sparse_init:
|
106
|
+
self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False)
|
107
|
+
else:
|
108
|
+
self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False)
|
109
|
+
|
110
|
+
self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
|
111
|
+
scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable)
|
112
|
+
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask).requires_grad_(sp_trainable) # make scale trainable
|
113
|
+
self.base_fun = base_fun
|
114
|
+
|
115
|
+
|
116
|
+
self.grid_eps = grid_eps
|
117
|
+
|
118
|
+
self.to(device)
|
119
|
+
|
120
|
+
def to(self, device):
|
121
|
+
super(KANLayer, self).to(device)
|
122
|
+
self.device = device
|
123
|
+
return self
|
124
|
+
|
125
|
+
def forward(self, x):
|
126
|
+
'''
|
127
|
+
KANLayer forward given input x
|
128
|
+
|
129
|
+
Args:
|
130
|
+
-----
|
131
|
+
x : 2D torch.float
|
132
|
+
inputs, shape (number of samples, input dimension)
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
--------
|
136
|
+
y : 2D torch.float
|
137
|
+
outputs, shape (number of samples, output dimension)
|
138
|
+
preacts : 3D torch.float
|
139
|
+
fan out x into activations, shape (number of sampels, output dimension, input dimension)
|
140
|
+
postacts : 3D torch.float
|
141
|
+
the outputs of activation functions with preacts as inputs
|
142
|
+
postspline : 3D torch.float
|
143
|
+
the outputs of spline functions with preacts as inputs
|
144
|
+
|
145
|
+
Example
|
146
|
+
-------
|
147
|
+
>>> from kan.KANLayer import *
|
148
|
+
>>> model = KANLayer(in_dim=3, out_dim=5)
|
149
|
+
>>> x = torch.normal(0,1,size=(100,3))
|
150
|
+
>>> y, preacts, postacts, postspline = model(x)
|
151
|
+
>>> y.shape, preacts.shape, postacts.shape, postspline.shape
|
152
|
+
'''
|
153
|
+
batch = x.shape[0]
|
154
|
+
preacts = x[:,None,:].clone().expand(batch, self.out_dim, self.in_dim)
|
155
|
+
|
156
|
+
base = self.base_fun(x) # (batch, in_dim)
|
157
|
+
y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k)
|
158
|
+
|
159
|
+
postspline = y.clone().permute(0,2,1)
|
160
|
+
|
161
|
+
y = self.scale_base[None,:,:] * base[:,:,None] + self.scale_sp[None,:,:] * y
|
162
|
+
y = self.mask[None,:,:] * y
|
163
|
+
|
164
|
+
postacts = y.clone().permute(0,2,1)
|
165
|
+
|
166
|
+
y = torch.sum(y, dim=1)
|
167
|
+
return y, preacts, postacts, postspline
|
168
|
+
|
169
|
+
def update_grid_from_samples(self, x, mode='sample'):
|
170
|
+
'''
|
171
|
+
update grid from samples
|
172
|
+
|
173
|
+
Args:
|
174
|
+
-----
|
175
|
+
x : 2D torch.float
|
176
|
+
inputs, shape (number of samples, input dimension)
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
--------
|
180
|
+
None
|
181
|
+
|
182
|
+
Example
|
183
|
+
-------
|
184
|
+
>>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
|
185
|
+
>>> print(model.grid.data)
|
186
|
+
>>> x = torch.linspace(-3,3,steps=100)[:,None]
|
187
|
+
>>> model.update_grid_from_samples(x)
|
188
|
+
>>> print(model.grid.data)
|
189
|
+
'''
|
190
|
+
|
191
|
+
batch = x.shape[0]
|
192
|
+
#x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
|
193
|
+
x_pos = torch.sort(x, dim=0)[0]
|
194
|
+
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
|
195
|
+
num_interval = self.grid.shape[1] - 1 - 2*self.k
|
196
|
+
|
197
|
+
def get_grid(num_interval):
|
198
|
+
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
|
199
|
+
grid_adaptive = x_pos[ids, :].permute(1,0)
|
200
|
+
margin = 0.00
|
201
|
+
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]] + 2 * margin)/num_interval
|
202
|
+
grid_uniform = grid_adaptive[:,[0]] - margin + h * torch.arange(num_interval+1,)[None, :].to(x.device)
|
203
|
+
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
|
204
|
+
return grid
|
205
|
+
|
206
|
+
|
207
|
+
grid = get_grid(num_interval)
|
208
|
+
|
209
|
+
if mode == 'grid':
|
210
|
+
sample_grid = get_grid(2*num_interval)
|
211
|
+
x_pos = sample_grid.permute(1,0)
|
212
|
+
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
|
213
|
+
|
214
|
+
self.grid.data = extend_grid(grid, k_extend=self.k)
|
215
|
+
#print('x_pos 2', x_pos.shape)
|
216
|
+
#print('y_eval 2', y_eval.shape)
|
217
|
+
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
|
218
|
+
|
219
|
+
def initialize_grid_from_parent(self, parent, x, mode='sample'):
|
220
|
+
'''
|
221
|
+
update grid from a parent KANLayer & samples
|
222
|
+
|
223
|
+
Args:
|
224
|
+
-----
|
225
|
+
parent : KANLayer
|
226
|
+
a parent KANLayer (whose grid is usually coarser than the current model)
|
227
|
+
x : 2D torch.float
|
228
|
+
inputs, shape (number of samples, input dimension)
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
--------
|
232
|
+
None
|
233
|
+
|
234
|
+
Example
|
235
|
+
-------
|
236
|
+
>>> batch = 100
|
237
|
+
>>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
|
238
|
+
>>> print(parent_model.grid.data)
|
239
|
+
>>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3)
|
240
|
+
>>> x = torch.normal(0,1,size=(batch, 1))
|
241
|
+
>>> model.initialize_grid_from_parent(parent_model, x)
|
242
|
+
>>> print(model.grid.data)
|
243
|
+
'''
|
244
|
+
|
245
|
+
batch = x.shape[0]
|
246
|
+
|
247
|
+
# shrink grid
|
248
|
+
x_pos = torch.sort(x, dim=0)[0]
|
249
|
+
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
|
250
|
+
num_interval = self.grid.shape[1] - 1 - 2*self.k
|
251
|
+
|
252
|
+
|
253
|
+
'''
|
254
|
+
# based on samples
|
255
|
+
def get_grid(num_interval):
|
256
|
+
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
|
257
|
+
grid_adaptive = x_pos[ids, :].permute(1,0)
|
258
|
+
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
|
259
|
+
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
|
260
|
+
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
|
261
|
+
return grid'''
|
262
|
+
|
263
|
+
#print('p', parent.grid)
|
264
|
+
# based on interpolating parent grid
|
265
|
+
def get_grid(num_interval):
|
266
|
+
x_pos = parent.grid[:,parent.k:-parent.k]
|
267
|
+
#print('x_pos', x_pos)
|
268
|
+
sp2 = KANLayer(in_dim=1, out_dim=self.in_dim,k=1,num=x_pos.shape[1]-1,scale_base_mu=0.0, scale_base_sigma=0.0).to(x.device)
|
269
|
+
|
270
|
+
#print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim))
|
271
|
+
#print('sp2_coef_shape', sp2.coef.shape)
|
272
|
+
sp2_coef = curve2coef(sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim), x_pos.permute(1,0).unsqueeze(dim=2), sp2.grid[:,:], k=1).permute(1,0,2)
|
273
|
+
shp = sp2_coef.shape
|
274
|
+
#sp2_coef = torch.cat([torch.zeros(shp[0], shp[1], 1), sp2_coef, torch.zeros(shp[0], shp[1], 1)], dim=2)
|
275
|
+
#print('sp2_coef',sp2_coef)
|
276
|
+
#print(sp2.coef.shape)
|
277
|
+
sp2.coef.data = sp2_coef
|
278
|
+
percentile = torch.linspace(-1,1,self.num+1).to(self.device)
|
279
|
+
grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0)
|
280
|
+
#print('c', grid)
|
281
|
+
return grid
|
282
|
+
|
283
|
+
grid = get_grid(num_interval)
|
284
|
+
|
285
|
+
if mode == 'grid':
|
286
|
+
sample_grid = get_grid(2*num_interval)
|
287
|
+
x_pos = sample_grid.permute(1,0)
|
288
|
+
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
|
289
|
+
|
290
|
+
grid = extend_grid(grid, k_extend=self.k)
|
291
|
+
self.grid.data = grid
|
292
|
+
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
|
293
|
+
|
294
|
+
def get_subset(self, in_id, out_id):
|
295
|
+
'''
|
296
|
+
get a smaller KANLayer from a larger KANLayer (used for pruning)
|
297
|
+
|
298
|
+
Args:
|
299
|
+
-----
|
300
|
+
in_id : list
|
301
|
+
id of selected input neurons
|
302
|
+
out_id : list
|
303
|
+
id of selected output neurons
|
304
|
+
|
305
|
+
Returns:
|
306
|
+
--------
|
307
|
+
spb : KANLayer
|
308
|
+
|
309
|
+
Example
|
310
|
+
-------
|
311
|
+
>>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3)
|
312
|
+
>>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3])
|
313
|
+
>>> kanlayer_small.in_dim, kanlayer_small.out_dim
|
314
|
+
(2, 3)
|
315
|
+
'''
|
316
|
+
spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun)
|
317
|
+
spb.grid.data = self.grid[in_id]
|
318
|
+
spb.coef.data = self.coef[in_id][:,out_id]
|
319
|
+
spb.scale_base.data = self.scale_base[in_id][:,out_id]
|
320
|
+
spb.scale_sp.data = self.scale_sp[in_id][:,out_id]
|
321
|
+
spb.mask.data = self.mask[in_id][:,out_id]
|
322
|
+
|
323
|
+
spb.in_dim = len(in_id)
|
324
|
+
spb.out_dim = len(out_id)
|
325
|
+
return spb
|
326
|
+
|
327
|
+
|
328
|
+
def swap(self, i1, i2, mode='in'):
|
329
|
+
'''
|
330
|
+
swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out')
|
331
|
+
|
332
|
+
Args:
|
333
|
+
-----
|
334
|
+
i1 : int
|
335
|
+
i2 : int
|
336
|
+
mode : str
|
337
|
+
mode = 'in' or 'out'
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
--------
|
341
|
+
None
|
342
|
+
|
343
|
+
Example
|
344
|
+
-------
|
345
|
+
>>> from kan.KANLayer import *
|
346
|
+
>>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3)
|
347
|
+
>>> print(model.coef)
|
348
|
+
>>> model.swap(0,1,mode='in')
|
349
|
+
>>> print(model.coef)
|
350
|
+
'''
|
351
|
+
with torch.no_grad():
|
352
|
+
def swap_(data, i1, i2, mode='in'):
|
353
|
+
if mode == 'in':
|
354
|
+
data[i1], data[i2] = data[i2].clone(), data[i1].clone()
|
355
|
+
elif mode == 'out':
|
356
|
+
data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
|
357
|
+
|
358
|
+
if mode == 'in':
|
359
|
+
swap_(self.grid.data, i1, i2, mode='in')
|
360
|
+
swap_(self.coef.data, i1, i2, mode=mode)
|
361
|
+
swap_(self.scale_base.data, i1, i2, mode=mode)
|
362
|
+
swap_(self.scale_sp.data, i1, i2, mode=mode)
|
363
|
+
swap_(self.mask.data, i1, i2, mode=mode)
|
364
|
+
|