gemseo-multi-fidelity 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. gemseo_multi_fidelity/__init__.py +17 -0
  2. gemseo_multi_fidelity/core/MFMapperAdapter_input.json +22 -0
  3. gemseo_multi_fidelity/core/MFMapperAdapter_output.json +22 -0
  4. gemseo_multi_fidelity/core/MFMapperLinker_input.json +22 -0
  5. gemseo_multi_fidelity/core/MFMapperLinker_output.json +22 -0
  6. gemseo_multi_fidelity/core/MFScenarioAdapter_input.json +39 -0
  7. gemseo_multi_fidelity/core/MFScenarioAdapter_output.json +23 -0
  8. gemseo_multi_fidelity/core/__init__.py +16 -0
  9. gemseo_multi_fidelity/core/boxed_domain.py +242 -0
  10. gemseo_multi_fidelity/core/corr_function.py +411 -0
  11. gemseo_multi_fidelity/core/criticality.py +124 -0
  12. gemseo_multi_fidelity/core/ds_mapper.py +307 -0
  13. gemseo_multi_fidelity/core/errors.py +42 -0
  14. gemseo_multi_fidelity/core/eval_mapper.py +188 -0
  15. gemseo_multi_fidelity/core/id_mapper_adapter.py +61 -0
  16. gemseo_multi_fidelity/core/mapper_adapter.py +126 -0
  17. gemseo_multi_fidelity/core/mapper_linker.py +72 -0
  18. gemseo_multi_fidelity/core/mf_formulation.py +635 -0
  19. gemseo_multi_fidelity/core/mf_logger.py +216 -0
  20. gemseo_multi_fidelity/core/mf_opt_problem.py +480 -0
  21. gemseo_multi_fidelity/core/mf_scenario.py +205 -0
  22. gemseo_multi_fidelity/core/noise_criterion.py +94 -0
  23. gemseo_multi_fidelity/core/projpolytope.out +0 -0
  24. gemseo_multi_fidelity/core/scenario_adapter.py +568 -0
  25. gemseo_multi_fidelity/core/stop_criteria.py +201 -0
  26. gemseo_multi_fidelity/core/strict_chain.py +75 -0
  27. gemseo_multi_fidelity/core/utils_model_quality.py +74 -0
  28. gemseo_multi_fidelity/corrections/__init__.py +16 -0
  29. gemseo_multi_fidelity/corrections/add_corr_function.py +80 -0
  30. gemseo_multi_fidelity/corrections/correction_factory.py +65 -0
  31. gemseo_multi_fidelity/corrections/mul_corr_function.py +86 -0
  32. gemseo_multi_fidelity/drivers/__init__.py +16 -0
  33. gemseo_multi_fidelity/drivers/mf_algo_factory.py +38 -0
  34. gemseo_multi_fidelity/drivers/mf_driver_lib.py +462 -0
  35. gemseo_multi_fidelity/drivers/refinement.py +234 -0
  36. gemseo_multi_fidelity/drivers/settings/__init__.py +16 -0
  37. gemseo_multi_fidelity/drivers/settings/base_mf_driver_settings.py +59 -0
  38. gemseo_multi_fidelity/drivers/settings/mf_refine_settings.py +50 -0
  39. gemseo_multi_fidelity/formulations/__init__.py +16 -0
  40. gemseo_multi_fidelity/formulations/refinement.py +144 -0
  41. gemseo_multi_fidelity/mapping/__init__.py +16 -0
  42. gemseo_multi_fidelity/mapping/identity_mapper.py +74 -0
  43. gemseo_multi_fidelity/mapping/interp_mapper.py +422 -0
  44. gemseo_multi_fidelity/mapping/mapper_factory.py +70 -0
  45. gemseo_multi_fidelity/mapping/mapping_errors.py +46 -0
  46. gemseo_multi_fidelity/mapping/subset_mapper.py +122 -0
  47. gemseo_multi_fidelity/mf_rosenbrock/__init__.py +16 -0
  48. gemseo_multi_fidelity/mf_rosenbrock/delayed_disc.py +136 -0
  49. gemseo_multi_fidelity/mf_rosenbrock/refact_rosen_testcase.py +46 -0
  50. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_case.py +284 -0
  51. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_funcs.py +350 -0
  52. gemseo_multi_fidelity/models/__init__.py +16 -0
  53. gemseo_multi_fidelity/models/fake_updater.py +112 -0
  54. gemseo_multi_fidelity/models/model_updater.py +91 -0
  55. gemseo_multi_fidelity/models/rbf/__init__.py +16 -0
  56. gemseo_multi_fidelity/models/rbf/kernel_factory.py +66 -0
  57. gemseo_multi_fidelity/models/rbf/kernels/__init__.py +16 -0
  58. gemseo_multi_fidelity/models/rbf/kernels/gaussian.py +93 -0
  59. gemseo_multi_fidelity/models/rbf/kernels/matern_3_2.py +101 -0
  60. gemseo_multi_fidelity/models/rbf/kernels/matern_5_2.py +101 -0
  61. gemseo_multi_fidelity/models/rbf/kernels/rbf_kernel.py +172 -0
  62. gemseo_multi_fidelity/models/rbf/rbf_model.py +422 -0
  63. gemseo_multi_fidelity/models/sparse_rbf_updater.py +96 -0
  64. gemseo_multi_fidelity/models/taylor/__init__.py +16 -0
  65. gemseo_multi_fidelity/models/taylor/taylor.py +212 -0
  66. gemseo_multi_fidelity/models/taylor_updater.py +66 -0
  67. gemseo_multi_fidelity/models/updater_factory.py +62 -0
  68. gemseo_multi_fidelity/settings/__init__.py +16 -0
  69. gemseo_multi_fidelity/settings/drivers.py +22 -0
  70. gemseo_multi_fidelity/settings/formulations.py +16 -0
  71. gemseo_multi_fidelity-0.0.1.dist-info/METADATA +99 -0
  72. gemseo_multi_fidelity-0.0.1.dist-info/RECORD +76 -0
  73. gemseo_multi_fidelity-0.0.1.dist-info/WHEEL +5 -0
  74. gemseo_multi_fidelity-0.0.1.dist-info/entry_points.txt +2 -0
  75. gemseo_multi_fidelity-0.0.1.dist-info/licenses/LICENSE.txt +165 -0
  76. gemseo_multi_fidelity-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,201 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2018 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """Stopping criteria."""
24
+
25
+ from __future__ import annotations
26
+
27
+ from typing import TYPE_CHECKING
28
+
29
+ from gemseo_multi_fidelity.core.criticality import crit_out_1
30
+
31
+ if TYPE_CHECKING:
32
+ from numpy.typing import NDArray
33
+
34
+ from gemseo_multi_fidelity.core.boxed_domain import BoxedDomain
35
+
36
+
37
+ def check_positive_float(value: float) -> float:
38
+ """Convenience function to check if value is a positive float.
39
+
40
+ Args:
41
+ value: The value to check.
42
+
43
+ Returns:
44
+ The checked value.
45
+ """
46
+ if not isinstance(value, float):
47
+ msg = f"{value} not a float"
48
+ raise TypeError(msg)
49
+ if not value > 0.0:
50
+ msg = "{} not > 0."
51
+ raise ValueError(msg)
52
+ return value
53
+
54
+
55
+ class MaxIterCriterion:
56
+ """Simple maximum iteration criterion."""
57
+
58
+ def __init__(self, max_iter: int) -> None:
59
+ """Constructor.
60
+
61
+ Args:
62
+ max_iter: The maximum number of iterations.
63
+ """
64
+ self._curr_it = 0
65
+ self._max_it = int(max_iter)
66
+
67
+ @property
68
+ def curr_it(self) -> int:
69
+ """Get the current iteration.
70
+
71
+ Returns:
72
+ The current number of iterations used.
73
+ """
74
+ return self._curr_it
75
+
76
+ def must_stop(self, x_vect=None) -> bool:
77
+ """Check if a stop is triggered by the criterion.
78
+
79
+ Returns:
80
+ ``True`` if a stop is triggered by the criterion, ``False`` otherwise.
81
+ """
82
+ self._curr_it += 1
83
+ return self._curr_it > self._max_it
84
+
85
+
86
+ class ProjGradCriterion:
87
+ """Projected gradient criterion class."""
88
+
89
+ def __init__(self, domain: BoxedDomain, grad_tol: float, bound_tol: float) -> None:
90
+ """Constructor.
91
+
92
+ Args:
93
+ domain: The domain of the optimization.
94
+ grad_tol: The projected gradient tolerance.
95
+ bound_tol: The tolerance on the problem bounds.
96
+ """
97
+ self._grad_tol = None
98
+ self._bound_tol = None
99
+ self.domain = domain
100
+
101
+ self.grad_tol = grad_tol
102
+ self.bound_tol = bound_tol
103
+
104
+ @property
105
+ def grad_tol(self) -> float:
106
+ """Get the projected gradient tolerance.
107
+
108
+ Returns:
109
+ The projected gradient tolerance.
110
+ """
111
+ return self._grad_tol
112
+
113
+ @grad_tol.setter
114
+ def grad_tol(self, grad_tol: float) -> None:
115
+ """Set the projected gradient tolerance.
116
+
117
+ Args:
118
+ grad_tol: The projected gradient tolerance.
119
+ """
120
+ self._grad_tol = check_positive_float(grad_tol)
121
+
122
+ @property
123
+ def bound_tol(self) -> float:
124
+ """Get the bound tolerance.
125
+
126
+ Returns:
127
+ The bound tolerance.
128
+ """
129
+ return self._bound_tol
130
+
131
+ @bound_tol.setter
132
+ def bound_tol(self, bound_tol) -> None:
133
+ """Set the projected gradient tolerance.
134
+
135
+ Args:
136
+ bound_tol: The tolerance on the problem bounds.
137
+ """
138
+ self._bound_tol = check_positive_float(bound_tol)
139
+
140
+ def must_stop(self, x_vect: NDArray, grad: NDArray) -> bool:
141
+ """Check if a stop is triggered by the criterion.
142
+
143
+ Args:
144
+ x_vect: The current point.
145
+ grad: The current gradient.
146
+
147
+ Returns:
148
+ ``True`` if a stop is triggered by the criterion, ``False`` otherwise.
149
+ """
150
+ return (
151
+ crit_out_1(x_vect, grad, self.domain, self._grad_tol, self._bound_tol)
152
+ < self._grad_tol
153
+ )
154
+
155
+
156
+ class DecrCriterion:
157
+ """Decrease criterion class."""
158
+
159
+ def __init__(self, decr_tol: float, n_under=1) -> None:
160
+ """Constructor.
161
+
162
+ Args:
163
+ decr_tol: The decrease tolerance.
164
+ n_under: The number of decrease below the tolerance to trigger the stop.
165
+ """
166
+ self._decr_tol = None
167
+ self.n_under_crit = n_under
168
+ self._curr_n_under = 0
169
+ self.decr_tol = decr_tol
170
+
171
+ @property
172
+ def decr_tol(self) -> float:
173
+ """Get the decrease tolerance.
174
+
175
+ Returns:
176
+ The decrease tolerance.
177
+ """
178
+ return self._decr_tol
179
+
180
+ @decr_tol.setter
181
+ def decr_tol(self, decr_tol) -> None:
182
+ """Set the decrease tolerance.
183
+
184
+ Args:
185
+ decr_tol: The decrease tolerance.
186
+ """
187
+ self._decr_tol = check_positive_float(decr_tol)
188
+
189
+ def must_stop(self, decr: float) -> bool:
190
+ """Check if a stop is triggered by the criterion.
191
+
192
+ Args:
193
+ decr: The current decrease.
194
+
195
+ Returns:
196
+ ``True`` if a stop is triggered by the criterion, ``False`` otherwise.
197
+ """
198
+ if 0.0 < decr < self.decr_tol:
199
+ self._curr_n_under += 1
200
+
201
+ return self._curr_n_under >= self.n_under_crit
@@ -0,0 +1,75 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """Strict MDO chain."""
24
+
25
+ from __future__ import annotations
26
+
27
+ from gemseo.core.chains.chain import MDOChain
28
+
29
+
30
+ class StrictChain(MDOChain):
31
+ """Strict MDO chain class.
32
+
33
+ Chain which does not store the local_data of each discipline in self.local_data.
34
+ Only the intermediary outputs update self.local_data.
35
+ """
36
+
37
+ default_grammar_type = "JSONGrammar"
38
+ """The default grammar type."""
39
+
40
+ def _initialize_grammars(self) -> None:
41
+ """Initialize grammars.
42
+
43
+ Defines the input and output grammars from the disciplines' ones.
44
+ """
45
+ self.io.input_grammar.clear()
46
+ self.io.output_grammar.clear()
47
+ self.io.input_grammar.update(self.disciplines[0].input_grammar)
48
+ self.io.output_grammar.update(self.disciplines[-1].output_grammar)
49
+
50
+ # def initialize_grammars(self):
51
+ # """Defines all inputs and outputs of the chain. Inputs are restricted
52
+ # to the ones of the first discipline, and the outputs to the ones of the
53
+ # last discipline."""
54
+ # self.input_grammar.clear()
55
+ # self.output_grammar.clear()
56
+ # self.input_grammar.update(self.disciplines[0].input_grammar)
57
+ # self.output_grammar.update(self.disciplines[-1].output_grammar)
58
+ #
59
+ # # Delete the required properties from the input grammar of the chain.
60
+ # # TODO verify ok-ness
61
+ # #for name in self.input_grammar.schema_dict.get("required", []):
62
+ # #self.input_grammar.remove_item(name)
63
+ # self.input_grammar.required_names.clear()
64
+ #
65
+ # #self.input_grammar._update_properties()
66
+
67
+ def _execute(self) -> None:
68
+ """Run the chain."""
69
+ for discipline in self.disciplines:
70
+ discipline.execute(self.io.data)
71
+ # Erase all local data
72
+ self.io.data = {}
73
+ # Only store the outputs of the previous discipline in self.local_data
74
+ outputs = discipline.get_output_data()
75
+ self.io.data.update(outputs)
@@ -0,0 +1,74 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """Model quality utilities."""
24
+
25
+ from __future__ import annotations
26
+
27
+ import logging
28
+
29
+ from numpy import finfo
30
+
31
+ MACHINE_EPS = finfo(float).eps
32
+ LOGGER = logging.getLogger(__name__)
33
+
34
+
35
+ def compute_accuracy_ratio(real_decr: float, pred_decr: float) -> float:
36
+ """Compute the accuracy ratio.
37
+
38
+ Computes the accuracy ratio real_decrease / predicted decrease in a safe manner.
39
+ It can be used to assess the quality of a model.
40
+
41
+ Args:
42
+ real_decr: The real decrease.
43
+ pred_decr: The predicted decrease.
44
+
45
+ Returns:
46
+ The accuracy ratio.
47
+ """
48
+ # Current ratio
49
+ if abs(pred_decr) < MACHINE_EPS:
50
+ # The sign is not guaranteed
51
+ msg = "Safeguard model accuracy: "
52
+ if abs(real_decr) < MACHINE_EPS:
53
+ if pred_decr == 0.0 and real_decr == 0.0:
54
+ # Exactly zero, no progress can be achieved
55
+ msg += "no improvement"
56
+ ratio = -1.0
57
+ else:
58
+ # Real inaccuracy, hazardous to perform the standard computation.
59
+ # The best option is to force the success
60
+ msg += "both the predicted and real decrease are close "
61
+ msg += "to the machine precision"
62
+ ratio = 1.0
63
+ else:
64
+ # The sign of the model decrease is not guaranteed, and it underpredicts the
65
+ # real decrease / increase
66
+ # Force failure
67
+ msg = "only the predicted decrease is close to the "
68
+ msg += "machine precision"
69
+ ratio = -1.0
70
+ LOGGER.info(msg)
71
+ else:
72
+ # Standard computation
73
+ ratio = real_decr / pred_decr
74
+ return ratio
@@ -0,0 +1,16 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ """Corrections."""
@@ -0,0 +1,80 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """Additive correction."""
24
+
25
+ from __future__ import annotations
26
+
27
+ from typing import TYPE_CHECKING
28
+
29
+ from gemseo_multi_fidelity.core.corr_function import MDOCorrectedFunction
30
+
31
+ if TYPE_CHECKING:
32
+ from numpy.typing import NDArray
33
+
34
+
35
+ class MDOAddCorrFunc(MDOCorrectedFunction):
36
+ """Additive correction."""
37
+
38
+ def _build_func_pointers(self) -> None:
39
+ # Rather simple here, benefit from GEMSEO overloading
40
+ add_func = self._orig_func + self._corr_model.function
41
+ self.expr = add_func.expr
42
+
43
+ def func(x_vect: NDArray) -> NDArray:
44
+ if self._corr_model.is_set:
45
+ out = add_func.func(x_vect)
46
+ else:
47
+ out = self._orig_func.func(x_vect)
48
+ return out
49
+
50
+ self.func = func
51
+
52
+ if self._orig_func.has_jac:
53
+
54
+ def jac(x_vect):
55
+ if self._corr_model.is_set:
56
+ out = add_func.jac(x_vect)
57
+ else:
58
+ out = self._orig_func.jac(x_vect)
59
+ return out
60
+
61
+ self.jac = jac
62
+
63
+ def _compute_correction(
64
+ self,
65
+ x_ref: NDArray,
66
+ val_ref: NDArray,
67
+ grad_ref: NDArray = None,
68
+ hess_ref: NDArray = None,
69
+ ) -> tuple[NDArray, NDArray, NDArray]:
70
+ # Simple difference
71
+ val_corr = val_ref - self._orig_func.evaluate(x_ref)
72
+ grad_corr = None
73
+ hess_corr = None
74
+
75
+ if grad_ref is not None:
76
+ grad_corr = grad_ref - self._orig_func.jac(x_ref)
77
+ if hess_ref is not None:
78
+ hess_corr = hess_ref - self._orig_func.hess(x_ref)
79
+
80
+ return val_corr, grad_corr, hess_corr
@@ -0,0 +1,65 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2018 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """A factory to create a corrected function from its name."""
24
+
25
+ from __future__ import annotations
26
+
27
+ from typing import TYPE_CHECKING
28
+ from typing import ClassVar
29
+
30
+ from gemseo_multi_fidelity.corrections.add_corr_function import MDOAddCorrFunc
31
+ from gemseo_multi_fidelity.corrections.mul_corr_function import MDOMulCorrFunc
32
+
33
+ if TYPE_CHECKING:
34
+ from gemseo_multi_fidelity.core.corr_function import MDOCorrectedFunction
35
+
36
+
37
+ class CorrectionFactory:
38
+ """Correction Factory."""
39
+
40
+ ADD_CORR = "additive"
41
+ MUL_CORR = "multiplicative"
42
+
43
+ AVAILABLE_CORRS: ClassVar[list] = [ADD_CORR, MUL_CORR]
44
+ """Available corrections."""
45
+
46
+ def create(self, name: str, *args, **kwargs) -> MDOCorrectedFunction:
47
+ """Create a corrected function from its name.
48
+
49
+ Args:
50
+ name: The name of the corrected function to use.
51
+ *args: The positional arguments to build the corrected function instance.
52
+ **kwargs: The keywords arguments to build the corrected function instance.
53
+
54
+ Returns:
55
+ The corrected function.
56
+ """
57
+ if name == self.ADD_CORR:
58
+ klass = MDOAddCorrFunc
59
+ elif name == self.MUL_CORR:
60
+ klass = MDOMulCorrFunc
61
+ else:
62
+ msg = f"{name} not in {self.AVAILABLE_CORRS}"
63
+ raise ValueError(msg)
64
+
65
+ return klass(*args, **kwargs)
@@ -0,0 +1,86 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """Multiplicative correction."""
24
+
25
+ from __future__ import annotations
26
+
27
+ from typing import TYPE_CHECKING
28
+
29
+ from numpy import einsum
30
+
31
+ from gemseo_multi_fidelity.core.corr_function import MDOCorrectedFunction
32
+
33
+ if TYPE_CHECKING:
34
+ from numpy.typing import NDArray
35
+
36
+
37
+ class MDOMulCorrFunc(MDOCorrectedFunction):
38
+ """Multiplicative correction."""
39
+
40
+ def _build_func_pointers(self) -> None:
41
+ # Rather simple here, benefit from GEMSEO overloading
42
+ mul_func = self._orig_func * self._corr_model.function
43
+ self.expr = mul_func.expr
44
+
45
+ def func(x_vect: NDArray) -> NDArray:
46
+ if self._corr_model.is_set:
47
+ out = mul_func.func(x_vect)
48
+ else:
49
+ out = self._orig_func.func(x_vect)
50
+ return out
51
+
52
+ self.func = func
53
+
54
+ if self._orig_func.has_jac:
55
+
56
+ def jac(x_vect):
57
+ if self._corr_model.is_set:
58
+ out = mul_func.jac(x_vect)
59
+ else:
60
+ out = self._orig_func.jac(x_vect)
61
+ return out
62
+
63
+ self.jac = jac
64
+
65
+ def _compute_correction(
66
+ self,
67
+ x_ref: NDArray,
68
+ val_ref: NDArray,
69
+ grad_ref: NDArray = None,
70
+ hess_ref: NDArray = None,
71
+ ) -> tuple[NDArray, NDArray, NDArray]:
72
+ val_func = self._orig_func(x_ref)
73
+ val_corr = val_ref / val_func
74
+ grad_corr = None
75
+ hess_corr = None
76
+
77
+ if grad_ref is not None:
78
+ grad_func = self._orig_func.jac(x_ref)
79
+ grad_corr = (grad_ref - val_corr * grad_func) / val_func
80
+ if hess_ref is not None:
81
+ tmp = einsum("i,j->ij", grad_func, grad_corr)
82
+ hess_corr = (
83
+ hess_ref - tmp - tmp.T - self._orig_func.hess(x_ref) * val_corr
84
+ ) / val_func
85
+
86
+ return val_corr, grad_corr, hess_corr
@@ -0,0 +1,16 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ """Drivers."""
@@ -0,0 +1,38 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ # :author: Francois Gallard - portage to GEMSEO 6
24
+ """Multi-fidelity algorithms factory."""
25
+
26
+ from __future__ import annotations
27
+
28
+ from gemseo.algos.base_algo_factory import BaseAlgoFactory
29
+
30
+ from gemseo_multi_fidelity.drivers.mf_driver_lib import MFDriverLibrary
31
+
32
+
33
+ class MFAlgoFactory(BaseAlgoFactory):
34
+ """Multi-fidelity driver factory."""
35
+
36
+ _CLASS = MFDriverLibrary
37
+
38
+ _PACKAGE_NAMES = ("gemseo_multi_fidelity.drivers",)