yms-kan 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yms_kan/hypothesis.py ADDED
@@ -0,0 +1,695 @@
1
+ import numpy as np
2
+ import torch
3
+ from sklearn.linear_model import LinearRegression
4
+ from sympy.utilities.lambdify import lambdify
5
+ from sklearn.cluster import AgglomerativeClustering
6
+ from .utils import batch_jacobian, batch_hessian
7
+ from functools import reduce
8
+ from yms_kan.utils import batch_jacobian, batch_hessian
9
+ import copy
10
+ import matplotlib.pyplot as plt
11
+ import sympy
12
+ from sympy.printing import latex
13
+
14
+
15
+ def detect_separability(model, x, mode='add', score_th=1e-2, res_th=1e-2, n_clusters=None, bias=0., verbose=False):
16
+ '''
17
+ detect function separability
18
+
19
+ Args:
20
+ -----
21
+ model : MultKAN, MLP or python function
22
+ x : 2D torch.float
23
+ inputs
24
+ mode : str
25
+ mode = 'add' or mode = 'mul'
26
+ score_th : float
27
+ threshold of score
28
+ res_th : float
29
+ threshold of residue
30
+ n_clusters : None or int
31
+ the number of clusters
32
+ bias : float
33
+ bias (for multiplicative separability)
34
+ verbose : bool
35
+
36
+ Returns:
37
+ --------
38
+ results (dictionary)
39
+
40
+ Example1
41
+ --------
42
+ >>> from yms_kan.hypothesis import *
43
+ >>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]])
44
+ >>> x = torch.normal(0,1,size=(100,3))
45
+ >>> detect_separability(model, x, mode='add')
46
+
47
+ Example2
48
+ --------
49
+ >>> from yms_kan.hypothesis import *
50
+ >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
51
+ >>> x = torch.normal(0,1,size=(100,3))
52
+ >>> detect_separability(model, x, mode='mul')
53
+ '''
54
+ results = {}
55
+
56
+ if mode == 'add':
57
+ hessian = batch_hessian(model, x)
58
+ elif mode == 'mul':
59
+ compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
60
+ hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
61
+
62
+ std = torch.std(x, dim=0)
63
+ hessian_normalized = hessian * std[None,:] * std[:,None]
64
+ score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
65
+ results['hessian'] = score_mat
66
+
67
+ dist_hard = (score_mat < score_th).float()
68
+
69
+ if isinstance(n_clusters, int):
70
+ n_cluster_try = [n_clusters, n_clusters]
71
+ elif isinstance(n_clusters, list):
72
+ n_cluster_try = n_clusters
73
+ else:
74
+ n_cluster_try = [1,x.shape[1]]
75
+
76
+ n_cluster_try = list(range(n_cluster_try[0], n_cluster_try[1]+1))
77
+
78
+ for n_cluster in n_cluster_try:
79
+
80
+ clustering = AgglomerativeClustering(
81
+ metric='precomputed',
82
+ n_clusters=n_cluster,
83
+ linkage='complete',
84
+ ).fit(dist_hard)
85
+
86
+ labels = clustering.labels_
87
+
88
+ groups = [list(np.where(labels == i)[0]) for i in range(n_cluster)]
89
+ blocks = [torch.sum(score_mat[groups[i]][:,groups[i]]) for i in range(n_cluster)]
90
+ block_sum = torch.sum(torch.stack(blocks))
91
+ total_sum = torch.sum(score_mat)
92
+ residual_sum = total_sum - block_sum
93
+ residual_ratio = residual_sum / total_sum
94
+
95
+ if verbose == True:
96
+ print(f'n_group={n_cluster}, residual_ratio={residual_ratio}')
97
+
98
+ if residual_ratio < res_th:
99
+ results['n_groups'] = n_cluster
100
+ results['labels'] = list(labels)
101
+ results['groups'] = groups
102
+
103
+ if results['n_groups'] > 1:
104
+ print(f'{mode} separability detected')
105
+ else:
106
+ print(f'{mode} separability not detected')
107
+
108
+ return results
109
+
110
+
111
+ def batch_grad_normgrad(model, x, group, create_graph=False):
112
+ # x in shape (Batch, Length)
113
+ group_A = group
114
+ group_B = list(set(range(x.shape[1])) - set(group))
115
+
116
+ def jac(x):
117
+ input_grad = batch_jacobian(model, x, create_graph=True)
118
+ input_grad_A = input_grad[:,group_A]
119
+ norm = torch.norm(input_grad_A, dim=1, keepdim=True) + 1e-6
120
+ input_grad_A_normalized = input_grad_A/norm
121
+ return input_grad_A_normalized
122
+
123
+ def _jac_sum(x):
124
+ return jac(x).sum(dim=0)
125
+
126
+ return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)[:,:,group_B]
127
+
128
+
129
+ def get_dependence(model, x, group):
130
+ group_A = group
131
+ group_B = list(set(range(x.shape[1])) - set(group))
132
+ grad_normgrad = batch_grad_normgrad(model, x, group=group)
133
+ std = torch.std(x, dim=0)
134
+ dependence = grad_normgrad * std[None,group_A,None] * std[None,None,group_B]
135
+ dependence = torch.median(torch.abs(dependence), dim=0)[0]
136
+ return dependence
137
+
138
+ def test_symmetry(model, x, group, dependence_th=1e-3):
139
+ '''
140
+ detect function separability
141
+
142
+ Args:
143
+ -----
144
+ model : MultKAN, MLP or python function
145
+ x : 2D torch.float
146
+ inputs
147
+ group : a list of indices
148
+ dependence_th : float
149
+ threshold of dependence
150
+
151
+ Returns:
152
+ --------
153
+ bool
154
+
155
+ Example
156
+ -------
157
+ >>> from yms_kan.hypothesis import *
158
+ >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
159
+ >>> x = torch.normal(0,1,size=(100,3))
160
+ >>> print(test_symmetry(model, x, [1,2])) # True
161
+ >>> print(test_symmetry(model, x, [0,2])) # False
162
+ '''
163
+ if len(group) == x.shape[1] or len(group) == 0:
164
+ return True
165
+
166
+ dependence = get_dependence(model, x, group)
167
+ max_dependence = torch.max(dependence)
168
+ return max_dependence < dependence_th
169
+
170
+
171
+ def test_separability(model, x, groups, mode='add', threshold=1e-2, bias=0):
172
+ '''
173
+ test function separability
174
+
175
+ Args:
176
+ -----
177
+ model : MultKAN, MLP or python function
178
+ x : 2D torch.float
179
+ inputs
180
+ mode : str
181
+ mode = 'add' or mode = 'mul'
182
+ score_th : float
183
+ threshold of score
184
+ res_th : float
185
+ threshold of residue
186
+ bias : float
187
+ bias (for multiplicative separability)
188
+ verbose : bool
189
+
190
+ Returns:
191
+ --------
192
+ bool
193
+
194
+ Example
195
+ -------
196
+ >>> from yms_kan.hypothesis import *
197
+ >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
198
+ >>> x = torch.normal(0,1,size=(100,3))
199
+ >>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True
200
+ >>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False
201
+ '''
202
+ if mode == 'add':
203
+ hessian = batch_hessian(model, x)
204
+ elif mode == 'mul':
205
+ compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
206
+ hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
207
+
208
+ std = torch.std(x, dim=0)
209
+ hessian_normalized = hessian * std[None,:] * std[:,None]
210
+ score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
211
+
212
+ sep_bool = True
213
+
214
+ # internal test
215
+ n_groups = len(groups)
216
+ for i in range(n_groups):
217
+ for j in range(i+1, n_groups):
218
+ sep_bool *= torch.max(score_mat[groups[i]][:,groups[j]]) < threshold
219
+
220
+ # external test
221
+ group_id = [x for xs in groups for x in xs]
222
+ nongroup_id = list(set(range(x.shape[1])) - set(group_id))
223
+ if len(nongroup_id) > 0 and len(group_id) > 0:
224
+ sep_bool *= torch.max(score_mat[group_id][:,nongroup_id]) < threshold
225
+
226
+ return sep_bool
227
+
228
+ def test_general_separability(model, x, groups, threshold=1e-2):
229
+ '''
230
+ test function separability
231
+
232
+ Args:
233
+ -----
234
+ model : MultKAN, MLP or python function
235
+ x : 2D torch.float
236
+ inputs
237
+ mode : str
238
+ mode = 'add' or mode = 'mul'
239
+ score_th : float
240
+ threshold of score
241
+ res_th : float
242
+ threshold of residue
243
+ bias : float
244
+ bias (for multiplicative separability)
245
+ verbose : bool
246
+
247
+ Returns:
248
+ --------
249
+ bool
250
+
251
+ Example
252
+ -------
253
+ >>> from yms_kan.hypothesis import *
254
+ >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2
255
+ >>> x = torch.normal(0,1,size=(100,3))
256
+ >>> print(test_general_separability(model, x, [[1],[0,2]])) # False
257
+ >>> print(test_general_separability(model, x, [[0],[1,2]])) # True
258
+ '''
259
+ grad = batch_jacobian(model, x)
260
+
261
+ gensep_bool = True
262
+
263
+ n_groups = len(groups)
264
+ for i in range(n_groups):
265
+ for j in range(i+1,n_groups):
266
+ group_A = groups[i]
267
+ group_B = groups[j]
268
+ for member_A in group_A:
269
+ for member_B in group_B:
270
+ def func(x):
271
+ grad = batch_jacobian(model, x, create_graph=True)
272
+ return grad[:,[member_B]]/grad[:,[member_A]]
273
+ # test if func is multiplicative separable
274
+ gensep_bool *= test_separability(func, x, groups, mode='mul', threshold=threshold)
275
+ return gensep_bool
276
+
277
+
278
+ def get_molecule(model, x, sym_th=1e-3, verbose=True):
279
+ '''
280
+ how variables are combined hierarchically
281
+
282
+ Args:
283
+ -----
284
+ model : MultKAN, MLP or python function
285
+ x : 2D torch.float
286
+ inputs
287
+ sym_th : float
288
+ threshold of symmetry
289
+ verbose : bool
290
+
291
+ Returns:
292
+ --------
293
+ list
294
+
295
+ Example
296
+ -------
297
+ >>> from yms_kan.hypothesis import *
298
+ >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
299
+ >>> x = torch.normal(0,1,size=(100,8))
300
+ >>> get_molecule(model, x, verbose=False)
301
+ [[[0], [1], [2], [3], [4], [5], [6], [7]],
302
+ [[0, 1], [2, 3], [4, 5], [6, 7]],
303
+ [[0, 1, 2, 3], [4, 5, 6, 7]],
304
+ [[0, 1, 2, 3, 4, 5, 6, 7]]]
305
+ '''
306
+ n = x.shape[1]
307
+ atoms = [[i] for i in range(n)]
308
+ molecules = []
309
+ moleculess = [copy.deepcopy(atoms)]
310
+ already_full = False
311
+ n_layer = 0
312
+ last_n_molecule = n
313
+
314
+ while True:
315
+
316
+
317
+ pointer = 0
318
+ current_molecule = []
319
+ remove_atoms = []
320
+ n_atom = 0
321
+
322
+ while len(atoms) > 0:
323
+
324
+ # assemble molecule
325
+ atom = atoms[pointer]
326
+ if verbose:
327
+ print(current_molecule)
328
+ print(atom)
329
+
330
+ if len(current_molecule) == 0:
331
+ full = False
332
+ current_molecule += atom
333
+ remove_atoms.append(atom)
334
+ n_atom += 1
335
+ else:
336
+ # try assemble the atom to the molecule
337
+ if len(current_molecule+atom) == x.shape[1] and already_full == False and n_atom > 1 and n_layer > 0:
338
+ full = True
339
+ already_full = True
340
+ else:
341
+ full = False
342
+ if test_symmetry(model, x, current_molecule+atom, dependence_th=sym_th):
343
+ current_molecule += atom
344
+ remove_atoms.append(atom)
345
+ n_atom += 1
346
+
347
+ pointer += 1
348
+
349
+ if pointer == len(atoms) or full:
350
+ molecules.append(current_molecule)
351
+ if full:
352
+ molecules.append(atom)
353
+ remove_atoms.append(atom)
354
+ # remove molecules from atoms
355
+ for atom in remove_atoms:
356
+ atoms.remove(atom)
357
+ current_molecule = []
358
+ remove_atoms = []
359
+ pointer = 0
360
+
361
+ # if not making progress, terminate
362
+ if len(molecules) == last_n_molecule:
363
+ def flatten(xss):
364
+ return [x for xs in xss for x in xs]
365
+ moleculess.append([flatten(molecules)])
366
+ break
367
+ else:
368
+ moleculess.append(copy.deepcopy(molecules))
369
+
370
+ last_n_molecule = len(molecules)
371
+
372
+ if len(molecules) == 1:
373
+ break
374
+
375
+ atoms = molecules
376
+ molecules = []
377
+
378
+ n_layer += 1
379
+
380
+ #print(n_layer, atoms)
381
+
382
+
383
+ # sort
384
+ depth = len(moleculess) - 1
385
+
386
+ for l in list(range(depth,0,-1)):
387
+
388
+ molecules_sorted = []
389
+ molecules_l = moleculess[l]
390
+ molecules_lm1 = moleculess[l-1]
391
+
392
+
393
+ for molecule_l in molecules_l:
394
+ start = 0
395
+ for i in range(1,len(molecule_l)+1):
396
+ if molecule_l[start:i] in molecules_lm1:
397
+
398
+ molecules_sorted.append(molecule_l[start:i])
399
+ start = i
400
+
401
+ moleculess[l-1] = molecules_sorted
402
+
403
+ return moleculess
404
+
405
+
406
+ def get_tree_node(model, x, moleculess, sep_th=1e-2, skip_test=True):
407
+ '''
408
+ get tree nodes
409
+
410
+ Args:
411
+ -----
412
+ model : MultKAN, MLP or python function
413
+ x : 2D torch.float
414
+ inputs
415
+ sep_th : float
416
+ threshold of separability
417
+ skip_test : bool
418
+ if True, don't test the property of each module (to save time)
419
+
420
+ Returns:
421
+ --------
422
+ arities : list of numbers
423
+ properties : list of strings
424
+
425
+ Example
426
+ -------
427
+ >>> from yms_kan.hypothesis import *
428
+ >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
429
+ >>> x = torch.normal(0,1,size=(100,8))
430
+ >>> moleculess = get_molecule(model, x, verbose=False)
431
+ >>> get_tree_node(model, x, moleculess, skip_test=False)
432
+ '''
433
+ arities = []
434
+ properties = []
435
+
436
+ depth = len(moleculess) - 1
437
+
438
+ for l in range(depth):
439
+ molecules_l = copy.deepcopy(moleculess[l])
440
+ molecules_lp1 = copy.deepcopy(moleculess[l+1])
441
+ arity_l = []
442
+ property_l = []
443
+
444
+ for molecule in molecules_lp1:
445
+ start = 0
446
+ arity = 0
447
+ groups = []
448
+ for i in range(1,len(molecule)+1):
449
+ if molecule[start:i] in molecules_l:
450
+ groups.append(molecule[start:i])
451
+ start = i
452
+ arity += 1
453
+ arity_l.append(arity)
454
+
455
+ if arity == 1:
456
+ property = 'Id'
457
+ else:
458
+ property = ''
459
+ # test property
460
+ if skip_test:
461
+ gensep_bool = False
462
+ else:
463
+ gensep_bool = test_general_separability(model, x, groups, threshold=sep_th)
464
+
465
+ if gensep_bool:
466
+ property = 'GS'
467
+ if l == depth - 1:
468
+ if skip_test:
469
+ add_bool = False
470
+ mul_bool = False
471
+ else:
472
+ add_bool = test_separability(model, x, groups, mode='add', threshold=sep_th)
473
+ mul_bool = test_separability(model, x, groups, mode='mul', threshold=sep_th)
474
+ if add_bool:
475
+ property = 'Add'
476
+ if mul_bool:
477
+ property = 'Mul'
478
+
479
+
480
+ property_l.append(property)
481
+
482
+
483
+ arities.append(arity_l)
484
+ properties.append(property_l)
485
+
486
+ return arities, properties
487
+
488
+
489
+ def plot_tree(model, x, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
490
+ '''
491
+ get tree graph
492
+
493
+ Args:
494
+ -----
495
+ model : MultKAN, MLP or python function
496
+ x : 2D torch.float
497
+ inputs
498
+ in_var : list of symbols
499
+ input variables
500
+ style : str
501
+ 'tree' or 'box'
502
+ sym_th : float
503
+ threshold of symmetry
504
+ sep_th : float
505
+ threshold of separability
506
+ skip_sep_test : bool
507
+ if True, don't test the property of each module (to save time)
508
+ verbose : bool
509
+
510
+ Returns:
511
+ --------
512
+ a tree graph
513
+
514
+ Example
515
+ -------
516
+ >>> from yms_kan.hypothesis import *
517
+ >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
518
+ >>> x = torch.normal(0,1,size=(100,8))
519
+ >>> plot_tree(model, x)
520
+ '''
521
+ moleculess = get_molecule(model, x, sym_th=sym_th, verbose=verbose)
522
+ arities, properties = get_tree_node(model, x, moleculess, sep_th=sep_th, skip_test=skip_sep_test)
523
+
524
+ n = x.shape[1]
525
+ var = None
526
+
527
+ in_vars = []
528
+
529
+ if in_var == None:
530
+ for ii in range(1, n + 1):
531
+ exec(f"x{ii} = sympy.Symbol('x_{ii}')")
532
+ exec(f"in_vars.append(x{ii})")
533
+ elif type(var[0]) == Symbol:
534
+ in_vars = var
535
+ else:
536
+ in_vars = [sympy.symbols(var_) for var_ in var]
537
+
538
+
539
+ def flatten(xss):
540
+ return [x for xs in xss for x in xs]
541
+
542
+ def myrectangle(center_x, center_y, width_x, width_y):
543
+ plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y + width_y/2, center_y + width_y/2], color='k') # up
544
+ plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y - width_y/2], color='k') # down
545
+ plt.plot([center_x - width_x/2, center_x - width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
546
+ plt.plot([center_x + width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
547
+
548
+ depth = len(moleculess)
549
+
550
+ delta = 1/n
551
+ a = 0.3
552
+ b = 0.15
553
+ y0 = 0.5
554
+
555
+
556
+ # draw rectangles
557
+ for l in range(depth-1):
558
+ molecules = moleculess[l+1]
559
+ n_molecule = len(molecules)
560
+
561
+ centers = []
562
+
563
+ acc_arity = 0
564
+
565
+ for i in range(n_molecule):
566
+ start_id = len(flatten(molecules[:i]))
567
+ end_id = len(flatten(molecules[:i+1]))
568
+
569
+ center_x = (start_id + (end_id - 1 - start_id)/2) * delta + delta/2
570
+ center_y = (l+1/2)*y0
571
+ width_x = (end_id - start_id - 1 + 2*a)*delta
572
+ width_y = 2*b
573
+
574
+ # add text (numbers) on rectangles
575
+ if style == 'box':
576
+ myrectangle(center_x, center_y, width_x, width_y)
577
+ plt.text(center_x, center_y, properties[l][i], fontsize=15, horizontalalignment='center',
578
+ verticalalignment='center')
579
+ elif style == 'tree':
580
+ # if 'GS', no rectangle, n=arity tilted lines
581
+ # if 'Id', no rectangle, n=arity vertical lines
582
+ # if 'Add' or 'Mul'. rectangle, "+" or "x"
583
+ # if '', rectangle
584
+ property = properties[l][i]
585
+ if property == 'GS' or property == 'Add' or property == 'Mul':
586
+ color = 'blue'
587
+ arity = arities[l][i]
588
+ for j in range(arity):
589
+
590
+ if l == 0:
591
+ # x = (start_id + j) * delta + delta/2, center_x
592
+ # y = center_y - b, center_y + b
593
+ plt.plot([(start_id + j) * delta + delta/2, center_x], [center_y - b, center_y + b], color=color)
594
+ else:
595
+ # x = last_centers[acc_arity:acc_arity+arity], center_x
596
+ # y = center_y - b, center_y + b
597
+ plt.plot([last_centers[acc_arity+j], center_x], [center_y - b, center_y + b], color=color)
598
+
599
+ acc_arity += arity
600
+
601
+ if property == 'Add' or property == 'Mul':
602
+ if property == 'Add':
603
+ symbol = '+'
604
+ else:
605
+ symbol = '*'
606
+
607
+ plt.text(center_x, center_y + b, symbol, horizontalalignment='center',
608
+ verticalalignment='center', color='red', fontsize=40)
609
+ if property == 'Id':
610
+ plt.plot([center_x, center_x], [center_y-width_y/2, center_y+width_y/2], color='black')
611
+
612
+ if property == '':
613
+ myrectangle(center_x, center_y, width_x, width_y)
614
+
615
+
616
+
617
+ # connections to the next layer
618
+ plt.plot([center_x, center_x], [center_y+width_y/2, center_y+y0-width_y/2], color='k')
619
+ centers.append(center_x)
620
+ last_centers = copy.deepcopy(centers)
621
+
622
+ # connections from input variables to the first layer
623
+ for i in range(n):
624
+ x_ = (i + 1/2) * delta
625
+ # connections to the next layer
626
+ plt.plot([x_, x_], [0, y0/2-width_y/2], color='k')
627
+ plt.text(x_, -0.05*(depth-1), f'${latex(in_vars[moleculess[0][i][0]])}$', fontsize=20, horizontalalignment='center')
628
+ plt.xlim(0,1)
629
+ #plt.ylim(0,1);
630
+ plt.axis('off');
631
+ plt.show()
632
+
633
+
634
+ def test_symmetry_var(model, x, input_vars, symmetry_var):
635
+ '''
636
+ test symmetry
637
+
638
+ Args:
639
+ -----
640
+ model : MultKAN, MLP or python function
641
+ x : 2D torch.float
642
+ inputs
643
+ input_vars : list of sympy symbols
644
+ symmetry_var : sympy expression
645
+
646
+ Returns:
647
+ --------
648
+ cosine similarity
649
+
650
+ Example
651
+ -------
652
+ >>> from yms_kan.hypothesis import *
653
+ >>> from sympy import *
654
+ >>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]])
655
+ >>> x = torch.normal(0,1,size=(100,8))
656
+ >>> input_vars = a, b, c = symbols('a b c')
657
+ >>> symmetry_var = b + c
658
+ >>> test_symmetry_var(model, x, input_vars, symmetry_var);
659
+ >>> symmetry_var = b * c
660
+ >>> test_symmetry_var(model, x, input_vars, symmetry_var);
661
+ '''
662
+ orig_vars = input_vars
663
+ sym_var = symmetry_var
664
+
665
+ # gradients wrt to input (model)
666
+ input_grad = batch_jacobian(model, x)
667
+
668
+ # gradients wrt to input (symmetry var)
669
+ func = lambdify(orig_vars, sym_var,'numpy') # returns a numpy-ready function
670
+
671
+ func2 = lambda x: func(*[x[:,[i]] for i in range(len(orig_vars))])
672
+ sym_grad = batch_jacobian(func2, x)
673
+
674
+ # get id
675
+ idx = []
676
+ sym_symbols = list(sym_var.free_symbols)
677
+ for sym_symbol in sym_symbols:
678
+ for j in range(len(orig_vars)):
679
+ if sym_symbol == orig_vars[j]:
680
+ idx.append(j)
681
+
682
+ input_grad_part = input_grad[:,idx]
683
+ sym_grad_part = sym_grad[:,idx]
684
+
685
+ cossim = torch.abs(torch.sum(input_grad_part * sym_grad_part, dim=1)/(torch.norm(input_grad_part, dim=1)*torch.norm(sym_grad_part, dim=1)))
686
+
687
+ ratio = torch.sum(cossim > 0.9)/len(cossim)
688
+
689
+ print(f'{100*ratio}% data have more than 0.9 cosine similarity')
690
+ if ratio > 0.9:
691
+ print('suggesting symmetry')
692
+ else:
693
+ print('not suggesting symmetry')
694
+
695
+ return cossim