gemseo-multi-fidelity 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. gemseo_multi_fidelity/__init__.py +17 -0
  2. gemseo_multi_fidelity/core/MFMapperAdapter_input.json +22 -0
  3. gemseo_multi_fidelity/core/MFMapperAdapter_output.json +22 -0
  4. gemseo_multi_fidelity/core/MFMapperLinker_input.json +22 -0
  5. gemseo_multi_fidelity/core/MFMapperLinker_output.json +22 -0
  6. gemseo_multi_fidelity/core/MFScenarioAdapter_input.json +39 -0
  7. gemseo_multi_fidelity/core/MFScenarioAdapter_output.json +23 -0
  8. gemseo_multi_fidelity/core/__init__.py +16 -0
  9. gemseo_multi_fidelity/core/boxed_domain.py +242 -0
  10. gemseo_multi_fidelity/core/corr_function.py +411 -0
  11. gemseo_multi_fidelity/core/criticality.py +124 -0
  12. gemseo_multi_fidelity/core/ds_mapper.py +307 -0
  13. gemseo_multi_fidelity/core/errors.py +42 -0
  14. gemseo_multi_fidelity/core/eval_mapper.py +188 -0
  15. gemseo_multi_fidelity/core/id_mapper_adapter.py +61 -0
  16. gemseo_multi_fidelity/core/mapper_adapter.py +126 -0
  17. gemseo_multi_fidelity/core/mapper_linker.py +72 -0
  18. gemseo_multi_fidelity/core/mf_formulation.py +635 -0
  19. gemseo_multi_fidelity/core/mf_logger.py +216 -0
  20. gemseo_multi_fidelity/core/mf_opt_problem.py +480 -0
  21. gemseo_multi_fidelity/core/mf_scenario.py +205 -0
  22. gemseo_multi_fidelity/core/noise_criterion.py +94 -0
  23. gemseo_multi_fidelity/core/projpolytope.out +0 -0
  24. gemseo_multi_fidelity/core/scenario_adapter.py +568 -0
  25. gemseo_multi_fidelity/core/stop_criteria.py +201 -0
  26. gemseo_multi_fidelity/core/strict_chain.py +75 -0
  27. gemseo_multi_fidelity/core/utils_model_quality.py +74 -0
  28. gemseo_multi_fidelity/corrections/__init__.py +16 -0
  29. gemseo_multi_fidelity/corrections/add_corr_function.py +80 -0
  30. gemseo_multi_fidelity/corrections/correction_factory.py +65 -0
  31. gemseo_multi_fidelity/corrections/mul_corr_function.py +86 -0
  32. gemseo_multi_fidelity/drivers/__init__.py +16 -0
  33. gemseo_multi_fidelity/drivers/mf_algo_factory.py +38 -0
  34. gemseo_multi_fidelity/drivers/mf_driver_lib.py +462 -0
  35. gemseo_multi_fidelity/drivers/refinement.py +234 -0
  36. gemseo_multi_fidelity/drivers/settings/__init__.py +16 -0
  37. gemseo_multi_fidelity/drivers/settings/base_mf_driver_settings.py +59 -0
  38. gemseo_multi_fidelity/drivers/settings/mf_refine_settings.py +50 -0
  39. gemseo_multi_fidelity/formulations/__init__.py +16 -0
  40. gemseo_multi_fidelity/formulations/refinement.py +144 -0
  41. gemseo_multi_fidelity/mapping/__init__.py +16 -0
  42. gemseo_multi_fidelity/mapping/identity_mapper.py +74 -0
  43. gemseo_multi_fidelity/mapping/interp_mapper.py +422 -0
  44. gemseo_multi_fidelity/mapping/mapper_factory.py +70 -0
  45. gemseo_multi_fidelity/mapping/mapping_errors.py +46 -0
  46. gemseo_multi_fidelity/mapping/subset_mapper.py +122 -0
  47. gemseo_multi_fidelity/mf_rosenbrock/__init__.py +16 -0
  48. gemseo_multi_fidelity/mf_rosenbrock/delayed_disc.py +136 -0
  49. gemseo_multi_fidelity/mf_rosenbrock/refact_rosen_testcase.py +46 -0
  50. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_case.py +284 -0
  51. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_funcs.py +350 -0
  52. gemseo_multi_fidelity/models/__init__.py +16 -0
  53. gemseo_multi_fidelity/models/fake_updater.py +112 -0
  54. gemseo_multi_fidelity/models/model_updater.py +91 -0
  55. gemseo_multi_fidelity/models/rbf/__init__.py +16 -0
  56. gemseo_multi_fidelity/models/rbf/kernel_factory.py +66 -0
  57. gemseo_multi_fidelity/models/rbf/kernels/__init__.py +16 -0
  58. gemseo_multi_fidelity/models/rbf/kernels/gaussian.py +93 -0
  59. gemseo_multi_fidelity/models/rbf/kernels/matern_3_2.py +101 -0
  60. gemseo_multi_fidelity/models/rbf/kernels/matern_5_2.py +101 -0
  61. gemseo_multi_fidelity/models/rbf/kernels/rbf_kernel.py +172 -0
  62. gemseo_multi_fidelity/models/rbf/rbf_model.py +422 -0
  63. gemseo_multi_fidelity/models/sparse_rbf_updater.py +96 -0
  64. gemseo_multi_fidelity/models/taylor/__init__.py +16 -0
  65. gemseo_multi_fidelity/models/taylor/taylor.py +212 -0
  66. gemseo_multi_fidelity/models/taylor_updater.py +66 -0
  67. gemseo_multi_fidelity/models/updater_factory.py +62 -0
  68. gemseo_multi_fidelity/settings/__init__.py +16 -0
  69. gemseo_multi_fidelity/settings/drivers.py +22 -0
  70. gemseo_multi_fidelity/settings/formulations.py +16 -0
  71. gemseo_multi_fidelity-0.0.1.dist-info/METADATA +99 -0
  72. gemseo_multi_fidelity-0.0.1.dist-info/RECORD +76 -0
  73. gemseo_multi_fidelity-0.0.1.dist-info/WHEEL +5 -0
  74. gemseo_multi_fidelity-0.0.1.dist-info/entry_points.txt +2 -0
  75. gemseo_multi_fidelity-0.0.1.dist-info/licenses/LICENSE.txt +165 -0
  76. gemseo_multi_fidelity-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,635 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """Baseclass for all multi-fidelity formulations."""
24
+
25
+ from __future__ import annotations
26
+
27
+ from copy import deepcopy
28
+ from typing import TYPE_CHECKING
29
+ from typing import Any
30
+
31
+ from gemseo.algos.database import Database
32
+ from gemseo.formulations.base_mdo_formulation import BaseMDOFormulation
33
+ from gemseo.scenarios.mdo_scenario import MDOScenario
34
+
35
+ from gemseo_multi_fidelity.core.ds_mapper import DesignSpaceMapper
36
+ from gemseo_multi_fidelity.core.errors import ConsistencyError
37
+ from gemseo_multi_fidelity.core.eval_mapper import EvaluationMapper
38
+ from gemseo_multi_fidelity.core.mf_opt_problem import MFOptimizationProblem
39
+ from gemseo_multi_fidelity.corrections.correction_factory import CorrectionFactory
40
+ from gemseo_multi_fidelity.mapping.identity_mapper import IdentityMapper
41
+ from gemseo_multi_fidelity.mapping.mapper_factory import DSMapperFactory
42
+ from gemseo_multi_fidelity.models.updater_factory import UpdaterFactory
43
+
44
+ if TYPE_CHECKING:
45
+ from collections.abc import Callable
46
+ from collections.abc import Iterable
47
+ from collections.abc import Mapping
48
+ from collections.abc import Sequence
49
+
50
+ from gemseo.algos.design_space import DesignSpace
51
+ from gemseo.core.discipline.discipline import Discipline
52
+ from gemseo.core.mdo_functions.mdo_function import MDOFunction
53
+
54
+ from gemseo_multi_fidelity.core.corr_function import MDOCorrectedFunction
55
+
56
+
57
+ class MFFormulation(BaseMDOFormulation):
58
+ """Abstract MF formulation class.
59
+
60
+ To be extended in subclasses for use.
61
+
62
+ The MFFormulation creates the functions for the multi-fidelity drivers from the
63
+ scenarios provided.
64
+
65
+ It defines the process implicitly.
66
+ """
67
+
68
+ NAME = "MFFormulation"
69
+
70
+ DS_MAPPER_TYPE = "mapper_type"
71
+ CORR_TYPE = "correction_type"
72
+ CORR_MODEL_TYPE = "correction_model"
73
+ mf_opt_problem: MFOptimizationProblem
74
+
75
+ def __init__(
76
+ self,
77
+ disciplines: Sequence[MDOScenario],
78
+ objective_name: str,
79
+ design_space: DesignSpace,
80
+ design_space_mapping: DesignSpaceMapper | dict | Iterable | None = None,
81
+ functions_mapping: Iterable[Mapping[str, Any]] | None = None,
82
+ correction_mapping: Iterable[Mapping[str, Any]] | None = None,
83
+ databases: Iterable[Database] | None = None,
84
+ **options,
85
+ ) -> None:
86
+ """Constructor.
87
+
88
+ ``design_space_mapping`` is the mapping logic between the design space of each
89
+ fidelity level. If ``None``, design spaces are assumed to be identical. A dict
90
+ can be provided to specify mappings between two scenarios using their names
91
+ as keys. Alternatively, a list can be provided, whose length must be
92
+ exactly equal to the number of scenarios - 1.
93
+
94
+ Either preexisting ``DSMapper`` objects can be passed or alternatively factory
95
+ names and their options.
96
+
97
+ design_space_mapping = {scenario_upper_name: mapping_logic} (as dict)
98
+ design_space_mapping = [mapping_logic_upper, ..., mapping_logic_lower]
99
+ (as iterable)
100
+
101
+ ``functions_mapping``: If the user chooses to specify scenarios with
102
+ non-matching names between the problem functions, an iterable of dict must be
103
+ provided to build the link between the different names. Otherwise, the functions
104
+ are expected to have identical names.
105
+
106
+ functions_mapping = [{upper_name: lower_name, }, ...]
107
+
108
+ ``correction_mapping``: An iterable of dict can be provided to specify which
109
+ type of correction and which correction model to use, while providing the
110
+ options to set them.
111
+ If provided, the correction mapping must be of len(scenarios) - 1,
112
+ as the top level is not corrected.
113
+
114
+ The consistency of the definition of the ``MFMDOScenario`` is checked.
115
+
116
+ Args:
117
+ disciplines: The list of MDOScenarios, ordered from the highest to the
118
+ lowest fidelity [HF_scenario, ..., LF_scenario].
119
+ objective_name: The name of the objective function from the high-fidelity
120
+ scenario.
121
+ design_space: The design space of the high-fidelity scenario.
122
+ design_space_mapping: The design space mapping (None, dict or fully
123
+ specified iterable).
124
+ functions_mapping: The functions mapping (None or iterable of dict).
125
+ correction_mapping: The correction mapping (None or iterable of dict).
126
+ databases: The Database objects or their file_path
127
+ Databases storing previous calls to the functions of the scenarios.
128
+ [sc_0, ..., sc_n] -> [dtb_0, ..., dtb_n]
129
+ Each elem can be None as well.
130
+ """
131
+ # Standard checks
132
+ # self.check_disciplines(disciplines)
133
+ # Make sure all the disciplines provided are actually scenarios
134
+ self._scenarios = self.check_scenarios(disciplines, design_space)
135
+
136
+ # Standard attributes of the BaseMDOFormulation are linked to the top level
137
+ # for overall consistency
138
+ self._disciplines = self._scenarios[0].formulation.disciplines
139
+ self.optimization_problem = self._scenarios[0].formulation.optimization_problem
140
+
141
+ # Specific attributes
142
+ self._correction_mapping = []
143
+ self._eval_mappers = []
144
+ self._databases = []
145
+ self._opt_adapters = []
146
+
147
+ self._correction_mapping = self._init_corr_mapping(correction_mapping)
148
+ self._eval_mappers = self._build_eval_mappers(
149
+ design_space_mapping, functions_mapping
150
+ )
151
+ self._build_corrected_functions(databases)
152
+ self._init_mf_opt_problem()
153
+
154
+ @property
155
+ def disciplines(self) -> tuple[Discipline, ...]:
156
+ """The sub scenarios."""
157
+ return self._scenarios
158
+
159
+ def get_top_level_disciplines(self) -> tuple[Discipline, ...]:
160
+ """Get the top-level disciplines."""
161
+ return self._disciplines
162
+
163
+ @staticmethod
164
+ def check_scenarios(
165
+ scenarios: Sequence[MDOScenario], design_space: DesignSpace
166
+ ) -> Sequence[MDOScenario]:
167
+ """Check scenarios.
168
+
169
+ Args:
170
+ scenarios: The scenarios.
171
+ design_space: The design space.
172
+
173
+ Returns:
174
+ The checked scenarios.
175
+ """
176
+ # Only MDOScenarios allowed
177
+ for scenario in scenarios:
178
+ if not isinstance(scenario, MDOScenario):
179
+ msg = f"{scenario} not a MDOScenario"
180
+ raise TypeError(msg)
181
+
182
+ if id(scenarios[0].design_space) != id(design_space):
183
+ msg = (
184
+ "The design space provided does not belong "
185
+ "to the high-fidelity scenario"
186
+ )
187
+ raise ConsistencyError(msg)
188
+ return scenarios
189
+
190
+ @staticmethod
191
+ def _check_type(item, expected_type):
192
+ if not isinstance(item, expected_type):
193
+ raise TypeError
194
+
195
+ @staticmethod
196
+ def _check_not_type(item, expected_type):
197
+ if isinstance(item, expected_type):
198
+ raise TypeError
199
+
200
+ def get_reference_databases(self) -> Sequence[Database]:
201
+ """Accessor to the databases storing the calls to the uncorrected functions.
202
+
203
+ Returns:
204
+ The list of ``Database``.
205
+ """
206
+ return self._databases
207
+
208
+ def _build_def_mapper(self, scenario_index: int) -> IdentityMapper:
209
+ """Build mapper.
210
+
211
+ Builds an IdentityMapper between scenario[scenario_index] and
212
+ scenario[scenario_index + 1].
213
+
214
+ Args:
215
+ scenario_index: The index of the scenario.
216
+
217
+ Returns:
218
+ The IdentityMapper.
219
+ """
220
+ ind = scenario_index
221
+ return IdentityMapper(
222
+ self._scenarios[ind + 1].design_space, self._scenarios[ind].design_space
223
+ )
224
+
225
+ def _build_ds_mappers_from_dict(
226
+ self, design_space_mapping: dict
227
+ ) -> Sequence[DesignSpaceMapper]:
228
+ """Build the DesignSpaceMappers from a dict.
229
+
230
+ Args:
231
+ design_space_mapping: The design space mapping.
232
+
233
+ Returns:
234
+ The list of DesignSpaceMapper.
235
+ """
236
+ # Check that scenario names are all different
237
+ scenarios_names = [s.name for s in self._scenarios]
238
+
239
+ if len(set(scenarios_names)) != len(self._scenarios):
240
+ err = (
241
+ "Cannot have scenarios with identical names "
242
+ "if a dict is passed as design_space_mapping arg"
243
+ )
244
+ raise RuntimeError(err)
245
+
246
+ ds_mapper_list = []
247
+
248
+ # The lowest scenario is not mapped with a lower level
249
+ for i, name in enumerate(scenarios_names[0:-1]):
250
+ mapper_link = design_space_mapping.get(name)
251
+ if mapper_link is None:
252
+ mapper_link = self._build_def_mapper(i)
253
+ ds_mapper_list.append(mapper_link)
254
+
255
+ return self._build_ds_mappers_from_list(ds_mapper_list)
256
+
257
+ def _build_ds_mappers_from_list(
258
+ self, design_space_mapping: Iterable[DesignSpaceMapper]
259
+ ) -> Sequence[DesignSpaceMapper]:
260
+ """Build the DesignSpaceMappers from a fully specified iterable.
261
+
262
+ Args:
263
+ design_space_mapping: The design space mapping.
264
+
265
+ Returns:
266
+ The list of DesignSpaceMapper.
267
+ """
268
+ # Make sure the input is an iterable of correct len
269
+ try:
270
+ _ = len(design_space_mapping)
271
+ MFFormulation._check_not_type(design_space_mapping, str)
272
+ except TypeError as exc:
273
+ msg = "design_space_mapping must be an iterable of mapper specifications"
274
+ raise msg from exc
275
+
276
+ if len(design_space_mapping) != len(self._scenarios) - 1:
277
+ err = (
278
+ "A design space mapping must be provided between each scenario "
279
+ "if an iterable is passed"
280
+ )
281
+ raise ValueError(err)
282
+
283
+ mappers = []
284
+ add_mapper = mappers.append
285
+
286
+ for i, mapper_logic in enumerate(design_space_mapping):
287
+ in_scenario = self._scenarios[i + 1]
288
+ out_scenario = self._scenarios[i]
289
+ if isinstance(mapper_logic, DesignSpaceMapper):
290
+ # Make sure stored design spaces point to the design spaces of
291
+ # the scenarios
292
+ if id(mapper_logic.design_space_in) != id(in_scenario.design_space):
293
+ err = (
294
+ f"{mapper_logic} input design space "
295
+ f"does not point to the design space of {in_scenario}"
296
+ )
297
+ raise ConsistencyError(err)
298
+ if id(mapper_logic.design_space_out) != id(out_scenario.design_space):
299
+ err = (
300
+ f"{mapper_logic} output design space "
301
+ f"does not point to the design space of {out_scenario}"
302
+ )
303
+ raise ConsistencyError(err)
304
+ mapper = mapper_logic
305
+ else:
306
+ # Copy as we will use pop the design space type afterwards
307
+ mapper_logic = mapper_logic.copy()
308
+ try:
309
+ ds_type = mapper_logic.pop(self.DS_MAPPER_TYPE)
310
+ except KeyError as exc:
311
+ msg = "Type of DesignSpaceMapper to be used not provided"
312
+ raise ValueError(msg) from exc
313
+ ds_in = in_scenario.design_space
314
+ ds_out = out_scenario.design_space
315
+ mapper = DSMapperFactory().create(
316
+ ds_type, ds_in, ds_out, **mapper_logic
317
+ )
318
+ add_mapper(mapper)
319
+ return mappers
320
+
321
+ def _init_ds_mappers(
322
+ self, design_space_mapping: DesignSpaceMapper | dict | Iterable | None
323
+ ) -> Sequence[DesignSpaceMapper]:
324
+ """Check the design_space_mapping data provided.
325
+
326
+ Builds missing or partially defined DesignSpaceMapper objects.
327
+
328
+ Args:
329
+ design_space_mapping: The design space mapper (None, dict or fully specified
330
+ iterable).
331
+
332
+ Returns:
333
+ The list of DesignSpaceMapper.
334
+ """
335
+ n_scenarios = len(self._scenarios)
336
+
337
+ # Handle allowed cases
338
+ if design_space_mapping is None:
339
+ # Assume identity mappings between all the design spaces
340
+ mappers = [self._build_def_mapper(i) for i in range(n_scenarios - 1)]
341
+ else:
342
+ if isinstance(design_space_mapping, dict):
343
+ mappers = self._build_ds_mappers_from_dict(design_space_mapping)
344
+ else:
345
+ mappers = self._build_ds_mappers_from_list(design_space_mapping)
346
+ return mappers
347
+
348
+ def _check_func_mapping_consistency(
349
+ self, functions_mapping: Iterable[Mapping[str, Any]]
350
+ ) -> None:
351
+ """Check that everything is correctly connected.
352
+
353
+ Args:
354
+ functions_mapping: The list of functions mapping.
355
+ """
356
+ for i, mapping in enumerate(functions_mapping):
357
+ # Make sure the scenarios have the same func names
358
+ upper_prob = self._scenarios[i].formulation.optimization_problem
359
+ lower_prob = self._scenarios[i + 1].formulation.optimization_problem
360
+ lower_f_names = lower_prob.function_names
361
+
362
+ # Store lower funcs that have already been connected to avoid that
363
+ # two upper funcs are connected to the same lower func
364
+ already_connected = []
365
+ for upp_name in upper_prob.function_names:
366
+ connected_to = mapping.get(upp_name)
367
+ if connected_to is None:
368
+ msg = f"Missing {upp_name} in functions_mapping[{i:d}]"
369
+ raise KeyError(msg)
370
+ if connected_to not in lower_f_names:
371
+ err = (
372
+ f"{upp_name} in functions_mapping[{i:d}] connected"
373
+ f" to a func that does not exist: {connected_to}"
374
+ )
375
+ raise ConsistencyError(err)
376
+ if connected_to in already_connected:
377
+ err = (
378
+ f"{upp_name} in functions_mapping[{i:d}] connected"
379
+ f" to a func that has already been connected: {connected_to}"
380
+ )
381
+ raise ConsistencyError(err)
382
+ already_connected.append(connected_to)
383
+
384
+ def _init_func_mapping(
385
+ self, functions_mapping: Iterable[Mapping[str, Any]] | None = None
386
+ ):
387
+ """Check the function mapping provided.
388
+
389
+ If ``None`` all function names are assumed to be identical between the levels.
390
+
391
+ Args:
392
+ functions_mapping: The list of functions mapping (None, or iterable of
393
+ dict).
394
+
395
+ Returns:
396
+ The list of checked functions mapping.
397
+ """
398
+ n_scenarios = len(self._scenarios)
399
+
400
+ if functions_mapping is None:
401
+ if n_scenarios != 1:
402
+ # Assume all the scenarios have identical func names
403
+ top_prob = self._scenarios[0].formulation.optimization_problem
404
+ f_names = top_prob.function_names
405
+ functions_mapping = [
406
+ {f_name: f_name for f_name in f_names}
407
+ for i in range(n_scenarios - 1)
408
+ ]
409
+ else:
410
+ functions_mapping = []
411
+ else:
412
+ # Make sure the mapping is consistent in type and len
413
+ try:
414
+ for elem in functions_mapping:
415
+ MFFormulation._check_type(elem, dict)
416
+ except TypeError as exc:
417
+ msg = "functions_mapping must be an iterable of dict"
418
+ raise msg from exc
419
+
420
+ if len(functions_mapping) != n_scenarios - 1:
421
+ err = "A func mapping must be provided between each scenario"
422
+ raise ValueError(err)
423
+
424
+ self._check_func_mapping_consistency(functions_mapping)
425
+ return functions_mapping
426
+
427
+ def _build_eval_mappers(
428
+ self,
429
+ design_space_mapping: Iterable[DesignSpaceMapper],
430
+ functions_mapping: Iterable[Mapping[str, Any]],
431
+ ) -> Sequence[EvaluationMapper]:
432
+ """Build the evaluation mappers.
433
+
434
+ Args:
435
+ design_space_mapping: The list of DesignSpaceMapper.
436
+ functions_mapping: The list of functions mapping.
437
+
438
+ Returns:
439
+ The list of evaluation mappers.
440
+ """
441
+ eval_mappers = []
442
+ add_mapper = eval_mappers.append
443
+
444
+ # Check the inputs to yield the final objects required
445
+ ds_mappers = self._init_ds_mappers(design_space_mapping)
446
+ funcs_mappers = self._init_func_mapping(functions_mapping)
447
+
448
+ for ds_mapper, func_mapper in zip(ds_mappers, funcs_mappers, strict=False):
449
+ add_mapper(
450
+ EvaluationMapper(ds_mapper, func_mapper, jac_tag=Database.GRAD_TAG)
451
+ )
452
+ return eval_mappers
453
+
454
+ def _init_corr_mapping(
455
+ self, correction_mapping: Iterable[Mapping[str, Any]]
456
+ ) -> Iterable[Mapping[str, Any]]:
457
+ """Check the correction mapping provided.
458
+
459
+ Args:
460
+ correction_mapping: The list of correction mapping (None, or iterable of
461
+ dict).
462
+
463
+ Returns:
464
+ The list of checked correction mapping.
465
+ """
466
+ n_scenarios = len(self._scenarios)
467
+ if correction_mapping is None:
468
+ # Empty correction data, default values will be used later
469
+ correction_mapping = [{} for i in range(n_scenarios - 1)]
470
+ else:
471
+ try:
472
+ for elem in correction_mapping:
473
+ MFFormulation._check_type(elem, dict)
474
+ except TypeError as exc:
475
+ msg = "correction_mapping must be an iterable of dict"
476
+ raise msg from exc
477
+ if len(correction_mapping) != n_scenarios - 1:
478
+ err = "A correction mapping must be provided for each sub scenario"
479
+ raise ValueError(err)
480
+ correction_mapping = deepcopy(correction_mapping)
481
+ for corr_dict, scenario in zip(
482
+ correction_mapping, self._scenarios[1::], strict=False
483
+ ):
484
+ prob = scenario.formulation.optimization_problem
485
+ all_funcs = prob.function_names
486
+ for func in all_funcs:
487
+ if func not in corr_dict:
488
+ corr_dict[func] = {}
489
+ return correction_mapping
490
+
491
+ def _build_corrected_function(
492
+ self, level: int, func: MDOFunction, database: Database
493
+ ) -> MDOCorrectedFunction:
494
+ """Build a corrected function from the mapping data.
495
+
496
+ Args:
497
+ level: The index of the level to select the right correction mapping.
498
+ func: The function to correct.
499
+ database: The Database to store the call to the uncorrected functions.
500
+
501
+ Returns:
502
+ The MDOCorrectedFunction.
503
+ """
504
+ corr_data = self._correction_mapping[level - 1][func.name]
505
+ # Default set to additive correction as it can handle all outputs
506
+ corr_type = corr_data.get(self.CORR_TYPE, CorrectionFactory.ADD_CORR)
507
+ # Default set to Taylor model
508
+ mod_type = corr_data.get(self.CORR_MODEL_TYPE, UpdaterFactory.TAYLOR)
509
+ # Build the correction model
510
+ corr_model = UpdaterFactory().create(mod_type, **corr_data)
511
+ # Build the corrected function
512
+ return CorrectionFactory().create(corr_type, func, corr_model, database)
513
+
514
+ def _build_corrected_functions(self, databases: Iterable[Database]) -> None:
515
+ """Build corrected functions.
516
+
517
+ Builds the databases storing calls to the original functions of each scenario
518
+ and replaces MDOFunctions by MDOCorrectedFunctions.
519
+
520
+ Args:
521
+ databases: The list of ``Database`` objects storing previous calls to the
522
+ functions of the scenarios.
523
+ """
524
+ _ = len(self._scenarios)
525
+
526
+ # Avoid duplicates for double calls (should only happen for testing purposes)
527
+ self._databases = []
528
+
529
+ # Builds all the required databases
530
+ # if databases is None:
531
+ # print("Init databases")
532
+ # self._databases = [Database() for i in range(n_scenarios)]
533
+ # else:
534
+ # # Make sure it matches in len and build the missing ones
535
+ # if len(databases) != n_scenarios:
536
+ # raise ValueError("len(databases) != len(scenarios)")
537
+ #
538
+ # for dtb in databases:
539
+ # if dtb is None:
540
+ # dtb = Database()
541
+ # self._databases.append(dtb)
542
+ self._databases = [
543
+ scn.formulation.optimization_problem.database for scn in self._scenarios
544
+ ]
545
+
546
+ # Replace MDOFunctions by MDOCorrectedFunctions
547
+ for i, (scn, dtb) in enumerate(
548
+ zip(self._scenarios, self._databases, strict=False)
549
+ ):
550
+ opt_prob = scn.formulation.optimization_problem
551
+
552
+ if i == 0:
553
+ # Do not correct, only redirect to database created here
554
+ opt_prob.database = dtb
555
+ continue
556
+
557
+ # Objective
558
+ orig_obj = opt_prob.objective
559
+ opt_prob.objective = self._build_corrected_function(i, orig_obj, dtb)
560
+ # Constraints
561
+ constraints = []
562
+ for con in opt_prob.constraints:
563
+ corr_con = self._build_corrected_function(i, con, dtb)
564
+ constraints.append(corr_con)
565
+ opt_prob.constraints = constraints
566
+
567
+ def _init_mf_opt_problem(self) -> None:
568
+ sub_probs = [scn.formulation.optimization_problem for scn in self._scenarios]
569
+ self.mf_opt_problem = MFOptimizationProblem(
570
+ sub_probs, self._eval_mappers, self._databases
571
+ )
572
+
573
+ def _build_workflow_runner(self):
574
+ raise NotImplementedError
575
+
576
+ def _build_option_setter(self) -> None:
577
+ def set_option(level: int, option: str, value: Any) -> None:
578
+ """Set an option for a specified level.
579
+
580
+ Args:
581
+ level: The level.
582
+ option: The option to set.
583
+ value: The value to set.
584
+ """
585
+ adapter = self._opt_adapters[level]
586
+ adapter.default_input_data[option] = value
587
+
588
+ return set_option
589
+
590
+ def _build_callback_setter(self) -> None:
591
+ def add_callback(level: int, target: str, callback: Callable) -> None:
592
+ """Add a callback for a specific level.
593
+
594
+ Args:
595
+ level: The level.
596
+ target: The type of callback to set.
597
+ callback: The callable.
598
+ """
599
+ adapter = self._opt_adapters[level]
600
+ callbacks = adapter.default_input_data.get(target)
601
+ if callbacks is None:
602
+ # Init with an empty list
603
+ callbacks = []
604
+ adapter.default_input_data[target] = callbacks
605
+ elif callable(callbacks):
606
+ # Make sure to convert it as a list
607
+ callbacks = [callbacks]
608
+ callbacks.append(callback)
609
+
610
+ return add_callback
611
+
612
+ def _build_driver_interface(self) -> None:
613
+ """Build the interface to be used by the multi-fidelity drivers.
614
+
615
+ To be called by subclasses at the end of __init__.
616
+ """
617
+ run_exec = self._build_workflow_runner()
618
+ callback_setter = self._build_callback_setter()
619
+ option_setter = self._build_option_setter()
620
+
621
+ self.mf_opt_problem.run_exec = run_exec
622
+ self.mf_opt_problem._add_callback = callback_setter
623
+ self.mf_opt_problem._set_option = option_setter
624
+
625
+ def get_expected_workflow(
626
+ self,
627
+ ) -> Iterable[Any]:
628
+ """Get the expected workflow."""
629
+ return []
630
+
631
+ def get_expected_dataflow(
632
+ self,
633
+ ) -> Iterable[Any]:
634
+ """Get the expected dataflow."""
635
+ return []