dml-dev 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,15 @@
1
+ from project_code.src.build_helpers import BuildSource, BuildSpec
2
+
3
+
4
+ source_data = BuildSource(
5
+ paths=[],
6
+ passthrough_cols=[],
7
+ passthrough_cols_as_lag=[],
8
+ )
9
+
10
+
11
+ BUILD_SPEC = BuildSpec(
12
+ source_data=source_data,
13
+ programs={},
14
+ post_panel_transforms=[],
15
+ )
@@ -0,0 +1,2 @@
1
+ """Example experiment YAML files."""
2
+
@@ -0,0 +1,7 @@
1
+ program_pointer: []
2
+ outcomes: []
3
+ covariate_set_pointer: []
4
+ filter_set_pointer: []
5
+ outcomes_model_pointer: []
6
+ propensity_model_pointer: []
7
+ num_controls_per_treat: []
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,3 @@
1
+ COVARIATE_SET_REGISTRY: dict[str, list[str]] = {}
2
+
3
+ CATEGORICAL_COVARIATES: list[str] = []
@@ -0,0 +1,4 @@
1
+ import polars as pl
2
+
3
+
4
+ FILTER_SET_REGISTRY: dict[str, list[pl.Expr]] = {}
@@ -0,0 +1,6 @@
1
+ from typing import Any
2
+
3
+
4
+ OUTCOMES_MODEL_REGISTRY: dict[str, Any] = {}
5
+
6
+ PROPENSITY_MODEL_REGISTRY: dict[str, Any] = {}
@@ -0,0 +1,4 @@
1
+ from project_code.src.build_helpers import ProgramSource
2
+
3
+
4
+ PROGRAM_REGISTRY: dict[str, ProgramSource] = {}
@@ -0,0 +1,35 @@
1
+ Metadata-Version: 2.4
2
+ Name: dml-dev
3
+ Version: 0.1.0
4
+ Summary: DoubleML build, estimation, plotting, and utility pipelines.
5
+ Author: DML Pipeline Contributors
6
+ Keywords: administrative-data,causal-inference,doubleml,observational-data,program-evaluation
7
+ Classifier: Development Status :: 3 - Alpha
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Classifier: Programming Language :: Python :: 3.13
12
+ Classifier: Topic :: Scientific/Engineering
13
+ Requires-Python: >=3.12
14
+ Description-Content-Type: text/markdown
15
+ Requires-Dist: doubleml
16
+ Requires-Dist: joblib
17
+ Requires-Dist: oi-tools[figures]
18
+ Requires-Dist: plotnine
19
+ Requires-Dist: polars
20
+ Requires-Dist: psutil
21
+ Requires-Dist: PyYAML
22
+ Requires-Dist: scikit-learn
23
+ Requires-Dist: threadpoolctl
24
+ Provides-Extra: dev
25
+ Requires-Dist: build; extra == "dev"
26
+ Requires-Dist: twine; extra == "dev"
27
+
28
+ # DML Pipeline
29
+
30
+ Reusable build, estimation, plotting, and utility code for applying DoubleML to
31
+ administrative observational data for program analysis.
32
+
33
+ The package includes the full implementation under `project_code`, plus an
34
+ empty `config_template` package that shows the required config shape without
35
+ shipping project configuration.
@@ -0,0 +1,24 @@
1
+ config_template/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
+ config_template/build_spec.py,sha256=7bnTZuzDt4M6ahHFzIwn0SBAW_s9hNZNaZMcFZPbRfE,272
3
+ config_template/experiments/__init__.py,sha256=jjiIUqNMd8N7pbUmn-fhoN_86ZcZN6_KRFdjD41xuOA,38
4
+ config_template/experiments/example_experiment.yaml,sha256=OYhssBSzMgs6DYVqmQHB1q0hBphCU-O-Kr6J8vyXSkg,165
5
+ config_template/registries/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
6
+ config_template/registries/covariate_sets.py,sha256=7FLp5f8Q5rgxPCL1gDE-tiDLfkFou51p26y0fK9gir0,90
7
+ config_template/registries/filter_sets.py,sha256=FjAF97-VPV_rJpgSITuxPCCCJtQDNmxL7lAha79W2E0,73
8
+ config_template/registries/models.py,sha256=gXlBvLwoFahq-HK7g5aD03VRe_1VWMiCkDfb620ZzIg,118
9
+ config_template/registries/programs.py,sha256=nLWw4fjZvzO2nVDvXSOGjVaZ0wKZj7Rq_xdSz6eacqc,107
10
+ project_code/__init__.py,sha256=bwmrbow99Z6b3_gOA36dz_i-mUUYI4ZF3g7uRRaPa5U,29
11
+ project_code/pipeline/__init__.py,sha256=iD1tczg7zGZ-iJLEeLrTN7QOqSml16SDG_BryJmoKpY,50
12
+ project_code/pipeline/build.py,sha256=n4QdVUImJ96WtEM6urjpGv1sACFX-af6xB1az8kSVfI,1551
13
+ project_code/pipeline/estimate.py,sha256=4OlaSz2wqeJ7j82e8s3-8aEJmbJpQbaLs5nQ3_OwL8s,4689
14
+ project_code/src/__init__.py,sha256=fSqPxsQ4vSqFiIeUX87hnojKT7B1_W5N3GPOp0o6o4Y,45
15
+ project_code/src/build_helpers.py,sha256=kOV4Q6NZyMOuvAuprg5VkUgZes0Yv4rJfjrodMhPAY4,9106
16
+ project_code/src/estimate_helpers.py,sha256=fCU7UKLGYfsgkbruqbTa9gt3LaVRnmDPDqZLB83-al8,8527
17
+ project_code/src/paths.py,sha256=Msvtz2Iz19RF8m7xfsaNAiAbxAm9IowEHtK58ErFDY0,1102
18
+ project_code/src/plotting.py,sha256=ISsYGfanDuMqErOa8_pdDMgmfA9ZnewuwdYCwfmKLMM,8759
19
+ project_code/src/utils.py,sha256=cQOOzc3cQtWW-0652VvYuBSbDiNxVoh2PMEQKKnUtCQ,2479
20
+ dml_dev-0.1.0.dist-info/METADATA,sha256=NpZtmk8qEBl3u9otvkgM0cl7FkuE-dx1qf1ZdVyhr4s,1259
21
+ dml_dev-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
22
+ dml_dev-0.1.0.dist-info/entry_points.txt,sha256=fuHwWSJipTQJvn9tT4ATcRpnfOEjVop6e_tj0LGHUZg,112
23
+ dml_dev-0.1.0.dist-info/top_level.txt,sha256=3VjkRcifTL17eG8IuySDOo1bh_fIMXwF4DRP0_5JScU,29
24
+ dml_dev-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ dml-build = project_code.pipeline.build:cli
3
+ dml-estimate = project_code.pipeline.estimate:cli
@@ -0,0 +1,2 @@
1
+ config_template
2
+ project_code
@@ -0,0 +1,2 @@
1
+ """DML pipeline package."""
2
+
@@ -0,0 +1,2 @@
1
+ """Executable build and estimation pipelines."""
2
+
@@ -0,0 +1,61 @@
1
+ """Build processed cohort panels for a configured program.
2
+
3
+ This entrypoint loads the build spec, backs up any existing program output,
4
+ then writes one processed parquet file per cohort.
5
+ """
6
+
7
+ import argparse
8
+ import sys
9
+ import time
10
+ from pathlib import Path
11
+
12
+ LOCAL_DIR = Path(__file__).resolve().parents[2]
13
+ sys.path.insert(0, str(LOCAL_DIR))
14
+
15
+ from project_code.src.build_helpers import (
16
+ backup_existing_output,
17
+ build_cohort_file,
18
+ get_post_panel_transforms,
19
+ get_program_spec,
20
+ get_source_data_spec,
21
+ time_elapsed,
22
+ )
23
+
24
+
25
+ def main(program: str) -> None:
26
+ """Run the build pipeline for one program name from the registry."""
27
+
28
+ start = time.time()
29
+
30
+ source_data_spec = get_source_data_spec()
31
+ program_spec = get_program_spec(program)
32
+ post_panel_transforms = get_post_panel_transforms()
33
+
34
+ backup_existing_output(program)
35
+
36
+ for source_data_path in source_data_spec.paths:
37
+ build_cohort_file(
38
+ source_data_path=source_data_path,
39
+ program=program,
40
+ source_data_spec=source_data_spec,
41
+ program_spec=program_spec,
42
+ post_panel_transforms=post_panel_transforms,
43
+ )
44
+
45
+ end = time.time()
46
+ total_run_time = time_elapsed(start, end)
47
+ print("\n Done")
48
+ print(f"\n Total time: {total_run_time}")
49
+
50
+
51
+ def cli() -> None:
52
+ """Command-line wrapper for package entrypoints."""
53
+
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument("program")
56
+ args = parser.parse_args()
57
+ main(program=args.program)
58
+
59
+
60
+ if __name__ == "__main__":
61
+ cli()
@@ -0,0 +1,122 @@
1
+ """Run DoubleML estimation for a YAML experiment.
2
+
3
+ This entrypoint loads an experiment, expands its registry pointers into runs,
4
+ fits each run, and writes estimation and prediction logs.
5
+ """
6
+
7
+ import argparse
8
+ import os
9
+ import sys
10
+ import time
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+
14
+ import polars as pl
15
+
16
+ if os.environ.get("NCPUS"):
17
+ os.environ["POLARS_MAX_THREADS"] = os.environ["NCPUS"]
18
+
19
+ LOCAL_DIR = Path(__file__).resolve().parents[2]
20
+ sys.path.insert(0, str(LOCAL_DIR))
21
+
22
+ from project_code.src.estimate_helpers import (
23
+ fit_doubleml_irm,
24
+ get_experiment,
25
+ prepare_estimation_data,
26
+ unpack_runs,
27
+ validate_runs,
28
+ )
29
+ from project_code.src.utils import log_process_resources, log_results, time_elapsed, trim_memory
30
+
31
+
32
+ def main(experiment_name: str) -> None:
33
+ """Run all expanded estimation runs for one experiment YAML name."""
34
+
35
+ experiment = get_experiment(experiment_name)
36
+ runs = unpack_runs(experiment)
37
+ validate_runs(runs)
38
+
39
+ stop_resource_logging = log_process_resources(interval=30)
40
+ try:
41
+ for run_number, run in enumerate(runs, start=1):
42
+ print(f"Starting run #{run_number} of {len(runs)} \n")
43
+ start = time.time()
44
+
45
+ df, x_cols, summary = prepare_estimation_data(run)
46
+
47
+ start_estimation = time.time()
48
+ dml_obj = fit_doubleml_irm(
49
+ df=df,
50
+ run=run,
51
+ covariate_set_after_dummies=x_cols,
52
+ )
53
+
54
+ end = time.time()
55
+ total_run_time = time_elapsed(start, end)
56
+ estimation_run_time = time_elapsed(start_estimation, end)
57
+ estimation_run_time_hours = (end - start_estimation) / (60 * 60)
58
+
59
+ print("\n Starting logging...\n")
60
+ estimation_log = pl.DataFrame({
61
+ "program": [run.program_name],
62
+ "treatment": [run.treatment],
63
+ "outcome": [run.outcome],
64
+ "covariate_set_name": [run.covariate_set_pointer],
65
+ "filter_set_name": [run.filter_set_pointer],
66
+ "num_controls_per_treat": [run.num_controls_per_treat],
67
+ "outcomes_model_name": [run.outcomes_model_pointer],
68
+ "propensity_model_name": [run.propensity_model_pointer],
69
+ "outcomes_model_class": [type(run.outcomes_model).__name__],
70
+ "propensity_model_class": [type(run.propensity_model).__name__],
71
+ "outcomes_model_params": [str(run.outcomes_model.get_params())],
72
+ "propensity_model_params": [str(run.propensity_model.get_params())],
73
+ "dml_estimate": [float(dml_obj.coef[0])],
74
+ "dml_se": [float(dml_obj.se[0])],
75
+ "dml_outcomes_loss": [float(dml_obj.nuisance_loss["ml_g0"][0][0])],
76
+ "dml_prop_loss": [float(dml_obj.nuisance_loss["ml_m"][0][0])],
77
+ "total_run_time": [total_run_time],
78
+ "estimation_run_time": [estimation_run_time],
79
+ "estimation_run_time_hours": [estimation_run_time_hours],
80
+ "timestamp": [datetime.now()],
81
+ "n_controls": [summary["n_controls"]],
82
+ "n_unique_controls": [summary["n_unique_controls"]],
83
+ "n_covariates": [summary["n_covariates"]],
84
+ "n_treated": [summary["n_treated"]],
85
+ "n_null_rows_dropped": [summary["n_null_rows_dropped"]],
86
+ "n_rows": [summary["n_rows"]],
87
+ "run_number": [run_number],
88
+ })
89
+ predictions_log = pl.DataFrame({
90
+ **{col: df[col].to_numpy() for col in df.columns},
91
+ "run_number": [run_number] * len(df),
92
+ "true_outcomes": df[run.outcome],
93
+ "true_propensity": df[run.treatment],
94
+ "outcomes_predictions": dml_obj.predictions["ml_g0"][:, 0, 0],
95
+ "propensity_predictions": dml_obj.predictions["ml_m"][:, 0, 0],
96
+ })
97
+
98
+ log_results("estimation", estimation_log, experiment_name, run_number)
99
+ log_results("predictions", predictions_log, experiment_name, run_number)
100
+
101
+ print(f"""\n Run #{run_number} complete
102
+ \n Estimation run time: {estimation_run_time}
103
+ \n Total run time: {total_run_time}\n \n""")
104
+
105
+ del dml_obj, df
106
+ trim_memory()
107
+
108
+ finally:
109
+ stop_resource_logging.set()
110
+
111
+
112
+ def cli() -> None:
113
+ """Command-line wrapper for package entrypoints."""
114
+
115
+ parser = argparse.ArgumentParser()
116
+ parser.add_argument("experiment_name", type=str)
117
+ args = parser.parse_args()
118
+ main(experiment_name=args.experiment_name)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ cli()
@@ -0,0 +1,2 @@
1
+ """Shared helpers for the DML pipelines."""
2
+
@@ -0,0 +1,289 @@
1
+ from collections.abc import Callable, Sequence
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ import shutil
5
+ import sys
6
+ import tempfile
7
+ import time
8
+
9
+ import polars as pl
10
+
11
+ from project_code.src.paths import CONFIG_DIR, processed_data_out_folder, processed_data_out_path
12
+ from project_code.src.utils import time_elapsed, trim_memory
13
+
14
+ TREATMENT_COL = "treatment"
15
+ OBSERVATION_COL = "observation_year"
16
+ JOIN_KEY = "unit_id"
17
+ Transform = Callable[[pl.LazyFrame], pl.LazyFrame]
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class RelativeCol:
22
+ """Calendar-year columns to convert into event-time columns around enrollment."""
23
+
24
+ stem: str
25
+ years: Sequence[int]
26
+
27
+
28
+ @dataclass(frozen=True, kw_only=True)
29
+ class BuildSource:
30
+ """Input files and columns to carry from one side of the build join."""
31
+
32
+ paths: Sequence[Path]
33
+ passthrough_cols: Sequence[pl.Expr]
34
+ passthrough_cols_as_lag: Sequence[RelativeCol]
35
+ join_key_col: pl.Expr = field(default_factory=lambda: pl.col(JOIN_KEY))
36
+
37
+
38
+ @dataclass(frozen=True, kw_only=True)
39
+ class ProgramSource(BuildSource):
40
+ """Program-specific source data and column mappings."""
41
+
42
+ name: str
43
+ treatment_col: pl.Expr
44
+ enrollment_year_col: pl.Expr
45
+
46
+
47
+ @dataclass(frozen=True, init=False)
48
+ class BuildSpec:
49
+ """Complete build recipe: source data, programs, and post-panel transforms."""
50
+
51
+ source_data: BuildSource
52
+ programs: dict[str, ProgramSource]
53
+ post_panel_transforms: Sequence[Transform]
54
+
55
+ def __init__(
56
+ self,
57
+ source_data: BuildSource | None = None,
58
+ programs: dict[str, ProgramSource] | None = None,
59
+ post_panel_transforms: Sequence[Transform] = (),
60
+ ):
61
+ if source_data is None:
62
+ raise ValueError("BuildSpec requires source_data")
63
+
64
+ object.__setattr__(self, "source_data", source_data)
65
+ object.__setattr__(self, "programs", programs or {})
66
+ object.__setattr__(self, "post_panel_transforms", post_panel_transforms)
67
+
68
+
69
+ def get_build_spec() -> BuildSpec:
70
+ """Load the configured build recipe lazily to avoid import cycles."""
71
+
72
+ sys.path.insert(0, str(CONFIG_DIR.parent))
73
+ from config.build_spec import BUILD_SPEC
74
+
75
+ return BUILD_SPEC
76
+
77
+
78
+ def get_program_spec(program: str) -> ProgramSource:
79
+ """Return the configured source definition for one program."""
80
+
81
+ try:
82
+ return get_build_spec().programs[program]
83
+ except KeyError as e:
84
+ raise ValueError(f"Unknown program: {program}") from e
85
+
86
+
87
+ def get_source_data_spec() -> BuildSource:
88
+ """Return the shared source data input definition."""
89
+
90
+ return get_build_spec().source_data
91
+
92
+
93
+ def get_post_panel_transforms() -> Sequence[Transform]:
94
+ """Return transforms applied after event-time panel construction."""
95
+
96
+ return get_build_spec().post_panel_transforms
97
+
98
+
99
+ def backup_existing_output(program: str) -> None:
100
+ """Move existing build output aside before writing a fresh run."""
101
+
102
+ out_folder = processed_data_out_folder(program)
103
+ if not out_folder.exists():
104
+ return
105
+
106
+ backup = out_folder.with_name(out_folder.name + "_backup")
107
+ if backup.exists():
108
+ shutil.rmtree(backup)
109
+ out_folder.rename(backup)
110
+
111
+
112
+
113
+ def load_program_lf(program_spec: ProgramSource) -> pl.LazyFrame:
114
+ """Load treated program records and normalize key build columns."""
115
+
116
+ return (
117
+ pl.scan_parquet(program_spec.paths)
118
+ .with_columns(
119
+ program_spec.join_key_col.alias(JOIN_KEY),
120
+ program_spec.treatment_col.alias(TREATMENT_COL),
121
+ program_spec.enrollment_year_col.alias(OBSERVATION_COL),
122
+ )
123
+ .filter(pl.col(TREATMENT_COL) == 1)
124
+ )
125
+
126
+
127
+ def get_treated_enrollment_years(program_lf: pl.LazyFrame) -> list[int]:
128
+ """Collect enrollment years that define separate event-time panels."""
129
+
130
+ return (
131
+ program_lf.select(OBSERVATION_COL)
132
+ .drop_nulls()
133
+ .unique()
134
+ .collect()
135
+ .to_series()
136
+ .sort()
137
+ .to_list()
138
+ )
139
+
140
+
141
+ def build_lag_exprs(
142
+ passthrough_cols_as_lag: Sequence[RelativeCol],
143
+ enrollment_year: int,
144
+ available_cols: set[str],
145
+ ) -> tuple[list[pl.Expr], set[str]]:
146
+ """Build expressions that map calendar-year columns to event-time columns.
147
+
148
+ Missing calendar-year inputs become null columns so every cohort panel has a
149
+ stable schema.
150
+ """
151
+
152
+ lag_exprs = []
153
+ missing_cols = set()
154
+
155
+ for col in passthrough_cols_as_lag:
156
+ for relative_year in col.years:
157
+ calendar_year = enrollment_year + relative_year
158
+ source_col = f"{col.stem}_{calendar_year}"
159
+ suffix = f"{relative_year}" if relative_year >= 0 else f"L{abs(relative_year)}"
160
+ target_col = f"{col.stem}_{suffix}"
161
+
162
+ if source_col in available_cols:
163
+ lag_exprs.append(pl.col(source_col).alias(target_col))
164
+ else:
165
+ missing_cols.add(source_col)
166
+ lag_exprs.append(pl.lit(None).alias(target_col))
167
+
168
+ return lag_exprs, missing_cols
169
+
170
+
171
+ def apply_transforms(
172
+ lf: pl.LazyFrame,
173
+ transforms: Sequence[Transform],
174
+ ) -> pl.LazyFrame:
175
+ """Apply configured LazyFrame transforms in order."""
176
+
177
+ for transform in transforms:
178
+ lf = transform(lf)
179
+ return lf
180
+
181
+
182
+ def build_cohort_file(
183
+ source_data_path: Path,
184
+ program: str,
185
+ source_data_spec: BuildSource,
186
+ program_spec: ProgramSource,
187
+ post_panel_transforms: Sequence[Transform],
188
+ ) -> None:
189
+ """Build and write one processed parquet file for one birth cohort.
190
+
191
+ The output stacks one panel per treated enrollment year, with controls
192
+ repeated into the panels where they are untreated or not-yet-treated.
193
+ """
194
+
195
+ start = time.time()
196
+ cohort = int(source_data_path.stem.split("=")[1])
197
+ print(f"\n \n Starting cohort {cohort}")
198
+
199
+ # Temporary cohort window used to avoid scanning out-of-scope source files.
200
+ if cohort < 1940 or cohort > 1995:
201
+ return
202
+
203
+ source_data_lf = pl.scan_parquet(source_data_path).with_columns(
204
+ source_data_spec.join_key_col.alias(JOIN_KEY)
205
+ )
206
+ program_lf = load_program_lf(program_spec)
207
+ treated_enrollment_years = get_treated_enrollment_years(program_lf)
208
+
209
+ # Join once at calendar time, then slice into event-time panels below.
210
+ merged_lf = source_data_lf.join(program_lf, on=JOIN_KEY, how="left")
211
+ merged_lf = merged_lf.with_columns(pl.col(TREATMENT_COL).fill_null(0))
212
+
213
+ available_cols = set(merged_lf.collect_schema().names())
214
+ passthrough_cols = [
215
+ pl.col(JOIN_KEY),
216
+ pl.col(OBSERVATION_COL),
217
+ pl.col(TREATMENT_COL),
218
+ *program_spec.passthrough_cols,
219
+ *source_data_spec.passthrough_cols,
220
+ ]
221
+ passthrough_cols_as_lag = [
222
+ *program_spec.passthrough_cols_as_lag,
223
+ *source_data_spec.passthrough_cols_as_lag,
224
+ ]
225
+
226
+ missing_cols = set()
227
+ temp_dir = Path(tempfile.mkdtemp())
228
+ cohort_panel_paths = []
229
+
230
+ for enrollment_year in treated_enrollment_years:
231
+ # Keep treated observations from this enrollment year and controls that
232
+ # are never treated or not yet treated.
233
+ cohort_panel = merged_lf.filter(
234
+ (pl.col(TREATMENT_COL) == 0) | (pl.col(OBSERVATION_COL) == enrollment_year)
235
+ )
236
+ cohort_panel = cohort_panel.with_columns(
237
+ pl.col(OBSERVATION_COL).fill_null(enrollment_year)
238
+ )
239
+
240
+ # Build event-time columns from calendar-time source columns.
241
+ lag_exprs, missing_lag_cols = build_lag_exprs(
242
+ passthrough_cols_as_lag=passthrough_cols_as_lag,
243
+ enrollment_year=enrollment_year,
244
+ available_cols=available_cols,
245
+ )
246
+ missing_cols.update(missing_lag_cols)
247
+
248
+ # Write each enrollment-year panel separately to keep memory bounded.
249
+ temp_path = temp_dir / f"panel_{enrollment_year}.parquet"
250
+ cohort_panel.select(
251
+ *passthrough_cols,
252
+ *lag_exprs,
253
+ ).sink_parquet(temp_path, engine="streaming")
254
+ cohort_panel_paths.append(temp_path)
255
+
256
+ del cohort_panel
257
+ trim_memory()
258
+
259
+ if missing_cols:
260
+ print(f"Warning - Lag construction is missing the following columns: {missing_cols}")
261
+
262
+ result = pl.concat(
263
+ [pl.scan_parquet(path) for path in cohort_panel_paths],
264
+ how="vertical_relaxed",
265
+ )
266
+ # Add common post-panel features after all relative columns exist.
267
+ result = apply_transforms(result, post_panel_transforms)
268
+
269
+ out_path = processed_data_out_path(program, cohort)
270
+ result.sink_parquet(out_path, engine="streaming")
271
+
272
+ del result
273
+ trim_memory()
274
+
275
+ end = time.time()
276
+ cohort_run_time = time_elapsed(start, end)
277
+ print(f"Cohort {cohort} build complete. \n Run time: {cohort_run_time}")
278
+
279
+
280
+ def add_derived_columns(program: str) -> None:
281
+ """Re-apply post-panel transforms to files that have already been built."""
282
+
283
+ folder = processed_data_out_folder(program)
284
+ post_panel_transforms = get_post_panel_transforms()
285
+
286
+ for path in folder.iterdir():
287
+ lf = pl.scan_parquet(path)
288
+ lf = apply_transforms(lf, post_panel_transforms)
289
+ lf.sink_parquet(path, engine="streaming")
@@ -0,0 +1,263 @@
1
+ from dataclasses import dataclass
2
+ from itertools import product
3
+ import re
4
+ import sys
5
+ from typing import Any
6
+
7
+ import doubleml as dml
8
+ import polars as pl
9
+ import yaml
10
+ from doubleml.utils import PSProcessorConfig
11
+ from joblib import parallel_backend
12
+ from sklearn.base import clone
13
+ from threadpoolctl import threadpool_limits
14
+
15
+ from project_code.src.build_helpers import TREATMENT_COL, JOIN_KEY
16
+ from project_code.src.paths import CONFIG_DIR, processed_data_out_folder
17
+ from project_code.src.utils import trim_memory
18
+
19
+
20
+ MODEL_THREADS = 8
21
+ N_JOBS_CV = 4
22
+
23
+ sys.path.insert(0, str(CONFIG_DIR.parent))
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class Experiment:
28
+ """Experiment pointers loaded from a YAML file."""
29
+
30
+ program_pointer: list[str]
31
+ outcomes: list[str]
32
+ covariate_set_pointer: list[str]
33
+ filter_set_pointer: list[str]
34
+ outcomes_model_pointer: list[str]
35
+ propensity_model_pointer: list[str]
36
+ num_controls_per_treat: list[float]
37
+
38
+
39
+ @dataclass
40
+ class Run:
41
+ """Concrete estimation run after resolving experiment pointers."""
42
+
43
+ program_name: str
44
+ join_key: str
45
+ treatment: str
46
+ outcome: str
47
+ covariate_set: list[str]
48
+ filter_set: list[Any]
49
+ covariate_set_pointer: str
50
+ filter_set_pointer: str
51
+ outcomes_model_pointer: str
52
+ propensity_model_pointer: str
53
+ num_controls_per_treat: float
54
+ outcomes_model: Any
55
+ propensity_model: Any
56
+
57
+
58
+ def get_experiment(name: str) -> Experiment:
59
+ """Load one experiment YAML by name."""
60
+
61
+ path = CONFIG_DIR / "experiments" / f"{name}.yaml"
62
+ if not path.exists():
63
+ raise ValueError(f"Unknown experiment: {name}")
64
+
65
+ with path.open() as f:
66
+ raw = yaml.safe_load(f) or {}
67
+
68
+ return Experiment(**raw)
69
+
70
+ def validate_runs(runs: list[Run]) -> None:
71
+ """Validate that required data files and referenced columns exist."""
72
+
73
+ import polars as pl
74
+ from project_code.src.paths import processed_data_out_path
75
+
76
+ errors = []
77
+
78
+ def _schema_has_reference(schema: pl.Schema, column_reference: str) -> bool:
79
+ if column_reference in schema:
80
+ return True
81
+ if column_reference.startswith("^"):
82
+ pattern = re.compile(column_reference)
83
+ return any(pattern.fullmatch(col) for col in schema)
84
+ return False
85
+
86
+ for i, run in enumerate(runs):
87
+ label = (
88
+ f"Run #{i + 1} "
89
+ f"({run.program_name}, {run.covariate_set_pointer}, {run.filter_set_pointer}, "
90
+ f"{run.outcomes_model_pointer}, {run.propensity_model_pointer}, "
91
+ f"{run.num_controls_per_treat} controls/treat)"
92
+ )
93
+
94
+ # Check processed data exists and required columns are present
95
+ data_path = processed_data_out_path(run.program_name, 1945)
96
+ if not data_path.exists():
97
+ errors.append(f"{label}: no processed data found for program '{run.program_name}'")
98
+ else:
99
+ schema = pl.scan_parquet(data_path).collect_schema()
100
+
101
+ for var in [TREATMENT_COL, run.outcome] + run.covariate_set:
102
+ if var not in schema:
103
+ errors.append(f"{label}: variable '{var}' not in schema")
104
+
105
+ # Check column references in filter expressions
106
+ for filter_entry in run.filter_set:
107
+ for col in filter_entry.meta.root_names():
108
+ if not _schema_has_reference(schema, col):
109
+ errors.append(f"{label}: filter references column '{col}' not in schema")
110
+
111
+ if errors:
112
+ raise ValueError("Run validation failed:\n\n" + "\n".join(errors))
113
+
114
+
115
+ def unpack_runs(experiment: Experiment) -> list[Run]:
116
+ """Resolve registry pointers and expand an experiment into concrete runs."""
117
+
118
+ from config.registries.programs import PROGRAM_REGISTRY
119
+ from config.registries.covariate_sets import COVARIATE_SET_REGISTRY
120
+ from config.registries.filter_sets import FILTER_SET_REGISTRY
121
+ from config.registries.models import OUTCOMES_MODEL_REGISTRY, PROPENSITY_MODEL_REGISTRY
122
+
123
+ runs = []
124
+
125
+ combos = product(
126
+ experiment.program_pointer,
127
+ experiment.outcomes,
128
+ experiment.covariate_set_pointer,
129
+ experiment.filter_set_pointer,
130
+ experiment.outcomes_model_pointer,
131
+ experiment.propensity_model_pointer,
132
+ experiment.num_controls_per_treat,
133
+ )
134
+
135
+ for (
136
+ program_pointer,
137
+ out,
138
+ cov_pointer,
139
+ filt_pointer,
140
+ outcomes_model_pointer,
141
+ propensity_model_pointer,
142
+ num_controls_per_treat,
143
+ ) in combos:
144
+ runs.append(
145
+ Run(
146
+ program_name=PROGRAM_REGISTRY[program_pointer].name,
147
+ join_key=JOIN_KEY,
148
+ treatment=TREATMENT_COL,
149
+ outcome=out,
150
+ covariate_set=COVARIATE_SET_REGISTRY[cov_pointer],
151
+ filter_set=FILTER_SET_REGISTRY[filt_pointer],
152
+ covariate_set_pointer=cov_pointer,
153
+ filter_set_pointer=filt_pointer,
154
+ outcomes_model_pointer=outcomes_model_pointer,
155
+ propensity_model_pointer=propensity_model_pointer,
156
+ num_controls_per_treat=num_controls_per_treat,
157
+ outcomes_model=clone(OUTCOMES_MODEL_REGISTRY[outcomes_model_pointer]),
158
+ propensity_model=clone(PROPENSITY_MODEL_REGISTRY[propensity_model_pointer]),
159
+ )
160
+ )
161
+
162
+ return runs
163
+
164
+
165
+ def sample_controls_per_treated(lf: pl.LazyFrame, treatment_col: str, join_key: str, num_controls_per_treat: float) -> pl.LazyFrame:
166
+ """Keep all treated rows and R hash sample of controls."""
167
+
168
+ treated = pl.col(treatment_col) == 1
169
+ n_treated, n_controls = lf.select(
170
+ treated.sum(),
171
+ (~treated).sum(),
172
+ ).collect().row(0)
173
+
174
+ if n_controls == 0:
175
+ return lf
176
+
177
+ keep_prob = min(1.0, num_controls_per_treat * n_treated / n_controls)
178
+ return lf.filter(treated | (pl.col(join_key).hash() < keep_prob * (2**64 - 1)))
179
+
180
+
181
+ def prepare_estimation_data(run: Run):
182
+ """Load, filter, sample, encode, and summarize modeling data for one run."""
183
+
184
+ from config.registries.covariate_sets import CATEGORICAL_COVARIATES
185
+
186
+ folder_path = processed_data_out_folder(run.program_name)
187
+ paths = sorted(folder_path.glob("*.parquet"))
188
+ lf = pl.concat([pl.scan_parquet(path) for path in paths], how="vertical_relaxed")
189
+ columns = [run.treatment, run.outcome, run.join_key] + run.covariate_set
190
+
191
+ lf = (
192
+ lf.with_columns(pl.col(pl.Float32, pl.Float64).fill_nan(None))
193
+ .drop_nulls(subset=columns)
194
+ )
195
+
196
+ for filter_entry in run.filter_set:
197
+ lf = lf.filter(filter_entry)
198
+
199
+ lf = sample_controls_per_treated(
200
+ lf=lf,
201
+ treatment_col=run.treatment,
202
+ join_key=run.join_key,
203
+ num_controls_per_treat=run.num_controls_per_treat,
204
+ )
205
+
206
+ print("\n Starting collect...\n")
207
+ lf = lf.select(columns).collect()
208
+
209
+ print("\n Starting hot encoding...\n")
210
+ cols_to_encode = [col for col in CATEGORICAL_COVARIATES if col in lf.columns]
211
+ lf = lf.to_dummies(columns=cols_to_encode, drop_first=True)
212
+
213
+ x_cols = [x for x in lf.columns if x not in [run.join_key, run.treatment, run.outcome]]
214
+ n_rows = lf.select(pl.len()).item()
215
+ n_treated = lf.filter(pl.col(run.treatment) == 1).select(pl.len()).item()
216
+ summary = {
217
+ "n_rows": n_rows,
218
+ "n_treated": n_treated,
219
+ "n_controls": n_rows - n_treated,
220
+ "n_unique_controls": lf.filter(pl.col(run.treatment) == 0)
221
+ .select(run.join_key)
222
+ .unique()
223
+ .select(pl.len())
224
+ .item(),
225
+ "n_covariates": len(x_cols),
226
+ "n_null_rows_dropped": 0,
227
+ }
228
+
229
+ df = lf.to_pandas()
230
+ del lf
231
+ trim_memory()
232
+
233
+ return df, x_cols, summary
234
+
235
+
236
+ def fit_doubleml_irm(df, run: Run, covariate_set_after_dummies: list[str]):
237
+ """Fit DoubleML IRM and return the fitted DoubleML object."""
238
+
239
+ print("\n Starting DML prep...\n")
240
+ dml_data = dml.DoubleMLData(
241
+ data=df,
242
+ y_col=run.outcome,
243
+ d_cols=[run.treatment],
244
+ x_cols=covariate_set_after_dummies,
245
+ )
246
+
247
+ ps_config = PSProcessorConfig()
248
+ dml_obj = dml.DoubleMLIRM(
249
+ obj_dml_data=dml_data,
250
+ ml_g=run.outcomes_model,
251
+ ml_m=run.propensity_model,
252
+ score="ATTE",
253
+ n_folds=5,
254
+ n_rep=3,
255
+ ps_processor_config=ps_config,
256
+ )
257
+
258
+ print("\n Starting DML fit...\n")
259
+ with threadpool_limits(limits=MODEL_THREADS, user_api="openmp"):
260
+ with parallel_backend("loky", inner_max_num_threads=MODEL_THREADS):
261
+ dml_obj.fit(n_jobs_cv=N_JOBS_CV)
262
+
263
+ return dml_obj
@@ -0,0 +1,35 @@
1
+ from pathlib import Path
2
+ import os
3
+
4
+
5
+ ROOT_DIR = Path(os.environ.get("DML_PIPELINE_ROOT_DIR", Path.cwd()))
6
+
7
+ # === Data directories ===
8
+
9
+ DATA_DIR = Path(os.environ.get("DML_PIPELINE_DATA_DIR", ROOT_DIR / "data"))
10
+
11
+ # === Local working directory ===
12
+ LOCAL_DIR = Path(os.environ.get("DML_PIPELINE_LOCAL_DIR", ROOT_DIR))
13
+
14
+ # sub-directories
15
+ SRC_DIR = LOCAL_DIR / "src"
16
+ OUT_DIR = LOCAL_DIR / "outputs"
17
+ CONFIG_DIR = Path(os.environ.get("DML_PIPELINE_CONFIG_DIR", LOCAL_DIR / "config"))
18
+
19
+ def processed_data_out_folder(program: str) -> Path:
20
+ path = (LOCAL_DIR / "data" / "build" / f"{program}")
21
+ return path
22
+
23
+ def processed_data_out_path(program: str, cohort: int) -> Path:
24
+ folder_path = processed_data_out_folder(program)
25
+ path = (
26
+ folder_path /
27
+ f"{program}_panel_cohort={cohort}.parquet"
28
+ )
29
+ path.parent.mkdir(parents=True, exist_ok=True)
30
+ return path
31
+
32
+ def get_log_out_path(result_type: str, experiment_name: str) -> Path:
33
+ path = (OUT_DIR / "raw" / f"{experiment_name}" / f"log_{result_type}.parquet")
34
+ path.parent.mkdir(parents=True, exist_ok=True)
35
+ return path
@@ -0,0 +1,253 @@
1
+ import polars as pl
2
+
3
+ import plotnine as pn
4
+ from oi_tools.figures import (
5
+ OIColors,
6
+ save_figure,
7
+ scale_color_oi,
8
+ scale_fill_oi,
9
+ theme_oi,
10
+ )
11
+
12
+ from project_code.src.build_helpers import TREATMENT_COL
13
+ from project_code.src.paths import OUT_DIR, get_log_out_path
14
+
15
+
16
+ PROPENSITY_PLOT_FOLDER = OUT_DIR / "plots" / "density_functions"
17
+ WEIGHTED_RESIDUAL_PLOT_FOLDER = OUT_DIR / "plots" / "weighted_residual_influence"
18
+ CALIBRATION_PLOT_FOLDER = OUT_DIR / "plots" / "calibration"
19
+
20
+
21
+ def plot_propensity_density(experiment_name: str, run_number: int):
22
+ """Plot treated/control propensity score densities for one run."""
23
+
24
+ df = (
25
+ pl.read_parquet(get_log_out_path("predictions", experiment_name))
26
+ .filter(pl.col("run_number") == run_number)
27
+ )
28
+ plot_df = (
29
+ df.select(
30
+ pl.col("propensity_predictions"),
31
+ pl.when(pl.col(TREATMENT_COL) == 1)
32
+ .then(pl.lit("Treated"))
33
+ .otherwise(pl.lit("Control"))
34
+ .alias("treatment_status"),
35
+ )
36
+ .to_pandas()
37
+ )
38
+
39
+ fig = (
40
+ pn.ggplot(
41
+ plot_df,
42
+ pn.aes(
43
+ x="propensity_predictions",
44
+ color="treatment_status",
45
+ fill="treatment_status",
46
+ ),
47
+ )
48
+ + pn.geom_density(alpha=0.25)
49
+ + pn.scale_x_continuous(limits=(0, 1))
50
+ + scale_color_oi(name="")
51
+ + scale_fill_oi(name="")
52
+ + pn.labs(
53
+ x="Estimated propensity score",
54
+ y="Density",
55
+ title="Propensity Score Density",
56
+ )
57
+ + theme_oi()
58
+ )
59
+
60
+ PROPENSITY_PLOT_FOLDER.mkdir(parents=True, exist_ok=True)
61
+ save_figure(fig, PROPENSITY_PLOT_FOLDER / f"{experiment_name}_{run_number}")
62
+ return fig
63
+
64
+
65
+ def _prediction_diagnostics(df: pl.DataFrame, outcome_col: str) -> pl.DataFrame:
66
+ """Add ATT control weights and residual diagnostics to prediction rows."""
67
+
68
+ return (
69
+ df.with_columns(
70
+ pl.when(pl.col("propensity_predictions").is_between(0, 1, closed="none"))
71
+ .then(pl.col("propensity_predictions") / (1 - pl.col("propensity_predictions")))
72
+ .otherwise(None)
73
+ .alias("att_control_weight"),
74
+ (pl.col(outcome_col) - pl.col("outcomes_predictions")).alias("outcome_residual"),
75
+ )
76
+ .with_columns(
77
+ (pl.col("att_control_weight") * pl.col("outcome_residual")).alias(
78
+ "weighted_residual_contribution"
79
+ ),
80
+ (pl.col("att_control_weight") * pl.col("outcome_residual")).abs().alias(
81
+ "abs_weighted_residual_contribution"
82
+ ),
83
+ )
84
+ )
85
+
86
+
87
+ def _weighted_mean(df: pl.DataFrame, value_col: str, weight_col: str) -> float | None:
88
+ """Return weighted mean, or None when all usable weight is zero/null."""
89
+
90
+ numerator, denominator = df.select(
91
+ (pl.col(value_col) * pl.col(weight_col)).sum().alias("weighted_sum"),
92
+ pl.col(weight_col).sum().alias("weight_sum"),
93
+ ).row(0)
94
+ if denominator in (None, 0):
95
+ return None
96
+ return numerator / denominator
97
+
98
+
99
+ def _control_ess(controls: pl.DataFrame) -> float | None:
100
+ """Effective sample size of ATT-weighted controls."""
101
+
102
+ sum_w, sum_w_sq = controls.select(
103
+ pl.col("att_control_weight").sum().alias("sum_w"),
104
+ (pl.col("att_control_weight") ** 2).sum().alias("sum_w_sq"),
105
+ ).row(0)
106
+ if sum_w_sq in (None, 0):
107
+ return None
108
+ return (sum_w**2) / sum_w_sq
109
+
110
+
111
+ def create_run_summary(experiment_name: str, run_number: int) -> pl.DataFrame:
112
+ """Create a one-row table of ATT estimates, losses, samples, and diagnostics."""
113
+
114
+ estimation_log = (
115
+ pl.read_parquet(get_log_out_path("estimation", experiment_name))
116
+ .filter(pl.col("run_number") == run_number)
117
+ )
118
+ if estimation_log.height != 1:
119
+ raise ValueError(f"Expected one estimation log row, found {estimation_log.height}")
120
+
121
+ predictions_log = (
122
+ pl.read_parquet(get_log_out_path("predictions", experiment_name))
123
+ .filter(pl.col("run_number") == run_number)
124
+ )
125
+
126
+ outcome_col = estimation_log["outcome"].item()
127
+ diagnostics = _prediction_diagnostics(predictions_log, outcome_col)
128
+ treated = diagnostics.filter(pl.col(TREATMENT_COL) == 1)
129
+ controls = diagnostics.filter(pl.col(TREATMENT_COL) == 0)
130
+
131
+ mean_treated_outcome = treated.select(pl.col(outcome_col).mean()).item()
132
+ weighted_control_outcome = _weighted_mean(controls, outcome_col, "att_control_weight")
133
+ att_outcome_only = treated.select((pl.col(outcome_col) - pl.col("outcomes_predictions")).mean()).item()
134
+ att_ipw_only = (
135
+ None if weighted_control_outcome is None else mean_treated_outcome - weighted_control_outcome
136
+ )
137
+ att_dml = float(estimation_log["dml_estimate"].item())
138
+
139
+ return pl.DataFrame({
140
+ "att_dml": [att_dml],
141
+ "se_dml": [float(estimation_log["dml_se"].item())],
142
+ "outcome_model_loss": [float(estimation_log["dml_outcomes_loss"].item())],
143
+ "propensity_model_loss": [float(estimation_log["dml_prop_loss"].item())],
144
+ "n_treated": [int(estimation_log["n_treated"].item())],
145
+ "n_controls": [int(estimation_log["n_controls"].item())],
146
+ "n_total": [int(estimation_log["n_rows"].item())],
147
+ "att_control_ess": [_control_ess(controls)],
148
+ "att_outcome_only": [att_outcome_only],
149
+ "att_ipw_only": [att_ipw_only],
150
+ "att_residual_correction": [att_dml - att_outcome_only],
151
+ })
152
+
153
+
154
+ def plot_weighted_residual_influence(experiment_name: str, run_number: int):
155
+ """Plot ATT-weighted control residuals for one run."""
156
+
157
+ estimation_log = (
158
+ pl.read_parquet(get_log_out_path("estimation", experiment_name))
159
+ .filter(pl.col("run_number") == run_number)
160
+ )
161
+ if estimation_log.height != 1:
162
+ raise ValueError(f"Expected one estimation log row, found {estimation_log.height}")
163
+
164
+ outcome_col = estimation_log["outcome"].item()
165
+ predictions_log = (
166
+ pl.read_parquet(get_log_out_path("predictions", experiment_name))
167
+ .filter(pl.col("run_number") == run_number)
168
+ )
169
+ controls = (
170
+ _prediction_diagnostics(predictions_log, outcome_col)
171
+ .filter(pl.col(TREATMENT_COL) == 0)
172
+ .drop_nulls(
173
+ subset=[
174
+ "att_control_weight",
175
+ "outcome_residual",
176
+ "abs_weighted_residual_contribution",
177
+ ]
178
+ )
179
+ )
180
+ plot_df = controls.to_pandas()
181
+ fig = (
182
+ pn.ggplot(
183
+ plot_df,
184
+ pn.aes(
185
+ x="att_control_weight",
186
+ y="outcome_residual",
187
+ color="abs_weighted_residual_contribution",
188
+ ),
189
+ )
190
+ + pn.geom_point(alpha=0.35)
191
+ + pn.labs(
192
+ x="ATT control weight",
193
+ y="Outcome residual",
194
+ color="Abs. weighted residual",
195
+ title="Weighted Control Residual Influence",
196
+ )
197
+ + pn.scale_color_gradient(low=OIColors.BLUE, high=OIColors.RED)
198
+ + theme_oi()
199
+ )
200
+
201
+ WEIGHTED_RESIDUAL_PLOT_FOLDER.mkdir(parents=True, exist_ok=True)
202
+ save_figure(fig, WEIGHTED_RESIDUAL_PLOT_FOLDER / f"{experiment_name}_{run_number}")
203
+ return fig
204
+
205
+
206
+ def calibration_plot(
207
+ experiment_name: str,
208
+ run_number: int,
209
+ prediction_type: str,
210
+ x_limits: tuple[float, float] | None = None,
211
+ ):
212
+ """Plot true values against nuisance predictions for controls in one run."""
213
+
214
+ if prediction_type not in ["propensity", "outcomes"]:
215
+ raise ValueError("prediction_type must be either 'propensity' or 'outcomes'")
216
+
217
+ prediction_col = f"{prediction_type}_predictions"
218
+ actual_col = f"true_{prediction_type}"
219
+
220
+ plot_df = (
221
+ pl.read_parquet(get_log_out_path("predictions", experiment_name))
222
+ .filter(
223
+ (pl.col("run_number") == run_number)
224
+ & (pl.col(TREATMENT_COL) == 0)
225
+ )
226
+ .select(
227
+ pl.col(prediction_col).alias("prediction"),
228
+ pl.col(actual_col).alias("actual"),
229
+ )
230
+ .drop_nulls()
231
+ .to_pandas()
232
+ )
233
+
234
+ fig = (
235
+ pn.ggplot(plot_df, pn.aes(x="prediction", y="actual"))
236
+ + pn.geom_point(alpha=0.18, color=OIColors.BLUE)
237
+ + pn.geom_smooth(se=True, color=OIColors.RED)
238
+ + pn.labs(
239
+ x=f"{prediction_type.title()} prediction",
240
+ y=f"True {prediction_type}",
241
+ title=f"{prediction_type.title()} Calibration",
242
+ )
243
+ + theme_oi()
244
+ )
245
+ if x_limits is not None:
246
+ fig = fig + pn.scale_x_continuous(limits=x_limits)
247
+
248
+ CALIBRATION_PLOT_FOLDER.mkdir(parents=True, exist_ok=True)
249
+ save_figure(
250
+ fig,
251
+ CALIBRATION_PLOT_FOLDER / f"{experiment_name}_{run_number}_{prediction_type}",
252
+ )
253
+ return fig
@@ -0,0 +1,95 @@
1
+ import ctypes
2
+ import gc
3
+ import os
4
+ import sys
5
+ import threading
6
+
7
+ import polars as pl
8
+
9
+ from project_code.src.paths import OUT_DIR, get_log_out_path
10
+
11
+
12
+ def log_results(result_type: str,
13
+ df: pl.DataFrame,
14
+ experiment_name: str,
15
+ run_number: int) -> None:
16
+
17
+ path = get_log_out_path(result_type, experiment_name)
18
+
19
+ if run_number == 1:
20
+ df.write_parquet(path)
21
+ else:
22
+ existing = pl.read_parquet(path)
23
+ combined = pl.concat([existing, df])
24
+ combined.write_parquet(path)
25
+
26
+ return
27
+
28
+
29
+ def time_elapsed(start: float, end: float) -> str:
30
+ """Format elapsed wall-clock time for logs."""
31
+
32
+ elapsed = end - start
33
+ minutes = int(elapsed // 60)
34
+ seconds = elapsed % 60
35
+ return f"{minutes} min, {seconds:.0f} sec"
36
+
37
+
38
+ def trim_memory() -> None:
39
+ """Ask Python and, on Linux, libc to release unused memory."""
40
+
41
+ gc.collect()
42
+ if sys.platform.startswith("linux"):
43
+ libc = ctypes.CDLL("libc.so.6")
44
+ if hasattr(libc, "malloc_trim"):
45
+ libc.malloc_trim(0)
46
+
47
+
48
+ def log_process_resources(interval: float = 30) -> threading.Event:
49
+ import psutil
50
+
51
+ stop_event = threading.Event()
52
+ proc = psutil.Process(os.getpid())
53
+
54
+ def _log_loop():
55
+ proc.cpu_percent(interval=None)
56
+
57
+ while not stop_event.is_set():
58
+ if stop_event.wait(interval):
59
+ break
60
+
61
+ cpu_pct = proc.cpu_percent(interval=None)
62
+ mem_gb = proc.memory_info().rss / 1e9
63
+ n_threads = proc.num_threads()
64
+
65
+ print(
66
+ f"CPU: {cpu_pct:.1f}% | "
67
+ f"Memory: {mem_gb:.2f} GB | "
68
+ f"Threads: {n_threads}",
69
+ flush=True,
70
+ )
71
+
72
+ thread = threading.Thread(target=_log_loop, daemon=True)
73
+ thread.start()
74
+
75
+ return stop_event
76
+
77
+
78
+ def pl_to_csv(df: pl.DataFrame, name: str) -> None:
79
+ """Save a transposed Polars table with variable names on the left."""
80
+
81
+ table_dir = OUT_DIR / "tables"
82
+ table_dir.mkdir(parents=True, exist_ok=True)
83
+ csv_name = name if name.endswith(".csv") else f"{name}.csv"
84
+
85
+ # This format is easier to drop into LaTeX-style summary tables:
86
+ # one row per original column, with observed values spread across columns.
87
+ out_df = df.transpose(
88
+ include_header=True,
89
+ header_name="variable",
90
+ column_names=[f"value_{i}" for i in range(1, df.height + 1)],
91
+ )
92
+ out_df.write_csv(table_dir / csv_name)
93
+ return
94
+
95
+