gemseo-multi-fidelity 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. gemseo_multi_fidelity/__init__.py +17 -0
  2. gemseo_multi_fidelity/core/MFMapperAdapter_input.json +22 -0
  3. gemseo_multi_fidelity/core/MFMapperAdapter_output.json +22 -0
  4. gemseo_multi_fidelity/core/MFMapperLinker_input.json +22 -0
  5. gemseo_multi_fidelity/core/MFMapperLinker_output.json +22 -0
  6. gemseo_multi_fidelity/core/MFScenarioAdapter_input.json +39 -0
  7. gemseo_multi_fidelity/core/MFScenarioAdapter_output.json +23 -0
  8. gemseo_multi_fidelity/core/__init__.py +16 -0
  9. gemseo_multi_fidelity/core/boxed_domain.py +242 -0
  10. gemseo_multi_fidelity/core/corr_function.py +411 -0
  11. gemseo_multi_fidelity/core/criticality.py +124 -0
  12. gemseo_multi_fidelity/core/ds_mapper.py +307 -0
  13. gemseo_multi_fidelity/core/errors.py +42 -0
  14. gemseo_multi_fidelity/core/eval_mapper.py +188 -0
  15. gemseo_multi_fidelity/core/id_mapper_adapter.py +61 -0
  16. gemseo_multi_fidelity/core/mapper_adapter.py +126 -0
  17. gemseo_multi_fidelity/core/mapper_linker.py +72 -0
  18. gemseo_multi_fidelity/core/mf_formulation.py +635 -0
  19. gemseo_multi_fidelity/core/mf_logger.py +216 -0
  20. gemseo_multi_fidelity/core/mf_opt_problem.py +480 -0
  21. gemseo_multi_fidelity/core/mf_scenario.py +205 -0
  22. gemseo_multi_fidelity/core/noise_criterion.py +94 -0
  23. gemseo_multi_fidelity/core/projpolytope.out +0 -0
  24. gemseo_multi_fidelity/core/scenario_adapter.py +568 -0
  25. gemseo_multi_fidelity/core/stop_criteria.py +201 -0
  26. gemseo_multi_fidelity/core/strict_chain.py +75 -0
  27. gemseo_multi_fidelity/core/utils_model_quality.py +74 -0
  28. gemseo_multi_fidelity/corrections/__init__.py +16 -0
  29. gemseo_multi_fidelity/corrections/add_corr_function.py +80 -0
  30. gemseo_multi_fidelity/corrections/correction_factory.py +65 -0
  31. gemseo_multi_fidelity/corrections/mul_corr_function.py +86 -0
  32. gemseo_multi_fidelity/drivers/__init__.py +16 -0
  33. gemseo_multi_fidelity/drivers/mf_algo_factory.py +38 -0
  34. gemseo_multi_fidelity/drivers/mf_driver_lib.py +462 -0
  35. gemseo_multi_fidelity/drivers/refinement.py +234 -0
  36. gemseo_multi_fidelity/drivers/settings/__init__.py +16 -0
  37. gemseo_multi_fidelity/drivers/settings/base_mf_driver_settings.py +59 -0
  38. gemseo_multi_fidelity/drivers/settings/mf_refine_settings.py +50 -0
  39. gemseo_multi_fidelity/formulations/__init__.py +16 -0
  40. gemseo_multi_fidelity/formulations/refinement.py +144 -0
  41. gemseo_multi_fidelity/mapping/__init__.py +16 -0
  42. gemseo_multi_fidelity/mapping/identity_mapper.py +74 -0
  43. gemseo_multi_fidelity/mapping/interp_mapper.py +422 -0
  44. gemseo_multi_fidelity/mapping/mapper_factory.py +70 -0
  45. gemseo_multi_fidelity/mapping/mapping_errors.py +46 -0
  46. gemseo_multi_fidelity/mapping/subset_mapper.py +122 -0
  47. gemseo_multi_fidelity/mf_rosenbrock/__init__.py +16 -0
  48. gemseo_multi_fidelity/mf_rosenbrock/delayed_disc.py +136 -0
  49. gemseo_multi_fidelity/mf_rosenbrock/refact_rosen_testcase.py +46 -0
  50. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_case.py +284 -0
  51. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_funcs.py +350 -0
  52. gemseo_multi_fidelity/models/__init__.py +16 -0
  53. gemseo_multi_fidelity/models/fake_updater.py +112 -0
  54. gemseo_multi_fidelity/models/model_updater.py +91 -0
  55. gemseo_multi_fidelity/models/rbf/__init__.py +16 -0
  56. gemseo_multi_fidelity/models/rbf/kernel_factory.py +66 -0
  57. gemseo_multi_fidelity/models/rbf/kernels/__init__.py +16 -0
  58. gemseo_multi_fidelity/models/rbf/kernels/gaussian.py +93 -0
  59. gemseo_multi_fidelity/models/rbf/kernels/matern_3_2.py +101 -0
  60. gemseo_multi_fidelity/models/rbf/kernels/matern_5_2.py +101 -0
  61. gemseo_multi_fidelity/models/rbf/kernels/rbf_kernel.py +172 -0
  62. gemseo_multi_fidelity/models/rbf/rbf_model.py +422 -0
  63. gemseo_multi_fidelity/models/sparse_rbf_updater.py +96 -0
  64. gemseo_multi_fidelity/models/taylor/__init__.py +16 -0
  65. gemseo_multi_fidelity/models/taylor/taylor.py +212 -0
  66. gemseo_multi_fidelity/models/taylor_updater.py +66 -0
  67. gemseo_multi_fidelity/models/updater_factory.py +62 -0
  68. gemseo_multi_fidelity/settings/__init__.py +16 -0
  69. gemseo_multi_fidelity/settings/drivers.py +22 -0
  70. gemseo_multi_fidelity/settings/formulations.py +16 -0
  71. gemseo_multi_fidelity-0.0.1.dist-info/METADATA +99 -0
  72. gemseo_multi_fidelity-0.0.1.dist-info/RECORD +76 -0
  73. gemseo_multi_fidelity-0.0.1.dist-info/WHEEL +5 -0
  74. gemseo_multi_fidelity-0.0.1.dist-info/entry_points.txt +2 -0
  75. gemseo_multi_fidelity-0.0.1.dist-info/licenses/LICENSE.txt +165 -0
  76. gemseo_multi_fidelity-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,216 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """MF Logger.
24
+
25
+ Provides simple features to get a common multi-fidelity output, distinct from the rest
26
+ of the data logged.
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ from typing import TYPE_CHECKING
32
+
33
+ if TYPE_CHECKING:
34
+ from logging import Logger
35
+
36
+
37
+ class MFLogger:
38
+ """Multi-fidelity logger."""
39
+
40
+ TAG_WRAP_LINE = "@WRAP_LINE@"
41
+
42
+ def __init__(
43
+ self,
44
+ logger: Logger | None = None,
45
+ padding: int = 2,
46
+ min_wrap: int = 3,
47
+ wrap_char: str = "-",
48
+ align_center_at: int = 33,
49
+ ) -> None:
50
+ """Constructor.
51
+
52
+ Args:
53
+ logger: The logger to use to write the log. If not provided, the default
54
+ GEMSEO logger will be used.
55
+ padding: The padding to set between the wrap chars and the content of the
56
+ data logged.
57
+ min_wrap: The minimum number of wrap chars to use on the longest line
58
+ logged.
59
+ wrap_char: The char to use to wrap the content.
60
+ align_center_at: The value at which the data logged will be aligned.
61
+ If the size of one of the lines to log is too big, the data will not be
62
+ aligned.
63
+ """
64
+ self._check_inputs(padding, min_wrap, wrap_char, align_center_at)
65
+
66
+ if logger is None:
67
+ import logging
68
+
69
+ logger = logging.getLogger(__name__)
70
+
71
+ self._logger = logger
72
+ self._padding = padding
73
+ self._min_wrap = min_wrap
74
+ self._wrap_char = wrap_char
75
+ self._align_center_at = align_center_at
76
+
77
+ self._map_replace = {}
78
+ self._info_lines = []
79
+ self._log_data = []
80
+
81
+ @staticmethod
82
+ def _check_inputs(
83
+ padding: int, min_wrap: int, wrap_char: str, align_center_at: int
84
+ ) -> None:
85
+ """Check the inputs.
86
+
87
+ Args:
88
+ padding: The padding to set between the wrap chars and the content of the
89
+ data logged.
90
+ min_wrap: The minimum number of wrap chars to use on the longest line
91
+ logged.
92
+ wrap_char: The char to use to wrap the content.
93
+ align_center_at: The value at which the data logged will be aligned.
94
+ If the size of one of the lines to log is too big, the data will not be
95
+ aligned.
96
+ """
97
+ vars_to_check = ["padding", "min_wrap", "align_center_at"]
98
+ values_to_check = [padding, min_wrap, align_center_at]
99
+ for name, value in zip(vars_to_check, values_to_check, strict=False):
100
+ if not isinstance(value, int):
101
+ msg = f"{name} not an int"
102
+ raise TypeError(msg)
103
+ if value < 0:
104
+ msg = f"{name} must be >= 0"
105
+ raise ValueError(msg)
106
+ if not isinstance(wrap_char, str):
107
+ msg = "wrap char must be a str"
108
+ raise TypeError(msg)
109
+ if len(wrap_char) != 1:
110
+ msg = f"{wrap_char} must be a single char"
111
+ raise ValueError(msg)
112
+
113
+ def add_line(self, line: str, clear: bool = False) -> None:
114
+ """Add a line in memory to be logged later.
115
+
116
+ Args:
117
+ line: The line to log.
118
+ clear: The flag to clear the current data stored.
119
+ """
120
+ if clear:
121
+ self.clear()
122
+ self._info_lines.append(line)
123
+
124
+ def add_wrap_line(self) -> None:
125
+ """Add wrap line."""
126
+ self._info_lines.append(self.TAG_WRAP_LINE)
127
+
128
+ def add_tag_link(self, tag: str, value: str) -> None:
129
+ """Add tag link.
130
+
131
+ Args:
132
+ tag: The tag name.
133
+ value: The tag value.
134
+ """
135
+ tag = str(tag)
136
+ value = str(value)
137
+ self._map_replace[tag] = value
138
+
139
+ def get_tag_links(self) -> dict:
140
+ """Get tag links.
141
+
142
+ Returns:
143
+ A copy of the tag link dict.
144
+ """
145
+ return self._map_replace.copy()
146
+
147
+ def clear(self) -> None:
148
+ """Clear."""
149
+ self._info_lines = []
150
+
151
+ def _build_log_data(self) -> None:
152
+ """Compute the max length taking into the data stored in ``map_replace``."""
153
+ # use the minimum padding settings
154
+ max_len = 0
155
+
156
+ full_lines = []
157
+ add_line = full_lines.append
158
+ for line in self._info_lines:
159
+ full_line = line
160
+ if line == self.TAG_WRAP_LINE:
161
+ # Specific wrap tag skip
162
+ add_line(line)
163
+ continue
164
+ # Try to replace specific tags
165
+ for tag, value in self._map_replace.items():
166
+ full_line = full_line.replace(tag, value)
167
+ add_line(full_line)
168
+ len_line = len(full_line) + 2 * (self._padding + self._min_wrap)
169
+ if len_line > max_len:
170
+ max_len = len_line
171
+
172
+ # Make it odd
173
+ if max_len % 2 == 0:
174
+ max_len += 1
175
+
176
+ # Compute the spacing due to the align parameter
177
+ line_center = max_len // 2 + 1
178
+ spacing = self._align_center_at - line_center
179
+
180
+ if spacing < 0:
181
+ # Do not add any spacing
182
+ spacing = 0
183
+
184
+ spacing = " " * spacing
185
+
186
+ dashed_line = spacing + "".join([self._wrap_char] * max_len)
187
+ self._log_data = ["", dashed_line]
188
+
189
+ add_data = self._log_data.append
190
+ padding = "".join([" "] * self._padding)
191
+
192
+ # Compute the real padding
193
+ for line in full_lines:
194
+ if line == self.TAG_WRAP_LINE:
195
+ add_data(dashed_line)
196
+ else:
197
+ n_dash = max_len - len(line) - 2 * self._padding
198
+ extra_space = "" if n_dash % 2 == 0 else " "
199
+ half_dash = "".join([self._wrap_char] * (n_dash // 2))
200
+ info_line = (
201
+ f"{half_dash}{padding}{line}{extra_space}{padding}{half_dash}"
202
+ )
203
+ add_data(spacing + info_line)
204
+
205
+ self._log_data += [dashed_line]
206
+
207
+ def log(self, clear: bool = False) -> None:
208
+ """Log.
209
+
210
+ Args:
211
+ clear: Whether the logger must be cleared.
212
+ """
213
+ self._build_log_data()
214
+ list(map(self._logger.info, self._log_data))
215
+ if clear:
216
+ self.clear()
@@ -0,0 +1,480 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """MF optimization problem."""
24
+
25
+ from __future__ import annotations
26
+
27
+ import logging
28
+ from typing import TYPE_CHECKING
29
+ from typing import Any
30
+
31
+ import h5py
32
+ from gemseo.algos.database import Database
33
+ from gemseo.algos.optimization_problem import OptimizationProblem
34
+ from gemseo.utils.hdf5 import get_hdf5_group
35
+ from numpy import array
36
+ from numpy import issubdtype
37
+ from numpy import ndarray
38
+ from numpy import number as np_number
39
+ from six import string_types
40
+
41
+ from gemseo_multi_fidelity.core.ds_mapper import DesignSpaceMapper
42
+ from gemseo_multi_fidelity.core.errors import ConsistencyError
43
+ from gemseo_multi_fidelity.core.eval_mapper import EvaluationMapper
44
+
45
+ if TYPE_CHECKING:
46
+ from collections.abc import Callable
47
+ from collections.abc import Iterable
48
+ from pathlib import Path
49
+
50
+ from gemseo.algos.design_space import DesignSpace
51
+ from gemseo.core.mdo_functions.collections.constraints import Constraints
52
+ from numpy.typing import NDArray
53
+ from typing_extensions import Self
54
+
55
+ LOGGER = logging.getLogger(__name__)
56
+
57
+
58
+ class MFOptimizationProblem:
59
+ """Multi-fidelity optimization problem class."""
60
+
61
+ MF_PROB_GROUP = "mf_problem"
62
+ MF_DTB_GROUP = "ref_database"
63
+ MF_F_MAP_GROUP = "func_mapping"
64
+
65
+ def __init__(
66
+ self,
67
+ sub_opt_problems: Iterable[OptimizationProblem],
68
+ eval_mappers: Iterable[EvaluationMapper],
69
+ ref_databases: Iterable[Database],
70
+ ) -> None:
71
+ """Constructor."""
72
+ self._check_inputs(sub_opt_problems, eval_mappers, ref_databases)
73
+ self.sub_opt_problems: list[OptimizationProblem] = sub_opt_problems
74
+ self.eval_mappers: list[EvaluationMapper] = eval_mappers
75
+ self.ref_databases: list[Database] = ref_databases
76
+
77
+ self._run_exec = None
78
+ self._set_option = None
79
+ self._add_callback = None
80
+
81
+ @property
82
+ def design_space(self) -> DesignSpace:
83
+ """Get design space."""
84
+ return self.sub_opt_problems[0].design_space
85
+
86
+ @property
87
+ def constraints(self) -> Constraints:
88
+ """Get constraints."""
89
+ return self.sub_opt_problems[0].constraints
90
+
91
+ @property
92
+ def is_linear(self):
93
+ """Get if an optimization problem is linear."""
94
+ return self.sub_opt_problems[0].is_linear
95
+
96
+ def check(self):
97
+ """Check."""
98
+
99
+ def preprocess_functions(self, *args, **kwargs):
100
+ """Preprocess functions."""
101
+
102
+ @property
103
+ def new_iter_observables(self):
104
+ """Get new iteration observables."""
105
+ return []
106
+
107
+ @property
108
+ def database(self) -> Database:
109
+ """Get database."""
110
+ return self.sub_opt_problems[0].database
111
+
112
+ def _new_iteration_callback(self, x_vect: NDArray) -> False:
113
+ """Iterate the progress bar, implement the stop criteria.
114
+
115
+ Args:
116
+ x_vect: The design variables values.
117
+
118
+ Raises:
119
+ MaxTimeReached: If the elapsed time is greater than the maximum execution
120
+ time.
121
+ """
122
+ # Deactivate this, it is handled by the subproblems
123
+ return
124
+
125
+ @staticmethod
126
+ def check_is_instance(obj: Any, klass: type | tuple(type)) -> None:
127
+ """Check that an object is an instance of a class."""
128
+ if not isinstance(obj, klass):
129
+ msg = f"{obj} not a {klass.__name__}"
130
+ raise TypeError(msg)
131
+
132
+ def _check_inputs_types(
133
+ self,
134
+ sub_opt_problems: Iterable[OptimizationProblem],
135
+ eval_mappers: Iterable[EvaluationMapper],
136
+ ref_databases: Iterable[Database],
137
+ ):
138
+ for obj_list, klass in zip(
139
+ [sub_opt_problems, eval_mappers, ref_databases],
140
+ [OptimizationProblem, EvaluationMapper, Database],
141
+ strict=False,
142
+ ):
143
+
144
+ def check_type(obj):
145
+ MFOptimizationProblem.check_is_instance(obj, klass) # noqa: B023
146
+
147
+ list(map(check_type, obj_list))
148
+
149
+ def _check_inputs_len(
150
+ self,
151
+ sub_opt_problems: Iterable[OptimizationProblem],
152
+ eval_mappers: Iterable[EvaluationMapper],
153
+ ref_databases: Iterable[Database],
154
+ ) -> None:
155
+ n_probs = len(sub_opt_problems)
156
+
157
+ if len(ref_databases) != n_probs:
158
+ msg = (
159
+ "Inconsistent number of sub-optimization problems "
160
+ f"reference databases: {n_probs:d} != {len(ref_databases):d}"
161
+ )
162
+ raise ConsistencyError(msg)
163
+ if len(eval_mappers) != n_probs - 1:
164
+ msg = (
165
+ f"Should be {n_probs - 1:d} "
166
+ f"evaluation mappers not {len(eval_mappers):d}"
167
+ )
168
+ raise ConsistencyError(msg)
169
+
170
+ @staticmethod
171
+ def _check_ds_consistency(
172
+ sub_opt_problems: Iterable[OptimizationProblem],
173
+ eval_mappers: Iterable[EvaluationMapper],
174
+ ) -> None:
175
+ # Make sure the design space of the sub opt problems and the evaluation mappers
176
+ # are correctly referenced
177
+ for i, mapper in enumerate(eval_mappers):
178
+ ds_in = mapper.design_space_mapper.design_space_in
179
+ ds_out = mapper.design_space_mapper.design_space_out
180
+ ds_upper = sub_opt_problems[i].design_space
181
+ ds_lower = sub_opt_problems[i + 1].design_space
182
+ if id(ds_in) != id(ds_lower):
183
+ msg = (
184
+ "design spaces not correctly linked with "
185
+ f"evaluation mappers: level {i:d}"
186
+ )
187
+ raise ConsistencyError(msg)
188
+ if id(ds_out) != id(ds_upper):
189
+ msg = (
190
+ "design spaces not correctly linked with "
191
+ f"evaluation mappers: level {i:d}"
192
+ )
193
+ raise ConsistencyError(msg)
194
+
195
+ def _check_inputs(
196
+ self,
197
+ sub_opt_problems: Iterable[OptimizationProblem],
198
+ eval_mappers: Iterable[EvaluationMapper],
199
+ ref_databases: Iterable[Database],
200
+ ) -> None:
201
+ """Check the inputs."""
202
+ # Make sure the all types are ok
203
+ self._check_inputs_types(sub_opt_problems, eval_mappers, ref_databases)
204
+
205
+ # Make sure the len are consistent
206
+ self._check_inputs_len(sub_opt_problems, eval_mappers, ref_databases)
207
+
208
+ # Check the consistency of the design spaces
209
+ self._check_ds_consistency(sub_opt_problems, eval_mappers)
210
+
211
+ @property
212
+ def n_levels(self) -> int:
213
+ """Get number of levels."""
214
+ return len(self.sub_opt_problems)
215
+
216
+ @property
217
+ def run_exec(self) -> Callable:
218
+ """Run execution getter."""
219
+ if self._run_exec is None:
220
+ msg = "Run exec not set"
221
+ raise RuntimeError(msg)
222
+ return self._run_exec
223
+
224
+ @run_exec.setter
225
+ def run_exec(self, func: Callable) -> None:
226
+ """Run execution setter."""
227
+ self._run_exec = func
228
+
229
+ def _check_level(self, level: int) -> None:
230
+ if level < 0 or level >= len(self.sub_opt_problems):
231
+ err = (
232
+ f"level = {level:d} does not specify a valid level: "
233
+ f"{list(range(len(self.sub_opt_problems)))}"
234
+ )
235
+ raise ValueError(err)
236
+
237
+ def add_callback(self, level: int, target: str, callback: Callable) -> None:
238
+ """Add callback.
239
+
240
+ Args:
241
+ level: The level.
242
+ target: The target name.
243
+ callback: The Callable.
244
+ """
245
+ self._check_level(level)
246
+ if self._add_callback is None:
247
+ msg = "_add_callback not set"
248
+ raise RuntimeError(msg)
249
+ self._add_callback(level, target, callback)
250
+
251
+ def set_exec_option(self, level: int, option: str, value: Any) -> None:
252
+ """Set execution option.
253
+
254
+ Args:
255
+ level: The level.
256
+ option: The option name.
257
+ value: The option value.
258
+ """
259
+ self._check_level(level)
260
+ if self._set_option is None:
261
+ msg = "_set_option not set"
262
+ raise RuntimeError(msg)
263
+ self._set_option(level, option, value)
264
+
265
+ def set_exec_options(self, level: int, options_dict: dict) -> None:
266
+ """Set execution options.
267
+
268
+ Args:
269
+ level: The level.
270
+ options_dict: The options.
271
+ """
272
+ for name, value in options_dict.items():
273
+ self.set_exec_option(level, name, value)
274
+
275
+ def to_hdf(
276
+ self, file_path: Path | str, append: bool = False, hdf_node_path: str = ""
277
+ ) -> None:
278
+ """Export to HDF.
279
+
280
+ Exports the multi-fidelity optimization problem to HDF file.
281
+ The high-fidelity problem can still be imported as a standard optimization
282
+ problem using this file.
283
+
284
+ Args:
285
+ file_path: The path of the file to store the data.
286
+ append: If ``True``, data is appended to the file if not empty.
287
+ hdf_node_path: The path of the HDF node in which the problem should be
288
+ exported, empty to select the root node.
289
+ """
290
+ LOGGER.info("Export multi-fidelity optimization problem to file: %s", file_path)
291
+ # The high-fidelity problem is stored as the root of the node provided
292
+ # to be read as a standard Optimization problem
293
+ self.sub_opt_problems[0].to_hdf(
294
+ file_path, append=append, hdf_node_path=hdf_node_path
295
+ )
296
+
297
+ # Open the file and store all the required data
298
+ # Always open in append mode as exporting the hifi problem will already clean
299
+ # the node if not in append mode
300
+ with h5py.File(file_path, mode="a") as h5file:
301
+ hdf_node = get_hdf5_group(h5file, hdf_node_path)
302
+
303
+ root_path = hdf_node_path + "/" if hdf_node_path != "" else ""
304
+
305
+ opt_node_paths = []
306
+ dtb_node_paths = []
307
+
308
+ # Build the groups required
309
+ for i in range(1, self.n_levels):
310
+ # The other problems are stored as sub-nodes
311
+ opt_path = f"{self.MF_PROB_GROUP}_{i:d}"
312
+ dtb_path = f"{self.MF_DTB_GROUP}_{i:d}"
313
+ f_map_path = f"{self.MF_F_MAP_GROUP}_{i - 1:d}"
314
+ # Build the groups
315
+ hdf_node.require_group(opt_path)
316
+ hdf_node.require_group(dtb_path)
317
+ f_map_grp = hdf_node.require_group(f_map_path)
318
+ # Add the path of the current HDF node as the HDF file is
319
+ # reloaded for each object before writing
320
+ opt_path = root_path + opt_path
321
+ dtb_path = root_path + dtb_path
322
+
323
+ # Store the func mapping
324
+ for f_in, f_out in self.eval_mappers[i - 1].output_mapping.items():
325
+ f_out = array(f_out)
326
+ self.__store_h5data(f_map_grp, f_out, f_in, f_out.dtype)
327
+ # f_map_grp.require_dataset(
328
+ # f_in, f_out.shape, f_out.dtype, exact=True, data=(f_out,)
329
+ # )
330
+
331
+ # Add the paths
332
+ opt_node_paths.append(opt_path)
333
+ dtb_node_paths.append(dtb_path)
334
+
335
+ # Call the export method of the objects to store
336
+ for i, (opt_prob, ref_dtb) in enumerate(
337
+ zip(self.sub_opt_problems[1::], self.ref_databases[1::], strict=False)
338
+ ):
339
+ # Sub-optimization problem
340
+ opt_prob.to_hdf(file_path, append=True, hdf_node_path=opt_node_paths[i])
341
+ # Reference database
342
+ ref_dtb.to_hdf(
343
+ file_path=file_path, append=True, hdf_node_path=dtb_node_paths[i]
344
+ )
345
+
346
+ @staticmethod
347
+ def from_hdf(
348
+ file_path: str | Path,
349
+ x_tolerance: float = 1e-10,
350
+ hdf_node_path: str | None = None,
351
+ ) -> Self:
352
+ """Import a multi-fidelity optimization problem from an HDF file.
353
+
354
+ Args:
355
+ file_path: The path of the file to read.
356
+ x_tolerance: The tolerance to read a point from the database.
357
+ hdf_node_path: The path of the HDF node from which the problem should be
358
+ imported, empty to specify the root node.
359
+
360
+ Returns:
361
+ The read multi-fidelity optimization problem.
362
+ """
363
+ LOGGER.info(
364
+ "Import multi-fidelity optimization problem from file: %s", file_path
365
+ )
366
+ with h5py.File(file_path, mode="r") as h5file:
367
+ hdf_node = get_hdf5_group(h5file, hdf_node_path)
368
+ root_path = hdf_node_path + "/" if hdf_node_path is not None else ""
369
+
370
+ # Check the consistency of the number of sub-problems stored
371
+ n_sub_lvl = len([
372
+ key
373
+ for key in hdf_node
374
+ if key.startswith(MFOptimizationProblem.MF_PROB_GROUP)
375
+ ])
376
+ n_ref_dtb = len([
377
+ key
378
+ for key in hdf_node
379
+ if key.startswith(MFOptimizationProblem.MF_DTB_GROUP)
380
+ ])
381
+ n_f_map = len([
382
+ key
383
+ for key in hdf_node
384
+ if key.startswith(MFOptimizationProblem.MF_F_MAP_GROUP)
385
+ ])
386
+
387
+ if (n_sub_lvl != n_ref_dtb) or (n_sub_lvl != n_f_map):
388
+ raise ConsistencyError(
389
+ "Mismatch between the number of data stored " + "for each level"
390
+ )
391
+
392
+ sub_opt_problems = []
393
+ ref_databases = []
394
+ func_mappings = []
395
+
396
+ # Import the hi-fi problem
397
+ opt_prob = OptimizationProblem.from_hdf(
398
+ file_path, x_tolerance=x_tolerance, hdf_node_path=hdf_node_path
399
+ )
400
+ sub_opt_problems.append(opt_prob)
401
+ # The ref database of the hi-fi problem is the same as the one
402
+ # stored in the optimization problem
403
+ ref_databases.append(opt_prob.database)
404
+
405
+ for i in range(n_sub_lvl):
406
+ opt_path = f"{root_path}{MFOptimizationProblem.MF_PROB_GROUP}_{i + 1:d}"
407
+ dtb_path = f"{root_path}{MFOptimizationProblem.MF_DTB_GROUP}_{i + 1:d}"
408
+ f_map_path = f"{MFOptimizationProblem.MF_F_MAP_GROUP}_{i:d}"
409
+ # Import the sub-optimization problem and the reference database
410
+ opt_prob = OptimizationProblem.from_hdf(
411
+ file_path, x_tolerance=x_tolerance, hdf_node_path=opt_path
412
+ )
413
+ ref_dtb = Database.from_hdf(file_path=file_path, hdf_node_path=dtb_path)
414
+
415
+ # Read the func mapping
416
+ func_mapping = {}
417
+ # Cannot fail at this point because the number of evaluation has
418
+ # been checked previously against the number of sublevels
419
+ f_map_grp = hdf_node[f_map_path]
420
+
421
+ for f_in in f_map_grp:
422
+ f_in = str(f_in)
423
+ f_out = str(f_map_grp[f_in])
424
+ func_mapping[f_in] = f_out
425
+
426
+ sub_opt_problems.append(opt_prob)
427
+ ref_databases.append(ref_dtb)
428
+ func_mappings.append(func_mapping)
429
+
430
+ # Build limited eval mappers. They cannot actually map anything but provide the
431
+ # output mapping
432
+ eval_mappers = []
433
+ for i, (func_mapping, prob) in enumerate(
434
+ zip(func_mappings, sub_opt_problems[0:-1], strict=False)
435
+ ):
436
+ ds_mapper = DesignSpaceMapper(
437
+ sub_opt_problems[i + 1].design_space, prob.design_space
438
+ )
439
+ eval_mapper = EvaluationMapper(ds_mapper, func_mapping)
440
+ eval_mappers.append(eval_mapper)
441
+
442
+ return MFOptimizationProblem(sub_opt_problems, eval_mappers, ref_databases)
443
+
444
+ @staticmethod
445
+ def __store_h5data(
446
+ group: Any, data_array: NDArray, dataset_name: str, dtype: str | None = None
447
+ ) -> None:
448
+ """Store an array in a HDF5 file group.
449
+
450
+ Args:
451
+ group: The group pointer.
452
+ data_array: The data to be stored.
453
+ dataset_name: The name of the dataset to store the array.
454
+ dtype: Numpy dtype or string. If ``None``, dtype('f') will be used.
455
+ TODO: clarify use of dtype argument
456
+ """
457
+ dtype = None
458
+ is_arr_n = isinstance(data_array, ndarray) and issubdtype(
459
+ data_array.dtype, np_number
460
+ )
461
+ attr = data_array
462
+ if isinstance(data_array, string_types):
463
+ attr = data_array.encode("ascii", "ignore")
464
+ elif isinstance(data_array, bytes):
465
+ attr = data_array.decode()
466
+ elif isinstance(data_array, ndarray) and not issubdtype(
467
+ data_array.dtype, np_number
468
+ ):
469
+ attr = str(data_array)
470
+ dtype = h5py.special_dtype(vlen=str)
471
+ elif hasattr(data_array, "__iter__") and not is_arr_n:
472
+ attr = [
473
+ att.encode("ascii", "ignore") if isinstance(att, string_types) else att
474
+ for att in data_array
475
+ ]
476
+ dtype = h5py.special_dtype(vlen=str)
477
+
478
+ if dataset_name in group:
479
+ del group[dataset_name]
480
+ group.create_dataset(dataset_name, data=attr, dtype=dtype)