pyAgrum-nightly 2.3.1.9.dev202512261765915415__cp310-abi3-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +172 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/WHEEL +4 -0
pyagrum/clg/learning.py
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
Using Rademacher Average to guarantee FWER (Family Wise Error Rate) in the independency test for Local Causal Discovery problem or for PC algorithm.
|
|
43
|
+
(see "Bounding the Family-Wise Error Rate in Local Causal Discover using Rademacher Averages", Dario Simionato, Fabio Vandin, 2022)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
import warnings
|
|
47
|
+
|
|
48
|
+
import pandas as pd
|
|
49
|
+
import numpy as np
|
|
50
|
+
import itertools
|
|
51
|
+
from sklearn.linear_model import LinearRegression
|
|
52
|
+
from typing import Dict, List, Set, Tuple, FrozenSet
|
|
53
|
+
|
|
54
|
+
from .constants import NodeId
|
|
55
|
+
from .CLG import CLG
|
|
56
|
+
from .GaussianVariable import GaussianVariable
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class CLGLearner:
|
|
60
|
+
"""
|
|
61
|
+
Using Rademacher Average to guarantee FWER(Family Wise Error Rate) in independency test.
|
|
62
|
+
(see "Bounding the Family-Wise Error Rate in Local Causal Discover using Rademacher Averages", Dario Simionato, Fabio Vandin, 2022)
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
_model: CLG
|
|
66
|
+
id2samples: Dict[NodeId, List]
|
|
67
|
+
_df: pd.DataFrame
|
|
68
|
+
sepset: Dict[Tuple[NodeId, NodeId], Set[NodeId]]
|
|
69
|
+
_SD: float
|
|
70
|
+
_V: Set[int]
|
|
71
|
+
_N: int
|
|
72
|
+
r_XYZ: Dict[Tuple[FrozenSet[NodeId], FrozenSet[NodeId]], List[float]]
|
|
73
|
+
|
|
74
|
+
def __init__(self, filename: str, *, n_sample: int = 15, fwer_delta: float = 0.05):
|
|
75
|
+
"""
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
filename : str
|
|
79
|
+
The path of the data file.
|
|
80
|
+
n_sample : int
|
|
81
|
+
amount of samples for Monte-Carlo Empirical Rademacher Average
|
|
82
|
+
fwer_delta : float ∈ (0,1]
|
|
83
|
+
Family-Wise Error Rate.
|
|
84
|
+
"""
|
|
85
|
+
self._model = CLG() # the CLG model
|
|
86
|
+
self.id2samples = {}
|
|
87
|
+
self.sepset = {}
|
|
88
|
+
|
|
89
|
+
self._df = pd.read_csv(filename)
|
|
90
|
+
# add all the variables to CLG model
|
|
91
|
+
for name in self._df.columns:
|
|
92
|
+
self._model.add(GaussianVariable(name, np.mean(self._df[name]), np.std(self._df[name])))
|
|
93
|
+
|
|
94
|
+
# collect the samples to a dict whose keys are NodeID
|
|
95
|
+
for node in self._model.nodes():
|
|
96
|
+
self.id2samples[node] = self._df[self._model.name(node)].tolist()
|
|
97
|
+
|
|
98
|
+
self._V = set(self._model.nodes()) # set of NodeId
|
|
99
|
+
L = len(self._V) # the degree of NodeId set
|
|
100
|
+
self._N = L * (L - 1) * (2 ** (L - 3)) # the maximum number of hypotheses that coulbe be tested is N
|
|
101
|
+
self.r_XYZ = {}
|
|
102
|
+
|
|
103
|
+
# lazy computation of supremum deviation : if needed and _SD is None, then call suppremum_deviation
|
|
104
|
+
self._n_sample = n_sample
|
|
105
|
+
self._fwer_delta = fwer_delta
|
|
106
|
+
self._SD = None
|
|
107
|
+
|
|
108
|
+
def Pearson_coeff(self, X, Y, Z):
|
|
109
|
+
"""
|
|
110
|
+
Estimate Pearson's linear correlation(using linear regression when Z is not empty).
|
|
111
|
+
|
|
112
|
+
Parmeters
|
|
113
|
+
---------
|
|
114
|
+
X : NodeId
|
|
115
|
+
id of the first variable tested.
|
|
116
|
+
Y : NodeId
|
|
117
|
+
id of the second variable tested.
|
|
118
|
+
Z : Set[NodeId]
|
|
119
|
+
The conditioned variable's id set.
|
|
120
|
+
"""
|
|
121
|
+
K = len(self.id2samples[X]) # number of samples
|
|
122
|
+
|
|
123
|
+
r = np.zeros(K)
|
|
124
|
+
if Z == set(): # if Z is an empty set
|
|
125
|
+
# x and y are observations for X and Y
|
|
126
|
+
x = self.id2samples[X]
|
|
127
|
+
y = self.id2samples[Y]
|
|
128
|
+
|
|
129
|
+
else: # if Z isn't empty
|
|
130
|
+
feature_name = [self._model.name(z) for z in Z]
|
|
131
|
+
sample_Z = self._df[feature_name]
|
|
132
|
+
# x represent the residuals of the linear regression of the observations of the variables in Z on the ones in X
|
|
133
|
+
regressor_x = LinearRegression()
|
|
134
|
+
sample_X = self._df[self._model.name(X)]
|
|
135
|
+
regressor_x.fit(sample_Z, sample_X)
|
|
136
|
+
x = sample_X - regressor_x.predict(sample_Z)
|
|
137
|
+
# y represent the residuals of the linear regression of the observations of the variables in Z on the ones in Y
|
|
138
|
+
regressor_y = LinearRegression()
|
|
139
|
+
sample_Y = self._df[self._model.name(Y)]
|
|
140
|
+
regressor_y.fit(sample_Z, sample_Y)
|
|
141
|
+
y = sample_Y - regressor_y.predict(sample_Z)
|
|
142
|
+
|
|
143
|
+
# all vectors have been centered around 0
|
|
144
|
+
x -= np.mean(x)
|
|
145
|
+
y -= np.mean(y)
|
|
146
|
+
x /= np.std(x)
|
|
147
|
+
y /= np.std(y)
|
|
148
|
+
for i in range(K):
|
|
149
|
+
r[i] = (x[i] * y[i] * K) / (K - 1)
|
|
150
|
+
|
|
151
|
+
self.r_XYZ[tuple((frozenset({X, Y}), frozenset(Z)))] = r
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def generate_XYZ(l):
|
|
155
|
+
"""
|
|
156
|
+
Find all the possible combinations of X, Y and Z.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
List[Tuple[Set[NodeId], Set[NodeId]]]
|
|
161
|
+
All the possible combinations of X, Y and Z.
|
|
162
|
+
"""
|
|
163
|
+
s = set(l)
|
|
164
|
+
# find all possible pairs(without considering the order), ie XY
|
|
165
|
+
for X in s:
|
|
166
|
+
for Y in s:
|
|
167
|
+
if X < Y:
|
|
168
|
+
for Z in CLGLearner.generate_subsets(s - {X, Y}):
|
|
169
|
+
yield X, Y, Z
|
|
170
|
+
|
|
171
|
+
def supremum_deviation(self, n_sample: int, fwer_delta: float):
|
|
172
|
+
"""
|
|
173
|
+
Use n-MCERA to get supremum deviation.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
n_sample : int
|
|
178
|
+
The MC number n in n-MCERA.
|
|
179
|
+
fwer_delta : float ∈ (0,1]
|
|
180
|
+
Threshold.
|
|
181
|
+
|
|
182
|
+
Returns
|
|
183
|
+
-------
|
|
184
|
+
SD : float
|
|
185
|
+
The supremum deviation.
|
|
186
|
+
"""
|
|
187
|
+
K = len(self.id2samples[0]) # number of samples
|
|
188
|
+
|
|
189
|
+
# create sigma: the n_sample × K matrix of i.i.d. Rademacher random variables
|
|
190
|
+
sigma = np.zeros((n_sample, K))
|
|
191
|
+
for j in range(n_sample):
|
|
192
|
+
for i in range(K):
|
|
193
|
+
rademacher = np.random.randint(-1, 1)
|
|
194
|
+
if rademacher == 0:
|
|
195
|
+
rademacher = 1
|
|
196
|
+
sigma[j][i] = rademacher
|
|
197
|
+
|
|
198
|
+
# iteration over combinations of each pair of variables X,Y and set of conditioned variables Z
|
|
199
|
+
# calcul r_XYZ for every element X,Y,Z
|
|
200
|
+
for X, Y, Z in CLGLearner.generate_XYZ(self._model.nodes()):
|
|
201
|
+
self.Pearson_coeff(X, Y, Z)
|
|
202
|
+
|
|
203
|
+
# [a, b] is range of F
|
|
204
|
+
b = np.mean(list(self.r_XYZ.values())[0])
|
|
205
|
+
a = np.mean(list(self.r_XYZ.values())[0])
|
|
206
|
+
for r in self.r_XYZ.values():
|
|
207
|
+
if np.mean(r) >= b:
|
|
208
|
+
b = np.mean(r)
|
|
209
|
+
if np.mean(r) <= a:
|
|
210
|
+
a = np.mean(r)
|
|
211
|
+
|
|
212
|
+
# calcul n-MCERA: n-samples Monte-Carlo Empirical Rademacher Average
|
|
213
|
+
R = 0
|
|
214
|
+
for j in range(n_sample):
|
|
215
|
+
sup = -1 * np.inf
|
|
216
|
+
for r in self.r_XYZ.values():
|
|
217
|
+
temp = 0
|
|
218
|
+
for i in range(K):
|
|
219
|
+
temp += sigma[j][i] * r[i]
|
|
220
|
+
temp /= K
|
|
221
|
+
if sup < temp:
|
|
222
|
+
sup = temp
|
|
223
|
+
R += sup
|
|
224
|
+
R /= n_sample
|
|
225
|
+
|
|
226
|
+
# calcul the final SD: supremum_deviation
|
|
227
|
+
z = max(np.abs(a), np.abs(b))
|
|
228
|
+
c = np.abs(b - a)
|
|
229
|
+
|
|
230
|
+
temp = np.log(4 / (fwer_delta / self._N))
|
|
231
|
+
R_hat = R + 2 * z * np.sqrt(temp / (2 * n_sample * K))
|
|
232
|
+
SD = 2 * R_hat
|
|
233
|
+
SD = SD + np.sqrt(c * (4 * K * R_hat + c * temp) * temp) / K
|
|
234
|
+
SD = SD + c * temp / K
|
|
235
|
+
SD = SD + c * np.sqrt(temp / (2 * K))
|
|
236
|
+
|
|
237
|
+
self._n_sample = n_sample
|
|
238
|
+
self._fwer_delta = fwer_delta
|
|
239
|
+
self._SD = SD
|
|
240
|
+
|
|
241
|
+
return SD
|
|
242
|
+
|
|
243
|
+
def test_indep(self, X, Y, Z):
|
|
244
|
+
"""
|
|
245
|
+
Perform a standard statistical test and use Bonferroni correction to correct for multiple hypothesis testing.
|
|
246
|
+
|
|
247
|
+
Parameters
|
|
248
|
+
----------
|
|
249
|
+
X : NodeId
|
|
250
|
+
The id of the first variable tested.
|
|
251
|
+
Y : NodeId
|
|
252
|
+
The id of the second variable tested.
|
|
253
|
+
Z : Set[NodeId]
|
|
254
|
+
The conditioned variable's id set.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
bool
|
|
259
|
+
True if X and Y are indep given Z, False if not indep.
|
|
260
|
+
"""
|
|
261
|
+
if self._SD is None: # lazy supremum_deviation computation
|
|
262
|
+
self.supremum_deviation(self._n_sample, self._fwer_delta)
|
|
263
|
+
SD = self._SD
|
|
264
|
+
r_XYZ = self.r_XYZ[tuple((frozenset({X, Y}), frozenset(Z)))].mean()
|
|
265
|
+
|
|
266
|
+
if (r_XYZ - SD <= 0) and (r_XYZ + SD >= 0):
|
|
267
|
+
return True # X and Y are indep
|
|
268
|
+
else:
|
|
269
|
+
return False # X and Y are dep
|
|
270
|
+
|
|
271
|
+
@staticmethod
|
|
272
|
+
def generate_subsets(S: Set[NodeId]):
|
|
273
|
+
"""
|
|
274
|
+
Generator that iterates on all all the subsets of S (from the smallest to the biggest).
|
|
275
|
+
|
|
276
|
+
Parameters
|
|
277
|
+
----------
|
|
278
|
+
S : Set[NodeId]
|
|
279
|
+
The set of variables.
|
|
280
|
+
"""
|
|
281
|
+
l = list(S)
|
|
282
|
+
for i in range(len(l) + 1):
|
|
283
|
+
for z in itertools.combinations(l, i):
|
|
284
|
+
yield set(z)
|
|
285
|
+
|
|
286
|
+
def RAveL_PC(self, T):
|
|
287
|
+
"""
|
|
288
|
+
Find the Parent-Children of variable T with FWER lower than Delta.
|
|
289
|
+
|
|
290
|
+
Parameters
|
|
291
|
+
----------
|
|
292
|
+
T : NodeId
|
|
293
|
+
The id of the target variable T.
|
|
294
|
+
|
|
295
|
+
Returns
|
|
296
|
+
-------
|
|
297
|
+
Set[NodeId]
|
|
298
|
+
The Parent-Children of variable T with FWER lower than Delta.
|
|
299
|
+
"""
|
|
300
|
+
PC = self._V - {T}
|
|
301
|
+
|
|
302
|
+
for X in self._V - {T}:
|
|
303
|
+
for Z in self.generate_subsets(self._V - {X, T}):
|
|
304
|
+
if self.test_indep(T, X, Z):
|
|
305
|
+
PC = PC - {X}
|
|
306
|
+
return PC
|
|
307
|
+
|
|
308
|
+
def RAveL_MB(self, T: NodeId) -> Set[NodeId]:
|
|
309
|
+
"""
|
|
310
|
+
Find the Markov Boundary of variable T with FWER lower than Delta.
|
|
311
|
+
|
|
312
|
+
Parameters
|
|
313
|
+
----------
|
|
314
|
+
T : NodeId
|
|
315
|
+
The id of the target variable T.
|
|
316
|
+
|
|
317
|
+
Returns
|
|
318
|
+
-------
|
|
319
|
+
MB : Set[NodeId]
|
|
320
|
+
The Markov Boundary of variable T with FWER lower than Delta.
|
|
321
|
+
"""
|
|
322
|
+
# find PC(T)
|
|
323
|
+
MB = self.RAveL_PC(T)
|
|
324
|
+
|
|
325
|
+
# Add Spouse(T) to PC(T) in order to get MB(T)
|
|
326
|
+
for X in list(MB):
|
|
327
|
+
for Y in self.RAveL_PC(X):
|
|
328
|
+
if (Y not in MB) and (Y != T):
|
|
329
|
+
if not self.test_indep(T, Y, self._V - {Y, T}):
|
|
330
|
+
MB.add(Y)
|
|
331
|
+
|
|
332
|
+
return MB
|
|
333
|
+
|
|
334
|
+
def Repeat_II(self, order, C, l, verbose=False):
|
|
335
|
+
"""
|
|
336
|
+
This function is the second part of the Step1 of PC algorithm.
|
|
337
|
+
|
|
338
|
+
Parameters
|
|
339
|
+
----------
|
|
340
|
+
order : List[NodeId]
|
|
341
|
+
The order of the variables.
|
|
342
|
+
C : Dict[NodeId, Set[NodeId]]
|
|
343
|
+
The temporary skeleton.
|
|
344
|
+
l : int
|
|
345
|
+
The size of the sepset
|
|
346
|
+
verbose : bool
|
|
347
|
+
Whether to print.
|
|
348
|
+
|
|
349
|
+
Returns
|
|
350
|
+
-------
|
|
351
|
+
found_edge : bool
|
|
352
|
+
True if a new edge is found, False if not.
|
|
353
|
+
"""
|
|
354
|
+
found_edge = False
|
|
355
|
+
V = list(self._V) # set of NodeId
|
|
356
|
+
# Select a (new) ordered pair of vertices (Xi, Xj) that are adjacent in C and satisfy |C[Xi]\{Xj}| ≥ l, using order(V)
|
|
357
|
+
for i in range(len(V)):
|
|
358
|
+
Xi = order[i]
|
|
359
|
+
for j in range(len(V)):
|
|
360
|
+
Xj = order[j]
|
|
361
|
+
if i == j: # if i == j, skip
|
|
362
|
+
continue
|
|
363
|
+
if Xi not in C[Xj] or Xj not in C[Xi]: # if Xi and Xj are not adjacent in C
|
|
364
|
+
continue
|
|
365
|
+
if len(C[Xi] - {Xj}) < l: # if |C[Xi]\{Xj}| < l
|
|
366
|
+
continue
|
|
367
|
+
|
|
368
|
+
# III: Repeat
|
|
369
|
+
# Choose a (new) set S ⊆ C[Xi]\{Xj} with |S| = l
|
|
370
|
+
for S in itertools.combinations(C[Xi] - {Xj}, l):
|
|
371
|
+
# Check if Xi and Xj are conditionally independent given S
|
|
372
|
+
if self.test_indep(Xi, Xj, set(S)):
|
|
373
|
+
# Delete edge Xi − Xj from C
|
|
374
|
+
if verbose:
|
|
375
|
+
warnings.warn("{0} and {1} are conditionally independent given {2}".format(Xi, Xj, S))
|
|
376
|
+
C[Xi].remove(Xj)
|
|
377
|
+
C[Xj].remove(Xi)
|
|
378
|
+
# Let sepset(Xi,Xj) = sepset(Xj,Xi) = S
|
|
379
|
+
self.sepset[(Xi, Xj)] = set(S)
|
|
380
|
+
self.sepset[(Xj, Xi)] = set(S)
|
|
381
|
+
# III: Until Xi and Xj are no longer adjacent in C or all S ⊆ C[Xi]\{Xj} with |S| = l have been considered
|
|
382
|
+
found_edge = True
|
|
383
|
+
return found_edge
|
|
384
|
+
|
|
385
|
+
return found_edge
|
|
386
|
+
|
|
387
|
+
def Adjacency_search(self, order, verbose=False):
|
|
388
|
+
"""
|
|
389
|
+
This function is the first step of PC-algo: Adjacency Search.
|
|
390
|
+
Apply indep_test() to the first step of PC-algo for Adjacency Search.
|
|
391
|
+
|
|
392
|
+
Parameters
|
|
393
|
+
----------
|
|
394
|
+
order : List[NodeId]
|
|
395
|
+
A particular order of the Nodes.
|
|
396
|
+
verbose : bool
|
|
397
|
+
Whether to print the process of Adjacency Search.
|
|
398
|
+
|
|
399
|
+
Returns
|
|
400
|
+
-------
|
|
401
|
+
C : Dict[NodeId, Set[NodeId]]
|
|
402
|
+
The temporary skeleton.
|
|
403
|
+
sepset : Dict[Tuple[NodeId, NodeId], Set[NodeId]]
|
|
404
|
+
Sepset(which will be used in Step2&3 of PC-Algo).
|
|
405
|
+
"""
|
|
406
|
+
|
|
407
|
+
def all_satisfied(V, order, C, l):
|
|
408
|
+
"""
|
|
409
|
+
Check if all pairs of adjacent vertices (Xi,Xj) in C satisfy that the size of the neighbours of Xi (except Xj) is less or equal than l.
|
|
410
|
+
|
|
411
|
+
Parameters
|
|
412
|
+
----------
|
|
413
|
+
V : List[NodeId]
|
|
414
|
+
The list of NodeId.
|
|
415
|
+
order : List[NodeId]
|
|
416
|
+
A particular order of the Nodes.
|
|
417
|
+
C : Dict[NodeId, Set[NodeId]]
|
|
418
|
+
The temporary skeleton.
|
|
419
|
+
l : int
|
|
420
|
+
The size of the sepset.
|
|
421
|
+
|
|
422
|
+
Returns
|
|
423
|
+
-------
|
|
424
|
+
all_satisfied : bool
|
|
425
|
+
Whether all pairs of adjacent vertices (Xi,Xj) in C satisfy the constraint.
|
|
426
|
+
"""
|
|
427
|
+
all_satisfied = True
|
|
428
|
+
for i in range(len(V)):
|
|
429
|
+
Xi = order[i]
|
|
430
|
+
for j in range(len(V)):
|
|
431
|
+
if i == j: # if i == j, skip
|
|
432
|
+
continue
|
|
433
|
+
Xj = order[j]
|
|
434
|
+
if Xi in C[Xj] and Xj in C[Xi]: # if Xi and Xj are adjacent in C
|
|
435
|
+
if len(C[Xi] - {Xj}) > l:
|
|
436
|
+
all_satisfied = False
|
|
437
|
+
return all_satisfied
|
|
438
|
+
return all_satisfied
|
|
439
|
+
|
|
440
|
+
# Form the complete undirected graph C on the vertex set V
|
|
441
|
+
V = list(self._V) # set of NodeId
|
|
442
|
+
C = {v: set() for v in V} # C is shown by a Adjacency List
|
|
443
|
+
for i in range(len(V) - 1):
|
|
444
|
+
for j in range(i + 1, len(V)):
|
|
445
|
+
C[V[i]].add(V[j])
|
|
446
|
+
C[V[j]].add(V[i])
|
|
447
|
+
|
|
448
|
+
l = -1
|
|
449
|
+
# I: Repeat
|
|
450
|
+
while True:
|
|
451
|
+
l += 1
|
|
452
|
+
|
|
453
|
+
# II: Repeat
|
|
454
|
+
while True:
|
|
455
|
+
found_edge = self.Repeat_II(order, C, l, verbose)
|
|
456
|
+
# II: Until all ordered pairs of adjacent vertices (Xi,Xj) in C with C[Xi]\{Xj}| ≥ l have been considered
|
|
457
|
+
if not found_edge:
|
|
458
|
+
break
|
|
459
|
+
|
|
460
|
+
# I: Until all pairs of adjacent vertices (Xi,Xj) in C satisfy |C[Xi]\{Xj}| ≤ l
|
|
461
|
+
if all_satisfied(V, order, C, l):
|
|
462
|
+
return C, self.sepset
|
|
463
|
+
|
|
464
|
+
def three_rules(self, C, verbose=False):
|
|
465
|
+
"""
|
|
466
|
+
This function is the third step of PC-algo.
|
|
467
|
+
Orient as many of the remaining undirected edges as possible by repeatedly application of the three rules.
|
|
468
|
+
|
|
469
|
+
Parameters
|
|
470
|
+
----------
|
|
471
|
+
C : Dict[NodeId, Set[NodeId]]
|
|
472
|
+
The temporary skeleton.
|
|
473
|
+
verbose : bool
|
|
474
|
+
Whether to print the process of this function.
|
|
475
|
+
|
|
476
|
+
Returns
|
|
477
|
+
-------
|
|
478
|
+
C : Dict[NodeId, Set[NodeId]]
|
|
479
|
+
The final skeleton (of Step3).
|
|
480
|
+
"""
|
|
481
|
+
while True:
|
|
482
|
+
new_oriented = False
|
|
483
|
+
|
|
484
|
+
# Rule 1: Orient Xj − Xk into Xj -> Xk whenever there is a directed edge Xi -> Xj such that Xi and Xk are not adjacent
|
|
485
|
+
# (otherwise a new v-structure would be created)
|
|
486
|
+
for Xj in self._V:
|
|
487
|
+
for Xk in C[Xj]:
|
|
488
|
+
if Xj in C[Xk]: # Xj - Xk
|
|
489
|
+
for Xi in self._V - {Xk, Xj}:
|
|
490
|
+
if (Xj in C[Xi] and Xi not in C[Xj]) and (
|
|
491
|
+
Xi not in C[Xk] and Xk not in C[Xi]
|
|
492
|
+
): # Xi -> Xj - Xk such that Xi and Xk are not adjacent
|
|
493
|
+
# Orient Xj -> Xk
|
|
494
|
+
if verbose:
|
|
495
|
+
warnings.warn("Rule 1 applied:{0}->{1}".format(Xj, Xk))
|
|
496
|
+
C[Xk].remove(Xj)
|
|
497
|
+
new_oriented = True
|
|
498
|
+
break
|
|
499
|
+
# Check other Xk
|
|
500
|
+
if new_oriented:
|
|
501
|
+
break
|
|
502
|
+
|
|
503
|
+
# Rule 2: Orient Xi − Xj into Xi -> Xj whenever there is a chain Xi -> Xk -> Xj
|
|
504
|
+
# (otherwise a directed cycle is created)
|
|
505
|
+
for Xi in self._V:
|
|
506
|
+
for Xj in C[Xi]:
|
|
507
|
+
if Xi in C[Xj]: # Xi - Xj
|
|
508
|
+
for Xk in C[Xi] - {Xj}:
|
|
509
|
+
if (Xi not in C[Xk]) and (Xj in C[Xk] and Xk not in C[Xj]): # Xi -> Xk -> Xj
|
|
510
|
+
# Orient Xi -> Xj
|
|
511
|
+
if verbose:
|
|
512
|
+
warnings.warn("Rule 2 applied:{0}->{1}".format(Xi, Xj))
|
|
513
|
+
C[Xj].remove(Xi)
|
|
514
|
+
new_oriented = True
|
|
515
|
+
break
|
|
516
|
+
# Check other Xj
|
|
517
|
+
if new_oriented:
|
|
518
|
+
break
|
|
519
|
+
|
|
520
|
+
# Rule 3: Orient Xi − Xj into Xi → Xj whenever there are two chains Xi − Xk → Xj and Xi − Xl → Xj such that Xk and Xl are not adjacent
|
|
521
|
+
# (otherwise a new v-structure or a directed cycle is created)
|
|
522
|
+
for Xi in self._V:
|
|
523
|
+
for Xj in C[Xi]:
|
|
524
|
+
if Xi in C[Xj]: # Xi - Xj
|
|
525
|
+
for Xk, Xl in itertools.combinations(C[Xi] - {Xj}, 2): # Xk and Xl are not adjacent
|
|
526
|
+
if Xi in C[Xk] and Xi in C[Xl]: # Xi - Xk and Xi - Xl
|
|
527
|
+
if (Xj in C[Xk] and Xk not in C[Xj]) and (Xj in C[Xl] and Xl not in C[Xj]): # Xk -> Xj and Xl -> Xj
|
|
528
|
+
# Orient Xi -> Xj
|
|
529
|
+
if verbose:
|
|
530
|
+
warnings.warn("Rule 3 applied:{0}->{1}".format(Xi, Xj))
|
|
531
|
+
C[Xj].remove(Xi)
|
|
532
|
+
new_oriented = True
|
|
533
|
+
break
|
|
534
|
+
# Check other Xj
|
|
535
|
+
if new_oriented:
|
|
536
|
+
break
|
|
537
|
+
|
|
538
|
+
# Stop if no more edges can be oriented
|
|
539
|
+
if not new_oriented:
|
|
540
|
+
break
|
|
541
|
+
|
|
542
|
+
return C
|
|
543
|
+
|
|
544
|
+
def Step4(self, C, verbose=False):
|
|
545
|
+
"""
|
|
546
|
+
This function is the fourth step of PC-algo.
|
|
547
|
+
Orient the remaining undirected edge by comparing variances of two nodes.
|
|
548
|
+
|
|
549
|
+
Parameters
|
|
550
|
+
----------
|
|
551
|
+
C : Dict[NodeId, Set[NodeId]]
|
|
552
|
+
The temporary skeleton.
|
|
553
|
+
verbose : bool
|
|
554
|
+
Whether to print the process of Step4.
|
|
555
|
+
|
|
556
|
+
Returns
|
|
557
|
+
-------
|
|
558
|
+
C : Dict[NodeId, Set[NodeId]]
|
|
559
|
+
The final skeleton (of Step4).
|
|
560
|
+
new_oriented : bool
|
|
561
|
+
Whether there is a new edge oriented in the fourth step.
|
|
562
|
+
"""
|
|
563
|
+
new_oriented = False
|
|
564
|
+
# Rule 0: Orient Xi − Xj into Xi -> Xj whenever Var(Xi) <= Var(Xj) (but only once)
|
|
565
|
+
for Xi in self._V:
|
|
566
|
+
for Xj in C[Xi]:
|
|
567
|
+
if Xi in C[Xj]: # Xi - Xj
|
|
568
|
+
if np.std(self.id2samples[Xi]) <= np.std(self.id2samples[Xj]):
|
|
569
|
+
# Orient Xi -> Xj
|
|
570
|
+
if verbose:
|
|
571
|
+
warnings.warn("Rule 0 applied:{0}->{1}".format(Xi, Xj))
|
|
572
|
+
C[Xj].remove(Xi)
|
|
573
|
+
new_oriented = True
|
|
574
|
+
# We only apply Rule 0 once
|
|
575
|
+
return C, new_oriented
|
|
576
|
+
|
|
577
|
+
return C, new_oriented
|
|
578
|
+
|
|
579
|
+
def PC_algorithm(self, order, verbose=False):
|
|
580
|
+
"""
|
|
581
|
+
This function is an advanced version of PC-algo.
|
|
582
|
+
We use Indep_test_Rademacher() to replace indep_test() in PC-algo.
|
|
583
|
+
And we orient the undirected edges in the skeleton C by comparing the variances of the two nodes.
|
|
584
|
+
|
|
585
|
+
Parameters
|
|
586
|
+
----------
|
|
587
|
+
order : List[NodeId]
|
|
588
|
+
A particular order of the Nodes.
|
|
589
|
+
verbose : bool
|
|
590
|
+
Whether to print the process of the PC algorithm.
|
|
591
|
+
|
|
592
|
+
Returns
|
|
593
|
+
-------
|
|
594
|
+
C : Dict[NodeId, Set[NodeId]]
|
|
595
|
+
A directed graph DAG representing the causal structure.
|
|
596
|
+
"""
|
|
597
|
+
# Step 1: Apply Adjacency_search() to obtain a skeleton C and a set of sepsets
|
|
598
|
+
if verbose:
|
|
599
|
+
warnings.warn("Step 1: Apply Adjacency_search() to obtain a skeleton C and a set of sepsets")
|
|
600
|
+
C, sepset = self.Adjacency_search(order, verbose)
|
|
601
|
+
|
|
602
|
+
# Step 2: Find the v-structures
|
|
603
|
+
if verbose:
|
|
604
|
+
warnings.warn("Step 2: Find the v-structures")
|
|
605
|
+
for Xk in self._V:
|
|
606
|
+
for Xi, Xj in itertools.combinations(C[Xk], 2):
|
|
607
|
+
if (Xi, Xj) in sepset and Xk not in sepset[(Xi, Xj)]:
|
|
608
|
+
if Xi not in C[Xj] and Xj not in C[Xi]:
|
|
609
|
+
# Orient Xi -> Xk <- Xj
|
|
610
|
+
if (Xi in C[Xk] and Xk in C[Xi]) and (Xj in C[Xk] and Xk in C[Xj]):
|
|
611
|
+
if verbose:
|
|
612
|
+
warnings.warn("V-structure found:{0}->{1}<-{2}".format(Xi, Xk, Xj))
|
|
613
|
+
C[Xk].remove(Xi)
|
|
614
|
+
C[Xk].remove(Xj)
|
|
615
|
+
elif (Xi in C[Xk] and Xk in C[Xi]) and (Xj not in C[Xk] and Xk in C[Xj]):
|
|
616
|
+
if verbose:
|
|
617
|
+
warnings.warn("V-structure found:{0}->{1}<-{2}".format(Xi, Xk, Xj))
|
|
618
|
+
C[Xk].remove(Xi)
|
|
619
|
+
elif (Xi not in C[Xk] and Xk in C[Xi]) and (Xj in C[Xk] and Xk in C[Xj]):
|
|
620
|
+
if verbose:
|
|
621
|
+
warnings.warn("V-structure found:{0}->{1}<-{2}".format(Xi, Xk, Xj))
|
|
622
|
+
C[Xk].remove(Xj)
|
|
623
|
+
|
|
624
|
+
# Repeat the following steps until no more edges can be oriented by Step 4
|
|
625
|
+
while True:
|
|
626
|
+
# Step 3: Orient as many of the remaining undirected edges as possible by repeatedly application of the following three rules
|
|
627
|
+
if verbose:
|
|
628
|
+
warnings.warn(
|
|
629
|
+
"Step 3: Orient as many of the remaining undirected edges as possible by repeatedly application of the three rules"
|
|
630
|
+
)
|
|
631
|
+
C = self.three_rules(C, verbose)
|
|
632
|
+
|
|
633
|
+
# Step 4: Orient the remaining undirected edge by comparing variances of two nodes
|
|
634
|
+
if verbose:
|
|
635
|
+
warnings.warn("Step 4: Orient one remaining undirected edge by comparing variances of the two nodes")
|
|
636
|
+
C, new_oriented = self.Step4(C, verbose)
|
|
637
|
+
|
|
638
|
+
# Stop if no more edges can be oriented by Step 4
|
|
639
|
+
if not new_oriented:
|
|
640
|
+
break
|
|
641
|
+
|
|
642
|
+
# Return the final DAG skeleton
|
|
643
|
+
return C
|
|
644
|
+
|
|
645
|
+
def estimate_parameters(self, C):
|
|
646
|
+
"""
|
|
647
|
+
This function is used to estimate the parameters of the CLG model.
|
|
648
|
+
|
|
649
|
+
Parameters
|
|
650
|
+
----------
|
|
651
|
+
C : Dict[NodeId, Set[NodeId]]
|
|
652
|
+
A directed graph DAG representing the causal structure.
|
|
653
|
+
|
|
654
|
+
Returns
|
|
655
|
+
-------
|
|
656
|
+
id2mu : Dict[NodeId, float]
|
|
657
|
+
The estimated mean of each node.
|
|
658
|
+
id2sigma : Dict[NodeId, float]
|
|
659
|
+
The estimated variance of each node.
|
|
660
|
+
arc2coef : Dict[Tuple[NodeId, NodeId], float]
|
|
661
|
+
The estimated coefficients of each arc.
|
|
662
|
+
"""
|
|
663
|
+
# Initialization
|
|
664
|
+
parents = {}
|
|
665
|
+
to_estimate = []
|
|
666
|
+
id2mu = {}
|
|
667
|
+
id2sigma = {}
|
|
668
|
+
arc2coef = {}
|
|
669
|
+
|
|
670
|
+
# Find the parents of each node
|
|
671
|
+
for Xi in self._V:
|
|
672
|
+
parents[Xi] = set()
|
|
673
|
+
for Xj in self._V - {Xi}:
|
|
674
|
+
if Xi in C[Xj]:
|
|
675
|
+
parents[Xi].add(Xj)
|
|
676
|
+
|
|
677
|
+
# Find the root of the DAG
|
|
678
|
+
for Xi in self._V:
|
|
679
|
+
if len(parents[Xi]) == 0:
|
|
680
|
+
to_estimate.append(Xi)
|
|
681
|
+
|
|
682
|
+
# Loop the DAG from the root to the leaves
|
|
683
|
+
while len(to_estimate) > 0:
|
|
684
|
+
# Pop the first node in to_estimate as Xi
|
|
685
|
+
Xi = to_estimate.pop(0)
|
|
686
|
+
|
|
687
|
+
# Add sons of Xi to to_estimate
|
|
688
|
+
for Xj in C[Xi]:
|
|
689
|
+
to_estimate.append(Xj)
|
|
690
|
+
|
|
691
|
+
# If Xi has no parent
|
|
692
|
+
if len(parents[Xi]) == 0:
|
|
693
|
+
# Estimate the mean of Xi
|
|
694
|
+
id2mu[Xi] = np.mean(self.id2samples[Xi])
|
|
695
|
+
# Estimate the variance of Xi
|
|
696
|
+
id2sigma[Xi] = np.std(self.id2samples[Xi])
|
|
697
|
+
|
|
698
|
+
# If Xi has parents
|
|
699
|
+
if len(parents[Xi]) != 0:
|
|
700
|
+
# Create the feature matrix X
|
|
701
|
+
X = np.zeros((len(self.id2samples[Xi]), len(parents[Xi])))
|
|
702
|
+
for i in range(len(self.id2samples[Xi])):
|
|
703
|
+
for j, Xj in enumerate(parents[Xi]):
|
|
704
|
+
X[i][j] = self.id2samples[Xj][i]
|
|
705
|
+
|
|
706
|
+
# Create the target vector y
|
|
707
|
+
y = np.array(self.id2samples[Xi])
|
|
708
|
+
|
|
709
|
+
# Estimate the coefficients of the arcs from Xi's parents to Xi
|
|
710
|
+
linear = LinearRegression()
|
|
711
|
+
linear.fit(X, y)
|
|
712
|
+
for j, Xj in enumerate(parents[Xi]):
|
|
713
|
+
arc2coef[(Xj, Xi)] = linear.coef_[j]
|
|
714
|
+
|
|
715
|
+
# Estimate the mean of Xi
|
|
716
|
+
id2mu[Xi] = linear.intercept_
|
|
717
|
+
|
|
718
|
+
# Estimate the variance of Xi
|
|
719
|
+
id2sigma[Xi] = np.std(y - linear.predict(X))
|
|
720
|
+
|
|
721
|
+
return id2mu, id2sigma, arc2coef
|
|
722
|
+
|
|
723
|
+
def learnCLG(self):
|
|
724
|
+
"""
|
|
725
|
+
First use PC algorithm to learn the skeleton of the CLG model.
|
|
726
|
+
Then estimate the parameters of the CLG model.
|
|
727
|
+
Finally create a CLG model and return it.
|
|
728
|
+
|
|
729
|
+
Returns
|
|
730
|
+
-------
|
|
731
|
+
learned_clg : CLG
|
|
732
|
+
The learned CLG model.
|
|
733
|
+
"""
|
|
734
|
+
learned_clg = CLG()
|
|
735
|
+
|
|
736
|
+
# Use PC algorithm to learn the structure of the CLG model
|
|
737
|
+
C = self.PC_algorithm(order=self._model.nodes(), verbose=False)
|
|
738
|
+
|
|
739
|
+
# Estimate the parameters of the CLG model
|
|
740
|
+
id2mu, id2sigma, arc2coef = self.estimate_parameters(C)
|
|
741
|
+
|
|
742
|
+
# Add the nodes to the CLG model
|
|
743
|
+
for node in self._model.nodes():
|
|
744
|
+
learned_clg.add(
|
|
745
|
+
GaussianVariable(
|
|
746
|
+
name=self._model.variable(node).name(), mu=float(f"{id2mu[node]:.3f}"), sigma=float(f"{id2sigma[node]:.3f}")
|
|
747
|
+
)
|
|
748
|
+
)
|
|
749
|
+
# Add the arcs to the CLG model
|
|
750
|
+
for arc in arc2coef.keys():
|
|
751
|
+
learned_clg.addArc(arc[0], arc[1], float(f"{arc2coef[arc]:.2f}"))
|
|
752
|
+
|
|
753
|
+
return learned_clg
|
|
754
|
+
|
|
755
|
+
def fitParameters(self, clg):
|
|
756
|
+
"""
|
|
757
|
+
In this function, we fit the parameters of the CLG model.
|
|
758
|
+
|
|
759
|
+
Parameters
|
|
760
|
+
----------
|
|
761
|
+
clg : CLG
|
|
762
|
+
The CLG model to be changed its parameters.
|
|
763
|
+
"""
|
|
764
|
+
# Get the DAG of the CLG model
|
|
765
|
+
C = clg.dag2dict()
|
|
766
|
+
|
|
767
|
+
# Estimate the parameters of the CLG model
|
|
768
|
+
id2mu, id2sigma, arc2coef = self.estimate_parameters(C)
|
|
769
|
+
|
|
770
|
+
# Change the parameters of the CLG model
|
|
771
|
+
for node in clg.nodes():
|
|
772
|
+
clg.setMu(node, float(f"{id2mu[node]:.3f}"))
|
|
773
|
+
clg.setSigma(node, float(f"{id2sigma[node]:.3f}"))
|
|
774
|
+
|
|
775
|
+
for arc in arc2coef.keys():
|
|
776
|
+
clg.setCoef(arc[0], arc[1], float(f"{arc2coef[arc]:.2f}"))
|