pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. pyagrum/__init__.py +165 -0
  2. pyagrum/_pyagrum.so +0 -0
  3. pyagrum/bnmixture/BNMInference.py +268 -0
  4. pyagrum/bnmixture/BNMLearning.py +376 -0
  5. pyagrum/bnmixture/BNMixture.py +464 -0
  6. pyagrum/bnmixture/__init__.py +60 -0
  7. pyagrum/bnmixture/notebook.py +1058 -0
  8. pyagrum/causal/_CausalFormula.py +280 -0
  9. pyagrum/causal/_CausalModel.py +436 -0
  10. pyagrum/causal/__init__.py +81 -0
  11. pyagrum/causal/_causalImpact.py +356 -0
  12. pyagrum/causal/_dSeparation.py +598 -0
  13. pyagrum/causal/_doAST.py +761 -0
  14. pyagrum/causal/_doCalculus.py +361 -0
  15. pyagrum/causal/_doorCriteria.py +374 -0
  16. pyagrum/causal/_exceptions.py +95 -0
  17. pyagrum/causal/_types.py +61 -0
  18. pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
  19. pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
  20. pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
  21. pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
  22. pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
  23. pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
  24. pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
  25. pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
  26. pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
  27. pyagrum/causal/notebook.py +171 -0
  28. pyagrum/clg/CLG.py +658 -0
  29. pyagrum/clg/GaussianVariable.py +111 -0
  30. pyagrum/clg/SEM.py +312 -0
  31. pyagrum/clg/__init__.py +63 -0
  32. pyagrum/clg/canonicalForm.py +408 -0
  33. pyagrum/clg/constants.py +54 -0
  34. pyagrum/clg/forwardSampling.py +202 -0
  35. pyagrum/clg/learning.py +776 -0
  36. pyagrum/clg/notebook.py +480 -0
  37. pyagrum/clg/variableElimination.py +271 -0
  38. pyagrum/common.py +60 -0
  39. pyagrum/config.py +319 -0
  40. pyagrum/ctbn/CIM.py +513 -0
  41. pyagrum/ctbn/CTBN.py +573 -0
  42. pyagrum/ctbn/CTBNGenerator.py +216 -0
  43. pyagrum/ctbn/CTBNInference.py +459 -0
  44. pyagrum/ctbn/CTBNLearner.py +161 -0
  45. pyagrum/ctbn/SamplesStats.py +671 -0
  46. pyagrum/ctbn/StatsIndepTest.py +355 -0
  47. pyagrum/ctbn/__init__.py +79 -0
  48. pyagrum/ctbn/constants.py +54 -0
  49. pyagrum/ctbn/notebook.py +264 -0
  50. pyagrum/defaults.ini +199 -0
  51. pyagrum/deprecated.py +95 -0
  52. pyagrum/explain/_ComputationCausal.py +75 -0
  53. pyagrum/explain/_ComputationConditional.py +48 -0
  54. pyagrum/explain/_ComputationMarginal.py +48 -0
  55. pyagrum/explain/_CustomShapleyCache.py +110 -0
  56. pyagrum/explain/_Explainer.py +176 -0
  57. pyagrum/explain/_Explanation.py +70 -0
  58. pyagrum/explain/_FIFOCache.py +54 -0
  59. pyagrum/explain/_ShallCausalValues.py +204 -0
  60. pyagrum/explain/_ShallConditionalValues.py +155 -0
  61. pyagrum/explain/_ShallMarginalValues.py +155 -0
  62. pyagrum/explain/_ShallValues.py +296 -0
  63. pyagrum/explain/_ShapCausalValues.py +208 -0
  64. pyagrum/explain/_ShapConditionalValues.py +126 -0
  65. pyagrum/explain/_ShapMarginalValues.py +191 -0
  66. pyagrum/explain/_ShapleyValues.py +298 -0
  67. pyagrum/explain/__init__.py +81 -0
  68. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  69. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  70. pyagrum/explain/_explInformationGraph.py +264 -0
  71. pyagrum/explain/notebook/__init__.py +54 -0
  72. pyagrum/explain/notebook/_bar.py +142 -0
  73. pyagrum/explain/notebook/_beeswarm.py +174 -0
  74. pyagrum/explain/notebook/_showShapValues.py +97 -0
  75. pyagrum/explain/notebook/_waterfall.py +220 -0
  76. pyagrum/explain/shapley.py +225 -0
  77. pyagrum/lib/__init__.py +46 -0
  78. pyagrum/lib/_colors.py +390 -0
  79. pyagrum/lib/bn2graph.py +299 -0
  80. pyagrum/lib/bn2roc.py +1026 -0
  81. pyagrum/lib/bn2scores.py +217 -0
  82. pyagrum/lib/bn_vs_bn.py +605 -0
  83. pyagrum/lib/cn2graph.py +305 -0
  84. pyagrum/lib/discreteTypeProcessor.py +1102 -0
  85. pyagrum/lib/discretizer.py +58 -0
  86. pyagrum/lib/dynamicBN.py +390 -0
  87. pyagrum/lib/explain.py +57 -0
  88. pyagrum/lib/export.py +84 -0
  89. pyagrum/lib/id2graph.py +258 -0
  90. pyagrum/lib/image.py +387 -0
  91. pyagrum/lib/ipython.py +307 -0
  92. pyagrum/lib/mrf2graph.py +471 -0
  93. pyagrum/lib/notebook.py +1821 -0
  94. pyagrum/lib/proba_histogram.py +552 -0
  95. pyagrum/lib/utils.py +138 -0
  96. pyagrum/pyagrum.py +31495 -0
  97. pyagrum/skbn/_MBCalcul.py +242 -0
  98. pyagrum/skbn/__init__.py +49 -0
  99. pyagrum/skbn/_learningMethods.py +282 -0
  100. pyagrum/skbn/_utils.py +297 -0
  101. pyagrum/skbn/bnclassifier.py +1014 -0
  102. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
  103. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
  104. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
  105. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
  106. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
  107. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
@@ -0,0 +1,396 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ import pandas as pd
42
+ import numpy as np
43
+
44
+ from typing import Any
45
+
46
+ from sklearn.base import clone
47
+
48
+ from ._learners import learnerFromString
49
+
50
+
51
+ class SimplePlugIn:
52
+ """
53
+ Uses the (original) Frontdoor Adjustment Formula to derive
54
+ the plug-in estimator. Does not account for covariates.
55
+ Based on Guo et al. (2023).
56
+ (see https://www.jstor.org/stable/2337329).
57
+ """
58
+
59
+ def __init__(self, learner: str | Any | None = None, propensity_learner: str | Any | None = None) -> None:
60
+ """
61
+ Initialize the Frontdoor Adjustment estimator.
62
+
63
+ Parameters
64
+ ----------
65
+ learner: str or object, optional
66
+ Estimator for outcome variable.
67
+ If not provided, defaults to LinearRegression.
68
+ propensity_learner: str |or object, optional
69
+ Estimator for treatment proability.
70
+ If not provided, defaults to LogisticRegression.
71
+ """
72
+
73
+ if learner is None:
74
+ self.learner = learnerFromString("LinearRegression")
75
+ elif isinstance(learner, str):
76
+ self.learner = learnerFromString(learner)
77
+ else:
78
+ self.learner = clone(learner)
79
+
80
+ if propensity_learner is None:
81
+ self.propensity_learner = learnerFromString("LogisticRegression")
82
+ elif isinstance(propensity_learner, str):
83
+ self.propensity_learner = learnerFromString(propensity_learner)
84
+ else:
85
+ self.propensity_learner = clone(propensity_learner)
86
+
87
+ self.treatment_probability = None
88
+
89
+ def fit(
90
+ self,
91
+ X: np.matrix | np.ndarray | pd.DataFrame,
92
+ treatment: np.ndarray | pd.Series,
93
+ y: np.ndarray | pd.Series,
94
+ M: np.matrix | np.ndarray | pd.DataFrame,
95
+ ) -> None:
96
+ """
97
+ Fit the inference model.
98
+
99
+ Parameters
100
+ ----------
101
+ X: np.matrix or np.ndarray or pd.DataFrame
102
+ The matrix of covariates.
103
+ treatment: np.ndarray or pd.Series
104
+ The treatment assignment vector.
105
+ y: np.ndarray or pd.Series,
106
+ The outcome vector.
107
+ M: np.matrix or np.ndarray or pd.DataFrame
108
+ The mediator matrix.
109
+ """
110
+
111
+ self.learner.fit(X=pd.concat([pd.DataFrame(M), pd.DataFrame(treatment)], axis=1), y=y)
112
+
113
+ self.propensity_learner.fit(X=pd.DataFrame(M), y=treatment)
114
+
115
+ self.treatment_probability = treatment.sum() / treatment.count()
116
+
117
+ def predict(
118
+ self,
119
+ X: np.matrix | np.ndarray | pd.DataFrame,
120
+ treatment: np.ndarray | pd.Series,
121
+ y: np.ndarray | pd.Series,
122
+ M: np.matrix | np.ndarray | pd.DataFrame,
123
+ ) -> np.ndarray:
124
+ """
125
+ Predict the Idividual Causal Effect (ICE),
126
+ also referd to as the Individual Treatment Effect (ITE).
127
+
128
+ Parameters
129
+ ----------
130
+ X: np.matrix or np.ndarray or pd.DataFrame
131
+ The matrix of covariates.
132
+ treatment: np.ndarray or pd.Series
133
+ The treatment assignment vector.
134
+ y: np.ndarray or pd.Series,
135
+ The outcome vector.
136
+ M: np.matrix or np.ndarray or pd.DataFrame
137
+ The mediator matrix.
138
+
139
+ Returns
140
+ -------
141
+ np.ndarray
142
+ An array containing the predicted ICE.
143
+ """
144
+
145
+ M_control = pd.concat(
146
+ [
147
+ pd.DataFrame(M),
148
+ pd.DataFrame({self.learner.feature_names_in_[-1]: np.zeros(len(M))}, index=pd.DataFrame(M).index),
149
+ ],
150
+ axis=1,
151
+ )
152
+
153
+ M_treatment = pd.concat(
154
+ [
155
+ pd.DataFrame(M),
156
+ pd.DataFrame({self.learner.feature_names_in_[-1]: np.ones(len(M))}, index=pd.DataFrame(M).index),
157
+ ],
158
+ axis=1,
159
+ )
160
+
161
+ mu0 = self.learner.predict(X=M_control)
162
+ mu1 = self.learner.predict(X=M_treatment)
163
+
164
+ e = self.propensity_learner.predict_proba(X=M)[:, 1]
165
+ p = self.treatment_probability
166
+
167
+ return (e / p - (1 - e) / (1 - p)) * (mu1 * p + mu0 * (1 - p))
168
+
169
+ def estimate_ate(
170
+ self,
171
+ X: np.matrix | np.ndarray | pd.DataFrame,
172
+ treatment: np.ndarray | pd.Series,
173
+ y: np.ndarray | pd.Series,
174
+ M: np.matrix | np.ndarray | pd.DataFrame,
175
+ pretrain: bool = True,
176
+ ) -> float:
177
+ """
178
+ Predicts the Average Causal Effect (ACE),
179
+ also refered to as the Average Treatment Effect (ATE).
180
+ (The term ATE is used in the method name for compatibility purposes.)
181
+
182
+ Parameters
183
+ ----------
184
+ X: np.matrix or np.ndarray or pd.DataFrame
185
+ The matrix of covariates.
186
+ treatment: np.ndarray or pd.Series
187
+ The treatment assignment vector.
188
+ y: np.ndarray or pd.Series,
189
+ The outcome vector.
190
+ M: np.matrix or np.ndarray or pd.DataFrame
191
+ The mediator matrix.
192
+
193
+ Returns
194
+ -------
195
+ float
196
+ The value of the ACE.
197
+ """
198
+
199
+ return self.predict(X, treatment, y, M).mean()
200
+
201
+
202
+ class GeneralizedPlugIn:
203
+ """
204
+ Basic implementation of the second plug-in TMLE estimator.
205
+ Must provide covariates.
206
+ Based on Guo et al. (2023).
207
+ (see https://arxiv.org/abs/2312.10234).
208
+ """
209
+
210
+ def __init__(
211
+ self,
212
+ learner: str | Any | None = None,
213
+ conditional_outcome_learner: str | Any | None = None,
214
+ propensity_score_learner: str | Any | None = None,
215
+ pseudo_control_outcome_learner: str | Any | None = None,
216
+ pseudo_treatment_outcome_learner: str | Any | None = None,
217
+ ) -> None:
218
+ """
219
+ Initialize the Frontdoor Adjustment estimator.
220
+
221
+ Parameters
222
+ ----------
223
+ learner: str or Any, optional
224
+ Estimator for outcome variable.
225
+ If not provided, defaults to LinearRegression.
226
+ propensity_learner: str or Any, optional
227
+ Estimator for treatment proability.
228
+ If not provided, defaults to LogisticRegression.
229
+ """
230
+
231
+ if learner is None:
232
+ self.conditional_outcome_learner = learnerFromString("LinearRegression")
233
+ self.pseudo_outcome_learner = learnerFromString("LinearRegression")
234
+ self.propensity_score_learner = learnerFromString("LogisticRegression")
235
+ elif isinstance(learner, str):
236
+ self.conditional_outcome_learner = learnerFromString(learner)
237
+ self.pseudo_outcome_learner = learnerFromString(learner)
238
+ self.propensity_score_learner = learnerFromString(learner)
239
+ else:
240
+ self.conditional_outcome_learner = clone(learner)
241
+ self.pseudo_outcome_learner = clone(learner)
242
+ self.propensity_score_learner = clone(learner)
243
+
244
+ if conditional_outcome_learner is None:
245
+ self.conditional_outcome_learner = learnerFromString("LinearRegression")
246
+ elif isinstance(conditional_outcome_learner, str):
247
+ self.conditional_outcome_learner = learnerFromString(conditional_outcome_learner)
248
+ else:
249
+ self.conditional_outcome_learner = clone(conditional_outcome_learner)
250
+
251
+ if pseudo_control_outcome_learner is None:
252
+ self.pseudo_control_outcome_learner = learnerFromString("LinearRegression")
253
+ elif isinstance(pseudo_control_outcome_learner, str):
254
+ self.pseudo_control_outcome_learner = learnerFromString(pseudo_control_outcome_learner)
255
+ else:
256
+ self.pseudo_control_outcome_learner = clone(pseudo_control_outcome_learner)
257
+
258
+ if pseudo_treatment_outcome_learner is None:
259
+ self.pseudo_treatment_outcome_learner = learnerFromString("LinearRegression")
260
+ elif isinstance(pseudo_treatment_outcome_learner, str):
261
+ self.pseudo_treatment_outcome_learner = learnerFromString(pseudo_treatment_outcome_learner)
262
+ else:
263
+ self.pseudo_treatment_outcome_learner = clone(pseudo_treatment_outcome_learner)
264
+
265
+ if propensity_score_learner is None:
266
+ self.propensity_score_learner = learnerFromString("LogisticRegression")
267
+ elif isinstance(propensity_score_learner, str):
268
+ self.propensity_score_learner = learnerFromString(propensity_score_learner)
269
+ else:
270
+ self.propensity_score_learner = clone(propensity_score_learner)
271
+
272
+ def fit(
273
+ self,
274
+ X: np.matrix | np.ndarray | pd.DataFrame,
275
+ treatment: np.ndarray | pd.Series,
276
+ y: np.ndarray | pd.Series,
277
+ M: np.matrix | np.ndarray | pd.DataFrame,
278
+ ) -> None:
279
+ """
280
+ Fit the inference model.
281
+
282
+ Parameters
283
+ ----------
284
+ X: np.matrix or np.ndarray or pd.DataFrame
285
+ The matrix of covariates.
286
+ treatment: np.ndarray or pd.Series
287
+ The treatment assignment vector.
288
+ y: np.ndarray or pd.Series,
289
+ The outcome vector.
290
+ M: np.matrix or np.ndarray or pd.DataFrame
291
+ The mediator matrix.
292
+ """
293
+
294
+ self.conditional_outcome_learner.fit(
295
+ X=pd.concat([pd.DataFrame(M), pd.DataFrame(X), pd.DataFrame(treatment)], axis=1), y=y
296
+ )
297
+
298
+ self.propensity_score_learner.fit(X=pd.DataFrame(X), y=treatment)
299
+
300
+ def predict(
301
+ self,
302
+ X: np.matrix | np.ndarray | pd.DataFrame,
303
+ treatment: np.ndarray | pd.Series,
304
+ y: np.ndarray | pd.Series,
305
+ M: np.matrix | np.ndarray | pd.DataFrame,
306
+ ) -> np.ndarray:
307
+ """
308
+ Predict the Idividual Causal Effect (ICE),
309
+ also referd to as the Individual Treatment Effect (ITE).
310
+
311
+ Parameters
312
+ ----------
313
+ X: np.matrix or np.ndarray or pd.DataFrame
314
+ The matrix of covariates.
315
+ treatment: np.ndarray or pd.Series
316
+ The treatment assignment vector.
317
+ y: np.ndarray or pd.Series,
318
+ The outcome vector.
319
+ M: np.matrix or np.ndarray or pd.DataFrame
320
+ The mediator matrix.
321
+
322
+ Returns
323
+ -------
324
+ np.ndarray
325
+ An array containing the predicted ICE.
326
+ """
327
+
328
+ def xi(m, x):
329
+ mu = self.conditional_outcome_learner.predict
330
+ pi = self.propensity_score_learner.predict_proba
331
+
332
+ MX_control = pd.concat(
333
+ [
334
+ pd.DataFrame(m),
335
+ pd.DataFrame(x),
336
+ pd.DataFrame(
337
+ {self.conditional_outcome_learner.feature_names_in_[-1]: np.zeros(len(m))}, index=pd.DataFrame(m).index
338
+ ),
339
+ ],
340
+ axis=1,
341
+ )
342
+
343
+ MX_treatment = pd.concat(
344
+ [
345
+ pd.DataFrame(m),
346
+ pd.DataFrame(x),
347
+ pd.DataFrame(
348
+ {self.conditional_outcome_learner.feature_names_in_[-1]: np.ones(len(m))}, index=pd.DataFrame(m).index
349
+ ),
350
+ ],
351
+ axis=1,
352
+ )
353
+ return mu(MX_control) * pi(x)[:, 0] + mu(MX_treatment) * pi(x)[:, 1]
354
+
355
+ MX_control_empirical = pd.concat([pd.DataFrame(M), pd.DataFrame(X)], axis=1)
356
+
357
+ self.pseudo_control_outcome_learner.fit(X=X[treatment == 0], y=xi(M[treatment == 0], X[treatment == 0]))
358
+
359
+ self.pseudo_treatment_outcome_learner.fit(X=X[treatment == 1], y=xi(M[treatment == 1], X[treatment == 1]))
360
+
361
+ gamma0 = self.pseudo_control_outcome_learner.predict(X)
362
+ gamma1 = self.pseudo_treatment_outcome_learner.predict(X)
363
+
364
+ return gamma1 - gamma0
365
+
366
+ def estimate_ate(
367
+ self,
368
+ X: np.matrix | np.ndarray | pd.DataFrame,
369
+ treatment: np.ndarray | pd.Series,
370
+ y: np.ndarray | pd.Series,
371
+ M: np.matrix | np.ndarray | pd.DataFrame,
372
+ pretrain: bool = True,
373
+ ) -> float:
374
+ """
375
+ Predicts the Average Causal Effect (ACE),
376
+ also refered to as the Average Treatment Effect (ATE).
377
+ (The term ATE is used in the method name for compatibility purposes.)
378
+
379
+ Parameters
380
+ ----------
381
+ X: np.matrix or np.ndarray or pd.DataFrame
382
+ The matrix of covariates.
383
+ treatment: np.ndarray or pd.Series
384
+ The treatment assignment vector.
385
+ y: np.ndarray or pd.Series,
386
+ The outcome vector.
387
+ M: np.matrix or np.ndarray or pd.DataFrame
388
+ The mediator matrix.
389
+
390
+ Returns
391
+ -------
392
+ float
393
+ The value of the ACE.
394
+ """
395
+
396
+ return self.predict(X, treatment, y, M).mean()
@@ -0,0 +1,118 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ from typing import Any
42
+ from ._utils import MisspecifiedLearnerError
43
+
44
+
45
+ def learnerFromString(learner_string: str) -> Any:
46
+ """
47
+ Retrieve a scikit-learn learner based on a string specification.
48
+
49
+ Parameters
50
+ ----------
51
+ learner_string: str
52
+ The string specifying a supported scikit-learn model.
53
+
54
+ Returns
55
+ -------
56
+ sklearn.base.BaseEstimator
57
+ An instance of a scikit-learn estimator corresponding to the
58
+ specified string. This object will be used as the learner.
59
+ """
60
+ match learner_string:
61
+ case "LinearRegression":
62
+ import sklearn.linear_model
63
+
64
+ return sklearn.linear_model.LinearRegression()
65
+ case "LogisticRegression":
66
+ import sklearn.linear_model
67
+
68
+ return sklearn.linear_model.LogisticRegression()
69
+ case "Ridge":
70
+ import sklearn.linear_model
71
+
72
+ return sklearn.linear_model.Ridge()
73
+ case "Lasso":
74
+ import sklearn.linear_model
75
+
76
+ return sklearn.linear_model.Lasso()
77
+ case "PoissonRegressor":
78
+ import sklearn.linear_model
79
+
80
+ return sklearn.linear_model.PoissonRegressor()
81
+ case "HuberRegressor":
82
+ import sklearn.linear_model
83
+
84
+ return sklearn.linear_model.HuberRegressor()
85
+ case "DecisionTreeRegressor":
86
+ import sklearn.linear_model
87
+
88
+ return sklearn.tree.DecisionTreeRegressor()
89
+ case "RandomForestRegressor":
90
+ import sklearn.ensemble
91
+
92
+ return sklearn.ensemble.RandomForestRegressor()
93
+ case "GradientBoostingRegressor":
94
+ import sklearn.ensemble
95
+
96
+ return sklearn.ensemble.GradientBoostingRegressor()
97
+ case "AdaBoostRegressor":
98
+ import sklearn.ensemble
99
+
100
+ return sklearn.ensemble.AdaBoostRegressor()
101
+ case "SVR":
102
+ import sklearn.svm
103
+
104
+ return sklearn.svm.SVR()
105
+ case "KNeighborsRegressor":
106
+ import sklearn.neighbors
107
+
108
+ return sklearn.neighbors.KNeighborsRegressor()
109
+ case "XGBRegressor":
110
+ import xgboost
111
+
112
+ return xgboost.XGBRegressor()
113
+ case "XGBClassifier":
114
+ import xgboost
115
+
116
+ return xgboost.XGBClassifier()
117
+ case _:
118
+ raise MisspecifiedLearnerError(learner_string)