skxperiments 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. skxperiments/__init__.py +5 -0
  2. skxperiments/core/__init__.py +42 -0
  3. skxperiments/core/assignment.py +589 -0
  4. skxperiments/core/base.py +512 -0
  5. skxperiments/core/exceptions.py +145 -0
  6. skxperiments/core/potential_outcomes.py +168 -0
  7. skxperiments/core/results.py +624 -0
  8. skxperiments/design/__init__.py +22 -0
  9. skxperiments/design/balance.py +182 -0
  10. skxperiments/design/blocked_crd.py +157 -0
  11. skxperiments/design/crd.py +162 -0
  12. skxperiments/design/factorial.py +174 -0
  13. skxperiments/design/power.py +233 -0
  14. skxperiments/design/rerandomized_crd.py +319 -0
  15. skxperiments/diagnostics/__init__.py +21 -0
  16. skxperiments/diagnostics/aa_test.py +277 -0
  17. skxperiments/diagnostics/balance_report.py +224 -0
  18. skxperiments/diagnostics/srm.py +327 -0
  19. skxperiments/estimators/__init__.py +23 -0
  20. skxperiments/estimators/blocked_difference_in_means.py +197 -0
  21. skxperiments/estimators/cuped.py +280 -0
  22. skxperiments/estimators/difference_in_means.py +161 -0
  23. skxperiments/estimators/factorial_estimator.py +213 -0
  24. skxperiments/estimators/lin_estimator.py +298 -0
  25. skxperiments/inference/__init__.py +17 -0
  26. skxperiments/inference/bootstrap.py +450 -0
  27. skxperiments/inference/multiple.py +365 -0
  28. skxperiments/inference/neyman.py +386 -0
  29. skxperiments/inference/randomization_test.py +319 -0
  30. skxperiments/pipeline.py +366 -0
  31. skxperiments/reporting/__init__.py +30 -0
  32. skxperiments/reporting/plots.py +411 -0
  33. skxperiments/reporting/summary.py +185 -0
  34. skxperiments-0.1.0.dev0.dist-info/METADATA +272 -0
  35. skxperiments-0.1.0.dev0.dist-info/RECORD +36 -0
  36. skxperiments-0.1.0.dev0.dist-info/WHEEL +4 -0
@@ -0,0 +1,512 @@
1
+ """Abstract base classes for designs, estimators, and inference methods.
2
+
3
+ These classes define the API contract that all concrete implementations
4
+ must follow, ensuring consistency across the library.
5
+ """
6
+
7
+ import inspect
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass, field
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ import pandas as pd
13
+
14
+ from skxperiments.core.exceptions import (
15
+ DesignEstimatorMismatch,
16
+ InvalidDesignError,
17
+ NotFittedError,
18
+ )
19
+
20
+ if TYPE_CHECKING:
21
+ from skxperiments.core.assignment import BaseAssignment
22
+ from skxperiments.core.results import Results
23
+
24
+
25
+ def _check_assignment_type(
26
+ obj: Any,
27
+ assignment: Any,
28
+ expected_type: type | tuple[type, ...],
29
+ ) -> None:
30
+ """Validate that ``assignment`` is an instance of ``expected_type``.
31
+
32
+ Module-level helper shared by ``BaseEstimator._validate_assignment_type``
33
+ and ``BaseInference._validate_assignment_type``. Both ABCs expose
34
+ thin wrappers that delegate here, so the validation logic and the
35
+ error message format live in a single place.
36
+
37
+ Parameters
38
+ ----------
39
+ obj : Any
40
+ The estimator or inference instance calling this helper. Used
41
+ to populate ``estimator_name`` in the raised exception.
42
+ assignment : Any
43
+ The assignment object to validate.
44
+ expected_type : type or tuple of type
45
+ Acceptable type(s) for ``assignment``.
46
+
47
+ Raises
48
+ ------
49
+ DesignEstimatorMismatch
50
+ If ``assignment`` is not an instance of any type in
51
+ ``expected_type``.
52
+ """
53
+ if not isinstance(assignment, expected_type):
54
+ if isinstance(expected_type, tuple):
55
+ expected_name = " or ".join(t.__name__ for t in expected_type)
56
+ else:
57
+ expected_name = expected_type.__name__
58
+
59
+ raise DesignEstimatorMismatch(
60
+ estimator_name=type(obj).__name__,
61
+ received_type=type(assignment).__name__,
62
+ expected_type=expected_name,
63
+ )
64
+
65
+
66
+ class BaseDesign(ABC):
67
+ """Abstract base class for all experimental designs.
68
+
69
+ Subclasses must implement the randomize() method, which takes a
70
+ DataFrame and returns a BaseAssignment object.
71
+
72
+ Examples
73
+ --------
74
+ Subclasses define their parameters in __init__:
75
+
76
+ >>> class CRD(BaseDesign):
77
+ ... def __init__(self, n_treated=None, seed=None):
78
+ ... self.n_treated = n_treated
79
+ ... self.seed = seed
80
+ ... def randomize(self, df):
81
+ ... ...
82
+ """
83
+
84
+ @abstractmethod
85
+ def randomize(self, df: pd.DataFrame) -> "BaseAssignment":
86
+ """Perform randomization and return an Assignment object.
87
+
88
+ Parameters
89
+ ----------
90
+ df : pd.DataFrame
91
+ DataFrame containing the experimental units.
92
+
93
+ Returns
94
+ -------
95
+ BaseAssignment
96
+ Assignment object with treatment assignments.
97
+ """
98
+
99
+ def get_params(self) -> dict[str, Any]:
100
+ """Get parameters of this design.
101
+
102
+ Uses inspect.signature to introspect __init__ parameters.
103
+ Works in subclasses without override.
104
+
105
+ Returns
106
+ -------
107
+ dict
108
+ Parameter names mapped to their current values.
109
+ """
110
+ sig = inspect.signature(self.__init__) # type: ignore[misc]
111
+ params = {}
112
+ for name in sig.parameters:
113
+ if name == "self":
114
+ continue
115
+ params[name] = getattr(self, name, None)
116
+ return params
117
+
118
+ def set_params(self, **params: Any) -> "BaseDesign":
119
+ """Set parameters of this design.
120
+
121
+ Parameters
122
+ ----------
123
+ **params
124
+ Keyword arguments with parameter names and new values.
125
+
126
+ Returns
127
+ -------
128
+ BaseDesign
129
+ Returns self.
130
+
131
+ Raises
132
+ ------
133
+ InvalidDesignError
134
+ If a parameter name does not exist.
135
+ """
136
+ valid_params = self.get_params()
137
+ for key, value in params.items():
138
+ if key not in valid_params:
139
+ raise InvalidDesignError(
140
+ f"Invalid parameter '{key}' for {type(self).__name__}. "
141
+ f"Valid parameters: {list(valid_params.keys())}."
142
+ )
143
+ setattr(self, key, value)
144
+ return self
145
+
146
+ def __repr__(self) -> str:
147
+ """Return string representation with parameters.
148
+
149
+ Returns
150
+ -------
151
+ str
152
+ Format: ClassName(param1=val1, param2=val2)
153
+ """
154
+ params = self.get_params()
155
+ params_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
156
+ return f"{type(self).__name__}({params_str})"
157
+
158
+
159
+ class BaseEstimator(ABC):
160
+ """Abstract base class for all causal estimators.
161
+
162
+ Subclasses must implement fit() and estimate() methods.
163
+ The fit() method receives a BaseAssignment object (not a DataFrame).
164
+
165
+ Examples
166
+ --------
167
+ >>> class DifferenceInMeans(BaseEstimator):
168
+ ... def __init__(self, outcome_col="y"):
169
+ ... self.outcome_col = outcome_col
170
+ ... def fit(self, assignment):
171
+ ... ...
172
+ ... def estimate(self):
173
+ ... ...
174
+ """
175
+
176
+ @abstractmethod
177
+ def fit(self, assignment: Any) -> "BaseEstimator":
178
+ """Fit the estimator using an assignment object.
179
+
180
+ Parameters
181
+ ----------
182
+ assignment : BaseAssignment
183
+ Assignment object containing data and treatment assignments.
184
+
185
+ Returns
186
+ -------
187
+ BaseEstimator
188
+ Returns self.
189
+ """
190
+
191
+ @abstractmethod
192
+ def estimate(self) -> "Results":
193
+ """Compute the causal estimate.
194
+
195
+ Returns
196
+ -------
197
+ Results
198
+ Results object with the estimate and metadata.
199
+ """
200
+
201
+ def get_params(self) -> dict[str, Any]:
202
+ """Get parameters of this estimator.
203
+
204
+ Returns
205
+ -------
206
+ dict
207
+ Parameter names mapped to their current values.
208
+ """
209
+ sig = inspect.signature(self.__init__) # type: ignore[misc]
210
+ params = {}
211
+ for name in sig.parameters:
212
+ if name == "self":
213
+ continue
214
+ params[name] = getattr(self, name, None)
215
+ return params
216
+
217
+ def set_params(self, **params: Any) -> "BaseEstimator":
218
+ """Set parameters of this estimator.
219
+
220
+ Parameters
221
+ ----------
222
+ **params
223
+ Keyword arguments with parameter names and new values.
224
+
225
+ Returns
226
+ -------
227
+ BaseEstimator
228
+ Returns self.
229
+
230
+ Raises
231
+ ------
232
+ InvalidDesignError
233
+ If a parameter name does not exist.
234
+ """
235
+ valid_params = self.get_params()
236
+ for key, value in params.items():
237
+ if key not in valid_params:
238
+ raise InvalidDesignError(
239
+ f"Invalid parameter '{key}' for {type(self).__name__}. "
240
+ f"Valid parameters: {list(valid_params.keys())}."
241
+ )
242
+ setattr(self, key, value)
243
+ return self
244
+
245
+ def _check_is_fitted(self) -> None:
246
+ """Check if the estimator has been fitted.
247
+
248
+ Raises
249
+ ------
250
+ NotFittedError
251
+ If no attributes ending in underscore are found.
252
+ """
253
+ fitted_attrs = [
254
+ attr for attr in self.__dict__ if attr.endswith("_")
255
+ ]
256
+ if not fitted_attrs:
257
+ raise NotFittedError(
258
+ class_name=type(self).__name__,
259
+ required_methods=["fit"],
260
+ )
261
+
262
+ def _validate_assignment_type(
263
+ self,
264
+ assignment: Any,
265
+ expected_type: type | tuple[type, ...],
266
+ ) -> None:
267
+ """Validate that the assignment is of the expected type(s).
268
+
269
+ Thin wrapper that delegates to the module-level
270
+ ``_check_assignment_type``. Kept as a method for API
271
+ compatibility with concrete estimators that call
272
+ ``self._validate_assignment_type(...)``.
273
+
274
+ Parameters
275
+ ----------
276
+ assignment : Any
277
+ The assignment object to validate.
278
+ expected_type : type or tuple of type
279
+ The expected type(s) of the assignment. A tuple may be passed
280
+ when the estimator accepts multiple Assignment types (e.g.,
281
+ LinEstimator accepts both CRDAssignment and BlockedAssignment).
282
+
283
+ Raises
284
+ ------
285
+ DesignEstimatorMismatch
286
+ If the assignment type does not match any of the expected types.
287
+ """
288
+ _check_assignment_type(self, assignment, expected_type)
289
+
290
+ def __repr__(self) -> str:
291
+ """Return string representation with parameters.
292
+
293
+ Returns
294
+ -------
295
+ str
296
+ Format: ClassName(param1=val1, param2=val2)
297
+ """
298
+ params = self.get_params()
299
+ params_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
300
+ return f"{type(self).__name__}({params_str})"
301
+
302
+
303
+ class BaseInference(ABC):
304
+ """Abstract base class for all inference methods.
305
+
306
+ Subclasses configure their dependencies (e.g., a wrapped estimator)
307
+ in ``__init__``, receive a ``BaseAssignment`` in ``fit()``, and
308
+ produce a new ``Results`` object via ``estimate()``.
309
+
310
+ Subclasses **must** implement both ``fit()`` and ``estimate()``.
311
+
312
+ Contract
313
+ --------
314
+ - ``fit(assignment)`` populates instance attributes ending in
315
+ underscore (e.g., ``observed_statistic_``, ``p_value_``).
316
+ - ``estimate()`` produces a **new** ``Results`` object. It must
317
+ not mutate the estimator's ``Results`` or any other input.
318
+ - Subclasses may accept arbitrary parameters in ``__init__`` (e.g.,
319
+ a configured estimator, ``n_permutations``, ``alpha``).
320
+
321
+ Examples
322
+ --------
323
+ >>> class RandomizationTest(BaseInference):
324
+ ... def __init__(self, estimator, n_permutations=1000):
325
+ ... self.estimator = estimator
326
+ ... self.n_permutations = n_permutations
327
+ ... def fit(self, assignment):
328
+ ... ...
329
+ ... def estimate(self):
330
+ ... ...
331
+ """
332
+
333
+ @abstractmethod
334
+ def fit(self, assignment: Any) -> "BaseInference":
335
+ """Fit the inference method using an assignment object.
336
+
337
+ Parameters
338
+ ----------
339
+ assignment : BaseAssignment
340
+ Assignment object containing data and treatment assignments.
341
+
342
+ Returns
343
+ -------
344
+ BaseInference
345
+ Returns self.
346
+ """
347
+
348
+ @abstractmethod
349
+ def estimate(self) -> "Results":
350
+ """Compute the inferential result.
351
+
352
+ Returns
353
+ -------
354
+ Results
355
+ Results object with point estimate (copied from the underlying
356
+ estimator) and inferential quantities (p_value, ci, se as
357
+ applicable) populated. Always returns a NEW Results object;
358
+ never mutates the estimator's Results.
359
+ """
360
+
361
+ def get_params(self) -> dict[str, Any]:
362
+ """Get parameters of this inference method.
363
+
364
+ Returns
365
+ -------
366
+ dict
367
+ Parameter names mapped to their current values.
368
+ """
369
+ sig = inspect.signature(self.__init__) # type: ignore[misc]
370
+ params = {}
371
+ for name in sig.parameters:
372
+ if name == "self":
373
+ continue
374
+ params[name] = getattr(self, name, None)
375
+ return params
376
+
377
+ def set_params(self, **params: Any) -> "BaseInference":
378
+ """Set parameters of this inference method.
379
+
380
+ Parameters
381
+ ----------
382
+ **params
383
+ Keyword arguments with parameter names and new values.
384
+
385
+ Returns
386
+ -------
387
+ BaseInference
388
+ Returns self.
389
+
390
+ Raises
391
+ ------
392
+ InvalidDesignError
393
+ If a parameter name does not exist.
394
+ """
395
+ valid_params = self.get_params()
396
+ for key, value in params.items():
397
+ if key not in valid_params:
398
+ raise InvalidDesignError(
399
+ f"Invalid parameter '{key}' for {type(self).__name__}. "
400
+ f"Valid parameters: {list(valid_params.keys())}."
401
+ )
402
+ setattr(self, key, value)
403
+ return self
404
+
405
+ def _check_is_fitted(self) -> None:
406
+ """Check if the inference method has been fitted.
407
+
408
+ Raises
409
+ ------
410
+ NotFittedError
411
+ If no attributes ending in underscore are found.
412
+ """
413
+ fitted_attrs = [
414
+ attr for attr in self.__dict__ if attr.endswith("_")
415
+ ]
416
+ if not fitted_attrs:
417
+ raise NotFittedError(
418
+ class_name=type(self).__name__,
419
+ required_methods=["fit"],
420
+ )
421
+
422
+ def _validate_assignment_type(
423
+ self,
424
+ assignment: Any,
425
+ expected_type: type | tuple[type, ...],
426
+ ) -> None:
427
+ """Validate that the assignment is of the expected type(s).
428
+
429
+ Thin wrapper that delegates to the module-level
430
+ ``_check_assignment_type``. Mirrors
431
+ ``BaseEstimator._validate_assignment_type`` so inference
432
+ classes have the same validation surface as estimators.
433
+
434
+ Parameters
435
+ ----------
436
+ assignment : Any
437
+ The assignment object to validate.
438
+ expected_type : type or tuple of type
439
+ The expected type(s) of the assignment. A tuple may be passed
440
+ when the inference method accepts multiple Assignment types
441
+ (e.g., RandomizationTest accepts both CRDAssignment and
442
+ BlockedAssignment).
443
+
444
+ Raises
445
+ ------
446
+ DesignEstimatorMismatch
447
+ If the assignment type does not match any of the expected types.
448
+ """
449
+ _check_assignment_type(self, assignment, expected_type)
450
+
451
+ def __repr__(self) -> str:
452
+ """Return string representation with parameters.
453
+
454
+ Returns
455
+ -------
456
+ str
457
+ Format: ClassName(param1=val1, param2=val2)
458
+ """
459
+ params = self.get_params()
460
+ params_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
461
+ return f"{type(self).__name__}({params_str})"
462
+
463
+
464
+ @dataclass
465
+ class DiagnosticsReport:
466
+ """Report containing diagnostic flags and warnings.
467
+
468
+ Attributes
469
+ ----------
470
+ flags : list of str
471
+ Critical issues that should be addressed.
472
+ warnings : list of str
473
+ Non-critical warnings for the user's attention.
474
+
475
+ Examples
476
+ --------
477
+ >>> report = DiagnosticsReport()
478
+ >>> report.summary()
479
+ ✅ No issues found.
480
+ """
481
+
482
+ flags: list[str] = field(default_factory=list)
483
+ warnings: list[str] = field(default_factory=list)
484
+
485
+ def summary(self) -> None:
486
+ """Print diagnostic summary to stdout.
487
+
488
+ Prints flags with ❌ prefix, warnings with ⚠️ prefix.
489
+ If both are empty, prints ✅ No issues found.
490
+ """
491
+ if not self.flags and not self.warnings:
492
+ print("✅ No issues found.")
493
+ return
494
+
495
+ for flag in self.flags:
496
+ print(f"❌ {flag}")
497
+
498
+ for warning in self.warnings:
499
+ print(f"⚠️ {warning}")
500
+
501
+ def __repr__(self) -> str:
502
+ """Return string representation.
503
+
504
+ Returns
505
+ -------
506
+ str
507
+ Format: DiagnosticsReport(flags=N, warnings=N)
508
+ """
509
+ return (
510
+ f"DiagnosticsReport(flags={len(self.flags)}, "
511
+ f"warnings={len(self.warnings)})"
512
+ )
@@ -0,0 +1,145 @@
1
+ """Custom exceptions for the skxperiments library.
2
+
3
+ All exceptions inherit from SkxperimentsError, enabling users to catch
4
+ all library-specific errors with a single except clause.
5
+ """
6
+
7
+
8
+ class SkxperimentsError(Exception):
9
+ """Base exception for all skxperiments errors.
10
+
11
+ Parameters
12
+ ----------
13
+ message : str
14
+ Human-readable error description.
15
+ """
16
+
17
+ def __init__(self, message: str) -> None:
18
+ self.message = message
19
+ super().__init__(self.message)
20
+
21
+
22
+ class DesignEstimatorMismatch(SkxperimentsError):
23
+ """Raised when an estimator receives an incompatible Assignment type.
24
+
25
+ Parameters
26
+ ----------
27
+ estimator_name : str
28
+ Name of the estimator that detected the mismatch.
29
+ received_type : str
30
+ Name of the Assignment type that was received.
31
+ expected_type : str
32
+ Name of the Assignment type that was expected.
33
+ suggestion : str or None, optional
34
+ Suggested alternative estimator or design, by default None.
35
+
36
+ Examples
37
+ --------
38
+ >>> raise DesignEstimatorMismatch(
39
+ ... estimator_name="DifferenceInMeans",
40
+ ... received_type="BlockedAssignment",
41
+ ... expected_type="CRDAssignment",
42
+ ... suggestion="BlockedDifferenceInMeans",
43
+ ... )
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ estimator_name: str,
49
+ received_type: str,
50
+ expected_type: str,
51
+ suggestion: str | None = None,
52
+ ) -> None:
53
+ self.estimator_name = estimator_name
54
+ self.received_type = received_type
55
+ self.expected_type = expected_type
56
+ self.suggestion = suggestion
57
+
58
+ message = (
59
+ f"[{estimator_name}] expects {expected_type} "
60
+ f"but received {received_type}."
61
+ )
62
+ if suggestion is not None:
63
+ message += f" Suggestion: use {suggestion} instead."
64
+
65
+ super().__init__(message)
66
+
67
+
68
+ class NotFittedError(SkxperimentsError):
69
+ """Raised when methods dependent on fit() are called before fitting.
70
+
71
+ Parameters
72
+ ----------
73
+ class_name : str
74
+ Name of the class that is not yet fitted.
75
+ required_methods : list of str
76
+ Methods that must be called before using the object.
77
+
78
+ Examples
79
+ --------
80
+ >>> raise NotFittedError(
81
+ ... class_name="DifferenceInMeans",
82
+ ... required_methods=["fit"],
83
+ ... )
84
+ """
85
+
86
+ def __init__(self, class_name: str, required_methods: list[str]) -> None:
87
+ self.class_name = class_name
88
+ self.required_methods = required_methods
89
+
90
+ methods_str = ", ".join(f"{m}()" for m in required_methods)
91
+ message = (
92
+ f"[{class_name}] is not fitted. "
93
+ f"Call {methods_str} before using this object."
94
+ )
95
+ super().__init__(message)
96
+
97
+
98
+ class InsufficientDataError(SkxperimentsError):
99
+ """Raised when the number of units or block size is insufficient.
100
+
101
+ Parameters
102
+ ----------
103
+ context : str
104
+ Description of the operation requiring more data.
105
+ minimum : int
106
+ Minimum number of units required.
107
+ received : int
108
+ Actual number of units received.
109
+
110
+ Examples
111
+ --------
112
+ >>> raise InsufficientDataError(
113
+ ... context="CRD randomization",
114
+ ... minimum=2,
115
+ ... received=1,
116
+ ... )
117
+ """
118
+
119
+ def __init__(self, context: str, minimum: int, received: int) -> None:
120
+ self.context = context
121
+ self.minimum = minimum
122
+ self.received = received
123
+
124
+ message = (
125
+ f"{context} requires at least {minimum} units, "
126
+ f"but received {received}."
127
+ )
128
+ super().__init__(message)
129
+
130
+
131
+ class InvalidDesignError(SkxperimentsError):
132
+ """Raised when design parameters are inconsistent.
133
+
134
+ Parameters
135
+ ----------
136
+ message : str
137
+ Description of the inconsistency.
138
+
139
+ Examples
140
+ --------
141
+ >>> raise InvalidDesignError("Treatment probability must be in (0, 1).")
142
+ """
143
+
144
+ def __init__(self, message: str) -> None:
145
+ super().__init__(message)