yms-kan 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,498 @@
1
+ from sympy import *
2
+ import sympy
3
+ import numpy as np
4
+ from kan.MultKAN import MultKAN
5
+ import torch
6
+
7
+ def next_nontrivial_operation(expr, scale=1, bias=0):
8
+ '''
9
+ remove the affine part of an expression
10
+
11
+ Args:
12
+ -----
13
+ expr : sympy expression
14
+ scale : float
15
+ bias : float
16
+
17
+ Returns:
18
+ --------
19
+ expr : sympy expression
20
+ scale : float
21
+ bias : float
22
+
23
+ Example
24
+ -------
25
+ >>> from kan.compiler import *
26
+ >>> from sympy import *
27
+ >>> input_vars = a, b = symbols('a b')
28
+ >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
29
+ >>> next_nontrivial_operation(expression)
30
+ '''
31
+ if expr.func == Add or expr.func == Mul:
32
+ n_arg = len(expr.args)
33
+ n_num = 0
34
+ n_var_id = []
35
+ n_num_id = []
36
+ var_args = []
37
+ for i in range(n_arg):
38
+ is_number = expr.args[i].is_number
39
+ n_num += is_number
40
+ if not is_number:
41
+ n_var_id.append(i)
42
+ var_args.append(expr.args[i])
43
+ else:
44
+ n_num_id.append(i)
45
+ if n_num > 0:
46
+ # trivial
47
+ if expr.func == Add:
48
+ for i in range(n_num):
49
+ if i == 0:
50
+ bias = expr.args[n_num_id[i]]
51
+ else:
52
+ bias += expr.args[n_num_id[i]]
53
+ if expr.func == Mul:
54
+ for i in range(n_num):
55
+ if i == 0:
56
+ scale = expr.args[n_num_id[i]]
57
+ else:
58
+ scale *= expr.args[n_num_id[i]]
59
+
60
+ return next_nontrivial_operation(expr.func(*var_args), scale, bias)
61
+ else:
62
+ return expr, scale, bias
63
+ else:
64
+ return expr, scale, bias
65
+
66
+
67
+ def expr2kan(input_variables, expr, grid=5, k=3, auto_save=False):
68
+ '''
69
+ compile a symbolic formula to a MultKAN
70
+
71
+ Args:
72
+ -----
73
+ input_variables : a list of sympy symbols
74
+ expr : sympy expression
75
+ grid : int
76
+ the number of grid intervals
77
+ k : int
78
+ spline order
79
+ auto_save : bool
80
+ if auto_save = True, models are automatically saved
81
+
82
+ Returns:
83
+ --------
84
+ MultKAN
85
+
86
+ Example
87
+ -------
88
+ >>> from kan.compiler import *
89
+ >>> from sympy import *
90
+ >>> input_vars = a, b = symbols('a b')
91
+ >>> expression = exp(sin(pi*a) + b**2)
92
+ >>> model = kanpiler(input_vars, expression)
93
+ >>> x = torch.rand(100,2) * 2 - 1
94
+ >>> model(x)
95
+ >>> model.plot()
96
+ '''
97
+ class Node:
98
+ def __init__(self, expr, mult_bool, depth, scale, bias, parent=None, mult_arity=None):
99
+ self.expr = expr
100
+ self.mult_bool = mult_bool
101
+ if self.mult_bool:
102
+ self.mult_arity = mult_arity
103
+ self.depth = depth
104
+
105
+ if len(Nodes) <= depth:
106
+ Nodes.append([])
107
+ index = 0
108
+ else:
109
+ index = len(Nodes[depth])
110
+
111
+ Nodes[depth].append(self)
112
+
113
+ self.index = index
114
+ if parent == None:
115
+ self.parent_index = None
116
+ else:
117
+ self.parent_index = parent.index
118
+ self.child_index = []
119
+
120
+ # update parent's child_index
121
+ if parent != None:
122
+ parent.child_index.append(self.index)
123
+
124
+
125
+ self.scale = scale
126
+ self.bias = bias
127
+
128
+
129
+ class SubNode:
130
+ def __init__(self, expr, depth, scale, bias, parent=None):
131
+ self.expr = expr
132
+ self.depth = depth
133
+
134
+ if len(SubNodes) <= depth:
135
+ SubNodes.append([])
136
+ index = 0
137
+ else:
138
+ index = len(SubNodes[depth])
139
+
140
+ SubNodes[depth].append(self)
141
+
142
+ self.index = index
143
+ self.parent_index = None # shape: (2,)
144
+ self.child_index = [] # shape: (n, 2)
145
+
146
+ # update parent's child_index
147
+ parent.child_index.append(self.index)
148
+
149
+ self.scale = scale
150
+ self.bias = bias
151
+
152
+
153
+ class Connection:
154
+ def __init__(self, affine, fun, fun_name, parent=None, child=None, power_exponent=None):
155
+ # connection = activation function that connects a subnode to a node in the next layer node
156
+ self.affine = affine #[1,0,1,0] # (a,b,c,d)
157
+ self.fun = fun # y = c*fun(a*x+b)+d
158
+ self.fun_name = fun_name
159
+ self.parent_index = parent.index
160
+ self.depth = parent.depth
161
+ self.child_index = child.index
162
+ self.power_exponent = power_exponent # if fun == Pow
163
+ Connections[(self.depth,self.parent_index,self.child_index)] = self
164
+
165
+ def create_node(expr, parent=None, n_layer=None):
166
+ #print('before', expr)
167
+ expr, scale, bias = next_nontrivial_operation(expr)
168
+ #print('after', expr)
169
+ if parent == None:
170
+ depth = 0
171
+ else:
172
+ depth = parent.depth
173
+
174
+
175
+ if expr.func == Mul:
176
+ mult_arity = len(expr.args)
177
+ node = Node(expr, True, depth, scale, bias, parent=parent, mult_arity=mult_arity)
178
+ # create mult_arity SubNodes, + 1
179
+ for i in range(mult_arity):
180
+ # create SubNode
181
+ expr_i, scale, bias = next_nontrivial_operation(expr.args[i])
182
+ subnode = SubNode(expr_i, node.depth+1, scale, bias, parent=node)
183
+ if expr_i.func == Add:
184
+ for j in range(len(expr_i.args)):
185
+ expr_ij, scale, bias = next_nontrivial_operation(expr_i.args[j])
186
+ # expr_ij is impossible to be Add, should be Mul or 1D
187
+ if expr_ij.func == Mul:
188
+ #print(expr_ij)
189
+ # create a node with expr_ij
190
+ new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer)
191
+ # create a connection which is a linear function
192
+ c = Connection([1,0,float(scale),float(bias)], lambda x: x, 'x', parent=subnode, child=new_node)
193
+
194
+ elif expr_ij.func == Symbol:
195
+ #print(expr_ij)
196
+ new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer)
197
+ c = Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
198
+
199
+ else:
200
+ # 1D function case
201
+ # create a node with expr_ij.args[0]
202
+ new_node = create_node(expr_ij.args[0], parent=subnode, n_layer=n_layer)
203
+ # create 1D function expr_ij.func
204
+ if expr_ij.func == Pow:
205
+ power_exponent = expr_ij.args[1]
206
+ else:
207
+ power_exponent = None
208
+ Connection([1,0,float(scale),float(bias)], expr_ij.func, fun_name = expr_ij.func, parent=subnode, child=new_node, power_exponent=power_exponent)
209
+
210
+
211
+ elif expr_i.func == Mul:
212
+ # create a node with expr_i
213
+ new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
214
+ # create 1D function, linear
215
+ Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
216
+
217
+ elif expr_i.func == Symbol:
218
+ new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
219
+ Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
220
+
221
+ else:
222
+ # 1D functions
223
+ # create a node with expr_i.args[0]
224
+ new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer)
225
+ # create 1D function expr_i.func
226
+ if expr_i.func == Pow:
227
+ power_exponent = expr_i.args[1]
228
+ else:
229
+ power_exponent = None
230
+ Connection([1,0,1,0], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent)
231
+
232
+ elif expr.func == Add:
233
+
234
+ node = Node(expr, False, depth, scale, bias, parent=parent)
235
+ subnode = SubNode(expr, node.depth+1, 1, 0, parent=node)
236
+
237
+ for i in range(len(expr.args)):
238
+ expr_i, scale, bias = next_nontrivial_operation(expr.args[i])
239
+ if expr_i.func == Mul:
240
+ # create a node with expr_i
241
+ new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
242
+ # create a connection which is a linear function
243
+ Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
244
+
245
+ elif expr_i.func == Symbol:
246
+ new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
247
+ Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
248
+
249
+ else:
250
+ # 1D function case
251
+ # create a node with expr_ij.args[0]
252
+ new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer)
253
+ # create 1D function expr_i.func
254
+ if expr_i.func == Pow:
255
+ power_exponent = expr_i.args[1]
256
+ else:
257
+ power_exponent = None
258
+ Connection([1,0,float(scale),float(bias)], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent)
259
+
260
+ elif expr.func == Symbol:
261
+ # expr.func is a symbol (one of input variables)
262
+ if n_layer == None:
263
+ node = Node(expr, False, depth, scale, bias, parent=parent)
264
+ else:
265
+ node = Node(expr, False, depth, scale, bias, parent=parent)
266
+ return_node = node
267
+ for i in range(n_layer - depth):
268
+ subnode = SubNode(expr, node.depth+1, 1, 0, parent=node)
269
+ node = Node(expr, False, subnode.depth, 1, 0, parent=subnode)
270
+ Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=node)
271
+ node = return_node
272
+
273
+ Start_Nodes.append(node)
274
+
275
+ else:
276
+ # expr.func is 1D function
277
+ #print(expr, scale, bias)
278
+ node = Node(expr, False, depth, scale, bias, parent=parent)
279
+ expr_i, scale, bias = next_nontrivial_operation(expr.args[0])
280
+ subnode = SubNode(expr_i, node.depth+1, 1, 0, parent=node)
281
+ # create a node with expr_i.args[0]
282
+ new_node = create_node(expr.args[0], parent=subnode, n_layer=n_layer)
283
+ # create 1D function expr_i.func
284
+ if expr.func == Pow:
285
+ power_exponent = expr.args[1]
286
+ else:
287
+ power_exponent = None
288
+ Connection([1,0,1,0], expr.func, fun_name = expr.func, parent=subnode, child=new_node, power_exponent=power_exponent)
289
+
290
+ return node
291
+
292
+ Nodes = [[]]
293
+ SubNodes = [[]]
294
+ Connections = {}
295
+ Start_Nodes = []
296
+
297
+ create_node(expr, n_layer=None)
298
+
299
+ n_layer = len(Nodes) - 1
300
+
301
+ Nodes = [[]]
302
+ SubNodes = [[]]
303
+ Connections = {}
304
+ Start_Nodes = []
305
+
306
+ create_node(expr, n_layer=n_layer)
307
+
308
+ # move affine parameters in leaf nodes to connections
309
+ for node in Start_Nodes:
310
+ c = Connections[(node.depth,node.parent_index,node.index)]
311
+ c.affine[0] = float(node.scale)
312
+ c.affine[1] = float(node.bias)
313
+ node.scale = 1.
314
+ node.bias = 0.
315
+
316
+ #input_variables = symbol
317
+ node2var = []
318
+ for node in Start_Nodes:
319
+ for i in range(len(input_variables)):
320
+ if node.expr == input_variables[i]:
321
+ node2var.append(i)
322
+
323
+ # Nodes
324
+ n_mult = []
325
+ n_sum = []
326
+ for layer in Nodes:
327
+ n_mult.append(0)
328
+ n_sum.append(0)
329
+ for node in layer:
330
+ if node.mult_bool == True:
331
+ n_mult[-1] += 1
332
+ else:
333
+ n_sum[-1] += 1
334
+
335
+ # depth
336
+ n_layer = len(Nodes) - 1
337
+
338
+ # converter
339
+ # input tree node id, output kan node id (distinguish sum and mult node)
340
+ # input tree subnode id, output tree subnode id
341
+ # node id
342
+ subnode_index_convert = {}
343
+ node_index_convert = {}
344
+ connection_index_convert = {}
345
+ mult_arities = []
346
+ for layer_id in range(n_layer+1):
347
+ mult_arity = []
348
+ i_sum = 0
349
+ i_mult = 0
350
+ for i in range(len(Nodes[layer_id])):
351
+ node = Nodes[layer_id][i]
352
+ if node.mult_bool == True:
353
+ kan_node_id = n_sum[layer_id] + i_mult
354
+ arity = len(node.child_index)
355
+ for i in range(arity):
356
+ subnode = SubNodes[node.depth+1][node.child_index[i]]
357
+ kan_subnode_id = n_sum[layer_id] + np.sum(mult_arity) + i
358
+ subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id))
359
+ i_mult += 1
360
+ mult_arity.append(arity)
361
+ else:
362
+ kan_node_id = i_sum
363
+ if len(node.child_index) > 0:
364
+ subnode = SubNodes[node.depth+1][node.child_index[0]]
365
+ kan_subnode_id = i_sum
366
+ subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id))
367
+ i_sum += 1
368
+
369
+ if layer_id == n_layer:
370
+ # input layer
371
+ node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(node2var[kan_node_id]))
372
+ else:
373
+ node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(kan_node_id))
374
+
375
+ # node: depth (node.depth -> n_layer - node.depth)
376
+ # width (node.index -> kan_node_id)
377
+ # subnode: depth (subnode.depth -> n_layer - subnode.depth)
378
+ # width (subnote.index -> kan_subnode_id)
379
+ mult_arities.append(mult_arity)
380
+
381
+ for index in list(Connections.keys()):
382
+ depth, subnode_id, node_id = index
383
+ # to int(n_layer-depth),
384
+ _, kan_subnode_id = subnode_index_convert[(depth, subnode_id)]
385
+ _, kan_node_id = node_index_convert[(depth, node_id)]
386
+ connection_index_convert[(depth, subnode_id, node_id)] = (n_layer-depth, kan_subnode_id, kan_node_id)
387
+
388
+
389
+ n_sum.reverse()
390
+ n_mult.reverse()
391
+ mult_arities.reverse()
392
+
393
+ width = [[n_sum[i], n_mult[i]] for i in range(len(n_sum))]
394
+ width[0][0] = len(input_variables)
395
+
396
+ # allow pass in other parameters (probably as a dictionary) in sf2kan, including grid k etc.
397
+ model = MultKAN(width=width, mult_arity=mult_arities, grid=grid, k=k, auto_save=auto_save)
398
+
399
+ # clean the graph
400
+ for l in range(model.depth):
401
+ for i in range(model.width_in[l]):
402
+ for j in range(model.width_out[l+1]):
403
+ model.fix_symbolic(l,i,j,'0',fit_params_bool=False)
404
+
405
+ # Nodes
406
+ Nodes_flat = [x for xs in Nodes for x in xs]
407
+
408
+ self = model
409
+
410
+ for node in Nodes_flat:
411
+ node_depth = node.depth
412
+ node_index = node.index
413
+ kan_node_depth, kan_node_index = node_index_convert[(node_depth,node_index)]
414
+ #print(kan_node_depth, kan_node_index)
415
+ if kan_node_depth > 0:
416
+ self.node_scale[kan_node_depth-1].data[kan_node_index] = float(node.scale)
417
+ self.node_bias[kan_node_depth-1].data[kan_node_index] = float(node.bias)
418
+
419
+
420
+ # SubNodes
421
+ SubNodes_flat = [x for xs in SubNodes for x in xs]
422
+
423
+ for subnode in SubNodes_flat:
424
+ subnode_depth = subnode.depth
425
+ subnode_index = subnode.index
426
+ kan_subnode_depth, kan_subnode_index = subnode_index_convert[(subnode_depth,subnode_index)]
427
+ #print(kan_subnode_depth, kan_subnode_index)
428
+ self.subnode_scale[kan_subnode_depth].data[kan_subnode_index] = float(subnode.scale)
429
+ self.subnode_bias[kan_subnode_depth].data[kan_subnode_index] = float(subnode.bias)
430
+
431
+ # Connections
432
+ Connections_flat = list(Connections.values())
433
+
434
+ for connection in Connections_flat:
435
+ c_depth = connection.depth
436
+ c_j = connection.parent_index
437
+ c_i = connection.child_index
438
+ kc_depth, kc_j, kc_i = connection_index_convert[(c_depth, c_j, c_i)]
439
+
440
+ # get symbolic fun_name
441
+ fun_name = connection.fun_name
442
+ #if fun_name == Pow:
443
+ # print(connection.power_exponent)
444
+
445
+ if fun_name == 'x':
446
+ kfun_name = 'x'
447
+ elif fun_name == exp:
448
+ kfun_name = 'exp'
449
+ elif fun_name == sin:
450
+ kfun_name = 'sin'
451
+ elif fun_name == cos:
452
+ kfun_name = 'cos'
453
+ elif fun_name == tan:
454
+ kfun_name = 'tan'
455
+ elif fun_name == sqrt:
456
+ kfun_name = 'sqrt'
457
+ elif fun_name == log:
458
+ kfun_name = 'log'
459
+ elif fun_name == tanh:
460
+ kfun_name = 'tanh'
461
+ elif fun_name == asin:
462
+ kfun_name = 'arcsin'
463
+ elif fun_name == acos:
464
+ kfun_name = 'arccos'
465
+ elif fun_name == atan:
466
+ kfun_name = 'arctan'
467
+ elif fun_name == atanh:
468
+ kfun_name = 'arctanh'
469
+ elif fun_name == sign:
470
+ kfun_name = 'sgn'
471
+ elif fun_name == Pow:
472
+ alpha = connection.power_exponent
473
+ if alpha == Rational(1,2):
474
+ kfun_name = 'x^0.5'
475
+ elif alpha == - Rational(1,2):
476
+ kfun_name = '1/x^0.5'
477
+ elif alpha == Rational(3,2):
478
+ kfun_name = 'x^1.5'
479
+ else:
480
+ alpha = int(connection.power_exponent)
481
+ if alpha > 0:
482
+ if alpha == 1:
483
+ kfun_name = 'x'
484
+ else:
485
+ kfun_name = f'x^{alpha}'
486
+ else:
487
+ if alpha == -1:
488
+ kfun_name = '1/x'
489
+ else:
490
+ kfun_name = f'1/x^{-alpha}'
491
+
492
+ model.fix_symbolic(kc_depth, kc_i, kc_j, kfun_name, fit_params_bool=False)
493
+ model.symbolic_fun[kc_depth].affine.data.reshape(self.width_out[kc_depth+1], self.width_in[kc_depth], 4)[kc_j][kc_i] = torch.tensor(connection.affine)
494
+
495
+ return model
496
+
497
+
498
+ sf2kan = kanpiler = expr2kan
@@ -0,0 +1,27 @@
1
+ import scipy.io as sio
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+
5
+
6
+ class KanDataset(Dataset):
7
+ def __init__(self, data_path):
8
+ self.path = data_path
9
+ self.data = sio.loadmat(data_path)
10
+ self.features = torch.tensor(self.data['features']).double()
11
+ self.labels = torch.tensor(self.data['labels'].squeeze()).double()
12
+
13
+ def __len__(self):
14
+ return len(self.labels)
15
+
16
+ def __getitem__(self, idx):
17
+ feature = self.features[idx]
18
+ label = self.labels[idx]
19
+ return feature, label
20
+
21
+
22
+ if __name__ == '__main__':
23
+ dataset = KanDataset(r'D:\Code\2-ZSL\Zero-Shot-Learning\data\新建文件夹\1\val.mat')
24
+ print(dataset.__len__())
25
+ data_loader = DataLoader(dataset=dataset, batch_size=2, shuffle=True)
26
+ for data, label in data_loader:
27
+ print(label.shape)
@@ -0,0 +1,50 @@
1
+ import torch
2
+ from .MultKAN import *
3
+
4
+
5
+ def runner1(width, dataset, grids=[5, 10, 20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2,
6
+ node_th=1e-2, metrics=None, seed=1):
7
+ result = {'test_loss': [], 'c': [], 'G': [], 'id': []}
8
+ if metrics != None:
9
+ for i in range(len(metrics)):
10
+ result[metrics[i].__name__] = []
11
+
12
+ def collect(evaluation):
13
+ result["test_loss"].append(evaluation['test_loss'])
14
+ result['c'].append(evaluation['n_edge'])
15
+ result['G'].append(evaluation['n_grid'])
16
+ result['id'].append(f'{model.round}.{model.state_id}')
17
+ if metrics is not None:
18
+ for i in range(len(metrics)):
19
+ result[metrics[i].__name__].append(metrics[i](model, dataset).item())
20
+
21
+ for i in range(prune_round):
22
+ # train and prune
23
+ if i == 0:
24
+ model = KAN(width=width, grid=grids[0], seed=seed)
25
+ else:
26
+ model = model.rewind(f'{i - 1}.{2 * i}')
27
+
28
+ model.fit(dataset, steps=steps, lamb=lamb)
29
+ model = model.prune(edge_th=edge_th, node_th=node_th)
30
+ evaluation = model.evaluate(dataset)
31
+ collect(evaluation)
32
+
33
+ for j in range(refine_round):
34
+ model = model.refine(grids[j])
35
+ model.fit(dataset, steps=steps)
36
+ evaluation = model.evaluate(dataset)
37
+ collect(evaluation)
38
+
39
+ for key in list(result.keys()):
40
+ result[key] = np.array(result[key])
41
+
42
+ return result
43
+
44
+
45
+ def pareto_frontier(x, y):
46
+ pf_id = np.where(np.sum((x[:, None] <= x[None, :]) * (y[:, None] <= y[None, :]), axis=0) == 1)[0]
47
+ x_pf = x[pf_id]
48
+ y_pf = y[pf_id]
49
+
50
+ return x_pf, y_pf, pf_id