skxperiments 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- skxperiments/__init__.py +5 -0
- skxperiments/core/__init__.py +42 -0
- skxperiments/core/assignment.py +589 -0
- skxperiments/core/base.py +512 -0
- skxperiments/core/exceptions.py +145 -0
- skxperiments/core/potential_outcomes.py +168 -0
- skxperiments/core/results.py +624 -0
- skxperiments/design/__init__.py +22 -0
- skxperiments/design/balance.py +182 -0
- skxperiments/design/blocked_crd.py +157 -0
- skxperiments/design/crd.py +162 -0
- skxperiments/design/factorial.py +174 -0
- skxperiments/design/power.py +233 -0
- skxperiments/design/rerandomized_crd.py +319 -0
- skxperiments/diagnostics/__init__.py +21 -0
- skxperiments/diagnostics/aa_test.py +277 -0
- skxperiments/diagnostics/balance_report.py +224 -0
- skxperiments/diagnostics/srm.py +327 -0
- skxperiments/estimators/__init__.py +23 -0
- skxperiments/estimators/blocked_difference_in_means.py +197 -0
- skxperiments/estimators/cuped.py +280 -0
- skxperiments/estimators/difference_in_means.py +161 -0
- skxperiments/estimators/factorial_estimator.py +213 -0
- skxperiments/estimators/lin_estimator.py +298 -0
- skxperiments/inference/__init__.py +17 -0
- skxperiments/inference/bootstrap.py +450 -0
- skxperiments/inference/multiple.py +365 -0
- skxperiments/inference/neyman.py +386 -0
- skxperiments/inference/randomization_test.py +319 -0
- skxperiments/pipeline.py +366 -0
- skxperiments/reporting/__init__.py +30 -0
- skxperiments/reporting/plots.py +411 -0
- skxperiments/reporting/summary.py +185 -0
- skxperiments-0.1.0.dev0.dist-info/METADATA +272 -0
- skxperiments-0.1.0.dev0.dist-info/RECORD +36 -0
- skxperiments-0.1.0.dev0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,512 @@
|
|
|
1
|
+
"""Abstract base classes for designs, estimators, and inference methods.
|
|
2
|
+
|
|
3
|
+
These classes define the API contract that all concrete implementations
|
|
4
|
+
must follow, ensuring consistency across the library.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import inspect
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from skxperiments.core.exceptions import (
|
|
15
|
+
DesignEstimatorMismatch,
|
|
16
|
+
InvalidDesignError,
|
|
17
|
+
NotFittedError,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from skxperiments.core.assignment import BaseAssignment
|
|
22
|
+
from skxperiments.core.results import Results
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _check_assignment_type(
|
|
26
|
+
obj: Any,
|
|
27
|
+
assignment: Any,
|
|
28
|
+
expected_type: type | tuple[type, ...],
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Validate that ``assignment`` is an instance of ``expected_type``.
|
|
31
|
+
|
|
32
|
+
Module-level helper shared by ``BaseEstimator._validate_assignment_type``
|
|
33
|
+
and ``BaseInference._validate_assignment_type``. Both ABCs expose
|
|
34
|
+
thin wrappers that delegate here, so the validation logic and the
|
|
35
|
+
error message format live in a single place.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
obj : Any
|
|
40
|
+
The estimator or inference instance calling this helper. Used
|
|
41
|
+
to populate ``estimator_name`` in the raised exception.
|
|
42
|
+
assignment : Any
|
|
43
|
+
The assignment object to validate.
|
|
44
|
+
expected_type : type or tuple of type
|
|
45
|
+
Acceptable type(s) for ``assignment``.
|
|
46
|
+
|
|
47
|
+
Raises
|
|
48
|
+
------
|
|
49
|
+
DesignEstimatorMismatch
|
|
50
|
+
If ``assignment`` is not an instance of any type in
|
|
51
|
+
``expected_type``.
|
|
52
|
+
"""
|
|
53
|
+
if not isinstance(assignment, expected_type):
|
|
54
|
+
if isinstance(expected_type, tuple):
|
|
55
|
+
expected_name = " or ".join(t.__name__ for t in expected_type)
|
|
56
|
+
else:
|
|
57
|
+
expected_name = expected_type.__name__
|
|
58
|
+
|
|
59
|
+
raise DesignEstimatorMismatch(
|
|
60
|
+
estimator_name=type(obj).__name__,
|
|
61
|
+
received_type=type(assignment).__name__,
|
|
62
|
+
expected_type=expected_name,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class BaseDesign(ABC):
|
|
67
|
+
"""Abstract base class for all experimental designs.
|
|
68
|
+
|
|
69
|
+
Subclasses must implement the randomize() method, which takes a
|
|
70
|
+
DataFrame and returns a BaseAssignment object.
|
|
71
|
+
|
|
72
|
+
Examples
|
|
73
|
+
--------
|
|
74
|
+
Subclasses define their parameters in __init__:
|
|
75
|
+
|
|
76
|
+
>>> class CRD(BaseDesign):
|
|
77
|
+
... def __init__(self, n_treated=None, seed=None):
|
|
78
|
+
... self.n_treated = n_treated
|
|
79
|
+
... self.seed = seed
|
|
80
|
+
... def randomize(self, df):
|
|
81
|
+
... ...
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def randomize(self, df: pd.DataFrame) -> "BaseAssignment":
|
|
86
|
+
"""Perform randomization and return an Assignment object.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
df : pd.DataFrame
|
|
91
|
+
DataFrame containing the experimental units.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
BaseAssignment
|
|
96
|
+
Assignment object with treatment assignments.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def get_params(self) -> dict[str, Any]:
|
|
100
|
+
"""Get parameters of this design.
|
|
101
|
+
|
|
102
|
+
Uses inspect.signature to introspect __init__ parameters.
|
|
103
|
+
Works in subclasses without override.
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
dict
|
|
108
|
+
Parameter names mapped to their current values.
|
|
109
|
+
"""
|
|
110
|
+
sig = inspect.signature(self.__init__) # type: ignore[misc]
|
|
111
|
+
params = {}
|
|
112
|
+
for name in sig.parameters:
|
|
113
|
+
if name == "self":
|
|
114
|
+
continue
|
|
115
|
+
params[name] = getattr(self, name, None)
|
|
116
|
+
return params
|
|
117
|
+
|
|
118
|
+
def set_params(self, **params: Any) -> "BaseDesign":
|
|
119
|
+
"""Set parameters of this design.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
**params
|
|
124
|
+
Keyword arguments with parameter names and new values.
|
|
125
|
+
|
|
126
|
+
Returns
|
|
127
|
+
-------
|
|
128
|
+
BaseDesign
|
|
129
|
+
Returns self.
|
|
130
|
+
|
|
131
|
+
Raises
|
|
132
|
+
------
|
|
133
|
+
InvalidDesignError
|
|
134
|
+
If a parameter name does not exist.
|
|
135
|
+
"""
|
|
136
|
+
valid_params = self.get_params()
|
|
137
|
+
for key, value in params.items():
|
|
138
|
+
if key not in valid_params:
|
|
139
|
+
raise InvalidDesignError(
|
|
140
|
+
f"Invalid parameter '{key}' for {type(self).__name__}. "
|
|
141
|
+
f"Valid parameters: {list(valid_params.keys())}."
|
|
142
|
+
)
|
|
143
|
+
setattr(self, key, value)
|
|
144
|
+
return self
|
|
145
|
+
|
|
146
|
+
def __repr__(self) -> str:
|
|
147
|
+
"""Return string representation with parameters.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
str
|
|
152
|
+
Format: ClassName(param1=val1, param2=val2)
|
|
153
|
+
"""
|
|
154
|
+
params = self.get_params()
|
|
155
|
+
params_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
|
|
156
|
+
return f"{type(self).__name__}({params_str})"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class BaseEstimator(ABC):
|
|
160
|
+
"""Abstract base class for all causal estimators.
|
|
161
|
+
|
|
162
|
+
Subclasses must implement fit() and estimate() methods.
|
|
163
|
+
The fit() method receives a BaseAssignment object (not a DataFrame).
|
|
164
|
+
|
|
165
|
+
Examples
|
|
166
|
+
--------
|
|
167
|
+
>>> class DifferenceInMeans(BaseEstimator):
|
|
168
|
+
... def __init__(self, outcome_col="y"):
|
|
169
|
+
... self.outcome_col = outcome_col
|
|
170
|
+
... def fit(self, assignment):
|
|
171
|
+
... ...
|
|
172
|
+
... def estimate(self):
|
|
173
|
+
... ...
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
@abstractmethod
|
|
177
|
+
def fit(self, assignment: Any) -> "BaseEstimator":
|
|
178
|
+
"""Fit the estimator using an assignment object.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
assignment : BaseAssignment
|
|
183
|
+
Assignment object containing data and treatment assignments.
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
BaseEstimator
|
|
188
|
+
Returns self.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
@abstractmethod
|
|
192
|
+
def estimate(self) -> "Results":
|
|
193
|
+
"""Compute the causal estimate.
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
Results
|
|
198
|
+
Results object with the estimate and metadata.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def get_params(self) -> dict[str, Any]:
|
|
202
|
+
"""Get parameters of this estimator.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
dict
|
|
207
|
+
Parameter names mapped to their current values.
|
|
208
|
+
"""
|
|
209
|
+
sig = inspect.signature(self.__init__) # type: ignore[misc]
|
|
210
|
+
params = {}
|
|
211
|
+
for name in sig.parameters:
|
|
212
|
+
if name == "self":
|
|
213
|
+
continue
|
|
214
|
+
params[name] = getattr(self, name, None)
|
|
215
|
+
return params
|
|
216
|
+
|
|
217
|
+
def set_params(self, **params: Any) -> "BaseEstimator":
|
|
218
|
+
"""Set parameters of this estimator.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
**params
|
|
223
|
+
Keyword arguments with parameter names and new values.
|
|
224
|
+
|
|
225
|
+
Returns
|
|
226
|
+
-------
|
|
227
|
+
BaseEstimator
|
|
228
|
+
Returns self.
|
|
229
|
+
|
|
230
|
+
Raises
|
|
231
|
+
------
|
|
232
|
+
InvalidDesignError
|
|
233
|
+
If a parameter name does not exist.
|
|
234
|
+
"""
|
|
235
|
+
valid_params = self.get_params()
|
|
236
|
+
for key, value in params.items():
|
|
237
|
+
if key not in valid_params:
|
|
238
|
+
raise InvalidDesignError(
|
|
239
|
+
f"Invalid parameter '{key}' for {type(self).__name__}. "
|
|
240
|
+
f"Valid parameters: {list(valid_params.keys())}."
|
|
241
|
+
)
|
|
242
|
+
setattr(self, key, value)
|
|
243
|
+
return self
|
|
244
|
+
|
|
245
|
+
def _check_is_fitted(self) -> None:
|
|
246
|
+
"""Check if the estimator has been fitted.
|
|
247
|
+
|
|
248
|
+
Raises
|
|
249
|
+
------
|
|
250
|
+
NotFittedError
|
|
251
|
+
If no attributes ending in underscore are found.
|
|
252
|
+
"""
|
|
253
|
+
fitted_attrs = [
|
|
254
|
+
attr for attr in self.__dict__ if attr.endswith("_")
|
|
255
|
+
]
|
|
256
|
+
if not fitted_attrs:
|
|
257
|
+
raise NotFittedError(
|
|
258
|
+
class_name=type(self).__name__,
|
|
259
|
+
required_methods=["fit"],
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def _validate_assignment_type(
|
|
263
|
+
self,
|
|
264
|
+
assignment: Any,
|
|
265
|
+
expected_type: type | tuple[type, ...],
|
|
266
|
+
) -> None:
|
|
267
|
+
"""Validate that the assignment is of the expected type(s).
|
|
268
|
+
|
|
269
|
+
Thin wrapper that delegates to the module-level
|
|
270
|
+
``_check_assignment_type``. Kept as a method for API
|
|
271
|
+
compatibility with concrete estimators that call
|
|
272
|
+
``self._validate_assignment_type(...)``.
|
|
273
|
+
|
|
274
|
+
Parameters
|
|
275
|
+
----------
|
|
276
|
+
assignment : Any
|
|
277
|
+
The assignment object to validate.
|
|
278
|
+
expected_type : type or tuple of type
|
|
279
|
+
The expected type(s) of the assignment. A tuple may be passed
|
|
280
|
+
when the estimator accepts multiple Assignment types (e.g.,
|
|
281
|
+
LinEstimator accepts both CRDAssignment and BlockedAssignment).
|
|
282
|
+
|
|
283
|
+
Raises
|
|
284
|
+
------
|
|
285
|
+
DesignEstimatorMismatch
|
|
286
|
+
If the assignment type does not match any of the expected types.
|
|
287
|
+
"""
|
|
288
|
+
_check_assignment_type(self, assignment, expected_type)
|
|
289
|
+
|
|
290
|
+
def __repr__(self) -> str:
|
|
291
|
+
"""Return string representation with parameters.
|
|
292
|
+
|
|
293
|
+
Returns
|
|
294
|
+
-------
|
|
295
|
+
str
|
|
296
|
+
Format: ClassName(param1=val1, param2=val2)
|
|
297
|
+
"""
|
|
298
|
+
params = self.get_params()
|
|
299
|
+
params_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
|
|
300
|
+
return f"{type(self).__name__}({params_str})"
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class BaseInference(ABC):
|
|
304
|
+
"""Abstract base class for all inference methods.
|
|
305
|
+
|
|
306
|
+
Subclasses configure their dependencies (e.g., a wrapped estimator)
|
|
307
|
+
in ``__init__``, receive a ``BaseAssignment`` in ``fit()``, and
|
|
308
|
+
produce a new ``Results`` object via ``estimate()``.
|
|
309
|
+
|
|
310
|
+
Subclasses **must** implement both ``fit()`` and ``estimate()``.
|
|
311
|
+
|
|
312
|
+
Contract
|
|
313
|
+
--------
|
|
314
|
+
- ``fit(assignment)`` populates instance attributes ending in
|
|
315
|
+
underscore (e.g., ``observed_statistic_``, ``p_value_``).
|
|
316
|
+
- ``estimate()`` produces a **new** ``Results`` object. It must
|
|
317
|
+
not mutate the estimator's ``Results`` or any other input.
|
|
318
|
+
- Subclasses may accept arbitrary parameters in ``__init__`` (e.g.,
|
|
319
|
+
a configured estimator, ``n_permutations``, ``alpha``).
|
|
320
|
+
|
|
321
|
+
Examples
|
|
322
|
+
--------
|
|
323
|
+
>>> class RandomizationTest(BaseInference):
|
|
324
|
+
... def __init__(self, estimator, n_permutations=1000):
|
|
325
|
+
... self.estimator = estimator
|
|
326
|
+
... self.n_permutations = n_permutations
|
|
327
|
+
... def fit(self, assignment):
|
|
328
|
+
... ...
|
|
329
|
+
... def estimate(self):
|
|
330
|
+
... ...
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
@abstractmethod
|
|
334
|
+
def fit(self, assignment: Any) -> "BaseInference":
|
|
335
|
+
"""Fit the inference method using an assignment object.
|
|
336
|
+
|
|
337
|
+
Parameters
|
|
338
|
+
----------
|
|
339
|
+
assignment : BaseAssignment
|
|
340
|
+
Assignment object containing data and treatment assignments.
|
|
341
|
+
|
|
342
|
+
Returns
|
|
343
|
+
-------
|
|
344
|
+
BaseInference
|
|
345
|
+
Returns self.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
@abstractmethod
|
|
349
|
+
def estimate(self) -> "Results":
|
|
350
|
+
"""Compute the inferential result.
|
|
351
|
+
|
|
352
|
+
Returns
|
|
353
|
+
-------
|
|
354
|
+
Results
|
|
355
|
+
Results object with point estimate (copied from the underlying
|
|
356
|
+
estimator) and inferential quantities (p_value, ci, se as
|
|
357
|
+
applicable) populated. Always returns a NEW Results object;
|
|
358
|
+
never mutates the estimator's Results.
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
def get_params(self) -> dict[str, Any]:
|
|
362
|
+
"""Get parameters of this inference method.
|
|
363
|
+
|
|
364
|
+
Returns
|
|
365
|
+
-------
|
|
366
|
+
dict
|
|
367
|
+
Parameter names mapped to their current values.
|
|
368
|
+
"""
|
|
369
|
+
sig = inspect.signature(self.__init__) # type: ignore[misc]
|
|
370
|
+
params = {}
|
|
371
|
+
for name in sig.parameters:
|
|
372
|
+
if name == "self":
|
|
373
|
+
continue
|
|
374
|
+
params[name] = getattr(self, name, None)
|
|
375
|
+
return params
|
|
376
|
+
|
|
377
|
+
def set_params(self, **params: Any) -> "BaseInference":
|
|
378
|
+
"""Set parameters of this inference method.
|
|
379
|
+
|
|
380
|
+
Parameters
|
|
381
|
+
----------
|
|
382
|
+
**params
|
|
383
|
+
Keyword arguments with parameter names and new values.
|
|
384
|
+
|
|
385
|
+
Returns
|
|
386
|
+
-------
|
|
387
|
+
BaseInference
|
|
388
|
+
Returns self.
|
|
389
|
+
|
|
390
|
+
Raises
|
|
391
|
+
------
|
|
392
|
+
InvalidDesignError
|
|
393
|
+
If a parameter name does not exist.
|
|
394
|
+
"""
|
|
395
|
+
valid_params = self.get_params()
|
|
396
|
+
for key, value in params.items():
|
|
397
|
+
if key not in valid_params:
|
|
398
|
+
raise InvalidDesignError(
|
|
399
|
+
f"Invalid parameter '{key}' for {type(self).__name__}. "
|
|
400
|
+
f"Valid parameters: {list(valid_params.keys())}."
|
|
401
|
+
)
|
|
402
|
+
setattr(self, key, value)
|
|
403
|
+
return self
|
|
404
|
+
|
|
405
|
+
def _check_is_fitted(self) -> None:
|
|
406
|
+
"""Check if the inference method has been fitted.
|
|
407
|
+
|
|
408
|
+
Raises
|
|
409
|
+
------
|
|
410
|
+
NotFittedError
|
|
411
|
+
If no attributes ending in underscore are found.
|
|
412
|
+
"""
|
|
413
|
+
fitted_attrs = [
|
|
414
|
+
attr for attr in self.__dict__ if attr.endswith("_")
|
|
415
|
+
]
|
|
416
|
+
if not fitted_attrs:
|
|
417
|
+
raise NotFittedError(
|
|
418
|
+
class_name=type(self).__name__,
|
|
419
|
+
required_methods=["fit"],
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
def _validate_assignment_type(
|
|
423
|
+
self,
|
|
424
|
+
assignment: Any,
|
|
425
|
+
expected_type: type | tuple[type, ...],
|
|
426
|
+
) -> None:
|
|
427
|
+
"""Validate that the assignment is of the expected type(s).
|
|
428
|
+
|
|
429
|
+
Thin wrapper that delegates to the module-level
|
|
430
|
+
``_check_assignment_type``. Mirrors
|
|
431
|
+
``BaseEstimator._validate_assignment_type`` so inference
|
|
432
|
+
classes have the same validation surface as estimators.
|
|
433
|
+
|
|
434
|
+
Parameters
|
|
435
|
+
----------
|
|
436
|
+
assignment : Any
|
|
437
|
+
The assignment object to validate.
|
|
438
|
+
expected_type : type or tuple of type
|
|
439
|
+
The expected type(s) of the assignment. A tuple may be passed
|
|
440
|
+
when the inference method accepts multiple Assignment types
|
|
441
|
+
(e.g., RandomizationTest accepts both CRDAssignment and
|
|
442
|
+
BlockedAssignment).
|
|
443
|
+
|
|
444
|
+
Raises
|
|
445
|
+
------
|
|
446
|
+
DesignEstimatorMismatch
|
|
447
|
+
If the assignment type does not match any of the expected types.
|
|
448
|
+
"""
|
|
449
|
+
_check_assignment_type(self, assignment, expected_type)
|
|
450
|
+
|
|
451
|
+
def __repr__(self) -> str:
|
|
452
|
+
"""Return string representation with parameters.
|
|
453
|
+
|
|
454
|
+
Returns
|
|
455
|
+
-------
|
|
456
|
+
str
|
|
457
|
+
Format: ClassName(param1=val1, param2=val2)
|
|
458
|
+
"""
|
|
459
|
+
params = self.get_params()
|
|
460
|
+
params_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
|
|
461
|
+
return f"{type(self).__name__}({params_str})"
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@dataclass
|
|
465
|
+
class DiagnosticsReport:
|
|
466
|
+
"""Report containing diagnostic flags and warnings.
|
|
467
|
+
|
|
468
|
+
Attributes
|
|
469
|
+
----------
|
|
470
|
+
flags : list of str
|
|
471
|
+
Critical issues that should be addressed.
|
|
472
|
+
warnings : list of str
|
|
473
|
+
Non-critical warnings for the user's attention.
|
|
474
|
+
|
|
475
|
+
Examples
|
|
476
|
+
--------
|
|
477
|
+
>>> report = DiagnosticsReport()
|
|
478
|
+
>>> report.summary()
|
|
479
|
+
✅ No issues found.
|
|
480
|
+
"""
|
|
481
|
+
|
|
482
|
+
flags: list[str] = field(default_factory=list)
|
|
483
|
+
warnings: list[str] = field(default_factory=list)
|
|
484
|
+
|
|
485
|
+
def summary(self) -> None:
|
|
486
|
+
"""Print diagnostic summary to stdout.
|
|
487
|
+
|
|
488
|
+
Prints flags with ❌ prefix, warnings with ⚠️ prefix.
|
|
489
|
+
If both are empty, prints ✅ No issues found.
|
|
490
|
+
"""
|
|
491
|
+
if not self.flags and not self.warnings:
|
|
492
|
+
print("✅ No issues found.")
|
|
493
|
+
return
|
|
494
|
+
|
|
495
|
+
for flag in self.flags:
|
|
496
|
+
print(f"❌ {flag}")
|
|
497
|
+
|
|
498
|
+
for warning in self.warnings:
|
|
499
|
+
print(f"⚠️ {warning}")
|
|
500
|
+
|
|
501
|
+
def __repr__(self) -> str:
|
|
502
|
+
"""Return string representation.
|
|
503
|
+
|
|
504
|
+
Returns
|
|
505
|
+
-------
|
|
506
|
+
str
|
|
507
|
+
Format: DiagnosticsReport(flags=N, warnings=N)
|
|
508
|
+
"""
|
|
509
|
+
return (
|
|
510
|
+
f"DiagnosticsReport(flags={len(self.flags)}, "
|
|
511
|
+
f"warnings={len(self.warnings)})"
|
|
512
|
+
)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Custom exceptions for the skxperiments library.
|
|
2
|
+
|
|
3
|
+
All exceptions inherit from SkxperimentsError, enabling users to catch
|
|
4
|
+
all library-specific errors with a single except clause.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SkxperimentsError(Exception):
|
|
9
|
+
"""Base exception for all skxperiments errors.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
message : str
|
|
14
|
+
Human-readable error description.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, message: str) -> None:
|
|
18
|
+
self.message = message
|
|
19
|
+
super().__init__(self.message)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DesignEstimatorMismatch(SkxperimentsError):
|
|
23
|
+
"""Raised when an estimator receives an incompatible Assignment type.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
estimator_name : str
|
|
28
|
+
Name of the estimator that detected the mismatch.
|
|
29
|
+
received_type : str
|
|
30
|
+
Name of the Assignment type that was received.
|
|
31
|
+
expected_type : str
|
|
32
|
+
Name of the Assignment type that was expected.
|
|
33
|
+
suggestion : str or None, optional
|
|
34
|
+
Suggested alternative estimator or design, by default None.
|
|
35
|
+
|
|
36
|
+
Examples
|
|
37
|
+
--------
|
|
38
|
+
>>> raise DesignEstimatorMismatch(
|
|
39
|
+
... estimator_name="DifferenceInMeans",
|
|
40
|
+
... received_type="BlockedAssignment",
|
|
41
|
+
... expected_type="CRDAssignment",
|
|
42
|
+
... suggestion="BlockedDifferenceInMeans",
|
|
43
|
+
... )
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
estimator_name: str,
|
|
49
|
+
received_type: str,
|
|
50
|
+
expected_type: str,
|
|
51
|
+
suggestion: str | None = None,
|
|
52
|
+
) -> None:
|
|
53
|
+
self.estimator_name = estimator_name
|
|
54
|
+
self.received_type = received_type
|
|
55
|
+
self.expected_type = expected_type
|
|
56
|
+
self.suggestion = suggestion
|
|
57
|
+
|
|
58
|
+
message = (
|
|
59
|
+
f"[{estimator_name}] expects {expected_type} "
|
|
60
|
+
f"but received {received_type}."
|
|
61
|
+
)
|
|
62
|
+
if suggestion is not None:
|
|
63
|
+
message += f" Suggestion: use {suggestion} instead."
|
|
64
|
+
|
|
65
|
+
super().__init__(message)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class NotFittedError(SkxperimentsError):
|
|
69
|
+
"""Raised when methods dependent on fit() are called before fitting.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
class_name : str
|
|
74
|
+
Name of the class that is not yet fitted.
|
|
75
|
+
required_methods : list of str
|
|
76
|
+
Methods that must be called before using the object.
|
|
77
|
+
|
|
78
|
+
Examples
|
|
79
|
+
--------
|
|
80
|
+
>>> raise NotFittedError(
|
|
81
|
+
... class_name="DifferenceInMeans",
|
|
82
|
+
... required_methods=["fit"],
|
|
83
|
+
... )
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, class_name: str, required_methods: list[str]) -> None:
|
|
87
|
+
self.class_name = class_name
|
|
88
|
+
self.required_methods = required_methods
|
|
89
|
+
|
|
90
|
+
methods_str = ", ".join(f"{m}()" for m in required_methods)
|
|
91
|
+
message = (
|
|
92
|
+
f"[{class_name}] is not fitted. "
|
|
93
|
+
f"Call {methods_str} before using this object."
|
|
94
|
+
)
|
|
95
|
+
super().__init__(message)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class InsufficientDataError(SkxperimentsError):
|
|
99
|
+
"""Raised when the number of units or block size is insufficient.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
context : str
|
|
104
|
+
Description of the operation requiring more data.
|
|
105
|
+
minimum : int
|
|
106
|
+
Minimum number of units required.
|
|
107
|
+
received : int
|
|
108
|
+
Actual number of units received.
|
|
109
|
+
|
|
110
|
+
Examples
|
|
111
|
+
--------
|
|
112
|
+
>>> raise InsufficientDataError(
|
|
113
|
+
... context="CRD randomization",
|
|
114
|
+
... minimum=2,
|
|
115
|
+
... received=1,
|
|
116
|
+
... )
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(self, context: str, minimum: int, received: int) -> None:
|
|
120
|
+
self.context = context
|
|
121
|
+
self.minimum = minimum
|
|
122
|
+
self.received = received
|
|
123
|
+
|
|
124
|
+
message = (
|
|
125
|
+
f"{context} requires at least {minimum} units, "
|
|
126
|
+
f"but received {received}."
|
|
127
|
+
)
|
|
128
|
+
super().__init__(message)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class InvalidDesignError(SkxperimentsError):
|
|
132
|
+
"""Raised when design parameters are inconsistent.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
message : str
|
|
137
|
+
Description of the inconsistency.
|
|
138
|
+
|
|
139
|
+
Examples
|
|
140
|
+
--------
|
|
141
|
+
>>> raise InvalidDesignError("Treatment probability must be in (0, 1).")
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def __init__(self, message: str) -> None:
|
|
145
|
+
super().__init__(message)
|