openai-gabriel 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. gabriel/__init__.py +61 -0
  2. gabriel/_version.py +1 -0
  3. gabriel/api.py +2284 -0
  4. gabriel/cli/__main__.py +60 -0
  5. gabriel/core/__init__.py +7 -0
  6. gabriel/core/llm_client.py +34 -0
  7. gabriel/core/pipeline.py +18 -0
  8. gabriel/core/prompt_template.py +152 -0
  9. gabriel/prompts/__init__.py +1 -0
  10. gabriel/prompts/bucket_prompt.jinja2 +113 -0
  11. gabriel/prompts/classification_prompt.jinja2 +50 -0
  12. gabriel/prompts/codify_prompt.jinja2 +95 -0
  13. gabriel/prompts/comparison_prompt.jinja2 +60 -0
  14. gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
  15. gabriel/prompts/deidentification_prompt.jinja2 +112 -0
  16. gabriel/prompts/extraction_prompt.jinja2 +61 -0
  17. gabriel/prompts/filter_prompt.jinja2 +31 -0
  18. gabriel/prompts/ideation_prompt.jinja2 +80 -0
  19. gabriel/prompts/merge_prompt.jinja2 +47 -0
  20. gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
  21. gabriel/prompts/rankings_prompt.jinja2 +49 -0
  22. gabriel/prompts/ratings_prompt.jinja2 +50 -0
  23. gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
  24. gabriel/prompts/seed.jinja2 +43 -0
  25. gabriel/prompts/snippets.jinja2 +117 -0
  26. gabriel/tasks/__init__.py +63 -0
  27. gabriel/tasks/_attribute_utils.py +69 -0
  28. gabriel/tasks/bucket.py +432 -0
  29. gabriel/tasks/classify.py +562 -0
  30. gabriel/tasks/codify.py +1033 -0
  31. gabriel/tasks/compare.py +235 -0
  32. gabriel/tasks/debias.py +1460 -0
  33. gabriel/tasks/deduplicate.py +341 -0
  34. gabriel/tasks/deidentify.py +316 -0
  35. gabriel/tasks/discover.py +524 -0
  36. gabriel/tasks/extract.py +455 -0
  37. gabriel/tasks/filter.py +169 -0
  38. gabriel/tasks/ideate.py +782 -0
  39. gabriel/tasks/merge.py +464 -0
  40. gabriel/tasks/paraphrase.py +531 -0
  41. gabriel/tasks/rank.py +2041 -0
  42. gabriel/tasks/rate.py +347 -0
  43. gabriel/tasks/seed.py +465 -0
  44. gabriel/tasks/whatever.py +344 -0
  45. gabriel/utils/__init__.py +64 -0
  46. gabriel/utils/audio_utils.py +42 -0
  47. gabriel/utils/file_utils.py +464 -0
  48. gabriel/utils/image_utils.py +22 -0
  49. gabriel/utils/jinja.py +31 -0
  50. gabriel/utils/logging.py +86 -0
  51. gabriel/utils/mapmaker.py +304 -0
  52. gabriel/utils/media_utils.py +78 -0
  53. gabriel/utils/modality_utils.py +148 -0
  54. gabriel/utils/openai_utils.py +5470 -0
  55. gabriel/utils/parsing.py +282 -0
  56. gabriel/utils/passage_viewer.py +2557 -0
  57. gabriel/utils/pdf_utils.py +20 -0
  58. gabriel/utils/plot_utils.py +2881 -0
  59. gabriel/utils/prompt_utils.py +42 -0
  60. gabriel/utils/word_matching.py +158 -0
  61. openai_gabriel-1.0.1.dist-info/METADATA +443 -0
  62. openai_gabriel-1.0.1.dist-info/RECORD +67 -0
  63. openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
  64. openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
  65. openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
  66. openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
  67. openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,341 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import hashlib
5
+ import os
6
+ import re
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from scipy.cluster.vq import kmeans2
14
+
15
+ from ..core.prompt_template import PromptTemplate, resolve_template
16
+ from ..utils.openai_utils import get_all_responses
17
+ from ..utils import (
18
+ safest_json,
19
+ safe_json,
20
+ get_all_embeddings,
21
+ warn_if_modality_mismatch,
22
+ )
23
+ from ..utils.logging import announce_prompt_rendering
24
+
25
+
26
+ @dataclass
27
+ class DeduplicateConfig:
28
+ """Configuration for :class:`Deduplicate`."""
29
+
30
+ save_dir: str = "deduplicate"
31
+ file_name: str = "deduplicate_responses.csv"
32
+ model: str = "gpt-5-mini"
33
+ n_parallels: int = 650
34
+ n_runs: int = 3
35
+ use_dummy: bool = False
36
+ max_timeout: Optional[float] = None
37
+ additional_instructions: Optional[str] = None
38
+ use_embeddings: bool = True
39
+ group_size: int = 500
40
+ modality: str = "entity"
41
+ max_words_per_text: int = 500
42
+
43
+ def __post_init__(self) -> None:
44
+ if self.additional_instructions is not None:
45
+ cleaned = str(self.additional_instructions).strip()
46
+ self.additional_instructions = cleaned or None
47
+
48
+
49
+ class Deduplicate:
50
+ """LLM-assisted deduplication for a single DataFrame column."""
51
+
52
+ def __init__(
53
+ self,
54
+ cfg: DeduplicateConfig,
55
+ template: Optional[PromptTemplate] = None,
56
+ template_path: Optional[str] = None,
57
+ ) -> None:
58
+ expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
59
+ expanded.mkdir(parents=True, exist_ok=True)
60
+ cfg.save_dir = str(expanded)
61
+ self.cfg = cfg
62
+ self.template = resolve_template(
63
+ template=template,
64
+ template_path=template_path,
65
+ reference_filename="deduplicate_prompt.jinja2",
66
+ )
67
+
68
+ # ------------------------------------------------------------------
69
+ @staticmethod
70
+ def _deduplicate(series: pd.Series) -> Tuple[List[str], Dict[str, List[str]], Dict[str, str]]:
71
+ """Return (unique_values, rep_to_group, orig_to_rep) for a Series."""
72
+ rep_map: Dict[str, str] = {}
73
+ groups: Dict[str, List[str]] = {}
74
+ orig_to_rep: Dict[str, str] = {}
75
+ for val in series.dropna().astype(str):
76
+ norm = re.sub(r"[^0-9a-z]+", "", val.lower())
77
+ if norm in rep_map:
78
+ rep = rep_map[norm]
79
+ groups[rep].append(val)
80
+ else:
81
+ rep_map[norm] = val
82
+ groups[val] = [val]
83
+ orig_to_rep[val] = rep_map[norm]
84
+ uniques = list(groups.keys())
85
+ return uniques, groups, orig_to_rep
86
+
87
+ # ------------------------------------------------------------------
88
+ @staticmethod
89
+ def _print_stats(before: pd.Series, after: pd.Series, *, run_idx: int, total_runs: int) -> None:
90
+ total = before.notna().sum()
91
+ diff = (before.fillna("<NA>") != after.fillna("<NA>")).sum()
92
+ percent = (diff / total * 100) if total else 0.0
93
+ unique_mapped = after.dropna().nunique()
94
+ avg_per_map = (total / unique_mapped) if unique_mapped else 0.0
95
+ print(
96
+ f"[Deduplicate] Run {run_idx + 1}/{total_runs}: {diff} deduplications "
97
+ f"({percent:.2f}% of {total})."
98
+ )
99
+ print(
100
+ f"[Deduplicate] Unique mapped terms: {unique_mapped}; "
101
+ f"avg terms per mapping: {avg_per_map:.2f}."
102
+ )
103
+
104
+ # ------------------------------------------------------------------
105
+ async def _run_once(
106
+ self,
107
+ df_proc: pd.DataFrame,
108
+ *,
109
+ column_name: str,
110
+ output_col: str,
111
+ raw_texts: Optional[Dict[str, str]] = None,
112
+ reset_files: bool,
113
+ run_idx: int,
114
+ **kwargs: Any,
115
+ ) -> None:
116
+ uniques, groups, orig_to_rep = self._deduplicate(df_proc[column_name])
117
+
118
+ use_embeddings = self.cfg.use_embeddings and len(uniques) >= self.cfg.group_size
119
+
120
+ batches: List[List[str]] = []
121
+ if use_embeddings:
122
+ embed_texts = uniques
123
+ if raw_texts is not None:
124
+ embed_texts = [raw_texts.get(u, u) for u in uniques]
125
+ emb = await get_all_embeddings(
126
+ texts=embed_texts,
127
+ identifiers=uniques,
128
+ save_path=os.path.join(self.cfg.save_dir, "deduplicate_embeddings.pkl"),
129
+ reset_file=reset_files and run_idx == 0,
130
+ use_dummy=self.cfg.use_dummy,
131
+ verbose=False,
132
+ )
133
+ if emb:
134
+ arr = np.array([emb[u] for u in uniques], dtype=float)
135
+ k = max(1, int(np.ceil(len(uniques) / self.cfg.group_size)))
136
+ _, labels = kmeans2(arr, k, minit="points")
137
+ clusters: List[List[str]] = [[] for _ in range(k)]
138
+ for term, lbl in zip(uniques, labels):
139
+ clusters[int(lbl)].append(term)
140
+ current: List[str] = []
141
+ for cluster in clusters:
142
+ for term in cluster:
143
+ current.append(term)
144
+ if len(current) >= self.cfg.group_size:
145
+ batches.append(current)
146
+ current = []
147
+ if current:
148
+ batches.append(current)
149
+ if not batches:
150
+ sorted_uniques = sorted(uniques, key=lambda x: x.lower())
151
+ for i in range(0, len(sorted_uniques), self.cfg.group_size):
152
+ batches.append(sorted_uniques[i : i + self.cfg.group_size])
153
+
154
+ prompts: List[str] = []
155
+ identifiers: List[str] = []
156
+ announce_prompt_rendering("Deduplicate", len(batches) * max(1, self.cfg.n_runs))
157
+ for idx, items in enumerate(batches):
158
+ raw_terms: Any = items
159
+ if raw_texts is not None:
160
+ raw_terms = {ident: raw_texts.get(ident, "") for ident in items}
161
+ prompts.append(
162
+ self.template.render(
163
+ group_id=f"deduplicate_{idx:05d}",
164
+ raw_terms=raw_terms,
165
+ additional_instructions=self.cfg.additional_instructions or "",
166
+ modality=self.cfg.modality,
167
+ )
168
+ )
169
+ identifiers.append(f"deduplicate_{idx:05d}")
170
+
171
+ base, ext = os.path.splitext(self.cfg.file_name)
172
+ if self.cfg.n_runs > 1:
173
+ response_file = f"{base}_run{run_idx + 1}{ext}"
174
+ else:
175
+ response_file = self.cfg.file_name
176
+ save_path = os.path.join(self.cfg.save_dir, response_file)
177
+ if prompts:
178
+ resp_df = await get_all_responses(
179
+ prompts=prompts,
180
+ identifiers=identifiers,
181
+ n_parallels=self.cfg.n_parallels,
182
+ model=self.cfg.model,
183
+ save_path=save_path,
184
+ use_dummy=self.cfg.use_dummy,
185
+ max_timeout=self.cfg.max_timeout,
186
+ json_mode=True,
187
+ reset_files=reset_files,
188
+ **kwargs,
189
+ )
190
+ else:
191
+ resp_df = pd.DataFrame(columns=["Identifier", "Response"])
192
+
193
+ resp_map = dict(zip(resp_df.get("Identifier", []), resp_df.get("Response", [])))
194
+ parsed = await asyncio.gather(
195
+ *[safest_json(resp_map.get(i, "")) for i in identifiers]
196
+ )
197
+
198
+ mappings: Dict[str, str] = {}
199
+ for items, res in zip(batches, parsed):
200
+ if isinstance(res, str):
201
+ res = safe_json(res)
202
+ if isinstance(res, dict):
203
+ for rep, vals in res.items():
204
+ if isinstance(vals, list):
205
+ for val in vals:
206
+ if isinstance(val, str) and val in items:
207
+ mappings[val] = rep
208
+ elif isinstance(res, list):
209
+ for row in res:
210
+ if isinstance(row, str):
211
+ row = safe_json(row)
212
+ if isinstance(row, dict):
213
+ inp = row.get("input")
214
+ mapped = row.get("mapped")
215
+ if (
216
+ isinstance(inp, str)
217
+ and isinstance(mapped, str)
218
+ and inp in items
219
+ ):
220
+ mappings[inp] = mapped
221
+
222
+ for rep in uniques:
223
+ mappings.setdefault(rep, rep)
224
+
225
+ mapped_vals: List[Optional[str]] = []
226
+ for val in df_proc[column_name]:
227
+ if pd.isna(val):
228
+ mapped_vals.append(val)
229
+ else:
230
+ rep = orig_to_rep.get(str(val), str(val))
231
+ mapped_vals.append(mappings.get(rep, rep))
232
+ df_proc[output_col] = mapped_vals
233
+
234
+ # ------------------------------------------------------------------
235
+
236
+ async def run(
237
+ self,
238
+ df: pd.DataFrame,
239
+ *,
240
+ column_name: str,
241
+ reset_files: bool = False,
242
+ nruns: Optional[int] = None,
243
+ **kwargs: Any,
244
+ ) -> pd.DataFrame:
245
+ """Deduplicate a column using LLM assistance."""
246
+
247
+ df_proc = df.reset_index(drop=True).copy()
248
+ n_runs = nruns if nruns is not None else self.cfg.n_runs
249
+ values = df_proc[column_name].tolist()
250
+ warn_if_modality_mismatch(values, self.cfg.modality, column_name=column_name)
251
+ current_col = column_name
252
+ raw_texts: Optional[Dict[str, str]] = None
253
+ id_to_original: Dict[str, Any] = {}
254
+
255
+ if self.cfg.modality == "text":
256
+ if self.cfg.group_size > 25:
257
+ print(
258
+ "[Deduplicate] For modality='text', consider a smaller group_size "
259
+ "(e.g., 25) to keep prompts concise."
260
+ )
261
+ ids: List[Optional[str]] = []
262
+ truncated: List[Optional[str]] = []
263
+ truncated_count = 0
264
+ raw_texts = {}
265
+ for value in values:
266
+ if pd.isna(value):
267
+ ids.append(None)
268
+ truncated.append(None)
269
+ continue
270
+ text = str(value)
271
+ sha8 = hashlib.sha1(text.encode()).hexdigest()[:8]
272
+ ids.append(sha8)
273
+ if sha8 not in id_to_original:
274
+ id_to_original[sha8] = value
275
+ words = text.split()
276
+ if len(words) > self.cfg.max_words_per_text:
277
+ truncated_count += 1
278
+ clipped = " ".join(words[: self.cfg.max_words_per_text])
279
+ else:
280
+ clipped = text
281
+ truncated.append(clipped)
282
+ raw_texts[sha8] = clipped
283
+ df_proc["_dedup_id"] = ids
284
+ df_proc[f"{column_name}_truncated"] = truncated
285
+ total = sum(1 for v in values if not pd.isna(v))
286
+ frac = (truncated_count / total * 100) if total else 0.0
287
+ print(
288
+ f"[Deduplicate] Truncated {truncated_count}/{total} texts "
289
+ f"({frac:.2f}%) to max_words_per_text={self.cfg.max_words_per_text}."
290
+ )
291
+ current_col = "_dedup_id"
292
+ for i in range(n_runs):
293
+ if self.cfg.modality == "text":
294
+ base = f"mapped_{column_name}_ids"
295
+ if n_runs == 1:
296
+ output_col = base
297
+ elif i == n_runs - 1:
298
+ output_col = f"{base}_final"
299
+ else:
300
+ output_col = f"{base}_run{i + 1}"
301
+ else:
302
+ if n_runs == 1:
303
+ output_col = f"mapped_{column_name}"
304
+ elif i == n_runs - 1:
305
+ output_col = f"mapped_{column_name}_final"
306
+ else:
307
+ output_col = f"mapped_{column_name}_run{i + 1}"
308
+ await self._run_once(
309
+ df_proc,
310
+ column_name=current_col,
311
+ output_col=output_col,
312
+ raw_texts=raw_texts,
313
+ reset_files=reset_files,
314
+ run_idx=i,
315
+ **kwargs,
316
+ )
317
+ self._print_stats(
318
+ df_proc[current_col],
319
+ df_proc[output_col],
320
+ run_idx=i,
321
+ total_runs=n_runs,
322
+ )
323
+ current_col = output_col
324
+ if self.cfg.modality == "text":
325
+ if n_runs > 1:
326
+ df_proc[f"mapped_{column_name}_ids"] = df_proc[current_col]
327
+ final_ids = (
328
+ df_proc[f"mapped_{column_name}_ids"]
329
+ if f"mapped_{column_name}_ids" in df_proc.columns
330
+ else df_proc[current_col]
331
+ )
332
+ mapped_texts: List[Any] = []
333
+ for val in final_ids:
334
+ if pd.isna(val):
335
+ mapped_texts.append(val)
336
+ else:
337
+ mapped_texts.append(id_to_original.get(str(val), val))
338
+ df_proc[f"mapped_{column_name}"] = mapped_texts
339
+ elif n_runs > 1:
340
+ df_proc[f"mapped_{column_name}"] = df_proc[current_col]
341
+ return df_proc
@@ -0,0 +1,316 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import json
5
+ import os
6
+ import re
7
+ from copy import deepcopy
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Tuple
11
+
12
+ import pandas as pd
13
+
14
+ from ..core.prompt_template import PromptTemplate, resolve_template
15
+ from ..utils import safest_json
16
+ from ..utils.openai_utils import get_all_responses
17
+ from ..utils.logging import announce_prompt_rendering
18
+
19
+
20
+ # ────────────────────────────
21
+ # Configuration dataclass
22
+ # ────────────────────────────
23
+ @dataclass
24
+ class DeidentifyConfig:
25
+ """Configuration for :class:`Deidentifier`."""
26
+
27
+ model: str = "gpt-5-mini"
28
+ n_parallels: int = 650
29
+ save_dir: str = "deidentify"
30
+ file_name: str = "deidentified.csv"
31
+ use_dummy: bool = False
32
+ max_timeout: Optional[float] = None
33
+ max_words_per_call: int = 7500
34
+ additional_instructions: Optional[str] = None
35
+ reasoning_effort: Optional[str] = None
36
+ reasoning_summary: Optional[str] = None
37
+ n_passes: int = 1
38
+ use_existing_mappings_only: bool = False
39
+
40
+ def __post_init__(self) -> None:
41
+ if not isinstance(self.n_passes, int) or self.n_passes < 1:
42
+ raise ValueError("n_passes must be an integer >= 1")
43
+ if self.additional_instructions is not None:
44
+ cleaned = str(self.additional_instructions).strip()
45
+ self.additional_instructions = cleaned or None
46
+
47
+
48
+ # ────────────────────────────
49
+ # Main de-identification task
50
+ # ────────────────────────────
51
+ class Deidentifier:
52
+ """Iterative de-identification of sensitive entities in text."""
53
+
54
+ def __init__(
55
+ self,
56
+ cfg: DeidentifyConfig,
57
+ template: Optional[PromptTemplate] = None,
58
+ template_path: Optional[str] = None,
59
+ ) -> None:
60
+ expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
61
+ expanded.mkdir(parents=True, exist_ok=True)
62
+ cfg.save_dir = str(expanded)
63
+ self.cfg = cfg
64
+ self.template = resolve_template(
65
+ template=template,
66
+ template_path=template_path,
67
+ reference_filename="deidentification_prompt.jinja2",
68
+ )
69
+
70
+ # ------------------------------------------------------------------
71
+ # Helpers
72
+ # ------------------------------------------------------------------
73
+ @staticmethod
74
+ def _chunk_by_words(text: str, max_words: int) -> List[str]:
75
+ words = text.split()
76
+ if len(words) <= max_words:
77
+ return [text]
78
+ return [" ".join(words[i : i + max_words]) for i in range(0, len(words), max_words)]
79
+
80
+ @staticmethod
81
+ def _coerce_mapping(value: Any) -> Optional[Dict[str, Any]]:
82
+ """Attempt to convert ``value`` into a mapping dictionary."""
83
+
84
+ if value is None:
85
+ return None
86
+ try:
87
+ if pd.isna(value): # type: ignore[arg-type]
88
+ return None
89
+ except Exception:
90
+ pass
91
+
92
+ if isinstance(value, dict):
93
+ return deepcopy(value)
94
+
95
+ if isinstance(value, str):
96
+ cleaned = value.strip()
97
+ if not cleaned or cleaned.lower() in {"nan", "none"}:
98
+ return None
99
+ try:
100
+ parsed = json.loads(cleaned)
101
+ except Exception:
102
+ try:
103
+ parsed = ast.literal_eval(cleaned)
104
+ except Exception:
105
+ return None
106
+ if isinstance(parsed, dict):
107
+ return deepcopy(parsed)
108
+ return None
109
+
110
+ return None
111
+
112
+ # ------------------------------------------------------------------
113
+ # Main entry point
114
+ # ------------------------------------------------------------------
115
+ async def run(
116
+ self,
117
+ df: pd.DataFrame,
118
+ column_name: str,
119
+ *,
120
+ grouping_column: Optional[str] = None,
121
+ mapping_column: Optional[str] = None,
122
+ reset_files: bool = False,
123
+ **kwargs: Any,
124
+ ) -> pd.DataFrame:
125
+ """Deidentify all texts in ``df[column_name]``.
126
+
127
+ Parameters
128
+ ----------
129
+ df:
130
+ Input DataFrame.
131
+ column_name:
132
+ Name of the column containing the text to de-identify.
133
+ grouping_column:
134
+ Optional column whose values determine which rows belong to the same
135
+ individual/entity. When omitted, each row is treated independently.
136
+ mapping_column:
137
+ Optional column containing pre-existing mapping dictionaries. The
138
+ first non-empty mapping encountered for each group is used as the
139
+ warm start and is also the mapping reused when
140
+ ``use_existing_mappings_only`` is ``True``.
141
+ reset_files:
142
+ When ``True``, intermediate CSV logs from :func:`get_all_responses`
143
+ are regenerated.
144
+ **kwargs:
145
+ Additional keyword arguments forwarded to
146
+ :func:`gabriel.utils.openai_utils.get_all_responses`.
147
+ """
148
+
149
+ if column_name not in df.columns:
150
+ raise ValueError(f"Column '{column_name}' not found in DataFrame")
151
+ if grouping_column is not None and grouping_column not in df.columns:
152
+ raise ValueError(
153
+ f"Grouping column '{grouping_column}' not found in DataFrame"
154
+ )
155
+ if mapping_column is not None and mapping_column not in df.columns:
156
+ raise ValueError(
157
+ f"Mapping column '{mapping_column}' not found in DataFrame"
158
+ )
159
+
160
+ df_proc = df.reset_index(drop=True).copy()
161
+
162
+ if grouping_column is None:
163
+ df_proc["group_id"] = df_proc.index.astype(str)
164
+ else:
165
+ df_proc["group_id"] = df_proc[grouping_column].astype(str)
166
+
167
+ group_ids = df_proc["group_id"].unique().tolist()
168
+ base_name = Path(self.cfg.file_name).stem
169
+ csv_path = Path(self.cfg.save_dir) / f"{base_name}_cleaned.csv"
170
+ raw_prefix = Path(self.cfg.save_dir) / f"{base_name}_raw_responses"
171
+
172
+ group_segments: Dict[str, List[str]] = {}
173
+ for gid in group_ids:
174
+ segs: List[str] = []
175
+ texts = (
176
+ df_proc.loc[df_proc["group_id"] == gid, column_name]
177
+ .fillna("")
178
+ .astype(str)
179
+ .tolist()
180
+ )
181
+ for text in texts:
182
+ segs.extend(self._chunk_by_words(text, self.cfg.max_words_per_call))
183
+ group_segments[gid] = segs
184
+
185
+ group_to_map: Dict[str, dict] = {gid: {} for gid in group_ids}
186
+
187
+ warm_start_count = 0
188
+ if mapping_column is not None:
189
+ for gid, subset in df_proc.groupby("group_id"):
190
+ values = subset[mapping_column].tolist()
191
+ mapping = next(
192
+ (m for m in (self._coerce_mapping(v) for v in values) if m is not None),
193
+ None,
194
+ )
195
+ if mapping is not None:
196
+ group_to_map[str(gid)] = mapping
197
+ warm_start_count += 1
198
+ if warm_start_count:
199
+ print(
200
+ f"[Deidentify] Loaded {warm_start_count} warm-start mapping(s) "
201
+ f"from column '{mapping_column}'."
202
+ )
203
+ else:
204
+ print(
205
+ f"[Deidentify] Column '{mapping_column}' provided but no usable "
206
+ "mappings were found."
207
+ )
208
+
209
+ print(
210
+ "[Deidentify] Tip: edit the first mapping for each group and rerun with "
211
+ "use_existing_mappings_only=True to apply your changes without new LLM calls."
212
+ )
213
+
214
+ max_rounds = max(len(segs) for segs in group_segments.values()) if group_segments else 0
215
+
216
+ if self.cfg.use_existing_mappings_only:
217
+ print(
218
+ "[Deidentify] use_existing_mappings_only=True -> skipping LLM calls and "
219
+ "reusing provided mappings."
220
+ )
221
+ missing = [gid for gid, mapping in group_to_map.items() if not mapping]
222
+ if missing:
223
+ print(
224
+ "[Deidentify] Warning: no mapping provided for "
225
+ f"{len(missing)} group(s). Their text will be returned unchanged."
226
+ )
227
+ else:
228
+ for pass_idx in range(self.cfg.n_passes):
229
+ if self.cfg.n_passes > 1:
230
+ print(
231
+ f"[Deidentify] Starting pass {pass_idx + 1}/{self.cfg.n_passes}."
232
+ )
233
+
234
+ for rnd in range(max_rounds):
235
+ prompts: List[str] = []
236
+ identifiers: List[str] = []
237
+ id_to_gid: Dict[str, str] = {}
238
+
239
+ for gid in group_ids:
240
+ segs = group_segments[gid]
241
+ if rnd >= len(segs):
242
+ continue
243
+
244
+ ident = f"{gid}_pass{pass_idx}_seg_{rnd}"
245
+ identifiers.append(ident)
246
+ id_to_gid[ident] = gid
247
+ prompts.append(
248
+ self.template.render(
249
+ text=segs[rnd],
250
+ current_map=json.dumps(
251
+ group_to_map.get(gid, {}), ensure_ascii=False
252
+ ),
253
+ additional_instructions=self.cfg.additional_instructions,
254
+ )
255
+ )
256
+
257
+ if not prompts:
258
+ continue
259
+
260
+ announce_prompt_rendering("Deidentify", len(prompts))
261
+
262
+ save_path = raw_prefix.with_name(
263
+ f"{raw_prefix.name}_pass{pass_idx}_round{rnd}.csv"
264
+ )
265
+ batch_df = await get_all_responses(
266
+ prompts=prompts,
267
+ identifiers=identifiers,
268
+ n_parallels=self.cfg.n_parallels,
269
+ model=self.cfg.model,
270
+ save_path=str(save_path),
271
+ use_dummy=self.cfg.use_dummy,
272
+ max_timeout=self.cfg.max_timeout,
273
+ json_mode=True,
274
+ reasoning_effort=self.cfg.reasoning_effort,
275
+ reasoning_summary=self.cfg.reasoning_summary,
276
+ reset_files=reset_files,
277
+ **kwargs,
278
+ )
279
+
280
+ for ident, resp in zip(batch_df["Identifier"], batch_df["Response"]):
281
+ gid = id_to_gid.get(ident)
282
+ if gid is None:
283
+ continue
284
+ main = resp[0] if isinstance(resp, list) and resp else ""
285
+ parsed = await safest_json(main)
286
+ if isinstance(parsed, dict):
287
+ group_to_map[gid] = parsed
288
+
289
+ mappings_col: List[dict] = []
290
+ deidentified_texts: List[str] = []
291
+ for _, row in df_proc.iterrows():
292
+ gid = row["group_id"]
293
+ mapping = group_to_map.get(gid, {})
294
+ mappings_col.append(mapping)
295
+ text = str(row[column_name])
296
+ deid_text = text
297
+ pairs: List[Tuple[str, str]] = []
298
+ for entry in mapping.values():
299
+ if isinstance(entry, dict):
300
+ casted = entry.get("casted form", "")
301
+ for real in entry.get("real forms", []) or []:
302
+ if casted and real:
303
+ pairs.append((real, casted))
304
+ pairs.sort(key=lambda x: len(x[0]), reverse=True)
305
+ for real, casted in pairs:
306
+ escaped = re.escape(real)
307
+ pattern = re.compile(
308
+ rf"(?<!\w){escaped}(?!\w)", flags=re.IGNORECASE
309
+ )
310
+ deid_text = pattern.sub(casted, deid_text)
311
+ deidentified_texts.append(deid_text)
312
+
313
+ df_proc["mapping"] = mappings_col
314
+ df_proc["deidentified_text"] = deidentified_texts
315
+ df_proc.to_csv(csv_path, index=False)
316
+ return df_proc