yms-kan 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,270 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import sympy
5
+ from .utils import *
6
+
7
+
8
+
9
+ class Symbolic_KANLayer(nn.Module):
10
+ '''
11
+ KANLayer class
12
+
13
+ Attributes:
14
+ -----------
15
+ in_dim : int
16
+ input dimension
17
+ out_dim : int
18
+ output dimension
19
+ funs : 2D array of torch functions (or lambda functions)
20
+ symbolic functions (torch)
21
+ funs_avoid_singularity : 2D array of torch functions (or lambda functions) with singularity avoiding
22
+ funs_name : 2D arry of str
23
+ names of symbolic functions
24
+ funs_sympy : 2D array of sympy functions (or lambda functions)
25
+ symbolic functions (sympy)
26
+ affine : 3D array of floats
27
+ affine transformations of inputs and outputs
28
+ '''
29
+ def __init__(self, in_dim=3, out_dim=2, device='cpu'):
30
+ '''
31
+ initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions)
32
+
33
+ Args:
34
+ -----
35
+ in_dim : int
36
+ input dimension
37
+ out_dim : int
38
+ output dimension
39
+ device : str
40
+ device
41
+
42
+ Returns:
43
+ --------
44
+ self
45
+
46
+ Example
47
+ -------
48
+ >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3)
49
+ >>> len(sb.funs), len(sb.funs[0])
50
+ '''
51
+ super(Symbolic_KANLayer, self).__init__()
52
+ self.out_dim = out_dim
53
+ self.in_dim = in_dim
54
+ self.mask = torch.nn.Parameter(torch.zeros(out_dim, in_dim, device=device)).requires_grad_(False)
55
+ # torch
56
+ self.funs = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)]
57
+ self.funs_avoid_singularity = [[lambda x, y_th: ((), x*0.) for i in range(self.in_dim)] for j in range(self.out_dim)]
58
+ # name
59
+ self.funs_name = [['0' for i in range(self.in_dim)] for j in range(self.out_dim)]
60
+ # sympy
61
+ self.funs_sympy = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)]
62
+ ### make funs_name the only parameter, and make others as the properties of funs_name?
63
+
64
+ self.affine = torch.nn.Parameter(torch.zeros(out_dim, in_dim, 4, device=device))
65
+ # c*f(a*x+b)+d
66
+
67
+ self.device = device
68
+ self.to(device)
69
+
70
+ def to(self, device):
71
+ '''
72
+ move to device
73
+ '''
74
+ super(Symbolic_KANLayer, self).to(device)
75
+ self.device = device
76
+ return self
77
+
78
+ def forward(self, x, singularity_avoiding=False, y_th=10.):
79
+ '''
80
+ forward
81
+
82
+ Args:
83
+ -----
84
+ x : 2D array
85
+ inputs, shape (batch, input dimension)
86
+ singularity_avoiding : bool
87
+ if True, funs_avoid_singularity is used; if False, funs is used.
88
+ y_th : float
89
+ the singularity threshold
90
+
91
+ Returns:
92
+ --------
93
+ y : 2D array
94
+ outputs, shape (batch, output dimension)
95
+ postacts : 3D array
96
+ activations after activation functions but before being summed on nodes
97
+
98
+ Example
99
+ -------
100
+ >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5)
101
+ >>> x = torch.normal(0,1,size=(100,3))
102
+ >>> y, postacts = sb(x)
103
+ >>> y.shape, postacts.shape
104
+ (torch.Size([100, 5]), torch.Size([100, 5, 3]))
105
+ '''
106
+
107
+ batch = x.shape[0]
108
+ postacts = []
109
+
110
+ for i in range(self.in_dim):
111
+ postacts_ = []
112
+ for j in range(self.out_dim):
113
+ if singularity_avoiding:
114
+ xij = self.affine[j,i,2]*self.funs_avoid_singularity[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1], torch.tensor(y_th))[1]+self.affine[j,i,3]
115
+ else:
116
+ xij = self.affine[j,i,2]*self.funs[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1])+self.affine[j,i,3]
117
+ postacts_.append(self.mask[j][i]*xij)
118
+ postacts.append(torch.stack(postacts_))
119
+
120
+ postacts = torch.stack(postacts)
121
+ postacts = postacts.permute(2,1,0,3)[:,:,:,0]
122
+ y = torch.sum(postacts, dim=2)
123
+
124
+ return y, postacts
125
+
126
+
127
+ def get_subset(self, in_id, out_id):
128
+ '''
129
+ get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning)
130
+
131
+ Args:
132
+ -----
133
+ in_id : list
134
+ id of selected input neurons
135
+ out_id : list
136
+ id of selected output neurons
137
+
138
+ Returns:
139
+ --------
140
+ spb : Symbolic_KANLayer
141
+
142
+ Example
143
+ -------
144
+ >>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10)
145
+ >>> sb_small = sb_large.get_subset([0,9],[1,2,3])
146
+ >>> sb_small.in_dim, sb_small.out_dim
147
+ '''
148
+ sbb = Symbolic_KANLayer(self.in_dim, self.out_dim, device=self.device)
149
+ sbb.in_dim = len(in_id)
150
+ sbb.out_dim = len(out_id)
151
+ sbb.mask.data = self.mask.data[out_id][:,in_id]
152
+ sbb.funs = [[self.funs[j][i] for i in in_id] for j in out_id]
153
+ sbb.funs_avoid_singularity = [[self.funs_avoid_singularity[j][i] for i in in_id] for j in out_id]
154
+ sbb.funs_sympy = [[self.funs_sympy[j][i] for i in in_id] for j in out_id]
155
+ sbb.funs_name = [[self.funs_name[j][i] for i in in_id] for j in out_id]
156
+ sbb.affine.data = self.affine.data[out_id][:,in_id]
157
+ return sbb
158
+
159
+
160
+ def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10,10), b_range=(-10,10), verbose=True):
161
+ '''
162
+ fix an activation function to be symbolic
163
+
164
+ Args:
165
+ -----
166
+ i : int
167
+ the id of input neuron
168
+ j : int
169
+ the id of output neuron
170
+ fun_name : str
171
+ the name of the symbolic functions
172
+ x : 1D array
173
+ preactivations
174
+ y : 1D array
175
+ postactivations
176
+ a_range : tuple
177
+ sweeping range of a
178
+ b_range : tuple
179
+ sweeping range of a
180
+ verbose : bool
181
+ print more information if True
182
+
183
+ Returns:
184
+ --------
185
+ r2 (coefficient of determination)
186
+
187
+ Example 1
188
+ ---------
189
+ >>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0
190
+ >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
191
+ >>> sb.fix_symbolic(2,1,'sin')
192
+ >>> print(sb.funs_name)
193
+ >>> print(sb.affine)
194
+
195
+ Example 2
196
+ ---------
197
+ >>> # when x & y are provided, fit_params() is called to find the best fit coefficients
198
+ >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
199
+ >>> batch = 100
200
+ >>> x = torch.linspace(-1,1,steps=batch)
201
+ >>> noises = torch.normal(0,1,(batch,)) * 0.02
202
+ >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
203
+ >>> sb.fix_symbolic(2,1,'sin',x,y)
204
+ >>> print(sb.funs_name)
205
+ >>> print(sb.affine[1,2,:].data)
206
+ '''
207
+ if isinstance(fun_name,str):
208
+ fun = SYMBOLIC_LIB[fun_name][0]
209
+ fun_sympy = SYMBOLIC_LIB[fun_name][1]
210
+ fun_avoid_singularity = SYMBOLIC_LIB[fun_name][3]
211
+ self.funs_sympy[j][i] = fun_sympy
212
+ self.funs_name[j][i] = fun_name
213
+
214
+ if x == None or y == None:
215
+ #initialzie from just fun
216
+ self.funs[j][i] = fun
217
+ self.funs_avoid_singularity[j][i] = fun_avoid_singularity
218
+ if random == False:
219
+ self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
220
+ else:
221
+ self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
222
+ return None
223
+ else:
224
+ #initialize from x & y and fun
225
+ params, r2 = fit_params(x,y,fun, a_range=a_range, b_range=b_range, verbose=verbose, device=self.device)
226
+ self.funs[j][i] = fun
227
+ self.funs_avoid_singularity[j][i] = fun_avoid_singularity
228
+ self.affine.data[j][i] = params
229
+ return r2
230
+ else:
231
+ # if fun_name itself is a function
232
+ fun = fun_name
233
+ fun_sympy = fun_name
234
+ self.funs_sympy[j][i] = fun_sympy
235
+ self.funs_name[j][i] = "anonymous"
236
+
237
+ self.funs[j][i] = fun
238
+ self.funs_avoid_singularity[j][i] = fun
239
+ if random == False:
240
+ self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
241
+ else:
242
+ self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
243
+ return None
244
+
245
+ def swap(self, i1, i2, mode='in'):
246
+ '''
247
+ swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out')
248
+ '''
249
+ with torch.no_grad():
250
+ def swap_list_(data, i1, i2, mode='in'):
251
+
252
+ if mode == 'in':
253
+ for j in range(self.out_dim):
254
+ data[j][i1], data[j][i2] = data[j][i2], data[j][i1]
255
+
256
+ elif mode == 'out':
257
+ data[i1], data[i2] = data[i2], data[i1]
258
+
259
+ def swap_(data, i1, i2, mode='in'):
260
+ if mode == 'in':
261
+ data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
262
+
263
+ elif mode == 'out':
264
+ data[i1], data[i2] = data[i2].clone(), data[i1].clone()
265
+
266
+ swap_list_(self.funs_name,i1,i2,mode)
267
+ swap_list_(self.funs_sympy,i1,i2,mode)
268
+ swap_list_(self.funs_avoid_singularity,i1,i2,mode)
269
+ swap_(self.affine.data,i1,i2,mode)
270
+ swap_(self.mask.data,i1,i2,mode)
yms_kan/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .MultKAN import *
2
+ from .utils import *
3
+ # torch.use_deterministic_algorithms(True)
4
+ from .version import __version__