yms-kan 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan-0.0.7/LICENSE +21 -0
- yms_kan-0.0.7/PKG-INFO +18 -0
- yms_kan-0.0.7/README.md +1 -0
- yms_kan-0.0.7/kan/KANLayer.py +364 -0
- yms_kan-0.0.7/kan/LBFGS.py +492 -0
- yms_kan-0.0.7/kan/MLP.py +361 -0
- yms_kan-0.0.7/kan/MultKAN.py +3087 -0
- yms_kan-0.0.7/kan/Symbolic_KANLayer.py +270 -0
- yms_kan-0.0.7/kan/__init__.py +3 -0
- yms_kan-0.0.7/kan/compiler.py +498 -0
- yms_kan-0.0.7/kan/dataset.py +27 -0
- yms_kan-0.0.7/kan/experiment.py +50 -0
- yms_kan-0.0.7/kan/feynman.py +739 -0
- yms_kan-0.0.7/kan/hypothesis.py +695 -0
- yms_kan-0.0.7/kan/spline.py +144 -0
- yms_kan-0.0.7/kan/utils.py +661 -0
- yms_kan-0.0.7/setup.cfg +4 -0
- yms_kan-0.0.7/setup.py +96 -0
- yms_kan-0.0.7/yms_kan.egg-info/PKG-INFO +18 -0
- yms_kan-0.0.7/yms_kan.egg-info/SOURCES.txt +20 -0
- yms_kan-0.0.7/yms_kan.egg-info/dependency_links.txt +1 -0
- yms_kan-0.0.7/yms_kan.egg-info/top_level.txt +1 -0
@@ -0,0 +1,498 @@
|
|
1
|
+
from sympy import *
|
2
|
+
import sympy
|
3
|
+
import numpy as np
|
4
|
+
from kan.MultKAN import MultKAN
|
5
|
+
import torch
|
6
|
+
|
7
|
+
def next_nontrivial_operation(expr, scale=1, bias=0):
|
8
|
+
'''
|
9
|
+
remove the affine part of an expression
|
10
|
+
|
11
|
+
Args:
|
12
|
+
-----
|
13
|
+
expr : sympy expression
|
14
|
+
scale : float
|
15
|
+
bias : float
|
16
|
+
|
17
|
+
Returns:
|
18
|
+
--------
|
19
|
+
expr : sympy expression
|
20
|
+
scale : float
|
21
|
+
bias : float
|
22
|
+
|
23
|
+
Example
|
24
|
+
-------
|
25
|
+
>>> from kan.compiler import *
|
26
|
+
>>> from sympy import *
|
27
|
+
>>> input_vars = a, b = symbols('a b')
|
28
|
+
>>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
|
29
|
+
>>> next_nontrivial_operation(expression)
|
30
|
+
'''
|
31
|
+
if expr.func == Add or expr.func == Mul:
|
32
|
+
n_arg = len(expr.args)
|
33
|
+
n_num = 0
|
34
|
+
n_var_id = []
|
35
|
+
n_num_id = []
|
36
|
+
var_args = []
|
37
|
+
for i in range(n_arg):
|
38
|
+
is_number = expr.args[i].is_number
|
39
|
+
n_num += is_number
|
40
|
+
if not is_number:
|
41
|
+
n_var_id.append(i)
|
42
|
+
var_args.append(expr.args[i])
|
43
|
+
else:
|
44
|
+
n_num_id.append(i)
|
45
|
+
if n_num > 0:
|
46
|
+
# trivial
|
47
|
+
if expr.func == Add:
|
48
|
+
for i in range(n_num):
|
49
|
+
if i == 0:
|
50
|
+
bias = expr.args[n_num_id[i]]
|
51
|
+
else:
|
52
|
+
bias += expr.args[n_num_id[i]]
|
53
|
+
if expr.func == Mul:
|
54
|
+
for i in range(n_num):
|
55
|
+
if i == 0:
|
56
|
+
scale = expr.args[n_num_id[i]]
|
57
|
+
else:
|
58
|
+
scale *= expr.args[n_num_id[i]]
|
59
|
+
|
60
|
+
return next_nontrivial_operation(expr.func(*var_args), scale, bias)
|
61
|
+
else:
|
62
|
+
return expr, scale, bias
|
63
|
+
else:
|
64
|
+
return expr, scale, bias
|
65
|
+
|
66
|
+
|
67
|
+
def expr2kan(input_variables, expr, grid=5, k=3, auto_save=False):
|
68
|
+
'''
|
69
|
+
compile a symbolic formula to a MultKAN
|
70
|
+
|
71
|
+
Args:
|
72
|
+
-----
|
73
|
+
input_variables : a list of sympy symbols
|
74
|
+
expr : sympy expression
|
75
|
+
grid : int
|
76
|
+
the number of grid intervals
|
77
|
+
k : int
|
78
|
+
spline order
|
79
|
+
auto_save : bool
|
80
|
+
if auto_save = True, models are automatically saved
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
--------
|
84
|
+
MultKAN
|
85
|
+
|
86
|
+
Example
|
87
|
+
-------
|
88
|
+
>>> from kan.compiler import *
|
89
|
+
>>> from sympy import *
|
90
|
+
>>> input_vars = a, b = symbols('a b')
|
91
|
+
>>> expression = exp(sin(pi*a) + b**2)
|
92
|
+
>>> model = kanpiler(input_vars, expression)
|
93
|
+
>>> x = torch.rand(100,2) * 2 - 1
|
94
|
+
>>> model(x)
|
95
|
+
>>> model.plot()
|
96
|
+
'''
|
97
|
+
class Node:
|
98
|
+
def __init__(self, expr, mult_bool, depth, scale, bias, parent=None, mult_arity=None):
|
99
|
+
self.expr = expr
|
100
|
+
self.mult_bool = mult_bool
|
101
|
+
if self.mult_bool:
|
102
|
+
self.mult_arity = mult_arity
|
103
|
+
self.depth = depth
|
104
|
+
|
105
|
+
if len(Nodes) <= depth:
|
106
|
+
Nodes.append([])
|
107
|
+
index = 0
|
108
|
+
else:
|
109
|
+
index = len(Nodes[depth])
|
110
|
+
|
111
|
+
Nodes[depth].append(self)
|
112
|
+
|
113
|
+
self.index = index
|
114
|
+
if parent == None:
|
115
|
+
self.parent_index = None
|
116
|
+
else:
|
117
|
+
self.parent_index = parent.index
|
118
|
+
self.child_index = []
|
119
|
+
|
120
|
+
# update parent's child_index
|
121
|
+
if parent != None:
|
122
|
+
parent.child_index.append(self.index)
|
123
|
+
|
124
|
+
|
125
|
+
self.scale = scale
|
126
|
+
self.bias = bias
|
127
|
+
|
128
|
+
|
129
|
+
class SubNode:
|
130
|
+
def __init__(self, expr, depth, scale, bias, parent=None):
|
131
|
+
self.expr = expr
|
132
|
+
self.depth = depth
|
133
|
+
|
134
|
+
if len(SubNodes) <= depth:
|
135
|
+
SubNodes.append([])
|
136
|
+
index = 0
|
137
|
+
else:
|
138
|
+
index = len(SubNodes[depth])
|
139
|
+
|
140
|
+
SubNodes[depth].append(self)
|
141
|
+
|
142
|
+
self.index = index
|
143
|
+
self.parent_index = None # shape: (2,)
|
144
|
+
self.child_index = [] # shape: (n, 2)
|
145
|
+
|
146
|
+
# update parent's child_index
|
147
|
+
parent.child_index.append(self.index)
|
148
|
+
|
149
|
+
self.scale = scale
|
150
|
+
self.bias = bias
|
151
|
+
|
152
|
+
|
153
|
+
class Connection:
|
154
|
+
def __init__(self, affine, fun, fun_name, parent=None, child=None, power_exponent=None):
|
155
|
+
# connection = activation function that connects a subnode to a node in the next layer node
|
156
|
+
self.affine = affine #[1,0,1,0] # (a,b,c,d)
|
157
|
+
self.fun = fun # y = c*fun(a*x+b)+d
|
158
|
+
self.fun_name = fun_name
|
159
|
+
self.parent_index = parent.index
|
160
|
+
self.depth = parent.depth
|
161
|
+
self.child_index = child.index
|
162
|
+
self.power_exponent = power_exponent # if fun == Pow
|
163
|
+
Connections[(self.depth,self.parent_index,self.child_index)] = self
|
164
|
+
|
165
|
+
def create_node(expr, parent=None, n_layer=None):
|
166
|
+
#print('before', expr)
|
167
|
+
expr, scale, bias = next_nontrivial_operation(expr)
|
168
|
+
#print('after', expr)
|
169
|
+
if parent == None:
|
170
|
+
depth = 0
|
171
|
+
else:
|
172
|
+
depth = parent.depth
|
173
|
+
|
174
|
+
|
175
|
+
if expr.func == Mul:
|
176
|
+
mult_arity = len(expr.args)
|
177
|
+
node = Node(expr, True, depth, scale, bias, parent=parent, mult_arity=mult_arity)
|
178
|
+
# create mult_arity SubNodes, + 1
|
179
|
+
for i in range(mult_arity):
|
180
|
+
# create SubNode
|
181
|
+
expr_i, scale, bias = next_nontrivial_operation(expr.args[i])
|
182
|
+
subnode = SubNode(expr_i, node.depth+1, scale, bias, parent=node)
|
183
|
+
if expr_i.func == Add:
|
184
|
+
for j in range(len(expr_i.args)):
|
185
|
+
expr_ij, scale, bias = next_nontrivial_operation(expr_i.args[j])
|
186
|
+
# expr_ij is impossible to be Add, should be Mul or 1D
|
187
|
+
if expr_ij.func == Mul:
|
188
|
+
#print(expr_ij)
|
189
|
+
# create a node with expr_ij
|
190
|
+
new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer)
|
191
|
+
# create a connection which is a linear function
|
192
|
+
c = Connection([1,0,float(scale),float(bias)], lambda x: x, 'x', parent=subnode, child=new_node)
|
193
|
+
|
194
|
+
elif expr_ij.func == Symbol:
|
195
|
+
#print(expr_ij)
|
196
|
+
new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer)
|
197
|
+
c = Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
|
198
|
+
|
199
|
+
else:
|
200
|
+
# 1D function case
|
201
|
+
# create a node with expr_ij.args[0]
|
202
|
+
new_node = create_node(expr_ij.args[0], parent=subnode, n_layer=n_layer)
|
203
|
+
# create 1D function expr_ij.func
|
204
|
+
if expr_ij.func == Pow:
|
205
|
+
power_exponent = expr_ij.args[1]
|
206
|
+
else:
|
207
|
+
power_exponent = None
|
208
|
+
Connection([1,0,float(scale),float(bias)], expr_ij.func, fun_name = expr_ij.func, parent=subnode, child=new_node, power_exponent=power_exponent)
|
209
|
+
|
210
|
+
|
211
|
+
elif expr_i.func == Mul:
|
212
|
+
# create a node with expr_i
|
213
|
+
new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
|
214
|
+
# create 1D function, linear
|
215
|
+
Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
|
216
|
+
|
217
|
+
elif expr_i.func == Symbol:
|
218
|
+
new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
|
219
|
+
Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
|
220
|
+
|
221
|
+
else:
|
222
|
+
# 1D functions
|
223
|
+
# create a node with expr_i.args[0]
|
224
|
+
new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer)
|
225
|
+
# create 1D function expr_i.func
|
226
|
+
if expr_i.func == Pow:
|
227
|
+
power_exponent = expr_i.args[1]
|
228
|
+
else:
|
229
|
+
power_exponent = None
|
230
|
+
Connection([1,0,1,0], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent)
|
231
|
+
|
232
|
+
elif expr.func == Add:
|
233
|
+
|
234
|
+
node = Node(expr, False, depth, scale, bias, parent=parent)
|
235
|
+
subnode = SubNode(expr, node.depth+1, 1, 0, parent=node)
|
236
|
+
|
237
|
+
for i in range(len(expr.args)):
|
238
|
+
expr_i, scale, bias = next_nontrivial_operation(expr.args[i])
|
239
|
+
if expr_i.func == Mul:
|
240
|
+
# create a node with expr_i
|
241
|
+
new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
|
242
|
+
# create a connection which is a linear function
|
243
|
+
Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
|
244
|
+
|
245
|
+
elif expr_i.func == Symbol:
|
246
|
+
new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
|
247
|
+
Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
|
248
|
+
|
249
|
+
else:
|
250
|
+
# 1D function case
|
251
|
+
# create a node with expr_ij.args[0]
|
252
|
+
new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer)
|
253
|
+
# create 1D function expr_i.func
|
254
|
+
if expr_i.func == Pow:
|
255
|
+
power_exponent = expr_i.args[1]
|
256
|
+
else:
|
257
|
+
power_exponent = None
|
258
|
+
Connection([1,0,float(scale),float(bias)], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent)
|
259
|
+
|
260
|
+
elif expr.func == Symbol:
|
261
|
+
# expr.func is a symbol (one of input variables)
|
262
|
+
if n_layer == None:
|
263
|
+
node = Node(expr, False, depth, scale, bias, parent=parent)
|
264
|
+
else:
|
265
|
+
node = Node(expr, False, depth, scale, bias, parent=parent)
|
266
|
+
return_node = node
|
267
|
+
for i in range(n_layer - depth):
|
268
|
+
subnode = SubNode(expr, node.depth+1, 1, 0, parent=node)
|
269
|
+
node = Node(expr, False, subnode.depth, 1, 0, parent=subnode)
|
270
|
+
Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=node)
|
271
|
+
node = return_node
|
272
|
+
|
273
|
+
Start_Nodes.append(node)
|
274
|
+
|
275
|
+
else:
|
276
|
+
# expr.func is 1D function
|
277
|
+
#print(expr, scale, bias)
|
278
|
+
node = Node(expr, False, depth, scale, bias, parent=parent)
|
279
|
+
expr_i, scale, bias = next_nontrivial_operation(expr.args[0])
|
280
|
+
subnode = SubNode(expr_i, node.depth+1, 1, 0, parent=node)
|
281
|
+
# create a node with expr_i.args[0]
|
282
|
+
new_node = create_node(expr.args[0], parent=subnode, n_layer=n_layer)
|
283
|
+
# create 1D function expr_i.func
|
284
|
+
if expr.func == Pow:
|
285
|
+
power_exponent = expr.args[1]
|
286
|
+
else:
|
287
|
+
power_exponent = None
|
288
|
+
Connection([1,0,1,0], expr.func, fun_name = expr.func, parent=subnode, child=new_node, power_exponent=power_exponent)
|
289
|
+
|
290
|
+
return node
|
291
|
+
|
292
|
+
Nodes = [[]]
|
293
|
+
SubNodes = [[]]
|
294
|
+
Connections = {}
|
295
|
+
Start_Nodes = []
|
296
|
+
|
297
|
+
create_node(expr, n_layer=None)
|
298
|
+
|
299
|
+
n_layer = len(Nodes) - 1
|
300
|
+
|
301
|
+
Nodes = [[]]
|
302
|
+
SubNodes = [[]]
|
303
|
+
Connections = {}
|
304
|
+
Start_Nodes = []
|
305
|
+
|
306
|
+
create_node(expr, n_layer=n_layer)
|
307
|
+
|
308
|
+
# move affine parameters in leaf nodes to connections
|
309
|
+
for node in Start_Nodes:
|
310
|
+
c = Connections[(node.depth,node.parent_index,node.index)]
|
311
|
+
c.affine[0] = float(node.scale)
|
312
|
+
c.affine[1] = float(node.bias)
|
313
|
+
node.scale = 1.
|
314
|
+
node.bias = 0.
|
315
|
+
|
316
|
+
#input_variables = symbol
|
317
|
+
node2var = []
|
318
|
+
for node in Start_Nodes:
|
319
|
+
for i in range(len(input_variables)):
|
320
|
+
if node.expr == input_variables[i]:
|
321
|
+
node2var.append(i)
|
322
|
+
|
323
|
+
# Nodes
|
324
|
+
n_mult = []
|
325
|
+
n_sum = []
|
326
|
+
for layer in Nodes:
|
327
|
+
n_mult.append(0)
|
328
|
+
n_sum.append(0)
|
329
|
+
for node in layer:
|
330
|
+
if node.mult_bool == True:
|
331
|
+
n_mult[-1] += 1
|
332
|
+
else:
|
333
|
+
n_sum[-1] += 1
|
334
|
+
|
335
|
+
# depth
|
336
|
+
n_layer = len(Nodes) - 1
|
337
|
+
|
338
|
+
# converter
|
339
|
+
# input tree node id, output kan node id (distinguish sum and mult node)
|
340
|
+
# input tree subnode id, output tree subnode id
|
341
|
+
# node id
|
342
|
+
subnode_index_convert = {}
|
343
|
+
node_index_convert = {}
|
344
|
+
connection_index_convert = {}
|
345
|
+
mult_arities = []
|
346
|
+
for layer_id in range(n_layer+1):
|
347
|
+
mult_arity = []
|
348
|
+
i_sum = 0
|
349
|
+
i_mult = 0
|
350
|
+
for i in range(len(Nodes[layer_id])):
|
351
|
+
node = Nodes[layer_id][i]
|
352
|
+
if node.mult_bool == True:
|
353
|
+
kan_node_id = n_sum[layer_id] + i_mult
|
354
|
+
arity = len(node.child_index)
|
355
|
+
for i in range(arity):
|
356
|
+
subnode = SubNodes[node.depth+1][node.child_index[i]]
|
357
|
+
kan_subnode_id = n_sum[layer_id] + np.sum(mult_arity) + i
|
358
|
+
subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id))
|
359
|
+
i_mult += 1
|
360
|
+
mult_arity.append(arity)
|
361
|
+
else:
|
362
|
+
kan_node_id = i_sum
|
363
|
+
if len(node.child_index) > 0:
|
364
|
+
subnode = SubNodes[node.depth+1][node.child_index[0]]
|
365
|
+
kan_subnode_id = i_sum
|
366
|
+
subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id))
|
367
|
+
i_sum += 1
|
368
|
+
|
369
|
+
if layer_id == n_layer:
|
370
|
+
# input layer
|
371
|
+
node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(node2var[kan_node_id]))
|
372
|
+
else:
|
373
|
+
node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(kan_node_id))
|
374
|
+
|
375
|
+
# node: depth (node.depth -> n_layer - node.depth)
|
376
|
+
# width (node.index -> kan_node_id)
|
377
|
+
# subnode: depth (subnode.depth -> n_layer - subnode.depth)
|
378
|
+
# width (subnote.index -> kan_subnode_id)
|
379
|
+
mult_arities.append(mult_arity)
|
380
|
+
|
381
|
+
for index in list(Connections.keys()):
|
382
|
+
depth, subnode_id, node_id = index
|
383
|
+
# to int(n_layer-depth),
|
384
|
+
_, kan_subnode_id = subnode_index_convert[(depth, subnode_id)]
|
385
|
+
_, kan_node_id = node_index_convert[(depth, node_id)]
|
386
|
+
connection_index_convert[(depth, subnode_id, node_id)] = (n_layer-depth, kan_subnode_id, kan_node_id)
|
387
|
+
|
388
|
+
|
389
|
+
n_sum.reverse()
|
390
|
+
n_mult.reverse()
|
391
|
+
mult_arities.reverse()
|
392
|
+
|
393
|
+
width = [[n_sum[i], n_mult[i]] for i in range(len(n_sum))]
|
394
|
+
width[0][0] = len(input_variables)
|
395
|
+
|
396
|
+
# allow pass in other parameters (probably as a dictionary) in sf2kan, including grid k etc.
|
397
|
+
model = MultKAN(width=width, mult_arity=mult_arities, grid=grid, k=k, auto_save=auto_save)
|
398
|
+
|
399
|
+
# clean the graph
|
400
|
+
for l in range(model.depth):
|
401
|
+
for i in range(model.width_in[l]):
|
402
|
+
for j in range(model.width_out[l+1]):
|
403
|
+
model.fix_symbolic(l,i,j,'0',fit_params_bool=False)
|
404
|
+
|
405
|
+
# Nodes
|
406
|
+
Nodes_flat = [x for xs in Nodes for x in xs]
|
407
|
+
|
408
|
+
self = model
|
409
|
+
|
410
|
+
for node in Nodes_flat:
|
411
|
+
node_depth = node.depth
|
412
|
+
node_index = node.index
|
413
|
+
kan_node_depth, kan_node_index = node_index_convert[(node_depth,node_index)]
|
414
|
+
#print(kan_node_depth, kan_node_index)
|
415
|
+
if kan_node_depth > 0:
|
416
|
+
self.node_scale[kan_node_depth-1].data[kan_node_index] = float(node.scale)
|
417
|
+
self.node_bias[kan_node_depth-1].data[kan_node_index] = float(node.bias)
|
418
|
+
|
419
|
+
|
420
|
+
# SubNodes
|
421
|
+
SubNodes_flat = [x for xs in SubNodes for x in xs]
|
422
|
+
|
423
|
+
for subnode in SubNodes_flat:
|
424
|
+
subnode_depth = subnode.depth
|
425
|
+
subnode_index = subnode.index
|
426
|
+
kan_subnode_depth, kan_subnode_index = subnode_index_convert[(subnode_depth,subnode_index)]
|
427
|
+
#print(kan_subnode_depth, kan_subnode_index)
|
428
|
+
self.subnode_scale[kan_subnode_depth].data[kan_subnode_index] = float(subnode.scale)
|
429
|
+
self.subnode_bias[kan_subnode_depth].data[kan_subnode_index] = float(subnode.bias)
|
430
|
+
|
431
|
+
# Connections
|
432
|
+
Connections_flat = list(Connections.values())
|
433
|
+
|
434
|
+
for connection in Connections_flat:
|
435
|
+
c_depth = connection.depth
|
436
|
+
c_j = connection.parent_index
|
437
|
+
c_i = connection.child_index
|
438
|
+
kc_depth, kc_j, kc_i = connection_index_convert[(c_depth, c_j, c_i)]
|
439
|
+
|
440
|
+
# get symbolic fun_name
|
441
|
+
fun_name = connection.fun_name
|
442
|
+
#if fun_name == Pow:
|
443
|
+
# print(connection.power_exponent)
|
444
|
+
|
445
|
+
if fun_name == 'x':
|
446
|
+
kfun_name = 'x'
|
447
|
+
elif fun_name == exp:
|
448
|
+
kfun_name = 'exp'
|
449
|
+
elif fun_name == sin:
|
450
|
+
kfun_name = 'sin'
|
451
|
+
elif fun_name == cos:
|
452
|
+
kfun_name = 'cos'
|
453
|
+
elif fun_name == tan:
|
454
|
+
kfun_name = 'tan'
|
455
|
+
elif fun_name == sqrt:
|
456
|
+
kfun_name = 'sqrt'
|
457
|
+
elif fun_name == log:
|
458
|
+
kfun_name = 'log'
|
459
|
+
elif fun_name == tanh:
|
460
|
+
kfun_name = 'tanh'
|
461
|
+
elif fun_name == asin:
|
462
|
+
kfun_name = 'arcsin'
|
463
|
+
elif fun_name == acos:
|
464
|
+
kfun_name = 'arccos'
|
465
|
+
elif fun_name == atan:
|
466
|
+
kfun_name = 'arctan'
|
467
|
+
elif fun_name == atanh:
|
468
|
+
kfun_name = 'arctanh'
|
469
|
+
elif fun_name == sign:
|
470
|
+
kfun_name = 'sgn'
|
471
|
+
elif fun_name == Pow:
|
472
|
+
alpha = connection.power_exponent
|
473
|
+
if alpha == Rational(1,2):
|
474
|
+
kfun_name = 'x^0.5'
|
475
|
+
elif alpha == - Rational(1,2):
|
476
|
+
kfun_name = '1/x^0.5'
|
477
|
+
elif alpha == Rational(3,2):
|
478
|
+
kfun_name = 'x^1.5'
|
479
|
+
else:
|
480
|
+
alpha = int(connection.power_exponent)
|
481
|
+
if alpha > 0:
|
482
|
+
if alpha == 1:
|
483
|
+
kfun_name = 'x'
|
484
|
+
else:
|
485
|
+
kfun_name = f'x^{alpha}'
|
486
|
+
else:
|
487
|
+
if alpha == -1:
|
488
|
+
kfun_name = '1/x'
|
489
|
+
else:
|
490
|
+
kfun_name = f'1/x^{-alpha}'
|
491
|
+
|
492
|
+
model.fix_symbolic(kc_depth, kc_i, kc_j, kfun_name, fit_params_bool=False)
|
493
|
+
model.symbolic_fun[kc_depth].affine.data.reshape(self.width_out[kc_depth+1], self.width_in[kc_depth], 4)[kc_j][kc_i] = torch.tensor(connection.affine)
|
494
|
+
|
495
|
+
return model
|
496
|
+
|
497
|
+
|
498
|
+
sf2kan = kanpiler = expr2kan
|
@@ -0,0 +1,27 @@
|
|
1
|
+
import scipy.io as sio
|
2
|
+
import torch
|
3
|
+
from torch.utils.data import Dataset, DataLoader
|
4
|
+
|
5
|
+
|
6
|
+
class KanDataset(Dataset):
|
7
|
+
def __init__(self, data_path):
|
8
|
+
self.path = data_path
|
9
|
+
self.data = sio.loadmat(data_path)
|
10
|
+
self.features = torch.tensor(self.data['features']).double()
|
11
|
+
self.labels = torch.tensor(self.data['labels'].squeeze()).double()
|
12
|
+
|
13
|
+
def __len__(self):
|
14
|
+
return len(self.labels)
|
15
|
+
|
16
|
+
def __getitem__(self, idx):
|
17
|
+
feature = self.features[idx]
|
18
|
+
label = self.labels[idx]
|
19
|
+
return feature, label
|
20
|
+
|
21
|
+
|
22
|
+
if __name__ == '__main__':
|
23
|
+
dataset = KanDataset(r'D:\Code\2-ZSL\Zero-Shot-Learning\data\新建文件夹\1\val.mat')
|
24
|
+
print(dataset.__len__())
|
25
|
+
data_loader = DataLoader(dataset=dataset, batch_size=2, shuffle=True)
|
26
|
+
for data, label in data_loader:
|
27
|
+
print(label.shape)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import torch
|
2
|
+
from .MultKAN import *
|
3
|
+
|
4
|
+
|
5
|
+
def runner1(width, dataset, grids=[5, 10, 20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2,
|
6
|
+
node_th=1e-2, metrics=None, seed=1):
|
7
|
+
result = {'test_loss': [], 'c': [], 'G': [], 'id': []}
|
8
|
+
if metrics != None:
|
9
|
+
for i in range(len(metrics)):
|
10
|
+
result[metrics[i].__name__] = []
|
11
|
+
|
12
|
+
def collect(evaluation):
|
13
|
+
result["test_loss"].append(evaluation['test_loss'])
|
14
|
+
result['c'].append(evaluation['n_edge'])
|
15
|
+
result['G'].append(evaluation['n_grid'])
|
16
|
+
result['id'].append(f'{model.round}.{model.state_id}')
|
17
|
+
if metrics is not None:
|
18
|
+
for i in range(len(metrics)):
|
19
|
+
result[metrics[i].__name__].append(metrics[i](model, dataset).item())
|
20
|
+
|
21
|
+
for i in range(prune_round):
|
22
|
+
# train and prune
|
23
|
+
if i == 0:
|
24
|
+
model = KAN(width=width, grid=grids[0], seed=seed)
|
25
|
+
else:
|
26
|
+
model = model.rewind(f'{i - 1}.{2 * i}')
|
27
|
+
|
28
|
+
model.fit(dataset, steps=steps, lamb=lamb)
|
29
|
+
model = model.prune(edge_th=edge_th, node_th=node_th)
|
30
|
+
evaluation = model.evaluate(dataset)
|
31
|
+
collect(evaluation)
|
32
|
+
|
33
|
+
for j in range(refine_round):
|
34
|
+
model = model.refine(grids[j])
|
35
|
+
model.fit(dataset, steps=steps)
|
36
|
+
evaluation = model.evaluate(dataset)
|
37
|
+
collect(evaluation)
|
38
|
+
|
39
|
+
for key in list(result.keys()):
|
40
|
+
result[key] = np.array(result[key])
|
41
|
+
|
42
|
+
return result
|
43
|
+
|
44
|
+
|
45
|
+
def pareto_frontier(x, y):
|
46
|
+
pf_id = np.where(np.sum((x[:, None] <= x[None, :]) * (y[:, None] <= y[None, :]), axis=0) == 1)[0]
|
47
|
+
x_pf = x[pf_id]
|
48
|
+
y_pf = y[pf_id]
|
49
|
+
|
50
|
+
return x_pf, y_pf, pf_id
|