gemseo-multi-fidelity 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gemseo_multi_fidelity/__init__.py +17 -0
- gemseo_multi_fidelity/core/MFMapperAdapter_input.json +22 -0
- gemseo_multi_fidelity/core/MFMapperAdapter_output.json +22 -0
- gemseo_multi_fidelity/core/MFMapperLinker_input.json +22 -0
- gemseo_multi_fidelity/core/MFMapperLinker_output.json +22 -0
- gemseo_multi_fidelity/core/MFScenarioAdapter_input.json +39 -0
- gemseo_multi_fidelity/core/MFScenarioAdapter_output.json +23 -0
- gemseo_multi_fidelity/core/__init__.py +16 -0
- gemseo_multi_fidelity/core/boxed_domain.py +242 -0
- gemseo_multi_fidelity/core/corr_function.py +411 -0
- gemseo_multi_fidelity/core/criticality.py +124 -0
- gemseo_multi_fidelity/core/ds_mapper.py +307 -0
- gemseo_multi_fidelity/core/errors.py +42 -0
- gemseo_multi_fidelity/core/eval_mapper.py +188 -0
- gemseo_multi_fidelity/core/id_mapper_adapter.py +61 -0
- gemseo_multi_fidelity/core/mapper_adapter.py +126 -0
- gemseo_multi_fidelity/core/mapper_linker.py +72 -0
- gemseo_multi_fidelity/core/mf_formulation.py +635 -0
- gemseo_multi_fidelity/core/mf_logger.py +216 -0
- gemseo_multi_fidelity/core/mf_opt_problem.py +480 -0
- gemseo_multi_fidelity/core/mf_scenario.py +205 -0
- gemseo_multi_fidelity/core/noise_criterion.py +94 -0
- gemseo_multi_fidelity/core/projpolytope.out +0 -0
- gemseo_multi_fidelity/core/scenario_adapter.py +568 -0
- gemseo_multi_fidelity/core/stop_criteria.py +201 -0
- gemseo_multi_fidelity/core/strict_chain.py +75 -0
- gemseo_multi_fidelity/core/utils_model_quality.py +74 -0
- gemseo_multi_fidelity/corrections/__init__.py +16 -0
- gemseo_multi_fidelity/corrections/add_corr_function.py +80 -0
- gemseo_multi_fidelity/corrections/correction_factory.py +65 -0
- gemseo_multi_fidelity/corrections/mul_corr_function.py +86 -0
- gemseo_multi_fidelity/drivers/__init__.py +16 -0
- gemseo_multi_fidelity/drivers/mf_algo_factory.py +38 -0
- gemseo_multi_fidelity/drivers/mf_driver_lib.py +462 -0
- gemseo_multi_fidelity/drivers/refinement.py +234 -0
- gemseo_multi_fidelity/drivers/settings/__init__.py +16 -0
- gemseo_multi_fidelity/drivers/settings/base_mf_driver_settings.py +59 -0
- gemseo_multi_fidelity/drivers/settings/mf_refine_settings.py +50 -0
- gemseo_multi_fidelity/formulations/__init__.py +16 -0
- gemseo_multi_fidelity/formulations/refinement.py +144 -0
- gemseo_multi_fidelity/mapping/__init__.py +16 -0
- gemseo_multi_fidelity/mapping/identity_mapper.py +74 -0
- gemseo_multi_fidelity/mapping/interp_mapper.py +422 -0
- gemseo_multi_fidelity/mapping/mapper_factory.py +70 -0
- gemseo_multi_fidelity/mapping/mapping_errors.py +46 -0
- gemseo_multi_fidelity/mapping/subset_mapper.py +122 -0
- gemseo_multi_fidelity/mf_rosenbrock/__init__.py +16 -0
- gemseo_multi_fidelity/mf_rosenbrock/delayed_disc.py +136 -0
- gemseo_multi_fidelity/mf_rosenbrock/refact_rosen_testcase.py +46 -0
- gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_case.py +284 -0
- gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_funcs.py +350 -0
- gemseo_multi_fidelity/models/__init__.py +16 -0
- gemseo_multi_fidelity/models/fake_updater.py +112 -0
- gemseo_multi_fidelity/models/model_updater.py +91 -0
- gemseo_multi_fidelity/models/rbf/__init__.py +16 -0
- gemseo_multi_fidelity/models/rbf/kernel_factory.py +66 -0
- gemseo_multi_fidelity/models/rbf/kernels/__init__.py +16 -0
- gemseo_multi_fidelity/models/rbf/kernels/gaussian.py +93 -0
- gemseo_multi_fidelity/models/rbf/kernels/matern_3_2.py +101 -0
- gemseo_multi_fidelity/models/rbf/kernels/matern_5_2.py +101 -0
- gemseo_multi_fidelity/models/rbf/kernels/rbf_kernel.py +172 -0
- gemseo_multi_fidelity/models/rbf/rbf_model.py +422 -0
- gemseo_multi_fidelity/models/sparse_rbf_updater.py +96 -0
- gemseo_multi_fidelity/models/taylor/__init__.py +16 -0
- gemseo_multi_fidelity/models/taylor/taylor.py +212 -0
- gemseo_multi_fidelity/models/taylor_updater.py +66 -0
- gemseo_multi_fidelity/models/updater_factory.py +62 -0
- gemseo_multi_fidelity/settings/__init__.py +16 -0
- gemseo_multi_fidelity/settings/drivers.py +22 -0
- gemseo_multi_fidelity/settings/formulations.py +16 -0
- gemseo_multi_fidelity-0.0.1.dist-info/METADATA +99 -0
- gemseo_multi_fidelity-0.0.1.dist-info/RECORD +76 -0
- gemseo_multi_fidelity-0.0.1.dist-info/WHEEL +5 -0
- gemseo_multi_fidelity-0.0.1.dist-info/entry_points.txt +2 -0
- gemseo_multi_fidelity-0.0.1.dist-info/licenses/LICENSE.txt +165 -0
- gemseo_multi_fidelity-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
|
|
2
|
+
#
|
|
3
|
+
# This program is free software; you can redistribute it and/or
|
|
4
|
+
# modify it under the terms of the GNU Lesser General Public
|
|
5
|
+
# License version 3 as published by the Free Software Foundation.
|
|
6
|
+
#
|
|
7
|
+
# This program is distributed in the hope that it will be useful,
|
|
8
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
9
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
10
|
+
# Lesser General Public License for more details.
|
|
11
|
+
#
|
|
12
|
+
# You should have received a copy of the GNU Lesser General Public License
|
|
13
|
+
# along with this program; if not, write to the Free Software Foundation,
|
|
14
|
+
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
|
15
|
+
|
|
16
|
+
# Copyright (c) 2018 AIRBUS OPERATIONS
|
|
17
|
+
|
|
18
|
+
#
|
|
19
|
+
# Contributors:
|
|
20
|
+
# INITIAL AUTHORS - API and implementation and/or documentation
|
|
21
|
+
# :author: Romain Olivanti
|
|
22
|
+
# OTHER AUTHORS - MACROSCOPIC CHANGES
|
|
23
|
+
"""Gaussian kernel."""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from numpy import exp
|
|
28
|
+
|
|
29
|
+
from gemseo_multi_fidelity.models.rbf.kernels.rbf_kernel import RBFKernel
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class GaussianKernel(RBFKernel):
|
|
33
|
+
"""Gaussian kernel."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, scaling: float = 1.0) -> None:
|
|
36
|
+
"""Initialize the kernel.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
scaling: The kernel scaling.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(scaling=scaling)
|
|
42
|
+
|
|
43
|
+
def set_scaling(self, scaling: float) -> None:
|
|
44
|
+
"""Set the scaling of the kernel.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
scaling: The scaling to set.
|
|
48
|
+
"""
|
|
49
|
+
# Pre-squared
|
|
50
|
+
super().set_scaling(scaling**2)
|
|
51
|
+
|
|
52
|
+
def _compute_val(self, radius: float) -> float:
|
|
53
|
+
"""Compute the kernel value for radius.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
radius: The 1D or 2D ndarray. ??? TODO
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The kernel value.
|
|
60
|
+
"""
|
|
61
|
+
return exp(-self._scaling * radius**2)
|
|
62
|
+
|
|
63
|
+
def _compute_grad_fact(self, radius: float) -> float:
|
|
64
|
+
"""Compute the kernel gradient factor.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
radius: The 1D or 2D ndarray. ??? TODO
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The gradient factor.
|
|
71
|
+
"""
|
|
72
|
+
scale = self._scaling
|
|
73
|
+
return -2.0 * scale * exp(-scale * radius**2)
|
|
74
|
+
|
|
75
|
+
def _compute_hess_fact(self, radius) -> tuple[str]:
|
|
76
|
+
"""Compute the kernel hessians factors.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
radius: The 1D or 2D ndarray. ??? TODO
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
The multiplicative factor, the diagonal factor.
|
|
83
|
+
"""
|
|
84
|
+
scale = self._scaling
|
|
85
|
+
return 4.0 * scale**2 * exp(-scale * radius**2), -0.5 / scale
|
|
86
|
+
|
|
87
|
+
def __str__(self) -> str:
|
|
88
|
+
"""Get the string representation.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
The string representation.
|
|
92
|
+
"""
|
|
93
|
+
return f"K(r) = exp(-{self._scaling:.6g} * r**2)"
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
|
|
2
|
+
#
|
|
3
|
+
# This program is free software; you can redistribute it and/or
|
|
4
|
+
# modify it under the terms of the GNU Lesser General Public
|
|
5
|
+
# License version 3 as published by the Free Software Foundation.
|
|
6
|
+
#
|
|
7
|
+
# This program is distributed in the hope that it will be useful,
|
|
8
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
9
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
10
|
+
# Lesser General Public License for more details.
|
|
11
|
+
#
|
|
12
|
+
# You should have received a copy of the GNU Lesser General Public License
|
|
13
|
+
# along with this program; if not, write to the Free Software Foundation,
|
|
14
|
+
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
|
15
|
+
|
|
16
|
+
# Copyright (c) 2018 AIRBUS OPERATIONS
|
|
17
|
+
|
|
18
|
+
#
|
|
19
|
+
# Contributors:
|
|
20
|
+
# INITIAL AUTHORS - API and implementation and/or documentation
|
|
21
|
+
# :author: Romain Olivanti
|
|
22
|
+
# OTHER AUTHORS - MACROSCOPIC CHANGES
|
|
23
|
+
"""Matern 3/2 kernel."""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from typing import TYPE_CHECKING
|
|
28
|
+
|
|
29
|
+
from numpy import exp
|
|
30
|
+
from numpy import sqrt
|
|
31
|
+
|
|
32
|
+
from gemseo_multi_fidelity.models.rbf.kernels.rbf_kernel import RBFKernel
|
|
33
|
+
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from numpy.typing import NDArray
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Matern32Kernel(RBFKernel):
|
|
39
|
+
"""Matern 3/2 kernel."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, scaling: float = 1.0) -> None:
|
|
42
|
+
"""Initialize the kernel.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
scaling: The kernel scaling.
|
|
46
|
+
"""
|
|
47
|
+
super().__init__(scaling=scaling)
|
|
48
|
+
|
|
49
|
+
def set_scaling(self, scaling: float) -> None:
|
|
50
|
+
"""Set the scaling of the kernel.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
scaling: The scaling to set.
|
|
54
|
+
"""
|
|
55
|
+
# Pre-multiplied by the Matern 3/2 factor
|
|
56
|
+
super().set_scaling(sqrt(3.0) * scaling)
|
|
57
|
+
|
|
58
|
+
def _compute_val(self, radius: NDArray) -> float:
|
|
59
|
+
"""Compute the kernel value for radius.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
radius: The radius (1D or 2D array).
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The kernel value.
|
|
66
|
+
"""
|
|
67
|
+
arg = self._scaling * radius
|
|
68
|
+
return (1.0 + arg) * exp(-arg)
|
|
69
|
+
|
|
70
|
+
def _compute_grad_fact(self, radius: float) -> float:
|
|
71
|
+
"""Compute the kernel gradient factor.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
radius: The radius (1D or 2D array).
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
The kernel gradient factor.
|
|
78
|
+
"""
|
|
79
|
+
scale = self._scaling
|
|
80
|
+
return -(scale**2) * exp(-scale * radius)
|
|
81
|
+
|
|
82
|
+
def _compute_hess_fact(self, radius: float) -> tuple[float, float]:
|
|
83
|
+
"""Compute the kernel hessians factors.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
radius: The radius (1D or 2D array).
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
The tuple (multiplicative factor, diagonal factor).
|
|
90
|
+
"""
|
|
91
|
+
scale = self._scaling
|
|
92
|
+
return scale**3 * exp(-scale * radius) / radius, -radius / scale
|
|
93
|
+
|
|
94
|
+
def __str__(self) -> str:
|
|
95
|
+
"""Get the string representation.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
The string representation.
|
|
99
|
+
"""
|
|
100
|
+
arg = f"{self._scaling:.6g} * r"
|
|
101
|
+
return f"K(r) = (1. + {arg}) * exp(-{arg})"
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
|
|
2
|
+
#
|
|
3
|
+
# This program is free software; you can redistribute it and/or
|
|
4
|
+
# modify it under the terms of the GNU Lesser General Public
|
|
5
|
+
# License version 3 as published by the Free Software Foundation.
|
|
6
|
+
#
|
|
7
|
+
# This program is distributed in the hope that it will be useful,
|
|
8
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
9
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
10
|
+
# Lesser General Public License for more details.
|
|
11
|
+
#
|
|
12
|
+
# You should have received a copy of the GNU Lesser General Public License
|
|
13
|
+
# along with this program; if not, write to the Free Software Foundation,
|
|
14
|
+
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
|
15
|
+
|
|
16
|
+
# Copyright (c) 2018 AIRBUS OPERATIONS
|
|
17
|
+
|
|
18
|
+
#
|
|
19
|
+
# Contributors:
|
|
20
|
+
# INITIAL AUTHORS - API and implementation and/or documentation
|
|
21
|
+
# :author: Romain Olivanti
|
|
22
|
+
# OTHER AUTHORS - MACROSCOPIC CHANGES
|
|
23
|
+
"""Matern 5/2 kernel."""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from typing import TYPE_CHECKING
|
|
28
|
+
|
|
29
|
+
from numpy import exp
|
|
30
|
+
from numpy import sqrt
|
|
31
|
+
|
|
32
|
+
from gemseo_multi_fidelity.models.rbf.kernels.rbf_kernel import RBFKernel
|
|
33
|
+
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from numpy.typing import NDArray
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Matern52Kernel(RBFKernel):
|
|
39
|
+
"""Matern 5/2 kernel."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, scaling: float = 1.0) -> None:
|
|
42
|
+
"""Initialize the kernel.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
scaling: The kernel scaling.
|
|
46
|
+
"""
|
|
47
|
+
super().__init__(scaling=scaling)
|
|
48
|
+
|
|
49
|
+
def set_scaling(self, scaling: float) -> None:
|
|
50
|
+
"""Set the scaling of the kernel.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
scaling: The scaling to set.
|
|
54
|
+
"""
|
|
55
|
+
# Pre-multiplied by the Matern 5/2 factor
|
|
56
|
+
super().set_scaling(sqrt(5.0) * scaling)
|
|
57
|
+
|
|
58
|
+
def _compute_val(self, radius: NDArray) -> float:
|
|
59
|
+
"""Compute the kernel value for radius.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
radius: The radius (1D or 2D array).
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The kernel value.
|
|
66
|
+
"""
|
|
67
|
+
arg = self._scaling * radius
|
|
68
|
+
return (1.0 + arg + arg**2 / 3.0) * exp(-arg)
|
|
69
|
+
|
|
70
|
+
def _compute_grad_fact(self, radius: float) -> float:
|
|
71
|
+
"""Compute the kernel gradient factor.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
radius: The radius (1D or 2D array).
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
The kernel gradient factor.
|
|
78
|
+
"""
|
|
79
|
+
scale = self._scaling
|
|
80
|
+
return -(scale**2) / 3.0 * (1.0 + scale * radius) * exp(-scale * radius)
|
|
81
|
+
|
|
82
|
+
def _compute_hess_fact(self, radius: float) -> tuple[float, float]:
|
|
83
|
+
"""Compute the kernel hessians factors.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
radius: The radius (1D or 2D array).
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
The tuple (multiplicative factor, diagonal factor).
|
|
90
|
+
"""
|
|
91
|
+
scale = self._scaling
|
|
92
|
+
return scale**4 / 3.0 * exp(-scale * radius), -(1.0 + scale * radius) / scale**2
|
|
93
|
+
|
|
94
|
+
def __str__(self) -> str:
|
|
95
|
+
"""Get the string representation.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
The string representation.
|
|
99
|
+
"""
|
|
100
|
+
arg = f"{self._scaling:.6g} * r"
|
|
101
|
+
return f"K(r) = (1. + {arg} + {arg}**2 / 3.)* exp(-{arg})"
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
|
|
2
|
+
#
|
|
3
|
+
# This program is free software; you can redistribute it and/or
|
|
4
|
+
# modify it under the terms of the GNU Lesser General Public
|
|
5
|
+
# License version 3 as published by the Free Software Foundation.
|
|
6
|
+
#
|
|
7
|
+
# This program is distributed in the hope that it will be useful,
|
|
8
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
9
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
10
|
+
# Lesser General Public License for more details.
|
|
11
|
+
#
|
|
12
|
+
# You should have received a copy of the GNU Lesser General Public License
|
|
13
|
+
# along with this program; if not, write to the Free Software Foundation,
|
|
14
|
+
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
|
15
|
+
|
|
16
|
+
# Copyright (c) 2018 AIRBUS OPERATIONS
|
|
17
|
+
|
|
18
|
+
#
|
|
19
|
+
# Contributors:
|
|
20
|
+
# INITIAL AUTHORS - API and implementation and/or documentation
|
|
21
|
+
# :author: Romain Olivanti
|
|
22
|
+
# OTHER AUTHORS - MACROSCOPIC CHANGES
|
|
23
|
+
"""Abstract RBF Kernel."""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from typing import TYPE_CHECKING
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
from numpy import einsum
|
|
31
|
+
from numpy import fill_diagonal
|
|
32
|
+
from numpy import ones
|
|
33
|
+
from numpy.linalg import norm
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from numpy.typing import NDArray
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RBFKernel:
|
|
40
|
+
"""Abstract RBF Kernel class."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, scaling: float = 1.0) -> None:
|
|
43
|
+
"""Initialize the kernel.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
scaling: The kernel scaling.
|
|
47
|
+
"""
|
|
48
|
+
self._scaling = 1.0
|
|
49
|
+
# Use setter to check the input
|
|
50
|
+
self.set_scaling(scaling)
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def scaling(self) -> float:
|
|
54
|
+
"""Accessor to the scaling property.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
The scaling.
|
|
58
|
+
"""
|
|
59
|
+
return self._scaling
|
|
60
|
+
|
|
61
|
+
def set_scaling(self, scaling: float) -> None:
|
|
62
|
+
"""Set the scaling of the kernel.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
scaling: The scaling to set.
|
|
66
|
+
"""
|
|
67
|
+
if not isinstance(scaling, float) or scaling <= 0.0:
|
|
68
|
+
msg = "The scaling must be a float > 0"
|
|
69
|
+
raise TypeError(msg)
|
|
70
|
+
self._scaling = scaling
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def compute_radius(arg: NDArray) -> float | NDArray:
|
|
74
|
+
"""Compute the input for the kernel.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
arg: The 1D or 2D ndarray.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
The kernel input.
|
|
81
|
+
"""
|
|
82
|
+
norm_axis = len(arg.shape) - 1
|
|
83
|
+
if norm_axis not in [0, 1]:
|
|
84
|
+
msg = "arg must be 1D or 2D ndarray"
|
|
85
|
+
raise ValueError(msg)
|
|
86
|
+
return norm(arg, axis=norm_axis)
|
|
87
|
+
|
|
88
|
+
def __call__(self, arg) -> Any:
|
|
89
|
+
"""Evaluate the kernel for the specified arg.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
arg: The 1D or 2D ndarray.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
The kernel value.
|
|
96
|
+
"""
|
|
97
|
+
radius = self.compute_radius(arg)
|
|
98
|
+
return self._compute_val(radius)
|
|
99
|
+
|
|
100
|
+
def grad(self, arg: NDArray) -> NDArray:
|
|
101
|
+
"""Evaluate the gradient of the kernel for the specified arg.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
arg: The 1D or 2D ndarray.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
The gradient.
|
|
108
|
+
"""
|
|
109
|
+
radius = self.compute_radius(arg)
|
|
110
|
+
grad_fact = self._compute_grad_fact(radius)
|
|
111
|
+
if len(arg.shape) == 1:
|
|
112
|
+
return grad_fact * arg
|
|
113
|
+
return einsum("i,ij->ij", grad_fact, arg)
|
|
114
|
+
|
|
115
|
+
def hess(self, arg: NDArray) -> NDArray:
|
|
116
|
+
"""Evaluate the hessian of the kernel for the specified arg.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
arg: The 1D or 2D ndarray.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
The hessian.
|
|
123
|
+
"""
|
|
124
|
+
radius = self.compute_radius(arg)
|
|
125
|
+
fact, diag_fact = self._compute_hess_fact(radius)
|
|
126
|
+
if len(arg.shape) == 1:
|
|
127
|
+
# 1d input
|
|
128
|
+
ret = einsum("i,j->ij", arg, arg)
|
|
129
|
+
fill_diagonal(ret, ret.diagonal() + diag_fact)
|
|
130
|
+
return fact * ret
|
|
131
|
+
# 2d input
|
|
132
|
+
ret = einsum("ij,il->ijl", arg, arg)
|
|
133
|
+
n_points = ret.shape[0]
|
|
134
|
+
if isinstance(diag_fact, float):
|
|
135
|
+
diag_fact *= ones(n_points)
|
|
136
|
+
|
|
137
|
+
for i in range(n_points):
|
|
138
|
+
fill_diagonal(ret[i], ret[i].diagonal() + diag_fact[i])
|
|
139
|
+
return einsum("i, ijk->ijk", fact, ret)
|
|
140
|
+
|
|
141
|
+
def _compute_val(self, radius: NDArray):
|
|
142
|
+
"""Compute the kernel value for radius.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
radius: The 1D or 2D ndarray.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
The kernel value.
|
|
149
|
+
"""
|
|
150
|
+
raise NotImplementedError
|
|
151
|
+
|
|
152
|
+
def _compute_grad_fact(self, radius: NDArray) -> Any:
|
|
153
|
+
"""Compute the kernel gradient factor.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
radius: The 1D or 2D ndarray.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
The gradient factor.
|
|
160
|
+
"""
|
|
161
|
+
raise NotImplementedError
|
|
162
|
+
|
|
163
|
+
def _compute_hess_fact(self, radius: NDArray) -> tuple[Any, Any]:
|
|
164
|
+
"""Compute the kernel hessians factors.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
radius: The 1D or 2D ndarray.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
The tuple (multiplicative factor, diagonal factor).
|
|
171
|
+
"""
|
|
172
|
+
raise NotImplementedError
|