openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import hashlib
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from scipy.cluster.vq import kmeans2
|
|
14
|
+
|
|
15
|
+
from ..core.prompt_template import PromptTemplate, resolve_template
|
|
16
|
+
from ..utils.openai_utils import get_all_responses
|
|
17
|
+
from ..utils import (
|
|
18
|
+
safest_json,
|
|
19
|
+
safe_json,
|
|
20
|
+
get_all_embeddings,
|
|
21
|
+
warn_if_modality_mismatch,
|
|
22
|
+
)
|
|
23
|
+
from ..utils.logging import announce_prompt_rendering
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class DeduplicateConfig:
|
|
28
|
+
"""Configuration for :class:`Deduplicate`."""
|
|
29
|
+
|
|
30
|
+
save_dir: str = "deduplicate"
|
|
31
|
+
file_name: str = "deduplicate_responses.csv"
|
|
32
|
+
model: str = "gpt-5-mini"
|
|
33
|
+
n_parallels: int = 650
|
|
34
|
+
n_runs: int = 3
|
|
35
|
+
use_dummy: bool = False
|
|
36
|
+
max_timeout: Optional[float] = None
|
|
37
|
+
additional_instructions: Optional[str] = None
|
|
38
|
+
use_embeddings: bool = True
|
|
39
|
+
group_size: int = 500
|
|
40
|
+
modality: str = "entity"
|
|
41
|
+
max_words_per_text: int = 500
|
|
42
|
+
|
|
43
|
+
def __post_init__(self) -> None:
|
|
44
|
+
if self.additional_instructions is not None:
|
|
45
|
+
cleaned = str(self.additional_instructions).strip()
|
|
46
|
+
self.additional_instructions = cleaned or None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Deduplicate:
|
|
50
|
+
"""LLM-assisted deduplication for a single DataFrame column."""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
cfg: DeduplicateConfig,
|
|
55
|
+
template: Optional[PromptTemplate] = None,
|
|
56
|
+
template_path: Optional[str] = None,
|
|
57
|
+
) -> None:
|
|
58
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
59
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
60
|
+
cfg.save_dir = str(expanded)
|
|
61
|
+
self.cfg = cfg
|
|
62
|
+
self.template = resolve_template(
|
|
63
|
+
template=template,
|
|
64
|
+
template_path=template_path,
|
|
65
|
+
reference_filename="deduplicate_prompt.jinja2",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# ------------------------------------------------------------------
|
|
69
|
+
@staticmethod
|
|
70
|
+
def _deduplicate(series: pd.Series) -> Tuple[List[str], Dict[str, List[str]], Dict[str, str]]:
|
|
71
|
+
"""Return (unique_values, rep_to_group, orig_to_rep) for a Series."""
|
|
72
|
+
rep_map: Dict[str, str] = {}
|
|
73
|
+
groups: Dict[str, List[str]] = {}
|
|
74
|
+
orig_to_rep: Dict[str, str] = {}
|
|
75
|
+
for val in series.dropna().astype(str):
|
|
76
|
+
norm = re.sub(r"[^0-9a-z]+", "", val.lower())
|
|
77
|
+
if norm in rep_map:
|
|
78
|
+
rep = rep_map[norm]
|
|
79
|
+
groups[rep].append(val)
|
|
80
|
+
else:
|
|
81
|
+
rep_map[norm] = val
|
|
82
|
+
groups[val] = [val]
|
|
83
|
+
orig_to_rep[val] = rep_map[norm]
|
|
84
|
+
uniques = list(groups.keys())
|
|
85
|
+
return uniques, groups, orig_to_rep
|
|
86
|
+
|
|
87
|
+
# ------------------------------------------------------------------
|
|
88
|
+
@staticmethod
|
|
89
|
+
def _print_stats(before: pd.Series, after: pd.Series, *, run_idx: int, total_runs: int) -> None:
|
|
90
|
+
total = before.notna().sum()
|
|
91
|
+
diff = (before.fillna("<NA>") != after.fillna("<NA>")).sum()
|
|
92
|
+
percent = (diff / total * 100) if total else 0.0
|
|
93
|
+
unique_mapped = after.dropna().nunique()
|
|
94
|
+
avg_per_map = (total / unique_mapped) if unique_mapped else 0.0
|
|
95
|
+
print(
|
|
96
|
+
f"[Deduplicate] Run {run_idx + 1}/{total_runs}: {diff} deduplications "
|
|
97
|
+
f"({percent:.2f}% of {total})."
|
|
98
|
+
)
|
|
99
|
+
print(
|
|
100
|
+
f"[Deduplicate] Unique mapped terms: {unique_mapped}; "
|
|
101
|
+
f"avg terms per mapping: {avg_per_map:.2f}."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# ------------------------------------------------------------------
|
|
105
|
+
async def _run_once(
|
|
106
|
+
self,
|
|
107
|
+
df_proc: pd.DataFrame,
|
|
108
|
+
*,
|
|
109
|
+
column_name: str,
|
|
110
|
+
output_col: str,
|
|
111
|
+
raw_texts: Optional[Dict[str, str]] = None,
|
|
112
|
+
reset_files: bool,
|
|
113
|
+
run_idx: int,
|
|
114
|
+
**kwargs: Any,
|
|
115
|
+
) -> None:
|
|
116
|
+
uniques, groups, orig_to_rep = self._deduplicate(df_proc[column_name])
|
|
117
|
+
|
|
118
|
+
use_embeddings = self.cfg.use_embeddings and len(uniques) >= self.cfg.group_size
|
|
119
|
+
|
|
120
|
+
batches: List[List[str]] = []
|
|
121
|
+
if use_embeddings:
|
|
122
|
+
embed_texts = uniques
|
|
123
|
+
if raw_texts is not None:
|
|
124
|
+
embed_texts = [raw_texts.get(u, u) for u in uniques]
|
|
125
|
+
emb = await get_all_embeddings(
|
|
126
|
+
texts=embed_texts,
|
|
127
|
+
identifiers=uniques,
|
|
128
|
+
save_path=os.path.join(self.cfg.save_dir, "deduplicate_embeddings.pkl"),
|
|
129
|
+
reset_file=reset_files and run_idx == 0,
|
|
130
|
+
use_dummy=self.cfg.use_dummy,
|
|
131
|
+
verbose=False,
|
|
132
|
+
)
|
|
133
|
+
if emb:
|
|
134
|
+
arr = np.array([emb[u] for u in uniques], dtype=float)
|
|
135
|
+
k = max(1, int(np.ceil(len(uniques) / self.cfg.group_size)))
|
|
136
|
+
_, labels = kmeans2(arr, k, minit="points")
|
|
137
|
+
clusters: List[List[str]] = [[] for _ in range(k)]
|
|
138
|
+
for term, lbl in zip(uniques, labels):
|
|
139
|
+
clusters[int(lbl)].append(term)
|
|
140
|
+
current: List[str] = []
|
|
141
|
+
for cluster in clusters:
|
|
142
|
+
for term in cluster:
|
|
143
|
+
current.append(term)
|
|
144
|
+
if len(current) >= self.cfg.group_size:
|
|
145
|
+
batches.append(current)
|
|
146
|
+
current = []
|
|
147
|
+
if current:
|
|
148
|
+
batches.append(current)
|
|
149
|
+
if not batches:
|
|
150
|
+
sorted_uniques = sorted(uniques, key=lambda x: x.lower())
|
|
151
|
+
for i in range(0, len(sorted_uniques), self.cfg.group_size):
|
|
152
|
+
batches.append(sorted_uniques[i : i + self.cfg.group_size])
|
|
153
|
+
|
|
154
|
+
prompts: List[str] = []
|
|
155
|
+
identifiers: List[str] = []
|
|
156
|
+
announce_prompt_rendering("Deduplicate", len(batches) * max(1, self.cfg.n_runs))
|
|
157
|
+
for idx, items in enumerate(batches):
|
|
158
|
+
raw_terms: Any = items
|
|
159
|
+
if raw_texts is not None:
|
|
160
|
+
raw_terms = {ident: raw_texts.get(ident, "") for ident in items}
|
|
161
|
+
prompts.append(
|
|
162
|
+
self.template.render(
|
|
163
|
+
group_id=f"deduplicate_{idx:05d}",
|
|
164
|
+
raw_terms=raw_terms,
|
|
165
|
+
additional_instructions=self.cfg.additional_instructions or "",
|
|
166
|
+
modality=self.cfg.modality,
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
identifiers.append(f"deduplicate_{idx:05d}")
|
|
170
|
+
|
|
171
|
+
base, ext = os.path.splitext(self.cfg.file_name)
|
|
172
|
+
if self.cfg.n_runs > 1:
|
|
173
|
+
response_file = f"{base}_run{run_idx + 1}{ext}"
|
|
174
|
+
else:
|
|
175
|
+
response_file = self.cfg.file_name
|
|
176
|
+
save_path = os.path.join(self.cfg.save_dir, response_file)
|
|
177
|
+
if prompts:
|
|
178
|
+
resp_df = await get_all_responses(
|
|
179
|
+
prompts=prompts,
|
|
180
|
+
identifiers=identifiers,
|
|
181
|
+
n_parallels=self.cfg.n_parallels,
|
|
182
|
+
model=self.cfg.model,
|
|
183
|
+
save_path=save_path,
|
|
184
|
+
use_dummy=self.cfg.use_dummy,
|
|
185
|
+
max_timeout=self.cfg.max_timeout,
|
|
186
|
+
json_mode=True,
|
|
187
|
+
reset_files=reset_files,
|
|
188
|
+
**kwargs,
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
resp_df = pd.DataFrame(columns=["Identifier", "Response"])
|
|
192
|
+
|
|
193
|
+
resp_map = dict(zip(resp_df.get("Identifier", []), resp_df.get("Response", [])))
|
|
194
|
+
parsed = await asyncio.gather(
|
|
195
|
+
*[safest_json(resp_map.get(i, "")) for i in identifiers]
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
mappings: Dict[str, str] = {}
|
|
199
|
+
for items, res in zip(batches, parsed):
|
|
200
|
+
if isinstance(res, str):
|
|
201
|
+
res = safe_json(res)
|
|
202
|
+
if isinstance(res, dict):
|
|
203
|
+
for rep, vals in res.items():
|
|
204
|
+
if isinstance(vals, list):
|
|
205
|
+
for val in vals:
|
|
206
|
+
if isinstance(val, str) and val in items:
|
|
207
|
+
mappings[val] = rep
|
|
208
|
+
elif isinstance(res, list):
|
|
209
|
+
for row in res:
|
|
210
|
+
if isinstance(row, str):
|
|
211
|
+
row = safe_json(row)
|
|
212
|
+
if isinstance(row, dict):
|
|
213
|
+
inp = row.get("input")
|
|
214
|
+
mapped = row.get("mapped")
|
|
215
|
+
if (
|
|
216
|
+
isinstance(inp, str)
|
|
217
|
+
and isinstance(mapped, str)
|
|
218
|
+
and inp in items
|
|
219
|
+
):
|
|
220
|
+
mappings[inp] = mapped
|
|
221
|
+
|
|
222
|
+
for rep in uniques:
|
|
223
|
+
mappings.setdefault(rep, rep)
|
|
224
|
+
|
|
225
|
+
mapped_vals: List[Optional[str]] = []
|
|
226
|
+
for val in df_proc[column_name]:
|
|
227
|
+
if pd.isna(val):
|
|
228
|
+
mapped_vals.append(val)
|
|
229
|
+
else:
|
|
230
|
+
rep = orig_to_rep.get(str(val), str(val))
|
|
231
|
+
mapped_vals.append(mappings.get(rep, rep))
|
|
232
|
+
df_proc[output_col] = mapped_vals
|
|
233
|
+
|
|
234
|
+
# ------------------------------------------------------------------
|
|
235
|
+
|
|
236
|
+
async def run(
|
|
237
|
+
self,
|
|
238
|
+
df: pd.DataFrame,
|
|
239
|
+
*,
|
|
240
|
+
column_name: str,
|
|
241
|
+
reset_files: bool = False,
|
|
242
|
+
nruns: Optional[int] = None,
|
|
243
|
+
**kwargs: Any,
|
|
244
|
+
) -> pd.DataFrame:
|
|
245
|
+
"""Deduplicate a column using LLM assistance."""
|
|
246
|
+
|
|
247
|
+
df_proc = df.reset_index(drop=True).copy()
|
|
248
|
+
n_runs = nruns if nruns is not None else self.cfg.n_runs
|
|
249
|
+
values = df_proc[column_name].tolist()
|
|
250
|
+
warn_if_modality_mismatch(values, self.cfg.modality, column_name=column_name)
|
|
251
|
+
current_col = column_name
|
|
252
|
+
raw_texts: Optional[Dict[str, str]] = None
|
|
253
|
+
id_to_original: Dict[str, Any] = {}
|
|
254
|
+
|
|
255
|
+
if self.cfg.modality == "text":
|
|
256
|
+
if self.cfg.group_size > 25:
|
|
257
|
+
print(
|
|
258
|
+
"[Deduplicate] For modality='text', consider a smaller group_size "
|
|
259
|
+
"(e.g., 25) to keep prompts concise."
|
|
260
|
+
)
|
|
261
|
+
ids: List[Optional[str]] = []
|
|
262
|
+
truncated: List[Optional[str]] = []
|
|
263
|
+
truncated_count = 0
|
|
264
|
+
raw_texts = {}
|
|
265
|
+
for value in values:
|
|
266
|
+
if pd.isna(value):
|
|
267
|
+
ids.append(None)
|
|
268
|
+
truncated.append(None)
|
|
269
|
+
continue
|
|
270
|
+
text = str(value)
|
|
271
|
+
sha8 = hashlib.sha1(text.encode()).hexdigest()[:8]
|
|
272
|
+
ids.append(sha8)
|
|
273
|
+
if sha8 not in id_to_original:
|
|
274
|
+
id_to_original[sha8] = value
|
|
275
|
+
words = text.split()
|
|
276
|
+
if len(words) > self.cfg.max_words_per_text:
|
|
277
|
+
truncated_count += 1
|
|
278
|
+
clipped = " ".join(words[: self.cfg.max_words_per_text])
|
|
279
|
+
else:
|
|
280
|
+
clipped = text
|
|
281
|
+
truncated.append(clipped)
|
|
282
|
+
raw_texts[sha8] = clipped
|
|
283
|
+
df_proc["_dedup_id"] = ids
|
|
284
|
+
df_proc[f"{column_name}_truncated"] = truncated
|
|
285
|
+
total = sum(1 for v in values if not pd.isna(v))
|
|
286
|
+
frac = (truncated_count / total * 100) if total else 0.0
|
|
287
|
+
print(
|
|
288
|
+
f"[Deduplicate] Truncated {truncated_count}/{total} texts "
|
|
289
|
+
f"({frac:.2f}%) to max_words_per_text={self.cfg.max_words_per_text}."
|
|
290
|
+
)
|
|
291
|
+
current_col = "_dedup_id"
|
|
292
|
+
for i in range(n_runs):
|
|
293
|
+
if self.cfg.modality == "text":
|
|
294
|
+
base = f"mapped_{column_name}_ids"
|
|
295
|
+
if n_runs == 1:
|
|
296
|
+
output_col = base
|
|
297
|
+
elif i == n_runs - 1:
|
|
298
|
+
output_col = f"{base}_final"
|
|
299
|
+
else:
|
|
300
|
+
output_col = f"{base}_run{i + 1}"
|
|
301
|
+
else:
|
|
302
|
+
if n_runs == 1:
|
|
303
|
+
output_col = f"mapped_{column_name}"
|
|
304
|
+
elif i == n_runs - 1:
|
|
305
|
+
output_col = f"mapped_{column_name}_final"
|
|
306
|
+
else:
|
|
307
|
+
output_col = f"mapped_{column_name}_run{i + 1}"
|
|
308
|
+
await self._run_once(
|
|
309
|
+
df_proc,
|
|
310
|
+
column_name=current_col,
|
|
311
|
+
output_col=output_col,
|
|
312
|
+
raw_texts=raw_texts,
|
|
313
|
+
reset_files=reset_files,
|
|
314
|
+
run_idx=i,
|
|
315
|
+
**kwargs,
|
|
316
|
+
)
|
|
317
|
+
self._print_stats(
|
|
318
|
+
df_proc[current_col],
|
|
319
|
+
df_proc[output_col],
|
|
320
|
+
run_idx=i,
|
|
321
|
+
total_runs=n_runs,
|
|
322
|
+
)
|
|
323
|
+
current_col = output_col
|
|
324
|
+
if self.cfg.modality == "text":
|
|
325
|
+
if n_runs > 1:
|
|
326
|
+
df_proc[f"mapped_{column_name}_ids"] = df_proc[current_col]
|
|
327
|
+
final_ids = (
|
|
328
|
+
df_proc[f"mapped_{column_name}_ids"]
|
|
329
|
+
if f"mapped_{column_name}_ids" in df_proc.columns
|
|
330
|
+
else df_proc[current_col]
|
|
331
|
+
)
|
|
332
|
+
mapped_texts: List[Any] = []
|
|
333
|
+
for val in final_ids:
|
|
334
|
+
if pd.isna(val):
|
|
335
|
+
mapped_texts.append(val)
|
|
336
|
+
else:
|
|
337
|
+
mapped_texts.append(id_to_original.get(str(val), val))
|
|
338
|
+
df_proc[f"mapped_{column_name}"] = mapped_texts
|
|
339
|
+
elif n_runs > 1:
|
|
340
|
+
df_proc[f"mapped_{column_name}"] = df_proc[current_col]
|
|
341
|
+
return df_proc
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from ..core.prompt_template import PromptTemplate, resolve_template
|
|
15
|
+
from ..utils import safest_json
|
|
16
|
+
from ..utils.openai_utils import get_all_responses
|
|
17
|
+
from ..utils.logging import announce_prompt_rendering
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# ────────────────────────────
|
|
21
|
+
# Configuration dataclass
|
|
22
|
+
# ────────────────────────────
|
|
23
|
+
@dataclass
|
|
24
|
+
class DeidentifyConfig:
|
|
25
|
+
"""Configuration for :class:`Deidentifier`."""
|
|
26
|
+
|
|
27
|
+
model: str = "gpt-5-mini"
|
|
28
|
+
n_parallels: int = 650
|
|
29
|
+
save_dir: str = "deidentify"
|
|
30
|
+
file_name: str = "deidentified.csv"
|
|
31
|
+
use_dummy: bool = False
|
|
32
|
+
max_timeout: Optional[float] = None
|
|
33
|
+
max_words_per_call: int = 7500
|
|
34
|
+
additional_instructions: Optional[str] = None
|
|
35
|
+
reasoning_effort: Optional[str] = None
|
|
36
|
+
reasoning_summary: Optional[str] = None
|
|
37
|
+
n_passes: int = 1
|
|
38
|
+
use_existing_mappings_only: bool = False
|
|
39
|
+
|
|
40
|
+
def __post_init__(self) -> None:
|
|
41
|
+
if not isinstance(self.n_passes, int) or self.n_passes < 1:
|
|
42
|
+
raise ValueError("n_passes must be an integer >= 1")
|
|
43
|
+
if self.additional_instructions is not None:
|
|
44
|
+
cleaned = str(self.additional_instructions).strip()
|
|
45
|
+
self.additional_instructions = cleaned or None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# ────────────────────────────
|
|
49
|
+
# Main de-identification task
|
|
50
|
+
# ────────────────────────────
|
|
51
|
+
class Deidentifier:
|
|
52
|
+
"""Iterative de-identification of sensitive entities in text."""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
cfg: DeidentifyConfig,
|
|
57
|
+
template: Optional[PromptTemplate] = None,
|
|
58
|
+
template_path: Optional[str] = None,
|
|
59
|
+
) -> None:
|
|
60
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
61
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
62
|
+
cfg.save_dir = str(expanded)
|
|
63
|
+
self.cfg = cfg
|
|
64
|
+
self.template = resolve_template(
|
|
65
|
+
template=template,
|
|
66
|
+
template_path=template_path,
|
|
67
|
+
reference_filename="deidentification_prompt.jinja2",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# ------------------------------------------------------------------
|
|
71
|
+
# Helpers
|
|
72
|
+
# ------------------------------------------------------------------
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _chunk_by_words(text: str, max_words: int) -> List[str]:
|
|
75
|
+
words = text.split()
|
|
76
|
+
if len(words) <= max_words:
|
|
77
|
+
return [text]
|
|
78
|
+
return [" ".join(words[i : i + max_words]) for i in range(0, len(words), max_words)]
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def _coerce_mapping(value: Any) -> Optional[Dict[str, Any]]:
|
|
82
|
+
"""Attempt to convert ``value`` into a mapping dictionary."""
|
|
83
|
+
|
|
84
|
+
if value is None:
|
|
85
|
+
return None
|
|
86
|
+
try:
|
|
87
|
+
if pd.isna(value): # type: ignore[arg-type]
|
|
88
|
+
return None
|
|
89
|
+
except Exception:
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
if isinstance(value, dict):
|
|
93
|
+
return deepcopy(value)
|
|
94
|
+
|
|
95
|
+
if isinstance(value, str):
|
|
96
|
+
cleaned = value.strip()
|
|
97
|
+
if not cleaned or cleaned.lower() in {"nan", "none"}:
|
|
98
|
+
return None
|
|
99
|
+
try:
|
|
100
|
+
parsed = json.loads(cleaned)
|
|
101
|
+
except Exception:
|
|
102
|
+
try:
|
|
103
|
+
parsed = ast.literal_eval(cleaned)
|
|
104
|
+
except Exception:
|
|
105
|
+
return None
|
|
106
|
+
if isinstance(parsed, dict):
|
|
107
|
+
return deepcopy(parsed)
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
# ------------------------------------------------------------------
|
|
113
|
+
# Main entry point
|
|
114
|
+
# ------------------------------------------------------------------
|
|
115
|
+
async def run(
|
|
116
|
+
self,
|
|
117
|
+
df: pd.DataFrame,
|
|
118
|
+
column_name: str,
|
|
119
|
+
*,
|
|
120
|
+
grouping_column: Optional[str] = None,
|
|
121
|
+
mapping_column: Optional[str] = None,
|
|
122
|
+
reset_files: bool = False,
|
|
123
|
+
**kwargs: Any,
|
|
124
|
+
) -> pd.DataFrame:
|
|
125
|
+
"""Deidentify all texts in ``df[column_name]``.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
df:
|
|
130
|
+
Input DataFrame.
|
|
131
|
+
column_name:
|
|
132
|
+
Name of the column containing the text to de-identify.
|
|
133
|
+
grouping_column:
|
|
134
|
+
Optional column whose values determine which rows belong to the same
|
|
135
|
+
individual/entity. When omitted, each row is treated independently.
|
|
136
|
+
mapping_column:
|
|
137
|
+
Optional column containing pre-existing mapping dictionaries. The
|
|
138
|
+
first non-empty mapping encountered for each group is used as the
|
|
139
|
+
warm start and is also the mapping reused when
|
|
140
|
+
``use_existing_mappings_only`` is ``True``.
|
|
141
|
+
reset_files:
|
|
142
|
+
When ``True``, intermediate CSV logs from :func:`get_all_responses`
|
|
143
|
+
are regenerated.
|
|
144
|
+
**kwargs:
|
|
145
|
+
Additional keyword arguments forwarded to
|
|
146
|
+
:func:`gabriel.utils.openai_utils.get_all_responses`.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
if column_name not in df.columns:
|
|
150
|
+
raise ValueError(f"Column '{column_name}' not found in DataFrame")
|
|
151
|
+
if grouping_column is not None and grouping_column not in df.columns:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Grouping column '{grouping_column}' not found in DataFrame"
|
|
154
|
+
)
|
|
155
|
+
if mapping_column is not None and mapping_column not in df.columns:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"Mapping column '{mapping_column}' not found in DataFrame"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
df_proc = df.reset_index(drop=True).copy()
|
|
161
|
+
|
|
162
|
+
if grouping_column is None:
|
|
163
|
+
df_proc["group_id"] = df_proc.index.astype(str)
|
|
164
|
+
else:
|
|
165
|
+
df_proc["group_id"] = df_proc[grouping_column].astype(str)
|
|
166
|
+
|
|
167
|
+
group_ids = df_proc["group_id"].unique().tolist()
|
|
168
|
+
base_name = Path(self.cfg.file_name).stem
|
|
169
|
+
csv_path = Path(self.cfg.save_dir) / f"{base_name}_cleaned.csv"
|
|
170
|
+
raw_prefix = Path(self.cfg.save_dir) / f"{base_name}_raw_responses"
|
|
171
|
+
|
|
172
|
+
group_segments: Dict[str, List[str]] = {}
|
|
173
|
+
for gid in group_ids:
|
|
174
|
+
segs: List[str] = []
|
|
175
|
+
texts = (
|
|
176
|
+
df_proc.loc[df_proc["group_id"] == gid, column_name]
|
|
177
|
+
.fillna("")
|
|
178
|
+
.astype(str)
|
|
179
|
+
.tolist()
|
|
180
|
+
)
|
|
181
|
+
for text in texts:
|
|
182
|
+
segs.extend(self._chunk_by_words(text, self.cfg.max_words_per_call))
|
|
183
|
+
group_segments[gid] = segs
|
|
184
|
+
|
|
185
|
+
group_to_map: Dict[str, dict] = {gid: {} for gid in group_ids}
|
|
186
|
+
|
|
187
|
+
warm_start_count = 0
|
|
188
|
+
if mapping_column is not None:
|
|
189
|
+
for gid, subset in df_proc.groupby("group_id"):
|
|
190
|
+
values = subset[mapping_column].tolist()
|
|
191
|
+
mapping = next(
|
|
192
|
+
(m for m in (self._coerce_mapping(v) for v in values) if m is not None),
|
|
193
|
+
None,
|
|
194
|
+
)
|
|
195
|
+
if mapping is not None:
|
|
196
|
+
group_to_map[str(gid)] = mapping
|
|
197
|
+
warm_start_count += 1
|
|
198
|
+
if warm_start_count:
|
|
199
|
+
print(
|
|
200
|
+
f"[Deidentify] Loaded {warm_start_count} warm-start mapping(s) "
|
|
201
|
+
f"from column '{mapping_column}'."
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
print(
|
|
205
|
+
f"[Deidentify] Column '{mapping_column}' provided but no usable "
|
|
206
|
+
"mappings were found."
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
print(
|
|
210
|
+
"[Deidentify] Tip: edit the first mapping for each group and rerun with "
|
|
211
|
+
"use_existing_mappings_only=True to apply your changes without new LLM calls."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
max_rounds = max(len(segs) for segs in group_segments.values()) if group_segments else 0
|
|
215
|
+
|
|
216
|
+
if self.cfg.use_existing_mappings_only:
|
|
217
|
+
print(
|
|
218
|
+
"[Deidentify] use_existing_mappings_only=True -> skipping LLM calls and "
|
|
219
|
+
"reusing provided mappings."
|
|
220
|
+
)
|
|
221
|
+
missing = [gid for gid, mapping in group_to_map.items() if not mapping]
|
|
222
|
+
if missing:
|
|
223
|
+
print(
|
|
224
|
+
"[Deidentify] Warning: no mapping provided for "
|
|
225
|
+
f"{len(missing)} group(s). Their text will be returned unchanged."
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
for pass_idx in range(self.cfg.n_passes):
|
|
229
|
+
if self.cfg.n_passes > 1:
|
|
230
|
+
print(
|
|
231
|
+
f"[Deidentify] Starting pass {pass_idx + 1}/{self.cfg.n_passes}."
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
for rnd in range(max_rounds):
|
|
235
|
+
prompts: List[str] = []
|
|
236
|
+
identifiers: List[str] = []
|
|
237
|
+
id_to_gid: Dict[str, str] = {}
|
|
238
|
+
|
|
239
|
+
for gid in group_ids:
|
|
240
|
+
segs = group_segments[gid]
|
|
241
|
+
if rnd >= len(segs):
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
ident = f"{gid}_pass{pass_idx}_seg_{rnd}"
|
|
245
|
+
identifiers.append(ident)
|
|
246
|
+
id_to_gid[ident] = gid
|
|
247
|
+
prompts.append(
|
|
248
|
+
self.template.render(
|
|
249
|
+
text=segs[rnd],
|
|
250
|
+
current_map=json.dumps(
|
|
251
|
+
group_to_map.get(gid, {}), ensure_ascii=False
|
|
252
|
+
),
|
|
253
|
+
additional_instructions=self.cfg.additional_instructions,
|
|
254
|
+
)
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if not prompts:
|
|
258
|
+
continue
|
|
259
|
+
|
|
260
|
+
announce_prompt_rendering("Deidentify", len(prompts))
|
|
261
|
+
|
|
262
|
+
save_path = raw_prefix.with_name(
|
|
263
|
+
f"{raw_prefix.name}_pass{pass_idx}_round{rnd}.csv"
|
|
264
|
+
)
|
|
265
|
+
batch_df = await get_all_responses(
|
|
266
|
+
prompts=prompts,
|
|
267
|
+
identifiers=identifiers,
|
|
268
|
+
n_parallels=self.cfg.n_parallels,
|
|
269
|
+
model=self.cfg.model,
|
|
270
|
+
save_path=str(save_path),
|
|
271
|
+
use_dummy=self.cfg.use_dummy,
|
|
272
|
+
max_timeout=self.cfg.max_timeout,
|
|
273
|
+
json_mode=True,
|
|
274
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
275
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
276
|
+
reset_files=reset_files,
|
|
277
|
+
**kwargs,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
for ident, resp in zip(batch_df["Identifier"], batch_df["Response"]):
|
|
281
|
+
gid = id_to_gid.get(ident)
|
|
282
|
+
if gid is None:
|
|
283
|
+
continue
|
|
284
|
+
main = resp[0] if isinstance(resp, list) and resp else ""
|
|
285
|
+
parsed = await safest_json(main)
|
|
286
|
+
if isinstance(parsed, dict):
|
|
287
|
+
group_to_map[gid] = parsed
|
|
288
|
+
|
|
289
|
+
mappings_col: List[dict] = []
|
|
290
|
+
deidentified_texts: List[str] = []
|
|
291
|
+
for _, row in df_proc.iterrows():
|
|
292
|
+
gid = row["group_id"]
|
|
293
|
+
mapping = group_to_map.get(gid, {})
|
|
294
|
+
mappings_col.append(mapping)
|
|
295
|
+
text = str(row[column_name])
|
|
296
|
+
deid_text = text
|
|
297
|
+
pairs: List[Tuple[str, str]] = []
|
|
298
|
+
for entry in mapping.values():
|
|
299
|
+
if isinstance(entry, dict):
|
|
300
|
+
casted = entry.get("casted form", "")
|
|
301
|
+
for real in entry.get("real forms", []) or []:
|
|
302
|
+
if casted and real:
|
|
303
|
+
pairs.append((real, casted))
|
|
304
|
+
pairs.sort(key=lambda x: len(x[0]), reverse=True)
|
|
305
|
+
for real, casted in pairs:
|
|
306
|
+
escaped = re.escape(real)
|
|
307
|
+
pattern = re.compile(
|
|
308
|
+
rf"(?<!\w){escaped}(?!\w)", flags=re.IGNORECASE
|
|
309
|
+
)
|
|
310
|
+
deid_text = pattern.sub(casted, deid_text)
|
|
311
|
+
deidentified_texts.append(deid_text)
|
|
312
|
+
|
|
313
|
+
df_proc["mapping"] = mappings_col
|
|
314
|
+
df_proc["deidentified_text"] = deidentified_texts
|
|
315
|
+
df_proc.to_csv(csv_path, index=False)
|
|
316
|
+
return df_proc
|