gemseo-multi-fidelity 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. gemseo_multi_fidelity/__init__.py +17 -0
  2. gemseo_multi_fidelity/core/MFMapperAdapter_input.json +22 -0
  3. gemseo_multi_fidelity/core/MFMapperAdapter_output.json +22 -0
  4. gemseo_multi_fidelity/core/MFMapperLinker_input.json +22 -0
  5. gemseo_multi_fidelity/core/MFMapperLinker_output.json +22 -0
  6. gemseo_multi_fidelity/core/MFScenarioAdapter_input.json +39 -0
  7. gemseo_multi_fidelity/core/MFScenarioAdapter_output.json +23 -0
  8. gemseo_multi_fidelity/core/__init__.py +16 -0
  9. gemseo_multi_fidelity/core/boxed_domain.py +242 -0
  10. gemseo_multi_fidelity/core/corr_function.py +411 -0
  11. gemseo_multi_fidelity/core/criticality.py +124 -0
  12. gemseo_multi_fidelity/core/ds_mapper.py +307 -0
  13. gemseo_multi_fidelity/core/errors.py +42 -0
  14. gemseo_multi_fidelity/core/eval_mapper.py +188 -0
  15. gemseo_multi_fidelity/core/id_mapper_adapter.py +61 -0
  16. gemseo_multi_fidelity/core/mapper_adapter.py +126 -0
  17. gemseo_multi_fidelity/core/mapper_linker.py +72 -0
  18. gemseo_multi_fidelity/core/mf_formulation.py +635 -0
  19. gemseo_multi_fidelity/core/mf_logger.py +216 -0
  20. gemseo_multi_fidelity/core/mf_opt_problem.py +480 -0
  21. gemseo_multi_fidelity/core/mf_scenario.py +205 -0
  22. gemseo_multi_fidelity/core/noise_criterion.py +94 -0
  23. gemseo_multi_fidelity/core/projpolytope.out +0 -0
  24. gemseo_multi_fidelity/core/scenario_adapter.py +568 -0
  25. gemseo_multi_fidelity/core/stop_criteria.py +201 -0
  26. gemseo_multi_fidelity/core/strict_chain.py +75 -0
  27. gemseo_multi_fidelity/core/utils_model_quality.py +74 -0
  28. gemseo_multi_fidelity/corrections/__init__.py +16 -0
  29. gemseo_multi_fidelity/corrections/add_corr_function.py +80 -0
  30. gemseo_multi_fidelity/corrections/correction_factory.py +65 -0
  31. gemseo_multi_fidelity/corrections/mul_corr_function.py +86 -0
  32. gemseo_multi_fidelity/drivers/__init__.py +16 -0
  33. gemseo_multi_fidelity/drivers/mf_algo_factory.py +38 -0
  34. gemseo_multi_fidelity/drivers/mf_driver_lib.py +462 -0
  35. gemseo_multi_fidelity/drivers/refinement.py +234 -0
  36. gemseo_multi_fidelity/drivers/settings/__init__.py +16 -0
  37. gemseo_multi_fidelity/drivers/settings/base_mf_driver_settings.py +59 -0
  38. gemseo_multi_fidelity/drivers/settings/mf_refine_settings.py +50 -0
  39. gemseo_multi_fidelity/formulations/__init__.py +16 -0
  40. gemseo_multi_fidelity/formulations/refinement.py +144 -0
  41. gemseo_multi_fidelity/mapping/__init__.py +16 -0
  42. gemseo_multi_fidelity/mapping/identity_mapper.py +74 -0
  43. gemseo_multi_fidelity/mapping/interp_mapper.py +422 -0
  44. gemseo_multi_fidelity/mapping/mapper_factory.py +70 -0
  45. gemseo_multi_fidelity/mapping/mapping_errors.py +46 -0
  46. gemseo_multi_fidelity/mapping/subset_mapper.py +122 -0
  47. gemseo_multi_fidelity/mf_rosenbrock/__init__.py +16 -0
  48. gemseo_multi_fidelity/mf_rosenbrock/delayed_disc.py +136 -0
  49. gemseo_multi_fidelity/mf_rosenbrock/refact_rosen_testcase.py +46 -0
  50. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_case.py +284 -0
  51. gemseo_multi_fidelity/mf_rosenbrock/rosen_mf_funcs.py +350 -0
  52. gemseo_multi_fidelity/models/__init__.py +16 -0
  53. gemseo_multi_fidelity/models/fake_updater.py +112 -0
  54. gemseo_multi_fidelity/models/model_updater.py +91 -0
  55. gemseo_multi_fidelity/models/rbf/__init__.py +16 -0
  56. gemseo_multi_fidelity/models/rbf/kernel_factory.py +66 -0
  57. gemseo_multi_fidelity/models/rbf/kernels/__init__.py +16 -0
  58. gemseo_multi_fidelity/models/rbf/kernels/gaussian.py +93 -0
  59. gemseo_multi_fidelity/models/rbf/kernels/matern_3_2.py +101 -0
  60. gemseo_multi_fidelity/models/rbf/kernels/matern_5_2.py +101 -0
  61. gemseo_multi_fidelity/models/rbf/kernels/rbf_kernel.py +172 -0
  62. gemseo_multi_fidelity/models/rbf/rbf_model.py +422 -0
  63. gemseo_multi_fidelity/models/sparse_rbf_updater.py +96 -0
  64. gemseo_multi_fidelity/models/taylor/__init__.py +16 -0
  65. gemseo_multi_fidelity/models/taylor/taylor.py +212 -0
  66. gemseo_multi_fidelity/models/taylor_updater.py +66 -0
  67. gemseo_multi_fidelity/models/updater_factory.py +62 -0
  68. gemseo_multi_fidelity/settings/__init__.py +16 -0
  69. gemseo_multi_fidelity/settings/drivers.py +22 -0
  70. gemseo_multi_fidelity/settings/formulations.py +16 -0
  71. gemseo_multi_fidelity-0.0.1.dist-info/METADATA +99 -0
  72. gemseo_multi_fidelity-0.0.1.dist-info/RECORD +76 -0
  73. gemseo_multi_fidelity-0.0.1.dist-info/WHEEL +5 -0
  74. gemseo_multi_fidelity-0.0.1.dist-info/entry_points.txt +2 -0
  75. gemseo_multi_fidelity-0.0.1.dist-info/licenses/LICENSE.txt +165 -0
  76. gemseo_multi_fidelity-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,411 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2019 AIRBUS OPERATIONS
17
+ #
18
+ # Contributors:
19
+ # INITIAL AUTHORS - API and implementation and/or documentation
20
+ # :author: Romain Olivanti
21
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
22
+ """MDOCorrected Function."""
23
+
24
+ from __future__ import annotations
25
+
26
+ from typing import TYPE_CHECKING
27
+ from typing import Any
28
+
29
+ from gemseo.algos.database import Database
30
+ from gemseo.core.mdo_functions.mdo_function import MDOFunction
31
+ from numpy import array
32
+ from numpy import delete
33
+ from numpy import flip
34
+ from numpy import where
35
+ from numpy.linalg import norm
36
+
37
+ from gemseo_multi_fidelity.models.model_updater import ModelUpdater
38
+
39
+ if TYPE_CHECKING:
40
+ from numpy import ndarray
41
+ from numpy.typing import NDArray
42
+
43
+
44
+ class MDOCorrectedFunction(MDOFunction):
45
+ """MDO Corrected Function class.
46
+
47
+ Corrected function, that can be dynamically updated using the provided correction
48
+ database. Protects the calls to the original function using this database and store
49
+ the reference data provided.
50
+ """
51
+
52
+ REF_TAG = "#CORR_REF#"
53
+
54
+ def __init__(
55
+ self, orig_func: MDOFunction, corr_model: ModelUpdater, corr_database: Database
56
+ ) -> None:
57
+ """Constructor.
58
+
59
+ Args:
60
+ orig_func: The uncorrected function.
61
+ corr_model: The model used for the correction.
62
+ corr_database: The database used for the correction.
63
+ """
64
+ if not isinstance(orig_func, MDOFunction):
65
+ msg = "orig_func must be a MDOFunction"
66
+ raise TypeError(msg)
67
+
68
+ if not isinstance(corr_database, Database):
69
+ msg = "corr_database must be a Database"
70
+ raise TypeError(msg)
71
+
72
+ if not isinstance(corr_model, ModelUpdater):
73
+ msg = "corr_model must be an instance of ModelUpdater"
74
+ raise TypeError(msg)
75
+
76
+ # Clone the specified attr
77
+ args_names = ["func", "name"]
78
+ args = [getattr(orig_func, attr) for attr in args_names]
79
+ # args[-1] = "corr_" + args[-1]
80
+ kwargs_names = ["f_type", "jac", "expr", "input_names", "dim", "output_names"]
81
+ kwargs = {attr: getattr(orig_func, attr) for attr in kwargs_names}
82
+
83
+ super().__init__(*args, **kwargs)
84
+ # Reference to the database used to protect the calls to the original
85
+ # functions and to store the reference data.
86
+ self._corr_database = corr_database
87
+ self._corr_model = corr_model
88
+ self._orig_func = self._process_orig_func(orig_func)
89
+ self._build_func_pointers()
90
+
91
+ def get_orig_func(self) -> MDOFunction:
92
+ """Get the original function.
93
+
94
+ Returns a pointer to the protected version of the original function (protected
95
+ by the database).
96
+
97
+ Returns:
98
+ The original function (callback).
99
+ """
100
+ return self._orig_func
101
+
102
+ def get_orig_database(self) -> Database:
103
+ """Return the database used to protect the data.
104
+
105
+ Returns:
106
+ The database.
107
+ """
108
+ return self._corr_database
109
+
110
+ def _process_orig_func(self, orig_func: MDOFunction) -> MDOFunction:
111
+ """Process orig_func to make sure that its call if protected by a database.
112
+
113
+ Args:
114
+ orig_func: The function to wrap.
115
+
116
+ Returns:
117
+ The protected version of orig_func.
118
+ """
119
+ # Make sure it is protected by the database
120
+ fname = orig_func.name
121
+
122
+ def wrapped_orig_func(x_vect: NDArray) -> float:
123
+ """Wrapped original function.
124
+
125
+ Makes sure each call is stored in the provided database.
126
+
127
+ Args:
128
+ x_vect: The design variable.
129
+
130
+ Returns:
131
+ The evaluation of the function at x_vect.
132
+ """
133
+ value = None
134
+ if self._corr_database.get(x_vect, False):
135
+ # x_vect is part of the database
136
+ value = self._corr_database.get_function_value(fname, x_vect)
137
+
138
+ if value is None:
139
+ # Not evaluated yet, evaluate
140
+ value = orig_func.func(x_vect)
141
+ values_dict = {fname: value}
142
+ # Store the new value
143
+ self._corr_database.store(x_vect, values_dict)
144
+ return value
145
+
146
+ db_func = MDOFunction(
147
+ wrapped_orig_func,
148
+ name=fname,
149
+ f_type=orig_func.f_type,
150
+ expr=orig_func.expr,
151
+ input_names=orig_func.input_names,
152
+ dim=orig_func.dim,
153
+ output_names=orig_func.output_names,
154
+ )
155
+
156
+ if orig_func.has_jac:
157
+
158
+ def dwrapped_orig_func(x_vect: Any) -> NDArray:
159
+ """Wrapped original jacobian.
160
+
161
+ Makes sure each call is stored in the provided database.
162
+
163
+ Args:
164
+ x_vect: The design variable.
165
+
166
+ Returns:
167
+ The evaluation of the jacobian at x_vect.
168
+ """
169
+ jac = None
170
+ if self._corr_database.get(x_vect, False):
171
+ # x_vect is part of the database
172
+ jac = self._corr_database.get_function_value(
173
+ Database.GRAD_TAG + fname, x_vect
174
+ )
175
+
176
+ if jac is None:
177
+ # not evaluated yet, evaluate
178
+ jac = orig_func.jac(x_vect).real
179
+ values_dict = {Database.GRAD_TAG + fname: jac}
180
+ # Store the new value
181
+ self._corr_database.store(x_vect, values_dict)
182
+ return jac
183
+
184
+ db_func.jac = dwrapped_orig_func
185
+
186
+ # Return only the reference of the db_func
187
+ return db_func
188
+
189
+ def _build_func_pointers(self):
190
+ """Set the corrected function. To be overloaded by subclasses."""
191
+ raise NotImplementedError
192
+
193
+ def _compute_correction(
194
+ self,
195
+ x_ref: NDArray,
196
+ val_ref: NDArray,
197
+ grad_ref: NDArray = None,
198
+ hess_ref: NDArray = None,
199
+ ):
200
+ """Compute correction.
201
+
202
+ Actual method to compute the correction from the reference data.
203
+ To be overloaded by subclasses.
204
+
205
+ Args:
206
+ x_ref: The variables.
207
+ val_ref: The value to match at x_ref.
208
+ grad_ref: The gradient to match at x_ref.
209
+ hess_ref: The hessian to match at x_ref.
210
+
211
+ Returns:
212
+ value correction, gradient correction, hessian correction.
213
+ """
214
+ raise NotImplementedError
215
+
216
+ def set_reference(
217
+ self,
218
+ x_ref: NDArray,
219
+ val_ref: Any,
220
+ grad_ref: Any = None,
221
+ hess_ref: Any = None,
222
+ max_norm_thr: float | None = None,
223
+ ) -> None:
224
+ """Set reference.
225
+
226
+ Sets the reference so that the corrected function matches the reference data.
227
+ Previous correction data with the same order of consistency, whose norm with
228
+ x_ref is within max_norm_thr, is also used if the correction model handles it.
229
+
230
+ Args:
231
+ x_ref: The design variables.
232
+ val_ref: The value to match at x_ref.
233
+ grad_ref: The gradient to match at x_ref.
234
+ hess_ref: The hessian to match at x_ref.
235
+ max_norm_thr: The threshold to discard points too far from x_ref.
236
+ """
237
+ # Always recompute the correction with the specified reference value even
238
+ # if already stored in the database
239
+ val_corr, grad_corr, hess_corr = self._compute_correction(
240
+ x_ref, val_ref, grad_ref=grad_ref, hess_ref=hess_ref
241
+ )
242
+
243
+ # Store the reference data
244
+ # Storing the reference rather than the correction data allow
245
+ # not to be correction-dependent for external uses
246
+ self.store_reference_data(x_ref, val_ref, grad_ref=grad_ref, hess_ref=hess_ref)
247
+
248
+ x_update = [x_ref]
249
+ val_update = [val_corr]
250
+ grad_update = None
251
+ hess_update = None
252
+
253
+ if grad_ref is not None:
254
+ grad_update = [grad_corr]
255
+ if hess_ref is not None:
256
+ hess_update = [hess_corr]
257
+
258
+ # Only perform the search if required
259
+ if self._corr_model.HANDLES_MULTI_UPDATE:
260
+ # tols will be applied, so that the previous point will not
261
+ # appear in the additional data
262
+ x_list, val_list, grad_list, hess_list = self.get_reference_data(
263
+ x_ref,
264
+ use_grad=grad_ref is not None,
265
+ use_hess=hess_ref is not None,
266
+ max_norm_thr=max_norm_thr,
267
+ min_norm_thr=1e-4,
268
+ )
269
+ if x_list is not None:
270
+ # Some points were found, add them to the update lists
271
+ for i, x_r in enumerate(x_list):
272
+ # Normally everything has already been evaluated
273
+ # Recompute the correction which should not be too costly
274
+ v_r = val_list[i]
275
+ g_r = grad_list[i] if grad_list is not None else None
276
+ h_r = hess_list[i] if hess_list is not None else None
277
+ v_c, g_c, h_c = self._compute_correction(
278
+ x_r, v_r, grad_ref=g_r, hess_ref=h_r
279
+ )
280
+ x_update.append(x_r)
281
+ val_update.append(v_c)
282
+ if g_c is not None:
283
+ grad_update.append(g_c)
284
+ if h_c is not None:
285
+ hess_update.append(h_c)
286
+
287
+ # Update the model
288
+ self._corr_model.update(
289
+ x_update, val_update, grads=grad_update, hess=hess_update
290
+ )
291
+
292
+ def store_reference_data(
293
+ self, x_ref: NDArray, val_ref: Any, grad_ref: Any = None, hess_ref: Any = None
294
+ ) -> None:
295
+ """Store reference data in the database.
296
+
297
+ Args:
298
+ x_ref: The design variables.
299
+ val_ref: The reference value.
300
+ grad_ref: The reference gradient.
301
+ hess_ref: The reference hessian.
302
+ """
303
+ ref_name = self.REF_TAG + self._orig_func.name
304
+ data_dict = {ref_name: val_ref}
305
+ if grad_ref is not None:
306
+ data_dict[Database.GRAD_TAG + ref_name] = grad_ref
307
+ if hess_ref is not None:
308
+ data_dict[Database.GRAD_TAG + Database.GRAD_TAG + ref_name] = hess_ref
309
+ self._corr_database.store(x_ref, data_dict)
310
+
311
+ def get_reference_data(
312
+ self,
313
+ x_ref: NDArray,
314
+ use_grad: bool = True,
315
+ use_hess: bool = False,
316
+ max_norm_thr: float | None = None,
317
+ min_norm_thr: float = 1e-4,
318
+ ) -> tuple[list[ndarray], list[Any], list[Any], list[Any]]:
319
+ """Get reference data.
320
+
321
+ Gets reference data from the database, close enough to x_ref and which matches
322
+ the required consistency. Points too close to x_ref are discarded.
323
+
324
+ Args:
325
+ x_ref: The input (1D array).
326
+ use_grad: The flag to look for data with grad corrections.
327
+ use_hess: The flag to look for data with hess corrections.
328
+ max_norm_thr: The threshold to discard points too far from x_ref.
329
+ min_norm_thr: The threshold to discard points too close from x_ref.
330
+
331
+ Returns:
332
+ x_list, val_list, grad_list, hess_list.
333
+ """
334
+ if use_hess and not use_grad:
335
+ msg = "Must allow use_grad if use_hess is activated"
336
+ raise ValueError(msg)
337
+ # Only data with the highest specified consistency will be kept
338
+ lookup_method = self._corr_database.get_function_history
339
+ if use_grad:
340
+ lookup_method = self._corr_database.get_gradient_history
341
+ if use_hess:
342
+ lookup_method = self._corr_database.get_hessian_history
343
+
344
+ ref_name = self.REF_TAG + self._orig_func.name
345
+
346
+ _, x_ref_hist = lookup_method(ref_name, with_x_vect=True)
347
+
348
+ if len(x_ref_hist) == 0:
349
+ # No valid data available
350
+ return [None] * 4
351
+
352
+ # Reverse the list to get the last points first
353
+ x_ref_hist = flip(x_ref_hist, 0) # reverse()
354
+ x_ref_hist = self._filter_x_hist(
355
+ x_ref_hist, x_ref, max_norm_thr=max_norm_thr, min_norm_thr=min_norm_thr
356
+ )
357
+
358
+ if len(x_ref_hist) == 0:
359
+ # No valid filtered data available, return
360
+ return [None] * 4
361
+
362
+ # Assemble data
363
+ get_data = self._corr_database.get_function_value
364
+
365
+ val_hist = [get_data(ref_name, x) for x in x_ref_hist]
366
+ grad_hist = None
367
+ hess_hist = None
368
+
369
+ if use_grad:
370
+ grad_tag = Database.GRAD_TAG + ref_name
371
+ grad_hist = [get_data(grad_tag, x) for x in x_ref_hist]
372
+ if use_hess:
373
+ hess_tag = Database.HESS_TAG + ref_name
374
+ hess_hist = [get_data(hess_tag, x) for x in x_ref_hist]
375
+
376
+ return x_ref_hist, val_hist, grad_hist, hess_hist
377
+
378
+ @staticmethod
379
+ def _filter_x_hist(
380
+ x_hist: NDArray,
381
+ x_ref: NDArray,
382
+ max_norm_thr: float | None = None,
383
+ min_norm_thr: float = 1e-4,
384
+ ) -> NDArray:
385
+ """Filter x_hist.
386
+
387
+ Filters x_hist using a proximity criterion with x_ref and max_norm as threshold.
388
+
389
+ Args:
390
+ x_hist: The list of points to filter.
391
+ x_ref: The reference point.
392
+ max_norm_thr: The threshold to discard points too far away.
393
+ min_norm_thr: The threshold to discard points too close.
394
+
395
+ Returns:
396
+ The filtered points.
397
+ """
398
+ if min_norm_thr <= 0.0:
399
+ msg = "tol must be > 0."
400
+ raise ValueError(msg)
401
+ x_hist = array(x_hist)
402
+ if max_norm_thr is not None:
403
+ # Apply the proximity criterion
404
+ x_hist = x_hist[where(norm(x_hist - x_ref, axis=1) <= max_norm_thr)]
405
+
406
+ # Delete points too close from x_ref
407
+ inds_delete = where(norm(x_hist - x_ref, axis=1) < min_norm_thr)[0]
408
+
409
+ if len(inds_delete) != 0:
410
+ x_hist = delete(x_hist, inds_delete, axis=0)
411
+ return x_hist
@@ -0,0 +1,124 @@
1
+ # Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
2
+ #
3
+ # This program is free software; you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public
5
+ # License version 3 as published by the Free Software Foundation.
6
+ #
7
+ # This program is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
+ # Lesser General Public License for more details.
11
+ #
12
+ # You should have received a copy of the GNU Lesser General Public License
13
+ # along with this program; if not, write to the Free Software Foundation,
14
+ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15
+
16
+ # Copyright (c) 2018 AIRBUS OPERATIONS
17
+ #
18
+ # Contributors:
19
+ # INITIAL AUTHORS - API and implementation and/or documentation
20
+ # :author: Romain Olivanti
21
+ # OTHER AUTHORS - MACROSCOPIC CHANGES
22
+ """Criticality criteria for bound-constrained problem."""
23
+
24
+ from __future__ import annotations
25
+
26
+ from typing import TYPE_CHECKING
27
+
28
+ from numpy import abs as np_abs
29
+ from numpy import inf
30
+ from numpy import min as np_min
31
+ from numpy import vstack
32
+ from numpy import where
33
+ from numpy import zeros_like
34
+ from numpy.linalg import norm
35
+
36
+ if TYPE_CHECKING:
37
+ from numpy.typing import NDArray
38
+
39
+ from gemseo_multi_fidelity.core.boxed_domain import BoxedDomain
40
+
41
+
42
+ def crit_out_1(
43
+ x_vect: NDArray,
44
+ grad: NDArray,
45
+ domain: BoxedDomain,
46
+ grad_tol: float,
47
+ bound_tol: float,
48
+ ) -> float:
49
+ """See Mouffe's thesis.
50
+
51
+ Args:
52
+ x_vect: The point.
53
+ grad: The gradient at point.
54
+ domain: The BoxedDomain.
55
+ grad_tol: The gradient tolerance.
56
+ bound_tol: The tolerance on bounds.
57
+
58
+ Returns:
59
+ The criticality.
60
+ """
61
+ return norm(compute_crit_vect(x_vect, grad, domain, grad_tol, bound_tol), ord=1)
62
+
63
+
64
+ def crit_in_inf(
65
+ x_vect: NDArray,
66
+ grad: NDArray,
67
+ domain: BoxedDomain,
68
+ grad_tol: float,
69
+ bound_tol: float,
70
+ ) -> float:
71
+ """See Mouffe's thesis.
72
+
73
+ Args:
74
+ x_vect: The point.
75
+ grad: The gradient at point.
76
+ domain: The BoxedDomain.
77
+ grad_tol: The gradient tolerance.
78
+ bound_tol: The tolerance on bounds.
79
+
80
+ Returns:
81
+ The criticality.
82
+ """
83
+ return norm(compute_crit_vect(x_vect, grad, domain, grad_tol, bound_tol), ord=inf)
84
+
85
+
86
+ def compute_crit_vect(
87
+ x_vect: NDArray,
88
+ grad: NDArray,
89
+ domain: BoxedDomain,
90
+ grad_tol: float,
91
+ bound_tol: float,
92
+ ) -> NDArray:
93
+ """Compute the criticality vect at x_vect with respect to the specified domain.
94
+
95
+ Args:
96
+ x_vect: The point to test.
97
+ grad: The gradient at x_vect.
98
+ domain: The BoxedDomain.
99
+ grad_tol: The gradient tolerance.
100
+ bound_tol: The tolerance on bounds.
101
+
102
+ Returns:
103
+ The criticality vect
104
+ """
105
+ crit_ratio = grad_tol / bound_tol
106
+
107
+ # Compute bounds proximity
108
+ dist_upper = domain.get_dist_upper(x_vect)
109
+ dist_lower = domain.get_dist_lower(x_vect)
110
+
111
+ crit = zeros_like(x_vect)
112
+ # Positive grad
113
+ pos_grad = where(grad > 0.0)[0]
114
+ # Check if lower bounds are active
115
+ crit[pos_grad] = np_min(
116
+ np_abs(vstack((grad[pos_grad], crit_ratio * dist_lower[pos_grad]))), axis=0
117
+ )
118
+ # Negative grad
119
+ neg_grad = where(grad < 0.0)[0]
120
+ # Check if upper bounds are active
121
+ crit[neg_grad] = np_min(
122
+ np_abs(vstack((grad[neg_grad], crit_ratio * dist_upper[neg_grad]))), axis=0
123
+ )
124
+ return crit