yms-kan 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,661 @@
1
+ import numpy as np
2
+ import torch
3
+ from sklearn.linear_model import LinearRegression
4
+ import sympy
5
+ import yaml
6
+ from sympy.utilities.lambdify import lambdify
7
+ import re
8
+
9
+ # sigmoid = sympy.Function('sigmoid')
10
+ # name: (torch implementation, sympy implementation)
11
+
12
+ # singularity protection functions
13
+ f_inv = lambda x, y_th: (
14
+ (x_th := 1 / y_th), y_th / x_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x) * (torch.abs(x) >= x_th))
15
+ f_inv2 = lambda x, y_th: (
16
+ (x_th := 1 / y_th ** (1 / 2)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 2) * (torch.abs(x) >= x_th))
17
+ f_inv3 = lambda x, y_th: ((x_th := 1 / y_th ** (1 / 3)),
18
+ y_th / x_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 3) * (
19
+ torch.abs(x) >= x_th))
20
+ f_inv4 = lambda x, y_th: (
21
+ (x_th := 1 / y_th ** (1 / 4)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 4) * (torch.abs(x) >= x_th))
22
+ f_inv5 = lambda x, y_th: ((x_th := 1 / y_th ** (1 / 5)),
23
+ y_th / x_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 5) * (
24
+ torch.abs(x) >= x_th))
25
+ f_sqrt = lambda x, y_th: ((x_th := 1 / y_th ** 2), x_th / y_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(
26
+ torch.sqrt(torch.abs(x)) * torch.sign(x)) * (torch.abs(x) >= x_th))
27
+ f_power1d5 = lambda x, y_th: torch.abs(x) ** 1.5
28
+ f_invsqrt = lambda x, y_th: ((x_th := 1 / y_th ** 2),
29
+ y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / torch.sqrt(torch.abs(x))) * (
30
+ torch.abs(x) >= x_th))
31
+ f_log = lambda x, y_th: ((x_th := torch.e ** (-y_th)),
32
+ - y_th * (torch.abs(x) < x_th) + torch.nan_to_num(torch.log(torch.abs(x))) * (
33
+ torch.abs(x) >= x_th))
34
+ f_tan = lambda x, y_th: ((clip := x % torch.pi), (delta := torch.pi / 2 - torch.arctan(y_th)),
35
+ - y_th / delta * (clip - torch.pi / 2) * (
36
+ torch.abs(clip - torch.pi / 2) < delta) + torch.nan_to_num(torch.tan(clip)) * (
37
+ torch.abs(clip - torch.pi / 2) >= delta))
38
+ f_arctanh = lambda x, y_th: ((delta := 1 - torch.tanh(y_th) + 1e-4),
39
+ y_th * torch.sign(x) * (torch.abs(x) > 1 - delta) + torch.nan_to_num(torch.arctanh(x)) * (
40
+ torch.abs(x) <= 1 - delta))
41
+ f_arcsin = lambda x, y_th: (
42
+ (), torch.pi / 2 * torch.sign(x) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arcsin(x)) * (torch.abs(x) <= 1))
43
+ f_arccos = lambda x, y_th: (
44
+ (), torch.pi / 2 * (1 - torch.sign(x)) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arccos(x)) * (torch.abs(x) <= 1))
45
+ f_exp = lambda x, y_th: ((x_th := torch.log(y_th)), y_th * (x > x_th) + torch.exp(x) * (x <= x_th))
46
+
47
+ SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)),
48
+ 'x^2': (lambda x: x ** 2, lambda x: x ** 2, 2, lambda x, y_th: ((), x ** 2)),
49
+ 'x^3': (lambda x: x ** 3, lambda x: x ** 3, 3, lambda x, y_th: ((), x ** 3)),
50
+ 'x^4': (lambda x: x ** 4, lambda x: x ** 4, 3, lambda x, y_th: ((), x ** 4)),
51
+ 'x^5': (lambda x: x ** 5, lambda x: x ** 5, 3, lambda x, y_th: ((), x ** 5)),
52
+ '1/x': (lambda x: 1 / x, lambda x: 1 / x, 2, f_inv),
53
+ '1/x^2': (lambda x: 1 / x ** 2, lambda x: 1 / x ** 2, 2, f_inv2),
54
+ '1/x^3': (lambda x: 1 / x ** 3, lambda x: 1 / x ** 3, 3, f_inv3),
55
+ '1/x^4': (lambda x: 1 / x ** 4, lambda x: 1 / x ** 4, 4, f_inv4),
56
+ '1/x^5': (lambda x: 1 / x ** 5, lambda x: 1 / x ** 5, 5, f_inv5),
57
+ 'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
58
+ 'x^0.5': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
59
+ 'x^1.5': (lambda x: torch.sqrt(x) ** 3, lambda x: sympy.sqrt(x) ** 3, 4, f_power1d5),
60
+ '1/sqrt(x)': (lambda x: 1 / torch.sqrt(x), lambda x: 1 / sympy.sqrt(x), 2, f_invsqrt),
61
+ '1/x^0.5': (lambda x: 1 / torch.sqrt(x), lambda x: 1 / sympy.sqrt(x), 2, f_invsqrt),
62
+ 'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x), 2, f_exp),
63
+ 'log': (lambda x: torch.log(x), lambda x: sympy.log(x), 2, f_log),
64
+ 'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x), 3, lambda x, y_th: ((), torch.abs(x))),
65
+ 'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x), 2, lambda x, y_th: ((), torch.sin(x))),
66
+ 'cos': (lambda x: torch.cos(x), lambda x: sympy.cos(x), 2, lambda x, y_th: ((), torch.cos(x))),
67
+ 'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x), 3, f_tan),
68
+ 'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x), 3, lambda x, y_th: ((), torch.tanh(x))),
69
+ 'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x), 3, lambda x, y_th: ((), torch.sign(x))),
70
+ 'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.asin(x), 4, f_arcsin),
71
+ 'arccos': (lambda x: torch.arccos(x), lambda x: sympy.acos(x), 4, f_arccos),
72
+ 'arctan': (
73
+ lambda x: torch.arctan(x), lambda x: sympy.atan(x), 4, lambda x, y_th: ((), torch.arctan(x))),
74
+ 'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x), 4, f_arctanh),
75
+ '0': (lambda x: x * 0, lambda x: x * 0, 0, lambda x, y_th: ((), x * 0)),
76
+ 'gaussian': (lambda x: torch.exp(-x ** 2), lambda x: sympy.exp(-x ** 2), 3,
77
+ lambda x, y_th: ((), torch.exp(-x ** 2))),
78
+ #'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5),
79
+ #'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4),
80
+ #'relu': (lambda x: torch.relu(x), relu),
81
+ }
82
+
83
+
84
+ def create_dataset(f,
85
+ n_var=2,
86
+ f_mode='col',
87
+ ranges=[-1, 1],
88
+ train_num=1000,
89
+ test_num=1000,
90
+ normalize_input=False,
91
+ normalize_label=False,
92
+ device='cpu',
93
+ seed=0):
94
+ '''
95
+ create dataset
96
+
97
+ Args:
98
+ -----
99
+ f : function
100
+ the symbolic formula used to create the synthetic dataset
101
+ ranges : list or np.array; shape (2,) or (n_var, 2)
102
+ the range of input variables. Default: [-1,1].
103
+ train_num : int
104
+ the number of training samples. Default: 1000.
105
+ test_num : int
106
+ the number of test samples. Default: 1000.
107
+ normalize_input : bool
108
+ If True, apply normalization to inputs. Default: False.
109
+ normalize_label : bool
110
+ If True, apply normalization to labels. Default: False.
111
+ device : str
112
+ device. Default: 'cpu'.
113
+ seed : int
114
+ random seed. Default: 0.
115
+
116
+ Returns:
117
+ --------
118
+ dataset : dic
119
+ Train/test inputs/labels are dataset['train_input'], dataset['train_label'],
120
+ dataset['test_input'], dataset['test_label']
121
+
122
+ Example
123
+ -------
124
+ >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
125
+ >>> dataset = create_dataset(f, n_var=2, train_num=100)
126
+ >>> dataset['train_input'].shape
127
+ torch.Size([100, 2])
128
+ '''
129
+
130
+ np.random.seed(seed)
131
+ torch.manual_seed(seed)
132
+
133
+ if len(np.array(ranges).shape) == 1:
134
+ ranges = np.array(ranges * n_var).reshape(n_var, 2)
135
+ else:
136
+ ranges = np.array(ranges)
137
+
138
+ train_input = torch.zeros(train_num, n_var)
139
+ test_input = torch.zeros(test_num, n_var)
140
+ for i in range(n_var):
141
+ train_input[:, i] = torch.rand(train_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0]
142
+ test_input[:, i] = torch.rand(test_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0]
143
+
144
+ if f_mode == 'col':
145
+ train_label = f(train_input)
146
+ test_label = f(test_input)
147
+ elif f_mode == 'row':
148
+ train_label = f(train_input.T)
149
+ test_label = f(test_input.T)
150
+ else:
151
+ print(f'f_mode {f_mode} not recognized')
152
+
153
+ # if has only 1 dimension
154
+ if len(train_label.shape) == 1:
155
+ train_label = train_label.unsqueeze(dim=1)
156
+ test_label = test_label.unsqueeze(dim=1)
157
+
158
+ def normalize(data, mean, std):
159
+ return (data - mean) / std
160
+
161
+ if normalize_input == True:
162
+ mean_input = torch.mean(train_input, dim=0, keepdim=True)
163
+ std_input = torch.std(train_input, dim=0, keepdim=True)
164
+ train_input = normalize(train_input, mean_input, std_input)
165
+ test_input = normalize(test_input, mean_input, std_input)
166
+
167
+ if normalize_label == True:
168
+ mean_label = torch.mean(train_label, dim=0, keepdim=True)
169
+ std_label = torch.std(train_label, dim=0, keepdim=True)
170
+ train_label = normalize(train_label, mean_label, std_label)
171
+ test_label = normalize(test_label, mean_label, std_label)
172
+
173
+ dataset = {}
174
+ dataset['train_input'] = train_input.to(device)
175
+ dataset['test_input'] = test_input.to(device)
176
+
177
+ dataset['train_label'] = train_label.to(device)
178
+ dataset['test_label'] = test_label.to(device)
179
+
180
+ return dataset
181
+
182
+
183
+ def fit_params(x, y, fun, a_range=(-10, 10), b_range=(-10, 10), grid_number=101, iteration=3, verbose=True,
184
+ device='cpu'):
185
+ '''
186
+ fit a, b, c, d such that
187
+
188
+ .. math::
189
+ |y-(cf(ax+b)+d)|^2
190
+
191
+ is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model.
192
+
193
+ Args:
194
+ -----
195
+ x : 1D array
196
+ x values
197
+ y : 1D array
198
+ y values
199
+ fun : function
200
+ symbolic function
201
+ a_range : tuple
202
+ sweeping range of a
203
+ b_range : tuple
204
+ sweeping range of b
205
+ grid_num : int
206
+ number of steps along a and b
207
+ iteration : int
208
+ number of zooming in
209
+ verbose : bool
210
+ print extra information if True
211
+ device : str
212
+ device
213
+
214
+ Returns:
215
+ --------
216
+ a_best : float
217
+ best fitted a
218
+ b_best : float
219
+ best fitted b
220
+ c_best : float
221
+ best fitted c
222
+ d_best : float
223
+ best fitted d
224
+ r2_best : float
225
+ best r2 (coefficient of determination)
226
+
227
+ Example
228
+ -------
229
+ >>> num = 100
230
+ >>> x = torch.linspace(-1,1,steps=num)
231
+ >>> noises = torch.normal(0,1,(num,)) * 0.02
232
+ >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
233
+ >>> fit_params(x, y, torch.sin)
234
+ r2 is 0.9999727010726929
235
+ (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000))
236
+ '''
237
+ # fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array.
238
+ # sweep a and b, choose the best fitted model
239
+ for _ in range(iteration):
240
+ a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number, device=device)
241
+ b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number, device=device)
242
+ a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij')
243
+ post_fun = fun(a_grid[None, :, :] * x[:, None, None] + b_grid[None, :, :])
244
+ x_mean = torch.mean(post_fun, dim=[0], keepdim=True)
245
+ y_mean = torch.mean(y, dim=[0], keepdim=True)
246
+ numerator = torch.sum((post_fun - x_mean) * (y - y_mean)[:, None, None], dim=0) ** 2
247
+ denominator = torch.sum((post_fun - x_mean) ** 2, dim=0) * torch.sum((y - y_mean)[:, None, None] ** 2, dim=0)
248
+ r2 = numerator / (denominator + 1e-4)
249
+ r2 = torch.nan_to_num(r2)
250
+
251
+ best_id = torch.argmax(r2)
252
+ a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number
253
+
254
+ if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1:
255
+ if _ == 0 and verbose == True:
256
+ print('Best value at boundary.')
257
+ if a_id == 0:
258
+ a_range = [a_[0], a_[1]]
259
+ if a_id == grid_number - 1:
260
+ a_range = [a_[-2], a_[-1]]
261
+ if b_id == 0:
262
+ b_range = [b_[0], b_[1]]
263
+ if b_id == grid_number - 1:
264
+ b_range = [b_[-2], b_[-1]]
265
+
266
+ else:
267
+ a_range = [a_[a_id - 1], a_[a_id + 1]]
268
+ b_range = [b_[b_id - 1], b_[b_id + 1]]
269
+
270
+ a_best = a_[a_id]
271
+ b_best = b_[b_id]
272
+ post_fun = fun(a_best * x + b_best)
273
+ r2_best = r2[a_id, b_id]
274
+
275
+ if verbose == True:
276
+ print(f"r2 is {r2_best}")
277
+ if r2_best < 0.9:
278
+ print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.')
279
+
280
+ post_fun = torch.nan_to_num(post_fun)
281
+ reg = LinearRegression().fit(post_fun[:, None].detach().cpu().numpy(), y.detach().cpu().numpy())
282
+ c_best = torch.from_numpy(reg.coef_)[0].to(device)
283
+ d_best = torch.from_numpy(np.array(reg.intercept_)).to(device)
284
+ return torch.stack([a_best, b_best, c_best, d_best]), r2_best
285
+
286
+
287
+ def sparse_mask(in_dim, out_dim):
288
+ '''
289
+ get sparse mask
290
+ '''
291
+ in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim)
292
+ out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim)
293
+
294
+ dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :])
295
+ in_nearest = torch.argmin(dist_mat, dim=0)
296
+ in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1, 0)
297
+ out_nearest = torch.argmin(dist_mat, dim=1)
298
+ out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1, 0)
299
+ all_connection = torch.cat([in_connection, out_connection], dim=0)
300
+ mask = torch.zeros(in_dim, out_dim)
301
+ mask[all_connection[:, 0], all_connection[:, 1]] = 1.
302
+
303
+ return mask
304
+
305
+
306
+ def add_symbolic(name, fun, c=1, fun_singularity=None):
307
+ '''
308
+ add a symbolic function to library
309
+
310
+ Args:
311
+ -----
312
+ name : str
313
+ name of the function
314
+ fun : fun
315
+ torch function or lambda function
316
+
317
+ Returns:
318
+ --------
319
+ None
320
+
321
+ Example
322
+ -------
323
+ >>> print(SYMBOLIC_LIB['Bessel'])
324
+ KeyError: 'Bessel'
325
+ >>> add_symbolic('Bessel', torch.special.bessel_j0)
326
+ >>> print(SYMBOLIC_LIB['Bessel'])
327
+ (<built-in function special_bessel_j0>, Bessel)
328
+ '''
329
+ exec(f"globals()['{name}'] = sympy.Function('{name}')")
330
+ if fun_singularity == None:
331
+ fun_singularity = fun
332
+ SYMBOLIC_LIB[name] = (fun, globals()[name], c, fun_singularity)
333
+
334
+
335
+ def ex_round(ex1, n_digit):
336
+ '''
337
+ rounding the numbers in an expression to certain floating points
338
+
339
+ Args:
340
+ -----
341
+ ex1 : sympy expression
342
+ n_digit : int
343
+
344
+ Returns:
345
+ --------
346
+ ex2 : sympy expression
347
+
348
+ Example
349
+ -------
350
+ >>> from kan.utils import *
351
+ >>> from sympy import *
352
+ >>> input_vars = a, b = symbols('a b')
353
+ >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
354
+ >>> ex_round(expression, 2)
355
+ '''
356
+ ex2 = ex1
357
+ for a in sympy.preorder_traversal(ex1):
358
+ if isinstance(a, sympy.Float):
359
+ ex2 = ex2.subs(a, round(a, n_digit))
360
+ return ex2
361
+
362
+
363
+ def augment_input(orig_vars, aux_vars, x):
364
+ '''
365
+ augment inputs
366
+
367
+ Args:
368
+ -----
369
+ orig_vars : list of sympy symbols
370
+ aux_vars : list of auxiliary symbols
371
+ x : inputs
372
+
373
+ Returns:
374
+ --------
375
+ augmented inputs
376
+
377
+ Example
378
+ -------
379
+ >>> from kan.utils import *
380
+ >>> from sympy import *
381
+ >>> orig_vars = a, b = symbols('a b')
382
+ >>> aux_vars = [a + b, a * b]
383
+ >>> x = torch.rand(100, 2)
384
+ >>> augment_input(orig_vars, aux_vars, x).shape
385
+ '''
386
+ # if x is a tensor
387
+ if isinstance(x, torch.Tensor):
388
+
389
+ aux_values = torch.tensor([]).to(x.device)
390
+
391
+ for aux_var in aux_vars:
392
+ func = lambdify(orig_vars, aux_var, 'numpy') # returns a numpy-ready function
393
+ aux_value = torch.from_numpy(func(*[x[:, [i]].numpy() for i in range(len(orig_vars))]))
394
+ aux_values = torch.cat([aux_values, aux_value], dim=1)
395
+
396
+ x = torch.cat([aux_values, x], dim=1)
397
+
398
+ # if x is a dataset
399
+ elif isinstance(x, dict):
400
+ x['train_input'] = augment_input(orig_vars, aux_vars, x['train_input'])
401
+ x['test_input'] = augment_input(orig_vars, aux_vars, x['test_input'])
402
+
403
+ return x
404
+
405
+
406
+ def batch_jacobian(func, x, create_graph=False, mode='scalar'):
407
+ '''
408
+ jacobian
409
+
410
+ Args:
411
+ -----
412
+ func : function or model
413
+ x : inputs
414
+ create_graph : bool
415
+
416
+ Returns:
417
+ --------
418
+ jacobian
419
+
420
+ Example
421
+ -------
422
+ >>> from kan.utils import batch_jacobian
423
+ >>> x = torch.normal(0,1,size=(100,2))
424
+ >>> model = lambda x: x[:,[0]] + x[:,[1]]
425
+ >>> batch_jacobian(model, x)
426
+ '''
427
+
428
+ # x in shape (Batch, Length)
429
+ def _func_sum(x):
430
+ return func(x).sum(dim=0)
431
+
432
+ if mode == 'scalar':
433
+ return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
434
+ elif mode == 'vector':
435
+ return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1, 0, 2)
436
+
437
+
438
+ def batch_hessian(model, x, create_graph=False):
439
+ '''
440
+ hessian
441
+
442
+ Args:
443
+ -----
444
+ func : function or model
445
+ x : inputs
446
+ create_graph : bool
447
+
448
+ Returns:
449
+ --------
450
+ jacobian
451
+
452
+ Example
453
+ -------
454
+ >>> from kan.utils import batch_hessian
455
+ >>> x = torch.normal(0,1,size=(100,2))
456
+ >>> model = lambda x: x[:,[0]]**2 + x[:,[1]]**2
457
+ >>> batch_hessian(model, x)
458
+ '''
459
+ # x in shape (Batch, Length)
460
+ jac = lambda x: batch_jacobian(model, x, create_graph=True)
461
+
462
+ def _jac_sum(x):
463
+ return jac(x).sum(dim=0)
464
+
465
+ return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1, 0, 2)
466
+
467
+
468
+ def create_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
469
+ from collections import defaultdict
470
+ class_indices = defaultdict(list)
471
+ for idx, label in enumerate(labels):
472
+ class_indices[label.item()].append(idx)
473
+ # 初始化训练集和测试集索引
474
+ train_id = []
475
+ test_id = []
476
+
477
+ # 分层抽样
478
+ for class_label, indices in class_indices.items():
479
+ num_samples = len(indices)
480
+ if num_samples == 0:
481
+ continue
482
+
483
+ # 计算训练样本数量
484
+ train_size = int(num_samples * train_ratio)
485
+ if train_size == 0:
486
+ train_size = 1 # 确保至少有一个样本
487
+
488
+ # 随机选择训练样本
489
+ np.random.shuffle(indices)
490
+ train_subset = indices[:train_size]
491
+ test_subset = indices[train_size:]
492
+
493
+ train_id.extend(train_subset)
494
+ test_id.extend(test_subset)
495
+
496
+ # 转换为numpy数组并打乱
497
+ train_id = np.array(train_id)
498
+ test_id = np.array(test_id)
499
+ np.random.shuffle(train_id) # 默认打乱训练集索引
500
+ np.random.shuffle(test_id) # 默认打乱测试集索引
501
+
502
+ # 构建数据集
503
+ dataset = {
504
+ 'train_input': inputs[train_id].detach().to(device),
505
+ 'test_input': inputs[test_id].detach().to(device),
506
+ 'train_label': labels[train_id].detach().to(device),
507
+ 'test_label': labels[test_id].detach().to(device)
508
+ }
509
+
510
+ return dataset
511
+
512
+
513
+ def create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
514
+ """
515
+ create dataset from data
516
+
517
+ Args:
518
+ -----
519
+ inputs : 2D torch.float
520
+ labels : 2D torch.float
521
+ train_ratio : float
522
+ the ratio of training fraction
523
+ device : str
524
+
525
+ Returns:
526
+ --------
527
+ dataset (dictionary)
528
+
529
+ Example
530
+ -------
531
+ >>> from kan.utils import create_dataset_from_data
532
+ >>> x = torch.normal(0,1,size=(100,2))
533
+ >>> y = torch.normal(0,1,size=(100,1))
534
+ >>> dataset = create_dataset_from_data(x, y)
535
+ >>> dataset['train_input'].shape
536
+ """
537
+ num = inputs.shape[0]
538
+ train_id = np.random.choice(num, int(num * train_ratio), replace=False)
539
+ test_id = list(set(np.arange(num)) - set(train_id))
540
+ dataset = {'train_input': inputs[train_id].detach().to(device), 'test_input': inputs[test_id].detach().to(device),
541
+ 'train_label': labels[train_id].detach().to(device), 'test_label': labels[test_id].detach().to(device)}
542
+
543
+ return dataset
544
+
545
+
546
+ def get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0., lamb_l1=1.,
547
+ lamb_entropy=0.):
548
+ '''
549
+ compute the jacobian/hessian of loss wrt to model parameters
550
+
551
+ Args:
552
+ -----
553
+ inputs : 2D torch.float
554
+ labels : 2D torch.float
555
+ derivative : str
556
+ 'jacobian' or 'hessian'
557
+ device : str
558
+
559
+ Returns:
560
+ --------
561
+ jacobian or hessian
562
+ '''
563
+
564
+ def get_mapping(model):
565
+
566
+ mapping = {}
567
+ name = 'model1'
568
+
569
+ keys = list(model.state_dict().keys())
570
+ for key in keys:
571
+
572
+ y = re.findall(".[0-9]+", key)
573
+ if len(y) > 0:
574
+ y = y[0][1:]
575
+ x = re.split(".[0-9]+", key)
576
+ mapping[key] = name + '.' + x[0] + '[' + y + ']' + x[1]
577
+
578
+ y = re.findall("_[0-9]+", key)
579
+ if len(y) > 0:
580
+ y = y[0][1:]
581
+ x = re.split(".[0-9]+", key)
582
+ mapping[key] = name + '.' + x[0] + '[' + y + ']'
583
+
584
+ return mapping
585
+
586
+ #model1 = copy.deepcopy(model)
587
+ model1 = model.copy()
588
+ mapping = get_mapping(model)
589
+
590
+ # collect keys and shapes
591
+ keys = list(model.state_dict().keys())
592
+ shapes = []
593
+
594
+ for params in model.parameters():
595
+ shapes.append(params.shape)
596
+
597
+ # turn a flattened vector to model params
598
+ def param2statedict(p, keys, shapes):
599
+
600
+ new_state_dict = {}
601
+
602
+ start = 0
603
+ n_group = len(keys)
604
+ for i in range(n_group):
605
+ shape = shapes[i]
606
+ n_params = torch.prod(torch.tensor(shape))
607
+ new_state_dict[keys[i]] = p[start:start + n_params].reshape(shape)
608
+ start += n_params
609
+
610
+ return new_state_dict
611
+
612
+ def differentiable_load_state_dict(mapping, state_dict, model1):
613
+
614
+ for key in keys:
615
+ if mapping[key][-1] != ']':
616
+ exec(f"del {mapping[key]}")
617
+ exec(f"{mapping[key]} = state_dict[key]")
618
+
619
+ # input: p, output: output
620
+ def get_param2loss_fun(inputs, labels):
621
+
622
+ def param2loss_fun(p):
623
+
624
+ p = p[0]
625
+ state_dict = param2statedict(p, keys, shapes)
626
+ # this step is non-differentiable
627
+ #model.load_state_dict(state_dict)
628
+ differentiable_load_state_dict(mapping, state_dict, model1)
629
+ if loss_mode == 'pred':
630
+ pred_loss = torch.mean((model1(inputs) - labels) ** 2, dim=(0, 1), keepdim=True)
631
+ loss = pred_loss
632
+ elif loss_mode == 'reg':
633
+ reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1,
634
+ lamb_entropy=lamb_entropy) * torch.ones(1, 1)
635
+ loss = reg_loss
636
+ elif loss_mode == 'all':
637
+ pred_loss = torch.mean((model1(inputs) - labels) ** 2, dim=(0, 1), keepdim=True)
638
+ reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1,
639
+ lamb_entropy=lamb_entropy) * torch.ones(1, 1)
640
+ loss = pred_loss + lamb * reg_loss
641
+ return loss
642
+
643
+ return param2loss_fun
644
+
645
+ fun = get_param2loss_fun(inputs, labels)
646
+ p = model2param(model)[None, :]
647
+ if derivative == 'hessian':
648
+ result = batch_hessian(fun, p)
649
+ elif derivative == 'jacobian':
650
+ result = batch_jacobian(fun, p)
651
+ return result
652
+
653
+
654
+ def model2param(model):
655
+ '''
656
+ turn model parameters into a flattened vector
657
+ '''
658
+ p = torch.tensor([]).to(model.device)
659
+ for params in model.parameters():
660
+ p = torch.cat([p, params.reshape(-1, )], dim=0)
661
+ return p
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+