yms-kan 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan-0.0.7/LICENSE +21 -0
- yms_kan-0.0.7/PKG-INFO +18 -0
- yms_kan-0.0.7/README.md +1 -0
- yms_kan-0.0.7/kan/KANLayer.py +364 -0
- yms_kan-0.0.7/kan/LBFGS.py +492 -0
- yms_kan-0.0.7/kan/MLP.py +361 -0
- yms_kan-0.0.7/kan/MultKAN.py +3087 -0
- yms_kan-0.0.7/kan/Symbolic_KANLayer.py +270 -0
- yms_kan-0.0.7/kan/__init__.py +3 -0
- yms_kan-0.0.7/kan/compiler.py +498 -0
- yms_kan-0.0.7/kan/dataset.py +27 -0
- yms_kan-0.0.7/kan/experiment.py +50 -0
- yms_kan-0.0.7/kan/feynman.py +739 -0
- yms_kan-0.0.7/kan/hypothesis.py +695 -0
- yms_kan-0.0.7/kan/spline.py +144 -0
- yms_kan-0.0.7/kan/utils.py +661 -0
- yms_kan-0.0.7/setup.cfg +4 -0
- yms_kan-0.0.7/setup.py +96 -0
- yms_kan-0.0.7/yms_kan.egg-info/PKG-INFO +18 -0
- yms_kan-0.0.7/yms_kan.egg-info/SOURCES.txt +20 -0
- yms_kan-0.0.7/yms_kan.egg-info/dependency_links.txt +1 -0
- yms_kan-0.0.7/yms_kan.egg-info/top_level.txt +1 -0
| @@ -0,0 +1,661 @@ | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from sklearn.linear_model import LinearRegression
         | 
| 4 | 
            +
            import sympy
         | 
| 5 | 
            +
            import yaml
         | 
| 6 | 
            +
            from sympy.utilities.lambdify import lambdify
         | 
| 7 | 
            +
            import re
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # sigmoid = sympy.Function('sigmoid')
         | 
| 10 | 
            +
            # name: (torch implementation, sympy implementation)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # singularity protection functions
         | 
| 13 | 
            +
            f_inv = lambda x, y_th: (
         | 
| 14 | 
            +
            (x_th := 1 / y_th), y_th / x_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x) * (torch.abs(x) >= x_th))
         | 
| 15 | 
            +
            f_inv2 = lambda x, y_th: (
         | 
| 16 | 
            +
            (x_th := 1 / y_th ** (1 / 2)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 2) * (torch.abs(x) >= x_th))
         | 
| 17 | 
            +
            f_inv3 = lambda x, y_th: ((x_th := 1 / y_th ** (1 / 3)),
         | 
| 18 | 
            +
                                      y_th / x_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 3) * (
         | 
| 19 | 
            +
                                                  torch.abs(x) >= x_th))
         | 
| 20 | 
            +
            f_inv4 = lambda x, y_th: (
         | 
| 21 | 
            +
            (x_th := 1 / y_th ** (1 / 4)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 4) * (torch.abs(x) >= x_th))
         | 
| 22 | 
            +
            f_inv5 = lambda x, y_th: ((x_th := 1 / y_th ** (1 / 5)),
         | 
| 23 | 
            +
                                      y_th / x_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x ** 5) * (
         | 
| 24 | 
            +
                                                  torch.abs(x) >= x_th))
         | 
| 25 | 
            +
            f_sqrt = lambda x, y_th: ((x_th := 1 / y_th ** 2), x_th / y_th * x * (torch.abs(x) < x_th) + torch.nan_to_num(
         | 
| 26 | 
            +
                torch.sqrt(torch.abs(x)) * torch.sign(x)) * (torch.abs(x) >= x_th))
         | 
| 27 | 
            +
            f_power1d5 = lambda x, y_th: torch.abs(x) ** 1.5
         | 
| 28 | 
            +
            f_invsqrt = lambda x, y_th: ((x_th := 1 / y_th ** 2),
         | 
| 29 | 
            +
                                         y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / torch.sqrt(torch.abs(x))) * (
         | 
| 30 | 
            +
                                                     torch.abs(x) >= x_th))
         | 
| 31 | 
            +
            f_log = lambda x, y_th: ((x_th := torch.e ** (-y_th)),
         | 
| 32 | 
            +
                                     - y_th * (torch.abs(x) < x_th) + torch.nan_to_num(torch.log(torch.abs(x))) * (
         | 
| 33 | 
            +
                                                 torch.abs(x) >= x_th))
         | 
| 34 | 
            +
            f_tan = lambda x, y_th: ((clip := x % torch.pi), (delta := torch.pi / 2 - torch.arctan(y_th)),
         | 
| 35 | 
            +
                                     - y_th / delta * (clip - torch.pi / 2) * (
         | 
| 36 | 
            +
                                                 torch.abs(clip - torch.pi / 2) < delta) + torch.nan_to_num(torch.tan(clip)) * (
         | 
| 37 | 
            +
                                                 torch.abs(clip - torch.pi / 2) >= delta))
         | 
| 38 | 
            +
            f_arctanh = lambda x, y_th: ((delta := 1 - torch.tanh(y_th) + 1e-4),
         | 
| 39 | 
            +
                                         y_th * torch.sign(x) * (torch.abs(x) > 1 - delta) + torch.nan_to_num(torch.arctanh(x)) * (
         | 
| 40 | 
            +
                                                     torch.abs(x) <= 1 - delta))
         | 
| 41 | 
            +
            f_arcsin = lambda x, y_th: (
         | 
| 42 | 
            +
            (), torch.pi / 2 * torch.sign(x) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arcsin(x)) * (torch.abs(x) <= 1))
         | 
| 43 | 
            +
            f_arccos = lambda x, y_th: (
         | 
| 44 | 
            +
            (), torch.pi / 2 * (1 - torch.sign(x)) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arccos(x)) * (torch.abs(x) <= 1))
         | 
| 45 | 
            +
            f_exp = lambda x, y_th: ((x_th := torch.log(y_th)), y_th * (x > x_th) + torch.exp(x) * (x <= x_th))
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)),
         | 
| 48 | 
            +
                            'x^2': (lambda x: x ** 2, lambda x: x ** 2, 2, lambda x, y_th: ((), x ** 2)),
         | 
| 49 | 
            +
                            'x^3': (lambda x: x ** 3, lambda x: x ** 3, 3, lambda x, y_th: ((), x ** 3)),
         | 
| 50 | 
            +
                            'x^4': (lambda x: x ** 4, lambda x: x ** 4, 3, lambda x, y_th: ((), x ** 4)),
         | 
| 51 | 
            +
                            'x^5': (lambda x: x ** 5, lambda x: x ** 5, 3, lambda x, y_th: ((), x ** 5)),
         | 
| 52 | 
            +
                            '1/x': (lambda x: 1 / x, lambda x: 1 / x, 2, f_inv),
         | 
| 53 | 
            +
                            '1/x^2': (lambda x: 1 / x ** 2, lambda x: 1 / x ** 2, 2, f_inv2),
         | 
| 54 | 
            +
                            '1/x^3': (lambda x: 1 / x ** 3, lambda x: 1 / x ** 3, 3, f_inv3),
         | 
| 55 | 
            +
                            '1/x^4': (lambda x: 1 / x ** 4, lambda x: 1 / x ** 4, 4, f_inv4),
         | 
| 56 | 
            +
                            '1/x^5': (lambda x: 1 / x ** 5, lambda x: 1 / x ** 5, 5, f_inv5),
         | 
| 57 | 
            +
                            'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
         | 
| 58 | 
            +
                            'x^0.5': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
         | 
| 59 | 
            +
                            'x^1.5': (lambda x: torch.sqrt(x) ** 3, lambda x: sympy.sqrt(x) ** 3, 4, f_power1d5),
         | 
| 60 | 
            +
                            '1/sqrt(x)': (lambda x: 1 / torch.sqrt(x), lambda x: 1 / sympy.sqrt(x), 2, f_invsqrt),
         | 
| 61 | 
            +
                            '1/x^0.5': (lambda x: 1 / torch.sqrt(x), lambda x: 1 / sympy.sqrt(x), 2, f_invsqrt),
         | 
| 62 | 
            +
                            'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x), 2, f_exp),
         | 
| 63 | 
            +
                            'log': (lambda x: torch.log(x), lambda x: sympy.log(x), 2, f_log),
         | 
| 64 | 
            +
                            'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x), 3, lambda x, y_th: ((), torch.abs(x))),
         | 
| 65 | 
            +
                            'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x), 2, lambda x, y_th: ((), torch.sin(x))),
         | 
| 66 | 
            +
                            'cos': (lambda x: torch.cos(x), lambda x: sympy.cos(x), 2, lambda x, y_th: ((), torch.cos(x))),
         | 
| 67 | 
            +
                            'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x), 3, f_tan),
         | 
| 68 | 
            +
                            'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x), 3, lambda x, y_th: ((), torch.tanh(x))),
         | 
| 69 | 
            +
                            'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x), 3, lambda x, y_th: ((), torch.sign(x))),
         | 
| 70 | 
            +
                            'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.asin(x), 4, f_arcsin),
         | 
| 71 | 
            +
                            'arccos': (lambda x: torch.arccos(x), lambda x: sympy.acos(x), 4, f_arccos),
         | 
| 72 | 
            +
                            'arctan': (
         | 
| 73 | 
            +
                            lambda x: torch.arctan(x), lambda x: sympy.atan(x), 4, lambda x, y_th: ((), torch.arctan(x))),
         | 
| 74 | 
            +
                            'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x), 4, f_arctanh),
         | 
| 75 | 
            +
                            '0': (lambda x: x * 0, lambda x: x * 0, 0, lambda x, y_th: ((), x * 0)),
         | 
| 76 | 
            +
                            'gaussian': (lambda x: torch.exp(-x ** 2), lambda x: sympy.exp(-x ** 2), 3,
         | 
| 77 | 
            +
                                         lambda x, y_th: ((), torch.exp(-x ** 2))),
         | 
| 78 | 
            +
                            #'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5),
         | 
| 79 | 
            +
                            #'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4),
         | 
| 80 | 
            +
                            #'relu': (lambda x: torch.relu(x), relu),
         | 
| 81 | 
            +
                            }
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            def create_dataset(f,
         | 
| 85 | 
            +
                               n_var=2,
         | 
| 86 | 
            +
                               f_mode='col',
         | 
| 87 | 
            +
                               ranges=[-1, 1],
         | 
| 88 | 
            +
                               train_num=1000,
         | 
| 89 | 
            +
                               test_num=1000,
         | 
| 90 | 
            +
                               normalize_input=False,
         | 
| 91 | 
            +
                               normalize_label=False,
         | 
| 92 | 
            +
                               device='cpu',
         | 
| 93 | 
            +
                               seed=0):
         | 
| 94 | 
            +
                '''
         | 
| 95 | 
            +
                create dataset
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                Args:
         | 
| 98 | 
            +
                -----
         | 
| 99 | 
            +
                    f : function
         | 
| 100 | 
            +
                        the symbolic formula used to create the synthetic dataset
         | 
| 101 | 
            +
                    ranges : list or np.array; shape (2,) or (n_var, 2)
         | 
| 102 | 
            +
                        the range of input variables. Default: [-1,1].
         | 
| 103 | 
            +
                    train_num : int
         | 
| 104 | 
            +
                        the number of training samples. Default: 1000.
         | 
| 105 | 
            +
                    test_num : int
         | 
| 106 | 
            +
                        the number of test samples. Default: 1000.
         | 
| 107 | 
            +
                    normalize_input : bool
         | 
| 108 | 
            +
                        If True, apply normalization to inputs. Default: False.
         | 
| 109 | 
            +
                    normalize_label : bool
         | 
| 110 | 
            +
                        If True, apply normalization to labels. Default: False.
         | 
| 111 | 
            +
                    device : str
         | 
| 112 | 
            +
                        device. Default: 'cpu'.
         | 
| 113 | 
            +
                    seed : int
         | 
| 114 | 
            +
                        random seed. Default: 0.
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                Returns:
         | 
| 117 | 
            +
                --------
         | 
| 118 | 
            +
                    dataset : dic
         | 
| 119 | 
            +
                        Train/test inputs/labels are dataset['train_input'], dataset['train_label'],
         | 
| 120 | 
            +
                                    dataset['test_input'], dataset['test_label']
         | 
| 121 | 
            +
                     
         | 
| 122 | 
            +
                Example
         | 
| 123 | 
            +
                -------
         | 
| 124 | 
            +
                >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
         | 
| 125 | 
            +
                >>> dataset = create_dataset(f, n_var=2, train_num=100)
         | 
| 126 | 
            +
                >>> dataset['train_input'].shape
         | 
| 127 | 
            +
                torch.Size([100, 2])
         | 
| 128 | 
            +
                '''
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                np.random.seed(seed)
         | 
| 131 | 
            +
                torch.manual_seed(seed)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                if len(np.array(ranges).shape) == 1:
         | 
| 134 | 
            +
                    ranges = np.array(ranges * n_var).reshape(n_var, 2)
         | 
| 135 | 
            +
                else:
         | 
| 136 | 
            +
                    ranges = np.array(ranges)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                train_input = torch.zeros(train_num, n_var)
         | 
| 139 | 
            +
                test_input = torch.zeros(test_num, n_var)
         | 
| 140 | 
            +
                for i in range(n_var):
         | 
| 141 | 
            +
                    train_input[:, i] = torch.rand(train_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0]
         | 
| 142 | 
            +
                    test_input[:, i] = torch.rand(test_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                if f_mode == 'col':
         | 
| 145 | 
            +
                    train_label = f(train_input)
         | 
| 146 | 
            +
                    test_label = f(test_input)
         | 
| 147 | 
            +
                elif f_mode == 'row':
         | 
| 148 | 
            +
                    train_label = f(train_input.T)
         | 
| 149 | 
            +
                    test_label = f(test_input.T)
         | 
| 150 | 
            +
                else:
         | 
| 151 | 
            +
                    print(f'f_mode {f_mode} not recognized')
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                # if has only 1 dimension
         | 
| 154 | 
            +
                if len(train_label.shape) == 1:
         | 
| 155 | 
            +
                    train_label = train_label.unsqueeze(dim=1)
         | 
| 156 | 
            +
                    test_label = test_label.unsqueeze(dim=1)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                def normalize(data, mean, std):
         | 
| 159 | 
            +
                    return (data - mean) / std
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                if normalize_input == True:
         | 
| 162 | 
            +
                    mean_input = torch.mean(train_input, dim=0, keepdim=True)
         | 
| 163 | 
            +
                    std_input = torch.std(train_input, dim=0, keepdim=True)
         | 
| 164 | 
            +
                    train_input = normalize(train_input, mean_input, std_input)
         | 
| 165 | 
            +
                    test_input = normalize(test_input, mean_input, std_input)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                if normalize_label == True:
         | 
| 168 | 
            +
                    mean_label = torch.mean(train_label, dim=0, keepdim=True)
         | 
| 169 | 
            +
                    std_label = torch.std(train_label, dim=0, keepdim=True)
         | 
| 170 | 
            +
                    train_label = normalize(train_label, mean_label, std_label)
         | 
| 171 | 
            +
                    test_label = normalize(test_label, mean_label, std_label)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                dataset = {}
         | 
| 174 | 
            +
                dataset['train_input'] = train_input.to(device)
         | 
| 175 | 
            +
                dataset['test_input'] = test_input.to(device)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                dataset['train_label'] = train_label.to(device)
         | 
| 178 | 
            +
                dataset['test_label'] = test_label.to(device)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                return dataset
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            def fit_params(x, y, fun, a_range=(-10, 10), b_range=(-10, 10), grid_number=101, iteration=3, verbose=True,
         | 
| 184 | 
            +
                           device='cpu'):
         | 
| 185 | 
            +
                '''
         | 
| 186 | 
            +
                fit a, b, c, d such that
         | 
| 187 | 
            +
                
         | 
| 188 | 
            +
                .. math::
         | 
| 189 | 
            +
                    |y-(cf(ax+b)+d)|^2
         | 
| 190 | 
            +
                    
         | 
| 191 | 
            +
                is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model.
         | 
| 192 | 
            +
                
         | 
| 193 | 
            +
                Args:
         | 
| 194 | 
            +
                -----
         | 
| 195 | 
            +
                    x : 1D array
         | 
| 196 | 
            +
                        x values
         | 
| 197 | 
            +
                    y : 1D array
         | 
| 198 | 
            +
                        y values
         | 
| 199 | 
            +
                    fun : function
         | 
| 200 | 
            +
                        symbolic function
         | 
| 201 | 
            +
                    a_range : tuple
         | 
| 202 | 
            +
                        sweeping range of a
         | 
| 203 | 
            +
                    b_range : tuple
         | 
| 204 | 
            +
                        sweeping range of b
         | 
| 205 | 
            +
                    grid_num : int
         | 
| 206 | 
            +
                        number of steps along a and b
         | 
| 207 | 
            +
                    iteration : int
         | 
| 208 | 
            +
                        number of zooming in
         | 
| 209 | 
            +
                    verbose : bool
         | 
| 210 | 
            +
                        print extra information if True
         | 
| 211 | 
            +
                    device : str
         | 
| 212 | 
            +
                        device
         | 
| 213 | 
            +
                    
         | 
| 214 | 
            +
                Returns:
         | 
| 215 | 
            +
                --------
         | 
| 216 | 
            +
                    a_best : float
         | 
| 217 | 
            +
                        best fitted a
         | 
| 218 | 
            +
                    b_best : float
         | 
| 219 | 
            +
                        best fitted b
         | 
| 220 | 
            +
                    c_best : float
         | 
| 221 | 
            +
                        best fitted c
         | 
| 222 | 
            +
                    d_best : float
         | 
| 223 | 
            +
                        best fitted d
         | 
| 224 | 
            +
                    r2_best : float
         | 
| 225 | 
            +
                        best r2 (coefficient of determination)
         | 
| 226 | 
            +
                
         | 
| 227 | 
            +
                Example
         | 
| 228 | 
            +
                -------
         | 
| 229 | 
            +
                >>> num = 100
         | 
| 230 | 
            +
                >>> x = torch.linspace(-1,1,steps=num)
         | 
| 231 | 
            +
                >>> noises = torch.normal(0,1,(num,)) * 0.02
         | 
| 232 | 
            +
                >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
         | 
| 233 | 
            +
                >>> fit_params(x, y, torch.sin)
         | 
| 234 | 
            +
                r2 is 0.9999727010726929
         | 
| 235 | 
            +
                (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000))
         | 
| 236 | 
            +
                '''
         | 
| 237 | 
            +
                # fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array.
         | 
| 238 | 
            +
                # sweep a and b, choose the best fitted model   
         | 
| 239 | 
            +
                for _ in range(iteration):
         | 
| 240 | 
            +
                    a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number, device=device)
         | 
| 241 | 
            +
                    b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number, device=device)
         | 
| 242 | 
            +
                    a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij')
         | 
| 243 | 
            +
                    post_fun = fun(a_grid[None, :, :] * x[:, None, None] + b_grid[None, :, :])
         | 
| 244 | 
            +
                    x_mean = torch.mean(post_fun, dim=[0], keepdim=True)
         | 
| 245 | 
            +
                    y_mean = torch.mean(y, dim=[0], keepdim=True)
         | 
| 246 | 
            +
                    numerator = torch.sum((post_fun - x_mean) * (y - y_mean)[:, None, None], dim=0) ** 2
         | 
| 247 | 
            +
                    denominator = torch.sum((post_fun - x_mean) ** 2, dim=0) * torch.sum((y - y_mean)[:, None, None] ** 2, dim=0)
         | 
| 248 | 
            +
                    r2 = numerator / (denominator + 1e-4)
         | 
| 249 | 
            +
                    r2 = torch.nan_to_num(r2)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    best_id = torch.argmax(r2)
         | 
| 252 | 
            +
                    a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1:
         | 
| 255 | 
            +
                        if _ == 0 and verbose == True:
         | 
| 256 | 
            +
                            print('Best value at boundary.')
         | 
| 257 | 
            +
                        if a_id == 0:
         | 
| 258 | 
            +
                            a_range = [a_[0], a_[1]]
         | 
| 259 | 
            +
                        if a_id == grid_number - 1:
         | 
| 260 | 
            +
                            a_range = [a_[-2], a_[-1]]
         | 
| 261 | 
            +
                        if b_id == 0:
         | 
| 262 | 
            +
                            b_range = [b_[0], b_[1]]
         | 
| 263 | 
            +
                        if b_id == grid_number - 1:
         | 
| 264 | 
            +
                            b_range = [b_[-2], b_[-1]]
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    else:
         | 
| 267 | 
            +
                        a_range = [a_[a_id - 1], a_[a_id + 1]]
         | 
| 268 | 
            +
                        b_range = [b_[b_id - 1], b_[b_id + 1]]
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                a_best = a_[a_id]
         | 
| 271 | 
            +
                b_best = b_[b_id]
         | 
| 272 | 
            +
                post_fun = fun(a_best * x + b_best)
         | 
| 273 | 
            +
                r2_best = r2[a_id, b_id]
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                if verbose == True:
         | 
| 276 | 
            +
                    print(f"r2 is {r2_best}")
         | 
| 277 | 
            +
                    if r2_best < 0.9:
         | 
| 278 | 
            +
                        print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.')
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                post_fun = torch.nan_to_num(post_fun)
         | 
| 281 | 
            +
                reg = LinearRegression().fit(post_fun[:, None].detach().cpu().numpy(), y.detach().cpu().numpy())
         | 
| 282 | 
            +
                c_best = torch.from_numpy(reg.coef_)[0].to(device)
         | 
| 283 | 
            +
                d_best = torch.from_numpy(np.array(reg.intercept_)).to(device)
         | 
| 284 | 
            +
                return torch.stack([a_best, b_best, c_best, d_best]), r2_best
         | 
| 285 | 
            +
             | 
| 286 | 
            +
             | 
| 287 | 
            +
            def sparse_mask(in_dim, out_dim):
         | 
| 288 | 
            +
                '''
         | 
| 289 | 
            +
                get sparse mask
         | 
| 290 | 
            +
                '''
         | 
| 291 | 
            +
                in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim)
         | 
| 292 | 
            +
                out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :])
         | 
| 295 | 
            +
                in_nearest = torch.argmin(dist_mat, dim=0)
         | 
| 296 | 
            +
                in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1, 0)
         | 
| 297 | 
            +
                out_nearest = torch.argmin(dist_mat, dim=1)
         | 
| 298 | 
            +
                out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1, 0)
         | 
| 299 | 
            +
                all_connection = torch.cat([in_connection, out_connection], dim=0)
         | 
| 300 | 
            +
                mask = torch.zeros(in_dim, out_dim)
         | 
| 301 | 
            +
                mask[all_connection[:, 0], all_connection[:, 1]] = 1.
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                return mask
         | 
| 304 | 
            +
             | 
| 305 | 
            +
             | 
| 306 | 
            +
            def add_symbolic(name, fun, c=1, fun_singularity=None):
         | 
| 307 | 
            +
                '''
         | 
| 308 | 
            +
                add a symbolic function to library
         | 
| 309 | 
            +
                
         | 
| 310 | 
            +
                Args:
         | 
| 311 | 
            +
                -----
         | 
| 312 | 
            +
                    name : str
         | 
| 313 | 
            +
                        name of the function
         | 
| 314 | 
            +
                    fun : fun
         | 
| 315 | 
            +
                        torch function or lambda function
         | 
| 316 | 
            +
                
         | 
| 317 | 
            +
                Returns:
         | 
| 318 | 
            +
                --------
         | 
| 319 | 
            +
                    None
         | 
| 320 | 
            +
                
         | 
| 321 | 
            +
                Example
         | 
| 322 | 
            +
                -------
         | 
| 323 | 
            +
                >>> print(SYMBOLIC_LIB['Bessel'])
         | 
| 324 | 
            +
                KeyError: 'Bessel'
         | 
| 325 | 
            +
                >>> add_symbolic('Bessel', torch.special.bessel_j0)
         | 
| 326 | 
            +
                >>> print(SYMBOLIC_LIB['Bessel'])
         | 
| 327 | 
            +
                (<built-in function special_bessel_j0>, Bessel)
         | 
| 328 | 
            +
                '''
         | 
| 329 | 
            +
                exec(f"globals()['{name}'] = sympy.Function('{name}')")
         | 
| 330 | 
            +
                if fun_singularity == None:
         | 
| 331 | 
            +
                    fun_singularity = fun
         | 
| 332 | 
            +
                SYMBOLIC_LIB[name] = (fun, globals()[name], c, fun_singularity)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
             | 
| 335 | 
            +
            def ex_round(ex1, n_digit):
         | 
| 336 | 
            +
                '''
         | 
| 337 | 
            +
                rounding the numbers in an expression to certain floating points
         | 
| 338 | 
            +
                
         | 
| 339 | 
            +
                Args:
         | 
| 340 | 
            +
                -----
         | 
| 341 | 
            +
                    ex1 : sympy expression
         | 
| 342 | 
            +
                    n_digit : int
         | 
| 343 | 
            +
                    
         | 
| 344 | 
            +
                Returns:
         | 
| 345 | 
            +
                --------
         | 
| 346 | 
            +
                    ex2 : sympy expression
         | 
| 347 | 
            +
                
         | 
| 348 | 
            +
                Example
         | 
| 349 | 
            +
                -------
         | 
| 350 | 
            +
                >>> from kan.utils import *
         | 
| 351 | 
            +
                >>> from sympy import *
         | 
| 352 | 
            +
                >>> input_vars = a, b = symbols('a b')
         | 
| 353 | 
            +
                >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
         | 
| 354 | 
            +
                >>> ex_round(expression, 2)
         | 
| 355 | 
            +
                '''
         | 
| 356 | 
            +
                ex2 = ex1
         | 
| 357 | 
            +
                for a in sympy.preorder_traversal(ex1):
         | 
| 358 | 
            +
                    if isinstance(a, sympy.Float):
         | 
| 359 | 
            +
                        ex2 = ex2.subs(a, round(a, n_digit))
         | 
| 360 | 
            +
                return ex2
         | 
| 361 | 
            +
             | 
| 362 | 
            +
             | 
| 363 | 
            +
            def augment_input(orig_vars, aux_vars, x):
         | 
| 364 | 
            +
                '''
         | 
| 365 | 
            +
                augment inputs
         | 
| 366 | 
            +
                
         | 
| 367 | 
            +
                Args:
         | 
| 368 | 
            +
                -----
         | 
| 369 | 
            +
                    orig_vars : list of sympy symbols
         | 
| 370 | 
            +
                    aux_vars : list of auxiliary symbols
         | 
| 371 | 
            +
                    x : inputs
         | 
| 372 | 
            +
                    
         | 
| 373 | 
            +
                Returns:
         | 
| 374 | 
            +
                --------
         | 
| 375 | 
            +
                    augmented inputs
         | 
| 376 | 
            +
                
         | 
| 377 | 
            +
                Example
         | 
| 378 | 
            +
                -------
         | 
| 379 | 
            +
                >>> from kan.utils import *
         | 
| 380 | 
            +
                >>> from sympy import *
         | 
| 381 | 
            +
                >>> orig_vars = a, b = symbols('a b')
         | 
| 382 | 
            +
                >>> aux_vars = [a + b, a * b]
         | 
| 383 | 
            +
                >>> x = torch.rand(100, 2)
         | 
| 384 | 
            +
                >>> augment_input(orig_vars, aux_vars, x).shape
         | 
| 385 | 
            +
                '''
         | 
| 386 | 
            +
                # if x is a tensor
         | 
| 387 | 
            +
                if isinstance(x, torch.Tensor):
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    aux_values = torch.tensor([]).to(x.device)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    for aux_var in aux_vars:
         | 
| 392 | 
            +
                        func = lambdify(orig_vars, aux_var, 'numpy')  # returns a numpy-ready function
         | 
| 393 | 
            +
                        aux_value = torch.from_numpy(func(*[x[:, [i]].numpy() for i in range(len(orig_vars))]))
         | 
| 394 | 
            +
                        aux_values = torch.cat([aux_values, aux_value], dim=1)
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    x = torch.cat([aux_values, x], dim=1)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                # if x is a dataset
         | 
| 399 | 
            +
                elif isinstance(x, dict):
         | 
| 400 | 
            +
                    x['train_input'] = augment_input(orig_vars, aux_vars, x['train_input'])
         | 
| 401 | 
            +
                    x['test_input'] = augment_input(orig_vars, aux_vars, x['test_input'])
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                return x
         | 
| 404 | 
            +
             | 
| 405 | 
            +
             | 
| 406 | 
            +
            def batch_jacobian(func, x, create_graph=False, mode='scalar'):
         | 
| 407 | 
            +
                '''
         | 
| 408 | 
            +
                jacobian
         | 
| 409 | 
            +
                
         | 
| 410 | 
            +
                Args:
         | 
| 411 | 
            +
                -----
         | 
| 412 | 
            +
                    func : function or model
         | 
| 413 | 
            +
                    x : inputs
         | 
| 414 | 
            +
                    create_graph : bool
         | 
| 415 | 
            +
                    
         | 
| 416 | 
            +
                Returns:
         | 
| 417 | 
            +
                --------
         | 
| 418 | 
            +
                    jacobian
         | 
| 419 | 
            +
                
         | 
| 420 | 
            +
                Example
         | 
| 421 | 
            +
                -------
         | 
| 422 | 
            +
                >>> from kan.utils import batch_jacobian
         | 
| 423 | 
            +
                >>> x = torch.normal(0,1,size=(100,2))
         | 
| 424 | 
            +
                >>> model = lambda x: x[:,[0]] + x[:,[1]]
         | 
| 425 | 
            +
                >>> batch_jacobian(model, x)
         | 
| 426 | 
            +
                '''
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                # x in shape (Batch, Length)
         | 
| 429 | 
            +
                def _func_sum(x):
         | 
| 430 | 
            +
                    return func(x).sum(dim=0)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                if mode == 'scalar':
         | 
| 433 | 
            +
                    return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
         | 
| 434 | 
            +
                elif mode == 'vector':
         | 
| 435 | 
            +
                    return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1, 0, 2)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
             | 
| 438 | 
            +
            def batch_hessian(model, x, create_graph=False):
         | 
| 439 | 
            +
                '''
         | 
| 440 | 
            +
                hessian
         | 
| 441 | 
            +
                
         | 
| 442 | 
            +
                Args:
         | 
| 443 | 
            +
                -----
         | 
| 444 | 
            +
                    func : function or model
         | 
| 445 | 
            +
                    x : inputs
         | 
| 446 | 
            +
                    create_graph : bool
         | 
| 447 | 
            +
                    
         | 
| 448 | 
            +
                Returns:
         | 
| 449 | 
            +
                --------
         | 
| 450 | 
            +
                    jacobian
         | 
| 451 | 
            +
                
         | 
| 452 | 
            +
                Example
         | 
| 453 | 
            +
                -------
         | 
| 454 | 
            +
                >>> from kan.utils import batch_hessian
         | 
| 455 | 
            +
                >>> x = torch.normal(0,1,size=(100,2))
         | 
| 456 | 
            +
                >>> model = lambda x: x[:,[0]]**2 + x[:,[1]]**2
         | 
| 457 | 
            +
                >>> batch_hessian(model, x)
         | 
| 458 | 
            +
                '''
         | 
| 459 | 
            +
                # x in shape (Batch, Length)
         | 
| 460 | 
            +
                jac = lambda x: batch_jacobian(model, x, create_graph=True)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                def _jac_sum(x):
         | 
| 463 | 
            +
                    return jac(x).sum(dim=0)
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1, 0, 2)
         | 
| 466 | 
            +
             | 
| 467 | 
            +
             | 
| 468 | 
            +
            def create_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
         | 
| 469 | 
            +
                from collections import defaultdict
         | 
| 470 | 
            +
                class_indices = defaultdict(list)
         | 
| 471 | 
            +
                for idx, label in enumerate(labels):
         | 
| 472 | 
            +
                    class_indices[label.item()].append(idx)
         | 
| 473 | 
            +
                    # 初始化训练集和测试集索引
         | 
| 474 | 
            +
                train_id = []
         | 
| 475 | 
            +
                test_id = []
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                # 分层抽样
         | 
| 478 | 
            +
                for class_label, indices in class_indices.items():
         | 
| 479 | 
            +
                    num_samples = len(indices)
         | 
| 480 | 
            +
                    if num_samples == 0:
         | 
| 481 | 
            +
                        continue
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    # 计算训练样本数量
         | 
| 484 | 
            +
                    train_size = int(num_samples * train_ratio)
         | 
| 485 | 
            +
                    if train_size == 0:
         | 
| 486 | 
            +
                        train_size = 1  # 确保至少有一个样本
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    # 随机选择训练样本
         | 
| 489 | 
            +
                    np.random.shuffle(indices)
         | 
| 490 | 
            +
                    train_subset = indices[:train_size]
         | 
| 491 | 
            +
                    test_subset = indices[train_size:]
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    train_id.extend(train_subset)
         | 
| 494 | 
            +
                    test_id.extend(test_subset)
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                # 转换为numpy数组并打乱
         | 
| 497 | 
            +
                train_id = np.array(train_id)
         | 
| 498 | 
            +
                test_id = np.array(test_id)
         | 
| 499 | 
            +
                np.random.shuffle(train_id)  # 默认打乱训练集索引
         | 
| 500 | 
            +
                np.random.shuffle(test_id)  # 默认打乱测试集索引
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                # 构建数据集
         | 
| 503 | 
            +
                dataset = {
         | 
| 504 | 
            +
                    'train_input': inputs[train_id].detach().to(device),
         | 
| 505 | 
            +
                    'test_input': inputs[test_id].detach().to(device),
         | 
| 506 | 
            +
                    'train_label': labels[train_id].detach().to(device),
         | 
| 507 | 
            +
                    'test_label': labels[test_id].detach().to(device)
         | 
| 508 | 
            +
                }
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                return dataset
         | 
| 511 | 
            +
             | 
| 512 | 
            +
             | 
| 513 | 
            +
            def create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
         | 
| 514 | 
            +
                """
         | 
| 515 | 
            +
                create dataset from data
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                Args:
         | 
| 518 | 
            +
                -----
         | 
| 519 | 
            +
                    inputs : 2D torch.float
         | 
| 520 | 
            +
                    labels : 2D torch.float
         | 
| 521 | 
            +
                    train_ratio : float
         | 
| 522 | 
            +
                        the ratio of training fraction
         | 
| 523 | 
            +
                    device : str
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                Returns:
         | 
| 526 | 
            +
                --------
         | 
| 527 | 
            +
                    dataset (dictionary)
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                Example
         | 
| 530 | 
            +
                -------
         | 
| 531 | 
            +
                >>> from kan.utils import create_dataset_from_data
         | 
| 532 | 
            +
                >>> x = torch.normal(0,1,size=(100,2))
         | 
| 533 | 
            +
                >>> y = torch.normal(0,1,size=(100,1))
         | 
| 534 | 
            +
                >>> dataset = create_dataset_from_data(x, y)
         | 
| 535 | 
            +
                >>> dataset['train_input'].shape
         | 
| 536 | 
            +
                """
         | 
| 537 | 
            +
                num = inputs.shape[0]
         | 
| 538 | 
            +
                train_id = np.random.choice(num, int(num * train_ratio), replace=False)
         | 
| 539 | 
            +
                test_id = list(set(np.arange(num)) - set(train_id))
         | 
| 540 | 
            +
                dataset = {'train_input': inputs[train_id].detach().to(device), 'test_input': inputs[test_id].detach().to(device),
         | 
| 541 | 
            +
                           'train_label': labels[train_id].detach().to(device), 'test_label': labels[test_id].detach().to(device)}
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                return dataset
         | 
| 544 | 
            +
             | 
| 545 | 
            +
             | 
| 546 | 
            +
            def get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0., lamb_l1=1.,
         | 
| 547 | 
            +
                               lamb_entropy=0.):
         | 
| 548 | 
            +
                '''
         | 
| 549 | 
            +
                compute the jacobian/hessian of loss wrt to model parameters
         | 
| 550 | 
            +
                
         | 
| 551 | 
            +
                Args:
         | 
| 552 | 
            +
                -----
         | 
| 553 | 
            +
                    inputs : 2D torch.float
         | 
| 554 | 
            +
                    labels : 2D torch.float
         | 
| 555 | 
            +
                    derivative : str
         | 
| 556 | 
            +
                        'jacobian' or 'hessian'
         | 
| 557 | 
            +
                    device : str
         | 
| 558 | 
            +
                    
         | 
| 559 | 
            +
                Returns:
         | 
| 560 | 
            +
                --------
         | 
| 561 | 
            +
                    jacobian or hessian
         | 
| 562 | 
            +
                '''
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                def get_mapping(model):
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                    mapping = {}
         | 
| 567 | 
            +
                    name = 'model1'
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                    keys = list(model.state_dict().keys())
         | 
| 570 | 
            +
                    for key in keys:
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                        y = re.findall(".[0-9]+", key)
         | 
| 573 | 
            +
                        if len(y) > 0:
         | 
| 574 | 
            +
                            y = y[0][1:]
         | 
| 575 | 
            +
                            x = re.split(".[0-9]+", key)
         | 
| 576 | 
            +
                            mapping[key] = name + '.' + x[0] + '[' + y + ']' + x[1]
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                        y = re.findall("_[0-9]+", key)
         | 
| 579 | 
            +
                        if len(y) > 0:
         | 
| 580 | 
            +
                            y = y[0][1:]
         | 
| 581 | 
            +
                            x = re.split(".[0-9]+", key)
         | 
| 582 | 
            +
                            mapping[key] = name + '.' + x[0] + '[' + y + ']'
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                    return mapping
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                #model1 = copy.deepcopy(model)
         | 
| 587 | 
            +
                model1 = model.copy()
         | 
| 588 | 
            +
                mapping = get_mapping(model)
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                # collect keys and shapes
         | 
| 591 | 
            +
                keys = list(model.state_dict().keys())
         | 
| 592 | 
            +
                shapes = []
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                for params in model.parameters():
         | 
| 595 | 
            +
                    shapes.append(params.shape)
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                # turn a flattened vector to model params
         | 
| 598 | 
            +
                def param2statedict(p, keys, shapes):
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                    new_state_dict = {}
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    start = 0
         | 
| 603 | 
            +
                    n_group = len(keys)
         | 
| 604 | 
            +
                    for i in range(n_group):
         | 
| 605 | 
            +
                        shape = shapes[i]
         | 
| 606 | 
            +
                        n_params = torch.prod(torch.tensor(shape))
         | 
| 607 | 
            +
                        new_state_dict[keys[i]] = p[start:start + n_params].reshape(shape)
         | 
| 608 | 
            +
                        start += n_params
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    return new_state_dict
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                def differentiable_load_state_dict(mapping, state_dict, model1):
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                    for key in keys:
         | 
| 615 | 
            +
                        if mapping[key][-1] != ']':
         | 
| 616 | 
            +
                            exec(f"del {mapping[key]}")
         | 
| 617 | 
            +
                        exec(f"{mapping[key]} = state_dict[key]")
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                # input: p, output: output
         | 
| 620 | 
            +
                def get_param2loss_fun(inputs, labels):
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                    def param2loss_fun(p):
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                        p = p[0]
         | 
| 625 | 
            +
                        state_dict = param2statedict(p, keys, shapes)
         | 
| 626 | 
            +
                        # this step is non-differentiable
         | 
| 627 | 
            +
                        #model.load_state_dict(state_dict)
         | 
| 628 | 
            +
                        differentiable_load_state_dict(mapping, state_dict, model1)
         | 
| 629 | 
            +
                        if loss_mode == 'pred':
         | 
| 630 | 
            +
                            pred_loss = torch.mean((model1(inputs) - labels) ** 2, dim=(0, 1), keepdim=True)
         | 
| 631 | 
            +
                            loss = pred_loss
         | 
| 632 | 
            +
                        elif loss_mode == 'reg':
         | 
| 633 | 
            +
                            reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1,
         | 
| 634 | 
            +
                                                      lamb_entropy=lamb_entropy) * torch.ones(1, 1)
         | 
| 635 | 
            +
                            loss = reg_loss
         | 
| 636 | 
            +
                        elif loss_mode == 'all':
         | 
| 637 | 
            +
                            pred_loss = torch.mean((model1(inputs) - labels) ** 2, dim=(0, 1), keepdim=True)
         | 
| 638 | 
            +
                            reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1,
         | 
| 639 | 
            +
                                                      lamb_entropy=lamb_entropy) * torch.ones(1, 1)
         | 
| 640 | 
            +
                            loss = pred_loss + lamb * reg_loss
         | 
| 641 | 
            +
                        return loss
         | 
| 642 | 
            +
             | 
| 643 | 
            +
                    return param2loss_fun
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                fun = get_param2loss_fun(inputs, labels)
         | 
| 646 | 
            +
                p = model2param(model)[None, :]
         | 
| 647 | 
            +
                if derivative == 'hessian':
         | 
| 648 | 
            +
                    result = batch_hessian(fun, p)
         | 
| 649 | 
            +
                elif derivative == 'jacobian':
         | 
| 650 | 
            +
                    result = batch_jacobian(fun, p)
         | 
| 651 | 
            +
                return result
         | 
| 652 | 
            +
             | 
| 653 | 
            +
             | 
| 654 | 
            +
            def model2param(model):
         | 
| 655 | 
            +
                '''
         | 
| 656 | 
            +
                turn model parameters into a flattened vector
         | 
| 657 | 
            +
                '''
         | 
| 658 | 
            +
                p = torch.tensor([]).to(model.device)
         | 
| 659 | 
            +
                for params in model.parameters():
         | 
| 660 | 
            +
                    p = torch.cat([p, params.reshape(-1, )], dim=0)
         | 
| 661 | 
            +
                return p
         | 
    
        yms_kan-0.0.7/setup.cfg
    ADDED