diddesign 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diddesign/__init__.py ADDED
@@ -0,0 +1,89 @@
1
+ """diddesign: Double Difference-in-Differences for Python.
2
+
3
+ This package implements the multiple-pre-treatment DID estimator proposed by
4
+ Egami and Yamauchi (2023, Political Analysis). It combines standard DID and
5
+ sequential DID via efficient GMM weighting, extending to K-DID for panels
6
+ with three or more pre-treatment periods and to staggered-adoption designs
7
+ with lead-specific estimates.
8
+
9
+ The public interface consists of two estimation functions—:func:`did` for
10
+ treatment effect estimation and :func:`did_check` for pre-treatment
11
+ diagnostics—together with immutable result objects (:class:`DidResult`,
12
+ :class:`DidCheckResult`) whose frame accessors return pandas DataFrames
13
+ for downstream analysis, plotting, and LaTeX export.
14
+ """
15
+
16
+ from importlib.metadata import PackageNotFoundError, version as _package_version
17
+
18
+ from .core.data_contracts import DataContractError, DidDataError
19
+ from .diagnostics import DidCheckDiagnosticRow, DidCheckPatternRow, DidCheckResult, DidCheckTrendRow, did_check
20
+ from .diagnostics_reporter import DiagnosticsReporter
21
+ from .errors import (
22
+ DidError,
23
+ DidRuntimeError,
24
+ DidValueError,
25
+ DidWarning,
26
+ ErrorCode,
27
+ WarningCode,
28
+ did_warn,
29
+ )
30
+ from .estimators import did
31
+ from .formula import DidFormulaSpec, did_formula
32
+ from .plotting import check, fit
33
+ from .results import (
34
+ DidBootstrapDraw,
35
+ DidBootstrapDrawK,
36
+ DidEstimateRow,
37
+ DidGmmAuditRow,
38
+ DidGmmRow,
39
+ DidResult,
40
+ DidWeightRow,
41
+ format_summary,
42
+ summary,
43
+ )
44
+ from .visualization import plot_diagnostics, plot_estimates, plot_pattern, plot_placebo, plot_trends
45
+
46
+ DIDResult = DidResult
47
+
48
+ try:
49
+ __version__ = _package_version("diddesign")
50
+ except PackageNotFoundError:
51
+ __version__ = "0.1.0"
52
+
53
+ __all__ = [
54
+ "__version__",
55
+ "DataContractError",
56
+ "DiagnosticsReporter",
57
+ "DidDataError",
58
+ "DidError",
59
+ "DidFormulaSpec",
60
+ "DidRuntimeError",
61
+ "DidValueError",
62
+ "DidWarning",
63
+ "ErrorCode",
64
+ "WarningCode",
65
+ "did_warn",
66
+ "DidBootstrapDraw",
67
+ "DidBootstrapDrawK",
68
+ "DidCheckDiagnosticRow",
69
+ "DidCheckPatternRow",
70
+ "DidCheckResult",
71
+ "DidCheckTrendRow",
72
+ "DidEstimateRow",
73
+ "DidGmmRow",
74
+ "DidResult",
75
+ "DIDResult",
76
+ "DidWeightRow",
77
+ "check",
78
+ "did",
79
+ "did_check",
80
+ "did_formula",
81
+ "fit",
82
+ "format_summary",
83
+ "plot_diagnostics",
84
+ "plot_estimates",
85
+ "plot_pattern",
86
+ "plot_placebo",
87
+ "plot_trends",
88
+ "summary",
89
+ ]
@@ -0,0 +1,12 @@
1
+ """Core data validation helpers."""
2
+
3
+ from .data_contracts import DataContractError, DidDataError, NormalizedDataContract, normalize_design_data
4
+ from .validation import validate_sa_panel_preconditions
5
+
6
+ __all__ = [
7
+ "DataContractError",
8
+ "DidDataError",
9
+ "NormalizedDataContract",
10
+ "normalize_design_data",
11
+ "validate_sa_panel_preconditions",
12
+ ]
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ from ..errors import ErrorCode, DidValueError
7
+ from .validation import (
8
+ DataContractError,
9
+ DidDataError,
10
+ materialize_rows,
11
+ require_binary_indicator,
12
+ require_column,
13
+ resolve_time_order_metadata,
14
+ validate_rcs_post_indicator,
15
+ validate_sa_panel_preconditions,
16
+ validate_design,
17
+ validate_sa_treatment_path,
18
+ validate_standard_did_panel_treatment_path,
19
+ validate_unique_panel_cells,
20
+ )
21
+
22
+
23
+ def _require_distinct_role_columns(**roles: str | None) -> None:
24
+ seen: dict[str, str] = {}
25
+ for role, column in roles.items():
26
+ if column is None:
27
+ continue
28
+ previous_role = seen.get(column)
29
+ if previous_role is not None:
30
+ raise DidValueError(
31
+ ErrorCode.E001,
32
+ f"{role} column must be distinct from {previous_role} column.",
33
+ context={"role": role, "previous_role": previous_role, "column": column},
34
+ )
35
+ seen[column] = role
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class NormalizedDataContract:
40
+ """Minimal normalized metadata for downstream diagnostics and estimators."""
41
+
42
+ design: str
43
+ data_type: str
44
+ branch: str
45
+ outcome: str
46
+ treatment: str
47
+ time: str
48
+ unit_id: str | None
49
+ post: str | None
50
+ time_order: tuple[Any, ...]
51
+ cluster_default: str | None
52
+ validation_trace: tuple[str, ...]
53
+
54
+ def as_metadata(self) -> dict[str, Any]:
55
+ return {
56
+ "design": self.design,
57
+ "data_type": self.data_type,
58
+ "branch": self.branch,
59
+ "outcome": self.outcome,
60
+ "treatment": self.treatment,
61
+ "time": self.time,
62
+ "unit_id": self.unit_id,
63
+ "post": self.post,
64
+ "time_order": self.time_order,
65
+ "cluster_default": self.cluster_default,
66
+ "validation_trace": self.validation_trace,
67
+ }
68
+
69
+
70
+ def normalize_design_data(
71
+ rows,
72
+ *,
73
+ outcome: str,
74
+ treatment: str,
75
+ time: str,
76
+ unit_id: str | None = None,
77
+ post: str | None = None,
78
+ design: str = "did",
79
+ data_type: str = "panel",
80
+ ) -> NormalizedDataContract:
81
+ materialized = materialize_rows(rows)
82
+ validate_design(design, data_type)
83
+ require_column(materialized, outcome, allow_missing=True, field_name="outcome")
84
+ require_column(materialized, treatment, field_name="treatment")
85
+ require_column(materialized, time, field_name="time")
86
+ _require_distinct_role_columns(
87
+ outcome=outcome,
88
+ treatment=treatment,
89
+ time=time,
90
+ unit_id=unit_id if data_type == "panel" else None,
91
+ post=post if data_type == "rcs" else None,
92
+ )
93
+ require_binary_indicator(materialized, column=treatment, label="treatment")
94
+
95
+ validation_trace: list[str] = [
96
+ "required:outcome",
97
+ "required:treatment",
98
+ "required:time",
99
+ "binary:treatment",
100
+ ]
101
+ cluster_default: str | None = None
102
+ if data_type == "panel":
103
+ if unit_id is None:
104
+ raise DidValueError(
105
+ ErrorCode.E001,
106
+ "unit_id is required for panel data.",
107
+ context={"field_name": "unit_id", "data_type": data_type},
108
+ )
109
+ require_column(materialized, unit_id, field_name="unit_id")
110
+ cluster_default = unit_id
111
+ validation_trace.append("required:unit_id")
112
+ if design == "sa":
113
+ time_order = validate_sa_panel_preconditions(materialized, unit_id=unit_id, time=time)
114
+ validate_sa_treatment_path(
115
+ materialized,
116
+ unit_id=unit_id,
117
+ time=time,
118
+ treatment=treatment,
119
+ time_order=time_order,
120
+ )
121
+ _, time_label_kind = resolve_time_order_metadata(materialized, time=time)
122
+ validation_trace.extend(
123
+ ("balanced-panel", "unique:unit-time", f"time-order:{time_label_kind}", "absorbing:treatment")
124
+ )
125
+ else:
126
+ validate_unique_panel_cells(materialized, unit_id=unit_id, time=time)
127
+ time_order, time_label_kind = resolve_time_order_metadata(materialized, time=time)
128
+ validate_standard_did_panel_treatment_path(
129
+ materialized,
130
+ unit_id=unit_id,
131
+ time=time,
132
+ treatment=treatment,
133
+ time_order=time_order,
134
+ )
135
+ validation_trace.append("unique:unit-time")
136
+ validation_trace.append(f"time-order:{time_label_kind}")
137
+ branch = f"{design}-panel"
138
+ else:
139
+ if post is None:
140
+ raise DidValueError(
141
+ ErrorCode.E001,
142
+ "post is required for repeated cross-section data.",
143
+ context={"field_name": "post", "data_type": data_type},
144
+ )
145
+ require_column(materialized, post, field_name="post")
146
+ require_binary_indicator(materialized, column=post, label="post")
147
+ validation_trace.extend(("required:post", "binary:post"))
148
+ time_order, time_label_kind = resolve_time_order_metadata(materialized, time=time)
149
+ validate_rcs_post_indicator(materialized, time=time, post=post, time_order=time_order)
150
+ validation_trace.append(f"time-order:{time_label_kind}")
151
+ branch = "did-rcs"
152
+ return NormalizedDataContract(
153
+ design=design,
154
+ data_type=data_type,
155
+ branch=branch,
156
+ outcome=outcome,
157
+ treatment=treatment,
158
+ time=time,
159
+ unit_id=unit_id,
160
+ post=post,
161
+ time_order=time_order,
162
+ cluster_default=cluster_default,
163
+ validation_trace=tuple(validation_trace),
164
+ )
165
+
166
+
167
+ __all__ = [
168
+ "DataContractError",
169
+ "DidDataError",
170
+ "NormalizedDataContract",
171
+ "normalize_design_data",
172
+ ]
@@ -0,0 +1,105 @@
1
+ """Automatic string-to-integer encoding for DID structural variables.
2
+
3
+ Mathematical justification:
4
+ - time column: Correct temporal ordering is critical for sDID (ΔY_t = Y_t - Y_{t-1}).
5
+ Strings are encoded by lexicographic order (matching Stata's egen group() behavior).
6
+ - unit_id column: Only grouping matters (ordering irrelevant for bootstrap clustering).
7
+ - cluster column: Only grouping matters (ordering irrelevant for bootstrap blocking).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import warnings
13
+ from typing import Any
14
+
15
+ import pandas as pd
16
+
17
+ from ..errors import WarningCode, did_warn
18
+
19
+
20
+ def auto_encode_string_columns(
21
+ df: pd.DataFrame,
22
+ *,
23
+ time: str | None = None,
24
+ unit_id: str | None = None,
25
+ id_cluster: str | None = None,
26
+ ) -> tuple[pd.DataFrame, dict[str, dict[Any, int]]]:
27
+ """Automatically encode string columns to integers for DID estimation.
28
+
29
+ Parameters
30
+ ----------
31
+ df : pd.DataFrame
32
+ Input data frame.
33
+ time : str
34
+ Time column name.
35
+ unit_id : str | None
36
+ Unit identifier column name.
37
+ id_cluster : str | None
38
+ Cluster variable column name.
39
+
40
+ Returns
41
+ -------
42
+ tuple[pd.DataFrame, dict[str, dict[Any, int]]]
43
+ Encoded DataFrame and encoding maps {column_name: {original_value: encoded_int}}.
44
+
45
+ Notes
46
+ -----
47
+ Encoding strategy (matches Stata's ``egen group()`` behavior):
48
+ - time: Sorted lexicographically then assigned 0, 1, 2, ...
49
+ - unit_id: Sorted lexicographically then assigned 0, 1, 2, ...
50
+ - id_cluster: Sorted lexicographically then assigned 0, 1, 2, ...
51
+
52
+ All columns use sorted order to ensure deterministic, reproducible encoding
53
+ that matches ``pd.factorize(sort=True)`` and Stata ``egen group()``.
54
+
55
+ Issues W001 warning for each auto-encoded column.
56
+ Does NOT encode columns that are already numeric.
57
+ """
58
+ encoded_df = df.copy()
59
+ encoding_maps: dict[str, dict[Any, int]] = {}
60
+
61
+ columns_to_check = []
62
+ if time is not None:
63
+ columns_to_check.append((time, "time"))
64
+ if unit_id is not None:
65
+ columns_to_check.append((unit_id, "unit_id"))
66
+ if id_cluster is not None and id_cluster != unit_id:
67
+ columns_to_check.append((id_cluster, "id_cluster"))
68
+
69
+ for col_name, role in columns_to_check:
70
+ if col_name is None or col_name not in df.columns:
71
+ continue
72
+
73
+ # Check if column contains string values
74
+ col_values = df[col_name]
75
+ if not _is_string_column(col_values):
76
+ continue
77
+
78
+ # Build encoding map: always sorted (matches Stata's egen group() behavior)
79
+ unique_values = sorted(col_values.dropna().unique())
80
+
81
+ encoding_map = {val: idx for idx, val in enumerate(unique_values)}
82
+ encoding_maps[col_name] = encoding_map
83
+
84
+ # Apply encoding
85
+ encoded_df[col_name] = col_values.map(encoding_map)
86
+
87
+ # Emit warning
88
+ did_warn(
89
+ WarningCode.W001,
90
+ f"String column '{col_name}' ({role}) automatically encoded to integer.",
91
+ context={"column": col_name, "role": role, "n_levels": len(encoding_map)},
92
+ stacklevel=3,
93
+ )
94
+
95
+ return encoded_df, encoding_maps
96
+
97
+
98
+ def _is_string_column(series: pd.Series) -> bool:
99
+ """Check if a pandas Series contains string (object/string) dtype."""
100
+ if series.dtype == object:
101
+ non_null = series.dropna()
102
+ if len(non_null) == 0:
103
+ return False
104
+ return all(isinstance(v, str) for v in non_null.head(100))
105
+ return pd.api.types.is_string_dtype(series)