gemseo-multi-fidelity 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gemseo_multi_fidelity/__init__.py +17 -0
- gemseo_multi_fidelity/core/MFMapperAdapter_input.json +22 -0
- gemseo_multi_fidelity/core/MFMapperAdapter_output.json +22 -0
- gemseo_multi_fidelity/core/MFMapperLinker_input.json +22 -0
- gemseo_multi_fidelity/core/MFMapperLinker_output.json +22 -0
- gemseo_multi_fidelity/core/MFScenarioAdapter_input.json +39 -0
- gemseo_multi_fidelity/core/MFScenarioAdapter_output.json +23 -0
- gemseo_multi_fidelity/core/__init__.py +16 -0
- gemseo_multi_fidelity/core/boxed_domain.py +242 -0
- gemseo_multi_fidelity/core/corr_function.py +411 -0
- gemseo_multi_fidelity/core/criticality.py +124 -0
- gemseo_multi_fidelity/core/ds_mapper.py +307 -0
- gemseo_multi_fidelity/core/errors.py +42 -0
- gemseo_multi_fidelity/core/eval_mapper.py +188 -0
- gemseo_multi_fidelity/core/id_mapper_adapter.py +61 -0
- gemseo_multi_fidelity/core/mapper_adapter.py +126 -0
- gemseo_multi_fidelity/core/mapper_linker.py +72 -0
- gemseo_multi_fidelity/core/mf_formulation.py +635 -0
- gemseo_multi_fidelity/core/mf_logger.py +216 -0
- gemseo_multi_fidelity/core/mf_opt_problem.py +480 -0
- gemseo_multi_fidelity/core/mf_scenario.py +205 -0
- gemseo_multi_fidelity/core/noise_criterion.py +94 -0
- gemseo_multi_fidelity/core/projpolytope.out +0 -0
- gemseo_multi_fidelity/core/scenario_adapter.py +568 -0
- gemseo_multi_fidelity/core/stop_criteria.py +201 -0
- gemseo_multi_fidelity/core/strict_chain.py +75 -0
- gemseo_multi_fidelity/core/utils_model_quality.py +74 -0
- gemseo_multi_fidelity/corrections/__init__.py +16 -0
- gemseo_multi_fidelity/corrections/add_corr_function.py +80 -0
- gemseo_multi_fidelity/corrections/correction_factory.py +65 -0
- gemseo_multi_fidelity/corrections/mul_corr_function.py +86 -0
- gemseo_multi_fidelity/drivers/__init__.py +16 -0
- gemseo_multi_fidelity/drivers/mf_algo_factory.py +38 -0
- gemseo_multi_fidelity/drivers/mf_driver_lib.py +462 -0
- gemseo_multi_fidelity/drivers/refinement.py +234 -0
- gemseo_multi_fidelity/drivers/settings/__init__.py +16 -0
- gemseo_multi_fidelity/drivers/settings/base_mf_driver_settings.py +59 -0
- gemseo_multi_fidelity/drivers/settings/mf_refine_settings.py +50 -0
- gemseo_multi_fidelity/formulations/__init__.py +16 -0
- gemseo_multi_fidelity/formulations/refinement.py +144 -0
- gemseo_multi_fidelity/mapping/__init__.py +16 -0
- gemseo_multi_fidelity/mapping/identity_mapper.py +74 -0
- gemseo_multi_fidelity/mapping/interp_mapper.py +422 -0
- gemseo_multi_fidelity/mapping/mapper_factory.py +70 -0
- gemseo_multi_fidelity/mapping/mapping_errors.py +46 -0
- gemseo_multi_fidelity/mapping/subset_mapper.py +122 -0
- gemseo_multi_fidelity/mf_rosenbrock/__init__.py +16 -0
- gemseo_multi_fidelity/mf_rosenbrock/delayed_disc.py +136 -0
- gemseo_multi_fidelity/mf_rosenbrock/refact_rosen_testcase.py +46 -0
- gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_case.py +284 -0
- gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_funcs.py +350 -0
- gemseo_multi_fidelity/models/__init__.py +16 -0
- gemseo_multi_fidelity/models/fake_updater.py +112 -0
- gemseo_multi_fidelity/models/model_updater.py +91 -0
- gemseo_multi_fidelity/models/rbf/__init__.py +16 -0
- gemseo_multi_fidelity/models/rbf/kernel_factory.py +66 -0
- gemseo_multi_fidelity/models/rbf/kernels/__init__.py +16 -0
- gemseo_multi_fidelity/models/rbf/kernels/gaussian.py +93 -0
- gemseo_multi_fidelity/models/rbf/kernels/matern_3_2.py +101 -0
- gemseo_multi_fidelity/models/rbf/kernels/matern_5_2.py +101 -0
- gemseo_multi_fidelity/models/rbf/kernels/rbf_kernel.py +172 -0
- gemseo_multi_fidelity/models/rbf/rbf_model.py +422 -0
- gemseo_multi_fidelity/models/sparse_rbf_updater.py +96 -0
- gemseo_multi_fidelity/models/taylor/__init__.py +16 -0
- gemseo_multi_fidelity/models/taylor/taylor.py +212 -0
- gemseo_multi_fidelity/models/taylor_updater.py +66 -0
- gemseo_multi_fidelity/models/updater_factory.py +62 -0
- gemseo_multi_fidelity/settings/__init__.py +16 -0
- gemseo_multi_fidelity/settings/drivers.py +22 -0
- gemseo_multi_fidelity/settings/formulations.py +16 -0
- gemseo_multi_fidelity-0.0.1.dist-info/METADATA +99 -0
- gemseo_multi_fidelity-0.0.1.dist-info/RECORD +76 -0
- gemseo_multi_fidelity-0.0.1.dist-info/WHEEL +5 -0
- gemseo_multi_fidelity-0.0.1.dist-info/entry_points.txt +2 -0
- gemseo_multi_fidelity-0.0.1.dist-info/licenses/LICENSE.txt +165 -0
- gemseo_multi_fidelity-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
|
|
2
|
+
#
|
|
3
|
+
# This program is free software; you can redistribute it and/or
|
|
4
|
+
# modify it under the terms of the GNU Lesser General Public
|
|
5
|
+
# License version 3 as published by the Free Software Foundation.
|
|
6
|
+
#
|
|
7
|
+
# This program is distributed in the hope that it will be useful,
|
|
8
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
9
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
10
|
+
# Lesser General Public License for more details.
|
|
11
|
+
#
|
|
12
|
+
# You should have received a copy of the GNU Lesser General Public License
|
|
13
|
+
# along with this program; if not, write to the Free Software Foundation,
|
|
14
|
+
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
|
15
|
+
|
|
16
|
+
# Copyright (c) 2019 AIRBUS OPERATIONS
|
|
17
|
+
#
|
|
18
|
+
# Contributors:
|
|
19
|
+
# INITIAL AUTHORS - API and implementation and/or documentation
|
|
20
|
+
# :author: Romain Olivanti
|
|
21
|
+
# OTHER AUTHORS - MACROSCOPIC CHANGES
|
|
22
|
+
"""MDOCorrected Function."""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from typing import TYPE_CHECKING
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
from gemseo.algos.database import Database
|
|
30
|
+
from gemseo.core.mdo_functions.mdo_function import MDOFunction
|
|
31
|
+
from numpy import array
|
|
32
|
+
from numpy import delete
|
|
33
|
+
from numpy import flip
|
|
34
|
+
from numpy import where
|
|
35
|
+
from numpy.linalg import norm
|
|
36
|
+
|
|
37
|
+
from gemseo_multi_fidelity.models.model_updater import ModelUpdater
|
|
38
|
+
|
|
39
|
+
if TYPE_CHECKING:
|
|
40
|
+
from numpy import ndarray
|
|
41
|
+
from numpy.typing import NDArray
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class MDOCorrectedFunction(MDOFunction):
|
|
45
|
+
"""MDO Corrected Function class.
|
|
46
|
+
|
|
47
|
+
Corrected function, that can be dynamically updated using the provided correction
|
|
48
|
+
database. Protects the calls to the original function using this database and store
|
|
49
|
+
the reference data provided.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
REF_TAG = "#CORR_REF#"
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self, orig_func: MDOFunction, corr_model: ModelUpdater, corr_database: Database
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Constructor.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
orig_func: The uncorrected function.
|
|
61
|
+
corr_model: The model used for the correction.
|
|
62
|
+
corr_database: The database used for the correction.
|
|
63
|
+
"""
|
|
64
|
+
if not isinstance(orig_func, MDOFunction):
|
|
65
|
+
msg = "orig_func must be a MDOFunction"
|
|
66
|
+
raise TypeError(msg)
|
|
67
|
+
|
|
68
|
+
if not isinstance(corr_database, Database):
|
|
69
|
+
msg = "corr_database must be a Database"
|
|
70
|
+
raise TypeError(msg)
|
|
71
|
+
|
|
72
|
+
if not isinstance(corr_model, ModelUpdater):
|
|
73
|
+
msg = "corr_model must be an instance of ModelUpdater"
|
|
74
|
+
raise TypeError(msg)
|
|
75
|
+
|
|
76
|
+
# Clone the specified attr
|
|
77
|
+
args_names = ["func", "name"]
|
|
78
|
+
args = [getattr(orig_func, attr) for attr in args_names]
|
|
79
|
+
# args[-1] = "corr_" + args[-1]
|
|
80
|
+
kwargs_names = ["f_type", "jac", "expr", "input_names", "dim", "output_names"]
|
|
81
|
+
kwargs = {attr: getattr(orig_func, attr) for attr in kwargs_names}
|
|
82
|
+
|
|
83
|
+
super().__init__(*args, **kwargs)
|
|
84
|
+
# Reference to the database used to protect the calls to the original
|
|
85
|
+
# functions and to store the reference data.
|
|
86
|
+
self._corr_database = corr_database
|
|
87
|
+
self._corr_model = corr_model
|
|
88
|
+
self._orig_func = self._process_orig_func(orig_func)
|
|
89
|
+
self._build_func_pointers()
|
|
90
|
+
|
|
91
|
+
def get_orig_func(self) -> MDOFunction:
|
|
92
|
+
"""Get the original function.
|
|
93
|
+
|
|
94
|
+
Returns a pointer to the protected version of the original function (protected
|
|
95
|
+
by the database).
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
The original function (callback).
|
|
99
|
+
"""
|
|
100
|
+
return self._orig_func
|
|
101
|
+
|
|
102
|
+
def get_orig_database(self) -> Database:
|
|
103
|
+
"""Return the database used to protect the data.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
The database.
|
|
107
|
+
"""
|
|
108
|
+
return self._corr_database
|
|
109
|
+
|
|
110
|
+
def _process_orig_func(self, orig_func: MDOFunction) -> MDOFunction:
|
|
111
|
+
"""Process orig_func to make sure that its call if protected by a database.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
orig_func: The function to wrap.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
The protected version of orig_func.
|
|
118
|
+
"""
|
|
119
|
+
# Make sure it is protected by the database
|
|
120
|
+
fname = orig_func.name
|
|
121
|
+
|
|
122
|
+
def wrapped_orig_func(x_vect: NDArray) -> float:
|
|
123
|
+
"""Wrapped original function.
|
|
124
|
+
|
|
125
|
+
Makes sure each call is stored in the provided database.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
x_vect: The design variable.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
The evaluation of the function at x_vect.
|
|
132
|
+
"""
|
|
133
|
+
value = None
|
|
134
|
+
if self._corr_database.get(x_vect, False):
|
|
135
|
+
# x_vect is part of the database
|
|
136
|
+
value = self._corr_database.get_function_value(fname, x_vect)
|
|
137
|
+
|
|
138
|
+
if value is None:
|
|
139
|
+
# Not evaluated yet, evaluate
|
|
140
|
+
value = orig_func.func(x_vect)
|
|
141
|
+
values_dict = {fname: value}
|
|
142
|
+
# Store the new value
|
|
143
|
+
self._corr_database.store(x_vect, values_dict)
|
|
144
|
+
return value
|
|
145
|
+
|
|
146
|
+
db_func = MDOFunction(
|
|
147
|
+
wrapped_orig_func,
|
|
148
|
+
name=fname,
|
|
149
|
+
f_type=orig_func.f_type,
|
|
150
|
+
expr=orig_func.expr,
|
|
151
|
+
input_names=orig_func.input_names,
|
|
152
|
+
dim=orig_func.dim,
|
|
153
|
+
output_names=orig_func.output_names,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if orig_func.has_jac:
|
|
157
|
+
|
|
158
|
+
def dwrapped_orig_func(x_vect: Any) -> NDArray:
|
|
159
|
+
"""Wrapped original jacobian.
|
|
160
|
+
|
|
161
|
+
Makes sure each call is stored in the provided database.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
x_vect: The design variable.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
The evaluation of the jacobian at x_vect.
|
|
168
|
+
"""
|
|
169
|
+
jac = None
|
|
170
|
+
if self._corr_database.get(x_vect, False):
|
|
171
|
+
# x_vect is part of the database
|
|
172
|
+
jac = self._corr_database.get_function_value(
|
|
173
|
+
Database.GRAD_TAG + fname, x_vect
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if jac is None:
|
|
177
|
+
# not evaluated yet, evaluate
|
|
178
|
+
jac = orig_func.jac(x_vect).real
|
|
179
|
+
values_dict = {Database.GRAD_TAG + fname: jac}
|
|
180
|
+
# Store the new value
|
|
181
|
+
self._corr_database.store(x_vect, values_dict)
|
|
182
|
+
return jac
|
|
183
|
+
|
|
184
|
+
db_func.jac = dwrapped_orig_func
|
|
185
|
+
|
|
186
|
+
# Return only the reference of the db_func
|
|
187
|
+
return db_func
|
|
188
|
+
|
|
189
|
+
def _build_func_pointers(self):
|
|
190
|
+
"""Set the corrected function. To be overloaded by subclasses."""
|
|
191
|
+
raise NotImplementedError
|
|
192
|
+
|
|
193
|
+
def _compute_correction(
|
|
194
|
+
self,
|
|
195
|
+
x_ref: NDArray,
|
|
196
|
+
val_ref: NDArray,
|
|
197
|
+
grad_ref: NDArray = None,
|
|
198
|
+
hess_ref: NDArray = None,
|
|
199
|
+
):
|
|
200
|
+
"""Compute correction.
|
|
201
|
+
|
|
202
|
+
Actual method to compute the correction from the reference data.
|
|
203
|
+
To be overloaded by subclasses.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
x_ref: The variables.
|
|
207
|
+
val_ref: The value to match at x_ref.
|
|
208
|
+
grad_ref: The gradient to match at x_ref.
|
|
209
|
+
hess_ref: The hessian to match at x_ref.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
value correction, gradient correction, hessian correction.
|
|
213
|
+
"""
|
|
214
|
+
raise NotImplementedError
|
|
215
|
+
|
|
216
|
+
def set_reference(
|
|
217
|
+
self,
|
|
218
|
+
x_ref: NDArray,
|
|
219
|
+
val_ref: Any,
|
|
220
|
+
grad_ref: Any = None,
|
|
221
|
+
hess_ref: Any = None,
|
|
222
|
+
max_norm_thr: float | None = None,
|
|
223
|
+
) -> None:
|
|
224
|
+
"""Set reference.
|
|
225
|
+
|
|
226
|
+
Sets the reference so that the corrected function matches the reference data.
|
|
227
|
+
Previous correction data with the same order of consistency, whose norm with
|
|
228
|
+
x_ref is within max_norm_thr, is also used if the correction model handles it.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
x_ref: The design variables.
|
|
232
|
+
val_ref: The value to match at x_ref.
|
|
233
|
+
grad_ref: The gradient to match at x_ref.
|
|
234
|
+
hess_ref: The hessian to match at x_ref.
|
|
235
|
+
max_norm_thr: The threshold to discard points too far from x_ref.
|
|
236
|
+
"""
|
|
237
|
+
# Always recompute the correction with the specified reference value even
|
|
238
|
+
# if already stored in the database
|
|
239
|
+
val_corr, grad_corr, hess_corr = self._compute_correction(
|
|
240
|
+
x_ref, val_ref, grad_ref=grad_ref, hess_ref=hess_ref
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Store the reference data
|
|
244
|
+
# Storing the reference rather than the correction data allow
|
|
245
|
+
# not to be correction-dependent for external uses
|
|
246
|
+
self.store_reference_data(x_ref, val_ref, grad_ref=grad_ref, hess_ref=hess_ref)
|
|
247
|
+
|
|
248
|
+
x_update = [x_ref]
|
|
249
|
+
val_update = [val_corr]
|
|
250
|
+
grad_update = None
|
|
251
|
+
hess_update = None
|
|
252
|
+
|
|
253
|
+
if grad_ref is not None:
|
|
254
|
+
grad_update = [grad_corr]
|
|
255
|
+
if hess_ref is not None:
|
|
256
|
+
hess_update = [hess_corr]
|
|
257
|
+
|
|
258
|
+
# Only perform the search if required
|
|
259
|
+
if self._corr_model.HANDLES_MULTI_UPDATE:
|
|
260
|
+
# tols will be applied, so that the previous point will not
|
|
261
|
+
# appear in the additional data
|
|
262
|
+
x_list, val_list, grad_list, hess_list = self.get_reference_data(
|
|
263
|
+
x_ref,
|
|
264
|
+
use_grad=grad_ref is not None,
|
|
265
|
+
use_hess=hess_ref is not None,
|
|
266
|
+
max_norm_thr=max_norm_thr,
|
|
267
|
+
min_norm_thr=1e-4,
|
|
268
|
+
)
|
|
269
|
+
if x_list is not None:
|
|
270
|
+
# Some points were found, add them to the update lists
|
|
271
|
+
for i, x_r in enumerate(x_list):
|
|
272
|
+
# Normally everything has already been evaluated
|
|
273
|
+
# Recompute the correction which should not be too costly
|
|
274
|
+
v_r = val_list[i]
|
|
275
|
+
g_r = grad_list[i] if grad_list is not None else None
|
|
276
|
+
h_r = hess_list[i] if hess_list is not None else None
|
|
277
|
+
v_c, g_c, h_c = self._compute_correction(
|
|
278
|
+
x_r, v_r, grad_ref=g_r, hess_ref=h_r
|
|
279
|
+
)
|
|
280
|
+
x_update.append(x_r)
|
|
281
|
+
val_update.append(v_c)
|
|
282
|
+
if g_c is not None:
|
|
283
|
+
grad_update.append(g_c)
|
|
284
|
+
if h_c is not None:
|
|
285
|
+
hess_update.append(h_c)
|
|
286
|
+
|
|
287
|
+
# Update the model
|
|
288
|
+
self._corr_model.update(
|
|
289
|
+
x_update, val_update, grads=grad_update, hess=hess_update
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def store_reference_data(
|
|
293
|
+
self, x_ref: NDArray, val_ref: Any, grad_ref: Any = None, hess_ref: Any = None
|
|
294
|
+
) -> None:
|
|
295
|
+
"""Store reference data in the database.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
x_ref: The design variables.
|
|
299
|
+
val_ref: The reference value.
|
|
300
|
+
grad_ref: The reference gradient.
|
|
301
|
+
hess_ref: The reference hessian.
|
|
302
|
+
"""
|
|
303
|
+
ref_name = self.REF_TAG + self._orig_func.name
|
|
304
|
+
data_dict = {ref_name: val_ref}
|
|
305
|
+
if grad_ref is not None:
|
|
306
|
+
data_dict[Database.GRAD_TAG + ref_name] = grad_ref
|
|
307
|
+
if hess_ref is not None:
|
|
308
|
+
data_dict[Database.GRAD_TAG + Database.GRAD_TAG + ref_name] = hess_ref
|
|
309
|
+
self._corr_database.store(x_ref, data_dict)
|
|
310
|
+
|
|
311
|
+
def get_reference_data(
|
|
312
|
+
self,
|
|
313
|
+
x_ref: NDArray,
|
|
314
|
+
use_grad: bool = True,
|
|
315
|
+
use_hess: bool = False,
|
|
316
|
+
max_norm_thr: float | None = None,
|
|
317
|
+
min_norm_thr: float = 1e-4,
|
|
318
|
+
) -> tuple[list[ndarray], list[Any], list[Any], list[Any]]:
|
|
319
|
+
"""Get reference data.
|
|
320
|
+
|
|
321
|
+
Gets reference data from the database, close enough to x_ref and which matches
|
|
322
|
+
the required consistency. Points too close to x_ref are discarded.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
x_ref: The input (1D array).
|
|
326
|
+
use_grad: The flag to look for data with grad corrections.
|
|
327
|
+
use_hess: The flag to look for data with hess corrections.
|
|
328
|
+
max_norm_thr: The threshold to discard points too far from x_ref.
|
|
329
|
+
min_norm_thr: The threshold to discard points too close from x_ref.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
x_list, val_list, grad_list, hess_list.
|
|
333
|
+
"""
|
|
334
|
+
if use_hess and not use_grad:
|
|
335
|
+
msg = "Must allow use_grad if use_hess is activated"
|
|
336
|
+
raise ValueError(msg)
|
|
337
|
+
# Only data with the highest specified consistency will be kept
|
|
338
|
+
lookup_method = self._corr_database.get_function_history
|
|
339
|
+
if use_grad:
|
|
340
|
+
lookup_method = self._corr_database.get_gradient_history
|
|
341
|
+
if use_hess:
|
|
342
|
+
lookup_method = self._corr_database.get_hessian_history
|
|
343
|
+
|
|
344
|
+
ref_name = self.REF_TAG + self._orig_func.name
|
|
345
|
+
|
|
346
|
+
_, x_ref_hist = lookup_method(ref_name, with_x_vect=True)
|
|
347
|
+
|
|
348
|
+
if len(x_ref_hist) == 0:
|
|
349
|
+
# No valid data available
|
|
350
|
+
return [None] * 4
|
|
351
|
+
|
|
352
|
+
# Reverse the list to get the last points first
|
|
353
|
+
x_ref_hist = flip(x_ref_hist, 0) # reverse()
|
|
354
|
+
x_ref_hist = self._filter_x_hist(
|
|
355
|
+
x_ref_hist, x_ref, max_norm_thr=max_norm_thr, min_norm_thr=min_norm_thr
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
if len(x_ref_hist) == 0:
|
|
359
|
+
# No valid filtered data available, return
|
|
360
|
+
return [None] * 4
|
|
361
|
+
|
|
362
|
+
# Assemble data
|
|
363
|
+
get_data = self._corr_database.get_function_value
|
|
364
|
+
|
|
365
|
+
val_hist = [get_data(ref_name, x) for x in x_ref_hist]
|
|
366
|
+
grad_hist = None
|
|
367
|
+
hess_hist = None
|
|
368
|
+
|
|
369
|
+
if use_grad:
|
|
370
|
+
grad_tag = Database.GRAD_TAG + ref_name
|
|
371
|
+
grad_hist = [get_data(grad_tag, x) for x in x_ref_hist]
|
|
372
|
+
if use_hess:
|
|
373
|
+
hess_tag = Database.HESS_TAG + ref_name
|
|
374
|
+
hess_hist = [get_data(hess_tag, x) for x in x_ref_hist]
|
|
375
|
+
|
|
376
|
+
return x_ref_hist, val_hist, grad_hist, hess_hist
|
|
377
|
+
|
|
378
|
+
@staticmethod
|
|
379
|
+
def _filter_x_hist(
|
|
380
|
+
x_hist: NDArray,
|
|
381
|
+
x_ref: NDArray,
|
|
382
|
+
max_norm_thr: float | None = None,
|
|
383
|
+
min_norm_thr: float = 1e-4,
|
|
384
|
+
) -> NDArray:
|
|
385
|
+
"""Filter x_hist.
|
|
386
|
+
|
|
387
|
+
Filters x_hist using a proximity criterion with x_ref and max_norm as threshold.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
x_hist: The list of points to filter.
|
|
391
|
+
x_ref: The reference point.
|
|
392
|
+
max_norm_thr: The threshold to discard points too far away.
|
|
393
|
+
min_norm_thr: The threshold to discard points too close.
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
The filtered points.
|
|
397
|
+
"""
|
|
398
|
+
if min_norm_thr <= 0.0:
|
|
399
|
+
msg = "tol must be > 0."
|
|
400
|
+
raise ValueError(msg)
|
|
401
|
+
x_hist = array(x_hist)
|
|
402
|
+
if max_norm_thr is not None:
|
|
403
|
+
# Apply the proximity criterion
|
|
404
|
+
x_hist = x_hist[where(norm(x_hist - x_ref, axis=1) <= max_norm_thr)]
|
|
405
|
+
|
|
406
|
+
# Delete points too close from x_ref
|
|
407
|
+
inds_delete = where(norm(x_hist - x_ref, axis=1) < min_norm_thr)[0]
|
|
408
|
+
|
|
409
|
+
if len(inds_delete) != 0:
|
|
410
|
+
x_hist = delete(x_hist, inds_delete, axis=0)
|
|
411
|
+
return x_hist
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
|
|
2
|
+
#
|
|
3
|
+
# This program is free software; you can redistribute it and/or
|
|
4
|
+
# modify it under the terms of the GNU Lesser General Public
|
|
5
|
+
# License version 3 as published by the Free Software Foundation.
|
|
6
|
+
#
|
|
7
|
+
# This program is distributed in the hope that it will be useful,
|
|
8
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
9
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
10
|
+
# Lesser General Public License for more details.
|
|
11
|
+
#
|
|
12
|
+
# You should have received a copy of the GNU Lesser General Public License
|
|
13
|
+
# along with this program; if not, write to the Free Software Foundation,
|
|
14
|
+
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
|
15
|
+
|
|
16
|
+
# Copyright (c) 2018 AIRBUS OPERATIONS
|
|
17
|
+
#
|
|
18
|
+
# Contributors:
|
|
19
|
+
# INITIAL AUTHORS - API and implementation and/or documentation
|
|
20
|
+
# :author: Romain Olivanti
|
|
21
|
+
# OTHER AUTHORS - MACROSCOPIC CHANGES
|
|
22
|
+
"""Criticality criteria for bound-constrained problem."""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from typing import TYPE_CHECKING
|
|
27
|
+
|
|
28
|
+
from numpy import abs as np_abs
|
|
29
|
+
from numpy import inf
|
|
30
|
+
from numpy import min as np_min
|
|
31
|
+
from numpy import vstack
|
|
32
|
+
from numpy import where
|
|
33
|
+
from numpy import zeros_like
|
|
34
|
+
from numpy.linalg import norm
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from numpy.typing import NDArray
|
|
38
|
+
|
|
39
|
+
from gemseo_multi_fidelity.core.boxed_domain import BoxedDomain
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def crit_out_1(
|
|
43
|
+
x_vect: NDArray,
|
|
44
|
+
grad: NDArray,
|
|
45
|
+
domain: BoxedDomain,
|
|
46
|
+
grad_tol: float,
|
|
47
|
+
bound_tol: float,
|
|
48
|
+
) -> float:
|
|
49
|
+
"""See Mouffe's thesis.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
x_vect: The point.
|
|
53
|
+
grad: The gradient at point.
|
|
54
|
+
domain: The BoxedDomain.
|
|
55
|
+
grad_tol: The gradient tolerance.
|
|
56
|
+
bound_tol: The tolerance on bounds.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The criticality.
|
|
60
|
+
"""
|
|
61
|
+
return norm(compute_crit_vect(x_vect, grad, domain, grad_tol, bound_tol), ord=1)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def crit_in_inf(
|
|
65
|
+
x_vect: NDArray,
|
|
66
|
+
grad: NDArray,
|
|
67
|
+
domain: BoxedDomain,
|
|
68
|
+
grad_tol: float,
|
|
69
|
+
bound_tol: float,
|
|
70
|
+
) -> float:
|
|
71
|
+
"""See Mouffe's thesis.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
x_vect: The point.
|
|
75
|
+
grad: The gradient at point.
|
|
76
|
+
domain: The BoxedDomain.
|
|
77
|
+
grad_tol: The gradient tolerance.
|
|
78
|
+
bound_tol: The tolerance on bounds.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
The criticality.
|
|
82
|
+
"""
|
|
83
|
+
return norm(compute_crit_vect(x_vect, grad, domain, grad_tol, bound_tol), ord=inf)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def compute_crit_vect(
|
|
87
|
+
x_vect: NDArray,
|
|
88
|
+
grad: NDArray,
|
|
89
|
+
domain: BoxedDomain,
|
|
90
|
+
grad_tol: float,
|
|
91
|
+
bound_tol: float,
|
|
92
|
+
) -> NDArray:
|
|
93
|
+
"""Compute the criticality vect at x_vect with respect to the specified domain.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
x_vect: The point to test.
|
|
97
|
+
grad: The gradient at x_vect.
|
|
98
|
+
domain: The BoxedDomain.
|
|
99
|
+
grad_tol: The gradient tolerance.
|
|
100
|
+
bound_tol: The tolerance on bounds.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
The criticality vect
|
|
104
|
+
"""
|
|
105
|
+
crit_ratio = grad_tol / bound_tol
|
|
106
|
+
|
|
107
|
+
# Compute bounds proximity
|
|
108
|
+
dist_upper = domain.get_dist_upper(x_vect)
|
|
109
|
+
dist_lower = domain.get_dist_lower(x_vect)
|
|
110
|
+
|
|
111
|
+
crit = zeros_like(x_vect)
|
|
112
|
+
# Positive grad
|
|
113
|
+
pos_grad = where(grad > 0.0)[0]
|
|
114
|
+
# Check if lower bounds are active
|
|
115
|
+
crit[pos_grad] = np_min(
|
|
116
|
+
np_abs(vstack((grad[pos_grad], crit_ratio * dist_lower[pos_grad]))), axis=0
|
|
117
|
+
)
|
|
118
|
+
# Negative grad
|
|
119
|
+
neg_grad = where(grad < 0.0)[0]
|
|
120
|
+
# Check if upper bounds are active
|
|
121
|
+
crit[neg_grad] = np_min(
|
|
122
|
+
np_abs(vstack((grad[neg_grad], crit_ratio * dist_upper[neg_grad]))), axis=0
|
|
123
|
+
)
|
|
124
|
+
return crit
|