openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,562 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import hashlib
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import re
|
|
6
|
+
import random
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, DefaultDict, Dict, List, Optional, Set
|
|
10
|
+
import json
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from ..core.prompt_template import PromptTemplate, resolve_template
|
|
15
|
+
from ..utils.openai_utils import get_all_responses
|
|
16
|
+
from ..utils import (
|
|
17
|
+
safest_json,
|
|
18
|
+
load_image_inputs,
|
|
19
|
+
load_audio_inputs,
|
|
20
|
+
load_pdf_inputs,
|
|
21
|
+
warn_if_modality_mismatch,
|
|
22
|
+
)
|
|
23
|
+
from ..utils.logging import announce_prompt_rendering
|
|
24
|
+
from ._attribute_utils import load_persisted_attributes
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _collect_predictions(row: pd.Series) -> List[str]:
|
|
28
|
+
"""Return labels whose values evaluate to ``True``.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
row:
|
|
33
|
+
A series containing only label columns.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
list of str
|
|
38
|
+
Labels for which the value is truthy.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
return [lab for lab, val in row.items() if bool(val)]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# ────────────────────────────
|
|
45
|
+
# Configuration dataclass
|
|
46
|
+
# ────────────────────────────
|
|
47
|
+
@dataclass
|
|
48
|
+
class ClassifyConfig:
|
|
49
|
+
"""Configuration for :class:`Classify`."""
|
|
50
|
+
|
|
51
|
+
labels: Dict[str, str] # {"label_name": "description", ...}
|
|
52
|
+
save_dir: str = "classifier"
|
|
53
|
+
file_name: str = "classify_responses.csv"
|
|
54
|
+
model: str = "gpt-5-mini"
|
|
55
|
+
n_parallels: int = 650
|
|
56
|
+
n_runs: int = 1
|
|
57
|
+
min_frequency: float = 0.6
|
|
58
|
+
additional_instructions: Optional[str] = None
|
|
59
|
+
use_dummy: bool = False
|
|
60
|
+
max_timeout: Optional[float] = None
|
|
61
|
+
modality: str = "text"
|
|
62
|
+
n_attributes_per_run: int = 8
|
|
63
|
+
reasoning_effort: Optional[str] = None
|
|
64
|
+
reasoning_summary: Optional[str] = None
|
|
65
|
+
differentiate: bool = False
|
|
66
|
+
circle_first: Optional[bool] = None
|
|
67
|
+
search_context_size: str = "medium"
|
|
68
|
+
|
|
69
|
+
def __post_init__(self) -> None:
|
|
70
|
+
if self.additional_instructions is not None:
|
|
71
|
+
cleaned = str(self.additional_instructions).strip()
|
|
72
|
+
self.additional_instructions = cleaned or None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# ────────────────────────────
|
|
76
|
+
# Main Basic classifier task
|
|
77
|
+
# ────────────────────────────
|
|
78
|
+
class Classify:
|
|
79
|
+
"""Robust passage classifier using an LLM.
|
|
80
|
+
|
|
81
|
+
* Accepts a list of *texts* (not a DataFrame) just like :class:`Rate`.
|
|
82
|
+
* Persists/reads cached responses via the **save_path** attribute (same pattern as
|
|
83
|
+
:class:`Rate`).
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
_FENCE_RE = re.compile(r"```(?:json)?\s*(.*?)\s*```", re.S)
|
|
87
|
+
|
|
88
|
+
# -----------------------------------------------------------------
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
cfg: ClassifyConfig,
|
|
92
|
+
template: Optional[PromptTemplate] = None,
|
|
93
|
+
template_path: Optional[str] = None,
|
|
94
|
+
) -> None: # noqa: D401,E501
|
|
95
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
96
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
97
|
+
cfg.save_dir = str(expanded)
|
|
98
|
+
self.cfg = cfg
|
|
99
|
+
self.template = resolve_template(
|
|
100
|
+
template=template,
|
|
101
|
+
template_path=template_path,
|
|
102
|
+
reference_filename="classification_prompt.jinja2",
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# -----------------------------------------------------------------
|
|
106
|
+
# Helpers for parsing raw model output
|
|
107
|
+
# -----------------------------------------------------------------
|
|
108
|
+
@staticmethod
|
|
109
|
+
def _regex(raw: str, labels: List[str]) -> Dict[str, Optional[bool]]:
|
|
110
|
+
out: Dict[str, Optional[bool]] = {}
|
|
111
|
+
for lab in labels:
|
|
112
|
+
pat = re.compile(
|
|
113
|
+
rf'\s*"?\s*{re.escape(lab)}\s*"?\s*:\s*(true|false)', re.I | re.S
|
|
114
|
+
)
|
|
115
|
+
m = pat.search(raw)
|
|
116
|
+
out[lab] = None if not m else m.group(1).lower() == "true"
|
|
117
|
+
return out
|
|
118
|
+
|
|
119
|
+
async def _parse(self, resp: Any, labels: List[str]) -> Dict[str, Optional[bool]]:
|
|
120
|
+
# unwrap common response containers (list-of-one, bytes, fenced blocks)
|
|
121
|
+
if isinstance(resp, list) and len(resp) == 1:
|
|
122
|
+
resp = resp[0]
|
|
123
|
+
if isinstance(resp, (bytes, bytearray)):
|
|
124
|
+
resp = resp.decode()
|
|
125
|
+
data: Optional[Any] = None
|
|
126
|
+
if isinstance(resp, str):
|
|
127
|
+
m = self._FENCE_RE.search(resp)
|
|
128
|
+
if m:
|
|
129
|
+
resp = m.group(1).strip()
|
|
130
|
+
|
|
131
|
+
data = await safest_json(resp)
|
|
132
|
+
elif isinstance(resp, dict):
|
|
133
|
+
data = resp
|
|
134
|
+
if isinstance(data, dict):
|
|
135
|
+
norm = {
|
|
136
|
+
k.strip().lower(): (
|
|
137
|
+
True
|
|
138
|
+
if str(v).strip().lower() in {"true", "yes", "1"}
|
|
139
|
+
else (
|
|
140
|
+
False
|
|
141
|
+
if str(v).strip().lower() in {"false", "no", "0"}
|
|
142
|
+
else None
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
for k, v in data.items()
|
|
146
|
+
}
|
|
147
|
+
return {lab: norm.get(lab.lower(), None) for lab in labels}
|
|
148
|
+
|
|
149
|
+
# fallback to regex extraction
|
|
150
|
+
return self._regex(str(resp), labels)
|
|
151
|
+
|
|
152
|
+
# -----------------------------------------------------------------
|
|
153
|
+
# Main entry point
|
|
154
|
+
# -----------------------------------------------------------------
|
|
155
|
+
async def run(
|
|
156
|
+
self,
|
|
157
|
+
df: pd.DataFrame,
|
|
158
|
+
column_name: Optional[str] = None,
|
|
159
|
+
*,
|
|
160
|
+
circle_column_name: Optional[str] = None,
|
|
161
|
+
square_column_name: Optional[str] = None,
|
|
162
|
+
reset_files: bool = False,
|
|
163
|
+
**kwargs: Any,
|
|
164
|
+
) -> pd.DataFrame:
|
|
165
|
+
"""Classify items and return ``df`` with label columns."""
|
|
166
|
+
|
|
167
|
+
if self.cfg.differentiate:
|
|
168
|
+
if circle_column_name is None or square_column_name is None:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
"circle_column_name and square_column_name are required when differentiate is True"
|
|
171
|
+
)
|
|
172
|
+
elif column_name is None:
|
|
173
|
+
raise ValueError("column_name is required when differentiate is False")
|
|
174
|
+
|
|
175
|
+
df_proc = df.reset_index(drop=True).copy()
|
|
176
|
+
base_name = os.path.splitext(self.cfg.file_name)[0]
|
|
177
|
+
|
|
178
|
+
self.cfg.labels = load_persisted_attributes(
|
|
179
|
+
save_dir=self.cfg.save_dir,
|
|
180
|
+
incoming=self.cfg.labels,
|
|
181
|
+
reset_files=reset_files,
|
|
182
|
+
task_name="Classify",
|
|
183
|
+
item_name="labels",
|
|
184
|
+
legacy_filename=f"{base_name}_attrs.json",
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
label_items = list(self.cfg.labels.items())
|
|
188
|
+
label_count = len(label_items)
|
|
189
|
+
if label_count > self.cfg.n_attributes_per_run:
|
|
190
|
+
batches = (
|
|
191
|
+
label_count + self.cfg.n_attributes_per_run - 1
|
|
192
|
+
) // self.cfg.n_attributes_per_run
|
|
193
|
+
print(
|
|
194
|
+
f"[Classify] {label_count} labels provided. n_attributes_per_run={self.cfg.n_attributes_per_run}. "
|
|
195
|
+
f"Splitting into {batches} prompt batches. Increase n_attributes_per_run if you want all attributes "
|
|
196
|
+
"to be processed in the same prompt."
|
|
197
|
+
)
|
|
198
|
+
label_batches: List[Dict[str, str]] = [
|
|
199
|
+
dict(label_items[i : i + self.cfg.n_attributes_per_run])
|
|
200
|
+
for i in range(0, len(label_items), self.cfg.n_attributes_per_run)
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
prompts: List[str] = []
|
|
204
|
+
ids: List[str] = []
|
|
205
|
+
base_ids: List[str] = []
|
|
206
|
+
id_to_circle_first: Dict[str, bool] = {}
|
|
207
|
+
id_to_rows: DefaultDict[str, List[int]] = defaultdict(list)
|
|
208
|
+
id_to_val: Dict[str, Any] = {}
|
|
209
|
+
prompt_texts: Dict[str, str] = {}
|
|
210
|
+
prompt_circles: Dict[str, str] = {}
|
|
211
|
+
prompt_squares: Dict[str, str] = {}
|
|
212
|
+
|
|
213
|
+
if self.cfg.differentiate:
|
|
214
|
+
circles = df_proc[circle_column_name].tolist() # type: ignore[index]
|
|
215
|
+
squares = df_proc[square_column_name].tolist() # type: ignore[index]
|
|
216
|
+
warn_if_modality_mismatch(
|
|
217
|
+
circles, self.cfg.modality, column_name=str(circle_column_name)
|
|
218
|
+
)
|
|
219
|
+
warn_if_modality_mismatch(
|
|
220
|
+
squares, self.cfg.modality, column_name=str(square_column_name)
|
|
221
|
+
)
|
|
222
|
+
for row, (circ, sq) in enumerate(zip(circles, squares)):
|
|
223
|
+
clean = " ".join(str(circ).split()) + "|" + " ".join(str(sq).split())
|
|
224
|
+
sha8 = hashlib.sha1(clean.encode()).hexdigest()[:8]
|
|
225
|
+
id_to_rows[sha8].append(row)
|
|
226
|
+
if len(id_to_rows[sha8]) > 1:
|
|
227
|
+
continue
|
|
228
|
+
id_to_val[sha8] = (circ, sq)
|
|
229
|
+
prompt_circles[sha8] = (
|
|
230
|
+
circ if self.cfg.modality in {"text", "entity", "web"} else ""
|
|
231
|
+
)
|
|
232
|
+
prompt_squares[sha8] = (
|
|
233
|
+
sq if self.cfg.modality in {"text", "entity", "web"} else ""
|
|
234
|
+
)
|
|
235
|
+
circle_first_flag = (
|
|
236
|
+
self.cfg.circle_first
|
|
237
|
+
if self.cfg.circle_first is not None
|
|
238
|
+
else random.random() < 0.5
|
|
239
|
+
)
|
|
240
|
+
id_to_circle_first[sha8] = circle_first_flag
|
|
241
|
+
base_ids.append(sha8)
|
|
242
|
+
announce_prompt_rendering(
|
|
243
|
+
"Classify",
|
|
244
|
+
len(base_ids) * len(label_batches),
|
|
245
|
+
)
|
|
246
|
+
for batch_idx, batch_labels in enumerate(label_batches):
|
|
247
|
+
for ident in base_ids:
|
|
248
|
+
prompts.append(
|
|
249
|
+
self.template.render(
|
|
250
|
+
entry_circle=prompt_circles[ident],
|
|
251
|
+
entry_square=prompt_squares[ident],
|
|
252
|
+
attributes=batch_labels,
|
|
253
|
+
additional_instructions=self.cfg.additional_instructions,
|
|
254
|
+
differentiate=True,
|
|
255
|
+
modality=self.cfg.modality,
|
|
256
|
+
circle_first=id_to_circle_first[ident],
|
|
257
|
+
)
|
|
258
|
+
)
|
|
259
|
+
ids.append(f"{ident}_batch{batch_idx}")
|
|
260
|
+
else:
|
|
261
|
+
values = df_proc[column_name].tolist() # type: ignore[index]
|
|
262
|
+
warn_if_modality_mismatch(values, self.cfg.modality, column_name=str(column_name))
|
|
263
|
+
for row, val in enumerate(values):
|
|
264
|
+
clean = " ".join(str(val).split())
|
|
265
|
+
sha8 = hashlib.sha1(clean.encode()).hexdigest()[:8]
|
|
266
|
+
id_to_rows[sha8].append(row)
|
|
267
|
+
if len(id_to_rows[sha8]) > 1:
|
|
268
|
+
continue
|
|
269
|
+
id_to_val[sha8] = values[row]
|
|
270
|
+
prompt_texts[sha8] = (
|
|
271
|
+
str(values[row])
|
|
272
|
+
if self.cfg.modality in {"text", "entity", "web"}
|
|
273
|
+
else ""
|
|
274
|
+
)
|
|
275
|
+
base_ids.append(sha8)
|
|
276
|
+
announce_prompt_rendering(
|
|
277
|
+
"Classify",
|
|
278
|
+
len(base_ids) * len(label_batches),
|
|
279
|
+
)
|
|
280
|
+
for batch_idx, batch_labels in enumerate(label_batches):
|
|
281
|
+
for ident in base_ids:
|
|
282
|
+
prompts.append(
|
|
283
|
+
self.template.render(
|
|
284
|
+
text=prompt_texts[ident],
|
|
285
|
+
attributes=batch_labels,
|
|
286
|
+
additional_instructions=self.cfg.additional_instructions,
|
|
287
|
+
modality=self.cfg.modality,
|
|
288
|
+
)
|
|
289
|
+
)
|
|
290
|
+
ids.append(f"{ident}_batch{batch_idx}")
|
|
291
|
+
|
|
292
|
+
prompt_images: Optional[Dict[str, List[str]]] = None
|
|
293
|
+
prompt_audio: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
294
|
+
prompt_pdfs: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
295
|
+
|
|
296
|
+
if self.cfg.modality == "image":
|
|
297
|
+
tmp: Dict[str, List[str]] = {}
|
|
298
|
+
for ident, rows in id_to_rows.items():
|
|
299
|
+
imgs: List[str] = []
|
|
300
|
+
if self.cfg.differentiate:
|
|
301
|
+
circ, sq = id_to_val[ident]
|
|
302
|
+
circ_imgs = load_image_inputs(circ)
|
|
303
|
+
sq_imgs = load_image_inputs(sq)
|
|
304
|
+
if id_to_circle_first.get(ident, False):
|
|
305
|
+
if circ_imgs:
|
|
306
|
+
imgs.extend(circ_imgs)
|
|
307
|
+
if sq_imgs:
|
|
308
|
+
imgs.extend(sq_imgs)
|
|
309
|
+
else:
|
|
310
|
+
if sq_imgs:
|
|
311
|
+
imgs.extend(sq_imgs)
|
|
312
|
+
if circ_imgs:
|
|
313
|
+
imgs.extend(circ_imgs)
|
|
314
|
+
else:
|
|
315
|
+
imgs = load_image_inputs(id_to_val[ident])
|
|
316
|
+
if imgs:
|
|
317
|
+
for batch_idx in range(len(label_batches)):
|
|
318
|
+
tmp[f"{ident}_batch{batch_idx}"] = imgs
|
|
319
|
+
prompt_images = tmp or None
|
|
320
|
+
elif self.cfg.modality == "audio":
|
|
321
|
+
tmp_a: Dict[str, List[Dict[str, str]]] = {}
|
|
322
|
+
for ident, rows in id_to_rows.items():
|
|
323
|
+
auds: List[Dict[str, str]] = []
|
|
324
|
+
if self.cfg.differentiate:
|
|
325
|
+
circ, sq = id_to_val[ident]
|
|
326
|
+
circ_auds = load_audio_inputs(circ)
|
|
327
|
+
sq_auds = load_audio_inputs(sq)
|
|
328
|
+
if id_to_circle_first.get(ident, False):
|
|
329
|
+
if circ_auds:
|
|
330
|
+
auds.extend(circ_auds)
|
|
331
|
+
if sq_auds:
|
|
332
|
+
auds.extend(sq_auds)
|
|
333
|
+
else:
|
|
334
|
+
if sq_auds:
|
|
335
|
+
auds.extend(sq_auds)
|
|
336
|
+
if circ_auds:
|
|
337
|
+
auds.extend(circ_auds)
|
|
338
|
+
else:
|
|
339
|
+
auds = load_audio_inputs(id_to_val[ident])
|
|
340
|
+
if auds:
|
|
341
|
+
for batch_idx in range(len(label_batches)):
|
|
342
|
+
tmp_a[f"{ident}_batch{batch_idx}"] = auds
|
|
343
|
+
prompt_audio = tmp_a or None
|
|
344
|
+
elif self.cfg.modality == "pdf":
|
|
345
|
+
tmp_p: Dict[str, List[Dict[str, str]]] = {}
|
|
346
|
+
for ident, rows in id_to_rows.items():
|
|
347
|
+
pdfs: List[Dict[str, str]] = []
|
|
348
|
+
if self.cfg.differentiate:
|
|
349
|
+
circ, sq = id_to_val[ident]
|
|
350
|
+
circ_pdfs = load_pdf_inputs(circ)
|
|
351
|
+
sq_pdfs = load_pdf_inputs(sq)
|
|
352
|
+
if id_to_circle_first.get(ident, False):
|
|
353
|
+
if circ_pdfs:
|
|
354
|
+
pdfs.extend(circ_pdfs)
|
|
355
|
+
if sq_pdfs:
|
|
356
|
+
pdfs.extend(sq_pdfs)
|
|
357
|
+
else:
|
|
358
|
+
if sq_pdfs:
|
|
359
|
+
pdfs.extend(sq_pdfs)
|
|
360
|
+
if circ_pdfs:
|
|
361
|
+
pdfs.extend(circ_pdfs)
|
|
362
|
+
else:
|
|
363
|
+
pdfs = load_pdf_inputs(id_to_val[ident])
|
|
364
|
+
if pdfs:
|
|
365
|
+
for batch_idx in range(len(label_batches)):
|
|
366
|
+
tmp_p[f"{ident}_batch{batch_idx}"] = pdfs
|
|
367
|
+
prompt_pdfs = tmp_p or None
|
|
368
|
+
|
|
369
|
+
csv_path = os.path.join(self.cfg.save_dir, f"{base_name}_raw_responses.csv")
|
|
370
|
+
|
|
371
|
+
kwargs.setdefault("web_search", self.cfg.modality == "web")
|
|
372
|
+
kwargs.setdefault("search_context_size", self.cfg.search_context_size)
|
|
373
|
+
|
|
374
|
+
if not isinstance(self.cfg.n_runs, int) or self.cfg.n_runs < 1:
|
|
375
|
+
raise ValueError("n_runs must be an integer >= 1")
|
|
376
|
+
|
|
377
|
+
existing_ids: Set[str] = set()
|
|
378
|
+
if not reset_files and os.path.exists(csv_path):
|
|
379
|
+
try:
|
|
380
|
+
existing_df = pd.read_csv(csv_path, usecols=["Identifier"])
|
|
381
|
+
existing_ids = set(existing_df["Identifier"].astype(str))
|
|
382
|
+
except Exception:
|
|
383
|
+
existing_ids = set()
|
|
384
|
+
|
|
385
|
+
run_identifier_lists: List[List[str]] = []
|
|
386
|
+
for run_idx in range(1, self.cfg.n_runs + 1):
|
|
387
|
+
run_ids: List[str] = []
|
|
388
|
+
for ident in ids:
|
|
389
|
+
if run_idx == 1:
|
|
390
|
+
legacy_ident = f"{ident}_run1"
|
|
391
|
+
run_ids.append(legacy_ident if legacy_ident in existing_ids else ident)
|
|
392
|
+
else:
|
|
393
|
+
run_ids.append(f"{ident}_run{run_idx}")
|
|
394
|
+
run_identifier_lists.append(run_ids)
|
|
395
|
+
|
|
396
|
+
prompts_all: List[str] = []
|
|
397
|
+
ids_all: List[str] = []
|
|
398
|
+
for run_ids in run_identifier_lists:
|
|
399
|
+
prompts_all.extend(prompts)
|
|
400
|
+
ids_all.extend(run_ids)
|
|
401
|
+
|
|
402
|
+
prompt_images_all: Optional[Dict[str, List[str]]] = None
|
|
403
|
+
if prompt_images:
|
|
404
|
+
prompt_images_all = {}
|
|
405
|
+
for run_ids in run_identifier_lists:
|
|
406
|
+
for base_ident, run_ident in zip(ids, run_ids):
|
|
407
|
+
imgs = prompt_images.get(base_ident)
|
|
408
|
+
if imgs:
|
|
409
|
+
prompt_images_all[run_ident] = imgs
|
|
410
|
+
prompt_audio_all: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
411
|
+
if prompt_audio:
|
|
412
|
+
prompt_audio_all = {}
|
|
413
|
+
for run_ids in run_identifier_lists:
|
|
414
|
+
for base_ident, run_ident in zip(ids, run_ids):
|
|
415
|
+
auds = prompt_audio.get(base_ident)
|
|
416
|
+
if auds:
|
|
417
|
+
prompt_audio_all[run_ident] = auds
|
|
418
|
+
prompt_pdfs_all: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
419
|
+
if prompt_pdfs:
|
|
420
|
+
prompt_pdfs_all = {}
|
|
421
|
+
for run_ids in run_identifier_lists:
|
|
422
|
+
for base_ident, run_ident in zip(ids, run_ids):
|
|
423
|
+
pdfs = prompt_pdfs.get(base_ident)
|
|
424
|
+
if pdfs:
|
|
425
|
+
prompt_pdfs_all[run_ident] = pdfs
|
|
426
|
+
|
|
427
|
+
df_resp_all = await get_all_responses(
|
|
428
|
+
prompts=prompts_all,
|
|
429
|
+
identifiers=ids_all,
|
|
430
|
+
prompt_images=prompt_images_all,
|
|
431
|
+
prompt_audio=prompt_audio_all,
|
|
432
|
+
prompt_pdfs=prompt_pdfs_all,
|
|
433
|
+
n_parallels=self.cfg.n_parallels,
|
|
434
|
+
save_path=csv_path,
|
|
435
|
+
reset_files=reset_files,
|
|
436
|
+
json_mode=self.cfg.modality != "audio",
|
|
437
|
+
model=self.cfg.model,
|
|
438
|
+
use_dummy=self.cfg.use_dummy,
|
|
439
|
+
max_timeout=self.cfg.max_timeout,
|
|
440
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
441
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
442
|
+
print_example_prompt=True,
|
|
443
|
+
**kwargs,
|
|
444
|
+
)
|
|
445
|
+
if not isinstance(df_resp_all, pd.DataFrame):
|
|
446
|
+
raise RuntimeError("get_all_responses returned no DataFrame")
|
|
447
|
+
|
|
448
|
+
df_resps = []
|
|
449
|
+
for run_idx, run_ids in enumerate(run_identifier_lists, start=1):
|
|
450
|
+
suffix = f"_run{run_idx}"
|
|
451
|
+
sub = df_resp_all[df_resp_all.Identifier.isin(run_ids)].copy()
|
|
452
|
+
sub.Identifier = sub.Identifier.str.replace(
|
|
453
|
+
suffix + "$", "", regex=True
|
|
454
|
+
)
|
|
455
|
+
df_resps.append(sub)
|
|
456
|
+
|
|
457
|
+
# parse each run and construct disaggregated records
|
|
458
|
+
full_records: List[Dict[str, Any]] = []
|
|
459
|
+
total_orphans = 0
|
|
460
|
+
all_labels = list(self.cfg.labels.keys())
|
|
461
|
+
for run_idx, df_resp in enumerate(df_resps, start=1):
|
|
462
|
+
id_to_labels: Dict[str, Dict[str, Optional[bool]]] = {
|
|
463
|
+
ident: {lab: None for lab in all_labels} for ident in base_ids
|
|
464
|
+
}
|
|
465
|
+
orphans = 0
|
|
466
|
+
for ident_batch, raw in zip(df_resp.Identifier, df_resp.Response):
|
|
467
|
+
if "_batch" not in ident_batch:
|
|
468
|
+
continue
|
|
469
|
+
base_ident, batch_part = ident_batch.rsplit("_batch", 1)
|
|
470
|
+
if base_ident not in id_to_rows:
|
|
471
|
+
orphans += 1
|
|
472
|
+
continue
|
|
473
|
+
batch_idx = int(batch_part)
|
|
474
|
+
labs = list(label_batches[batch_idx].keys())
|
|
475
|
+
parsed = await self._parse(raw, labs)
|
|
476
|
+
for lab in labs:
|
|
477
|
+
id_to_labels[base_ident][lab] = parsed.get(lab)
|
|
478
|
+
total_orphans += orphans
|
|
479
|
+
for ident in base_ids:
|
|
480
|
+
parsed = id_to_labels.get(ident, {lab: None for lab in all_labels})
|
|
481
|
+
if self.cfg.differentiate:
|
|
482
|
+
circ_val, sq_val = id_to_val[ident]
|
|
483
|
+
rec = {"circle": circ_val, "square": sq_val, "run": run_idx}
|
|
484
|
+
else:
|
|
485
|
+
rec = {"text": id_to_val[ident], "run": run_idx}
|
|
486
|
+
rec.update({lab: parsed.get(lab) for lab in all_labels})
|
|
487
|
+
full_records.append(rec)
|
|
488
|
+
|
|
489
|
+
if total_orphans:
|
|
490
|
+
print(
|
|
491
|
+
f"[Classify] WARNING: {total_orphans} response(s) had no matching passage this run."
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
if self.cfg.differentiate:
|
|
495
|
+
full_df = pd.DataFrame(full_records).set_index(["circle", "square", "run"])
|
|
496
|
+
index_cols = ["circle", "square", "run"]
|
|
497
|
+
group_cols = ["circle", "square"]
|
|
498
|
+
else:
|
|
499
|
+
full_df = pd.DataFrame(full_records).set_index(["text", "run"])
|
|
500
|
+
index_cols = ["text", "run"]
|
|
501
|
+
group_cols = ["text"]
|
|
502
|
+
if self.cfg.n_runs > 1:
|
|
503
|
+
disagg_path = os.path.join(
|
|
504
|
+
self.cfg.save_dir, f"{base_name}_full_disaggregated.csv"
|
|
505
|
+
)
|
|
506
|
+
full_df.to_csv(disagg_path, index_label=index_cols)
|
|
507
|
+
|
|
508
|
+
# aggregate across runs using a minimum frequency threshold
|
|
509
|
+
def _min_freq(s: pd.Series) -> Optional[bool]:
|
|
510
|
+
if s.notna().sum() == 0:
|
|
511
|
+
return None
|
|
512
|
+
true_count = s.fillna(False).infer_objects(copy=False).sum()
|
|
513
|
+
prop = true_count / self.cfg.n_runs
|
|
514
|
+
return prop >= self.cfg.min_frequency
|
|
515
|
+
|
|
516
|
+
agg_df = pd.DataFrame(
|
|
517
|
+
{
|
|
518
|
+
lab: full_df[lab].groupby(group_cols).apply(_min_freq)
|
|
519
|
+
for lab in self.cfg.labels
|
|
520
|
+
}
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
filled = agg_df.dropna(how="all").shape[0]
|
|
524
|
+
print(f"[Classify] Filled {filled}/{len(agg_df)} unique texts.")
|
|
525
|
+
|
|
526
|
+
total = len(agg_df)
|
|
527
|
+
print("\n=== Label coverage (non-null) ===")
|
|
528
|
+
for lab in self.cfg.labels:
|
|
529
|
+
n = agg_df[lab].notna().sum()
|
|
530
|
+
print(f"{lab:<55s}: {n / total:6.2%} ({n}/{total})")
|
|
531
|
+
print("=================================\n")
|
|
532
|
+
|
|
533
|
+
out_path = os.path.join(self.cfg.save_dir, f"{base_name}_cleaned.csv")
|
|
534
|
+
if self.cfg.differentiate:
|
|
535
|
+
result = df_proc.merge(
|
|
536
|
+
agg_df,
|
|
537
|
+
left_on=[circle_column_name, square_column_name],
|
|
538
|
+
right_index=True,
|
|
539
|
+
how="left",
|
|
540
|
+
)
|
|
541
|
+
else:
|
|
542
|
+
result = df_proc.merge(
|
|
543
|
+
agg_df, left_on=column_name, right_index=True, how="left"
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
label_cols = list(self.cfg.labels.keys())
|
|
547
|
+
|
|
548
|
+
if not self.cfg.differentiate and column_name in result.columns:
|
|
549
|
+
cols = result.columns.tolist()
|
|
550
|
+
cols.remove(column_name)
|
|
551
|
+
cols.insert(0, column_name)
|
|
552
|
+
result = result[cols]
|
|
553
|
+
|
|
554
|
+
result.insert(1, "predicted_classes", result[label_cols].apply(_collect_predictions, axis=1))
|
|
555
|
+
|
|
556
|
+
result_to_save = result.copy()
|
|
557
|
+
result_to_save["predicted_classes"] = result_to_save["predicted_classes"].apply(json.dumps)
|
|
558
|
+
result_to_save.to_csv(out_path, index=False)
|
|
559
|
+
|
|
560
|
+
# keep raw response files for reference
|
|
561
|
+
|
|
562
|
+
return result
|