openai-gabriel 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. gabriel/__init__.py +61 -0
  2. gabriel/_version.py +1 -0
  3. gabriel/api.py +2284 -0
  4. gabriel/cli/__main__.py +60 -0
  5. gabriel/core/__init__.py +7 -0
  6. gabriel/core/llm_client.py +34 -0
  7. gabriel/core/pipeline.py +18 -0
  8. gabriel/core/prompt_template.py +152 -0
  9. gabriel/prompts/__init__.py +1 -0
  10. gabriel/prompts/bucket_prompt.jinja2 +113 -0
  11. gabriel/prompts/classification_prompt.jinja2 +50 -0
  12. gabriel/prompts/codify_prompt.jinja2 +95 -0
  13. gabriel/prompts/comparison_prompt.jinja2 +60 -0
  14. gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
  15. gabriel/prompts/deidentification_prompt.jinja2 +112 -0
  16. gabriel/prompts/extraction_prompt.jinja2 +61 -0
  17. gabriel/prompts/filter_prompt.jinja2 +31 -0
  18. gabriel/prompts/ideation_prompt.jinja2 +80 -0
  19. gabriel/prompts/merge_prompt.jinja2 +47 -0
  20. gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
  21. gabriel/prompts/rankings_prompt.jinja2 +49 -0
  22. gabriel/prompts/ratings_prompt.jinja2 +50 -0
  23. gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
  24. gabriel/prompts/seed.jinja2 +43 -0
  25. gabriel/prompts/snippets.jinja2 +117 -0
  26. gabriel/tasks/__init__.py +63 -0
  27. gabriel/tasks/_attribute_utils.py +69 -0
  28. gabriel/tasks/bucket.py +432 -0
  29. gabriel/tasks/classify.py +562 -0
  30. gabriel/tasks/codify.py +1033 -0
  31. gabriel/tasks/compare.py +235 -0
  32. gabriel/tasks/debias.py +1460 -0
  33. gabriel/tasks/deduplicate.py +341 -0
  34. gabriel/tasks/deidentify.py +316 -0
  35. gabriel/tasks/discover.py +524 -0
  36. gabriel/tasks/extract.py +455 -0
  37. gabriel/tasks/filter.py +169 -0
  38. gabriel/tasks/ideate.py +782 -0
  39. gabriel/tasks/merge.py +464 -0
  40. gabriel/tasks/paraphrase.py +531 -0
  41. gabriel/tasks/rank.py +2041 -0
  42. gabriel/tasks/rate.py +347 -0
  43. gabriel/tasks/seed.py +465 -0
  44. gabriel/tasks/whatever.py +344 -0
  45. gabriel/utils/__init__.py +64 -0
  46. gabriel/utils/audio_utils.py +42 -0
  47. gabriel/utils/file_utils.py +464 -0
  48. gabriel/utils/image_utils.py +22 -0
  49. gabriel/utils/jinja.py +31 -0
  50. gabriel/utils/logging.py +86 -0
  51. gabriel/utils/mapmaker.py +304 -0
  52. gabriel/utils/media_utils.py +78 -0
  53. gabriel/utils/modality_utils.py +148 -0
  54. gabriel/utils/openai_utils.py +5470 -0
  55. gabriel/utils/parsing.py +282 -0
  56. gabriel/utils/passage_viewer.py +2557 -0
  57. gabriel/utils/pdf_utils.py +20 -0
  58. gabriel/utils/plot_utils.py +2881 -0
  59. gabriel/utils/prompt_utils.py +42 -0
  60. gabriel/utils/word_matching.py +158 -0
  61. openai_gabriel-1.0.1.dist-info/METADATA +443 -0
  62. openai_gabriel-1.0.1.dist-info/RECORD +67 -0
  63. openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
  64. openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
  65. openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
  66. openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
  67. openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,562 @@
1
+ from __future__ import annotations
2
+ import hashlib
3
+ import os
4
+ from pathlib import Path
5
+ import re
6
+ import random
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass
9
+ from typing import Any, DefaultDict, Dict, List, Optional, Set
10
+ import json
11
+
12
+ import pandas as pd
13
+
14
+ from ..core.prompt_template import PromptTemplate, resolve_template
15
+ from ..utils.openai_utils import get_all_responses
16
+ from ..utils import (
17
+ safest_json,
18
+ load_image_inputs,
19
+ load_audio_inputs,
20
+ load_pdf_inputs,
21
+ warn_if_modality_mismatch,
22
+ )
23
+ from ..utils.logging import announce_prompt_rendering
24
+ from ._attribute_utils import load_persisted_attributes
25
+
26
+
27
+ def _collect_predictions(row: pd.Series) -> List[str]:
28
+ """Return labels whose values evaluate to ``True``.
29
+
30
+ Parameters
31
+ ----------
32
+ row:
33
+ A series containing only label columns.
34
+
35
+ Returns
36
+ -------
37
+ list of str
38
+ Labels for which the value is truthy.
39
+ """
40
+
41
+ return [lab for lab, val in row.items() if bool(val)]
42
+
43
+
44
+ # ────────────────────────────
45
+ # Configuration dataclass
46
+ # ────────────────────────────
47
+ @dataclass
48
+ class ClassifyConfig:
49
+ """Configuration for :class:`Classify`."""
50
+
51
+ labels: Dict[str, str] # {"label_name": "description", ...}
52
+ save_dir: str = "classifier"
53
+ file_name: str = "classify_responses.csv"
54
+ model: str = "gpt-5-mini"
55
+ n_parallels: int = 650
56
+ n_runs: int = 1
57
+ min_frequency: float = 0.6
58
+ additional_instructions: Optional[str] = None
59
+ use_dummy: bool = False
60
+ max_timeout: Optional[float] = None
61
+ modality: str = "text"
62
+ n_attributes_per_run: int = 8
63
+ reasoning_effort: Optional[str] = None
64
+ reasoning_summary: Optional[str] = None
65
+ differentiate: bool = False
66
+ circle_first: Optional[bool] = None
67
+ search_context_size: str = "medium"
68
+
69
+ def __post_init__(self) -> None:
70
+ if self.additional_instructions is not None:
71
+ cleaned = str(self.additional_instructions).strip()
72
+ self.additional_instructions = cleaned or None
73
+
74
+
75
+ # ────────────────────────────
76
+ # Main Basic classifier task
77
+ # ────────────────────────────
78
+ class Classify:
79
+ """Robust passage classifier using an LLM.
80
+
81
+ * Accepts a list of *texts* (not a DataFrame) just like :class:`Rate`.
82
+ * Persists/reads cached responses via the **save_path** attribute (same pattern as
83
+ :class:`Rate`).
84
+ """
85
+
86
+ _FENCE_RE = re.compile(r"```(?:json)?\s*(.*?)\s*```", re.S)
87
+
88
+ # -----------------------------------------------------------------
89
+ def __init__(
90
+ self,
91
+ cfg: ClassifyConfig,
92
+ template: Optional[PromptTemplate] = None,
93
+ template_path: Optional[str] = None,
94
+ ) -> None: # noqa: D401,E501
95
+ expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
96
+ expanded.mkdir(parents=True, exist_ok=True)
97
+ cfg.save_dir = str(expanded)
98
+ self.cfg = cfg
99
+ self.template = resolve_template(
100
+ template=template,
101
+ template_path=template_path,
102
+ reference_filename="classification_prompt.jinja2",
103
+ )
104
+
105
+ # -----------------------------------------------------------------
106
+ # Helpers for parsing raw model output
107
+ # -----------------------------------------------------------------
108
+ @staticmethod
109
+ def _regex(raw: str, labels: List[str]) -> Dict[str, Optional[bool]]:
110
+ out: Dict[str, Optional[bool]] = {}
111
+ for lab in labels:
112
+ pat = re.compile(
113
+ rf'\s*"?\s*{re.escape(lab)}\s*"?\s*:\s*(true|false)', re.I | re.S
114
+ )
115
+ m = pat.search(raw)
116
+ out[lab] = None if not m else m.group(1).lower() == "true"
117
+ return out
118
+
119
+ async def _parse(self, resp: Any, labels: List[str]) -> Dict[str, Optional[bool]]:
120
+ # unwrap common response containers (list-of-one, bytes, fenced blocks)
121
+ if isinstance(resp, list) and len(resp) == 1:
122
+ resp = resp[0]
123
+ if isinstance(resp, (bytes, bytearray)):
124
+ resp = resp.decode()
125
+ data: Optional[Any] = None
126
+ if isinstance(resp, str):
127
+ m = self._FENCE_RE.search(resp)
128
+ if m:
129
+ resp = m.group(1).strip()
130
+
131
+ data = await safest_json(resp)
132
+ elif isinstance(resp, dict):
133
+ data = resp
134
+ if isinstance(data, dict):
135
+ norm = {
136
+ k.strip().lower(): (
137
+ True
138
+ if str(v).strip().lower() in {"true", "yes", "1"}
139
+ else (
140
+ False
141
+ if str(v).strip().lower() in {"false", "no", "0"}
142
+ else None
143
+ )
144
+ )
145
+ for k, v in data.items()
146
+ }
147
+ return {lab: norm.get(lab.lower(), None) for lab in labels}
148
+
149
+ # fallback to regex extraction
150
+ return self._regex(str(resp), labels)
151
+
152
+ # -----------------------------------------------------------------
153
+ # Main entry point
154
+ # -----------------------------------------------------------------
155
+ async def run(
156
+ self,
157
+ df: pd.DataFrame,
158
+ column_name: Optional[str] = None,
159
+ *,
160
+ circle_column_name: Optional[str] = None,
161
+ square_column_name: Optional[str] = None,
162
+ reset_files: bool = False,
163
+ **kwargs: Any,
164
+ ) -> pd.DataFrame:
165
+ """Classify items and return ``df`` with label columns."""
166
+
167
+ if self.cfg.differentiate:
168
+ if circle_column_name is None or square_column_name is None:
169
+ raise ValueError(
170
+ "circle_column_name and square_column_name are required when differentiate is True"
171
+ )
172
+ elif column_name is None:
173
+ raise ValueError("column_name is required when differentiate is False")
174
+
175
+ df_proc = df.reset_index(drop=True).copy()
176
+ base_name = os.path.splitext(self.cfg.file_name)[0]
177
+
178
+ self.cfg.labels = load_persisted_attributes(
179
+ save_dir=self.cfg.save_dir,
180
+ incoming=self.cfg.labels,
181
+ reset_files=reset_files,
182
+ task_name="Classify",
183
+ item_name="labels",
184
+ legacy_filename=f"{base_name}_attrs.json",
185
+ )
186
+
187
+ label_items = list(self.cfg.labels.items())
188
+ label_count = len(label_items)
189
+ if label_count > self.cfg.n_attributes_per_run:
190
+ batches = (
191
+ label_count + self.cfg.n_attributes_per_run - 1
192
+ ) // self.cfg.n_attributes_per_run
193
+ print(
194
+ f"[Classify] {label_count} labels provided. n_attributes_per_run={self.cfg.n_attributes_per_run}. "
195
+ f"Splitting into {batches} prompt batches. Increase n_attributes_per_run if you want all attributes "
196
+ "to be processed in the same prompt."
197
+ )
198
+ label_batches: List[Dict[str, str]] = [
199
+ dict(label_items[i : i + self.cfg.n_attributes_per_run])
200
+ for i in range(0, len(label_items), self.cfg.n_attributes_per_run)
201
+ ]
202
+
203
+ prompts: List[str] = []
204
+ ids: List[str] = []
205
+ base_ids: List[str] = []
206
+ id_to_circle_first: Dict[str, bool] = {}
207
+ id_to_rows: DefaultDict[str, List[int]] = defaultdict(list)
208
+ id_to_val: Dict[str, Any] = {}
209
+ prompt_texts: Dict[str, str] = {}
210
+ prompt_circles: Dict[str, str] = {}
211
+ prompt_squares: Dict[str, str] = {}
212
+
213
+ if self.cfg.differentiate:
214
+ circles = df_proc[circle_column_name].tolist() # type: ignore[index]
215
+ squares = df_proc[square_column_name].tolist() # type: ignore[index]
216
+ warn_if_modality_mismatch(
217
+ circles, self.cfg.modality, column_name=str(circle_column_name)
218
+ )
219
+ warn_if_modality_mismatch(
220
+ squares, self.cfg.modality, column_name=str(square_column_name)
221
+ )
222
+ for row, (circ, sq) in enumerate(zip(circles, squares)):
223
+ clean = " ".join(str(circ).split()) + "|" + " ".join(str(sq).split())
224
+ sha8 = hashlib.sha1(clean.encode()).hexdigest()[:8]
225
+ id_to_rows[sha8].append(row)
226
+ if len(id_to_rows[sha8]) > 1:
227
+ continue
228
+ id_to_val[sha8] = (circ, sq)
229
+ prompt_circles[sha8] = (
230
+ circ if self.cfg.modality in {"text", "entity", "web"} else ""
231
+ )
232
+ prompt_squares[sha8] = (
233
+ sq if self.cfg.modality in {"text", "entity", "web"} else ""
234
+ )
235
+ circle_first_flag = (
236
+ self.cfg.circle_first
237
+ if self.cfg.circle_first is not None
238
+ else random.random() < 0.5
239
+ )
240
+ id_to_circle_first[sha8] = circle_first_flag
241
+ base_ids.append(sha8)
242
+ announce_prompt_rendering(
243
+ "Classify",
244
+ len(base_ids) * len(label_batches),
245
+ )
246
+ for batch_idx, batch_labels in enumerate(label_batches):
247
+ for ident in base_ids:
248
+ prompts.append(
249
+ self.template.render(
250
+ entry_circle=prompt_circles[ident],
251
+ entry_square=prompt_squares[ident],
252
+ attributes=batch_labels,
253
+ additional_instructions=self.cfg.additional_instructions,
254
+ differentiate=True,
255
+ modality=self.cfg.modality,
256
+ circle_first=id_to_circle_first[ident],
257
+ )
258
+ )
259
+ ids.append(f"{ident}_batch{batch_idx}")
260
+ else:
261
+ values = df_proc[column_name].tolist() # type: ignore[index]
262
+ warn_if_modality_mismatch(values, self.cfg.modality, column_name=str(column_name))
263
+ for row, val in enumerate(values):
264
+ clean = " ".join(str(val).split())
265
+ sha8 = hashlib.sha1(clean.encode()).hexdigest()[:8]
266
+ id_to_rows[sha8].append(row)
267
+ if len(id_to_rows[sha8]) > 1:
268
+ continue
269
+ id_to_val[sha8] = values[row]
270
+ prompt_texts[sha8] = (
271
+ str(values[row])
272
+ if self.cfg.modality in {"text", "entity", "web"}
273
+ else ""
274
+ )
275
+ base_ids.append(sha8)
276
+ announce_prompt_rendering(
277
+ "Classify",
278
+ len(base_ids) * len(label_batches),
279
+ )
280
+ for batch_idx, batch_labels in enumerate(label_batches):
281
+ for ident in base_ids:
282
+ prompts.append(
283
+ self.template.render(
284
+ text=prompt_texts[ident],
285
+ attributes=batch_labels,
286
+ additional_instructions=self.cfg.additional_instructions,
287
+ modality=self.cfg.modality,
288
+ )
289
+ )
290
+ ids.append(f"{ident}_batch{batch_idx}")
291
+
292
+ prompt_images: Optional[Dict[str, List[str]]] = None
293
+ prompt_audio: Optional[Dict[str, List[Dict[str, str]]]] = None
294
+ prompt_pdfs: Optional[Dict[str, List[Dict[str, str]]]] = None
295
+
296
+ if self.cfg.modality == "image":
297
+ tmp: Dict[str, List[str]] = {}
298
+ for ident, rows in id_to_rows.items():
299
+ imgs: List[str] = []
300
+ if self.cfg.differentiate:
301
+ circ, sq = id_to_val[ident]
302
+ circ_imgs = load_image_inputs(circ)
303
+ sq_imgs = load_image_inputs(sq)
304
+ if id_to_circle_first.get(ident, False):
305
+ if circ_imgs:
306
+ imgs.extend(circ_imgs)
307
+ if sq_imgs:
308
+ imgs.extend(sq_imgs)
309
+ else:
310
+ if sq_imgs:
311
+ imgs.extend(sq_imgs)
312
+ if circ_imgs:
313
+ imgs.extend(circ_imgs)
314
+ else:
315
+ imgs = load_image_inputs(id_to_val[ident])
316
+ if imgs:
317
+ for batch_idx in range(len(label_batches)):
318
+ tmp[f"{ident}_batch{batch_idx}"] = imgs
319
+ prompt_images = tmp or None
320
+ elif self.cfg.modality == "audio":
321
+ tmp_a: Dict[str, List[Dict[str, str]]] = {}
322
+ for ident, rows in id_to_rows.items():
323
+ auds: List[Dict[str, str]] = []
324
+ if self.cfg.differentiate:
325
+ circ, sq = id_to_val[ident]
326
+ circ_auds = load_audio_inputs(circ)
327
+ sq_auds = load_audio_inputs(sq)
328
+ if id_to_circle_first.get(ident, False):
329
+ if circ_auds:
330
+ auds.extend(circ_auds)
331
+ if sq_auds:
332
+ auds.extend(sq_auds)
333
+ else:
334
+ if sq_auds:
335
+ auds.extend(sq_auds)
336
+ if circ_auds:
337
+ auds.extend(circ_auds)
338
+ else:
339
+ auds = load_audio_inputs(id_to_val[ident])
340
+ if auds:
341
+ for batch_idx in range(len(label_batches)):
342
+ tmp_a[f"{ident}_batch{batch_idx}"] = auds
343
+ prompt_audio = tmp_a or None
344
+ elif self.cfg.modality == "pdf":
345
+ tmp_p: Dict[str, List[Dict[str, str]]] = {}
346
+ for ident, rows in id_to_rows.items():
347
+ pdfs: List[Dict[str, str]] = []
348
+ if self.cfg.differentiate:
349
+ circ, sq = id_to_val[ident]
350
+ circ_pdfs = load_pdf_inputs(circ)
351
+ sq_pdfs = load_pdf_inputs(sq)
352
+ if id_to_circle_first.get(ident, False):
353
+ if circ_pdfs:
354
+ pdfs.extend(circ_pdfs)
355
+ if sq_pdfs:
356
+ pdfs.extend(sq_pdfs)
357
+ else:
358
+ if sq_pdfs:
359
+ pdfs.extend(sq_pdfs)
360
+ if circ_pdfs:
361
+ pdfs.extend(circ_pdfs)
362
+ else:
363
+ pdfs = load_pdf_inputs(id_to_val[ident])
364
+ if pdfs:
365
+ for batch_idx in range(len(label_batches)):
366
+ tmp_p[f"{ident}_batch{batch_idx}"] = pdfs
367
+ prompt_pdfs = tmp_p or None
368
+
369
+ csv_path = os.path.join(self.cfg.save_dir, f"{base_name}_raw_responses.csv")
370
+
371
+ kwargs.setdefault("web_search", self.cfg.modality == "web")
372
+ kwargs.setdefault("search_context_size", self.cfg.search_context_size)
373
+
374
+ if not isinstance(self.cfg.n_runs, int) or self.cfg.n_runs < 1:
375
+ raise ValueError("n_runs must be an integer >= 1")
376
+
377
+ existing_ids: Set[str] = set()
378
+ if not reset_files and os.path.exists(csv_path):
379
+ try:
380
+ existing_df = pd.read_csv(csv_path, usecols=["Identifier"])
381
+ existing_ids = set(existing_df["Identifier"].astype(str))
382
+ except Exception:
383
+ existing_ids = set()
384
+
385
+ run_identifier_lists: List[List[str]] = []
386
+ for run_idx in range(1, self.cfg.n_runs + 1):
387
+ run_ids: List[str] = []
388
+ for ident in ids:
389
+ if run_idx == 1:
390
+ legacy_ident = f"{ident}_run1"
391
+ run_ids.append(legacy_ident if legacy_ident in existing_ids else ident)
392
+ else:
393
+ run_ids.append(f"{ident}_run{run_idx}")
394
+ run_identifier_lists.append(run_ids)
395
+
396
+ prompts_all: List[str] = []
397
+ ids_all: List[str] = []
398
+ for run_ids in run_identifier_lists:
399
+ prompts_all.extend(prompts)
400
+ ids_all.extend(run_ids)
401
+
402
+ prompt_images_all: Optional[Dict[str, List[str]]] = None
403
+ if prompt_images:
404
+ prompt_images_all = {}
405
+ for run_ids in run_identifier_lists:
406
+ for base_ident, run_ident in zip(ids, run_ids):
407
+ imgs = prompt_images.get(base_ident)
408
+ if imgs:
409
+ prompt_images_all[run_ident] = imgs
410
+ prompt_audio_all: Optional[Dict[str, List[Dict[str, str]]]] = None
411
+ if prompt_audio:
412
+ prompt_audio_all = {}
413
+ for run_ids in run_identifier_lists:
414
+ for base_ident, run_ident in zip(ids, run_ids):
415
+ auds = prompt_audio.get(base_ident)
416
+ if auds:
417
+ prompt_audio_all[run_ident] = auds
418
+ prompt_pdfs_all: Optional[Dict[str, List[Dict[str, str]]]] = None
419
+ if prompt_pdfs:
420
+ prompt_pdfs_all = {}
421
+ for run_ids in run_identifier_lists:
422
+ for base_ident, run_ident in zip(ids, run_ids):
423
+ pdfs = prompt_pdfs.get(base_ident)
424
+ if pdfs:
425
+ prompt_pdfs_all[run_ident] = pdfs
426
+
427
+ df_resp_all = await get_all_responses(
428
+ prompts=prompts_all,
429
+ identifiers=ids_all,
430
+ prompt_images=prompt_images_all,
431
+ prompt_audio=prompt_audio_all,
432
+ prompt_pdfs=prompt_pdfs_all,
433
+ n_parallels=self.cfg.n_parallels,
434
+ save_path=csv_path,
435
+ reset_files=reset_files,
436
+ json_mode=self.cfg.modality != "audio",
437
+ model=self.cfg.model,
438
+ use_dummy=self.cfg.use_dummy,
439
+ max_timeout=self.cfg.max_timeout,
440
+ reasoning_effort=self.cfg.reasoning_effort,
441
+ reasoning_summary=self.cfg.reasoning_summary,
442
+ print_example_prompt=True,
443
+ **kwargs,
444
+ )
445
+ if not isinstance(df_resp_all, pd.DataFrame):
446
+ raise RuntimeError("get_all_responses returned no DataFrame")
447
+
448
+ df_resps = []
449
+ for run_idx, run_ids in enumerate(run_identifier_lists, start=1):
450
+ suffix = f"_run{run_idx}"
451
+ sub = df_resp_all[df_resp_all.Identifier.isin(run_ids)].copy()
452
+ sub.Identifier = sub.Identifier.str.replace(
453
+ suffix + "$", "", regex=True
454
+ )
455
+ df_resps.append(sub)
456
+
457
+ # parse each run and construct disaggregated records
458
+ full_records: List[Dict[str, Any]] = []
459
+ total_orphans = 0
460
+ all_labels = list(self.cfg.labels.keys())
461
+ for run_idx, df_resp in enumerate(df_resps, start=1):
462
+ id_to_labels: Dict[str, Dict[str, Optional[bool]]] = {
463
+ ident: {lab: None for lab in all_labels} for ident in base_ids
464
+ }
465
+ orphans = 0
466
+ for ident_batch, raw in zip(df_resp.Identifier, df_resp.Response):
467
+ if "_batch" not in ident_batch:
468
+ continue
469
+ base_ident, batch_part = ident_batch.rsplit("_batch", 1)
470
+ if base_ident not in id_to_rows:
471
+ orphans += 1
472
+ continue
473
+ batch_idx = int(batch_part)
474
+ labs = list(label_batches[batch_idx].keys())
475
+ parsed = await self._parse(raw, labs)
476
+ for lab in labs:
477
+ id_to_labels[base_ident][lab] = parsed.get(lab)
478
+ total_orphans += orphans
479
+ for ident in base_ids:
480
+ parsed = id_to_labels.get(ident, {lab: None for lab in all_labels})
481
+ if self.cfg.differentiate:
482
+ circ_val, sq_val = id_to_val[ident]
483
+ rec = {"circle": circ_val, "square": sq_val, "run": run_idx}
484
+ else:
485
+ rec = {"text": id_to_val[ident], "run": run_idx}
486
+ rec.update({lab: parsed.get(lab) for lab in all_labels})
487
+ full_records.append(rec)
488
+
489
+ if total_orphans:
490
+ print(
491
+ f"[Classify] WARNING: {total_orphans} response(s) had no matching passage this run."
492
+ )
493
+
494
+ if self.cfg.differentiate:
495
+ full_df = pd.DataFrame(full_records).set_index(["circle", "square", "run"])
496
+ index_cols = ["circle", "square", "run"]
497
+ group_cols = ["circle", "square"]
498
+ else:
499
+ full_df = pd.DataFrame(full_records).set_index(["text", "run"])
500
+ index_cols = ["text", "run"]
501
+ group_cols = ["text"]
502
+ if self.cfg.n_runs > 1:
503
+ disagg_path = os.path.join(
504
+ self.cfg.save_dir, f"{base_name}_full_disaggregated.csv"
505
+ )
506
+ full_df.to_csv(disagg_path, index_label=index_cols)
507
+
508
+ # aggregate across runs using a minimum frequency threshold
509
+ def _min_freq(s: pd.Series) -> Optional[bool]:
510
+ if s.notna().sum() == 0:
511
+ return None
512
+ true_count = s.fillna(False).infer_objects(copy=False).sum()
513
+ prop = true_count / self.cfg.n_runs
514
+ return prop >= self.cfg.min_frequency
515
+
516
+ agg_df = pd.DataFrame(
517
+ {
518
+ lab: full_df[lab].groupby(group_cols).apply(_min_freq)
519
+ for lab in self.cfg.labels
520
+ }
521
+ )
522
+
523
+ filled = agg_df.dropna(how="all").shape[0]
524
+ print(f"[Classify] Filled {filled}/{len(agg_df)} unique texts.")
525
+
526
+ total = len(agg_df)
527
+ print("\n=== Label coverage (non-null) ===")
528
+ for lab in self.cfg.labels:
529
+ n = agg_df[lab].notna().sum()
530
+ print(f"{lab:<55s}: {n / total:6.2%} ({n}/{total})")
531
+ print("=================================\n")
532
+
533
+ out_path = os.path.join(self.cfg.save_dir, f"{base_name}_cleaned.csv")
534
+ if self.cfg.differentiate:
535
+ result = df_proc.merge(
536
+ agg_df,
537
+ left_on=[circle_column_name, square_column_name],
538
+ right_index=True,
539
+ how="left",
540
+ )
541
+ else:
542
+ result = df_proc.merge(
543
+ agg_df, left_on=column_name, right_index=True, how="left"
544
+ )
545
+
546
+ label_cols = list(self.cfg.labels.keys())
547
+
548
+ if not self.cfg.differentiate and column_name in result.columns:
549
+ cols = result.columns.tolist()
550
+ cols.remove(column_name)
551
+ cols.insert(0, column_name)
552
+ result = result[cols]
553
+
554
+ result.insert(1, "predicted_classes", result[label_cols].apply(_collect_predictions, axis=1))
555
+
556
+ result_to_save = result.copy()
557
+ result_to_save["predicted_classes"] = result_to_save["predicted_classes"].apply(json.dumps)
558
+ result_to_save.to_csv(out_path, index=False)
559
+
560
+ # keep raw response files for reference
561
+
562
+ return result