yms-kan 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yms_kan-0.0.7/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ziming Liu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
yms_kan-0.0.7/PKG-INFO ADDED
@@ -0,0 +1,18 @@
1
+ Metadata-Version: 2.4
2
+ Name: yms_kan
3
+ Version: 0.0.7
4
+ Summary: works
5
+ Author: yms
6
+ Author-email: 226000@qq.com
7
+ Requires-Python: >=3.6
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Dynamic: author
11
+ Dynamic: author-email
12
+ Dynamic: description
13
+ Dynamic: description-content-type
14
+ Dynamic: license-file
15
+ Dynamic: requires-python
16
+ Dynamic: summary
17
+
18
+ 0.0.6
@@ -0,0 +1 @@
1
+ 0.0.7
@@ -0,0 +1,364 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from .spline import *
5
+ from .utils import sparse_mask
6
+
7
+
8
+ class KANLayer(nn.Module):
9
+ """
10
+ KANLayer class
11
+
12
+
13
+ Attributes:
14
+ -----------
15
+ in_dim: int
16
+ input dimension
17
+ out_dim: int
18
+ output dimension
19
+ num: int
20
+ the number of grid intervals
21
+ k: int
22
+ the piecewise polynomial order of splines
23
+ noise_scale: float
24
+ spline scale at initialization
25
+ coef: 2D torch.tensor
26
+ coefficients of B-spline bases
27
+ scale_base_mu: float
28
+ magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu
29
+ scale_base_sigma: float
30
+ magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma
31
+ scale_sp: float
32
+ mangitude of the spline function spline(x)
33
+ base_fun: fun
34
+ residual function b(x)
35
+ mask: 1D torch.float
36
+ mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function.
37
+ grid_eps: float in [0,1]
38
+ a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
39
+ the id of activation functions that are locked
40
+ device: str
41
+ device
42
+ """
43
+
44
+ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False):
45
+ ''''
46
+ initialize a KANLayer
47
+
48
+ Args:
49
+ -----
50
+ in_dim : int
51
+ input dimension. Default: 2.
52
+ out_dim : int
53
+ output dimension. Default: 3.
54
+ num : int
55
+ the number of grid intervals = G. Default: 5.
56
+ k : int
57
+ the order of piecewise polynomial. Default: 3.
58
+ noise_scale : float
59
+ the scale of noise injected at initialization. Default: 0.1.
60
+ scale_base_mu : float
61
+ the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
62
+ scale_base_sigma : float
63
+ the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
64
+ scale_sp : float
65
+ the scale of the base function spline(x).
66
+ base_fun : function
67
+ residual function b(x). Default: torch.nn.SiLU()
68
+ grid_eps : float
69
+ When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
70
+ grid_range : list/np.array of shape (2,)
71
+ setting the range of grids. Default: [-1,1].
72
+ sp_trainable : bool
73
+ If true, scale_sp is trainable
74
+ sb_trainable : bool
75
+ If true, scale_base is trainable
76
+ device : str
77
+ device
78
+ sparse_init : bool
79
+ if sparse_init = True, sparse initialization is applied.
80
+
81
+ Returns:
82
+ --------
83
+ self
84
+
85
+ Example
86
+ -------
87
+ >>> from kan.KANLayer import *
88
+ >>> model = KANLayer(in_dim=3, out_dim=5)
89
+ >>> (model.in_dim, model.out_dim)
90
+ '''
91
+ super(KANLayer, self).__init__()
92
+ # size
93
+ self.out_dim = out_dim
94
+ self.in_dim = in_dim
95
+ self.num = num
96
+ self.k = k
97
+
98
+ grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None,:].expand(self.in_dim, num+1)
99
+ grid = extend_grid(grid, k_extend=k)
100
+ self.grid = torch.nn.Parameter(grid).requires_grad_(False)
101
+ noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1/2) * noise_scale / num
102
+
103
+ self.coef = torch.nn.Parameter(curve2coef(self.grid[:,k:-k].permute(1,0), noises, self.grid, k))
104
+
105
+ if sparse_init:
106
+ self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False)
107
+ else:
108
+ self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False)
109
+
110
+ self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
111
+ scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable)
112
+ self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask).requires_grad_(sp_trainable) # make scale trainable
113
+ self.base_fun = base_fun
114
+
115
+
116
+ self.grid_eps = grid_eps
117
+
118
+ self.to(device)
119
+
120
+ def to(self, device):
121
+ super(KANLayer, self).to(device)
122
+ self.device = device
123
+ return self
124
+
125
+ def forward(self, x):
126
+ '''
127
+ KANLayer forward given input x
128
+
129
+ Args:
130
+ -----
131
+ x : 2D torch.float
132
+ inputs, shape (number of samples, input dimension)
133
+
134
+ Returns:
135
+ --------
136
+ y : 2D torch.float
137
+ outputs, shape (number of samples, output dimension)
138
+ preacts : 3D torch.float
139
+ fan out x into activations, shape (number of sampels, output dimension, input dimension)
140
+ postacts : 3D torch.float
141
+ the outputs of activation functions with preacts as inputs
142
+ postspline : 3D torch.float
143
+ the outputs of spline functions with preacts as inputs
144
+
145
+ Example
146
+ -------
147
+ >>> from kan.KANLayer import *
148
+ >>> model = KANLayer(in_dim=3, out_dim=5)
149
+ >>> x = torch.normal(0,1,size=(100,3))
150
+ >>> y, preacts, postacts, postspline = model(x)
151
+ >>> y.shape, preacts.shape, postacts.shape, postspline.shape
152
+ '''
153
+ batch = x.shape[0]
154
+ preacts = x[:,None,:].clone().expand(batch, self.out_dim, self.in_dim)
155
+
156
+ base = self.base_fun(x) # (batch, in_dim)
157
+ y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k)
158
+
159
+ postspline = y.clone().permute(0,2,1)
160
+
161
+ y = self.scale_base[None,:,:] * base[:,:,None] + self.scale_sp[None,:,:] * y
162
+ y = self.mask[None,:,:] * y
163
+
164
+ postacts = y.clone().permute(0,2,1)
165
+
166
+ y = torch.sum(y, dim=1)
167
+ return y, preacts, postacts, postspline
168
+
169
+ def update_grid_from_samples(self, x, mode='sample'):
170
+ '''
171
+ update grid from samples
172
+
173
+ Args:
174
+ -----
175
+ x : 2D torch.float
176
+ inputs, shape (number of samples, input dimension)
177
+
178
+ Returns:
179
+ --------
180
+ None
181
+
182
+ Example
183
+ -------
184
+ >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
185
+ >>> print(model.grid.data)
186
+ >>> x = torch.linspace(-3,3,steps=100)[:,None]
187
+ >>> model.update_grid_from_samples(x)
188
+ >>> print(model.grid.data)
189
+ '''
190
+
191
+ batch = x.shape[0]
192
+ #x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
193
+ x_pos = torch.sort(x, dim=0)[0]
194
+ y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
195
+ num_interval = self.grid.shape[1] - 1 - 2*self.k
196
+
197
+ def get_grid(num_interval):
198
+ ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
199
+ grid_adaptive = x_pos[ids, :].permute(1,0)
200
+ margin = 0.00
201
+ h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]] + 2 * margin)/num_interval
202
+ grid_uniform = grid_adaptive[:,[0]] - margin + h * torch.arange(num_interval+1,)[None, :].to(x.device)
203
+ grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
204
+ return grid
205
+
206
+
207
+ grid = get_grid(num_interval)
208
+
209
+ if mode == 'grid':
210
+ sample_grid = get_grid(2*num_interval)
211
+ x_pos = sample_grid.permute(1,0)
212
+ y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
213
+
214
+ self.grid.data = extend_grid(grid, k_extend=self.k)
215
+ #print('x_pos 2', x_pos.shape)
216
+ #print('y_eval 2', y_eval.shape)
217
+ self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
218
+
219
+ def initialize_grid_from_parent(self, parent, x, mode='sample'):
220
+ '''
221
+ update grid from a parent KANLayer & samples
222
+
223
+ Args:
224
+ -----
225
+ parent : KANLayer
226
+ a parent KANLayer (whose grid is usually coarser than the current model)
227
+ x : 2D torch.float
228
+ inputs, shape (number of samples, input dimension)
229
+
230
+ Returns:
231
+ --------
232
+ None
233
+
234
+ Example
235
+ -------
236
+ >>> batch = 100
237
+ >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
238
+ >>> print(parent_model.grid.data)
239
+ >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3)
240
+ >>> x = torch.normal(0,1,size=(batch, 1))
241
+ >>> model.initialize_grid_from_parent(parent_model, x)
242
+ >>> print(model.grid.data)
243
+ '''
244
+
245
+ batch = x.shape[0]
246
+
247
+ # shrink grid
248
+ x_pos = torch.sort(x, dim=0)[0]
249
+ y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
250
+ num_interval = self.grid.shape[1] - 1 - 2*self.k
251
+
252
+
253
+ '''
254
+ # based on samples
255
+ def get_grid(num_interval):
256
+ ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
257
+ grid_adaptive = x_pos[ids, :].permute(1,0)
258
+ h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
259
+ grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
260
+ grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
261
+ return grid'''
262
+
263
+ #print('p', parent.grid)
264
+ # based on interpolating parent grid
265
+ def get_grid(num_interval):
266
+ x_pos = parent.grid[:,parent.k:-parent.k]
267
+ #print('x_pos', x_pos)
268
+ sp2 = KANLayer(in_dim=1, out_dim=self.in_dim,k=1,num=x_pos.shape[1]-1,scale_base_mu=0.0, scale_base_sigma=0.0).to(x.device)
269
+
270
+ #print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim))
271
+ #print('sp2_coef_shape', sp2.coef.shape)
272
+ sp2_coef = curve2coef(sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim), x_pos.permute(1,0).unsqueeze(dim=2), sp2.grid[:,:], k=1).permute(1,0,2)
273
+ shp = sp2_coef.shape
274
+ #sp2_coef = torch.cat([torch.zeros(shp[0], shp[1], 1), sp2_coef, torch.zeros(shp[0], shp[1], 1)], dim=2)
275
+ #print('sp2_coef',sp2_coef)
276
+ #print(sp2.coef.shape)
277
+ sp2.coef.data = sp2_coef
278
+ percentile = torch.linspace(-1,1,self.num+1).to(self.device)
279
+ grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0)
280
+ #print('c', grid)
281
+ return grid
282
+
283
+ grid = get_grid(num_interval)
284
+
285
+ if mode == 'grid':
286
+ sample_grid = get_grid(2*num_interval)
287
+ x_pos = sample_grid.permute(1,0)
288
+ y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
289
+
290
+ grid = extend_grid(grid, k_extend=self.k)
291
+ self.grid.data = grid
292
+ self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
293
+
294
+ def get_subset(self, in_id, out_id):
295
+ '''
296
+ get a smaller KANLayer from a larger KANLayer (used for pruning)
297
+
298
+ Args:
299
+ -----
300
+ in_id : list
301
+ id of selected input neurons
302
+ out_id : list
303
+ id of selected output neurons
304
+
305
+ Returns:
306
+ --------
307
+ spb : KANLayer
308
+
309
+ Example
310
+ -------
311
+ >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3)
312
+ >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3])
313
+ >>> kanlayer_small.in_dim, kanlayer_small.out_dim
314
+ (2, 3)
315
+ '''
316
+ spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun)
317
+ spb.grid.data = self.grid[in_id]
318
+ spb.coef.data = self.coef[in_id][:,out_id]
319
+ spb.scale_base.data = self.scale_base[in_id][:,out_id]
320
+ spb.scale_sp.data = self.scale_sp[in_id][:,out_id]
321
+ spb.mask.data = self.mask[in_id][:,out_id]
322
+
323
+ spb.in_dim = len(in_id)
324
+ spb.out_dim = len(out_id)
325
+ return spb
326
+
327
+
328
+ def swap(self, i1, i2, mode='in'):
329
+ '''
330
+ swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out')
331
+
332
+ Args:
333
+ -----
334
+ i1 : int
335
+ i2 : int
336
+ mode : str
337
+ mode = 'in' or 'out'
338
+
339
+ Returns:
340
+ --------
341
+ None
342
+
343
+ Example
344
+ -------
345
+ >>> from kan.KANLayer import *
346
+ >>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3)
347
+ >>> print(model.coef)
348
+ >>> model.swap(0,1,mode='in')
349
+ >>> print(model.coef)
350
+ '''
351
+ with torch.no_grad():
352
+ def swap_(data, i1, i2, mode='in'):
353
+ if mode == 'in':
354
+ data[i1], data[i2] = data[i2].clone(), data[i1].clone()
355
+ elif mode == 'out':
356
+ data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
357
+
358
+ if mode == 'in':
359
+ swap_(self.grid.data, i1, i2, mode='in')
360
+ swap_(self.coef.data, i1, i2, mode=mode)
361
+ swap_(self.scale_base.data, i1, i2, mode=mode)
362
+ swap_(self.scale_sp.data, i1, i2, mode=mode)
363
+ swap_(self.mask.data, i1, i2, mode=mode)
364
+