gemseo-multi-fidelity 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. gemseo_multi_fidelity/__init__.py +17 -0
  2. gemseo_multi_fidelity/core/MFMapperAdapter_input.json +22 -0
  3. gemseo_multi_fidelity/core/MFMapperAdapter_output.json +22 -0
  4. gemseo_multi_fidelity/core/MFMapperLinker_input.json +22 -0
  5. gemseo_multi_fidelity/core/MFMapperLinker_output.json +22 -0
  6. gemseo_multi_fidelity/core/MFScenarioAdapter_input.json +39 -0
  7. gemseo_multi_fidelity/core/MFScenarioAdapter_output.json +23 -0
  8. gemseo_multi_fidelity/core/__init__.py +16 -0
  9. gemseo_multi_fidelity/core/boxed_domain.py +242 -0
  10. gemseo_multi_fidelity/core/corr_function.py +411 -0
  11. gemseo_multi_fidelity/core/criticality.py +124 -0
  12. gemseo_multi_fidelity/core/ds_mapper.py +307 -0
  13. gemseo_multi_fidelity/core/errors.py +42 -0
  14. gemseo_multi_fidelity/core/eval_mapper.py +188 -0
  15. gemseo_multi_fidelity/core/id_mapper_adapter.py +61 -0
  16. gemseo_multi_fidelity/core/mapper_adapter.py +126 -0
  17. gemseo_multi_fidelity/core/mapper_linker.py +72 -0
  18. gemseo_multi_fidelity/core/mf_formulation.py +635 -0
  19. gemseo_multi_fidelity/core/mf_logger.py +216 -0
  20. gemseo_multi_fidelity/core/mf_opt_problem.py +480 -0
  21. gemseo_multi_fidelity/core/mf_scenario.py +205 -0
  22. gemseo_multi_fidelity/core/noise_criterion.py +94 -0
  23. gemseo_multi_fidelity/core/projpolytope.out +0 -0
  24. gemseo_multi_fidelity/core/scenario_adapter.py +568 -0
  25. gemseo_multi_fidelity/core/stop_criteria.py +201 -0
  26. gemseo_multi_fidelity/core/strict_chain.py +75 -0
  27. gemseo_multi_fidelity/core/utils_model_quality.py +74 -0
  28. gemseo_multi_fidelity/corrections/__init__.py +16 -0
  29. gemseo_multi_fidelity/corrections/add_corr_function.py +80 -0
  30. gemseo_multi_fidelity/corrections/correction_factory.py +65 -0
  31. gemseo_multi_fidelity/corrections/mul_corr_function.py +86 -0
  32. gemseo_multi_fidelity/drivers/__init__.py +16 -0
  33. gemseo_multi_fidelity/drivers/mf_algo_factory.py +38 -0
  34. gemseo_multi_fidelity/drivers/mf_driver_lib.py +462 -0
  35. gemseo_multi_fidelity/drivers/refinement.py +234 -0
  36. gemseo_multi_fidelity/drivers/settings/__init__.py +16 -0
  37. gemseo_multi_fidelity/drivers/settings/base_mf_driver_settings.py +59 -0
  38. gemseo_multi_fidelity/drivers/settings/mf_refine_settings.py +50 -0
  39. gemseo_multi_fidelity/formulations/__init__.py +16 -0
  40. gemseo_multi_fidelity/formulations/refinement.py +144 -0
  41. gemseo_multi_fidelity/mapping/__init__.py +16 -0
  42. gemseo_multi_fidelity/mapping/identity_mapper.py +74 -0
  43. gemseo_multi_fidelity/mapping/interp_mapper.py +422 -0
  44. gemseo_multi_fidelity/mapping/mapper_factory.py +70 -0
  45. gemseo_multi_fidelity/mapping/mapping_errors.py +46 -0
  46. gemseo_multi_fidelity/mapping/subset_mapper.py +122 -0
  47. gemseo_multi_fidelity/mf_rosenbrock/__init__.py +16 -0
  48. gemseo_multi_fidelity/mf_rosenbrock/delayed_disc.py +136 -0
  49. gemseo_multi_fidelity/mf_rosenbrock/refact_rosen_testcase.py +46 -0
  50. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_case.py +284 -0
  51. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_funcs.py +350 -0
  52. gemseo_multi_fidelity/models/__init__.py +16 -0
  53. gemseo_multi_fidelity/models/fake_updater.py +112 -0
  54. gemseo_multi_fidelity/models/model_updater.py +91 -0
  55. gemseo_multi_fidelity/models/rbf/__init__.py +16 -0
  56. gemseo_multi_fidelity/models/rbf/kernel_factory.py +66 -0
  57. gemseo_multi_fidelity/models/rbf/kernels/__init__.py +16 -0
  58. gemseo_multi_fidelity/models/rbf/kernels/gaussian.py +93 -0
  59. gemseo_multi_fidelity/models/rbf/kernels/matern_3_2.py +101 -0
  60. gemseo_multi_fidelity/models/rbf/kernels/matern_5_2.py +101 -0
  61. gemseo_multi_fidelity/models/rbf/kernels/rbf_kernel.py +172 -0
  62. gemseo_multi_fidelity/models/rbf/rbf_model.py +422 -0
  63. gemseo_multi_fidelity/models/sparse_rbf_updater.py +96 -0
  64. gemseo_multi_fidelity/models/taylor/__init__.py +16 -0
  65. gemseo_multi_fidelity/models/taylor/taylor.py +212 -0
  66. gemseo_multi_fidelity/models/taylor_updater.py +66 -0
  67. gemseo_multi_fidelity/models/updater_factory.py +62 -0
  68. gemseo_multi_fidelity/settings/__init__.py +16 -0
  69. gemseo_multi_fidelity/settings/drivers.py +22 -0
  70. gemseo_multi_fidelity/settings/formulations.py +16 -0
  71. gemseo_multi_fidelity-0.0.1.dist-info/METADATA +99 -0
  72. gemseo_multi_fidelity-0.0.1.dist-info/RECORD +76 -0
  73. gemseo_multi_fidelity-0.0.1.dist-info/WHEEL +5 -0
  74. gemseo_multi_fidelity-0.0.1.dist-info/entry_points.txt +2 -0
  75. gemseo_multi_fidelity-0.0.1.dist-info/licenses/LICENSE.txt +165 -0
  76. gemseo_multi_fidelity-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,462 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ # Francois Gallard: portage to GEMSEO 6
24
+
25
+ """Multi-fidelity abstract driver library."""
26
+
27
+ from __future__ import annotations
28
+
29
+ from abc import abstractmethod
30
+ from copy import deepcopy
31
+ from logging import getLogger
32
+ from time import time
33
+ from typing import TYPE_CHECKING
34
+ from typing import Any
35
+ from typing import ClassVar
36
+
37
+ from gemseo.algos.base_algorithm_library import BaseAlgorithmLibrary
38
+
39
+ from gemseo_multi_fidelity.core.mf_logger import MFLogger
40
+ from gemseo_multi_fidelity.core.scenario_adapter import MFScenarioAdapter
41
+
42
+ # from gemseo_multi_fidelity.drivers.core.quad_updater import QuadUpdater
43
+ # from gemseo_multi_fidelity.drivers.core.utils.precond_builder import PrecondBuilder
44
+
45
+ if TYPE_CHECKING:
46
+ from collections.abc import Callable
47
+ from collections.abc import Iterable
48
+
49
+ from gemseo.algos.optimization_problem import OptimizationProblem
50
+ from gemseo.algos.optimization_result import OptimizationResult
51
+ from numpy.typing import NDArray
52
+
53
+ from gemseo_multi_fidelity.core.eval_mapper import EvaluationMapper
54
+ from gemseo_multi_fidelity.core.mf_opt_problem import MFOptimizationProblem
55
+
56
+ LOGGER = getLogger(__name__)
57
+
58
+
59
+ class MFDriverLibrary(BaseAlgorithmLibrary):
60
+ """Multi-fidelity Driver Library."""
61
+
62
+ DRIVER_NAME = "AbstractMFDriver"
63
+
64
+ X_START = "x_start"
65
+ INPUT_LEVEL = "input_level"
66
+ RUN_OPTIONS = "run_options"
67
+
68
+ OPTIONS_DIR = "options"
69
+ OPTIONS_MAP: ClassVar[dict] = {}
70
+
71
+ STATUS_CONVERGED = "converged"
72
+ STATUS_ERROR = "error"
73
+
74
+ HESS_PAIRS_X0 = "lmem_x0"
75
+ TIMESTAMP = "timestamp"
76
+ _problem: MFOptimizationProblem | None
77
+ """The optimization problem the driver library is bonded to."""
78
+
79
+ def __init__(self, algo_name: str = "Refinement") -> None:
80
+ """Constructor."""
81
+ super().__init__(algo_name)
82
+ self._x_0_driver = None
83
+ self._x_0_hf = None
84
+
85
+ @staticmethod
86
+ def _check_is_instance(
87
+ obj: Any, klass: type | tuple(type), err_msg: str | None = None
88
+ ) -> bool | None:
89
+ if not isinstance(obj, klass):
90
+ if not err_msg:
91
+ err_msg = f"{obj} not a {klass.__name__}"
92
+ raise TypeError(err_msg)
93
+ return True
94
+
95
+ def check_run_options(self, options: dict) -> dict:
96
+ """Check run options.
97
+
98
+ Args:
99
+ options: The options to check.
100
+
101
+ Returns:
102
+ The dict of checked options.
103
+ """
104
+ n_levels = len(self._problem.sub_opt_problems)
105
+
106
+ logger = MFLogger()
107
+ logger.add_line(self.DRIVER_NAME)
108
+
109
+ if options is None:
110
+ options = {}
111
+
112
+ if isinstance(options, dict):
113
+ # Specific setter, duplicate all options
114
+ options = [deepcopy(options) for i in range(n_levels)]
115
+ if len(options) > 0:
116
+ logger.add_line("single dictionary to set options")
117
+ logger.add_line("all options will be forwarded to all levels")
118
+ else:
119
+ # Is it iterable?
120
+ try:
121
+ len_options = len(options)
122
+ # Is it an iterable of dict?
123
+ for elem in options:
124
+ self._check_is_instance(elem, dict)
125
+ except TypeError as exc:
126
+ msg = "options must be an iterable of dict or a single dict"
127
+ raise msg from exc
128
+
129
+ options = list(deepcopy(options))
130
+ # Is the len correct?
131
+ if len_options != n_levels:
132
+ n_missing = n_levels - len_options
133
+ logger.add_line(f"missing options for {n_missing:d} levels")
134
+ logger.add_line("default options will be used")
135
+ options += [{} for i in range(n_levels - len_options)]
136
+ logger.log()
137
+ return options
138
+
139
+ def set_run_options(self, options: Iterable[dict]) -> None:
140
+ """Set run options.
141
+
142
+ Args:
143
+ options: The options to set.
144
+ """
145
+ for level, opt_dict in enumerate(options):
146
+ self._problem.set_exec_options(level, opt_dict)
147
+
148
+ def _set_log_callbacks(self) -> None:
149
+ """Set specific callbacks dedicated to log some events."""
150
+ n_levels = len(self._problem.sub_opt_problems)
151
+
152
+ for level in range(n_levels):
153
+ lvl_logger = MFLogger()
154
+ lvl_logger.add_tag_link("$LVL$", level)
155
+ lvl_logger.add_line(self.DRIVER_NAME)
156
+ lvl_logger.add_line("Running level $LVL$")
157
+
158
+ def log_level_entry(local_data, logger=lvl_logger):
159
+ logger.log()
160
+
161
+ self._problem.add_callback(
162
+ level, MFScenarioAdapter.ON_ENTRY_CALLBACKS, log_level_entry
163
+ )
164
+
165
+ def _set_timestamp_callbacks(self) -> None:
166
+ for _level, prob in enumerate(self._problem.sub_opt_problems):
167
+
168
+ def log_timestamp(x_vect=None, prob=prob):
169
+ dtb = prob.database
170
+ x_new = dtb.get_x_vect(-1)
171
+ if dtb.get_function_value(self.TIMESTAMP, x_new) is None:
172
+ dtb.store(x_new, {self.TIMESTAMP: time()})
173
+
174
+ prob.database.add_store_listener(log_timestamp)
175
+
176
+ def _match_level_from_dvs(self, dvs: dict) -> int | None:
177
+ """Find the index of the first level to which the dvs provided belong.
178
+
179
+ Args:
180
+ dvs: The dict of design variables.
181
+
182
+ Returns:
183
+ The index.
184
+ """
185
+ if not isinstance(dvs, dict):
186
+ msg = f"dvs must be a dict, got instead {type(dvs)}"
187
+ raise TypeError(msg)
188
+
189
+ for i, prob in enumerate(self._problem.sub_opt_problems):
190
+ design_space = prob.design_space
191
+ try:
192
+ design_space.check_membership(dvs)
193
+ except ValueError:
194
+ pass
195
+ else:
196
+ return i
197
+ return None
198
+
199
+ def _propagate_dvs(self, index: int, dvs: dict | None = None) -> None:
200
+ """Propagate the starting in all design spaces.
201
+
202
+ Args:
203
+ index: The index of the level to propagate from.
204
+ dvs: The dict of design variables.
205
+ """
206
+ # Set the starting point at the right level
207
+ if dvs is not None:
208
+ self._problem.sub_opt_problems[index].design_space.set_current_value(dvs)
209
+
210
+ if index != 0:
211
+ # Propagate to upper levels
212
+ for mapper in self._problem.eval_mappers[index - 1 :: -1]:
213
+ ds_mapper = mapper.design_space_mapper
214
+ ds_mapper.map_ds_direct()
215
+
216
+ n_adapters = len(self._problem.sub_opt_problems)
217
+ if index != n_adapters - 1:
218
+ # Propagate to lower levels
219
+ for mapper in self._problem.eval_mappers[index:n_adapters]:
220
+ ds_mapper = mapper.design_space_mapper
221
+ ds_mapper.map_ds_reverse()
222
+
223
+ def set_starting_point(self, options: dict) -> None:
224
+ """Set starting point.
225
+
226
+ Args:
227
+ options: The dict of options.
228
+ """
229
+ x_start = options.get(self.X_START)
230
+
231
+ if x_start is not None:
232
+ x_start = dict(x_start)
233
+ start_level = self._match_level_from_dvs(x_start)
234
+ if start_level is None:
235
+ err = f"{self.X_START} does not belong to any design space"
236
+ raise ValueError(err)
237
+ else:
238
+ start_level = options.get(self.INPUT_LEVEL, 0)
239
+ if start_level >= len(self._problem.sub_opt_problems):
240
+ err = f"{self.INPUT_LEVEL} does not specify a valid level"
241
+ raise ValueError(err)
242
+
243
+ # Propagate the starting point at all levels
244
+ self._propagate_dvs(start_level, dvs=x_start)
245
+
246
+ def _store_x_0(self, start_level: int) -> None:
247
+ """Store x_0.
248
+
249
+ Args:
250
+ start_level: The index of start level.
251
+ """
252
+ start_ds = self._problem.sub_opt_problems[start_level].design_space
253
+ hf_ds = self._problem.sub_opt_problems[0].design_space
254
+ self._x_0_driver = start_ds.get_current_value(as_dict=True)
255
+ self._x_0_hf = hf_ds.get_current_value(as_dict=True)
256
+
257
+ def get_hifi_opt_result(self) -> OptimizationResult:
258
+ """Get the HF optimization result.
259
+
260
+ Retrieves the optimum from the high-fidelity problem and builds an optimization
261
+ result object from it.
262
+
263
+ Returns:
264
+ The optimization result.
265
+ """
266
+ return self._problem.sub_opt_problems[0].solution
267
+ # # compute the best feasible or infeasible point
268
+ # f_opt, x_opt, is_feas, c_opt, c_opt_grad = hf_prob.get_optimum()
269
+ # if f_opt is not None and not hf_prob.minimize_objective:
270
+ # f_opt = -f_opt
271
+ # x_0 = hf_prob.design_space.convert_dict_to_array(self._x_0_hf)
272
+ # return OptimizationResult(
273
+ # x_0=x_0,
274
+ # x_opt=x_opt,
275
+ # f_opt=f_opt,
276
+ # status=status,
277
+ # constraint_values=c_opt,
278
+ # constraints_grad=c_opt_grad,
279
+ # optimizer_name=self.algo_name,
280
+ # message=message,
281
+ # n_obj_call=hf_prob.objective.execution_statistics.n_calls,
282
+ # is_feasible=is_feas
283
+ # )
284
+
285
+ # TODO: CLEAN (SEEMS TO BE DEAD CODE)
286
+ # def _pre_run(self, problem: MFOptimizationProblem, **settings: Any) -> None:
287
+ # """Pre-run.
288
+ #
289
+ # Args:
290
+ # problem: The optimization problem.
291
+ # settings: The settings kwargs.
292
+ # """
293
+ # self.set_starting_point(settings)
294
+ # # Extract and set run options
295
+ # run_options = self.check_run_options(settings.get(self.RUN_OPTIONS, {}))
296
+ # self.set_run_options(run_options)
297
+ #
298
+ # # Call _set_log_callbacks after the run options as the user can
299
+ # # define the callbacks using the run options
300
+ # self._set_log_callbacks()
301
+ # self._set_timestamp_callbacks()
302
+ #
303
+ # # Lib-specific pre-run
304
+ # start_level = self._pre_run(problem, **settings)
305
+ # self._store_x_0(start_level)
306
+
307
+ def _new_iteration_callback(self, x_vect: NDArray) -> None:
308
+ # deactivate iter handling
309
+ pass
310
+
311
+ def _get_result(
312
+ self,
313
+ problem: MFOptimizationProblem,
314
+ message: Any,
315
+ status: Any,
316
+ *args: Any,
317
+ ) -> OptimizationResult:
318
+ """Return the result of the resolution of the problem.
319
+
320
+ Args:
321
+ message: The message associated with the termination criterion if any.
322
+ status: The status associated with the termination criterion if any.
323
+ *args: Specific arguments.
324
+ """
325
+ return self._RESULT_CLASS.from_optimization_problem(
326
+ problem.sub_opt_problems[0],
327
+ message=message,
328
+ status=status,
329
+ optimizer_name=self._algo_name,
330
+ )
331
+
332
+ @abstractmethod
333
+ def _run(
334
+ self, problem: MFOptimizationProblem, **settings: Any
335
+ ) -> OptimizationResult:
336
+ return
337
+
338
+ def _get_result(self, problem: MFOptimizationProblem) -> OptimizationResult:
339
+ """Return the result of the resolution of the problem.
340
+
341
+ Args:
342
+ message: The message associated with the termination criterion if any.
343
+ status: The status associated with the termination criterion if any.
344
+ *args: Specific arguments.
345
+ """
346
+ return problem.sub_opt_problems[0].solution
347
+
348
+ # def _post_run(
349
+ # self,
350
+ # problem: MFOptimizationProblem,
351
+ # result: OptimizationResult,
352
+ # max_design_space_dimension_to_log: int,
353
+ # **settings: Any,
354
+ # ) -> None:
355
+ # print("MF driver lib _post_run sub problems 0")
356
+ # super(MFDriverLibrary)._post_run(
357
+ # problem.sub_opt_problems[0],
358
+ # result=result,
359
+ # max_design_space_dimension_to_log=max_design_space_dimension_to_log,
360
+ # **settings
361
+ # )
362
+
363
+ # def execute(
364
+ # self,
365
+ # problem:MFOptimizationProblem,
366
+ # settings_model: BaseMFDriver_Settings | None = None,
367
+ # **settings: Any)-> OptimizationResult:
368
+ #
369
+ # self._problem = problem
370
+ # # problem.check()
371
+ # settings = self._validate_settings(settings_model=settings_model, **settings)
372
+ #
373
+ #
374
+ # #options = self._load_options(**settings)
375
+ #
376
+ # # Set the starting point
377
+ # self.set_starting_point(settings)
378
+ # # Extract and set run options
379
+ # run_options = self.check_run_options(settings.get(self.RUN_OPTIONS, {}))
380
+ # self.set_run_options(run_options)
381
+ #
382
+ # # Call _set_log_callbacks after the run options as the user can
383
+ # # define the callbacks using the run options
384
+ # self._set_log_callbacks()
385
+ # self._set_timestamp_callbacks()
386
+ #
387
+ # # Lib-specific pre-run
388
+ # start_level = self._pre_run(problem, **settings)
389
+ # self._store_x_0(start_level)
390
+ #
391
+ # # Run the driver
392
+ # result = self._run(problem, **settings)
393
+ #
394
+ # # Lib-specific post-run
395
+ # self._post_run(
396
+ # problem.sub_opt_problems[0],
397
+ # result,
398
+ # max_design_space_dimension_to_log=40,
399
+ # **settings
400
+ # )
401
+ #
402
+ # return result
403
+
404
+ def set_auto_restart(self, level: int) -> None:
405
+ """Set a level to make sure that it can be auto-restarted.
406
+
407
+ i.e. that its output point is also set as the next starting point.
408
+
409
+ Args:
410
+ level: The index of the level to target.
411
+ """
412
+
413
+ def set_restart(local_data):
414
+ x_best = local_data[MFScenarioAdapter.X_BEST]
415
+ local_data[MFScenarioAdapter.X_START] = deepcopy(x_best)
416
+
417
+ self._problem.add_callback(
418
+ level, MFScenarioAdapter.ON_EXIT_CALLBACKS, set_restart
419
+ )
420
+
421
+ @staticmethod
422
+ def _build_precond_build_callback(
423
+ opt_problem: OptimizationProblem, mapper: EvaluationMapper
424
+ ) -> Callable:
425
+ """Build callback to build a preconditioner.
426
+
427
+ Builds a callback allowing to build a preconditioner from the database of the
428
+ optimization problem provided. The callback stores it in the output options of
429
+ the scenario adapter to be used by the next level.
430
+ If a mapper is provided, the preconditioner is mapped accordingly.
431
+
432
+ Args:
433
+ opt_problem: The optimization problem from which the preconditioner is
434
+ built.
435
+ mapper: The mapper to map the hessian approximation pairs used to build the
436
+ preconditioner.
437
+
438
+ Returns:
439
+ The callback.
440
+ """
441
+ raise NotImplementedError
442
+
443
+ @staticmethod
444
+ def _build_precond_forward_callback(mapper: EvaluationMapper) -> Callable:
445
+ """Build preconditioner forwarding callback.
446
+
447
+ Builds a callback allowing to forward the preconditioner used at the current
448
+ level to the next one.
449
+
450
+ Args:
451
+ mapper: The mapper to map the preconditioner.
452
+
453
+ Returns:
454
+ The callback.
455
+ """
456
+ raise NotImplementedError
457
+
458
+ @staticmethod
459
+ def _build_pairs_builder(
460
+ opt_problem: OptimizationProblem, max_corr: int = 100
461
+ ) -> callable:
462
+ raise NotImplementedError
@@ -0,0 +1,234 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ # Francois Gallard: portage to GEMSEO 6
24
+ """Refinement driver."""
25
+
26
+ from __future__ import annotations
27
+
28
+ from collections.abc import Iterable
29
+ from copy import deepcopy
30
+ from dataclasses import dataclass
31
+ from typing import TYPE_CHECKING
32
+ from typing import Any
33
+ from typing import ClassVar
34
+ from typing import Final
35
+
36
+ from gemseo.algos.opt.base_optimization_library import OptimizationAlgorithmDescription
37
+
38
+ from gemseo_multi_fidelity.core.scenario_adapter import MFScenarioAdapter
39
+ from gemseo_multi_fidelity.drivers.mf_driver_lib import MFDriverLibrary
40
+ from gemseo_multi_fidelity.drivers.settings.mf_refine_settings import MFRefine_Settings
41
+ from gemseo_multi_fidelity.mapping.identity_mapper import IdentityMapper
42
+
43
+ if TYPE_CHECKING:
44
+ from gemseo.algos.optimization_result import OptimizationResult
45
+
46
+ from gemseo_multi_fidelity.core.mf_opt_problem import MFOptimizationProblem
47
+ from gemseo_multi_fidelity.drivers.settings.base_mf_driver_settings import (
48
+ BaseMFDriver_Settings,
49
+ )
50
+
51
+
52
+ @dataclass
53
+ class RefinementAlgorithmDescription(OptimizationAlgorithmDescription):
54
+ """The description of the SciPy local optimization library."""
55
+
56
+ library_name: str = "MF Refinement"
57
+ """The library name."""
58
+
59
+ Settings: type[BaseMFDriver_Settings] = MFRefine_Settings
60
+ """The option validation model for Refinement library."""
61
+
62
+
63
+ class Refinement(MFDriverLibrary):
64
+ """Refinement driver class."""
65
+
66
+ __DOC: Final[str] = ""
67
+
68
+ ALGORITHM_INFOS: ClassVar[dict[str, RefinementAlgorithmDescription]] = {
69
+ "MF_REFINE": RefinementAlgorithmDescription(
70
+ algorithm_name="MF_REFINE",
71
+ description=(
72
+ "Refinement driver, sequentially solves the optimization problems."
73
+ ),
74
+ internal_algorithm_name="MF_REFINE",
75
+ require_gradient=False,
76
+ Settings=MFRefine_Settings,
77
+ handle_multiobjective=False,
78
+ handle_integer_variables=True,
79
+ handle_equality_constraints=True,
80
+ handle_inequality_constraints=True,
81
+ ),
82
+ }
83
+ """The algorithm infos."""
84
+
85
+ BUILD_PRECOND = "build_precond"
86
+ FORWARD_PRECOND = "forward_precond"
87
+
88
+ def __init__(self, algo_name: str = "MF_REFINE") -> None:
89
+ """Constructor."""
90
+ super().__init__(algo_name=algo_name)
91
+
92
+ @staticmethod
93
+ def _check_not_instance(
94
+ obj: Any, klass: type | tuple(type), err_msg: str | None = None
95
+ ) -> bool | None:
96
+ if isinstance(obj, klass):
97
+ if not err_msg:
98
+ if isinstance(klass, tuple):
99
+ # Find which class in the tuple matches
100
+ matching_classes = [cls for cls in klass if isinstance(obj, cls)]
101
+ class_names = ", ".join(cls.__name__ for cls in matching_classes)
102
+ err_msg = f"{obj} is a {class_names}"
103
+ else:
104
+ err_msg = f"{obj} is a {klass.__name__}"
105
+ raise TypeError(err_msg)
106
+ return True
107
+
108
+ @staticmethod
109
+ def check_iterable(
110
+ value: Any,
111
+ size: int,
112
+ type_to_check: type | None = None,
113
+ input_name: str = "",
114
+ ) -> Iterable:
115
+ """Check iterable.
116
+
117
+ Args:
118
+ value: The input value.
119
+ size: The size to match.
120
+ type_to_check: The type to match.
121
+ input_name: The input name.
122
+
123
+ Returns:
124
+ The checked value.
125
+ """
126
+ try:
127
+ Refinement._check_not_instance(value, (dict, str))
128
+ if len(value) != size:
129
+ err_size = f"len {input_name} != {size:d}"
130
+ raise ValueError(err_size)
131
+ for item in value:
132
+ Refinement._check_is_instance(item, type_to_check)
133
+ except TypeError as exc:
134
+ if type_to_check is not None and not isinstance(value, type_to_check):
135
+ err_type = (
136
+ f"{input_name} must be a {type_to_check} ",
137
+ f"or a list of {type_to_check}",
138
+ )
139
+ raise TypeError(err_type) from exc
140
+ value = [deepcopy(value) for i in range(size)]
141
+ return value
142
+
143
+ def _check_precond_options(self) -> None:
144
+ """Check the preconditioner options."""
145
+ # Build precond
146
+ self.check_iterable(
147
+ self._settings.build_precond,
148
+ self._problem.n_levels,
149
+ type_to_check=bool,
150
+ input_name=self.BUILD_PRECOND,
151
+ )
152
+ # Forward precond
153
+ self.check_iterable(
154
+ self._settings.forward_precond,
155
+ self._problem.n_levels,
156
+ type_to_check=bool,
157
+ input_name=self.FORWARD_PRECOND,
158
+ )
159
+
160
+ def _build_precond_callbacks(self) -> None:
161
+ """Build the preconditioner callbacks."""
162
+ build_flags = self._settings.build_precond
163
+ forward_flags = self._settings.forward_precond
164
+ if not build_flags and not forward_flags:
165
+ return
166
+
167
+ if not isinstance(forward_flags, Iterable):
168
+ forward_flags = [forward_flags] * self._problem.n_levels
169
+
170
+ if not isinstance(build_flags, Iterable):
171
+ build_flags = [build_flags] * self._problem.n_levels
172
+
173
+ for i, (build_precond, forward_precond) in enumerate(
174
+ zip(build_flags, forward_flags, strict=False)
175
+ ):
176
+ if not (build_precond or forward_precond):
177
+ # Nothing to do
178
+ continue
179
+
180
+ # Select the evaluation mapper - No mapper for the top level
181
+ mapper = self._problem.eval_mappers[i - 1] if i != 0 else None
182
+ # Build flag has the priority over the forward one
183
+ if build_precond:
184
+ # Build the callback
185
+ callback = self._build_precond_build_callback(
186
+ self._problem.sub_opt_problems[i], mapper
187
+ )
188
+ elif forward_precond:
189
+ if mapper is None or isinstance(
190
+ mapper.design_space_mapper, IdentityMapper
191
+ ):
192
+ callback = self._build_precond_forward_callback(mapper)
193
+ else:
194
+ raise NotImplementedError
195
+
196
+ # Add it to the on exit callbacks
197
+ self._problem.add_callback(i, MFScenarioAdapter.ON_EXIT_CALLBACKS, callback)
198
+
199
+ def _pre_run(self, problem: MFOptimizationProblem) -> int:
200
+ """Pre-run.
201
+
202
+ Args:
203
+ problem: The optimization problem.
204
+
205
+ Returns:
206
+ The start level.
207
+ """
208
+ # Start at the lowest level
209
+ start_level = self._problem.n_levels - 1
210
+ self._check_precond_options()
211
+ self._build_precond_callbacks()
212
+ self._store_x_0(start_level)
213
+ return start_level
214
+
215
+ def _run(
216
+ self, problem: MFOptimizationProblem, **settings: Any
217
+ ) -> OptimizationResult:
218
+ """Run.
219
+
220
+ Args:
221
+ problem: The optimization problem.
222
+ settings: The keyword arguments to pass as settings.
223
+
224
+ Returns:
225
+ The optimization result.
226
+ """
227
+ x_start = self._x_0_driver
228
+ # try:
229
+ self._problem.run_exec(x_start)
230
+ return self.get_hifi_opt_result()
231
+ # except Exception as err:
232
+ # status = self.STATUS_ERROR
233
+ # message = "{}: {}".format(err.__class__.__name__, err)
234
+ # solution= OptimizationResult(message=message,status=status)