openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
gabriel/tasks/debias.py
ADDED
|
@@ -0,0 +1,1460 @@
|
|
|
1
|
+
"""Econometric debiasing utilities for GABRIEL."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import random
|
|
8
|
+
import re
|
|
9
|
+
import warnings
|
|
10
|
+
from dataclasses import asdict, dataclass, field
|
|
11
|
+
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
try: # pragma: no cover - tqdm is optional
|
|
18
|
+
from tqdm.auto import tqdm
|
|
19
|
+
except Exception: # pragma: no cover
|
|
20
|
+
def tqdm(iterable: Iterable, **_: Any) -> Iterable:
|
|
21
|
+
return iterable
|
|
22
|
+
|
|
23
|
+
from .classify import Classify, ClassifyConfig
|
|
24
|
+
from .codify import Codify, CodifyConfig
|
|
25
|
+
from .extract import Extract, ExtractConfig
|
|
26
|
+
from .paraphrase import Paraphrase, ParaphraseConfig
|
|
27
|
+
from .rank import Rank, RankConfig
|
|
28
|
+
from .rate import Rate, RateConfig
|
|
29
|
+
try: # statsmodels is optional; fall back to a lightweight solver if missing
|
|
30
|
+
from ..utils.plot_utils import fit_ols as _fit_ols
|
|
31
|
+
except Exception: # pragma: no cover - fallback exercised when statsmodels absent
|
|
32
|
+
|
|
33
|
+
def fit_ols(
|
|
34
|
+
y: np.ndarray,
|
|
35
|
+
X: np.ndarray,
|
|
36
|
+
*,
|
|
37
|
+
robust: bool = True, # noqa: ARG001 - signature parity with primary implementation
|
|
38
|
+
varnames: Optional[List[str]] = None, # noqa: ARG001
|
|
39
|
+
) -> Dict[str, Any]:
|
|
40
|
+
"""Minimal OLS routine used when :mod:`statsmodels` is unavailable.
|
|
41
|
+
|
|
42
|
+
The implementation mirrors the API of
|
|
43
|
+
:func:`gabriel.utils.plot_utils.fit_ols` closely enough for the
|
|
44
|
+
debiasing pipeline, returning coefficient estimates, approximate
|
|
45
|
+
standard errors, and residuals. ``robust`` and ``varnames`` are
|
|
46
|
+
accepted for signature compatibility but otherwise ignored.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
y_arr = np.asarray(y, dtype=float)
|
|
50
|
+
X_arr = np.asarray(X, dtype=float)
|
|
51
|
+
if y_arr.ndim != 1:
|
|
52
|
+
y_arr = y_arr.reshape(-1)
|
|
53
|
+
if X_arr.ndim != 2:
|
|
54
|
+
raise ValueError("Design matrix must be two-dimensional")
|
|
55
|
+
|
|
56
|
+
beta, _, _, _ = np.linalg.lstsq(X_arr, y_arr, rcond=None)
|
|
57
|
+
resid = y_arr - X_arr @ beta
|
|
58
|
+
n, k_plus1 = X_arr.shape
|
|
59
|
+
df_resid = n - k_plus1
|
|
60
|
+
|
|
61
|
+
XtX = X_arr.T @ X_arr
|
|
62
|
+
try:
|
|
63
|
+
XtX_inv = np.linalg.inv(XtX)
|
|
64
|
+
except np.linalg.LinAlgError: # pseudo-inverse when design is singular
|
|
65
|
+
XtX_inv = np.linalg.pinv(XtX)
|
|
66
|
+
|
|
67
|
+
if df_resid > 0:
|
|
68
|
+
sigma2 = float(resid @ resid) / df_resid
|
|
69
|
+
cov = sigma2 * XtX_inv
|
|
70
|
+
se = np.sqrt(np.diag(cov))
|
|
71
|
+
rse = float(np.sqrt(sigma2))
|
|
72
|
+
else:
|
|
73
|
+
sigma2 = 0.0
|
|
74
|
+
se = np.full(beta.shape, np.nan)
|
|
75
|
+
rse = np.nan
|
|
76
|
+
|
|
77
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
78
|
+
t_vals = beta / se
|
|
79
|
+
p_vals = np.full(beta.shape, np.nan)
|
|
80
|
+
|
|
81
|
+
mean_y = float(y_arr.mean()) if n else 0.0
|
|
82
|
+
ss_tot = float(np.sum((y_arr - mean_y) ** 2))
|
|
83
|
+
ss_res = float(np.sum(resid ** 2))
|
|
84
|
+
r2 = 1.0 - ss_res / ss_tot if ss_tot else 0.0
|
|
85
|
+
adj_r2 = (
|
|
86
|
+
1.0 - (1.0 - r2) * (n - 1) / df_resid if df_resid > 0 else np.nan
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return {
|
|
90
|
+
"coef": np.asarray(beta),
|
|
91
|
+
"se": np.asarray(se),
|
|
92
|
+
"t": np.asarray(t_vals),
|
|
93
|
+
"p": np.asarray(p_vals),
|
|
94
|
+
"r2": float(r2),
|
|
95
|
+
"adj_r2": float(adj_r2),
|
|
96
|
+
"n": int(n),
|
|
97
|
+
"k": int(k_plus1 - 1),
|
|
98
|
+
"rse": rse,
|
|
99
|
+
"F": np.nan,
|
|
100
|
+
"resid": np.asarray(resid),
|
|
101
|
+
"varnames": varnames,
|
|
102
|
+
"sm_results": None,
|
|
103
|
+
}
|
|
104
|
+
else: # pragma: no cover - executed when statsmodels dependency is available
|
|
105
|
+
fit_ols = _fit_ols
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _fit_ols_fallback(
|
|
109
|
+
y: np.ndarray,
|
|
110
|
+
X: np.ndarray,
|
|
111
|
+
*,
|
|
112
|
+
robust: bool = True, # noqa: ARG001 - signature parity
|
|
113
|
+
varnames: Optional[List[str]] = None, # noqa: ARG001
|
|
114
|
+
) -> Dict[str, Any]:
|
|
115
|
+
"""Minimal OLS routine used when statsmodels-backed fitting is unavailable."""
|
|
116
|
+
|
|
117
|
+
y_arr = np.asarray(y, dtype=float)
|
|
118
|
+
X_arr = np.asarray(X, dtype=float)
|
|
119
|
+
if y_arr.ndim != 1:
|
|
120
|
+
y_arr = y_arr.reshape(-1)
|
|
121
|
+
if X_arr.ndim != 2:
|
|
122
|
+
raise ValueError("Design matrix must be two-dimensional")
|
|
123
|
+
|
|
124
|
+
beta, _, _, _ = np.linalg.lstsq(X_arr, y_arr, rcond=None)
|
|
125
|
+
resid = y_arr - X_arr @ beta
|
|
126
|
+
n, k_plus1 = X_arr.shape
|
|
127
|
+
df_resid = n - k_plus1
|
|
128
|
+
|
|
129
|
+
XtX = X_arr.T @ X_arr
|
|
130
|
+
try:
|
|
131
|
+
XtX_inv = np.linalg.inv(XtX)
|
|
132
|
+
except np.linalg.LinAlgError:
|
|
133
|
+
XtX_inv = np.linalg.pinv(XtX)
|
|
134
|
+
|
|
135
|
+
if df_resid > 0:
|
|
136
|
+
sigma2 = float(resid @ resid) / df_resid
|
|
137
|
+
cov = sigma2 * XtX_inv
|
|
138
|
+
se = np.sqrt(np.diag(cov))
|
|
139
|
+
rse = float(np.sqrt(sigma2))
|
|
140
|
+
else:
|
|
141
|
+
sigma2 = 0.0
|
|
142
|
+
se = np.full(beta.shape, np.nan)
|
|
143
|
+
rse = np.nan
|
|
144
|
+
|
|
145
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
146
|
+
t_vals = beta / se
|
|
147
|
+
p_vals = np.full(beta.shape, np.nan)
|
|
148
|
+
|
|
149
|
+
mean_y = float(y_arr.mean()) if n else 0.0
|
|
150
|
+
ss_tot = float(np.sum((y_arr - mean_y) ** 2))
|
|
151
|
+
ss_res = float(np.sum(resid ** 2))
|
|
152
|
+
r2 = 1.0 - ss_res / ss_tot if ss_tot else 0.0
|
|
153
|
+
adj_r2 = 1.0 - (1.0 - r2) * (n - 1) / df_resid if df_resid > 0 else np.nan
|
|
154
|
+
|
|
155
|
+
return {
|
|
156
|
+
"coef": np.asarray(beta),
|
|
157
|
+
"se": np.asarray(se),
|
|
158
|
+
"t": np.asarray(t_vals),
|
|
159
|
+
"p": np.asarray(p_vals),
|
|
160
|
+
"r2": float(r2),
|
|
161
|
+
"adj_r2": float(adj_r2),
|
|
162
|
+
"n": int(n),
|
|
163
|
+
"k": int(k_plus1 - 1),
|
|
164
|
+
"rse": rse,
|
|
165
|
+
"F": np.nan,
|
|
166
|
+
"resid": np.asarray(resid),
|
|
167
|
+
"varnames": varnames,
|
|
168
|
+
"sm_results": None,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _safe_fit_ols(
|
|
173
|
+
y: np.ndarray,
|
|
174
|
+
X: np.ndarray,
|
|
175
|
+
*,
|
|
176
|
+
robust: bool,
|
|
177
|
+
varnames: Optional[List[str]] = None,
|
|
178
|
+
) -> Dict[str, Any]:
|
|
179
|
+
"""Attempt statsmodels-backed fitting, falling back to NumPy OLS when needed."""
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
return fit_ols(y, X, robust=robust, varnames=varnames)
|
|
183
|
+
except ImportError:
|
|
184
|
+
if robust:
|
|
185
|
+
warnings.warn(
|
|
186
|
+
"statsmodels is unavailable; falling back to non-robust OLS for debiasing diagnostics.",
|
|
187
|
+
RuntimeWarning,
|
|
188
|
+
stacklevel=2,
|
|
189
|
+
)
|
|
190
|
+
return _fit_ols_fallback(y, X, robust=False, varnames=varnames)
|
|
191
|
+
|
|
192
|
+
DEFAULT_SAVE_DIR = os.path.expanduser("~/Documents/runs")
|
|
193
|
+
RemovalMethod = Literal["codify", "paraphrase"]
|
|
194
|
+
MeasurementMode = Literal["rate", "classify", "extract", "rank"]
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@dataclass
|
|
198
|
+
class DebiasRegressionResult:
|
|
199
|
+
"""Container for regression diagnostics of a debiasing run."""
|
|
200
|
+
|
|
201
|
+
variant: str
|
|
202
|
+
display_name: str
|
|
203
|
+
strip_percentage: Optional[int]
|
|
204
|
+
correlation: Optional[float]
|
|
205
|
+
mean_original: Optional[float]
|
|
206
|
+
mean_stripped: Optional[float]
|
|
207
|
+
attenuation_pct: Optional[float] = None
|
|
208
|
+
diff_regression: Optional[Dict[str, Any]] = None
|
|
209
|
+
twostep_regression: Optional[Dict[str, Any]] = None
|
|
210
|
+
stage1_regression: Optional[Dict[str, Any]] = None
|
|
211
|
+
stage1_delta: Optional[float] = None
|
|
212
|
+
debiased_columns: Dict[str, str] = field(default_factory=dict)
|
|
213
|
+
|
|
214
|
+
def as_dict(self) -> Dict[str, Any]:
|
|
215
|
+
"""Return a JSON serialisable representation of the result."""
|
|
216
|
+
|
|
217
|
+
def _convert(value: Any) -> Any:
|
|
218
|
+
if isinstance(value, dict):
|
|
219
|
+
return {k: _convert(v) for k, v in value.items()}
|
|
220
|
+
if isinstance(value, (list, tuple)):
|
|
221
|
+
return [_convert(v) for v in value]
|
|
222
|
+
if isinstance(value, np.generic):
|
|
223
|
+
return value.item()
|
|
224
|
+
return value
|
|
225
|
+
|
|
226
|
+
return {
|
|
227
|
+
"variant": self.variant,
|
|
228
|
+
"display_name": self.display_name,
|
|
229
|
+
"strip_percentage": self.strip_percentage,
|
|
230
|
+
"correlation": _convert(self.correlation),
|
|
231
|
+
"mean_original": _convert(self.mean_original),
|
|
232
|
+
"mean_stripped": _convert(self.mean_stripped),
|
|
233
|
+
"attenuation_pct": _convert(self.attenuation_pct),
|
|
234
|
+
"diff_regression": _convert(self.diff_regression),
|
|
235
|
+
"twostep_regression": _convert(self.twostep_regression),
|
|
236
|
+
"stage1_regression": _convert(self.stage1_regression),
|
|
237
|
+
"stage1_delta": _convert(self.stage1_delta),
|
|
238
|
+
"debiased_columns": _convert(self.debiased_columns),
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@dataclass
|
|
243
|
+
class DebiasConfig:
|
|
244
|
+
"""Configuration for :class:`DebiasPipeline`."""
|
|
245
|
+
|
|
246
|
+
mode: MeasurementMode = "rate"
|
|
247
|
+
measurement_attribute: Optional[str] = None
|
|
248
|
+
removal_attribute: Optional[str] = None
|
|
249
|
+
attributes: Dict[str, str] = field(default_factory=dict)
|
|
250
|
+
signal_dictionary: Dict[str, str] = field(default_factory=dict)
|
|
251
|
+
removal_method: RemovalMethod = "codify"
|
|
252
|
+
remaining_signal: bool = True
|
|
253
|
+
remaining_signal_attribute: Optional[str] = None
|
|
254
|
+
remaining_signal_description: Optional[str] = None
|
|
255
|
+
save_dir: str = DEFAULT_SAVE_DIR
|
|
256
|
+
run_name: Optional[str] = None
|
|
257
|
+
strip_percentages: Optional[List[int]] = None
|
|
258
|
+
categories_to_strip: Optional[List[str]] = None
|
|
259
|
+
template_path: Optional[str] = None
|
|
260
|
+
model: str = "gpt-5-mini"
|
|
261
|
+
n_parallels: int = 650
|
|
262
|
+
measurement_kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
263
|
+
removal_kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
264
|
+
use_dummy: bool = False
|
|
265
|
+
robust_regression: bool = True
|
|
266
|
+
random_seed: int = 12345
|
|
267
|
+
verbose: bool = True
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@dataclass
|
|
271
|
+
class DebiasResult:
|
|
272
|
+
"""Return object for :func:`gabriel.debias`."""
|
|
273
|
+
|
|
274
|
+
results: pd.DataFrame
|
|
275
|
+
metadata: Dict[str, Any]
|
|
276
|
+
regression: Dict[str, DebiasRegressionResult]
|
|
277
|
+
|
|
278
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
279
|
+
"""Serialise the result metadata for convenient inspection."""
|
|
280
|
+
|
|
281
|
+
return {
|
|
282
|
+
"metadata": self.metadata,
|
|
283
|
+
"regression": {k: v.as_dict() for k, v in self.regression.items()},
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class DebiasPipeline:
|
|
288
|
+
"""Coordinate debiasing runs that rely on core Gabriel tasks."""
|
|
289
|
+
|
|
290
|
+
def __init__(self, config: DebiasConfig) -> None:
|
|
291
|
+
self.cfg = config
|
|
292
|
+
self._validate_config()
|
|
293
|
+
base_dir = os.path.expandvars(os.path.expanduser(self.cfg.save_dir))
|
|
294
|
+
os.makedirs(base_dir, exist_ok=True)
|
|
295
|
+
run_name = self.cfg.run_name or self._default_run_name()
|
|
296
|
+
self.run_dir = os.path.join(base_dir, run_name)
|
|
297
|
+
os.makedirs(self.run_dir, exist_ok=True)
|
|
298
|
+
|
|
299
|
+
# ------------------------------------------------------------------
|
|
300
|
+
def _default_run_name(self) -> str:
|
|
301
|
+
base_name = (
|
|
302
|
+
self.cfg.measurement_attribute
|
|
303
|
+
or self.cfg.removal_attribute
|
|
304
|
+
or "signal"
|
|
305
|
+
)
|
|
306
|
+
cleaned = re.sub(r"[^a-zA-Z0-9_-]+", "_", base_name)
|
|
307
|
+
prefix = "debias"
|
|
308
|
+
cleaned = cleaned.strip("_")
|
|
309
|
+
if cleaned:
|
|
310
|
+
return f"{prefix}_{cleaned}"
|
|
311
|
+
return prefix
|
|
312
|
+
|
|
313
|
+
# ------------------------------------------------------------------
|
|
314
|
+
def _validate_config(self) -> None:
|
|
315
|
+
valid_modes = {"rate", "classify", "extract", "rank"}
|
|
316
|
+
if self.cfg.mode not in valid_modes:
|
|
317
|
+
raise ValueError("mode must be one of {'rate', 'classify', 'extract', 'rank'}")
|
|
318
|
+
self.cfg.attributes = dict(self.cfg.attributes or {})
|
|
319
|
+
if not self.cfg.attributes:
|
|
320
|
+
raise ValueError("attributes must be supplied for the selected mode")
|
|
321
|
+
|
|
322
|
+
self.cfg.signal_dictionary = dict(self.cfg.signal_dictionary or {})
|
|
323
|
+
if not self.cfg.signal_dictionary:
|
|
324
|
+
raise ValueError("signal_dictionary must describe the signal to remove")
|
|
325
|
+
attr_keys = list(self.cfg.attributes.keys())
|
|
326
|
+
first_attr = attr_keys[0]
|
|
327
|
+
measurement_attr = self.cfg.measurement_attribute
|
|
328
|
+
if measurement_attr is not None and measurement_attr not in self.cfg.attributes:
|
|
329
|
+
raise ValueError(
|
|
330
|
+
f"Measurement attribute '{measurement_attr}' must be a key in attributes"
|
|
331
|
+
)
|
|
332
|
+
if measurement_attr is None:
|
|
333
|
+
measurement_attr = first_attr
|
|
334
|
+
if self.cfg.verbose:
|
|
335
|
+
msg = (
|
|
336
|
+
"[Debias] measurement_attribute not provided; "
|
|
337
|
+
f"defaulting to '{measurement_attr}'."
|
|
338
|
+
)
|
|
339
|
+
if len(attr_keys) > 1:
|
|
340
|
+
msg += " Debiasing will use the first attribute provided."
|
|
341
|
+
print(msg)
|
|
342
|
+
self.cfg.measurement_attribute = measurement_attr
|
|
343
|
+
|
|
344
|
+
removal_attr = self.cfg.removal_attribute
|
|
345
|
+
if removal_attr is None:
|
|
346
|
+
if measurement_attr in self.cfg.signal_dictionary:
|
|
347
|
+
removal_attr = measurement_attr
|
|
348
|
+
if self.cfg.verbose:
|
|
349
|
+
print(
|
|
350
|
+
"[Debias] removal_attribute not provided; "
|
|
351
|
+
f"defaulting to measurement attribute '{removal_attr}'."
|
|
352
|
+
)
|
|
353
|
+
else:
|
|
354
|
+
removal_attr = next(iter(self.cfg.signal_dictionary))
|
|
355
|
+
if self.cfg.verbose:
|
|
356
|
+
print(
|
|
357
|
+
"[Debias] removal_attribute not provided; "
|
|
358
|
+
f"defaulting to '{removal_attr}'."
|
|
359
|
+
)
|
|
360
|
+
elif removal_attr not in self.cfg.signal_dictionary:
|
|
361
|
+
raise ValueError(
|
|
362
|
+
f"Removal attribute '{removal_attr}' must be a key in signal_dictionary"
|
|
363
|
+
)
|
|
364
|
+
self.cfg.removal_attribute = removal_attr
|
|
365
|
+
if self.cfg.removal_method not in {"codify", "paraphrase"}:
|
|
366
|
+
raise ValueError("removal_method must be 'codify' or 'paraphrase'")
|
|
367
|
+
|
|
368
|
+
if not bool(self.cfg.remaining_signal):
|
|
369
|
+
self.cfg.remaining_signal_attribute = None
|
|
370
|
+
self.cfg.remaining_signal_description = None
|
|
371
|
+
else:
|
|
372
|
+
attr_name = (self.cfg.remaining_signal_attribute or "").strip()
|
|
373
|
+
if not attr_name:
|
|
374
|
+
default_attr = f"prevalence of {self.cfg.removal_attribute}"
|
|
375
|
+
default_desc = (
|
|
376
|
+
"Prevalence of any direct mentions or allusions to concepts related to "
|
|
377
|
+
f"{self.cfg.removal_attribute} anywhere in the text."
|
|
378
|
+
)
|
|
379
|
+
self.cfg.remaining_signal_attribute = default_attr
|
|
380
|
+
self.cfg.remaining_signal_description = default_desc
|
|
381
|
+
if self.cfg.verbose:
|
|
382
|
+
print(
|
|
383
|
+
"[Debias] remaining_signal_attribute not provided; "
|
|
384
|
+
f"defaulting to '{default_attr}'."
|
|
385
|
+
)
|
|
386
|
+
elif not self.cfg.remaining_signal_description:
|
|
387
|
+
raise ValueError(
|
|
388
|
+
"remaining_signal_description must be provided when remaining_signal_attribute is set"
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
if self.cfg.categories_to_strip is None:
|
|
392
|
+
if self.cfg.removal_method == "codify":
|
|
393
|
+
self.cfg.categories_to_strip = [self.cfg.removal_attribute]
|
|
394
|
+
else:
|
|
395
|
+
self.cfg.categories_to_strip = []
|
|
396
|
+
else:
|
|
397
|
+
categories: List[str] = []
|
|
398
|
+
for name in self.cfg.categories_to_strip:
|
|
399
|
+
if name in self.cfg.signal_dictionary and name not in categories:
|
|
400
|
+
categories.append(name)
|
|
401
|
+
if self.cfg.removal_method == "codify" and not categories:
|
|
402
|
+
raise ValueError(
|
|
403
|
+
"categories_to_strip must contain at least one key from signal_dictionary when using codify"
|
|
404
|
+
)
|
|
405
|
+
self.cfg.categories_to_strip = categories
|
|
406
|
+
|
|
407
|
+
if self.cfg.strip_percentages is None:
|
|
408
|
+
percentages: List[int] = [100]
|
|
409
|
+
else:
|
|
410
|
+
percentages = []
|
|
411
|
+
for pct in self.cfg.strip_percentages:
|
|
412
|
+
try:
|
|
413
|
+
val = int(pct)
|
|
414
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
415
|
+
raise ValueError(f"Invalid strip percentage: {pct!r}") from exc
|
|
416
|
+
val = max(0, min(100, val))
|
|
417
|
+
if val not in percentages:
|
|
418
|
+
percentages.append(val)
|
|
419
|
+
if not percentages:
|
|
420
|
+
percentages = [100]
|
|
421
|
+
percentages.sort()
|
|
422
|
+
self.cfg.strip_percentages = percentages
|
|
423
|
+
|
|
424
|
+
self.cfg.measurement_kwargs = dict(self.cfg.measurement_kwargs or {})
|
|
425
|
+
self.cfg.removal_kwargs = dict(self.cfg.removal_kwargs or {})
|
|
426
|
+
|
|
427
|
+
# ------------------------------------------------------------------
|
|
428
|
+
async def run(
|
|
429
|
+
self,
|
|
430
|
+
df: pd.DataFrame,
|
|
431
|
+
column_name: str,
|
|
432
|
+
*,
|
|
433
|
+
reset_files: bool = False,
|
|
434
|
+
) -> DebiasResult:
|
|
435
|
+
if column_name not in df.columns:
|
|
436
|
+
raise ValueError(f"Column '{column_name}' not found in DataFrame")
|
|
437
|
+
|
|
438
|
+
df_master = df.copy().reset_index(drop=True)
|
|
439
|
+
row_index = pd.RangeIndex(start=0, stop=len(df_master), step=1, name="__debias_row_id")
|
|
440
|
+
df_master.index = row_index
|
|
441
|
+
|
|
442
|
+
if self.cfg.verbose:
|
|
443
|
+
print(f"[Debias] Running debiasing pipeline on {len(df_master)} rows.")
|
|
444
|
+
|
|
445
|
+
attr_keys = list(self.cfg.attributes.keys())
|
|
446
|
+
if self.cfg.verbose and len(attr_keys) > 1:
|
|
447
|
+
print(
|
|
448
|
+
"[Debias] Multiple attributes detected; debiasing will focus on "
|
|
449
|
+
f"'{self.cfg.measurement_attribute}' while measuring all attributes "
|
|
450
|
+
"on both raw and stripped text."
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
existing_raw_attrs = [attr for attr in attr_keys if attr in df_master.columns]
|
|
454
|
+
missing_raw_attrs = [attr for attr in attr_keys if attr not in df_master.columns]
|
|
455
|
+
|
|
456
|
+
if existing_raw_attrs and self.cfg.verbose:
|
|
457
|
+
print(
|
|
458
|
+
"[Debias] Using existing raw measurement columns: "
|
|
459
|
+
+ ", ".join(existing_raw_attrs)
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
if missing_raw_attrs:
|
|
463
|
+
if self.cfg.verbose:
|
|
464
|
+
print(
|
|
465
|
+
"[Debias] Measuring baseline signals for missing attributes: "
|
|
466
|
+
+ ", ".join(missing_raw_attrs)
|
|
467
|
+
)
|
|
468
|
+
base_measure = await self._run_measurement(
|
|
469
|
+
df_master,
|
|
470
|
+
column_name=column_name,
|
|
471
|
+
mode=self.cfg.mode,
|
|
472
|
+
save_label="original",
|
|
473
|
+
attributes={k: self.cfg.attributes[k] for k in missing_raw_attrs},
|
|
474
|
+
template_path=self.cfg.template_path,
|
|
475
|
+
extra_kwargs=self.cfg.measurement_kwargs,
|
|
476
|
+
default_model=self.cfg.model,
|
|
477
|
+
reset_files=reset_files,
|
|
478
|
+
)
|
|
479
|
+
self._attach_measurement(
|
|
480
|
+
df_master,
|
|
481
|
+
base_measure,
|
|
482
|
+
missing_raw_attrs,
|
|
483
|
+
variant_key="original",
|
|
484
|
+
display_name="original",
|
|
485
|
+
)
|
|
486
|
+
elif self.cfg.verbose:
|
|
487
|
+
print("[Debias] All raw measurement attributes already exist; skipping baseline rating step.")
|
|
488
|
+
|
|
489
|
+
if self.cfg.removal_method == "codify":
|
|
490
|
+
variant_info = await self._prepare_codify_variants(
|
|
491
|
+
df_master, column_name, reset_files=reset_files
|
|
492
|
+
)
|
|
493
|
+
else:
|
|
494
|
+
variant_info = await self._prepare_paraphrase_variant(
|
|
495
|
+
df_master, column_name, reset_files=reset_files
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
if variant_info and self.cfg.verbose:
|
|
499
|
+
print("[Debias] Measuring stripped variants...")
|
|
500
|
+
|
|
501
|
+
variant_count = len(variant_info)
|
|
502
|
+
for info in variant_info.values():
|
|
503
|
+
info["label_suffix"] = self._variant_label_suffix(info, variant_count)
|
|
504
|
+
|
|
505
|
+
disable_progress = not (self.cfg.verbose and bool(variant_info))
|
|
506
|
+
for key in tqdm(list(variant_info.keys()), desc="Variants", disable=disable_progress):
|
|
507
|
+
info = variant_info[key]
|
|
508
|
+
measure_df = await self._run_measurement(
|
|
509
|
+
df_master,
|
|
510
|
+
column_name=info["text_column"],
|
|
511
|
+
mode=self.cfg.mode,
|
|
512
|
+
save_label=key,
|
|
513
|
+
attributes=self.cfg.attributes,
|
|
514
|
+
template_path=self.cfg.template_path,
|
|
515
|
+
extra_kwargs=self.cfg.measurement_kwargs,
|
|
516
|
+
default_model=self.cfg.model,
|
|
517
|
+
reset_files=reset_files,
|
|
518
|
+
)
|
|
519
|
+
column_map = self._attach_measurement(
|
|
520
|
+
df_master,
|
|
521
|
+
measure_df,
|
|
522
|
+
attr_keys,
|
|
523
|
+
variant_key=key,
|
|
524
|
+
display_name=info["label_suffix"],
|
|
525
|
+
)
|
|
526
|
+
info["measurement_columns"] = column_map
|
|
527
|
+
self._print_average_attenuation(
|
|
528
|
+
df_master,
|
|
529
|
+
attr_keys,
|
|
530
|
+
column_map,
|
|
531
|
+
label_suffix=info["label_suffix"],
|
|
532
|
+
display_name=info["display"],
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
regression_info: Dict[str, DebiasRegressionResult] = {}
|
|
536
|
+
for key, info in variant_info.items():
|
|
537
|
+
stripped_column = info["measurement_columns"].get(
|
|
538
|
+
self.cfg.measurement_attribute
|
|
539
|
+
)
|
|
540
|
+
if not stripped_column:
|
|
541
|
+
continue
|
|
542
|
+
remaining_signal_column = await self._measure_remaining_signal(
|
|
543
|
+
df_master,
|
|
544
|
+
info=info,
|
|
545
|
+
save_label=key,
|
|
546
|
+
label_suffix=info["label_suffix"],
|
|
547
|
+
reset_files=reset_files,
|
|
548
|
+
)
|
|
549
|
+
summary = self._run_debiasing(
|
|
550
|
+
df_master,
|
|
551
|
+
original_column=self.cfg.measurement_attribute,
|
|
552
|
+
stripped_column=stripped_column,
|
|
553
|
+
remaining_signal_column=remaining_signal_column,
|
|
554
|
+
variant_key=key,
|
|
555
|
+
display_name=info["display"],
|
|
556
|
+
strip_percentage=info.get("strip_percentage"),
|
|
557
|
+
label_suffix=info["label_suffix"],
|
|
558
|
+
)
|
|
559
|
+
regression_info[key] = summary
|
|
560
|
+
|
|
561
|
+
metadata = {
|
|
562
|
+
"config": self._serialise_config(),
|
|
563
|
+
"variants": [regression_info[k].as_dict() for k in regression_info],
|
|
564
|
+
"result_path": os.path.join(self.run_dir, "debias_results.csv"),
|
|
565
|
+
}
|
|
566
|
+
metadata_path = os.path.join(self.run_dir, "debias_metadata.json")
|
|
567
|
+
with open(metadata_path, "w", encoding="utf-8") as f:
|
|
568
|
+
json.dump(metadata, f, indent=2)
|
|
569
|
+
metadata["metadata_path"] = metadata_path
|
|
570
|
+
|
|
571
|
+
results_df = df_master.reset_index(drop=True)
|
|
572
|
+
results_path = metadata["result_path"]
|
|
573
|
+
results_df.to_csv(results_path, index=False)
|
|
574
|
+
if self.cfg.verbose:
|
|
575
|
+
print(f"[Debias] Results saved to {results_path}")
|
|
576
|
+
|
|
577
|
+
return DebiasResult(
|
|
578
|
+
results=results_df,
|
|
579
|
+
metadata=metadata,
|
|
580
|
+
regression=regression_info,
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
# ------------------------------------------------------------------
|
|
584
|
+
async def _run_measurement(
|
|
585
|
+
self,
|
|
586
|
+
df: pd.DataFrame,
|
|
587
|
+
*,
|
|
588
|
+
column_name: str,
|
|
589
|
+
mode: MeasurementMode,
|
|
590
|
+
save_label: str,
|
|
591
|
+
attributes: Optional[Dict[str, str]],
|
|
592
|
+
template_path: Optional[str],
|
|
593
|
+
extra_kwargs: Optional[Dict[str, Any]],
|
|
594
|
+
default_model: str,
|
|
595
|
+
reset_files: bool,
|
|
596
|
+
) -> pd.DataFrame:
|
|
597
|
+
kwargs = dict(extra_kwargs or {})
|
|
598
|
+
response_fn = kwargs.pop("response_fn", None)
|
|
599
|
+
get_all_responses_fn = kwargs.pop("get_all_responses_fn", None)
|
|
600
|
+
save_dir = kwargs.pop("save_dir", os.path.join(self.run_dir, save_label))
|
|
601
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
602
|
+
reset_files_local = bool(kwargs.pop("reset_files", False)) or bool(reset_files)
|
|
603
|
+
df_reset = df[[column_name]].reset_index()
|
|
604
|
+
run_kwargs: Dict[str, Any] = {}
|
|
605
|
+
if response_fn is not None:
|
|
606
|
+
run_kwargs["response_fn"] = response_fn
|
|
607
|
+
if get_all_responses_fn is not None:
|
|
608
|
+
run_kwargs["get_all_responses_fn"] = get_all_responses_fn
|
|
609
|
+
if mode == "rate":
|
|
610
|
+
cfg = RateConfig(
|
|
611
|
+
attributes=attributes or {},
|
|
612
|
+
save_dir=save_dir,
|
|
613
|
+
model=kwargs.pop("model", default_model),
|
|
614
|
+
n_parallels=kwargs.pop("n_parallels", self.cfg.n_parallels),
|
|
615
|
+
use_dummy=kwargs.pop("use_dummy", self.cfg.use_dummy),
|
|
616
|
+
**kwargs,
|
|
617
|
+
)
|
|
618
|
+
runner = Rate(cfg, template_path=template_path)
|
|
619
|
+
result = await runner.run(
|
|
620
|
+
df_reset,
|
|
621
|
+
column_name,
|
|
622
|
+
reset_files=reset_files_local,
|
|
623
|
+
**run_kwargs,
|
|
624
|
+
)
|
|
625
|
+
elif mode == "classify":
|
|
626
|
+
cfg = ClassifyConfig(
|
|
627
|
+
labels=attributes or {},
|
|
628
|
+
save_dir=save_dir,
|
|
629
|
+
model=kwargs.pop("model", default_model),
|
|
630
|
+
n_parallels=kwargs.pop("n_parallels", self.cfg.n_parallels),
|
|
631
|
+
use_dummy=kwargs.pop("use_dummy", self.cfg.use_dummy),
|
|
632
|
+
**kwargs,
|
|
633
|
+
)
|
|
634
|
+
runner = Classify(cfg, template_path=template_path)
|
|
635
|
+
result = await runner.run(
|
|
636
|
+
df_reset,
|
|
637
|
+
column_name,
|
|
638
|
+
reset_files=reset_files_local,
|
|
639
|
+
**run_kwargs,
|
|
640
|
+
)
|
|
641
|
+
elif mode == "extract":
|
|
642
|
+
cfg = ExtractConfig(
|
|
643
|
+
attributes=attributes or {},
|
|
644
|
+
save_dir=save_dir,
|
|
645
|
+
model=kwargs.pop("model", default_model),
|
|
646
|
+
n_parallels=kwargs.pop("n_parallels", self.cfg.n_parallels),
|
|
647
|
+
use_dummy=kwargs.pop("use_dummy", self.cfg.use_dummy),
|
|
648
|
+
**kwargs,
|
|
649
|
+
)
|
|
650
|
+
runner = Extract(cfg, template_path=template_path)
|
|
651
|
+
result = await runner.run(
|
|
652
|
+
df_reset,
|
|
653
|
+
column_name,
|
|
654
|
+
reset_files=reset_files_local,
|
|
655
|
+
**run_kwargs,
|
|
656
|
+
)
|
|
657
|
+
elif mode == "rank":
|
|
658
|
+
cfg = RankConfig(
|
|
659
|
+
attributes=attributes or {},
|
|
660
|
+
save_dir=save_dir,
|
|
661
|
+
model=kwargs.pop("model", default_model),
|
|
662
|
+
n_parallels=kwargs.pop("n_parallels", self.cfg.n_parallels),
|
|
663
|
+
use_dummy=kwargs.pop("use_dummy", self.cfg.use_dummy),
|
|
664
|
+
**kwargs,
|
|
665
|
+
)
|
|
666
|
+
runner = Rank(cfg, template_path=template_path)
|
|
667
|
+
result = await runner.run(
|
|
668
|
+
df_reset,
|
|
669
|
+
column_name,
|
|
670
|
+
id_column="__debias_row_id",
|
|
671
|
+
reset_files=reset_files_local,
|
|
672
|
+
**run_kwargs,
|
|
673
|
+
)
|
|
674
|
+
else: # pragma: no cover - defensive
|
|
675
|
+
raise ValueError(f"Unsupported mode: {mode}")
|
|
676
|
+
|
|
677
|
+
if "__debias_row_id" in result.columns:
|
|
678
|
+
result = result.set_index("__debias_row_id")
|
|
679
|
+
elif "identifier" in result.columns:
|
|
680
|
+
result = result.set_index("identifier")
|
|
681
|
+
else:
|
|
682
|
+
result.index = df.index
|
|
683
|
+
result.index.name = "__debias_row_id"
|
|
684
|
+
return result
|
|
685
|
+
|
|
686
|
+
# ------------------------------------------------------------------
|
|
687
|
+
def _attach_measurement(
|
|
688
|
+
self,
|
|
689
|
+
df_master: pd.DataFrame,
|
|
690
|
+
measurement_df: pd.DataFrame,
|
|
691
|
+
attributes: List[str],
|
|
692
|
+
*,
|
|
693
|
+
variant_key: str,
|
|
694
|
+
display_name: str,
|
|
695
|
+
) -> Dict[str, str]:
|
|
696
|
+
column_map: Dict[str, str] = {}
|
|
697
|
+
for attr in attributes:
|
|
698
|
+
if attr not in measurement_df.columns:
|
|
699
|
+
continue
|
|
700
|
+
if variant_key == "original":
|
|
701
|
+
target_name = attr
|
|
702
|
+
else:
|
|
703
|
+
target_name = f"{attr} ({display_name})"
|
|
704
|
+
df_master[target_name] = measurement_df[attr].reindex(df_master.index)
|
|
705
|
+
column_map[attr] = target_name
|
|
706
|
+
return column_map
|
|
707
|
+
|
|
708
|
+
# ------------------------------------------------------------------
|
|
709
|
+
def _remaining_signal_attributes(self) -> Dict[str, str]:
|
|
710
|
+
if not self.cfg.remaining_signal_attribute:
|
|
711
|
+
return {}
|
|
712
|
+
return {
|
|
713
|
+
self.cfg.remaining_signal_attribute: str(self.cfg.remaining_signal_description)
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
# ------------------------------------------------------------------
|
|
717
|
+
async def _measure_remaining_signal(
|
|
718
|
+
self,
|
|
719
|
+
df_master: pd.DataFrame,
|
|
720
|
+
*,
|
|
721
|
+
info: Dict[str, Any],
|
|
722
|
+
save_label: str,
|
|
723
|
+
label_suffix: str,
|
|
724
|
+
reset_files: bool,
|
|
725
|
+
) -> Optional[str]:
|
|
726
|
+
attrs = self._remaining_signal_attributes()
|
|
727
|
+
if not attrs:
|
|
728
|
+
return None
|
|
729
|
+
measurement_attr = self.cfg.remaining_signal_attribute
|
|
730
|
+
assert measurement_attr is not None
|
|
731
|
+
target_column = f"{measurement_attr} ({label_suffix})"
|
|
732
|
+
if target_column in df_master.columns:
|
|
733
|
+
if self.cfg.verbose:
|
|
734
|
+
print(
|
|
735
|
+
"[Debias] Using existing remaining-signal column: "
|
|
736
|
+
f"{target_column}"
|
|
737
|
+
)
|
|
738
|
+
return target_column
|
|
739
|
+
if self.cfg.verbose:
|
|
740
|
+
print(
|
|
741
|
+
"[Debias] Measuring remaining-signal prevalence on "
|
|
742
|
+
f"variant '{info['display']}'."
|
|
743
|
+
)
|
|
744
|
+
remaining_df = await self._run_measurement(
|
|
745
|
+
df_master,
|
|
746
|
+
column_name=info["text_column"],
|
|
747
|
+
mode="rate",
|
|
748
|
+
save_label=f"{save_label}_remaining_signal",
|
|
749
|
+
attributes=attrs,
|
|
750
|
+
template_path=self.cfg.template_path,
|
|
751
|
+
extra_kwargs=self.cfg.measurement_kwargs,
|
|
752
|
+
default_model=self.cfg.model,
|
|
753
|
+
reset_files=reset_files,
|
|
754
|
+
)
|
|
755
|
+
column_map = self._attach_measurement(
|
|
756
|
+
df_master,
|
|
757
|
+
remaining_df,
|
|
758
|
+
[measurement_attr],
|
|
759
|
+
variant_key=f"{save_label}_remaining_signal",
|
|
760
|
+
display_name=label_suffix,
|
|
761
|
+
)
|
|
762
|
+
return column_map.get(measurement_attr, target_column)
|
|
763
|
+
|
|
764
|
+
# ------------------------------------------------------------------
|
|
765
|
+
def _variant_label_suffix(self, info: Dict[str, Any], variant_count: int) -> str:
|
|
766
|
+
pct = info.get("strip_percentage")
|
|
767
|
+
if pct is None:
|
|
768
|
+
return "stripped"
|
|
769
|
+
if pct == 100 and variant_count == 1:
|
|
770
|
+
return "stripped"
|
|
771
|
+
return f"stripped {pct}%"
|
|
772
|
+
|
|
773
|
+
async def _prepare_codify_variants(
|
|
774
|
+
self,
|
|
775
|
+
df: pd.DataFrame,
|
|
776
|
+
column_name: str,
|
|
777
|
+
*,
|
|
778
|
+
reset_files: bool,
|
|
779
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
780
|
+
kwargs = dict(self.cfg.removal_kwargs or {})
|
|
781
|
+
reset_files_local = bool(kwargs.pop("reset_files", False)) or bool(reset_files)
|
|
782
|
+
response_fn = kwargs.pop("response_fn", None)
|
|
783
|
+
get_all_responses_fn = kwargs.pop("get_all_responses_fn", None)
|
|
784
|
+
save_dir = kwargs.pop("save_dir", os.path.join(self.run_dir, "codify"))
|
|
785
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
786
|
+
additional_instructions = kwargs.pop("additional_instructions", "")
|
|
787
|
+
run_kwargs: Dict[str, Any] = {}
|
|
788
|
+
if response_fn is not None:
|
|
789
|
+
run_kwargs["response_fn"] = response_fn
|
|
790
|
+
if get_all_responses_fn is not None:
|
|
791
|
+
run_kwargs["get_all_responses_fn"] = get_all_responses_fn
|
|
792
|
+
kwargs.setdefault("n_rounds", 3)
|
|
793
|
+
cfg = CodifyConfig(
|
|
794
|
+
save_dir=save_dir,
|
|
795
|
+
model=kwargs.pop("model", self.cfg.model),
|
|
796
|
+
n_parallels=kwargs.pop("n_parallels", self.cfg.n_parallels),
|
|
797
|
+
use_dummy=kwargs.pop("use_dummy", self.cfg.use_dummy),
|
|
798
|
+
**kwargs,
|
|
799
|
+
)
|
|
800
|
+
runner = Codify(cfg)
|
|
801
|
+
codify_df = await runner.run(
|
|
802
|
+
df.reset_index(),
|
|
803
|
+
column_name,
|
|
804
|
+
categories=self.cfg.signal_dictionary,
|
|
805
|
+
additional_instructions=additional_instructions,
|
|
806
|
+
reset_files=reset_files_local,
|
|
807
|
+
**run_kwargs,
|
|
808
|
+
)
|
|
809
|
+
variants: Dict[str, Dict[str, Any]] = {}
|
|
810
|
+
categories = self.cfg.categories_to_strip or []
|
|
811
|
+
for pct in self.cfg.strip_percentages:
|
|
812
|
+
if pct <= 0:
|
|
813
|
+
continue
|
|
814
|
+
key = f"stripped_{pct:03d}pct"
|
|
815
|
+
stripped_base = f"{column_name} (stripped)"
|
|
816
|
+
single_full_strip = len(self.cfg.strip_percentages or []) == 1 and pct == 100
|
|
817
|
+
display = "stripped" if single_full_strip else f"stripped {pct}%"
|
|
818
|
+
new_col = stripped_base if single_full_strip else f"{column_name} ({display})"
|
|
819
|
+
df[new_col] = [
|
|
820
|
+
self._strip_passages(
|
|
821
|
+
original_text=str(df.at[idx, column_name]),
|
|
822
|
+
snippets=self._collect_snippets(codify_df, idx, categories),
|
|
823
|
+
pct_strip=pct,
|
|
824
|
+
row_idx=int(idx),
|
|
825
|
+
)
|
|
826
|
+
for idx in df.index
|
|
827
|
+
]
|
|
828
|
+
variants[key] = {
|
|
829
|
+
"text_column": new_col,
|
|
830
|
+
"display": display,
|
|
831
|
+
"strip_percentage": pct,
|
|
832
|
+
}
|
|
833
|
+
return variants
|
|
834
|
+
|
|
835
|
+
# ------------------------------------------------------------------
|
|
836
|
+
async def _prepare_paraphrase_variant(
|
|
837
|
+
self,
|
|
838
|
+
df: pd.DataFrame,
|
|
839
|
+
column_name: str,
|
|
840
|
+
*,
|
|
841
|
+
reset_files: bool,
|
|
842
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
843
|
+
kwargs = dict(self.cfg.removal_kwargs or {})
|
|
844
|
+
reset_files_local = bool(kwargs.pop("reset_files", False)) or bool(reset_files)
|
|
845
|
+
response_fn = kwargs.pop("response_fn", None)
|
|
846
|
+
get_all_responses_fn = kwargs.pop("get_all_responses_fn", None)
|
|
847
|
+
save_dir = kwargs.pop("save_dir", os.path.join(self.run_dir, "paraphrase"))
|
|
848
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
849
|
+
revised_name = f"{column_name} (stripped)"
|
|
850
|
+
instructions = kwargs.pop("instructions", None) or self._build_paraphrase_instructions()
|
|
851
|
+
response_kwargs: Dict[str, Any] = {}
|
|
852
|
+
run_kwargs: Dict[str, Any] = {}
|
|
853
|
+
if response_fn is not None:
|
|
854
|
+
run_kwargs["response_fn"] = response_fn
|
|
855
|
+
if get_all_responses_fn is not None:
|
|
856
|
+
run_kwargs["get_all_responses_fn"] = get_all_responses_fn
|
|
857
|
+
if "n_rounds" in kwargs:
|
|
858
|
+
response_kwargs["n_rounds"] = kwargs.pop("n_rounds")
|
|
859
|
+
if "completion_max_rounds" in kwargs and "n_rounds" not in response_kwargs:
|
|
860
|
+
replacement = kwargs.pop("completion_max_rounds")
|
|
861
|
+
warnings.warn(
|
|
862
|
+
"completion_max_rounds is deprecated; use n_rounds instead.",
|
|
863
|
+
DeprecationWarning,
|
|
864
|
+
stacklevel=2,
|
|
865
|
+
)
|
|
866
|
+
if replacement is not None:
|
|
867
|
+
response_kwargs["n_rounds"] = replacement
|
|
868
|
+
response_kwargs.setdefault("n_rounds", 3)
|
|
869
|
+
cfg = ParaphraseConfig(
|
|
870
|
+
instructions=instructions,
|
|
871
|
+
revised_column_name=revised_name,
|
|
872
|
+
save_dir=save_dir,
|
|
873
|
+
model=kwargs.pop("model", self.cfg.model),
|
|
874
|
+
n_parallels=kwargs.pop("n_parallels", self.cfg.n_parallels),
|
|
875
|
+
use_dummy=kwargs.pop("use_dummy", self.cfg.use_dummy),
|
|
876
|
+
**kwargs,
|
|
877
|
+
)
|
|
878
|
+
runner = Paraphrase(cfg)
|
|
879
|
+
paraphrased = await runner.run(
|
|
880
|
+
df.reset_index(),
|
|
881
|
+
column_name,
|
|
882
|
+
reset_files=reset_files_local,
|
|
883
|
+
**run_kwargs,
|
|
884
|
+
**response_kwargs,
|
|
885
|
+
)
|
|
886
|
+
df[revised_name] = paraphrased[revised_name].reindex(df.index)
|
|
887
|
+
return {
|
|
888
|
+
"paraphrase": {
|
|
889
|
+
"text_column": revised_name,
|
|
890
|
+
"display": "stripped",
|
|
891
|
+
"strip_percentage": None,
|
|
892
|
+
}
|
|
893
|
+
}
|
|
894
|
+
|
|
895
|
+
# ------------------------------------------------------------------
|
|
896
|
+
def _build_paraphrase_instructions(self) -> str:
|
|
897
|
+
lines = [
|
|
898
|
+
"Rewrite the passage so that every reference to the following signal(s) is removed while preserving all other content:",
|
|
899
|
+
]
|
|
900
|
+
for key, desc in self.cfg.signal_dictionary.items():
|
|
901
|
+
lines.append(f"- {key}: {desc}")
|
|
902
|
+
lines.append(
|
|
903
|
+
"Keep the remaining content identical. Be exhaustive: the final output must have absolutely no content which references, mentions, manifests, or alludes to the aformentioned signal(s). All such content must be stripped from the text; ensure nothing at all relevant remains. When in doubt, remove it. The stripping must be perfect; don't output anything related to / manifesting the signal(s). Your output is the original text, with everything else unchanged, but all this content completely and entirely removed."
|
|
904
|
+
)
|
|
905
|
+
return "\n".join(lines)
|
|
906
|
+
|
|
907
|
+
# ------------------------------------------------------------------
|
|
908
|
+
def _collect_snippets(
|
|
909
|
+
self,
|
|
910
|
+
codify_df: pd.DataFrame,
|
|
911
|
+
row_idx: int,
|
|
912
|
+
categories: List[str],
|
|
913
|
+
) -> List[str]:
|
|
914
|
+
snippets: List[str] = []
|
|
915
|
+
for cat in categories:
|
|
916
|
+
if cat not in codify_df.columns:
|
|
917
|
+
continue
|
|
918
|
+
raw = codify_df.at[row_idx, cat]
|
|
919
|
+
if isinstance(raw, list):
|
|
920
|
+
snippets.extend(str(s) for s in raw if s)
|
|
921
|
+
elif isinstance(raw, str):
|
|
922
|
+
try:
|
|
923
|
+
parsed = json.loads(raw)
|
|
924
|
+
if isinstance(parsed, list):
|
|
925
|
+
snippets.extend(str(s) for s in parsed if s)
|
|
926
|
+
except Exception:
|
|
927
|
+
if raw:
|
|
928
|
+
snippets.append(str(raw))
|
|
929
|
+
return snippets
|
|
930
|
+
|
|
931
|
+
# ------------------------------------------------------------------
|
|
932
|
+
def _strip_passages(
|
|
933
|
+
self,
|
|
934
|
+
*,
|
|
935
|
+
original_text: str,
|
|
936
|
+
snippets: List[str],
|
|
937
|
+
pct_strip: int,
|
|
938
|
+
row_idx: int,
|
|
939
|
+
) -> str:
|
|
940
|
+
if pct_strip <= 0 or not snippets:
|
|
941
|
+
return self._normalise_ws(original_text)
|
|
942
|
+
unique: List[str] = []
|
|
943
|
+
seen = set()
|
|
944
|
+
for snippet in snippets:
|
|
945
|
+
snippet = snippet.strip()
|
|
946
|
+
if not snippet:
|
|
947
|
+
continue
|
|
948
|
+
if snippet not in seen:
|
|
949
|
+
seen.add(snippet)
|
|
950
|
+
unique.append(snippet)
|
|
951
|
+
if not unique:
|
|
952
|
+
return self._normalise_ws(original_text)
|
|
953
|
+
n_total = len(unique)
|
|
954
|
+
n_remove = max(0, min(n_total, int(round(n_total * (pct_strip / 100.0)))))
|
|
955
|
+
if n_remove == 0:
|
|
956
|
+
return self._normalise_ws(original_text)
|
|
957
|
+
seed_val = f"{self.cfg.random_seed}:{row_idx}:{pct_strip}:{n_total}"
|
|
958
|
+
rng = random.Random()
|
|
959
|
+
rng.seed(seed_val)
|
|
960
|
+
to_remove = rng.sample(unique, n_remove)
|
|
961
|
+
cleaned = original_text
|
|
962
|
+
for snippet in sorted(to_remove, key=len, reverse=True):
|
|
963
|
+
cleaned = cleaned.replace(snippet, " ")
|
|
964
|
+
return self._normalise_ws(cleaned)
|
|
965
|
+
|
|
966
|
+
# ------------------------------------------------------------------
|
|
967
|
+
def _run_debiasing(
|
|
968
|
+
self,
|
|
969
|
+
df: pd.DataFrame,
|
|
970
|
+
original_column: str,
|
|
971
|
+
stripped_column: str,
|
|
972
|
+
*,
|
|
973
|
+
remaining_signal_column: Optional[str],
|
|
974
|
+
variant_key: str,
|
|
975
|
+
display_name: str,
|
|
976
|
+
strip_percentage: Optional[int],
|
|
977
|
+
label_suffix: str,
|
|
978
|
+
) -> DebiasRegressionResult:
|
|
979
|
+
cols = [original_column, stripped_column]
|
|
980
|
+
if remaining_signal_column:
|
|
981
|
+
cols.append(remaining_signal_column)
|
|
982
|
+
measurement_attr = self.cfg.measurement_attribute
|
|
983
|
+
reg_df = df[cols].apply(pd.to_numeric, errors="coerce")
|
|
984
|
+
reg_df = reg_df.dropna(subset=[original_column, stripped_column])
|
|
985
|
+
if len(reg_df) < 3:
|
|
986
|
+
if self.cfg.verbose:
|
|
987
|
+
print(
|
|
988
|
+
f"[Debias] Not enough observations for debiasing on variant '{display_name}'."
|
|
989
|
+
)
|
|
990
|
+
return DebiasRegressionResult(
|
|
991
|
+
variant=variant_key,
|
|
992
|
+
display_name=display_name,
|
|
993
|
+
strip_percentage=strip_percentage,
|
|
994
|
+
correlation=None,
|
|
995
|
+
mean_original=None,
|
|
996
|
+
mean_stripped=None,
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
y_series = reg_df[original_column]
|
|
1000
|
+
s_series = reg_df[stripped_column]
|
|
1001
|
+
|
|
1002
|
+
if np.allclose(np.var(s_series.values), 0.0):
|
|
1003
|
+
if self.cfg.verbose:
|
|
1004
|
+
print(
|
|
1005
|
+
f"[Debias] Stripped column has no variation for variant '{display_name}'."
|
|
1006
|
+
)
|
|
1007
|
+
return DebiasRegressionResult(
|
|
1008
|
+
variant=variant_key,
|
|
1009
|
+
display_name=display_name,
|
|
1010
|
+
strip_percentage=strip_percentage,
|
|
1011
|
+
correlation=None,
|
|
1012
|
+
mean_original=float(y_series.mean()),
|
|
1013
|
+
mean_stripped=float(s_series.mean()),
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
attenuation_pct = self._attenuation_pct(y_series, s_series)
|
|
1017
|
+
if self.cfg.verbose and attenuation_pct is not None:
|
|
1018
|
+
print(
|
|
1019
|
+
"[Debias] Mean attenuation from original to stripped "
|
|
1020
|
+
f"for '{display_name}': {attenuation_pct:.2f}%"
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
diff_col = (
|
|
1024
|
+
f"{measurement_attr} (raw - stripped)"
|
|
1025
|
+
if label_suffix == "stripped"
|
|
1026
|
+
else f"{measurement_attr} (raw - {label_suffix})"
|
|
1027
|
+
)
|
|
1028
|
+
df.loc[reg_df.index, diff_col] = y_series - s_series
|
|
1029
|
+
pretty_diff_col = diff_col
|
|
1030
|
+
|
|
1031
|
+
mean_pct_diff, mean_abs_pct_diff = self._avg_pct_diff(
|
|
1032
|
+
y_series,
|
|
1033
|
+
df.loc[reg_df.index, diff_col],
|
|
1034
|
+
)
|
|
1035
|
+
if self.cfg.verbose:
|
|
1036
|
+
print(
|
|
1037
|
+
"[Debias] Simple difference debiasing percent change "
|
|
1038
|
+
f"(signed / absolute): {mean_pct_diff:.2f}% / {mean_abs_pct_diff:.2f}%"
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
diff_reg = self._regress_against_original(
|
|
1042
|
+
df,
|
|
1043
|
+
original_column=original_column,
|
|
1044
|
+
debiased_column=diff_col,
|
|
1045
|
+
title="original ~ (original - stripped)",
|
|
1046
|
+
plot_filename=f"{variant_key}_diff_vs_original.png",
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
twostep_col = (
|
|
1050
|
+
f"{measurement_attr} (debiased)"
|
|
1051
|
+
if label_suffix == "stripped"
|
|
1052
|
+
else f"{measurement_attr} (debiased, {label_suffix})"
|
|
1053
|
+
)
|
|
1054
|
+
pretty_twostep_col = twostep_col
|
|
1055
|
+
stage1_reg: Optional[Dict[str, Any]] = None
|
|
1056
|
+
delta: Optional[float] = None
|
|
1057
|
+
if remaining_signal_column and remaining_signal_column in df.columns:
|
|
1058
|
+
stage1_reg, delta = self._stage1_remaining_signal_regression(
|
|
1059
|
+
df,
|
|
1060
|
+
stripped_column=stripped_column,
|
|
1061
|
+
remaining_signal_column=remaining_signal_column,
|
|
1062
|
+
display_name=display_name,
|
|
1063
|
+
)
|
|
1064
|
+
if delta is not None:
|
|
1065
|
+
valid_idx, y_vals, s_vals, r_vals = self._aligned_series(
|
|
1066
|
+
df,
|
|
1067
|
+
original_column,
|
|
1068
|
+
stripped_column,
|
|
1069
|
+
remaining_signal_column,
|
|
1070
|
+
)
|
|
1071
|
+
df.loc[valid_idx, twostep_col] = y_vals - s_vals + delta * r_vals
|
|
1072
|
+
df.loc[valid_idx, pretty_twostep_col] = df.loc[valid_idx, twostep_col]
|
|
1073
|
+
mean_pct_two, mean_abs_pct_two = self._avg_pct_diff(
|
|
1074
|
+
y_vals,
|
|
1075
|
+
df.loc[valid_idx, twostep_col],
|
|
1076
|
+
)
|
|
1077
|
+
if self.cfg.verbose:
|
|
1078
|
+
print(
|
|
1079
|
+
"[Debias] Two-step debiasing percent change "
|
|
1080
|
+
f"(signed / absolute): {mean_pct_two:.2f}% / {mean_abs_pct_two:.2f}%"
|
|
1081
|
+
)
|
|
1082
|
+
else:
|
|
1083
|
+
if self.cfg.verbose and self.cfg.remaining_signal_attribute:
|
|
1084
|
+
print(
|
|
1085
|
+
"[Debias] Remaining-signal attribute configured but no remaining-signal "
|
|
1086
|
+
f"column was available for variant '{display_name}'."
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
twostep_reg: Optional[Dict[str, Any]] = None
|
|
1090
|
+
if twostep_col in df.columns and df[twostep_col].notna().sum() >= 3:
|
|
1091
|
+
twostep_reg = self._regress_against_original(
|
|
1092
|
+
df,
|
|
1093
|
+
original_column=original_column,
|
|
1094
|
+
debiased_column=twostep_col,
|
|
1095
|
+
title="original ~ debiased (two-step)",
|
|
1096
|
+
plot_filename=f"{variant_key}_twostep_vs_original.png",
|
|
1097
|
+
)
|
|
1098
|
+
|
|
1099
|
+
correlation = float(y_series.corr(s_series))
|
|
1100
|
+
summary = DebiasRegressionResult(
|
|
1101
|
+
variant=variant_key,
|
|
1102
|
+
display_name=display_name,
|
|
1103
|
+
strip_percentage=strip_percentage,
|
|
1104
|
+
correlation=correlation,
|
|
1105
|
+
mean_original=float(y_series.mean()),
|
|
1106
|
+
mean_stripped=float(s_series.mean()),
|
|
1107
|
+
attenuation_pct=attenuation_pct,
|
|
1108
|
+
diff_regression=diff_reg,
|
|
1109
|
+
twostep_regression=twostep_reg,
|
|
1110
|
+
stage1_regression=stage1_reg,
|
|
1111
|
+
stage1_delta=delta,
|
|
1112
|
+
debiased_columns={
|
|
1113
|
+
"diff": pretty_diff_col,
|
|
1114
|
+
"twostep": pretty_twostep_col,
|
|
1115
|
+
"remaining_signal": remaining_signal_column or "",
|
|
1116
|
+
},
|
|
1117
|
+
)
|
|
1118
|
+
return summary
|
|
1119
|
+
|
|
1120
|
+
# ------------------------------------------------------------------
|
|
1121
|
+
def _aligned_series(
|
|
1122
|
+
self,
|
|
1123
|
+
df: pd.DataFrame,
|
|
1124
|
+
original_column: str,
|
|
1125
|
+
stripped_column: str,
|
|
1126
|
+
remaining_signal_column: str,
|
|
1127
|
+
) -> Tuple[pd.Index, pd.Series, pd.Series, pd.Series]:
|
|
1128
|
+
tmp = df[[original_column, stripped_column, remaining_signal_column]].apply(
|
|
1129
|
+
pd.to_numeric,
|
|
1130
|
+
errors="coerce",
|
|
1131
|
+
)
|
|
1132
|
+
tmp = tmp.dropna()
|
|
1133
|
+
return (
|
|
1134
|
+
tmp.index,
|
|
1135
|
+
tmp[original_column],
|
|
1136
|
+
tmp[stripped_column],
|
|
1137
|
+
tmp[remaining_signal_column],
|
|
1138
|
+
)
|
|
1139
|
+
|
|
1140
|
+
# ------------------------------------------------------------------
|
|
1141
|
+
def _print_average_attenuation(
|
|
1142
|
+
self,
|
|
1143
|
+
df: pd.DataFrame,
|
|
1144
|
+
attributes: List[str],
|
|
1145
|
+
column_map: Dict[str, str],
|
|
1146
|
+
*,
|
|
1147
|
+
label_suffix: str,
|
|
1148
|
+
display_name: str,
|
|
1149
|
+
) -> None:
|
|
1150
|
+
if not self.cfg.verbose:
|
|
1151
|
+
return
|
|
1152
|
+
attenuations: List[float] = []
|
|
1153
|
+
for attr in attributes:
|
|
1154
|
+
stripped_col = column_map.get(attr)
|
|
1155
|
+
if not stripped_col or attr not in df.columns:
|
|
1156
|
+
continue
|
|
1157
|
+
raw_series = pd.to_numeric(df[attr], errors="coerce")
|
|
1158
|
+
stripped_series = pd.to_numeric(df[stripped_col], errors="coerce")
|
|
1159
|
+
tmp = pd.concat([raw_series, stripped_series], axis=1).dropna()
|
|
1160
|
+
if tmp.empty:
|
|
1161
|
+
continue
|
|
1162
|
+
raw_mean = float(tmp.iloc[:, 0].mean())
|
|
1163
|
+
stripped_mean = float(tmp.iloc[:, 1].mean())
|
|
1164
|
+
if np.isclose(raw_mean, 0.0):
|
|
1165
|
+
continue
|
|
1166
|
+
pct = (stripped_mean - raw_mean) / raw_mean * 100.0
|
|
1167
|
+
attenuations.append(float(pct))
|
|
1168
|
+
if not attenuations:
|
|
1169
|
+
print(
|
|
1170
|
+
"[Debias] Unable to compute average percent attenuation from raw to stripped "
|
|
1171
|
+
f"for variant '{display_name}'."
|
|
1172
|
+
)
|
|
1173
|
+
return
|
|
1174
|
+
avg_pct = float(np.mean(attenuations))
|
|
1175
|
+
print(
|
|
1176
|
+
"[Debias] Average percent attenuation from raw to "
|
|
1177
|
+
f"'{label_suffix}' across measurement attributes: {avg_pct:.2f}%"
|
|
1178
|
+
)
|
|
1179
|
+
|
|
1180
|
+
# ------------------------------------------------------------------
|
|
1181
|
+
def _attenuation_pct(
|
|
1182
|
+
self,
|
|
1183
|
+
original: pd.Series,
|
|
1184
|
+
stripped: pd.Series,
|
|
1185
|
+
) -> Optional[float]:
|
|
1186
|
+
mean_original = float(pd.to_numeric(original, errors="coerce").mean())
|
|
1187
|
+
mean_stripped = float(pd.to_numeric(stripped, errors="coerce").mean())
|
|
1188
|
+
if np.isclose(mean_original, 0.0):
|
|
1189
|
+
return None
|
|
1190
|
+
return float((mean_original - mean_stripped) / mean_original * 100.0)
|
|
1191
|
+
|
|
1192
|
+
# ------------------------------------------------------------------
|
|
1193
|
+
def _avg_pct_diff(
|
|
1194
|
+
self,
|
|
1195
|
+
raw: pd.Series,
|
|
1196
|
+
debiased: pd.Series,
|
|
1197
|
+
) -> Tuple[float, float]:
|
|
1198
|
+
raw_num = pd.to_numeric(raw, errors="coerce")
|
|
1199
|
+
deb_num = pd.to_numeric(debiased, errors="coerce")
|
|
1200
|
+
pct = (deb_num - raw_num) / raw_num * 100.0
|
|
1201
|
+
pct = pct.replace([np.inf, -np.inf], np.nan)
|
|
1202
|
+
pct = pct[raw_num != 0]
|
|
1203
|
+
return float(pct.mean()), float(pct.abs().mean())
|
|
1204
|
+
|
|
1205
|
+
# ------------------------------------------------------------------
|
|
1206
|
+
def _stage1_remaining_signal_regression(
|
|
1207
|
+
self,
|
|
1208
|
+
df: pd.DataFrame,
|
|
1209
|
+
*,
|
|
1210
|
+
stripped_column: str,
|
|
1211
|
+
remaining_signal_column: str,
|
|
1212
|
+
display_name: str,
|
|
1213
|
+
) -> Tuple[Optional[Dict[str, Any]], Optional[float]]:
|
|
1214
|
+
stage1_df = df[[stripped_column, remaining_signal_column]].apply(
|
|
1215
|
+
pd.to_numeric,
|
|
1216
|
+
errors="coerce",
|
|
1217
|
+
).dropna()
|
|
1218
|
+
if len(stage1_df) < 3:
|
|
1219
|
+
if self.cfg.verbose:
|
|
1220
|
+
print(
|
|
1221
|
+
"[Debias] Not enough observations for remaining-signal regression "
|
|
1222
|
+
f"on variant '{display_name}'."
|
|
1223
|
+
)
|
|
1224
|
+
return None, None
|
|
1225
|
+
|
|
1226
|
+
r_vals = stage1_df[remaining_signal_column].values
|
|
1227
|
+
if np.allclose(np.var(r_vals), 0.0):
|
|
1228
|
+
if self.cfg.verbose:
|
|
1229
|
+
print(
|
|
1230
|
+
"[Debias] Remaining-signal proxy has no variation for "
|
|
1231
|
+
f"variant '{display_name}'."
|
|
1232
|
+
)
|
|
1233
|
+
return None, None
|
|
1234
|
+
|
|
1235
|
+
s_vals = stage1_df[stripped_column].values
|
|
1236
|
+
X1 = np.column_stack([np.ones(len(stage1_df)), r_vals])
|
|
1237
|
+
reg = _safe_fit_ols(
|
|
1238
|
+
s_vals,
|
|
1239
|
+
X1,
|
|
1240
|
+
robust=self.cfg.robust_regression,
|
|
1241
|
+
varnames=["Intercept", remaining_signal_column],
|
|
1242
|
+
)
|
|
1243
|
+
if self.cfg.verbose:
|
|
1244
|
+
self._print_generic_regression_table(
|
|
1245
|
+
reg,
|
|
1246
|
+
names=["Intercept", remaining_signal_column],
|
|
1247
|
+
title=f"Stage-1 regression: stripped ~ remaining signal [{display_name}]",
|
|
1248
|
+
)
|
|
1249
|
+
delta = float(reg["coef"][1])
|
|
1250
|
+
return self._regression_dict(reg, ["Intercept", remaining_signal_column]), delta
|
|
1251
|
+
|
|
1252
|
+
# ------------------------------------------------------------------
|
|
1253
|
+
def _regress_against_original(
|
|
1254
|
+
self,
|
|
1255
|
+
df: pd.DataFrame,
|
|
1256
|
+
*,
|
|
1257
|
+
original_column: str,
|
|
1258
|
+
debiased_column: str,
|
|
1259
|
+
title: str,
|
|
1260
|
+
plot_filename: str,
|
|
1261
|
+
) -> Dict[str, Any]:
|
|
1262
|
+
reg_df = df[[original_column, debiased_column]].apply(
|
|
1263
|
+
pd.to_numeric,
|
|
1264
|
+
errors="coerce",
|
|
1265
|
+
).dropna()
|
|
1266
|
+
if len(reg_df) < 3:
|
|
1267
|
+
return {}
|
|
1268
|
+
y_vals = reg_df[original_column].values
|
|
1269
|
+
x_vals = reg_df[debiased_column].values
|
|
1270
|
+
X = np.column_stack([np.ones(len(reg_df)), x_vals])
|
|
1271
|
+
reg_res = _safe_fit_ols(
|
|
1272
|
+
y_vals,
|
|
1273
|
+
X,
|
|
1274
|
+
robust=self.cfg.robust_regression,
|
|
1275
|
+
varnames=["Intercept", debiased_column],
|
|
1276
|
+
)
|
|
1277
|
+
if self.cfg.verbose:
|
|
1278
|
+
self._print_generic_regression_table(
|
|
1279
|
+
reg_res,
|
|
1280
|
+
names=["Intercept", debiased_column],
|
|
1281
|
+
title=title,
|
|
1282
|
+
)
|
|
1283
|
+
self._save_binned_scatter_with_fit(
|
|
1284
|
+
reg_df,
|
|
1285
|
+
x_col=debiased_column,
|
|
1286
|
+
y_col=original_column,
|
|
1287
|
+
reg_res=reg_res,
|
|
1288
|
+
filename=plot_filename,
|
|
1289
|
+
title=title,
|
|
1290
|
+
bins=20,
|
|
1291
|
+
)
|
|
1292
|
+
return self._regression_dict(reg_res, ["Intercept", debiased_column]) or {}
|
|
1293
|
+
|
|
1294
|
+
# ------------------------------------------------------------------
|
|
1295
|
+
def _save_binned_scatter_with_fit(
|
|
1296
|
+
self,
|
|
1297
|
+
df: pd.DataFrame,
|
|
1298
|
+
*,
|
|
1299
|
+
x_col: str,
|
|
1300
|
+
y_col: str,
|
|
1301
|
+
reg_res: Dict[str, Any],
|
|
1302
|
+
filename: str,
|
|
1303
|
+
title: str,
|
|
1304
|
+
bins: int,
|
|
1305
|
+
) -> Optional[str]:
|
|
1306
|
+
if df.empty:
|
|
1307
|
+
return None
|
|
1308
|
+
data = df[[x_col, y_col]].dropna().copy()
|
|
1309
|
+
if data.empty:
|
|
1310
|
+
return None
|
|
1311
|
+
data[x_col] = pd.to_numeric(data[x_col], errors="coerce")
|
|
1312
|
+
data[y_col] = pd.to_numeric(data[y_col], errors="coerce")
|
|
1313
|
+
data = data.dropna()
|
|
1314
|
+
if data.empty:
|
|
1315
|
+
return None
|
|
1316
|
+
|
|
1317
|
+
n_bins = max(1, min(int(bins), len(data)))
|
|
1318
|
+
data["_bin"] = pd.qcut(data[x_col], q=n_bins, duplicates="drop")
|
|
1319
|
+
grp = data.groupby("_bin", observed=True)
|
|
1320
|
+
xm = grp[x_col].mean()
|
|
1321
|
+
ym = grp[y_col].mean()
|
|
1322
|
+
counts = grp[y_col].count().astype(float)
|
|
1323
|
+
std = grp[y_col].std(ddof=1)
|
|
1324
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
1325
|
+
yerr = std / np.sqrt(counts)
|
|
1326
|
+
yerr = yerr.fillna(0.0)
|
|
1327
|
+
|
|
1328
|
+
x_vals = data[x_col].values.astype(float)
|
|
1329
|
+
beta0, beta1 = reg_res["coef"][0], reg_res["coef"][1]
|
|
1330
|
+
x_min = float(np.min(x_vals))
|
|
1331
|
+
x_max = float(np.max(x_vals))
|
|
1332
|
+
if x_max == x_min:
|
|
1333
|
+
pad = 1.0 if x_max == 0 else 0.05 * abs(x_max)
|
|
1334
|
+
x_min -= pad
|
|
1335
|
+
x_max += pad
|
|
1336
|
+
x_line = np.linspace(x_min, x_max, 200)
|
|
1337
|
+
y_line = beta0 + beta1 * x_line
|
|
1338
|
+
|
|
1339
|
+
fig, ax = plt.subplots(figsize=(7, 5), dpi=400)
|
|
1340
|
+
colours = plt.cm.get_cmap("rainbow")(np.linspace(0, 1, len(xm)))
|
|
1341
|
+
ax.errorbar(
|
|
1342
|
+
xm,
|
|
1343
|
+
ym,
|
|
1344
|
+
yerr=yerr,
|
|
1345
|
+
fmt="o",
|
|
1346
|
+
color="black",
|
|
1347
|
+
ecolor="black",
|
|
1348
|
+
capsize=3,
|
|
1349
|
+
markersize=6,
|
|
1350
|
+
)
|
|
1351
|
+
ax.scatter(xm, ym, c=colours, s=50, zorder=3)
|
|
1352
|
+
ax.plot(x_line, y_line, color="black", linewidth=2, label="OLS fit")
|
|
1353
|
+
ax.set_title(title)
|
|
1354
|
+
ax.set_xlabel(x_col)
|
|
1355
|
+
ax.set_ylabel(y_col)
|
|
1356
|
+
ax.margins(x=0.05, y=0.05)
|
|
1357
|
+
ax.legend(loc="best")
|
|
1358
|
+
ax.grid(True, alpha=0.2)
|
|
1359
|
+
fig.tight_layout()
|
|
1360
|
+
|
|
1361
|
+
plot_path = os.path.join(self.run_dir, filename)
|
|
1362
|
+
fig.savefig(plot_path)
|
|
1363
|
+
plt.close(fig)
|
|
1364
|
+
if self.cfg.verbose:
|
|
1365
|
+
print(f"[Debias] Saved regression plot to {plot_path}")
|
|
1366
|
+
return plot_path
|
|
1367
|
+
|
|
1368
|
+
# ------------------------------------------------------------------
|
|
1369
|
+
def _print_regression_table(
|
|
1370
|
+
self,
|
|
1371
|
+
result: Dict[str, Any],
|
|
1372
|
+
rename_map: Dict[str, str],
|
|
1373
|
+
stripped_column: str,
|
|
1374
|
+
title: str,
|
|
1375
|
+
) -> None:
|
|
1376
|
+
names = ["Intercept", rename_map[stripped_column]]
|
|
1377
|
+
table = pd.DataFrame(
|
|
1378
|
+
{
|
|
1379
|
+
"coef": result["coef"],
|
|
1380
|
+
"se": result["se"],
|
|
1381
|
+
"t": result["t"],
|
|
1382
|
+
"p": result["p"],
|
|
1383
|
+
},
|
|
1384
|
+
index=names,
|
|
1385
|
+
)
|
|
1386
|
+
print(f"\n{title}")
|
|
1387
|
+
print(table.round(6).to_string())
|
|
1388
|
+
print(
|
|
1389
|
+
f"R^2 = {result['r2']:.4f}, adj. R^2 = {result['adj_r2']:.4f}, n = {result['n']}"
|
|
1390
|
+
)
|
|
1391
|
+
|
|
1392
|
+
# ------------------------------------------------------------------
|
|
1393
|
+
def _print_generic_regression_table(
|
|
1394
|
+
self,
|
|
1395
|
+
result: Dict[str, Any],
|
|
1396
|
+
*,
|
|
1397
|
+
names: List[str],
|
|
1398
|
+
title: str,
|
|
1399
|
+
) -> None:
|
|
1400
|
+
table = pd.DataFrame(
|
|
1401
|
+
{
|
|
1402
|
+
"coef": result["coef"],
|
|
1403
|
+
"se": result["se"],
|
|
1404
|
+
"t": result["t"],
|
|
1405
|
+
"p": result["p"],
|
|
1406
|
+
},
|
|
1407
|
+
index=names,
|
|
1408
|
+
)
|
|
1409
|
+
print(f"\n{title}")
|
|
1410
|
+
print(table.round(6).to_string())
|
|
1411
|
+
print(
|
|
1412
|
+
f"R^2 = {result['r2']:.4f}, adj. R^2 = {result['adj_r2']:.4f}, n = {result['n']}"
|
|
1413
|
+
)
|
|
1414
|
+
|
|
1415
|
+
# ------------------------------------------------------------------
|
|
1416
|
+
def _regression_dict(
|
|
1417
|
+
self,
|
|
1418
|
+
result: Optional[Dict[str, Any]],
|
|
1419
|
+
names: List[str],
|
|
1420
|
+
) -> Optional[Dict[str, Any]]:
|
|
1421
|
+
if result is None:
|
|
1422
|
+
return None
|
|
1423
|
+
return {
|
|
1424
|
+
"coef": {name: float(val) for name, val in zip(names, result["coef"])},
|
|
1425
|
+
"se": {name: float(val) for name, val in zip(names, result["se"])},
|
|
1426
|
+
"t": {name: float(val) for name, val in zip(names, result["t"])},
|
|
1427
|
+
"p": {name: float(val) for name, val in zip(names, result["p"])},
|
|
1428
|
+
"r2": float(result["r2"]),
|
|
1429
|
+
"adj_r2": float(result["adj_r2"]),
|
|
1430
|
+
"n": int(result["n"]),
|
|
1431
|
+
}
|
|
1432
|
+
|
|
1433
|
+
# ------------------------------------------------------------------
|
|
1434
|
+
def _serialise_config(self) -> Dict[str, Any]:
|
|
1435
|
+
def _convert(value: Any) -> Any:
|
|
1436
|
+
if isinstance(value, dict):
|
|
1437
|
+
return {k: _convert(v) for k, v in value.items()}
|
|
1438
|
+
if isinstance(value, (list, tuple, set)):
|
|
1439
|
+
return [_convert(v) for v in value]
|
|
1440
|
+
if isinstance(value, np.generic):
|
|
1441
|
+
return value.item()
|
|
1442
|
+
if callable(value):
|
|
1443
|
+
name = getattr(value, "__name__", None)
|
|
1444
|
+
return name or repr(value)
|
|
1445
|
+
try:
|
|
1446
|
+
json.dumps(value)
|
|
1447
|
+
return value
|
|
1448
|
+
except TypeError:
|
|
1449
|
+
return repr(value)
|
|
1450
|
+
|
|
1451
|
+
cfg_dict = asdict(self.cfg)
|
|
1452
|
+
cfg_dict["save_dir"] = os.path.expandvars(os.path.expanduser(self.cfg.save_dir))
|
|
1453
|
+
cfg_dict["measurement_kwargs"] = _convert(self.cfg.measurement_kwargs)
|
|
1454
|
+
cfg_dict["removal_kwargs"] = _convert(self.cfg.removal_kwargs)
|
|
1455
|
+
return cfg_dict
|
|
1456
|
+
|
|
1457
|
+
# ------------------------------------------------------------------
|
|
1458
|
+
@staticmethod
|
|
1459
|
+
def _normalise_ws(text: str) -> str:
|
|
1460
|
+
return " ".join(str(text or "").split())
|