gemseo-multi-fidelity 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. gemseo_multi_fidelity/__init__.py +17 -0
  2. gemseo_multi_fidelity/core/MFMapperAdapter_input.json +22 -0
  3. gemseo_multi_fidelity/core/MFMapperAdapter_output.json +22 -0
  4. gemseo_multi_fidelity/core/MFMapperLinker_input.json +22 -0
  5. gemseo_multi_fidelity/core/MFMapperLinker_output.json +22 -0
  6. gemseo_multi_fidelity/core/MFScenarioAdapter_input.json +39 -0
  7. gemseo_multi_fidelity/core/MFScenarioAdapter_output.json +23 -0
  8. gemseo_multi_fidelity/core/__init__.py +16 -0
  9. gemseo_multi_fidelity/core/boxed_domain.py +242 -0
  10. gemseo_multi_fidelity/core/corr_function.py +411 -0
  11. gemseo_multi_fidelity/core/criticality.py +124 -0
  12. gemseo_multi_fidelity/core/ds_mapper.py +307 -0
  13. gemseo_multi_fidelity/core/errors.py +42 -0
  14. gemseo_multi_fidelity/core/eval_mapper.py +188 -0
  15. gemseo_multi_fidelity/core/id_mapper_adapter.py +61 -0
  16. gemseo_multi_fidelity/core/mapper_adapter.py +126 -0
  17. gemseo_multi_fidelity/core/mapper_linker.py +72 -0
  18. gemseo_multi_fidelity/core/mf_formulation.py +635 -0
  19. gemseo_multi_fidelity/core/mf_logger.py +216 -0
  20. gemseo_multi_fidelity/core/mf_opt_problem.py +480 -0
  21. gemseo_multi_fidelity/core/mf_scenario.py +205 -0
  22. gemseo_multi_fidelity/core/noise_criterion.py +94 -0
  23. gemseo_multi_fidelity/core/projpolytope.out +0 -0
  24. gemseo_multi_fidelity/core/scenario_adapter.py +568 -0
  25. gemseo_multi_fidelity/core/stop_criteria.py +201 -0
  26. gemseo_multi_fidelity/core/strict_chain.py +75 -0
  27. gemseo_multi_fidelity/core/utils_model_quality.py +74 -0
  28. gemseo_multi_fidelity/corrections/__init__.py +16 -0
  29. gemseo_multi_fidelity/corrections/add_corr_function.py +80 -0
  30. gemseo_multi_fidelity/corrections/correction_factory.py +65 -0
  31. gemseo_multi_fidelity/corrections/mul_corr_function.py +86 -0
  32. gemseo_multi_fidelity/drivers/__init__.py +16 -0
  33. gemseo_multi_fidelity/drivers/mf_algo_factory.py +38 -0
  34. gemseo_multi_fidelity/drivers/mf_driver_lib.py +462 -0
  35. gemseo_multi_fidelity/drivers/refinement.py +234 -0
  36. gemseo_multi_fidelity/drivers/settings/__init__.py +16 -0
  37. gemseo_multi_fidelity/drivers/settings/base_mf_driver_settings.py +59 -0
  38. gemseo_multi_fidelity/drivers/settings/mf_refine_settings.py +50 -0
  39. gemseo_multi_fidelity/formulations/__init__.py +16 -0
  40. gemseo_multi_fidelity/formulations/refinement.py +144 -0
  41. gemseo_multi_fidelity/mapping/__init__.py +16 -0
  42. gemseo_multi_fidelity/mapping/identity_mapper.py +74 -0
  43. gemseo_multi_fidelity/mapping/interp_mapper.py +422 -0
  44. gemseo_multi_fidelity/mapping/mapper_factory.py +70 -0
  45. gemseo_multi_fidelity/mapping/mapping_errors.py +46 -0
  46. gemseo_multi_fidelity/mapping/subset_mapper.py +122 -0
  47. gemseo_multi_fidelity/mf_rosenbrock/__init__.py +16 -0
  48. gemseo_multi_fidelity/mf_rosenbrock/delayed_disc.py +136 -0
  49. gemseo_multi_fidelity/mf_rosenbrock/refact_rosen_testcase.py +46 -0
  50. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_case.py +284 -0
  51. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_funcs.py +350 -0
  52. gemseo_multi_fidelity/models/__init__.py +16 -0
  53. gemseo_multi_fidelity/models/fake_updater.py +112 -0
  54. gemseo_multi_fidelity/models/model_updater.py +91 -0
  55. gemseo_multi_fidelity/models/rbf/__init__.py +16 -0
  56. gemseo_multi_fidelity/models/rbf/kernel_factory.py +66 -0
  57. gemseo_multi_fidelity/models/rbf/kernels/__init__.py +16 -0
  58. gemseo_multi_fidelity/models/rbf/kernels/gaussian.py +93 -0
  59. gemseo_multi_fidelity/models/rbf/kernels/matern_3_2.py +101 -0
  60. gemseo_multi_fidelity/models/rbf/kernels/matern_5_2.py +101 -0
  61. gemseo_multi_fidelity/models/rbf/kernels/rbf_kernel.py +172 -0
  62. gemseo_multi_fidelity/models/rbf/rbf_model.py +422 -0
  63. gemseo_multi_fidelity/models/sparse_rbf_updater.py +96 -0
  64. gemseo_multi_fidelity/models/taylor/__init__.py +16 -0
  65. gemseo_multi_fidelity/models/taylor/taylor.py +212 -0
  66. gemseo_multi_fidelity/models/taylor_updater.py +66 -0
  67. gemseo_multi_fidelity/models/updater_factory.py +62 -0
  68. gemseo_multi_fidelity/settings/__init__.py +16 -0
  69. gemseo_multi_fidelity/settings/drivers.py +22 -0
  70. gemseo_multi_fidelity/settings/formulations.py +16 -0
  71. gemseo_multi_fidelity-0.0.1.dist-info/METADATA +99 -0
  72. gemseo_multi_fidelity-0.0.1.dist-info/RECORD +76 -0
  73. gemseo_multi_fidelity-0.0.1.dist-info/WHEEL +5 -0
  74. gemseo_multi_fidelity-0.0.1.dist-info/entry_points.txt +2 -0
  75. gemseo_multi_fidelity-0.0.1.dist-info/licenses/LICENSE.txt +165 -0
  76. gemseo_multi_fidelity-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,422 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2018 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """RBF Model."""
24
+
25
+ from __future__ import annotations
26
+
27
+ from typing import TYPE_CHECKING
28
+ from typing import Any
29
+
30
+ from numpy import asarray
31
+ from numpy import atleast_1d
32
+ from numpy import atleast_2d
33
+ from numpy import concatenate
34
+ from numpy import einsum
35
+ from numpy import empty
36
+ from numpy import inner
37
+ from numpy import split
38
+ from numpy import zeros
39
+ from scipy.linalg import pinvh
40
+ from scipy.linalg import solve
41
+
42
+ from gemseo_multi_fidelity.models.rbf.kernel_factory import KernelFactory
43
+
44
+ if TYPE_CHECKING:
45
+ from numpy.typing import NDArray
46
+
47
+
48
+ class RBFModel:
49
+ """RBF Model with gradient support."""
50
+
51
+ def __init__(self, kernel: str = "gaussian", scaling: float = 1.0) -> None:
52
+ """Constructor.
53
+
54
+ Args:
55
+ kernel: The type of kernel to use.
56
+ scaling: The scaling for the kernel.
57
+ """
58
+ self._kernel_factory = KernelFactory()
59
+ self._kernel = None
60
+ self._centers = None
61
+ self._weights = None
62
+ self._grad_weights = None
63
+
64
+ self._use_mem_eff_train = False
65
+ self._use_mem_eff_eval = False
66
+ self._mem_eff_check_overridden = False
67
+
68
+ self.set_kernel(kernel)
69
+ self.set_scaling(scaling)
70
+
71
+ def set_kernel(self, kernel):
72
+ """Set the type of RBF kernel used.
73
+
74
+ Args:
75
+ kernel: The type of RBF kernel.
76
+ """
77
+ self._kernel = self._kernel_factory.create(kernel)
78
+
79
+ def set_scaling(self, scaling) -> None:
80
+ """Set the scaling parameter of the RBF kernel.
81
+
82
+ Args:
83
+ scaling: The scaling.
84
+ """
85
+ self._kernel.set_scaling(scaling)
86
+
87
+ def override_eff_check(
88
+ self, mem_eff_train: bool = True, mem_eff_eval: bool = False
89
+ ) -> None:
90
+ """Override check for memory efficient routines.
91
+
92
+ Overrides the check used to choose whether to use 'memory efficient' routines
93
+ for training and evaluating. Trade-off between speed and memory usage.
94
+
95
+ Args:
96
+ mem_eff_train: ``True`` to use memory efficient routines for training.
97
+ mem_eff_eval: ``True`` to use memory efficient routines for evaluating.
98
+ """
99
+ self._use_mem_eff_train = bool(mem_eff_train)
100
+ self._use_mem_eff_eval = bool(mem_eff_eval)
101
+ self._mem_eff_check_overridden = True
102
+
103
+ def train(self, centers: NDArray, values: NDArray, gradients: NDArray = None):
104
+ """Train the model using the specified inputs, outputs.
105
+
106
+ Gradient outputs are optional. Fewer gradient outputs than standard outputs can
107
+ be provided.
108
+ They will be associated with the first centers provided.
109
+
110
+ Args:
111
+ centers: The training inputs.
112
+ values: The training outputs.
113
+ gradients: The training gradient outputs.
114
+ """
115
+ self._centers = atleast_2d(centers).copy()
116
+ values = atleast_1d(values)
117
+ dim = self.get_dim()
118
+ n_centers = self.get_n_centers()
119
+ is_gradient_enhanced = gradients is not None
120
+
121
+ if len(values) != n_centers:
122
+ msg = "One value must be provided for each center"
123
+ raise ValueError(msg)
124
+ if is_gradient_enhanced:
125
+ gradients = atleast_2d(gradients)
126
+ if gradients.shape[1] != dim:
127
+ msg = (
128
+ "Gradients dim inconsistent with centers: "
129
+ f"{gradients.shape[1]:d} != {dim:d}"
130
+ )
131
+ raise ValueError(msg)
132
+ if gradients.shape[0] > n_centers:
133
+ msg = "Cannot provide more gradients than centers"
134
+ raise ValueError(msg)
135
+
136
+ if not self._mem_eff_check_overridden:
137
+ # TODO
138
+ # Choose whether to use memory efficient routines
139
+ self._use_mem_eff_eval = False
140
+ self._use_mem_eff_train = True
141
+
142
+ if is_gradient_enhanced and self._use_mem_eff_train:
143
+ train_method = self._compute_weights_mem_eff
144
+ else:
145
+ train_method = self._compute_weights_dir
146
+ self._weights, self._grad_weights = train_method(values, gradients)
147
+
148
+ def get_n_centers(self) -> int:
149
+ """Return the number of centers used for the model.
150
+
151
+ Returns:
152
+ The number of centers.
153
+ """
154
+ return self._centers.shape[0] if self._centers is not None else 0
155
+
156
+ def get_dim(self) -> int:
157
+ """Return the dimension of the input space.
158
+
159
+ Returns:
160
+ The dimension of the input space.
161
+ """
162
+ return self._centers.shape[1] if self._centers is not None else None
163
+
164
+ def _asbl_sqr_fun(self) -> NDArray:
165
+ """Assemble the squared matrix (fun-fun) of the RBF system.
166
+
167
+ Returns:
168
+ The 2x2 matrix.
169
+ """
170
+ # Initialize the RBF matrix
171
+ mat = empty((self.get_n_centers(),) * 2)
172
+
173
+ for i, center in enumerate(self._centers):
174
+ mat[i, :] = self._kernel(center - self._centers)
175
+ return mat
176
+
177
+ def _asbl_rect_grad(self, n_grads: int) -> NDArray:
178
+ """Assemble the rectangular matrix (grad-fun) of the RBF system.
179
+
180
+ Args:
181
+ n_grads: The number of centers linked to a gradient output.
182
+
183
+ Returns:
184
+ The 2x2 matrix.
185
+ """
186
+ mat = empty((n_grads * self.get_dim(), self.get_n_centers()))
187
+ grad_centers = self._centers[0:n_grads]
188
+
189
+ for i, center in enumerate(self._centers):
190
+ mat[:, i] = self._kernel.grad(grad_centers - center).flatten()
191
+ return mat
192
+
193
+ def _asbl_sqr_grad(self, n_grads: int) -> NDArray:
194
+ """Assemble the squared matrix (grad-grad) of the RBF system.
195
+
196
+ Args:
197
+ n_grads: The number of centers linked to a gradient output.
198
+
199
+ Returns:
200
+ The 2x2 matrix.
201
+ """
202
+ dim = self.get_dim()
203
+ mat_size = n_grads * dim
204
+ mat = zeros((mat_size,) * 2)
205
+ grad_centers = self._centers[0:n_grads]
206
+
207
+ for i, center in enumerate(grad_centers):
208
+ start_i = i * dim
209
+ end_i = start_i + dim
210
+ mat[start_i:end_i, :] = (
211
+ self._kernel
212
+ .hess(center - grad_centers)
213
+ .swapaxes(0, 1)
214
+ .reshape((dim, mat_size))
215
+ )
216
+ return mat
217
+
218
+ def _asbl_rbf_mat(self, n_grads: int) -> NDArray:
219
+ """Convenience method to directly solve the RBF system.
220
+
221
+ Assembles the full RBF matrix.
222
+
223
+ Args:
224
+ n_grads: The number of centers linked to a gradient output.
225
+
226
+ Returns:
227
+ The 2x2 matrix.
228
+ """
229
+ n_centers = self.get_n_centers()
230
+ dim = self.get_dim()
231
+ mat_size = n_centers + dim * n_grads
232
+
233
+ # Initialize the RBF matrix
234
+ mat = empty((mat_size,) * 2)
235
+ mat[0:n_centers, 0:n_centers] = self._asbl_sqr_fun()
236
+ if n_grads > 0:
237
+ mat[n_centers::, 0:n_centers] = self._asbl_rect_grad(n_grads)
238
+ mat[0:n_centers, n_centers::] = -mat[n_centers::, 0:n_centers].T
239
+ mat[n_centers::, n_centers::] = self._asbl_sqr_grad(n_grads)
240
+ return mat
241
+
242
+ def _compute_weights_dir(
243
+ self, values: NDArray, gradients: NDArray
244
+ ) -> tuple[NDArray, NDArray | None]:
245
+ """Compute the RBF weights by directly solving the RBF system.
246
+
247
+ Args:
248
+ values: The outputs to match (1D ndarray).
249
+ gradients: The gradient outputs to match (2D ndarray).
250
+
251
+ Returns:
252
+ The list (weights, grad_weights).
253
+ """
254
+ # Symmetric, positive definite only if not gradient enhanced
255
+ is_grad_enhanced = gradients is not None
256
+ if is_grad_enhanced:
257
+ inputs = concatenate([values, gradients.flatten()])
258
+ n_grads = gradients.shape[0]
259
+ else:
260
+ inputs = values
261
+ n_grads = 0
262
+
263
+ res = solve(
264
+ self._asbl_rbf_mat(n_grads),
265
+ inputs,
266
+ assume_a="gen" if is_grad_enhanced else "pos",
267
+ overwrite_a=True,
268
+ overwrite_b=is_grad_enhanced,
269
+ check_finite=False,
270
+ )
271
+
272
+ if is_grad_enhanced:
273
+ # Split the weights
274
+ weights, grad_weights = split(res, [self.get_n_centers()])
275
+ grad_weights = grad_weights.reshape((n_grads, self.get_dim()))
276
+ else:
277
+ weights = res
278
+ grad_weights = None
279
+ return weights, grad_weights
280
+
281
+ def _compute_weights_mem_eff(self, values, gradients) -> tuple[Any, Any]:
282
+ """Compute the RBF weights by benefiting from the structure of the problem.
283
+
284
+ Expected to be more memory efficient when the dimension is high and when several
285
+ gradients are provided.
286
+
287
+ Args:
288
+ values: The outputs to match (1D ndarray).
289
+ gradients: The gradient outputs to match (2D ndarray).
290
+
291
+ Returns:
292
+ The list (weights, grad_weights).
293
+ """
294
+ # Could be optimized further by using sparse representations
295
+ n_grads = gradients.shape[0]
296
+ # Compute the inverse of sqr_fun using the pinvh routine
297
+ # This will use the fact that sqr_fun i symmetric positive definite
298
+ inv_fun = pinvh(
299
+ self._asbl_sqr_fun(), lower=True, return_rank=False, check_finite=False
300
+ )
301
+ # Compute first temp term of fun_weights
302
+ fun_weights = inv_fun.dot(values)
303
+ # Assemble the rectangular gradient matrix
304
+ rect_grad = self._asbl_rect_grad(n_grads)
305
+ tmp_mixed_inputs = gradients.flatten() - rect_grad.dot(fun_weights)
306
+ prod_inv_grad = inv_fun.dot(-rect_grad.T)
307
+ del inv_fun
308
+ tmp_sqr = self._asbl_sqr_grad(n_grads) - rect_grad.dot(prod_inv_grad)
309
+ del rect_grad
310
+ # Symmetrical -> lower = True
311
+
312
+ grad_weights = solve(
313
+ tmp_sqr,
314
+ tmp_mixed_inputs,
315
+ lower=True,
316
+ overwrite_a=True,
317
+ overwrite_b=True,
318
+ check_finite=False,
319
+ )
320
+ del tmp_sqr
321
+ del tmp_mixed_inputs
322
+ fun_weights -= prod_inv_grad.dot(grad_weights)
323
+ return fun_weights, grad_weights.reshape((n_grads, self.get_dim()))
324
+
325
+ def __call__(self, x_vect: NDArray) -> float:
326
+ """Evaluate the model at x_vect.
327
+
328
+ Args:
329
+ x_vect: The point at which the model is evaluated (1d ndarray).
330
+
331
+ Returns:
332
+ The prediction.
333
+ """
334
+ diff = asarray(x_vect) - self._centers
335
+ if self._use_mem_eff_eval:
336
+ call_method = self._compute_val_mem_eff
337
+ else:
338
+ call_method = self._compute_val_vect
339
+ return call_method(diff)
340
+
341
+ def _compute_val_vect(self, diff: NDArray) -> Any:
342
+ """Compute the prediction of the model at x_vect. (Vectorized version).
343
+
344
+ Args:
345
+ diff: The kernel inputs (x_vect - centers) (2D ndarray).
346
+
347
+ Returns:
348
+ The prediction.
349
+ """
350
+ ret = inner(self._weights, self._kernel(diff))
351
+ if self._grad_weights is not None:
352
+ ret += einsum("ij,ij", self._grad_weights, self._kernel.grad(diff))
353
+ return ret
354
+
355
+ def _compute_val_mem_eff(self, diff: NDArray) -> float:
356
+ """Compute the prediction of the model at x_vect. (Less overhead expected).
357
+
358
+ Args:
359
+ diff: The kernel inputs (x_vect - centers) (2d ndarray).
360
+
361
+ Returns:
362
+ The prediction.
363
+ """
364
+ ret = 0.0
365
+ for weight, diff_i in zip(self._weights, diff, strict=False):
366
+ ret += weight * self._kernel(diff_i)
367
+ if self._grad_weights is not None:
368
+ for weight, diff_i in zip(
369
+ self._grad_weights, diff[0 : len(self._grad_weights)], strict=False
370
+ ):
371
+ ret += inner(weight, self._kernel.grad(diff_i))
372
+ return ret
373
+
374
+ def grad(self, x_vect: NDArray):
375
+ """Evaluate the gradient of the model at x_vect.
376
+
377
+ Args:
378
+ x_vect: The point at which the gradient of the model is evaluated
379
+ (1D ndarray).
380
+
381
+ Returns:
382
+ The prediction of the gradient.
383
+ """
384
+ diff = asarray(x_vect) - self._centers
385
+ if self._use_mem_eff_eval:
386
+ grad_method = self._compute_grad_mem_eff
387
+ else:
388
+ grad_method = self._compute_grad_vect
389
+ return grad_method(diff)
390
+
391
+ def _compute_grad_vect(self, diff: NDArray) -> Any:
392
+ """Compute the gradient of the model (Vectorized version).
393
+
394
+ Args:
395
+ diff: The kernel inputs (x_vect - centers) (2D ndarray).
396
+
397
+ Returns:
398
+ The prediction of the gradient.
399
+ """
400
+ ret = einsum("i,ij->j", self._weights, self._kernel.grad(diff))
401
+ if self._grad_weights is not None:
402
+ ret += einsum("ijk,ik->j", self._kernel.hess(diff), self._grad_weights)
403
+ return ret
404
+
405
+ def _compute_grad_mem_eff(self, diff) -> NDArray:
406
+ """Compute the gradient of the model (Less overhead expected).
407
+
408
+ Args:
409
+ diff: The kernel inputs (x_vect - centers) (2D ndarray).
410
+
411
+ Returns:
412
+ The prediction of the gradient.
413
+ """
414
+ ret = zeros(self.get_dim())
415
+ for weight, diff_i in zip(self._weights, diff, strict=False):
416
+ ret += weight * self._kernel.grad(diff_i)
417
+ if self._grad_weights is not None:
418
+ for weight, diff_i in zip(
419
+ self._grad_weights, diff[0 : len(self._grad_weights)], strict=False
420
+ ):
421
+ ret += self._kernel.hess(diff_i).dot(weight)
422
+ return ret
@@ -0,0 +1,96 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2018 AIRBUS OPERATIONS
17
+
18
+ #
19
+ # Contributors:
20
+ # INITIAL AUTHORS - API and implementation and/or documentation
21
+ # :author: Romain Olivanti
22
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
23
+ """Sparse RBF updater."""
24
+
25
+ from __future__ import annotations
26
+
27
+ import logging
28
+ from typing import TYPE_CHECKING
29
+
30
+ from gemseo.core.mdo_functions.mdo_function import MDOFunction
31
+
32
+ from gemseo_multi_fidelity.models.model_updater import ModelUpdater
33
+ from gemseo_multi_fidelity.models.rbf.rbf_model import RBFModel
34
+
35
+ if TYPE_CHECKING:
36
+ from numpy.typing import NDArray
37
+
38
+ LOGGER = logging.getLogger(__name__)
39
+
40
+
41
+ class SparseRBFUpdater(ModelUpdater):
42
+ """Updater for the Sparse RBF model."""
43
+
44
+ HANDLES_MULTI_UPDATE = True
45
+
46
+ N_FUNC_MAX = "rbf_max_val"
47
+ N_GRAD_MAX = "rbf_max_grad"
48
+ KERNEL = "rbf_kernel"
49
+ SCALING = "rbf_scaling"
50
+
51
+ def __init__(self, **kwargs) -> None:
52
+ """Constructor.
53
+
54
+ Args:
55
+ kwargs: The keywords arguments.
56
+ """
57
+ n_func_max = int(kwargs.get(self.N_FUNC_MAX, 20))
58
+ n_grad_max = int(kwargs.get(self.N_GRAD_MAX, 20))
59
+ kernel = kwargs.get(self.KERNEL, "matern_5_2")
60
+ scaling = kwargs.get(self.SCALING, 1.0)
61
+
62
+ if n_func_max < 1:
63
+ msg = f"{self.N_FUNC_MAX} must be > 0"
64
+ raise ValueError(msg)
65
+ if n_grad_max < 1:
66
+ msg = f"{self.N_GRAD_MAX} must be > 0"
67
+ raise ValueError(msg)
68
+
69
+ self._n_func_max = n_func_max
70
+ self._n_grad_max = n_grad_max
71
+
72
+ self._model = RBFModel(kernel=kernel, scaling=scaling)
73
+ super().__init__()
74
+
75
+ def _set_function(self) -> None:
76
+ self.function = MDOFunction(
77
+ self._model, "RBF", jac=self._model.grad, input_names=["x"], dim=1
78
+ )
79
+
80
+ def _update(
81
+ self, vects: NDArray, vals: NDArray, grads: NDArray = None, hess: NDArray = None
82
+ ) -> None:
83
+ n_func_points = min(len(vects), self._n_func_max)
84
+ msg = f"Sparse RBF model relying on {n_func_points:d} value(s)"
85
+ if grads is not None:
86
+ n_grad_points = min(n_func_points, self._n_grad_max)
87
+ gradients = grads[0:n_grad_points]
88
+ msg += f" and {n_grad_points:d} gradient(s)"
89
+ else:
90
+ gradients = None
91
+ LOGGER.info(msg)
92
+
93
+ # TODO cross validation
94
+ self._model.train(
95
+ vects[0:n_func_points], vals[0:n_func_points], gradients=gradients
96
+ )
@@ -0,0 +1,16 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ """Taylor model."""