pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. pyagrum/__init__.py +165 -0
  2. pyagrum/_pyagrum.so +0 -0
  3. pyagrum/bnmixture/BNMInference.py +268 -0
  4. pyagrum/bnmixture/BNMLearning.py +376 -0
  5. pyagrum/bnmixture/BNMixture.py +464 -0
  6. pyagrum/bnmixture/__init__.py +60 -0
  7. pyagrum/bnmixture/notebook.py +1058 -0
  8. pyagrum/causal/_CausalFormula.py +280 -0
  9. pyagrum/causal/_CausalModel.py +436 -0
  10. pyagrum/causal/__init__.py +81 -0
  11. pyagrum/causal/_causalImpact.py +356 -0
  12. pyagrum/causal/_dSeparation.py +598 -0
  13. pyagrum/causal/_doAST.py +761 -0
  14. pyagrum/causal/_doCalculus.py +361 -0
  15. pyagrum/causal/_doorCriteria.py +374 -0
  16. pyagrum/causal/_exceptions.py +95 -0
  17. pyagrum/causal/_types.py +61 -0
  18. pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
  19. pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
  20. pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
  21. pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
  22. pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
  23. pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
  24. pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
  25. pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
  26. pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
  27. pyagrum/causal/notebook.py +171 -0
  28. pyagrum/clg/CLG.py +658 -0
  29. pyagrum/clg/GaussianVariable.py +111 -0
  30. pyagrum/clg/SEM.py +312 -0
  31. pyagrum/clg/__init__.py +63 -0
  32. pyagrum/clg/canonicalForm.py +408 -0
  33. pyagrum/clg/constants.py +54 -0
  34. pyagrum/clg/forwardSampling.py +202 -0
  35. pyagrum/clg/learning.py +776 -0
  36. pyagrum/clg/notebook.py +480 -0
  37. pyagrum/clg/variableElimination.py +271 -0
  38. pyagrum/common.py +60 -0
  39. pyagrum/config.py +319 -0
  40. pyagrum/ctbn/CIM.py +513 -0
  41. pyagrum/ctbn/CTBN.py +573 -0
  42. pyagrum/ctbn/CTBNGenerator.py +216 -0
  43. pyagrum/ctbn/CTBNInference.py +459 -0
  44. pyagrum/ctbn/CTBNLearner.py +161 -0
  45. pyagrum/ctbn/SamplesStats.py +671 -0
  46. pyagrum/ctbn/StatsIndepTest.py +355 -0
  47. pyagrum/ctbn/__init__.py +79 -0
  48. pyagrum/ctbn/constants.py +54 -0
  49. pyagrum/ctbn/notebook.py +264 -0
  50. pyagrum/defaults.ini +199 -0
  51. pyagrum/deprecated.py +95 -0
  52. pyagrum/explain/_ComputationCausal.py +75 -0
  53. pyagrum/explain/_ComputationConditional.py +48 -0
  54. pyagrum/explain/_ComputationMarginal.py +48 -0
  55. pyagrum/explain/_CustomShapleyCache.py +110 -0
  56. pyagrum/explain/_Explainer.py +176 -0
  57. pyagrum/explain/_Explanation.py +70 -0
  58. pyagrum/explain/_FIFOCache.py +54 -0
  59. pyagrum/explain/_ShallCausalValues.py +204 -0
  60. pyagrum/explain/_ShallConditionalValues.py +155 -0
  61. pyagrum/explain/_ShallMarginalValues.py +155 -0
  62. pyagrum/explain/_ShallValues.py +296 -0
  63. pyagrum/explain/_ShapCausalValues.py +208 -0
  64. pyagrum/explain/_ShapConditionalValues.py +126 -0
  65. pyagrum/explain/_ShapMarginalValues.py +191 -0
  66. pyagrum/explain/_ShapleyValues.py +298 -0
  67. pyagrum/explain/__init__.py +81 -0
  68. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  69. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  70. pyagrum/explain/_explInformationGraph.py +264 -0
  71. pyagrum/explain/notebook/__init__.py +54 -0
  72. pyagrum/explain/notebook/_bar.py +142 -0
  73. pyagrum/explain/notebook/_beeswarm.py +174 -0
  74. pyagrum/explain/notebook/_showShapValues.py +97 -0
  75. pyagrum/explain/notebook/_waterfall.py +220 -0
  76. pyagrum/explain/shapley.py +225 -0
  77. pyagrum/lib/__init__.py +46 -0
  78. pyagrum/lib/_colors.py +390 -0
  79. pyagrum/lib/bn2graph.py +299 -0
  80. pyagrum/lib/bn2roc.py +1026 -0
  81. pyagrum/lib/bn2scores.py +217 -0
  82. pyagrum/lib/bn_vs_bn.py +605 -0
  83. pyagrum/lib/cn2graph.py +305 -0
  84. pyagrum/lib/discreteTypeProcessor.py +1102 -0
  85. pyagrum/lib/discretizer.py +58 -0
  86. pyagrum/lib/dynamicBN.py +390 -0
  87. pyagrum/lib/explain.py +57 -0
  88. pyagrum/lib/export.py +84 -0
  89. pyagrum/lib/id2graph.py +258 -0
  90. pyagrum/lib/image.py +387 -0
  91. pyagrum/lib/ipython.py +307 -0
  92. pyagrum/lib/mrf2graph.py +471 -0
  93. pyagrum/lib/notebook.py +1821 -0
  94. pyagrum/lib/proba_histogram.py +552 -0
  95. pyagrum/lib/utils.py +138 -0
  96. pyagrum/pyagrum.py +31495 -0
  97. pyagrum/skbn/_MBCalcul.py +242 -0
  98. pyagrum/skbn/__init__.py +49 -0
  99. pyagrum/skbn/_learningMethods.py +282 -0
  100. pyagrum/skbn/_utils.py +297 -0
  101. pyagrum/skbn/bnclassifier.py +1014 -0
  102. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
  103. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
  104. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
  105. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
  106. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
  107. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
@@ -0,0 +1,1175 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ import pandas as pd
42
+ import numpy as np
43
+
44
+ from typing import Any
45
+
46
+ import pyagrum.causal as csl
47
+
48
+ from ._utils import (
49
+ MisspecifiedAdjustmentError,
50
+ EmptyConditionError,
51
+ InvalidConditionError,
52
+ RCTError,
53
+ BackdoorError,
54
+ FrontdoorError,
55
+ IVError,
56
+ RCT,
57
+ generalizedFrontDoor,
58
+ instrumentalVariable,
59
+ RCT_ESTIMATORS_LIST,
60
+ BACKDOOR_ESTIMATORS_LIST,
61
+ FRONTDOOR_ESTIMATORS_LIST,
62
+ IV_ESTIMATORS_LIST,
63
+ )
64
+
65
+ from ._causalBNEstimator import CausalBNEstimator
66
+
67
+ from ._RCTEstimators import DM
68
+
69
+ from ._backdoorEstimators import SLearner, TLearner, XLearner, PStratification, IPW
70
+ from ._frontdoorEstimators import SimplePlugIn, GeneralizedPlugIn
71
+ from ._IVEstimators import Wald, WaldIPW, NormalizedWaldIPW, TSLS
72
+
73
+
74
+ class CausalEffectEstimation:
75
+ """
76
+ Estimates causal effects using a dataset and a causal graph within
77
+ the Neyman-Rubin Tensor Outcomes framework.
78
+
79
+ This class performs causal identification based on user-specified datasets
80
+ and causal graphical models. It determines the appropriate adjustment method
81
+ — suchas backdoor, front-door, or instrumental variables (IV) — to optimally
82
+ estimate the causal effect (or treatment effect) between the intervention
83
+ (treatment assignment) and the outcome.
84
+
85
+ The class integrates domain-specific statistical estimators and recent
86
+ advancements in machine learning techniques to estimate various causal
87
+ effects, including the Average Causal Effect (ACE), Conditional Average
88
+ Causal Effect (CACE), Individual Causal Effect (ICE), and Local Average
89
+ Treatment Effect (LATE), among others.
90
+
91
+ This module is inspired by the works of
92
+ :cite:t:`wager2020stats` and :cite:t:`neal2020introduction`.
93
+
94
+ Raises
95
+ ------
96
+ AssertionError
97
+ If the input dataframe is empty, indicating that predictions
98
+ cannot be made.
99
+ ValueError
100
+ If the provided estimator_string does not correspond to any
101
+ supported estimator.
102
+ """
103
+
104
+ _RCT = "Randomized Controlled Trial"
105
+ _BACKDOOR = "Backdoor"
106
+ _FRONTDOOR = "Generalized Frontdoor"
107
+ _IV = "Generalized Instrumental Variable"
108
+ _UNKNOWN = "Unknown"
109
+
110
+ def __init__(self, df: pd.DataFrame, causal_model: csl.CausalModel) -> None:
111
+ """
112
+ Initializes the causal estimator instance.
113
+
114
+ Parameters
115
+ ----------
116
+ df: pd.DataFrame
117
+ The dataset for causal effect estimation.
118
+ causal_model: csl.CausalModel
119
+ The causal model for causal effect identification.
120
+ """
121
+
122
+ self._df = df
123
+ self._causal_model = causal_model
124
+ self._adjustment = None
125
+ self._T = None # str: Intervention (Treatement) variable
126
+ self._y = None # str: Outcome variable
127
+ self._w = None # str: Instrumental variable
128
+ self._M = None # set[str]: Mediator variable set
129
+ self._X = None # set[str]: Confounder/Covariate variable set
130
+ self._estimator = None # Any: Causal estimator
131
+
132
+ def __str__(self) -> str:
133
+ """
134
+ Return a string representation of the instance.
135
+
136
+ Includes class/module info, DataFrame summary, and details about
137
+ the causal model, adjustment variables, intervention, outcome,
138
+ confounders, mediators, instrumental variables, and estimator.
139
+
140
+ Returns
141
+ -------
142
+ str
143
+ A formatted string describing the instance's attributes.
144
+ """
145
+
146
+ object_info = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}>"
147
+
148
+ df_info = (
149
+ f"<{self._df.__class__.__module__}."
150
+ f"{self._df.__class__.__name__}"
151
+ f" object at {hex(id(self._df))}>"
152
+ f"\n\t- shape\t\t: {self._df.shape}"
153
+ f"\n\t- columns\t: {self._df.columns}"
154
+ f"\n\t- memory usage\t: {self._df.memory_usage().sum() / 1e6} MB"
155
+ )
156
+
157
+ causal_model_info = (
158
+ f"{self._causal_model}"
159
+ f"\n\t- names\t\t: {self._causal_model.names()}"
160
+ f"\n\t- causal BN\t: {self._causal_model.causalBN()}"
161
+ f"\n\t- observ. BN\t: {self._causal_model.observationalBN()}"
162
+ )
163
+
164
+ res = f"{object_info}\n\n Dataframe\t: {df_info}\n Causal Model\t: {causal_model_info}"
165
+
166
+ if self._adjustment is not None:
167
+ res += f"\n Adjustment\t: {self._adjustment}"
168
+ if self._T is not None:
169
+ res += f"\n Intervention\t: {self._T}"
170
+ if self._y is not None:
171
+ res += f"\n Outcome\t: {self._y}"
172
+ if self._X is not None:
173
+ res += f"\n Confounders\t: {self._X}"
174
+ if self._M is not None:
175
+ res += f"\n Mediators\t: {self._M}"
176
+ if self._w is not None:
177
+ res += f"\n Instrument\t: {self._w}"
178
+ if self._estimator is not None:
179
+ res += f"\n Estimator\t: {self._estimator}"
180
+
181
+ return res
182
+
183
+ # Causal identification
184
+
185
+ def useRCTAdjustment(self, intervention: str, outcome: str, confounders: set[str]) -> None:
186
+ """
187
+ Specify the Randomized Controlled Trial (RCT) Adjustment.
188
+
189
+ Note: This method does not verify if the specified adjustment is
190
+ appropriate within the causal graph. If unsure, use
191
+ `.identifyAdjustment()` to automatically determine the correct
192
+ adjustment set.
193
+
194
+ Parameters
195
+ ----------
196
+ intervention: str
197
+ Intervention (or treatment) variable.
198
+ outcome: str
199
+ Outcome variable.
200
+ confounders: set[str] or None
201
+ Set of confounder variables (or covariates).
202
+ """
203
+
204
+ self._adjustment = self._RCT
205
+ self._T = intervention
206
+ self._y = outcome
207
+ self._X = set() if confounders is None else confounders
208
+ self._w = None
209
+ self._M = None
210
+ self._estimator = None
211
+
212
+ def useBackdoorAdjustment(self, intervention: str, outcome: str, confounders: set[str]) -> None:
213
+ """
214
+ Specify the Backdoor Adjustment.
215
+
216
+ Note: This method does not verify if the specified adjustment is
217
+ appropriate within the causal graph. If unsure, use
218
+ `.identifyAdjustment()` to automatically determine the correct
219
+ adjustment set.
220
+
221
+ Parameters
222
+ ----------
223
+ intervention: str
224
+ Intervention (or treatment) variable.
225
+ outcome: str
226
+ Outcome variable.
227
+ confounders: set[str] or None
228
+ Set of confounder variables (or covariates).
229
+ """
230
+
231
+ self._adjustment = self._BACKDOOR
232
+ self._T = intervention
233
+ self._y = outcome
234
+ self._X = set() if confounders is None else confounders
235
+ self._w = None
236
+ self._M = None
237
+ self._estimator = None
238
+
239
+ def useFrontdoorAdjustment(
240
+ self, intervention: str, outcome: str, mediators: set[str], confounders: set[str] | None = None
241
+ ) -> None:
242
+ """
243
+ Specify the (General) Frontdoor Adjustment.
244
+ :cite:t:`guo2023targeted`.
245
+
246
+ Note: This method does not verify if the specified adjustment is
247
+ appropriate within the causal graph. If unsure, use
248
+ `.identifyAdjustment()` to automatically determine the correct
249
+ adjustment set.
250
+
251
+ Parameters
252
+ ----------
253
+ intervention: str
254
+ Intervention (or treatment) variable.
255
+ outcome: str
256
+ Outcome variable.
257
+ mediators: set[str]
258
+ Mediator variables.
259
+ confounders: set[str] or None, optional
260
+ Set of confounder variables (or covariates).
261
+ """
262
+
263
+ self._adjustment = self._FRONTDOOR
264
+ self._T = intervention
265
+ self._y = outcome
266
+ self._M = mediators
267
+ self._X = set() if confounders is None else confounders
268
+ self._w = None
269
+ self._estimator = None
270
+
271
+ def useIVAdjustment(
272
+ self, intervention: str, outcome: str, instrument: str, confounders: set[str] | None = None
273
+ ) -> None:
274
+ """
275
+ Specify the (Generalized) Instrumental Variable Adjustment.
276
+ :cite:t:`brito2012generalized`,
277
+ :cite:t:`van2015efficiently`.
278
+
279
+
280
+ Note: This method does not verify if the specified adjustment is
281
+ appropriate within the causal graph. If unsure, use
282
+ `.identifyAdjustment()` to automatically determine the correct
283
+ adjustment set.
284
+
285
+ Parameters
286
+ ----------
287
+ intervention: str
288
+ Intervention (or treatment) variable.
289
+ outcome: str
290
+ Outcome variable.
291
+ instruments: str
292
+ Instrumental variable.
293
+ confounders: set[str] or None, optional
294
+ Set of confounder variables (or covariates).
295
+ """
296
+
297
+ self._adjustment = self._IV
298
+ self._T = intervention
299
+ self._y = outcome
300
+ self._w = instrument
301
+ self._X = set() if confounders is None else confounders
302
+ self._estimator = None
303
+
304
+ def useUnknownAdjustment(
305
+ self,
306
+ intervention: str,
307
+ outcome: str,
308
+ ) -> None:
309
+ """
310
+ Specify an Unknown Adjustment.
311
+
312
+ Note: This method does not verify if the specified adjustment is
313
+ appropriate within the causal graph. If unsure, use
314
+ `.identifyAdjustment()` to automatically determine the correct
315
+ adjustment set.
316
+
317
+ Parameters
318
+ ----------
319
+ intervention: str
320
+ Intervention (or treatment) variable.
321
+ outcome: str
322
+ Outcome variable.
323
+ """
324
+
325
+ self._adjustment = self._UNKNOWN
326
+ self._T = intervention
327
+ self._y = outcome
328
+
329
+ def identifyAdjustmentSet(self, intervention: str, outcome: str, verbose: bool = True) -> None:
330
+ """
331
+ Identify the sufficent adjustment set of covariates.
332
+
333
+ Parameters
334
+ ----------
335
+ intervention: str
336
+ Intervention (treatment) variable.
337
+ outcome: str
338
+ Outcome variable.
339
+ verbose: bool
340
+ If True, prints the estimators that can be used using
341
+ the found adjustment. Default is True.
342
+
343
+ Raises
344
+ ------
345
+ ValueError
346
+ The tratment isn't binary or no adjustment set was found.
347
+ """
348
+
349
+ if set(self._df[intervention].unique()) != {0, 1}:
350
+ raise ValueError(
351
+ "Treatment must be binary with values 0 and 1.\n"
352
+ "Please make sure that the datatype is `int` with "
353
+ "the positivity assumption satisfied "
354
+ "(i.e. there is at least one occurence of both 0 and 1)."
355
+ )
356
+
357
+ rct = RCT(self._causal_model, intervention, outcome)
358
+ backdoor = self._causal_model.backDoor(intervention, outcome)
359
+ frontdoor, fd_covariates = generalizedFrontDoor(self._causal_model, intervention, outcome)
360
+ instrumental_variable, iv_covariates = instrumentalVariable(self._causal_model, intervention, outcome)
361
+
362
+ suggestion_text = ""
363
+
364
+ if rct is not None:
365
+ self.useRCTAdjustment(intervention, outcome, rct)
366
+ suggestion_text += (
367
+ self._RCT
368
+ + " adjustment found. \n\n"
369
+ + "Supported estimators include:"
370
+ + RCT_ESTIMATORS_LIST
371
+ + "\nIf the outcome variable is a cause of other covariates "
372
+ "in the causal graph,\nBackdoor estimators may also be used."
373
+ )
374
+
375
+ elif backdoor is not None:
376
+ self.useBackdoorAdjustment(intervention, outcome, backdoor)
377
+ suggestion_text += (
378
+ self._BACKDOOR + " adjustment found. \n\n" + "Supported estimators include:" + BACKDOOR_ESTIMATORS_LIST
379
+ )
380
+
381
+ elif frontdoor is not None:
382
+ self.useFrontdoorAdjustment(intervention, outcome, frontdoor, fd_covariates)
383
+ suggestion_text += (
384
+ self._FRONTDOOR + " adjustment found. \n\n" + "Supported estimators include:" + FRONTDOOR_ESTIMATORS_LIST
385
+ )
386
+
387
+ elif instrumental_variable is not None:
388
+ self.useIVAdjustment(intervention, outcome, instrumental_variable, iv_covariates)
389
+ suggestion_text += self._IV + " adjustment found. \n\n" + "Supported estimators include:" + IV_ESTIMATORS_LIST
390
+
391
+ else:
392
+ self.useUnknownAdjustment(intervention, outcome)
393
+ suggestion_text = (
394
+ "No adjustment set found among: "
395
+ "RCT, Backdoor, Generalized Frontdoor, or Generalized IV.\n\n"
396
+ "The only supported estimator without a known adjustment is "
397
+ "the Causal Bayesian Network Estimator, which can estimate "
398
+ "the causal effect if identifiable using do-Calculus.\n"
399
+ "Use `.fitCausalBNEstimator()` to apply this estimator."
400
+ )
401
+
402
+ if verbose:
403
+ print(suggestion_text)
404
+
405
+ return self._adjustment
406
+
407
+ # Model fitting
408
+
409
+ def _fitEstimator(self, **fit_params) -> None:
410
+ """
411
+ Fits the specified causal effect estimator to the data.
412
+
413
+ Parameters
414
+ ----------
415
+ estimator: str or Any
416
+ The estimator to be used. Can be a string identifier for built-in
417
+ estimators or a causalML object.
418
+
419
+ estimator_params: dict[str, Any], optional
420
+ Parameters to initialize the estimator. Keys are parameter names,
421
+ values are the corresponding parameter values. Default is None.
422
+
423
+ fit_params: dict[str, Any], optional
424
+ Additional parameters passed to the fit method of the estimator.
425
+ Keys are parameter names, values are the corresponding parameter
426
+ values. Default is None.
427
+
428
+ Raises
429
+ ------
430
+ ValueError
431
+ No adjustment have been selected before fitting an estimator.
432
+ """
433
+
434
+ match self._adjustment:
435
+ case self._IV:
436
+ try:
437
+ return self._estimator.fit(
438
+ X=self._df[[*self._X]],
439
+ treatment=self._df[self._T],
440
+ y=self._df[self._y],
441
+ w=self._df[[*self._w]],
442
+ **fit_params,
443
+ )
444
+ except TypeError:
445
+ return self._estimator.fit(
446
+ X=self._df[[*self._X]],
447
+ treatment=self._df[self._T],
448
+ y=self._df[self._y],
449
+ assignment=self._df[[*self._w]],
450
+ **fit_params,
451
+ )
452
+
453
+ case self._FRONTDOOR:
454
+ return self._estimator.fit(
455
+ X=self._df[[*self._X]], treatment=self._df[self._T], y=self._df[self._y], M=self._df[[*self._M]], **fit_params
456
+ )
457
+
458
+ case self._BACKDOOR:
459
+ return self._estimator.fit(
460
+ X=self._df[[*self._X]], treatment=self._df[self._T], y=self._df[self._y], **fit_params
461
+ )
462
+
463
+ case self._RCT:
464
+ return self._estimator.fit(
465
+ X=self._df[[*self._X]], treatment=self._df[self._T], y=self._df[self._y], **fit_params
466
+ )
467
+
468
+ case _:
469
+ raise MisspecifiedAdjustmentError("fitting an estimator")
470
+
471
+ # Custom estimators
472
+
473
+ def fitCausalBNEstimator(self) -> Any:
474
+ """
475
+ Fit the Causal Bayesian Network Estimator.
476
+
477
+ This class utilizes do-calculus identification and lazy propagation
478
+ inference, implemented via the pyAgrum library's causal module,
479
+ to determine the causal effect within Bayesian Networks.
480
+
481
+ Note: In the case of instrumental variables, the causal effect is
482
+ estimated using heuristic methods, as this adjustment is not
483
+ identifiable through do-calculus.
484
+ """
485
+
486
+ self._estimator = CausalBNEstimator(self._causal_model, self._T, self._y, self._w)
487
+
488
+ self._estimator.fit(self._df)
489
+
490
+ def fitCustomEstimator(self, estimator: Any) -> Any:
491
+ """
492
+ Fits the specified `estimator` object, which must implement
493
+ `.fit()`, `.predict()`, and `.estimate_ate()` methods consistent
494
+ with CausalML estimators.
495
+ :cite:t:`chen2020causalml`.
496
+
497
+ Note: Compatibility with the current adjustment is not guarenteed.
498
+
499
+ Parameters
500
+ ----------
501
+ estimator: Any
502
+ The estimator object to be fitted, adhering to the CausalML
503
+ method declarations.
504
+ """
505
+
506
+ self._estimator = estimator
507
+ self._fitEstimator()
508
+
509
+ # RCT
510
+
511
+ def fitDM(self) -> Any:
512
+ """
513
+ Fits the Difference in Means (DM) Estimator.
514
+
515
+ The DM estimator computes the Average Causal Effect (ACE)
516
+ under the ignorability assumption in Randomized
517
+ Controlled Trials (RCT) by taking the difference of the mean values
518
+ among the treated and untreated population.
519
+ """
520
+
521
+ if self._adjustment is self._BACKDOOR:
522
+ raise BackdoorError("DM")
523
+ if self._adjustment is self._FRONTDOOR:
524
+ raise FrontdoorError("DM")
525
+ if self._adjustment is self._IV:
526
+ raise IVError("DM")
527
+
528
+ self._estimator = DM()
529
+ self._fitEstimator()
530
+
531
+ # Backdoor
532
+
533
+ def fitSLearner(self, **estimator_params) -> Any:
534
+ """
535
+ Fit the S-Learner Estimator.
536
+
537
+ A basic implementation of the S-learner based on Kunzel et al. (2018)
538
+ :cite:t:`kunzel2019metalearners`.
539
+
540
+ Parameters
541
+ ----------
542
+ learner: str or Any, optional
543
+ Base estimator for all learners.
544
+ If not provided, defaults to LinearRegression.
545
+ """
546
+
547
+ if self._adjustment is self._IV:
548
+ raise IVError("SLearner")
549
+ if self._adjustment is self._FRONTDOOR:
550
+ raise FrontdoorError("SLearner")
551
+ if self._X is None or len(self._X) == 0:
552
+ raise RCTError("SLearner")
553
+
554
+ self._estimator = SLearner(**estimator_params)
555
+ self._fitEstimator()
556
+
557
+ def fitTLearner(self, **estimator_params) -> Any:
558
+ """
559
+ Fit the T-Learner Estimator.
560
+
561
+ A basic implementation of the T-learner based on Kunzel et al. (2018)
562
+ :cite:t:`kunzel2019metalearners`.
563
+
564
+ Parameters
565
+ ----------
566
+ learner: str or Any, optional
567
+ Base estimator for all learners.
568
+ If not provided, defaults to LinearRegression.
569
+ control_learner: str or Any, optional
570
+ Estimator for control group outcome.
571
+ Overrides `learner` if specified.
572
+ treatment_learner: str or Any, optional
573
+ Estimator for treatment group outcome.
574
+ Overrides `learner` if specified.
575
+ """
576
+
577
+ if self._adjustment is self._IV:
578
+ raise IVError("TLearner")
579
+ if self._adjustment is self._FRONTDOOR:
580
+ raise FrontdoorError("TLearner")
581
+ if self._X is None or len(self._X) == 0:
582
+ raise RCTError("TLearner")
583
+
584
+ self._estimator = TLearner(**estimator_params)
585
+ self._fitEstimator()
586
+
587
+ def fitXLearner(self, **estimator_params) -> Any:
588
+ """
589
+ Fit the X-Learner Estimator.
590
+
591
+ A basic implementation of the X-learner based on Kunzel et al. (2018)
592
+ :cite:t:`kunzel2019metalearners`.
593
+
594
+ Parameters
595
+ ----------
596
+ learner: str or Any, optional
597
+ Base estimator for all learners.
598
+ If not provided, defaults to LinearRegression.
599
+ control_outcome_learner: str or Any, optional
600
+ Estimator for control group outcome.
601
+ Overrides `learner` if specified.
602
+ treatment_outcome_learner: str or Any, optional
603
+ Estimator for treatment group outcome.
604
+ Overrides `learner` if specified.
605
+ control_effect_learner: str or Any, optional
606
+ Estimator for control group effect.
607
+ Overrides `learner` if specified.
608
+ treatment_effect_learner: str or Any, optional
609
+ Estimator for treatment group effect.
610
+ Overrides `learner` if specified.
611
+ propensity_score_learner: str or Any, optional
612
+ Estimator for propensity score.
613
+ If not provided, defaults to LogisticRegression.
614
+ """
615
+
616
+ if self._adjustment is self._IV:
617
+ raise IVError("XLearner")
618
+ if self._adjustment is self._FRONTDOOR:
619
+ raise FrontdoorError("XLearner")
620
+ if self._X is None or len(self._X) == 0:
621
+ raise RCTError("XLearner")
622
+
623
+ self._estimator = XLearner(**estimator_params)
624
+ self._fitEstimator()
625
+
626
+ def fitPStratification(self, **estimator_params) -> Any:
627
+ """
628
+ Fit the Propensity score Stratification Estimator.
629
+
630
+ A basic implementation of Propensity Stratification estimator
631
+ based on Lunceford et al. (2004)
632
+ :cite:t:`lunceford2004stratification`.
633
+
634
+ Parameters
635
+ ----------
636
+ propensity_score_learner: str or Any, optional
637
+ Estimator for propensity score.
638
+ If not provided, defaults to LogisticRegression.
639
+ num_strata: int, optional
640
+ The number of strata.
641
+ Default is 100.
642
+ """
643
+
644
+ if self._adjustment is self._IV:
645
+ raise IVError("PStratification")
646
+ if self._adjustment is self._FRONTDOOR:
647
+ raise FrontdoorError("PStratification")
648
+ if self._X is None or len(self._X) == 0:
649
+ raise RCTError("PStratification")
650
+
651
+ self._estimator = PStratification(**estimator_params)
652
+ self._fitEstimator()
653
+
654
+ def fitIPW(self, **estimator_params) -> Any:
655
+ """
656
+ Fit the Inverse Propensity score Weighting Estimator.
657
+
658
+ A basic implementation of the Inverse Propensity Score Weighting (IPW)
659
+ estimator based on Lunceford et al. (2004)
660
+ :cite:t:`lunceford2004stratification`.
661
+
662
+ Parameters
663
+ ----------
664
+ propensity_score_learner: str or Any, optional
665
+ Estimator for propensity score.
666
+ If not provided, defaults to LogisticRegression.
667
+ """
668
+
669
+ if self._adjustment is self._IV:
670
+ raise IVError("IPW")
671
+ if self._adjustment is self._FRONTDOOR:
672
+ raise FrontdoorError("IPW")
673
+ if self._X is None or len(self._X) == 0:
674
+ raise RCTError("IPW")
675
+
676
+ self._estimator = IPW(**estimator_params)
677
+ self._fitEstimator()
678
+
679
+ # Frontdoor
680
+
681
+ def fitSimplePlugIn(self, **estimator_params) -> Any:
682
+ """
683
+ Fit the Plug-in Estimator.
684
+
685
+ Uses the (original) Frontdoor Adjustment Formula to derive
686
+ the plug-in estimator. Does not account for covariates.
687
+ Inspired by Guo et al. (2023).
688
+
689
+ :cite:t:`fulcher2020robust`,
690
+ :cite:t:`guo2023targeted`.
691
+
692
+ Parameters
693
+ ----------
694
+ learner: str or object, optional
695
+ Estimator for outcome variable.
696
+ If not provided, defaults to LinearRegression.
697
+ propensity_learner: str or object, optional
698
+ Estimator for treatment probability.
699
+ If not provided, defaults to LogisticRegression.
700
+ """
701
+
702
+ if self._adjustment is self._RCT:
703
+ raise RCTError("PlugIn")
704
+ if self._adjustment is self._BACKDOOR:
705
+ raise BackdoorError("PlugIn")
706
+ if self._adjustment is self._IV:
707
+ raise IVError("PlugIn")
708
+
709
+ self._estimator = SimplePlugIn(**estimator_params)
710
+ self._fitEstimator()
711
+
712
+ def fitGeneralizedPlugIn(self, **estimator_params) -> Any:
713
+ """
714
+ Fit the Generalized plug-in Estimator.
715
+
716
+ Basic implementation of the second plug-in estimator in
717
+ Guo et al. (2023). Must provide covariates.
718
+ :cite:t:`fulcher2020robust`,
719
+ :cite:t:`guo2023targeted`.
720
+
721
+ Parameters
722
+ ----------
723
+ estimator_params: Any
724
+ The parameters of the estimator.
725
+ """
726
+
727
+ if self._adjustment is self._RCT:
728
+ raise RCTError("GeneralizedPlugIn")
729
+ if self._adjustment is self._BACKDOOR:
730
+ raise BackdoorError("GeneralizedPlugIn")
731
+ if self._adjustment is self._IV:
732
+ raise IVError("GeneralizedPlugIn")
733
+
734
+ self._estimator = GeneralizedPlugIn(**estimator_params)
735
+ self._fitEstimator()
736
+
737
+ # Instrumental Variable
738
+
739
+ def fitWald(self) -> Any:
740
+ """
741
+ Fit the Wald Estimator.
742
+
743
+ An implementation of the Wald estimator which computes the
744
+ Local Average Causal Effect (LACE), also know as the Local Average
745
+ Treatment Effect (LATE).
746
+ Only Supports binary instruments.
747
+ """
748
+
749
+ if self._adjustment is self._RCT:
750
+ raise RCTError("Wald")
751
+ if self._adjustment is self._BACKDOOR:
752
+ raise BackdoorError("Wald")
753
+ if self._adjustment is self._FRONTDOOR:
754
+ raise FrontdoorError("Wald")
755
+
756
+ self._estimator = Wald()
757
+ self._fitEstimator()
758
+
759
+ def fitWaldIPW(self, **estimator_params) -> Any:
760
+ """
761
+ Fit the Wald Inverse Probability Weighting Estimator.
762
+
763
+ A basic implementation of the Wald estimand with Inverse Propensity
764
+ Score Weighting which computes the Local Average Causal Effect (LACE).
765
+ Only Supports binary instruments.
766
+ :cite:t:`choi2021instrumental`.
767
+
768
+ Parameters
769
+ ----------
770
+ iv_probability_learner: str or Any, optional
771
+ Estimator for instrumental variable probability.
772
+ If not provided, defaults to LogisticRegression.
773
+ """
774
+
775
+ if self._adjustment is self._RCT:
776
+ raise RCTError("WaldIPW")
777
+ if self._adjustment is self._BACKDOOR:
778
+ raise BackdoorError("WaldIPW")
779
+ if self._adjustment is self._FRONTDOOR:
780
+ raise FrontdoorError("WaldIPW")
781
+
782
+ self._estimator = WaldIPW(**estimator_params)
783
+ self._fitEstimator()
784
+
785
+ def fitNormalizedWaldIPW(self, **estimator_params) -> Any:
786
+ """
787
+ Fit the Normalized Wald Inverse Probability Weighting Estimator.
788
+
789
+ A basic implementation of the normalized Wald estimator with Inverse Propensity
790
+ Score Weighting which computes the Local Average Causal Effect (LACE).
791
+ Only Supports binary instruments.
792
+ :cite:t:`choi2021instrumental`.
793
+
794
+ Parameters
795
+ ----------
796
+ iv_probability_learner: str or Any, optional
797
+ Estimator for instrumental variable probability.
798
+ If not provided, defaults to LogisticRegression.
799
+ """
800
+
801
+ if self._adjustment is self._RCT:
802
+ raise RCTError("NormalizedWaldIPW")
803
+ if self._adjustment is self._BACKDOOR:
804
+ raise BackdoorError("NormalizedWaldIPW")
805
+ if self._adjustment is self._FRONTDOOR:
806
+ raise FrontdoorError("NormalizedWaldIPW")
807
+
808
+ self._estimator = NormalizedWaldIPW(**estimator_params)
809
+ self._fitEstimator()
810
+
811
+ def fitTSLS(self, **estimator_params) -> Any:
812
+ """
813
+ Fit the Two Stage Least Squares Estimator.
814
+
815
+ A basic implementation of the Two Stage Least-Squares Estimator.
816
+ Only supports Linear Models, must have `.coef_` attribute.
817
+ :cite:t:`angrist1995two`.
818
+
819
+ Parameters
820
+ ----------
821
+ learner: str or Any, optional
822
+ Base estimator for all learners.
823
+ If not provided, defaults to LinearRegression.
824
+ treatment_learner: str or Any, optional
825
+ Estimator for treatment assignment.
826
+ Overrides `learner` if specified.
827
+ outcome_learner: str or Any, optional
828
+ Estimator for outcome.
829
+ Overrides `learner` if specified.
830
+ """
831
+
832
+ if self._adjustment is self._RCT:
833
+ raise RCTError("TSLS")
834
+ if self._adjustment is self._BACKDOOR:
835
+ raise BackdoorError("TSLS")
836
+ if self._adjustment is self._FRONTDOOR:
837
+ raise FrontdoorError("TSLS")
838
+
839
+ self._estimator = TSLS(**estimator_params)
840
+ self._fitEstimator()
841
+
842
+ # Estimation
843
+
844
+ def estimateCausalEffect(
845
+ self, conditional: pd.DataFrame | str | None = None, **estimation_params: Any
846
+ ) -> float | np.ndarray:
847
+ """
848
+ Estimate the causal or treatment effect based on the initialized data.
849
+
850
+ Parameters
851
+ ----------
852
+ conditional: pd.DataFrame, str, or None, optional
853
+ Specifies conditions for estimating causal effects.
854
+
855
+ - If `pd.DataFrame`, estimates the Individual Causal Effect (ICE)
856
+ for each row.
857
+ - If `str`, estimates the Conditional Average Causal Effect (CACE).
858
+ The string must be a valid pandas query.
859
+ - If `None`, estimates the Average Causal Effect (ACE).
860
+ Default is `None`.
861
+
862
+ estimation_params: dict of str to Any, optional
863
+ Additional parameters for the estimation method.
864
+ Keys are parameter names, and values are the corresponding
865
+ parameter values. Default is `None`.
866
+
867
+ Returns
868
+ -------
869
+ float or np.ndarray
870
+ If `return_ci` is `False`, returns the estimated causal effect
871
+ as a float.
872
+
873
+ If `return_ci` is `True`, returns a tuple containing:
874
+ - The estimated causal effect (float)
875
+ - The lower and upper bounds of the confidence interval
876
+ (tuple of floats)
877
+
878
+ Raises
879
+ ------
880
+ ValueError
881
+ No adjustment have been selected before making the estimate.
882
+ """
883
+
884
+ assert self._estimator is not None, "Please fit an estimator before attempting to make an estimate."
885
+
886
+ match self._adjustment:
887
+ case self._IV:
888
+ return self._estimateIVCausalEffect(conditional, **estimation_params)
889
+ case self._FRONTDOOR:
890
+ return self._estimateFrontdoorCausalEffect(conditional, **estimation_params)
891
+ case self._BACKDOOR:
892
+ return self._estimateBackdoorCausalEffect(conditional, **estimation_params)
893
+ case self._RCT:
894
+ return self._estimateRCTCausalEffect(conditional, **estimation_params)
895
+ case self._UNKNOWN:
896
+ return self._estimateUnknownCausalEffect(conditional, **estimation_params)
897
+ case _:
898
+ raise MisspecifiedAdjustmentError("making an estimate")
899
+
900
+ def _estimateRCTCausalEffect(
901
+ self, conditional: pd.DataFrame | str | None = None, **estimation_params: Any
902
+ ) -> float | np.ndarray:
903
+ """
904
+ Estimate the causal or treatment effect using RCT adjustment.
905
+
906
+ Parameters
907
+ ----------
908
+ conditional: pd.DataFrame, str, or None, optional
909
+ Conditions for estimating causal effects.
910
+ estimation_params: dict[str, Any], optional
911
+ Additional parameters for the estimation method.
912
+
913
+ Returns
914
+ -------
915
+ float or np.ndarray
916
+ The estimated causal effect.
917
+
918
+ Raises
919
+ ------
920
+ ValueError
921
+ The inputed conditional is invalid.
922
+ """
923
+
924
+ assert self._estimator is not None, "Please fit an estimator before attempting to make an estimate."
925
+
926
+ if estimation_params is None:
927
+ estimation_params = dict()
928
+ # ICE
929
+ if isinstance(conditional, pd.DataFrame):
930
+ conditional = pd.DataFrame(conditional)
931
+ return self._estimator.predict(
932
+ X=conditional[[*self._X]], treatment=conditional[self._T], y=conditional[self._y], **estimation_params
933
+ )
934
+ # CACE
935
+ elif isinstance(conditional, str):
936
+ cond_df = self._df.query(conditional)
937
+ if len(cond_df) == 0:
938
+ raise EmptyConditionError()
939
+ predictions = self._estimator.predict(
940
+ X=cond_df[[*self._X]], treatment=cond_df[self._T], y=cond_df[self._y], **estimation_params
941
+ )
942
+ return predictions.mean()
943
+ # ACE
944
+ elif conditional is None:
945
+ return self._estimator.estimate_ate(
946
+ X=self._df[[*self._X]], treatment=self._df[self._T], y=self._df[self._y], **estimation_params
947
+ )
948
+ else:
949
+ raise InvalidConditionError()
950
+
951
+ def _estimateIVCausalEffect(
952
+ self, conditional: pd.DataFrame | str | None = None, **estimation_params: Any
953
+ ) -> float | np.ndarray:
954
+ """
955
+ Estimate the causal or treatment effect using instrumental
956
+ variable adjustment.
957
+
958
+ Parameters
959
+ ----------
960
+ conditional: pd.DataFrame, str, or None, optional
961
+ Conditions for estimating causal effects.
962
+ estimation_params: dict[str, Any], optional
963
+ Additional parameters for the estimation method.
964
+
965
+ Returns
966
+ -------
967
+ float or np.ndarray
968
+ The estimated causal effect.
969
+
970
+ Raises
971
+ ------
972
+ ValueError
973
+ The inputed conditional is invalid.
974
+ """
975
+
976
+ assert self._estimator is not None, "Please fit an estimator before attempting to make an estimate."
977
+
978
+ if estimation_params is None:
979
+ estimation_params = dict()
980
+ # ICE
981
+ if isinstance(conditional, pd.DataFrame):
982
+ conditional = pd.DataFrame(conditional)
983
+ return self._estimator.predict(
984
+ X=conditional[[*self._X]],
985
+ w=conditional[[*self._w]],
986
+ treatment=conditional[self._T],
987
+ y=conditional[self._y],
988
+ **estimation_params,
989
+ )
990
+ # CACE
991
+ elif isinstance(conditional, str):
992
+ cond_df = self._df.query(conditional)
993
+ if len(cond_df) == 0:
994
+ raise EmptyConditionError()
995
+ predictions = self._estimator.predict(
996
+ X=cond_df[[*self._X]],
997
+ w=cond_df[[*self._w]],
998
+ treatment=cond_df[self._T],
999
+ y=cond_df[self._y],
1000
+ **estimation_params,
1001
+ )
1002
+ return predictions.mean()
1003
+ # ACE
1004
+ elif conditional is None:
1005
+ return self._estimator.estimate_ate(
1006
+ X=self._df[[*self._X]],
1007
+ w=self._df[[*self._w]],
1008
+ treatment=self._df[self._T],
1009
+ y=self._df[self._y],
1010
+ pretrain=True,
1011
+ **estimation_params,
1012
+ )
1013
+ else:
1014
+ raise InvalidConditionError()
1015
+
1016
+ def _estimateFrontdoorCausalEffect(
1017
+ self, conditional: pd.DataFrame | str | None = None, **estimation_params: Any
1018
+ ) -> float | np.ndarray:
1019
+ """
1020
+ Estimate the causal or treatment effect using generalized
1021
+ frontdoor adjustment.
1022
+
1023
+ Parameters
1024
+ ----------
1025
+ conditional: pd.DataFrame, str, or None, optional
1026
+ Conditions for estimating treatment effects.
1027
+ estimation_params: dict[str, Any], optional
1028
+ Additional parameters for the estimation method.
1029
+
1030
+ Returns
1031
+ -------
1032
+ float or np.ndarray
1033
+ The estimated treatment effect.
1034
+
1035
+ Raises
1036
+ ------
1037
+ ValueError
1038
+ The inputed conditional is invalid.
1039
+ """
1040
+
1041
+ assert self._estimator is not None, "Please fit an estimator before attempting to make an estimate."
1042
+
1043
+ if estimation_params is None:
1044
+ estimation_params = dict()
1045
+ # ICE
1046
+ if isinstance(conditional, pd.DataFrame):
1047
+ conditional = pd.DataFrame(conditional)
1048
+ return self._estimator.predict(
1049
+ X=conditional[[*self._X]],
1050
+ treatment=conditional[self._T],
1051
+ y=conditional[self._y],
1052
+ M=conditional[[*self._M]],
1053
+ **estimation_params,
1054
+ )
1055
+ # CACE
1056
+ elif isinstance(conditional, str):
1057
+ cond_df = self._df.query(conditional)
1058
+ if len(cond_df) == 0:
1059
+ raise EmptyConditionError()
1060
+ predictions = self._estimator.predict(
1061
+ X=cond_df[[*self._X]],
1062
+ treatment=cond_df[self._T],
1063
+ y=cond_df[self._y],
1064
+ M=cond_df[[*self._M]],
1065
+ **estimation_params,
1066
+ )
1067
+ return predictions.mean()
1068
+ # ACE
1069
+ elif conditional is None:
1070
+ return self._estimator.estimate_ate(
1071
+ X=self._df[[*self._X]],
1072
+ treatment=self._df[self._T],
1073
+ y=self._df[self._y],
1074
+ M=self._df[[*self._M]],
1075
+ pretrain=True,
1076
+ **estimation_params,
1077
+ )
1078
+ else:
1079
+ raise InvalidConditionError()
1080
+
1081
+ def _estimateBackdoorCausalEffect(
1082
+ self, conditional: pd.DataFrame | str | None = None, **estimation_params: Any
1083
+ ) -> float | np.ndarray:
1084
+ """
1085
+ Estimate the causal or treatment effect using backdoor adjustment.
1086
+
1087
+ Parameters
1088
+ ----------
1089
+ conditional: pd.DataFrame, str, or None, optional
1090
+ Specifies conditions for estimating treatment effects.
1091
+ estimation_params: dict[str, Any], optional
1092
+ Additional parameters for the estimation method.
1093
+
1094
+ Returns
1095
+ -------
1096
+ float or np.ndarray
1097
+ The estimated treatment effect.
1098
+
1099
+ Raises
1100
+ ------
1101
+ ValueError
1102
+ The inputed conditional is invalid.
1103
+ """
1104
+
1105
+ assert self._estimator is not None, "Please fit an estimator before attempting to make an estimate."
1106
+
1107
+ if estimation_params is None:
1108
+ estimation_params = dict()
1109
+ # ICE
1110
+ if isinstance(conditional, pd.DataFrame):
1111
+ conditional = pd.DataFrame(conditional)
1112
+ return self._estimator.predict(
1113
+ X=conditional[[*self._X]], treatment=conditional[self._T], y=conditional[self._y], **estimation_params
1114
+ )
1115
+ # CACE
1116
+ elif isinstance(conditional, str):
1117
+ cond_df = self._df.query(conditional)
1118
+ if len(cond_df) == 0:
1119
+ raise EmptyConditionError()
1120
+ predictions = self._estimator.predict(
1121
+ X=cond_df[[*self._X]], treatment=cond_df[self._T], y=cond_df[self._y], **estimation_params
1122
+ )
1123
+ return predictions.mean()
1124
+ # ACE
1125
+ elif conditional is None:
1126
+ return self._estimator.estimate_ate(
1127
+ X=self._df[[*self._X]], treatment=self._df[self._T], y=self._df[self._y], pretrain=True, **estimation_params
1128
+ )
1129
+ else:
1130
+ raise InvalidConditionError()
1131
+
1132
+ def _estimateUnknownCausalEffect(
1133
+ self, conditional: pd.DataFrame | str | None = None, **estimation_params: Any
1134
+ ) -> float | np.ndarray:
1135
+ """
1136
+ Estimate the causal or treatment effect using unknown adjustment.
1137
+
1138
+ Parameters
1139
+ ----------
1140
+ conditional: pd.DataFrame, str, or None, optional
1141
+ Specifies conditions for estimating treatment effects.
1142
+ estimation_params: dict[str, Any], optional
1143
+ Additional parameters for the estimation method.
1144
+
1145
+ Returns
1146
+ -------
1147
+ float or np.ndarray
1148
+ The estimated treatment effect.
1149
+
1150
+ Raises
1151
+ ------
1152
+ ValueError
1153
+ The inputed conditional is invalid.
1154
+ """
1155
+
1156
+ assert self._estimator is not None, "Please fit an estimator before attempting to make an estimate."
1157
+
1158
+ if estimation_params is None:
1159
+ estimation_params = dict()
1160
+ # ICE
1161
+ if isinstance(conditional, pd.DataFrame):
1162
+ conditional = pd.DataFrame(conditional)
1163
+ return self._estimator.predict(treatment=conditional[self._T], y=conditional[self._y], **estimation_params)
1164
+ # CACE
1165
+ elif isinstance(conditional, str):
1166
+ cond_df = self._df.query(conditional)
1167
+ if len(cond_df) == 0:
1168
+ raise EmptyConditionError()
1169
+ predictions = self._estimator.predict(treatment=cond_df[self._T], y=cond_df[self._y], **estimation_params)
1170
+ return predictions.mean()
1171
+ # ACE
1172
+ elif conditional is None:
1173
+ return self._estimator.estimate_ate(treatment=self._df[self._T], y=self._df[self._y], **estimation_params)
1174
+ else:
1175
+ raise InvalidConditionError()