yms-kan 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,144 @@
1
+ import torch
2
+
3
+
4
+ def B_batch(x, grid, k=0, extend=True, device='cpu'):
5
+ '''
6
+ evaludate x on B-spline bases
7
+
8
+ Args:
9
+ -----
10
+ x : 2D torch.tensor
11
+ inputs, shape (number of splines, number of samples)
12
+ grid : 2D torch.tensor
13
+ grids, shape (number of splines, number of grid points)
14
+ k : int
15
+ the piecewise polynomial order of splines.
16
+ extend : bool
17
+ If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
18
+ device : str
19
+ devicde
20
+
21
+ Returns:
22
+ --------
23
+ spline values : 3D torch.tensor
24
+ shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order.
25
+
26
+ Example
27
+ -------
28
+ >>> from kan.spline import B_batch
29
+ >>> x = torch.rand(100,2)
30
+ >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11)
31
+ >>> B_batch(x, grid, k=3).shape
32
+ '''
33
+
34
+ x = x.unsqueeze(dim=2)
35
+ grid = grid.unsqueeze(dim=0)
36
+
37
+ if k == 0:
38
+ value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
39
+ else:
40
+ B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1)
41
+
42
+ value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + (
43
+ grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:]
44
+
45
+ # in case grid is degenerate
46
+ value = torch.nan_to_num(value)
47
+ return value
48
+
49
+
50
+
51
+ def coef2curve(x_eval, grid, coef, k, device="cpu"):
52
+ '''
53
+ converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
54
+
55
+ Args:
56
+ -----
57
+ x_eval : 2D torch.tensor
58
+ shape (batch, in_dim)
59
+ grid : 2D torch.tensor
60
+ shape (in_dim, G+2k). G: the number of grid intervals; k: spline order.
61
+ coef : 3D torch.tensor
62
+ shape (in_dim, out_dim, G+k)
63
+ k : int
64
+ the piecewise polynomial order of splines.
65
+ device : str
66
+ devicde
67
+
68
+ Returns:
69
+ --------
70
+ y_eval : 3D torch.tensor
71
+ shape (batch, in_dim, out_dim)
72
+
73
+ '''
74
+
75
+ b_splines = B_batch(x_eval, grid, k=k)
76
+ y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))
77
+
78
+ return y_eval
79
+
80
+
81
+ def curve2coef(x_eval, y_eval, grid, k):
82
+ '''
83
+ converting B-spline curves to B-spline coefficients using least squares.
84
+
85
+ Args:
86
+ -----
87
+ x_eval : 2D torch.tensor
88
+ shape (batch, in_dim)
89
+ y_eval : 3D torch.tensor
90
+ shape (batch, in_dim, out_dim)
91
+ grid : 2D torch.tensor
92
+ shape (in_dim, grid+2*k)
93
+ k : int
94
+ spline order
95
+ lamb : float
96
+ regularized least square lambda
97
+
98
+ Returns:
99
+ --------
100
+ coef : 3D torch.tensor
101
+ shape (in_dim, out_dim, G+k)
102
+ '''
103
+ #print('haha', x_eval.shape, y_eval.shape, grid.shape)
104
+ batch = x_eval.shape[0]
105
+ in_dim = x_eval.shape[1]
106
+ out_dim = y_eval.shape[2]
107
+ n_coef = grid.shape[1] - k - 1
108
+ mat = B_batch(x_eval, grid, k)
109
+ mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef)
110
+ #print('mat', mat.shape)
111
+ y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3)
112
+ #print('y_eval', y_eval.shape)
113
+ device = mat.device
114
+
115
+ #coef = torch.linalg.lstsq(mat, y_eval, driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
116
+ try:
117
+ coef = torch.linalg.lstsq(mat, y_eval).solution[:,:,:,0]
118
+ except:
119
+ print('lstsq failed')
120
+
121
+ # manual psuedo-inverse
122
+ '''lamb=1e-8
123
+ XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
124
+ Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
125
+ n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
126
+ identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
127
+ A = XtX + lamb * identity
128
+ B = Xty
129
+ coef = (A.pinverse() @ B)[:,:,:,0]'''
130
+
131
+ return coef
132
+
133
+
134
+ def extend_grid(grid, k_extend=0):
135
+ '''
136
+ extend grid
137
+ '''
138
+ h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
139
+
140
+ for i in range(k_extend):
141
+ grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
142
+ grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
143
+
144
+ return grid