yms-kan 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan/KANLayer.py +364 -0
- yms_kan/LBFGS.py +492 -0
- yms_kan/MLP.py +361 -0
- yms_kan/MultKAN.py +3085 -0
- yms_kan/Symbolic_KANLayer.py +270 -0
- yms_kan/__init__.py +4 -0
- yms_kan/compiler.py +498 -0
- yms_kan/experiment.py +50 -0
- yms_kan/feynman.py +739 -0
- yms_kan/hypothesis.py +695 -0
- yms_kan/spline.py +144 -0
- yms_kan/tool.py +304 -0
- yms_kan/train_eval_utils.py +175 -0
- yms_kan/utils.py +661 -0
- yms_kan/version.py +1 -0
- yms_kan-0.0.1.dist-info/METADATA +11 -0
- yms_kan-0.0.1.dist-info/RECORD +20 -0
- yms_kan-0.0.1.dist-info/WHEEL +5 -0
- yms_kan-0.0.1.dist-info/licenses/LICENSE +21 -0
- yms_kan-0.0.1.dist-info/top_level.txt +1 -0
yms_kan/hypothesis.py
ADDED
@@ -0,0 +1,695 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
from sklearn.linear_model import LinearRegression
|
4
|
+
from sympy.utilities.lambdify import lambdify
|
5
|
+
from sklearn.cluster import AgglomerativeClustering
|
6
|
+
from .utils import batch_jacobian, batch_hessian
|
7
|
+
from functools import reduce
|
8
|
+
from yms_kan.utils import batch_jacobian, batch_hessian
|
9
|
+
import copy
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
import sympy
|
12
|
+
from sympy.printing import latex
|
13
|
+
|
14
|
+
|
15
|
+
def detect_separability(model, x, mode='add', score_th=1e-2, res_th=1e-2, n_clusters=None, bias=0., verbose=False):
|
16
|
+
'''
|
17
|
+
detect function separability
|
18
|
+
|
19
|
+
Args:
|
20
|
+
-----
|
21
|
+
model : MultKAN, MLP or python function
|
22
|
+
x : 2D torch.float
|
23
|
+
inputs
|
24
|
+
mode : str
|
25
|
+
mode = 'add' or mode = 'mul'
|
26
|
+
score_th : float
|
27
|
+
threshold of score
|
28
|
+
res_th : float
|
29
|
+
threshold of residue
|
30
|
+
n_clusters : None or int
|
31
|
+
the number of clusters
|
32
|
+
bias : float
|
33
|
+
bias (for multiplicative separability)
|
34
|
+
verbose : bool
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
--------
|
38
|
+
results (dictionary)
|
39
|
+
|
40
|
+
Example1
|
41
|
+
--------
|
42
|
+
>>> from yms_kan.hypothesis import *
|
43
|
+
>>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]])
|
44
|
+
>>> x = torch.normal(0,1,size=(100,3))
|
45
|
+
>>> detect_separability(model, x, mode='add')
|
46
|
+
|
47
|
+
Example2
|
48
|
+
--------
|
49
|
+
>>> from yms_kan.hypothesis import *
|
50
|
+
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
|
51
|
+
>>> x = torch.normal(0,1,size=(100,3))
|
52
|
+
>>> detect_separability(model, x, mode='mul')
|
53
|
+
'''
|
54
|
+
results = {}
|
55
|
+
|
56
|
+
if mode == 'add':
|
57
|
+
hessian = batch_hessian(model, x)
|
58
|
+
elif mode == 'mul':
|
59
|
+
compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
|
60
|
+
hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
|
61
|
+
|
62
|
+
std = torch.std(x, dim=0)
|
63
|
+
hessian_normalized = hessian * std[None,:] * std[:,None]
|
64
|
+
score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
|
65
|
+
results['hessian'] = score_mat
|
66
|
+
|
67
|
+
dist_hard = (score_mat < score_th).float()
|
68
|
+
|
69
|
+
if isinstance(n_clusters, int):
|
70
|
+
n_cluster_try = [n_clusters, n_clusters]
|
71
|
+
elif isinstance(n_clusters, list):
|
72
|
+
n_cluster_try = n_clusters
|
73
|
+
else:
|
74
|
+
n_cluster_try = [1,x.shape[1]]
|
75
|
+
|
76
|
+
n_cluster_try = list(range(n_cluster_try[0], n_cluster_try[1]+1))
|
77
|
+
|
78
|
+
for n_cluster in n_cluster_try:
|
79
|
+
|
80
|
+
clustering = AgglomerativeClustering(
|
81
|
+
metric='precomputed',
|
82
|
+
n_clusters=n_cluster,
|
83
|
+
linkage='complete',
|
84
|
+
).fit(dist_hard)
|
85
|
+
|
86
|
+
labels = clustering.labels_
|
87
|
+
|
88
|
+
groups = [list(np.where(labels == i)[0]) for i in range(n_cluster)]
|
89
|
+
blocks = [torch.sum(score_mat[groups[i]][:,groups[i]]) for i in range(n_cluster)]
|
90
|
+
block_sum = torch.sum(torch.stack(blocks))
|
91
|
+
total_sum = torch.sum(score_mat)
|
92
|
+
residual_sum = total_sum - block_sum
|
93
|
+
residual_ratio = residual_sum / total_sum
|
94
|
+
|
95
|
+
if verbose == True:
|
96
|
+
print(f'n_group={n_cluster}, residual_ratio={residual_ratio}')
|
97
|
+
|
98
|
+
if residual_ratio < res_th:
|
99
|
+
results['n_groups'] = n_cluster
|
100
|
+
results['labels'] = list(labels)
|
101
|
+
results['groups'] = groups
|
102
|
+
|
103
|
+
if results['n_groups'] > 1:
|
104
|
+
print(f'{mode} separability detected')
|
105
|
+
else:
|
106
|
+
print(f'{mode} separability not detected')
|
107
|
+
|
108
|
+
return results
|
109
|
+
|
110
|
+
|
111
|
+
def batch_grad_normgrad(model, x, group, create_graph=False):
|
112
|
+
# x in shape (Batch, Length)
|
113
|
+
group_A = group
|
114
|
+
group_B = list(set(range(x.shape[1])) - set(group))
|
115
|
+
|
116
|
+
def jac(x):
|
117
|
+
input_grad = batch_jacobian(model, x, create_graph=True)
|
118
|
+
input_grad_A = input_grad[:,group_A]
|
119
|
+
norm = torch.norm(input_grad_A, dim=1, keepdim=True) + 1e-6
|
120
|
+
input_grad_A_normalized = input_grad_A/norm
|
121
|
+
return input_grad_A_normalized
|
122
|
+
|
123
|
+
def _jac_sum(x):
|
124
|
+
return jac(x).sum(dim=0)
|
125
|
+
|
126
|
+
return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)[:,:,group_B]
|
127
|
+
|
128
|
+
|
129
|
+
def get_dependence(model, x, group):
|
130
|
+
group_A = group
|
131
|
+
group_B = list(set(range(x.shape[1])) - set(group))
|
132
|
+
grad_normgrad = batch_grad_normgrad(model, x, group=group)
|
133
|
+
std = torch.std(x, dim=0)
|
134
|
+
dependence = grad_normgrad * std[None,group_A,None] * std[None,None,group_B]
|
135
|
+
dependence = torch.median(torch.abs(dependence), dim=0)[0]
|
136
|
+
return dependence
|
137
|
+
|
138
|
+
def test_symmetry(model, x, group, dependence_th=1e-3):
|
139
|
+
'''
|
140
|
+
detect function separability
|
141
|
+
|
142
|
+
Args:
|
143
|
+
-----
|
144
|
+
model : MultKAN, MLP or python function
|
145
|
+
x : 2D torch.float
|
146
|
+
inputs
|
147
|
+
group : a list of indices
|
148
|
+
dependence_th : float
|
149
|
+
threshold of dependence
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
--------
|
153
|
+
bool
|
154
|
+
|
155
|
+
Example
|
156
|
+
-------
|
157
|
+
>>> from yms_kan.hypothesis import *
|
158
|
+
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
|
159
|
+
>>> x = torch.normal(0,1,size=(100,3))
|
160
|
+
>>> print(test_symmetry(model, x, [1,2])) # True
|
161
|
+
>>> print(test_symmetry(model, x, [0,2])) # False
|
162
|
+
'''
|
163
|
+
if len(group) == x.shape[1] or len(group) == 0:
|
164
|
+
return True
|
165
|
+
|
166
|
+
dependence = get_dependence(model, x, group)
|
167
|
+
max_dependence = torch.max(dependence)
|
168
|
+
return max_dependence < dependence_th
|
169
|
+
|
170
|
+
|
171
|
+
def test_separability(model, x, groups, mode='add', threshold=1e-2, bias=0):
|
172
|
+
'''
|
173
|
+
test function separability
|
174
|
+
|
175
|
+
Args:
|
176
|
+
-----
|
177
|
+
model : MultKAN, MLP or python function
|
178
|
+
x : 2D torch.float
|
179
|
+
inputs
|
180
|
+
mode : str
|
181
|
+
mode = 'add' or mode = 'mul'
|
182
|
+
score_th : float
|
183
|
+
threshold of score
|
184
|
+
res_th : float
|
185
|
+
threshold of residue
|
186
|
+
bias : float
|
187
|
+
bias (for multiplicative separability)
|
188
|
+
verbose : bool
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
--------
|
192
|
+
bool
|
193
|
+
|
194
|
+
Example
|
195
|
+
-------
|
196
|
+
>>> from yms_kan.hypothesis import *
|
197
|
+
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
|
198
|
+
>>> x = torch.normal(0,1,size=(100,3))
|
199
|
+
>>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True
|
200
|
+
>>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False
|
201
|
+
'''
|
202
|
+
if mode == 'add':
|
203
|
+
hessian = batch_hessian(model, x)
|
204
|
+
elif mode == 'mul':
|
205
|
+
compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
|
206
|
+
hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
|
207
|
+
|
208
|
+
std = torch.std(x, dim=0)
|
209
|
+
hessian_normalized = hessian * std[None,:] * std[:,None]
|
210
|
+
score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
|
211
|
+
|
212
|
+
sep_bool = True
|
213
|
+
|
214
|
+
# internal test
|
215
|
+
n_groups = len(groups)
|
216
|
+
for i in range(n_groups):
|
217
|
+
for j in range(i+1, n_groups):
|
218
|
+
sep_bool *= torch.max(score_mat[groups[i]][:,groups[j]]) < threshold
|
219
|
+
|
220
|
+
# external test
|
221
|
+
group_id = [x for xs in groups for x in xs]
|
222
|
+
nongroup_id = list(set(range(x.shape[1])) - set(group_id))
|
223
|
+
if len(nongroup_id) > 0 and len(group_id) > 0:
|
224
|
+
sep_bool *= torch.max(score_mat[group_id][:,nongroup_id]) < threshold
|
225
|
+
|
226
|
+
return sep_bool
|
227
|
+
|
228
|
+
def test_general_separability(model, x, groups, threshold=1e-2):
|
229
|
+
'''
|
230
|
+
test function separability
|
231
|
+
|
232
|
+
Args:
|
233
|
+
-----
|
234
|
+
model : MultKAN, MLP or python function
|
235
|
+
x : 2D torch.float
|
236
|
+
inputs
|
237
|
+
mode : str
|
238
|
+
mode = 'add' or mode = 'mul'
|
239
|
+
score_th : float
|
240
|
+
threshold of score
|
241
|
+
res_th : float
|
242
|
+
threshold of residue
|
243
|
+
bias : float
|
244
|
+
bias (for multiplicative separability)
|
245
|
+
verbose : bool
|
246
|
+
|
247
|
+
Returns:
|
248
|
+
--------
|
249
|
+
bool
|
250
|
+
|
251
|
+
Example
|
252
|
+
-------
|
253
|
+
>>> from yms_kan.hypothesis import *
|
254
|
+
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2
|
255
|
+
>>> x = torch.normal(0,1,size=(100,3))
|
256
|
+
>>> print(test_general_separability(model, x, [[1],[0,2]])) # False
|
257
|
+
>>> print(test_general_separability(model, x, [[0],[1,2]])) # True
|
258
|
+
'''
|
259
|
+
grad = batch_jacobian(model, x)
|
260
|
+
|
261
|
+
gensep_bool = True
|
262
|
+
|
263
|
+
n_groups = len(groups)
|
264
|
+
for i in range(n_groups):
|
265
|
+
for j in range(i+1,n_groups):
|
266
|
+
group_A = groups[i]
|
267
|
+
group_B = groups[j]
|
268
|
+
for member_A in group_A:
|
269
|
+
for member_B in group_B:
|
270
|
+
def func(x):
|
271
|
+
grad = batch_jacobian(model, x, create_graph=True)
|
272
|
+
return grad[:,[member_B]]/grad[:,[member_A]]
|
273
|
+
# test if func is multiplicative separable
|
274
|
+
gensep_bool *= test_separability(func, x, groups, mode='mul', threshold=threshold)
|
275
|
+
return gensep_bool
|
276
|
+
|
277
|
+
|
278
|
+
def get_molecule(model, x, sym_th=1e-3, verbose=True):
|
279
|
+
'''
|
280
|
+
how variables are combined hierarchically
|
281
|
+
|
282
|
+
Args:
|
283
|
+
-----
|
284
|
+
model : MultKAN, MLP or python function
|
285
|
+
x : 2D torch.float
|
286
|
+
inputs
|
287
|
+
sym_th : float
|
288
|
+
threshold of symmetry
|
289
|
+
verbose : bool
|
290
|
+
|
291
|
+
Returns:
|
292
|
+
--------
|
293
|
+
list
|
294
|
+
|
295
|
+
Example
|
296
|
+
-------
|
297
|
+
>>> from yms_kan.hypothesis import *
|
298
|
+
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
|
299
|
+
>>> x = torch.normal(0,1,size=(100,8))
|
300
|
+
>>> get_molecule(model, x, verbose=False)
|
301
|
+
[[[0], [1], [2], [3], [4], [5], [6], [7]],
|
302
|
+
[[0, 1], [2, 3], [4, 5], [6, 7]],
|
303
|
+
[[0, 1, 2, 3], [4, 5, 6, 7]],
|
304
|
+
[[0, 1, 2, 3, 4, 5, 6, 7]]]
|
305
|
+
'''
|
306
|
+
n = x.shape[1]
|
307
|
+
atoms = [[i] for i in range(n)]
|
308
|
+
molecules = []
|
309
|
+
moleculess = [copy.deepcopy(atoms)]
|
310
|
+
already_full = False
|
311
|
+
n_layer = 0
|
312
|
+
last_n_molecule = n
|
313
|
+
|
314
|
+
while True:
|
315
|
+
|
316
|
+
|
317
|
+
pointer = 0
|
318
|
+
current_molecule = []
|
319
|
+
remove_atoms = []
|
320
|
+
n_atom = 0
|
321
|
+
|
322
|
+
while len(atoms) > 0:
|
323
|
+
|
324
|
+
# assemble molecule
|
325
|
+
atom = atoms[pointer]
|
326
|
+
if verbose:
|
327
|
+
print(current_molecule)
|
328
|
+
print(atom)
|
329
|
+
|
330
|
+
if len(current_molecule) == 0:
|
331
|
+
full = False
|
332
|
+
current_molecule += atom
|
333
|
+
remove_atoms.append(atom)
|
334
|
+
n_atom += 1
|
335
|
+
else:
|
336
|
+
# try assemble the atom to the molecule
|
337
|
+
if len(current_molecule+atom) == x.shape[1] and already_full == False and n_atom > 1 and n_layer > 0:
|
338
|
+
full = True
|
339
|
+
already_full = True
|
340
|
+
else:
|
341
|
+
full = False
|
342
|
+
if test_symmetry(model, x, current_molecule+atom, dependence_th=sym_th):
|
343
|
+
current_molecule += atom
|
344
|
+
remove_atoms.append(atom)
|
345
|
+
n_atom += 1
|
346
|
+
|
347
|
+
pointer += 1
|
348
|
+
|
349
|
+
if pointer == len(atoms) or full:
|
350
|
+
molecules.append(current_molecule)
|
351
|
+
if full:
|
352
|
+
molecules.append(atom)
|
353
|
+
remove_atoms.append(atom)
|
354
|
+
# remove molecules from atoms
|
355
|
+
for atom in remove_atoms:
|
356
|
+
atoms.remove(atom)
|
357
|
+
current_molecule = []
|
358
|
+
remove_atoms = []
|
359
|
+
pointer = 0
|
360
|
+
|
361
|
+
# if not making progress, terminate
|
362
|
+
if len(molecules) == last_n_molecule:
|
363
|
+
def flatten(xss):
|
364
|
+
return [x for xs in xss for x in xs]
|
365
|
+
moleculess.append([flatten(molecules)])
|
366
|
+
break
|
367
|
+
else:
|
368
|
+
moleculess.append(copy.deepcopy(molecules))
|
369
|
+
|
370
|
+
last_n_molecule = len(molecules)
|
371
|
+
|
372
|
+
if len(molecules) == 1:
|
373
|
+
break
|
374
|
+
|
375
|
+
atoms = molecules
|
376
|
+
molecules = []
|
377
|
+
|
378
|
+
n_layer += 1
|
379
|
+
|
380
|
+
#print(n_layer, atoms)
|
381
|
+
|
382
|
+
|
383
|
+
# sort
|
384
|
+
depth = len(moleculess) - 1
|
385
|
+
|
386
|
+
for l in list(range(depth,0,-1)):
|
387
|
+
|
388
|
+
molecules_sorted = []
|
389
|
+
molecules_l = moleculess[l]
|
390
|
+
molecules_lm1 = moleculess[l-1]
|
391
|
+
|
392
|
+
|
393
|
+
for molecule_l in molecules_l:
|
394
|
+
start = 0
|
395
|
+
for i in range(1,len(molecule_l)+1):
|
396
|
+
if molecule_l[start:i] in molecules_lm1:
|
397
|
+
|
398
|
+
molecules_sorted.append(molecule_l[start:i])
|
399
|
+
start = i
|
400
|
+
|
401
|
+
moleculess[l-1] = molecules_sorted
|
402
|
+
|
403
|
+
return moleculess
|
404
|
+
|
405
|
+
|
406
|
+
def get_tree_node(model, x, moleculess, sep_th=1e-2, skip_test=True):
|
407
|
+
'''
|
408
|
+
get tree nodes
|
409
|
+
|
410
|
+
Args:
|
411
|
+
-----
|
412
|
+
model : MultKAN, MLP or python function
|
413
|
+
x : 2D torch.float
|
414
|
+
inputs
|
415
|
+
sep_th : float
|
416
|
+
threshold of separability
|
417
|
+
skip_test : bool
|
418
|
+
if True, don't test the property of each module (to save time)
|
419
|
+
|
420
|
+
Returns:
|
421
|
+
--------
|
422
|
+
arities : list of numbers
|
423
|
+
properties : list of strings
|
424
|
+
|
425
|
+
Example
|
426
|
+
-------
|
427
|
+
>>> from yms_kan.hypothesis import *
|
428
|
+
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
|
429
|
+
>>> x = torch.normal(0,1,size=(100,8))
|
430
|
+
>>> moleculess = get_molecule(model, x, verbose=False)
|
431
|
+
>>> get_tree_node(model, x, moleculess, skip_test=False)
|
432
|
+
'''
|
433
|
+
arities = []
|
434
|
+
properties = []
|
435
|
+
|
436
|
+
depth = len(moleculess) - 1
|
437
|
+
|
438
|
+
for l in range(depth):
|
439
|
+
molecules_l = copy.deepcopy(moleculess[l])
|
440
|
+
molecules_lp1 = copy.deepcopy(moleculess[l+1])
|
441
|
+
arity_l = []
|
442
|
+
property_l = []
|
443
|
+
|
444
|
+
for molecule in molecules_lp1:
|
445
|
+
start = 0
|
446
|
+
arity = 0
|
447
|
+
groups = []
|
448
|
+
for i in range(1,len(molecule)+1):
|
449
|
+
if molecule[start:i] in molecules_l:
|
450
|
+
groups.append(molecule[start:i])
|
451
|
+
start = i
|
452
|
+
arity += 1
|
453
|
+
arity_l.append(arity)
|
454
|
+
|
455
|
+
if arity == 1:
|
456
|
+
property = 'Id'
|
457
|
+
else:
|
458
|
+
property = ''
|
459
|
+
# test property
|
460
|
+
if skip_test:
|
461
|
+
gensep_bool = False
|
462
|
+
else:
|
463
|
+
gensep_bool = test_general_separability(model, x, groups, threshold=sep_th)
|
464
|
+
|
465
|
+
if gensep_bool:
|
466
|
+
property = 'GS'
|
467
|
+
if l == depth - 1:
|
468
|
+
if skip_test:
|
469
|
+
add_bool = False
|
470
|
+
mul_bool = False
|
471
|
+
else:
|
472
|
+
add_bool = test_separability(model, x, groups, mode='add', threshold=sep_th)
|
473
|
+
mul_bool = test_separability(model, x, groups, mode='mul', threshold=sep_th)
|
474
|
+
if add_bool:
|
475
|
+
property = 'Add'
|
476
|
+
if mul_bool:
|
477
|
+
property = 'Mul'
|
478
|
+
|
479
|
+
|
480
|
+
property_l.append(property)
|
481
|
+
|
482
|
+
|
483
|
+
arities.append(arity_l)
|
484
|
+
properties.append(property_l)
|
485
|
+
|
486
|
+
return arities, properties
|
487
|
+
|
488
|
+
|
489
|
+
def plot_tree(model, x, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
|
490
|
+
'''
|
491
|
+
get tree graph
|
492
|
+
|
493
|
+
Args:
|
494
|
+
-----
|
495
|
+
model : MultKAN, MLP or python function
|
496
|
+
x : 2D torch.float
|
497
|
+
inputs
|
498
|
+
in_var : list of symbols
|
499
|
+
input variables
|
500
|
+
style : str
|
501
|
+
'tree' or 'box'
|
502
|
+
sym_th : float
|
503
|
+
threshold of symmetry
|
504
|
+
sep_th : float
|
505
|
+
threshold of separability
|
506
|
+
skip_sep_test : bool
|
507
|
+
if True, don't test the property of each module (to save time)
|
508
|
+
verbose : bool
|
509
|
+
|
510
|
+
Returns:
|
511
|
+
--------
|
512
|
+
a tree graph
|
513
|
+
|
514
|
+
Example
|
515
|
+
-------
|
516
|
+
>>> from yms_kan.hypothesis import *
|
517
|
+
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
|
518
|
+
>>> x = torch.normal(0,1,size=(100,8))
|
519
|
+
>>> plot_tree(model, x)
|
520
|
+
'''
|
521
|
+
moleculess = get_molecule(model, x, sym_th=sym_th, verbose=verbose)
|
522
|
+
arities, properties = get_tree_node(model, x, moleculess, sep_th=sep_th, skip_test=skip_sep_test)
|
523
|
+
|
524
|
+
n = x.shape[1]
|
525
|
+
var = None
|
526
|
+
|
527
|
+
in_vars = []
|
528
|
+
|
529
|
+
if in_var == None:
|
530
|
+
for ii in range(1, n + 1):
|
531
|
+
exec(f"x{ii} = sympy.Symbol('x_{ii}')")
|
532
|
+
exec(f"in_vars.append(x{ii})")
|
533
|
+
elif type(var[0]) == Symbol:
|
534
|
+
in_vars = var
|
535
|
+
else:
|
536
|
+
in_vars = [sympy.symbols(var_) for var_ in var]
|
537
|
+
|
538
|
+
|
539
|
+
def flatten(xss):
|
540
|
+
return [x for xs in xss for x in xs]
|
541
|
+
|
542
|
+
def myrectangle(center_x, center_y, width_x, width_y):
|
543
|
+
plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y + width_y/2, center_y + width_y/2], color='k') # up
|
544
|
+
plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y - width_y/2], color='k') # down
|
545
|
+
plt.plot([center_x - width_x/2, center_x - width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
|
546
|
+
plt.plot([center_x + width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
|
547
|
+
|
548
|
+
depth = len(moleculess)
|
549
|
+
|
550
|
+
delta = 1/n
|
551
|
+
a = 0.3
|
552
|
+
b = 0.15
|
553
|
+
y0 = 0.5
|
554
|
+
|
555
|
+
|
556
|
+
# draw rectangles
|
557
|
+
for l in range(depth-1):
|
558
|
+
molecules = moleculess[l+1]
|
559
|
+
n_molecule = len(molecules)
|
560
|
+
|
561
|
+
centers = []
|
562
|
+
|
563
|
+
acc_arity = 0
|
564
|
+
|
565
|
+
for i in range(n_molecule):
|
566
|
+
start_id = len(flatten(molecules[:i]))
|
567
|
+
end_id = len(flatten(molecules[:i+1]))
|
568
|
+
|
569
|
+
center_x = (start_id + (end_id - 1 - start_id)/2) * delta + delta/2
|
570
|
+
center_y = (l+1/2)*y0
|
571
|
+
width_x = (end_id - start_id - 1 + 2*a)*delta
|
572
|
+
width_y = 2*b
|
573
|
+
|
574
|
+
# add text (numbers) on rectangles
|
575
|
+
if style == 'box':
|
576
|
+
myrectangle(center_x, center_y, width_x, width_y)
|
577
|
+
plt.text(center_x, center_y, properties[l][i], fontsize=15, horizontalalignment='center',
|
578
|
+
verticalalignment='center')
|
579
|
+
elif style == 'tree':
|
580
|
+
# if 'GS', no rectangle, n=arity tilted lines
|
581
|
+
# if 'Id', no rectangle, n=arity vertical lines
|
582
|
+
# if 'Add' or 'Mul'. rectangle, "+" or "x"
|
583
|
+
# if '', rectangle
|
584
|
+
property = properties[l][i]
|
585
|
+
if property == 'GS' or property == 'Add' or property == 'Mul':
|
586
|
+
color = 'blue'
|
587
|
+
arity = arities[l][i]
|
588
|
+
for j in range(arity):
|
589
|
+
|
590
|
+
if l == 0:
|
591
|
+
# x = (start_id + j) * delta + delta/2, center_x
|
592
|
+
# y = center_y - b, center_y + b
|
593
|
+
plt.plot([(start_id + j) * delta + delta/2, center_x], [center_y - b, center_y + b], color=color)
|
594
|
+
else:
|
595
|
+
# x = last_centers[acc_arity:acc_arity+arity], center_x
|
596
|
+
# y = center_y - b, center_y + b
|
597
|
+
plt.plot([last_centers[acc_arity+j], center_x], [center_y - b, center_y + b], color=color)
|
598
|
+
|
599
|
+
acc_arity += arity
|
600
|
+
|
601
|
+
if property == 'Add' or property == 'Mul':
|
602
|
+
if property == 'Add':
|
603
|
+
symbol = '+'
|
604
|
+
else:
|
605
|
+
symbol = '*'
|
606
|
+
|
607
|
+
plt.text(center_x, center_y + b, symbol, horizontalalignment='center',
|
608
|
+
verticalalignment='center', color='red', fontsize=40)
|
609
|
+
if property == 'Id':
|
610
|
+
plt.plot([center_x, center_x], [center_y-width_y/2, center_y+width_y/2], color='black')
|
611
|
+
|
612
|
+
if property == '':
|
613
|
+
myrectangle(center_x, center_y, width_x, width_y)
|
614
|
+
|
615
|
+
|
616
|
+
|
617
|
+
# connections to the next layer
|
618
|
+
plt.plot([center_x, center_x], [center_y+width_y/2, center_y+y0-width_y/2], color='k')
|
619
|
+
centers.append(center_x)
|
620
|
+
last_centers = copy.deepcopy(centers)
|
621
|
+
|
622
|
+
# connections from input variables to the first layer
|
623
|
+
for i in range(n):
|
624
|
+
x_ = (i + 1/2) * delta
|
625
|
+
# connections to the next layer
|
626
|
+
plt.plot([x_, x_], [0, y0/2-width_y/2], color='k')
|
627
|
+
plt.text(x_, -0.05*(depth-1), f'${latex(in_vars[moleculess[0][i][0]])}$', fontsize=20, horizontalalignment='center')
|
628
|
+
plt.xlim(0,1)
|
629
|
+
#plt.ylim(0,1);
|
630
|
+
plt.axis('off');
|
631
|
+
plt.show()
|
632
|
+
|
633
|
+
|
634
|
+
def test_symmetry_var(model, x, input_vars, symmetry_var):
|
635
|
+
'''
|
636
|
+
test symmetry
|
637
|
+
|
638
|
+
Args:
|
639
|
+
-----
|
640
|
+
model : MultKAN, MLP or python function
|
641
|
+
x : 2D torch.float
|
642
|
+
inputs
|
643
|
+
input_vars : list of sympy symbols
|
644
|
+
symmetry_var : sympy expression
|
645
|
+
|
646
|
+
Returns:
|
647
|
+
--------
|
648
|
+
cosine similarity
|
649
|
+
|
650
|
+
Example
|
651
|
+
-------
|
652
|
+
>>> from yms_kan.hypothesis import *
|
653
|
+
>>> from sympy import *
|
654
|
+
>>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]])
|
655
|
+
>>> x = torch.normal(0,1,size=(100,8))
|
656
|
+
>>> input_vars = a, b, c = symbols('a b c')
|
657
|
+
>>> symmetry_var = b + c
|
658
|
+
>>> test_symmetry_var(model, x, input_vars, symmetry_var);
|
659
|
+
>>> symmetry_var = b * c
|
660
|
+
>>> test_symmetry_var(model, x, input_vars, symmetry_var);
|
661
|
+
'''
|
662
|
+
orig_vars = input_vars
|
663
|
+
sym_var = symmetry_var
|
664
|
+
|
665
|
+
# gradients wrt to input (model)
|
666
|
+
input_grad = batch_jacobian(model, x)
|
667
|
+
|
668
|
+
# gradients wrt to input (symmetry var)
|
669
|
+
func = lambdify(orig_vars, sym_var,'numpy') # returns a numpy-ready function
|
670
|
+
|
671
|
+
func2 = lambda x: func(*[x[:,[i]] for i in range(len(orig_vars))])
|
672
|
+
sym_grad = batch_jacobian(func2, x)
|
673
|
+
|
674
|
+
# get id
|
675
|
+
idx = []
|
676
|
+
sym_symbols = list(sym_var.free_symbols)
|
677
|
+
for sym_symbol in sym_symbols:
|
678
|
+
for j in range(len(orig_vars)):
|
679
|
+
if sym_symbol == orig_vars[j]:
|
680
|
+
idx.append(j)
|
681
|
+
|
682
|
+
input_grad_part = input_grad[:,idx]
|
683
|
+
sym_grad_part = sym_grad[:,idx]
|
684
|
+
|
685
|
+
cossim = torch.abs(torch.sum(input_grad_part * sym_grad_part, dim=1)/(torch.norm(input_grad_part, dim=1)*torch.norm(sym_grad_part, dim=1)))
|
686
|
+
|
687
|
+
ratio = torch.sum(cossim > 0.9)/len(cossim)
|
688
|
+
|
689
|
+
print(f'{100*ratio}% data have more than 0.9 cosine similarity')
|
690
|
+
if ratio > 0.9:
|
691
|
+
print('suggesting symmetry')
|
692
|
+
else:
|
693
|
+
print('not suggesting symmetry')
|
694
|
+
|
695
|
+
return cossim
|