pyAgrum-nightly 2.3.1.9.dev202512261765915415__cp310-abi3-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +172 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
import pyagrum as gum
|
|
42
|
+
import pyagrum.causal as csl
|
|
43
|
+
|
|
44
|
+
from collections import deque
|
|
45
|
+
|
|
46
|
+
EXCEPTION_TEXT = "\n(Call `.use[estimator_name]()` to select an estimator.)"
|
|
47
|
+
|
|
48
|
+
RCT_ESTIMATORS_LIST = "\n- CausalModelEstimator\n- DM"
|
|
49
|
+
|
|
50
|
+
BACKDOOR_ESTIMATORS_LIST = "\n- CausalModelEstimator\n- SLearner\n- TLearner\n- XLearner\n- PStratification\n- IPW"
|
|
51
|
+
|
|
52
|
+
FRONTDOOR_ESTIMATORS_LIST = "\n- CausalModelEstimator\n- SimplePlugIn\n- GeneralizedPlugIn"
|
|
53
|
+
|
|
54
|
+
IV_ESTIMATORS_LIST = "\n- CausalModelEstimator\n- Wald\n- WaldIPW\n- NormalizedWaldIPW\n- TSLS"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class MisspecifiedAdjustmentError(ValueError):
|
|
58
|
+
def __init__(self, before=None) -> None:
|
|
59
|
+
self.message = (
|
|
60
|
+
f"Please select a valid adjustment before {before}. \n"
|
|
61
|
+
"The supported adjustments are:"
|
|
62
|
+
"\n- randomized controlled trial\t\t: call `.useRCTAdjustment()`"
|
|
63
|
+
"\n- backdoor\t\t\t\t: call `.useBackdoorAdjustment()`"
|
|
64
|
+
"\n- generalized frontdoor\t\t\t: call `.useFrontdoorAdjustment()`"
|
|
65
|
+
"\n- generalized instrumental variable\t: call `.useIVAdjustment()`"
|
|
66
|
+
)
|
|
67
|
+
super().__init__(self.message)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class MisspecifiedLearnerError(ValueError):
|
|
71
|
+
def __init__(self, learner_name=None) -> None:
|
|
72
|
+
self.message = (
|
|
73
|
+
f"The specified learner string: `{learner_name}` is not "
|
|
74
|
+
"recognized or does not correspond to any supported learner.\n"
|
|
75
|
+
"Consider passing the appropriate scikit-learn estimator object "
|
|
76
|
+
"directly, which should implement the `.fit()`, `.predict()`, "
|
|
77
|
+
"and `.predict_proba()` methods, or use one of the following "
|
|
78
|
+
"supported learner strings:"
|
|
79
|
+
"\n- LinearRegression"
|
|
80
|
+
"\n- Ridge"
|
|
81
|
+
"\n- Lasso"
|
|
82
|
+
"\n- PoissonRegressor"
|
|
83
|
+
"\n- DecisionTreeRegressor"
|
|
84
|
+
"\n- RandomForestRegressor"
|
|
85
|
+
"\n- GradientBoostingRegressor"
|
|
86
|
+
"\n- AdaBoostRegressor"
|
|
87
|
+
"\n- SVR"
|
|
88
|
+
"\n- KNeighborsRegressor"
|
|
89
|
+
"\n- XGBRegressor"
|
|
90
|
+
"\n- XGBClassifier"
|
|
91
|
+
)
|
|
92
|
+
super().__init__(self.message)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class EmptyConditionError(ZeroDivisionError):
|
|
96
|
+
def __init__(self) -> None:
|
|
97
|
+
self.message = (
|
|
98
|
+
"No matching instances found in the data for the "
|
|
99
|
+
"provided conditions.\nPlease ensure the conditions "
|
|
100
|
+
"are correctly specified or consider using a Pandas "
|
|
101
|
+
"DataFrame with these conditions containing intervention "
|
|
102
|
+
"and control instances for estimation purposes."
|
|
103
|
+
)
|
|
104
|
+
super().__init__(self.message)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class InvalidConditionError(ValueError):
|
|
108
|
+
def __init__(self) -> None:
|
|
109
|
+
self.message = "Invalid Conditional.\nPlease use a Pandas DataFrame, string or Nonetype as the conditional."
|
|
110
|
+
super().__init__(self.message)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class RCTError(ValueError):
|
|
114
|
+
def __init__(self, estimator_name=None) -> None:
|
|
115
|
+
self.message = (
|
|
116
|
+
f"The specified estimator: '{estimator_name}' is not supported "
|
|
117
|
+
"by the Randomized Controlled Trial criterion. "
|
|
118
|
+
"\nPlease choose a supported estimator:"
|
|
119
|
+
+ RCT_ESTIMATORS_LIST
|
|
120
|
+
+ "\nIf the outcome variable is a cause of other covariates in the "
|
|
121
|
+
"causal graph, Backdoor estimators may also be used." + EXCEPTION_TEXT
|
|
122
|
+
)
|
|
123
|
+
super().__init__(self.message)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class BackdoorError(ValueError):
|
|
127
|
+
def __init__(self, estimator_name=None) -> None:
|
|
128
|
+
self.message = (
|
|
129
|
+
f"The specified estimator: '{estimator_name}' is not supported "
|
|
130
|
+
"by the backdoor criterion. "
|
|
131
|
+
"\nPlease choose a supported estimator:" + BACKDOOR_ESTIMATORS_LIST + EXCEPTION_TEXT
|
|
132
|
+
)
|
|
133
|
+
super().__init__(self.message)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class FrontdoorError(ValueError):
|
|
137
|
+
def __init__(self, estimator_name=None) -> None:
|
|
138
|
+
self.message = (
|
|
139
|
+
f"The specified estimator: '{estimator_name}' is not supported "
|
|
140
|
+
"by the (genralized) frontdoor criterion. "
|
|
141
|
+
"\nPlease choose a supported estimator:" + FRONTDOOR_ESTIMATORS_LIST + EXCEPTION_TEXT
|
|
142
|
+
)
|
|
143
|
+
super().__init__(self.message)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class IVError(ValueError):
|
|
147
|
+
def __init__(self, estimator_name=None) -> None:
|
|
148
|
+
self.message = (
|
|
149
|
+
f"The specified estimator: '{estimator_name}' is not supported "
|
|
150
|
+
"by the (conditional) instrumental variable criterion. "
|
|
151
|
+
"\nPlease choose a supported estimator:" + IV_ESTIMATORS_LIST + EXCEPTION_TEXT
|
|
152
|
+
)
|
|
153
|
+
super().__init__(self.message)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def RCT(causal_model: csl.CausalModel, intervention: str, outcome: str) -> set[str] | None:
|
|
157
|
+
"""
|
|
158
|
+
Determine the Randomized Controlled Trial (RCT) adjustment.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
intervention: str
|
|
163
|
+
Intervention (treatment) variable.
|
|
164
|
+
outcome: str
|
|
165
|
+
Outcome variable.
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
set[str] or None
|
|
170
|
+
Set with the names of the confounders if ignorability.
|
|
171
|
+
None if ignorability is not satisfied.
|
|
172
|
+
"""
|
|
173
|
+
cbn_without_T_Y = gum.BayesNet(causal_model.causalBN())
|
|
174
|
+
t = cbn_without_T_Y.idFromName(intervention)
|
|
175
|
+
y = cbn_without_T_Y.idFromName(outcome)
|
|
176
|
+
|
|
177
|
+
if cbn_without_T_Y.existsArc(t, y):
|
|
178
|
+
cbn_without_T_Y.eraseArc(t, y)
|
|
179
|
+
|
|
180
|
+
if csl._dSeparation.isDSep(cbn_without_T_Y, {t}, {y}, set()):
|
|
181
|
+
return {cbn_without_T_Y.variable(pa).name() for pa in cbn_without_T_Y.parents(y)}
|
|
182
|
+
else:
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _verifyFrontDoorDSep(cbn: gum.BayesNet, t: int, y: int, M: set[int], W: set[int]) -> bool:
|
|
187
|
+
"""
|
|
188
|
+
Verify the generalized frontdoor adjustment d-Sepatation assumptions.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
cbn: gum.BayesNet
|
|
193
|
+
The causal Baysian Network.
|
|
194
|
+
t: int
|
|
195
|
+
The intervention node ID.
|
|
196
|
+
y: int
|
|
197
|
+
The outcome node ID.
|
|
198
|
+
M: set[int]
|
|
199
|
+
The set of mediator node IDs.
|
|
200
|
+
W: set[int]
|
|
201
|
+
The set of confounder node IDs.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
bool
|
|
206
|
+
True if the M is d-Sep. from {t} in the mutilated graph without
|
|
207
|
+
the arcs t->M, and M is d-Sep. from {y} in the mutilated graph
|
|
208
|
+
without the arcs M->{y} and t and y are not neighbors.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
cbn_without_T_M = gum.BayesNet(cbn)
|
|
212
|
+
cbn_without_M_Y = gum.BayesNet(cbn)
|
|
213
|
+
|
|
214
|
+
for m in M:
|
|
215
|
+
if cbn_without_T_M.existsArc(t, m):
|
|
216
|
+
cbn_without_T_M.eraseArc(t, m)
|
|
217
|
+
if cbn_without_M_Y.existsArc(m, y):
|
|
218
|
+
cbn_without_M_Y.eraseArc(m, y)
|
|
219
|
+
|
|
220
|
+
res = (
|
|
221
|
+
csl._dSeparation.isDSep(cbn_without_T_M, {t}, M, W)
|
|
222
|
+
and csl._dSeparation.isDSep(cbn_without_M_Y, {y}, M, W | {t})
|
|
223
|
+
and t not in cbn.parents(y) | cbn.children(y)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return res
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def generalizedFrontDoor(causal_model: csl.CausalModel, intervention: str, outcome: str) -> tuple[set[str]] | None:
|
|
230
|
+
"""
|
|
231
|
+
Identify the generalised frontdoor adjustment set and covariates.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
intervention: str
|
|
236
|
+
Intervention (treatment) variable.
|
|
237
|
+
outcome: str
|
|
238
|
+
Outcome variable.
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
tuple[set[str]] or None
|
|
243
|
+
Set with the names of the mediators,
|
|
244
|
+
set with the names of covariates, or None if not applicable.
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
obn = causal_model.observationalBN()
|
|
248
|
+
cbn = causal_model.causalBN()
|
|
249
|
+
|
|
250
|
+
mediators = csl._doorCriteria.nodes_on_dipath(obn, obn.idFromName(intervention), obn.idFromName(outcome))
|
|
251
|
+
mediators = {obn.variable(m).name() for m in mediators}
|
|
252
|
+
|
|
253
|
+
confounders = set()
|
|
254
|
+
|
|
255
|
+
for m in mediators:
|
|
256
|
+
backdoor_T_M = causal_model.backDoor(intervention, m)
|
|
257
|
+
backdoor_M_Y = causal_model.backDoor(m, outcome)
|
|
258
|
+
backdoor_T_M = set() if backdoor_T_M is None else backdoor_T_M
|
|
259
|
+
backdoor_M_Y = set() if backdoor_M_Y is None else backdoor_M_Y
|
|
260
|
+
confounders |= backdoor_T_M | backdoor_M_Y
|
|
261
|
+
|
|
262
|
+
confounders = confounders - {intervention}
|
|
263
|
+
|
|
264
|
+
# Clone with latent variables:
|
|
265
|
+
# Sometime the causal structure is changed while cloning,
|
|
266
|
+
# so extra operations must be made
|
|
267
|
+
mutilated_causal_model = causal_model.clone()
|
|
268
|
+
|
|
269
|
+
for id in causal_model.latentVariablesIds():
|
|
270
|
+
childrens = cbn.children(id)
|
|
271
|
+
childrens = {cbn.variable(c).name() for c in childrens}
|
|
272
|
+
if cbn.variable(id).name() not in mutilated_causal_model.names().values():
|
|
273
|
+
mutilated_causal_model.addLatentVariable(cbn.variable(id).name(), tuple(childrens))
|
|
274
|
+
|
|
275
|
+
for c in confounders:
|
|
276
|
+
if mutilated_causal_model.existsArc(c, intervention):
|
|
277
|
+
mutilated_causal_model.eraseCausalArc(c, intervention)
|
|
278
|
+
if mutilated_causal_model.existsArc(c, outcome):
|
|
279
|
+
mutilated_causal_model.eraseCausalArc(c, outcome)
|
|
280
|
+
for m in mediators:
|
|
281
|
+
if mutilated_causal_model.existsArc(c, m):
|
|
282
|
+
mutilated_causal_model.eraseCausalArc(c, m)
|
|
283
|
+
|
|
284
|
+
frontdoor = mutilated_causal_model.frontDoor(cause=intervention, effect=outcome)
|
|
285
|
+
|
|
286
|
+
valid_fd = _verifyFrontDoorDSep(
|
|
287
|
+
cbn,
|
|
288
|
+
cbn.idFromName(intervention),
|
|
289
|
+
cbn.idFromName(outcome),
|
|
290
|
+
{cbn.idFromName(m) for m in mediators},
|
|
291
|
+
{cbn.idFromName(m) for m in confounders},
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return (None, None) if frontdoor is None or len(mediators) == 0 or not valid_fd else (frontdoor, confounders)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _findPath(
|
|
298
|
+
G: gum.UndiGraph,
|
|
299
|
+
a: int,
|
|
300
|
+
b: int,
|
|
301
|
+
) -> list[int]:
|
|
302
|
+
"""
|
|
303
|
+
Find a path in the mixed graph `G` from node `a` to node `b`.
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
G: gum.MixedGraph
|
|
308
|
+
The graph.
|
|
309
|
+
a: int
|
|
310
|
+
The starting node ID.
|
|
311
|
+
b: int
|
|
312
|
+
The ending node ID.
|
|
313
|
+
|
|
314
|
+
Returns
|
|
315
|
+
-------
|
|
316
|
+
list[int]
|
|
317
|
+
The path from node `a` to `b`.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
stack = deque()
|
|
321
|
+
stack.append((a, [a]))
|
|
322
|
+
visited = set()
|
|
323
|
+
|
|
324
|
+
while stack:
|
|
325
|
+
(node, path) = stack.pop()
|
|
326
|
+
|
|
327
|
+
if node == b:
|
|
328
|
+
return path
|
|
329
|
+
|
|
330
|
+
if node not in visited:
|
|
331
|
+
visited.add(node)
|
|
332
|
+
|
|
333
|
+
for neighbor in G.neighbours(node):
|
|
334
|
+
if neighbor not in visited:
|
|
335
|
+
stack.append((neighbor, path + [neighbor]))
|
|
336
|
+
|
|
337
|
+
return []
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _nearestSeparator(obn: gum.BayesNet, cbn: gum.BayesNet, t: int, y: int, z: int) -> set[int]:
|
|
341
|
+
"""
|
|
342
|
+
Find the nearest separator set in the `causal_model` according to `(y,w)`.
|
|
343
|
+
|
|
344
|
+
(see https://www.ijcai.org/Proceedings/15/Papers/457.pdf)
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
causal_model: csl.CausalModel
|
|
349
|
+
The causal graph.
|
|
350
|
+
t: int
|
|
351
|
+
The intervention node ID.
|
|
352
|
+
y: int
|
|
353
|
+
The outcome node ID.
|
|
354
|
+
z: int
|
|
355
|
+
The instrument node ID.
|
|
356
|
+
|
|
357
|
+
Returns
|
|
358
|
+
-------
|
|
359
|
+
set[int]
|
|
360
|
+
The nearest separator set in the mutilated graph of
|
|
361
|
+
`causal_model` with respect to `t`.
|
|
362
|
+
"""
|
|
363
|
+
|
|
364
|
+
M = obn.nodes()
|
|
365
|
+
W = set()
|
|
366
|
+
|
|
367
|
+
moralized_ancestral_graph = cbn.moralizedAncestralGraph({z, y})
|
|
368
|
+
|
|
369
|
+
while True:
|
|
370
|
+
# Moralized Graph controlling for W
|
|
371
|
+
csl._dSeparation._removeZ(moralized_ancestral_graph, W)
|
|
372
|
+
|
|
373
|
+
path = _findPath(moralized_ancestral_graph, y, z)
|
|
374
|
+
|
|
375
|
+
if path == list() or set(path[1:-1]) & M == set():
|
|
376
|
+
break
|
|
377
|
+
|
|
378
|
+
w = next((node for node in path[1:-1] if node in M), None)
|
|
379
|
+
if w is not None:
|
|
380
|
+
W.add(w)
|
|
381
|
+
|
|
382
|
+
if csl._dSeparation.isDSep(cbn, {z}, {y}, W):
|
|
383
|
+
return W
|
|
384
|
+
else:
|
|
385
|
+
return None
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _ancestralInstrument(causal_model: csl.CausalModel, t: int, y: int, z: int) -> set[int]:
|
|
389
|
+
"""
|
|
390
|
+
Find the ancetral instrument conditioning set `W` in the `causal_model`
|
|
391
|
+
with `t` as intervention, `y` as outcome and `z` as instrument.
|
|
392
|
+
|
|
393
|
+
(see https://www.ijcai.org/Proceedings/15/Papers/457.pdf)
|
|
394
|
+
|
|
395
|
+
Parameters
|
|
396
|
+
----------
|
|
397
|
+
causal_model: csl.CausalModel
|
|
398
|
+
The causal graph.
|
|
399
|
+
t: int
|
|
400
|
+
The intervention node ID.
|
|
401
|
+
y: int
|
|
402
|
+
The outcome node ID.
|
|
403
|
+
z: int
|
|
404
|
+
The instrument node ID.
|
|
405
|
+
|
|
406
|
+
Returns
|
|
407
|
+
-------
|
|
408
|
+
set[int]
|
|
409
|
+
the ancetral instrument conditioning set `W`.
|
|
410
|
+
"""
|
|
411
|
+
|
|
412
|
+
mutilated_obn = gum.BayesNet(causal_model.observationalBN())
|
|
413
|
+
mutilated_cbn = gum.BayesNet(causal_model.causalBN())
|
|
414
|
+
|
|
415
|
+
if mutilated_obn.existsArc(t, y):
|
|
416
|
+
mutilated_obn.eraseArc(t, y)
|
|
417
|
+
if mutilated_cbn.existsArc(t, y):
|
|
418
|
+
mutilated_cbn.eraseArc(t, y)
|
|
419
|
+
|
|
420
|
+
W = _nearestSeparator(mutilated_obn, mutilated_cbn, t, y, z)
|
|
421
|
+
if W is None or bool(W & mutilated_cbn.descendants(y)) or t in W:
|
|
422
|
+
return None
|
|
423
|
+
elif not csl._dSeparation.isDSep(mutilated_cbn, {z}, {t}, W):
|
|
424
|
+
return W - {t}
|
|
425
|
+
else:
|
|
426
|
+
return None
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def instrumentalVariable(causal_model: csl.CausalModel, intervention: str, outcome: str) -> tuple[set[str], set[str]]:
|
|
430
|
+
"""
|
|
431
|
+
Identifies the instrumental variables and covariates, using ancestral
|
|
432
|
+
instruments.
|
|
433
|
+
|
|
434
|
+
(see https://www.ijcai.org/Proceedings/15/Papers/457.pdf)
|
|
435
|
+
(see https://ftp.cs.ucla.edu/pub/stat_ser/r303-reprint.pdf)
|
|
436
|
+
|
|
437
|
+
Parameters
|
|
438
|
+
----------
|
|
439
|
+
intervention: str
|
|
440
|
+
Intervention (treatment) variable.
|
|
441
|
+
outcome: str
|
|
442
|
+
Outcome variable.
|
|
443
|
+
|
|
444
|
+
Returns
|
|
445
|
+
------
|
|
446
|
+
tuple[set[str], set[str]] or None
|
|
447
|
+
Set with the names of the instrumental variables,
|
|
448
|
+
"""
|
|
449
|
+
|
|
450
|
+
obn = causal_model.observationalBN()
|
|
451
|
+
|
|
452
|
+
t = intervention
|
|
453
|
+
y = outcome
|
|
454
|
+
if not isinstance(intervention, int):
|
|
455
|
+
t = obn.idFromName(intervention)
|
|
456
|
+
|
|
457
|
+
if not isinstance(outcome, int):
|
|
458
|
+
y = obn.idFromName(outcome)
|
|
459
|
+
|
|
460
|
+
tensor_instruments = obn.parents(intervention)
|
|
461
|
+
|
|
462
|
+
for z in tensor_instruments:
|
|
463
|
+
W = _ancestralInstrument(causal_model, t, y, z)
|
|
464
|
+
if W is not None:
|
|
465
|
+
return (obn.variable(z).name(), {obn.variable(w).name() for w in W})
|
|
466
|
+
return (None, None)
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
This file defines some helpers for handling causal concepts in notebooks
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
from typing import Union, Optional, Dict
|
|
46
|
+
import IPython
|
|
47
|
+
|
|
48
|
+
import pyagrum
|
|
49
|
+
import pyagrum.lib.notebook as gnb
|
|
50
|
+
import pyagrum.causal as csl
|
|
51
|
+
|
|
52
|
+
from pyagrum.causal._types import NameSet
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def getCausalModel(cm: csl.CausalModel, size=None) -> str:
|
|
56
|
+
"""
|
|
57
|
+
return a HTML representing the causal model
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
cm: CausalModel
|
|
62
|
+
the causal model
|
|
63
|
+
size: int|str
|
|
64
|
+
the size of the rendered graph
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
pydot.Dot
|
|
69
|
+
the dot representation
|
|
70
|
+
"""
|
|
71
|
+
if size is None:
|
|
72
|
+
size = pyagrum.config["causal", "default_graph_size"]
|
|
73
|
+
return gnb.getDot(cm.toDot(), size)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def showCausalModel(cm: csl.CausalModel, size=None):
|
|
77
|
+
"""
|
|
78
|
+
Shows a pydot svg representation of the causal DAG
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
cm: CausalModel
|
|
83
|
+
the causal model
|
|
84
|
+
size: int|str
|
|
85
|
+
the size of the rendered graph
|
|
86
|
+
"""
|
|
87
|
+
if size is None:
|
|
88
|
+
size = pyagrum.config["causal", "default_graph_size"]
|
|
89
|
+
gnb.showDot(cm.toDot(), size=size)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def getCausalImpact(
|
|
93
|
+
model: csl.CausalModel,
|
|
94
|
+
on: Union[str, NameSet],
|
|
95
|
+
doing: Union[str, NameSet],
|
|
96
|
+
knowing: Optional[NameSet] = None,
|
|
97
|
+
values: Optional[Dict[str, int]] = None,
|
|
98
|
+
):
|
|
99
|
+
"""
|
|
100
|
+
return a HTML representing of the three values defining a causal impact : formula, value, explanation
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
model: CausalModel
|
|
105
|
+
the causal model
|
|
106
|
+
on: str | Set[str]
|
|
107
|
+
the impacted variable(s)
|
|
108
|
+
doing: str | Set[str]
|
|
109
|
+
the interventions
|
|
110
|
+
knowing: str | Set[str]
|
|
111
|
+
the observations
|
|
112
|
+
values: Dict[str,int] default=None
|
|
113
|
+
value for certain variables
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
HTML
|
|
118
|
+
"""
|
|
119
|
+
formula, impact, explanation = csl.causalImpact(model, on, doing, knowing, values)
|
|
120
|
+
|
|
121
|
+
gnb.flow.clear()
|
|
122
|
+
gnb.flow.add(getCausalModel(model), caption="Causal Model")
|
|
123
|
+
|
|
124
|
+
if formula is None:
|
|
125
|
+
gnb.flow.add(explanation, caption="Impossible")
|
|
126
|
+
else:
|
|
127
|
+
gnb.flow.add(
|
|
128
|
+
"\n\n$$\n\\begin{equation*}" + formula.toLatex() + "\\end{equation*}\n$$\n\n",
|
|
129
|
+
caption="Explanation : " + explanation,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if formula is None:
|
|
133
|
+
res = "No result"
|
|
134
|
+
else:
|
|
135
|
+
if impact.variable(0).domainSize() < 5:
|
|
136
|
+
res = impact
|
|
137
|
+
else:
|
|
138
|
+
res = gnb.getProba(impact)
|
|
139
|
+
gnb.flow.add(res, caption="Impact")
|
|
140
|
+
|
|
141
|
+
return gnb.flow.html()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def showCausalImpact(
|
|
145
|
+
model: csl.CausalModel,
|
|
146
|
+
on: Union[str, NameSet],
|
|
147
|
+
doing: Union[str, NameSet],
|
|
148
|
+
knowing: Optional[NameSet] = None,
|
|
149
|
+
values: Optional[Dict[str, int]] = None,
|
|
150
|
+
):
|
|
151
|
+
"""
|
|
152
|
+
display a HTML representing of the three values defining a causal impact : formula, value, explanation
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
model: CausalModel
|
|
157
|
+
the causal model
|
|
158
|
+
on: str | Set[str]
|
|
159
|
+
the impacted variable(s)
|
|
160
|
+
doing: str | Set[str]
|
|
161
|
+
the interventions
|
|
162
|
+
knowing: str | Set[str]
|
|
163
|
+
the observations
|
|
164
|
+
values: Dict[str,int] default=None
|
|
165
|
+
value for certain variables
|
|
166
|
+
"""
|
|
167
|
+
html = getCausalImpact(model, on, doing, knowing, values)
|
|
168
|
+
IPython.display.display(html)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
csl.CausalModel._repr_html_ = lambda self: gnb.getDot(self.toDot(), size=pyagrum.config["causal", "default_graph_size"])
|
|
172
|
+
csl.CausalFormula._repr_html_ = lambda self: f"$${self.toLatex()}$$"
|