yms-kan 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan-0.0.7/LICENSE +21 -0
- yms_kan-0.0.7/PKG-INFO +18 -0
- yms_kan-0.0.7/README.md +1 -0
- yms_kan-0.0.7/kan/KANLayer.py +364 -0
- yms_kan-0.0.7/kan/LBFGS.py +492 -0
- yms_kan-0.0.7/kan/MLP.py +361 -0
- yms_kan-0.0.7/kan/MultKAN.py +3087 -0
- yms_kan-0.0.7/kan/Symbolic_KANLayer.py +270 -0
- yms_kan-0.0.7/kan/__init__.py +3 -0
- yms_kan-0.0.7/kan/compiler.py +498 -0
- yms_kan-0.0.7/kan/dataset.py +27 -0
- yms_kan-0.0.7/kan/experiment.py +50 -0
- yms_kan-0.0.7/kan/feynman.py +739 -0
- yms_kan-0.0.7/kan/hypothesis.py +695 -0
- yms_kan-0.0.7/kan/spline.py +144 -0
- yms_kan-0.0.7/kan/utils.py +661 -0
- yms_kan-0.0.7/setup.cfg +4 -0
- yms_kan-0.0.7/setup.py +96 -0
- yms_kan-0.0.7/yms_kan.egg-info/PKG-INFO +18 -0
- yms_kan-0.0.7/yms_kan.egg-info/SOURCES.txt +20 -0
- yms_kan-0.0.7/yms_kan.egg-info/dependency_links.txt +1 -0
- yms_kan-0.0.7/yms_kan.egg-info/top_level.txt +1 -0
@@ -0,0 +1,661 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
from sklearn.linear_model import LinearRegression
|
4
|
+
import sympy
|
5
|
+
import yaml
|
6
|
+
from sympy.utilities.lambdify import lambdify
|
7
|
+
import re
|
8
|
+
|
9
|
+
# sigmoid = sympy.Function('sigmoid')
|
10
|
+
# name: (torch implementation, sympy implementation)
|
11
|
+
|
12
|
+
# singularity protection functions
|
13
|
+
f_inv = lambda x, y_th: (
|
14
|
+
(x_th := 1 / y_th), y_th / x_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x) * (torch.abs(x) >= x_th))
|
15
|
+
f_inv2 = lambda x, y_th: (
|
16
|
+
(x_th := 1 / y_th ** (1 / 2)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 2) * (torch.abs(x) >= x_th))
|
17
|
+
f_inv3 = lambda x, y_th: ((x_th := 1 / y_th ** (1 / 3)),
|
18
|
+
y_th / x_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 3) * (
|
19
|
+
torch.abs(x) >= x_th))
|
20
|
+
f_inv4 = lambda x, y_th: (
|
21
|
+
(x_th := 1 / y_th ** (1 / 4)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 4) * (torch.abs(x) >= x_th))
|
22
|
+
f_inv5 = lambda x, y_th: ((x_th := 1 / y_th ** (1 / 5)),
|
23
|
+
y_th / x_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 5) * (
|
24
|
+
torch.abs(x) >= x_th))
|
25
|
+
f_sqrt = lambda x, y_th: ((x_th := 1 / y_th ** 2), x_th / y_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(
|
26
|
+
torch.sqrt(torch.abs(x)) * torch.sign(x)) * (torch.abs(x) >= x_th))
|
27
|
+
f_power1d5 = lambda x, y_th: torch.abs(x) ** 1.5
|
28
|
+
f_invsqrt = lambda x, y_th: ((x_th := 1 / y_th ** 2),
|
29
|
+
y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / torch.sqrt(torch.abs(x))) * (
|
30
|
+
torch.abs(x) >= x_th))
|
31
|
+
f_log = lambda x, y_th: ((x_th := torch.e ** (-y_th)),
|
32
|
+
- y_th * (torch.abs(x) < x_th) + torch.nan_to_num(torch.log(torch.abs(x))) * (
|
33
|
+
torch.abs(x) >= x_th))
|
34
|
+
f_tan = lambda x, y_th: ((clip := x % torch.pi), (delta := torch.pi / 2 - torch.arctan(y_th)),
|
35
|
+
- y_th / delta * (clip - torch.pi / 2) * (
|
36
|
+
torch.abs(clip - torch.pi / 2) < delta) + torch.nan_to_num(torch.tan(clip)) * (
|
37
|
+
torch.abs(clip - torch.pi / 2) >= delta))
|
38
|
+
f_arctanh = lambda x, y_th: ((delta := 1 - torch.tanh(y_th) + 1e-4),
|
39
|
+
y_th * torch.sign(x) * (torch.abs(x) > 1 - delta) + torch.nan_to_num(torch.arctanh(x)) * (
|
40
|
+
torch.abs(x) <= 1 - delta))
|
41
|
+
f_arcsin = lambda x, y_th: (
|
42
|
+
(), torch.pi / 2 * torch.sign(x) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arcsin(x)) * (torch.abs(x) <= 1))
|
43
|
+
f_arccos = lambda x, y_th: (
|
44
|
+
(), torch.pi / 2 * (1 - torch.sign(x)) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arccos(x)) * (torch.abs(x) <= 1))
|
45
|
+
f_exp = lambda x, y_th: ((x_th := torch.log(y_th)), y_th * (x > x_th) + torch.exp(x) * (x <= x_th))
|
46
|
+
|
47
|
+
SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)),
|
48
|
+
'x^2': (lambda x: x ** 2, lambda x: x ** 2, 2, lambda x, y_th: ((), x ** 2)),
|
49
|
+
'x^3': (lambda x: x ** 3, lambda x: x ** 3, 3, lambda x, y_th: ((), x ** 3)),
|
50
|
+
'x^4': (lambda x: x ** 4, lambda x: x ** 4, 3, lambda x, y_th: ((), x ** 4)),
|
51
|
+
'x^5': (lambda x: x ** 5, lambda x: x ** 5, 3, lambda x, y_th: ((), x ** 5)),
|
52
|
+
'1/x': (lambda x: 1 / x, lambda x: 1 / x, 2, f_inv),
|
53
|
+
'1/x^2': (lambda x: 1 / x ** 2, lambda x: 1 / x ** 2, 2, f_inv2),
|
54
|
+
'1/x^3': (lambda x: 1 / x ** 3, lambda x: 1 / x ** 3, 3, f_inv3),
|
55
|
+
'1/x^4': (lambda x: 1 / x ** 4, lambda x: 1 / x ** 4, 4, f_inv4),
|
56
|
+
'1/x^5': (lambda x: 1 / x ** 5, lambda x: 1 / x ** 5, 5, f_inv5),
|
57
|
+
'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
|
58
|
+
'x^0.5': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
|
59
|
+
'x^1.5': (lambda x: torch.sqrt(x) ** 3, lambda x: sympy.sqrt(x) ** 3, 4, f_power1d5),
|
60
|
+
'1/sqrt(x)': (lambda x: 1 / torch.sqrt(x), lambda x: 1 / sympy.sqrt(x), 2, f_invsqrt),
|
61
|
+
'1/x^0.5': (lambda x: 1 / torch.sqrt(x), lambda x: 1 / sympy.sqrt(x), 2, f_invsqrt),
|
62
|
+
'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x), 2, f_exp),
|
63
|
+
'log': (lambda x: torch.log(x), lambda x: sympy.log(x), 2, f_log),
|
64
|
+
'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x), 3, lambda x, y_th: ((), torch.abs(x))),
|
65
|
+
'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x), 2, lambda x, y_th: ((), torch.sin(x))),
|
66
|
+
'cos': (lambda x: torch.cos(x), lambda x: sympy.cos(x), 2, lambda x, y_th: ((), torch.cos(x))),
|
67
|
+
'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x), 3, f_tan),
|
68
|
+
'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x), 3, lambda x, y_th: ((), torch.tanh(x))),
|
69
|
+
'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x), 3, lambda x, y_th: ((), torch.sign(x))),
|
70
|
+
'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.asin(x), 4, f_arcsin),
|
71
|
+
'arccos': (lambda x: torch.arccos(x), lambda x: sympy.acos(x), 4, f_arccos),
|
72
|
+
'arctan': (
|
73
|
+
lambda x: torch.arctan(x), lambda x: sympy.atan(x), 4, lambda x, y_th: ((), torch.arctan(x))),
|
74
|
+
'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x), 4, f_arctanh),
|
75
|
+
'0': (lambda x: x * 0, lambda x: x * 0, 0, lambda x, y_th: ((), x * 0)),
|
76
|
+
'gaussian': (lambda x: torch.exp(-x ** 2), lambda x: sympy.exp(-x ** 2), 3,
|
77
|
+
lambda x, y_th: ((), torch.exp(-x ** 2))),
|
78
|
+
#'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5),
|
79
|
+
#'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4),
|
80
|
+
#'relu': (lambda x: torch.relu(x), relu),
|
81
|
+
}
|
82
|
+
|
83
|
+
|
84
|
+
def create_dataset(f,
|
85
|
+
n_var=2,
|
86
|
+
f_mode='col',
|
87
|
+
ranges=[-1, 1],
|
88
|
+
train_num=1000,
|
89
|
+
test_num=1000,
|
90
|
+
normalize_input=False,
|
91
|
+
normalize_label=False,
|
92
|
+
device='cpu',
|
93
|
+
seed=0):
|
94
|
+
'''
|
95
|
+
create dataset
|
96
|
+
|
97
|
+
Args:
|
98
|
+
-----
|
99
|
+
f : function
|
100
|
+
the symbolic formula used to create the synthetic dataset
|
101
|
+
ranges : list or np.array; shape (2,) or (n_var, 2)
|
102
|
+
the range of input variables. Default: [-1,1].
|
103
|
+
train_num : int
|
104
|
+
the number of training samples. Default: 1000.
|
105
|
+
test_num : int
|
106
|
+
the number of test samples. Default: 1000.
|
107
|
+
normalize_input : bool
|
108
|
+
If True, apply normalization to inputs. Default: False.
|
109
|
+
normalize_label : bool
|
110
|
+
If True, apply normalization to labels. Default: False.
|
111
|
+
device : str
|
112
|
+
device. Default: 'cpu'.
|
113
|
+
seed : int
|
114
|
+
random seed. Default: 0.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
--------
|
118
|
+
dataset : dic
|
119
|
+
Train/test inputs/labels are dataset['train_input'], dataset['train_label'],
|
120
|
+
dataset['test_input'], dataset['test_label']
|
121
|
+
|
122
|
+
Example
|
123
|
+
-------
|
124
|
+
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
125
|
+
>>> dataset = create_dataset(f, n_var=2, train_num=100)
|
126
|
+
>>> dataset['train_input'].shape
|
127
|
+
torch.Size([100, 2])
|
128
|
+
'''
|
129
|
+
|
130
|
+
np.random.seed(seed)
|
131
|
+
torch.manual_seed(seed)
|
132
|
+
|
133
|
+
if len(np.array(ranges).shape) == 1:
|
134
|
+
ranges = np.array(ranges * n_var).reshape(n_var, 2)
|
135
|
+
else:
|
136
|
+
ranges = np.array(ranges)
|
137
|
+
|
138
|
+
train_input = torch.zeros(train_num, n_var)
|
139
|
+
test_input = torch.zeros(test_num, n_var)
|
140
|
+
for i in range(n_var):
|
141
|
+
train_input[:, i] = torch.rand(train_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0]
|
142
|
+
test_input[:, i] = torch.rand(test_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0]
|
143
|
+
|
144
|
+
if f_mode == 'col':
|
145
|
+
train_label = f(train_input)
|
146
|
+
test_label = f(test_input)
|
147
|
+
elif f_mode == 'row':
|
148
|
+
train_label = f(train_input.T)
|
149
|
+
test_label = f(test_input.T)
|
150
|
+
else:
|
151
|
+
print(f'f_mode {f_mode} not recognized')
|
152
|
+
|
153
|
+
# if has only 1 dimension
|
154
|
+
if len(train_label.shape) == 1:
|
155
|
+
train_label = train_label.unsqueeze(dim=1)
|
156
|
+
test_label = test_label.unsqueeze(dim=1)
|
157
|
+
|
158
|
+
def normalize(data, mean, std):
|
159
|
+
return (data - mean) / std
|
160
|
+
|
161
|
+
if normalize_input == True:
|
162
|
+
mean_input = torch.mean(train_input, dim=0, keepdim=True)
|
163
|
+
std_input = torch.std(train_input, dim=0, keepdim=True)
|
164
|
+
train_input = normalize(train_input, mean_input, std_input)
|
165
|
+
test_input = normalize(test_input, mean_input, std_input)
|
166
|
+
|
167
|
+
if normalize_label == True:
|
168
|
+
mean_label = torch.mean(train_label, dim=0, keepdim=True)
|
169
|
+
std_label = torch.std(train_label, dim=0, keepdim=True)
|
170
|
+
train_label = normalize(train_label, mean_label, std_label)
|
171
|
+
test_label = normalize(test_label, mean_label, std_label)
|
172
|
+
|
173
|
+
dataset = {}
|
174
|
+
dataset['train_input'] = train_input.to(device)
|
175
|
+
dataset['test_input'] = test_input.to(device)
|
176
|
+
|
177
|
+
dataset['train_label'] = train_label.to(device)
|
178
|
+
dataset['test_label'] = test_label.to(device)
|
179
|
+
|
180
|
+
return dataset
|
181
|
+
|
182
|
+
|
183
|
+
def fit_params(x, y, fun, a_range=(-10, 10), b_range=(-10, 10), grid_number=101, iteration=3, verbose=True,
|
184
|
+
device='cpu'):
|
185
|
+
'''
|
186
|
+
fit a, b, c, d such that
|
187
|
+
|
188
|
+
.. math::
|
189
|
+
|y-(cf(ax+b)+d)|^2
|
190
|
+
|
191
|
+
is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
-----
|
195
|
+
x : 1D array
|
196
|
+
x values
|
197
|
+
y : 1D array
|
198
|
+
y values
|
199
|
+
fun : function
|
200
|
+
symbolic function
|
201
|
+
a_range : tuple
|
202
|
+
sweeping range of a
|
203
|
+
b_range : tuple
|
204
|
+
sweeping range of b
|
205
|
+
grid_num : int
|
206
|
+
number of steps along a and b
|
207
|
+
iteration : int
|
208
|
+
number of zooming in
|
209
|
+
verbose : bool
|
210
|
+
print extra information if True
|
211
|
+
device : str
|
212
|
+
device
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
--------
|
216
|
+
a_best : float
|
217
|
+
best fitted a
|
218
|
+
b_best : float
|
219
|
+
best fitted b
|
220
|
+
c_best : float
|
221
|
+
best fitted c
|
222
|
+
d_best : float
|
223
|
+
best fitted d
|
224
|
+
r2_best : float
|
225
|
+
best r2 (coefficient of determination)
|
226
|
+
|
227
|
+
Example
|
228
|
+
-------
|
229
|
+
>>> num = 100
|
230
|
+
>>> x = torch.linspace(-1,1,steps=num)
|
231
|
+
>>> noises = torch.normal(0,1,(num,)) * 0.02
|
232
|
+
>>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
|
233
|
+
>>> fit_params(x, y, torch.sin)
|
234
|
+
r2 is 0.9999727010726929
|
235
|
+
(tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000))
|
236
|
+
'''
|
237
|
+
# fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array.
|
238
|
+
# sweep a and b, choose the best fitted model
|
239
|
+
for _ in range(iteration):
|
240
|
+
a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number, device=device)
|
241
|
+
b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number, device=device)
|
242
|
+
a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij')
|
243
|
+
post_fun = fun(a_grid[None, :, :] * x[:, None, None] + b_grid[None, :, :])
|
244
|
+
x_mean = torch.mean(post_fun, dim=[0], keepdim=True)
|
245
|
+
y_mean = torch.mean(y, dim=[0], keepdim=True)
|
246
|
+
numerator = torch.sum((post_fun - x_mean) * (y - y_mean)[:, None, None], dim=0) ** 2
|
247
|
+
denominator = torch.sum((post_fun - x_mean) ** 2, dim=0) * torch.sum((y - y_mean)[:, None, None] ** 2, dim=0)
|
248
|
+
r2 = numerator / (denominator + 1e-4)
|
249
|
+
r2 = torch.nan_to_num(r2)
|
250
|
+
|
251
|
+
best_id = torch.argmax(r2)
|
252
|
+
a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number
|
253
|
+
|
254
|
+
if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1:
|
255
|
+
if _ == 0 and verbose == True:
|
256
|
+
print('Best value at boundary.')
|
257
|
+
if a_id == 0:
|
258
|
+
a_range = [a_[0], a_[1]]
|
259
|
+
if a_id == grid_number - 1:
|
260
|
+
a_range = [a_[-2], a_[-1]]
|
261
|
+
if b_id == 0:
|
262
|
+
b_range = [b_[0], b_[1]]
|
263
|
+
if b_id == grid_number - 1:
|
264
|
+
b_range = [b_[-2], b_[-1]]
|
265
|
+
|
266
|
+
else:
|
267
|
+
a_range = [a_[a_id - 1], a_[a_id + 1]]
|
268
|
+
b_range = [b_[b_id - 1], b_[b_id + 1]]
|
269
|
+
|
270
|
+
a_best = a_[a_id]
|
271
|
+
b_best = b_[b_id]
|
272
|
+
post_fun = fun(a_best * x + b_best)
|
273
|
+
r2_best = r2[a_id, b_id]
|
274
|
+
|
275
|
+
if verbose == True:
|
276
|
+
print(f"r2 is {r2_best}")
|
277
|
+
if r2_best < 0.9:
|
278
|
+
print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.')
|
279
|
+
|
280
|
+
post_fun = torch.nan_to_num(post_fun)
|
281
|
+
reg = LinearRegression().fit(post_fun[:, None].detach().cpu().numpy(), y.detach().cpu().numpy())
|
282
|
+
c_best = torch.from_numpy(reg.coef_)[0].to(device)
|
283
|
+
d_best = torch.from_numpy(np.array(reg.intercept_)).to(device)
|
284
|
+
return torch.stack([a_best, b_best, c_best, d_best]), r2_best
|
285
|
+
|
286
|
+
|
287
|
+
def sparse_mask(in_dim, out_dim):
|
288
|
+
'''
|
289
|
+
get sparse mask
|
290
|
+
'''
|
291
|
+
in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim)
|
292
|
+
out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim)
|
293
|
+
|
294
|
+
dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :])
|
295
|
+
in_nearest = torch.argmin(dist_mat, dim=0)
|
296
|
+
in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1, 0)
|
297
|
+
out_nearest = torch.argmin(dist_mat, dim=1)
|
298
|
+
out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1, 0)
|
299
|
+
all_connection = torch.cat([in_connection, out_connection], dim=0)
|
300
|
+
mask = torch.zeros(in_dim, out_dim)
|
301
|
+
mask[all_connection[:, 0], all_connection[:, 1]] = 1.
|
302
|
+
|
303
|
+
return mask
|
304
|
+
|
305
|
+
|
306
|
+
def add_symbolic(name, fun, c=1, fun_singularity=None):
|
307
|
+
'''
|
308
|
+
add a symbolic function to library
|
309
|
+
|
310
|
+
Args:
|
311
|
+
-----
|
312
|
+
name : str
|
313
|
+
name of the function
|
314
|
+
fun : fun
|
315
|
+
torch function or lambda function
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
--------
|
319
|
+
None
|
320
|
+
|
321
|
+
Example
|
322
|
+
-------
|
323
|
+
>>> print(SYMBOLIC_LIB['Bessel'])
|
324
|
+
KeyError: 'Bessel'
|
325
|
+
>>> add_symbolic('Bessel', torch.special.bessel_j0)
|
326
|
+
>>> print(SYMBOLIC_LIB['Bessel'])
|
327
|
+
(<built-in function special_bessel_j0>, Bessel)
|
328
|
+
'''
|
329
|
+
exec(f"globals()['{name}'] = sympy.Function('{name}')")
|
330
|
+
if fun_singularity == None:
|
331
|
+
fun_singularity = fun
|
332
|
+
SYMBOLIC_LIB[name] = (fun, globals()[name], c, fun_singularity)
|
333
|
+
|
334
|
+
|
335
|
+
def ex_round(ex1, n_digit):
|
336
|
+
'''
|
337
|
+
rounding the numbers in an expression to certain floating points
|
338
|
+
|
339
|
+
Args:
|
340
|
+
-----
|
341
|
+
ex1 : sympy expression
|
342
|
+
n_digit : int
|
343
|
+
|
344
|
+
Returns:
|
345
|
+
--------
|
346
|
+
ex2 : sympy expression
|
347
|
+
|
348
|
+
Example
|
349
|
+
-------
|
350
|
+
>>> from kan.utils import *
|
351
|
+
>>> from sympy import *
|
352
|
+
>>> input_vars = a, b = symbols('a b')
|
353
|
+
>>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
|
354
|
+
>>> ex_round(expression, 2)
|
355
|
+
'''
|
356
|
+
ex2 = ex1
|
357
|
+
for a in sympy.preorder_traversal(ex1):
|
358
|
+
if isinstance(a, sympy.Float):
|
359
|
+
ex2 = ex2.subs(a, round(a, n_digit))
|
360
|
+
return ex2
|
361
|
+
|
362
|
+
|
363
|
+
def augment_input(orig_vars, aux_vars, x):
|
364
|
+
'''
|
365
|
+
augment inputs
|
366
|
+
|
367
|
+
Args:
|
368
|
+
-----
|
369
|
+
orig_vars : list of sympy symbols
|
370
|
+
aux_vars : list of auxiliary symbols
|
371
|
+
x : inputs
|
372
|
+
|
373
|
+
Returns:
|
374
|
+
--------
|
375
|
+
augmented inputs
|
376
|
+
|
377
|
+
Example
|
378
|
+
-------
|
379
|
+
>>> from kan.utils import *
|
380
|
+
>>> from sympy import *
|
381
|
+
>>> orig_vars = a, b = symbols('a b')
|
382
|
+
>>> aux_vars = [a + b, a * b]
|
383
|
+
>>> x = torch.rand(100, 2)
|
384
|
+
>>> augment_input(orig_vars, aux_vars, x).shape
|
385
|
+
'''
|
386
|
+
# if x is a tensor
|
387
|
+
if isinstance(x, torch.Tensor):
|
388
|
+
|
389
|
+
aux_values = torch.tensor([]).to(x.device)
|
390
|
+
|
391
|
+
for aux_var in aux_vars:
|
392
|
+
func = lambdify(orig_vars, aux_var, 'numpy') # returns a numpy-ready function
|
393
|
+
aux_value = torch.from_numpy(func(*[x[:, [i]].numpy() for i in range(len(orig_vars))]))
|
394
|
+
aux_values = torch.cat([aux_values, aux_value], dim=1)
|
395
|
+
|
396
|
+
x = torch.cat([aux_values, x], dim=1)
|
397
|
+
|
398
|
+
# if x is a dataset
|
399
|
+
elif isinstance(x, dict):
|
400
|
+
x['train_input'] = augment_input(orig_vars, aux_vars, x['train_input'])
|
401
|
+
x['test_input'] = augment_input(orig_vars, aux_vars, x['test_input'])
|
402
|
+
|
403
|
+
return x
|
404
|
+
|
405
|
+
|
406
|
+
def batch_jacobian(func, x, create_graph=False, mode='scalar'):
|
407
|
+
'''
|
408
|
+
jacobian
|
409
|
+
|
410
|
+
Args:
|
411
|
+
-----
|
412
|
+
func : function or model
|
413
|
+
x : inputs
|
414
|
+
create_graph : bool
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
--------
|
418
|
+
jacobian
|
419
|
+
|
420
|
+
Example
|
421
|
+
-------
|
422
|
+
>>> from kan.utils import batch_jacobian
|
423
|
+
>>> x = torch.normal(0,1,size=(100,2))
|
424
|
+
>>> model = lambda x: x[:,[0]] + x[:,[1]]
|
425
|
+
>>> batch_jacobian(model, x)
|
426
|
+
'''
|
427
|
+
|
428
|
+
# x in shape (Batch, Length)
|
429
|
+
def _func_sum(x):
|
430
|
+
return func(x).sum(dim=0)
|
431
|
+
|
432
|
+
if mode == 'scalar':
|
433
|
+
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
|
434
|
+
elif mode == 'vector':
|
435
|
+
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1, 0, 2)
|
436
|
+
|
437
|
+
|
438
|
+
def batch_hessian(model, x, create_graph=False):
|
439
|
+
'''
|
440
|
+
hessian
|
441
|
+
|
442
|
+
Args:
|
443
|
+
-----
|
444
|
+
func : function or model
|
445
|
+
x : inputs
|
446
|
+
create_graph : bool
|
447
|
+
|
448
|
+
Returns:
|
449
|
+
--------
|
450
|
+
jacobian
|
451
|
+
|
452
|
+
Example
|
453
|
+
-------
|
454
|
+
>>> from kan.utils import batch_hessian
|
455
|
+
>>> x = torch.normal(0,1,size=(100,2))
|
456
|
+
>>> model = lambda x: x[:,[0]]**2 + x[:,[1]]**2
|
457
|
+
>>> batch_hessian(model, x)
|
458
|
+
'''
|
459
|
+
# x in shape (Batch, Length)
|
460
|
+
jac = lambda x: batch_jacobian(model, x, create_graph=True)
|
461
|
+
|
462
|
+
def _jac_sum(x):
|
463
|
+
return jac(x).sum(dim=0)
|
464
|
+
|
465
|
+
return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1, 0, 2)
|
466
|
+
|
467
|
+
|
468
|
+
def create_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
|
469
|
+
from collections import defaultdict
|
470
|
+
class_indices = defaultdict(list)
|
471
|
+
for idx, label in enumerate(labels):
|
472
|
+
class_indices[label.item()].append(idx)
|
473
|
+
# 初始化训练集和测试集索引
|
474
|
+
train_id = []
|
475
|
+
test_id = []
|
476
|
+
|
477
|
+
# 分层抽样
|
478
|
+
for class_label, indices in class_indices.items():
|
479
|
+
num_samples = len(indices)
|
480
|
+
if num_samples == 0:
|
481
|
+
continue
|
482
|
+
|
483
|
+
# 计算训练样本数量
|
484
|
+
train_size = int(num_samples * train_ratio)
|
485
|
+
if train_size == 0:
|
486
|
+
train_size = 1 # 确保至少有一个样本
|
487
|
+
|
488
|
+
# 随机选择训练样本
|
489
|
+
np.random.shuffle(indices)
|
490
|
+
train_subset = indices[:train_size]
|
491
|
+
test_subset = indices[train_size:]
|
492
|
+
|
493
|
+
train_id.extend(train_subset)
|
494
|
+
test_id.extend(test_subset)
|
495
|
+
|
496
|
+
# 转换为numpy数组并打乱
|
497
|
+
train_id = np.array(train_id)
|
498
|
+
test_id = np.array(test_id)
|
499
|
+
np.random.shuffle(train_id) # 默认打乱训练集索引
|
500
|
+
np.random.shuffle(test_id) # 默认打乱测试集索引
|
501
|
+
|
502
|
+
# 构建数据集
|
503
|
+
dataset = {
|
504
|
+
'train_input': inputs[train_id].detach().to(device),
|
505
|
+
'test_input': inputs[test_id].detach().to(device),
|
506
|
+
'train_label': labels[train_id].detach().to(device),
|
507
|
+
'test_label': labels[test_id].detach().to(device)
|
508
|
+
}
|
509
|
+
|
510
|
+
return dataset
|
511
|
+
|
512
|
+
|
513
|
+
def create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
|
514
|
+
"""
|
515
|
+
create dataset from data
|
516
|
+
|
517
|
+
Args:
|
518
|
+
-----
|
519
|
+
inputs : 2D torch.float
|
520
|
+
labels : 2D torch.float
|
521
|
+
train_ratio : float
|
522
|
+
the ratio of training fraction
|
523
|
+
device : str
|
524
|
+
|
525
|
+
Returns:
|
526
|
+
--------
|
527
|
+
dataset (dictionary)
|
528
|
+
|
529
|
+
Example
|
530
|
+
-------
|
531
|
+
>>> from kan.utils import create_dataset_from_data
|
532
|
+
>>> x = torch.normal(0,1,size=(100,2))
|
533
|
+
>>> y = torch.normal(0,1,size=(100,1))
|
534
|
+
>>> dataset = create_dataset_from_data(x, y)
|
535
|
+
>>> dataset['train_input'].shape
|
536
|
+
"""
|
537
|
+
num = inputs.shape[0]
|
538
|
+
train_id = np.random.choice(num, int(num * train_ratio), replace=False)
|
539
|
+
test_id = list(set(np.arange(num)) - set(train_id))
|
540
|
+
dataset = {'train_input': inputs[train_id].detach().to(device), 'test_input': inputs[test_id].detach().to(device),
|
541
|
+
'train_label': labels[train_id].detach().to(device), 'test_label': labels[test_id].detach().to(device)}
|
542
|
+
|
543
|
+
return dataset
|
544
|
+
|
545
|
+
|
546
|
+
def get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0., lamb_l1=1.,
|
547
|
+
lamb_entropy=0.):
|
548
|
+
'''
|
549
|
+
compute the jacobian/hessian of loss wrt to model parameters
|
550
|
+
|
551
|
+
Args:
|
552
|
+
-----
|
553
|
+
inputs : 2D torch.float
|
554
|
+
labels : 2D torch.float
|
555
|
+
derivative : str
|
556
|
+
'jacobian' or 'hessian'
|
557
|
+
device : str
|
558
|
+
|
559
|
+
Returns:
|
560
|
+
--------
|
561
|
+
jacobian or hessian
|
562
|
+
'''
|
563
|
+
|
564
|
+
def get_mapping(model):
|
565
|
+
|
566
|
+
mapping = {}
|
567
|
+
name = 'model1'
|
568
|
+
|
569
|
+
keys = list(model.state_dict().keys())
|
570
|
+
for key in keys:
|
571
|
+
|
572
|
+
y = re.findall(".[0-9]+", key)
|
573
|
+
if len(y) > 0:
|
574
|
+
y = y[0][1:]
|
575
|
+
x = re.split(".[0-9]+", key)
|
576
|
+
mapping[key] = name + '.' + x[0] + '[' + y + ']' + x[1]
|
577
|
+
|
578
|
+
y = re.findall("_[0-9]+", key)
|
579
|
+
if len(y) > 0:
|
580
|
+
y = y[0][1:]
|
581
|
+
x = re.split(".[0-9]+", key)
|
582
|
+
mapping[key] = name + '.' + x[0] + '[' + y + ']'
|
583
|
+
|
584
|
+
return mapping
|
585
|
+
|
586
|
+
#model1 = copy.deepcopy(model)
|
587
|
+
model1 = model.copy()
|
588
|
+
mapping = get_mapping(model)
|
589
|
+
|
590
|
+
# collect keys and shapes
|
591
|
+
keys = list(model.state_dict().keys())
|
592
|
+
shapes = []
|
593
|
+
|
594
|
+
for params in model.parameters():
|
595
|
+
shapes.append(params.shape)
|
596
|
+
|
597
|
+
# turn a flattened vector to model params
|
598
|
+
def param2statedict(p, keys, shapes):
|
599
|
+
|
600
|
+
new_state_dict = {}
|
601
|
+
|
602
|
+
start = 0
|
603
|
+
n_group = len(keys)
|
604
|
+
for i in range(n_group):
|
605
|
+
shape = shapes[i]
|
606
|
+
n_params = torch.prod(torch.tensor(shape))
|
607
|
+
new_state_dict[keys[i]] = p[start:start + n_params].reshape(shape)
|
608
|
+
start += n_params
|
609
|
+
|
610
|
+
return new_state_dict
|
611
|
+
|
612
|
+
def differentiable_load_state_dict(mapping, state_dict, model1):
|
613
|
+
|
614
|
+
for key in keys:
|
615
|
+
if mapping[key][-1] != ']':
|
616
|
+
exec(f"del {mapping[key]}")
|
617
|
+
exec(f"{mapping[key]} = state_dict[key]")
|
618
|
+
|
619
|
+
# input: p, output: output
|
620
|
+
def get_param2loss_fun(inputs, labels):
|
621
|
+
|
622
|
+
def param2loss_fun(p):
|
623
|
+
|
624
|
+
p = p[0]
|
625
|
+
state_dict = param2statedict(p, keys, shapes)
|
626
|
+
# this step is non-differentiable
|
627
|
+
#model.load_state_dict(state_dict)
|
628
|
+
differentiable_load_state_dict(mapping, state_dict, model1)
|
629
|
+
if loss_mode == 'pred':
|
630
|
+
pred_loss = torch.mean((model1(inputs) - labels) ** 2, dim=(0, 1), keepdim=True)
|
631
|
+
loss = pred_loss
|
632
|
+
elif loss_mode == 'reg':
|
633
|
+
reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1,
|
634
|
+
lamb_entropy=lamb_entropy) * torch.ones(1, 1)
|
635
|
+
loss = reg_loss
|
636
|
+
elif loss_mode == 'all':
|
637
|
+
pred_loss = torch.mean((model1(inputs) - labels) ** 2, dim=(0, 1), keepdim=True)
|
638
|
+
reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1,
|
639
|
+
lamb_entropy=lamb_entropy) * torch.ones(1, 1)
|
640
|
+
loss = pred_loss + lamb * reg_loss
|
641
|
+
return loss
|
642
|
+
|
643
|
+
return param2loss_fun
|
644
|
+
|
645
|
+
fun = get_param2loss_fun(inputs, labels)
|
646
|
+
p = model2param(model)[None, :]
|
647
|
+
if derivative == 'hessian':
|
648
|
+
result = batch_hessian(fun, p)
|
649
|
+
elif derivative == 'jacobian':
|
650
|
+
result = batch_jacobian(fun, p)
|
651
|
+
return result
|
652
|
+
|
653
|
+
|
654
|
+
def model2param(model):
|
655
|
+
'''
|
656
|
+
turn model parameters into a flattened vector
|
657
|
+
'''
|
658
|
+
p = torch.tensor([]).to(model.device)
|
659
|
+
for params in model.parameters():
|
660
|
+
p = torch.cat([p, params.reshape(-1, )], dim=0)
|
661
|
+
return p
|
yms_kan-0.0.7/setup.cfg
ADDED