yms-kan 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan-0.0.7/LICENSE +21 -0
- yms_kan-0.0.7/PKG-INFO +18 -0
- yms_kan-0.0.7/README.md +1 -0
- yms_kan-0.0.7/kan/KANLayer.py +364 -0
- yms_kan-0.0.7/kan/LBFGS.py +492 -0
- yms_kan-0.0.7/kan/MLP.py +361 -0
- yms_kan-0.0.7/kan/MultKAN.py +3087 -0
- yms_kan-0.0.7/kan/Symbolic_KANLayer.py +270 -0
- yms_kan-0.0.7/kan/__init__.py +3 -0
- yms_kan-0.0.7/kan/compiler.py +498 -0
- yms_kan-0.0.7/kan/dataset.py +27 -0
- yms_kan-0.0.7/kan/experiment.py +50 -0
- yms_kan-0.0.7/kan/feynman.py +739 -0
- yms_kan-0.0.7/kan/hypothesis.py +695 -0
- yms_kan-0.0.7/kan/spline.py +144 -0
- yms_kan-0.0.7/kan/utils.py +661 -0
- yms_kan-0.0.7/setup.cfg +4 -0
- yms_kan-0.0.7/setup.py +96 -0
- yms_kan-0.0.7/yms_kan.egg-info/PKG-INFO +18 -0
- yms_kan-0.0.7/yms_kan.egg-info/SOURCES.txt +20 -0
- yms_kan-0.0.7/yms_kan.egg-info/dependency_links.txt +1 -0
- yms_kan-0.0.7/yms_kan.egg-info/top_level.txt +1 -0
@@ -0,0 +1,270 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import numpy as np
|
4
|
+
import sympy
|
5
|
+
from .utils import *
|
6
|
+
|
7
|
+
|
8
|
+
|
9
|
+
class Symbolic_KANLayer(nn.Module):
|
10
|
+
'''
|
11
|
+
KANLayer class
|
12
|
+
|
13
|
+
Attributes:
|
14
|
+
-----------
|
15
|
+
in_dim : int
|
16
|
+
input dimension
|
17
|
+
out_dim : int
|
18
|
+
output dimension
|
19
|
+
funs : 2D array of torch functions (or lambda functions)
|
20
|
+
symbolic functions (torch)
|
21
|
+
funs_avoid_singularity : 2D array of torch functions (or lambda functions) with singularity avoiding
|
22
|
+
funs_name : 2D arry of str
|
23
|
+
names of symbolic functions
|
24
|
+
funs_sympy : 2D array of sympy functions (or lambda functions)
|
25
|
+
symbolic functions (sympy)
|
26
|
+
affine : 3D array of floats
|
27
|
+
affine transformations of inputs and outputs
|
28
|
+
'''
|
29
|
+
def __init__(self, in_dim=3, out_dim=2, device='cpu'):
|
30
|
+
'''
|
31
|
+
initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions)
|
32
|
+
|
33
|
+
Args:
|
34
|
+
-----
|
35
|
+
in_dim : int
|
36
|
+
input dimension
|
37
|
+
out_dim : int
|
38
|
+
output dimension
|
39
|
+
device : str
|
40
|
+
device
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
--------
|
44
|
+
self
|
45
|
+
|
46
|
+
Example
|
47
|
+
-------
|
48
|
+
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3)
|
49
|
+
>>> len(sb.funs), len(sb.funs[0])
|
50
|
+
'''
|
51
|
+
super(Symbolic_KANLayer, self).__init__()
|
52
|
+
self.out_dim = out_dim
|
53
|
+
self.in_dim = in_dim
|
54
|
+
self.mask = torch.nn.Parameter(torch.zeros(out_dim, in_dim, device=device)).requires_grad_(False)
|
55
|
+
# torch
|
56
|
+
self.funs = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)]
|
57
|
+
self.funs_avoid_singularity = [[lambda x, y_th: ((), x*0.) for i in range(self.in_dim)] for j in range(self.out_dim)]
|
58
|
+
# name
|
59
|
+
self.funs_name = [['0' for i in range(self.in_dim)] for j in range(self.out_dim)]
|
60
|
+
# sympy
|
61
|
+
self.funs_sympy = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)]
|
62
|
+
### make funs_name the only parameter, and make others as the properties of funs_name?
|
63
|
+
|
64
|
+
self.affine = torch.nn.Parameter(torch.zeros(out_dim, in_dim, 4, device=device))
|
65
|
+
# c*f(a*x+b)+d
|
66
|
+
|
67
|
+
self.device = device
|
68
|
+
self.to(device)
|
69
|
+
|
70
|
+
def to(self, device):
|
71
|
+
'''
|
72
|
+
move to device
|
73
|
+
'''
|
74
|
+
super(Symbolic_KANLayer, self).to(device)
|
75
|
+
self.device = device
|
76
|
+
return self
|
77
|
+
|
78
|
+
def forward(self, x, singularity_avoiding=False, y_th=10.):
|
79
|
+
'''
|
80
|
+
forward
|
81
|
+
|
82
|
+
Args:
|
83
|
+
-----
|
84
|
+
x : 2D array
|
85
|
+
inputs, shape (batch, input dimension)
|
86
|
+
singularity_avoiding : bool
|
87
|
+
if True, funs_avoid_singularity is used; if False, funs is used.
|
88
|
+
y_th : float
|
89
|
+
the singularity threshold
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
--------
|
93
|
+
y : 2D array
|
94
|
+
outputs, shape (batch, output dimension)
|
95
|
+
postacts : 3D array
|
96
|
+
activations after activation functions but before being summed on nodes
|
97
|
+
|
98
|
+
Example
|
99
|
+
-------
|
100
|
+
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5)
|
101
|
+
>>> x = torch.normal(0,1,size=(100,3))
|
102
|
+
>>> y, postacts = sb(x)
|
103
|
+
>>> y.shape, postacts.shape
|
104
|
+
(torch.Size([100, 5]), torch.Size([100, 5, 3]))
|
105
|
+
'''
|
106
|
+
|
107
|
+
batch = x.shape[0]
|
108
|
+
postacts = []
|
109
|
+
|
110
|
+
for i in range(self.in_dim):
|
111
|
+
postacts_ = []
|
112
|
+
for j in range(self.out_dim):
|
113
|
+
if singularity_avoiding:
|
114
|
+
xij = self.affine[j,i,2]*self.funs_avoid_singularity[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1], torch.tensor(y_th))[1]+self.affine[j,i,3]
|
115
|
+
else:
|
116
|
+
xij = self.affine[j,i,2]*self.funs[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1])+self.affine[j,i,3]
|
117
|
+
postacts_.append(self.mask[j][i]*xij)
|
118
|
+
postacts.append(torch.stack(postacts_))
|
119
|
+
|
120
|
+
postacts = torch.stack(postacts)
|
121
|
+
postacts = postacts.permute(2,1,0,3)[:,:,:,0]
|
122
|
+
y = torch.sum(postacts, dim=2)
|
123
|
+
|
124
|
+
return y, postacts
|
125
|
+
|
126
|
+
|
127
|
+
def get_subset(self, in_id, out_id):
|
128
|
+
'''
|
129
|
+
get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning)
|
130
|
+
|
131
|
+
Args:
|
132
|
+
-----
|
133
|
+
in_id : list
|
134
|
+
id of selected input neurons
|
135
|
+
out_id : list
|
136
|
+
id of selected output neurons
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
--------
|
140
|
+
spb : Symbolic_KANLayer
|
141
|
+
|
142
|
+
Example
|
143
|
+
-------
|
144
|
+
>>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10)
|
145
|
+
>>> sb_small = sb_large.get_subset([0,9],[1,2,3])
|
146
|
+
>>> sb_small.in_dim, sb_small.out_dim
|
147
|
+
'''
|
148
|
+
sbb = Symbolic_KANLayer(self.in_dim, self.out_dim, device=self.device)
|
149
|
+
sbb.in_dim = len(in_id)
|
150
|
+
sbb.out_dim = len(out_id)
|
151
|
+
sbb.mask.data = self.mask.data[out_id][:,in_id]
|
152
|
+
sbb.funs = [[self.funs[j][i] for i in in_id] for j in out_id]
|
153
|
+
sbb.funs_avoid_singularity = [[self.funs_avoid_singularity[j][i] for i in in_id] for j in out_id]
|
154
|
+
sbb.funs_sympy = [[self.funs_sympy[j][i] for i in in_id] for j in out_id]
|
155
|
+
sbb.funs_name = [[self.funs_name[j][i] for i in in_id] for j in out_id]
|
156
|
+
sbb.affine.data = self.affine.data[out_id][:,in_id]
|
157
|
+
return sbb
|
158
|
+
|
159
|
+
|
160
|
+
def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10,10), b_range=(-10,10), verbose=True):
|
161
|
+
'''
|
162
|
+
fix an activation function to be symbolic
|
163
|
+
|
164
|
+
Args:
|
165
|
+
-----
|
166
|
+
i : int
|
167
|
+
the id of input neuron
|
168
|
+
j : int
|
169
|
+
the id of output neuron
|
170
|
+
fun_name : str
|
171
|
+
the name of the symbolic functions
|
172
|
+
x : 1D array
|
173
|
+
preactivations
|
174
|
+
y : 1D array
|
175
|
+
postactivations
|
176
|
+
a_range : tuple
|
177
|
+
sweeping range of a
|
178
|
+
b_range : tuple
|
179
|
+
sweeping range of a
|
180
|
+
verbose : bool
|
181
|
+
print more information if True
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
--------
|
185
|
+
r2 (coefficient of determination)
|
186
|
+
|
187
|
+
Example 1
|
188
|
+
---------
|
189
|
+
>>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0
|
190
|
+
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
|
191
|
+
>>> sb.fix_symbolic(2,1,'sin')
|
192
|
+
>>> print(sb.funs_name)
|
193
|
+
>>> print(sb.affine)
|
194
|
+
|
195
|
+
Example 2
|
196
|
+
---------
|
197
|
+
>>> # when x & y are provided, fit_params() is called to find the best fit coefficients
|
198
|
+
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
|
199
|
+
>>> batch = 100
|
200
|
+
>>> x = torch.linspace(-1,1,steps=batch)
|
201
|
+
>>> noises = torch.normal(0,1,(batch,)) * 0.02
|
202
|
+
>>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
|
203
|
+
>>> sb.fix_symbolic(2,1,'sin',x,y)
|
204
|
+
>>> print(sb.funs_name)
|
205
|
+
>>> print(sb.affine[1,2,:].data)
|
206
|
+
'''
|
207
|
+
if isinstance(fun_name,str):
|
208
|
+
fun = SYMBOLIC_LIB[fun_name][0]
|
209
|
+
fun_sympy = SYMBOLIC_LIB[fun_name][1]
|
210
|
+
fun_avoid_singularity = SYMBOLIC_LIB[fun_name][3]
|
211
|
+
self.funs_sympy[j][i] = fun_sympy
|
212
|
+
self.funs_name[j][i] = fun_name
|
213
|
+
|
214
|
+
if x == None or y == None:
|
215
|
+
#initialzie from just fun
|
216
|
+
self.funs[j][i] = fun
|
217
|
+
self.funs_avoid_singularity[j][i] = fun_avoid_singularity
|
218
|
+
if random == False:
|
219
|
+
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
|
220
|
+
else:
|
221
|
+
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
|
222
|
+
return None
|
223
|
+
else:
|
224
|
+
#initialize from x & y and fun
|
225
|
+
params, r2 = fit_params(x,y,fun, a_range=a_range, b_range=b_range, verbose=verbose, device=self.device)
|
226
|
+
self.funs[j][i] = fun
|
227
|
+
self.funs_avoid_singularity[j][i] = fun_avoid_singularity
|
228
|
+
self.affine.data[j][i] = params
|
229
|
+
return r2
|
230
|
+
else:
|
231
|
+
# if fun_name itself is a function
|
232
|
+
fun = fun_name
|
233
|
+
fun_sympy = fun_name
|
234
|
+
self.funs_sympy[j][i] = fun_sympy
|
235
|
+
self.funs_name[j][i] = "anonymous"
|
236
|
+
|
237
|
+
self.funs[j][i] = fun
|
238
|
+
self.funs_avoid_singularity[j][i] = fun
|
239
|
+
if random == False:
|
240
|
+
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
|
241
|
+
else:
|
242
|
+
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
|
243
|
+
return None
|
244
|
+
|
245
|
+
def swap(self, i1, i2, mode='in'):
|
246
|
+
'''
|
247
|
+
swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out')
|
248
|
+
'''
|
249
|
+
with torch.no_grad():
|
250
|
+
def swap_list_(data, i1, i2, mode='in'):
|
251
|
+
|
252
|
+
if mode == 'in':
|
253
|
+
for j in range(self.out_dim):
|
254
|
+
data[j][i1], data[j][i2] = data[j][i2], data[j][i1]
|
255
|
+
|
256
|
+
elif mode == 'out':
|
257
|
+
data[i1], data[i2] = data[i2], data[i1]
|
258
|
+
|
259
|
+
def swap_(data, i1, i2, mode='in'):
|
260
|
+
if mode == 'in':
|
261
|
+
data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
|
262
|
+
|
263
|
+
elif mode == 'out':
|
264
|
+
data[i1], data[i2] = data[i2].clone(), data[i1].clone()
|
265
|
+
|
266
|
+
swap_list_(self.funs_name,i1,i2,mode)
|
267
|
+
swap_list_(self.funs_sympy,i1,i2,mode)
|
268
|
+
swap_list_(self.funs_avoid_singularity,i1,i2,mode)
|
269
|
+
swap_(self.affine.data,i1,i2,mode)
|
270
|
+
swap_(self.mask.data,i1,i2,mode)
|