openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
import zipfile
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from .codify import Codify, CodifyConfig
|
|
16
|
+
from .compare import Compare, CompareConfig
|
|
17
|
+
from .bucket import Bucket, BucketConfig
|
|
18
|
+
from .classify import Classify, ClassifyConfig
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class DiscoverConfig:
|
|
23
|
+
"""Configuration for :class:`Discover`."""
|
|
24
|
+
|
|
25
|
+
save_dir: str = "discover"
|
|
26
|
+
model: str = "gpt-5-mini"
|
|
27
|
+
n_parallels: int = 650
|
|
28
|
+
n_runs: int = 1
|
|
29
|
+
min_frequency: float = 0.6
|
|
30
|
+
bucket_count: int = 10
|
|
31
|
+
additional_instructions: Optional[str] = None
|
|
32
|
+
differentiate: bool = True
|
|
33
|
+
max_words_per_call: int = 1000
|
|
34
|
+
max_categories_per_call: int = 8
|
|
35
|
+
use_dummy: bool = False
|
|
36
|
+
modality: str = "text"
|
|
37
|
+
n_terms_per_prompt: int = 250
|
|
38
|
+
repeat_bucketing: int = 5
|
|
39
|
+
repeat_voting: int = 25
|
|
40
|
+
next_round_frac: float = 0.25
|
|
41
|
+
top_k_per_round: int = 1
|
|
42
|
+
raw_term_definitions: bool = True
|
|
43
|
+
reasoning_effort: Optional[str] = None
|
|
44
|
+
reasoning_summary: Optional[str] = None
|
|
45
|
+
max_timeout: Optional[float] = None
|
|
46
|
+
|
|
47
|
+
def __post_init__(self) -> None:
|
|
48
|
+
if self.additional_instructions is not None:
|
|
49
|
+
cleaned = str(self.additional_instructions).strip()
|
|
50
|
+
self.additional_instructions = cleaned or None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Discover:
|
|
54
|
+
"""High-level feature discovery pipeline.
|
|
55
|
+
|
|
56
|
+
Depending on the inputs, the pipeline will either:
|
|
57
|
+
1. Use :class:`Codify` to discover raw feature candidates from a single column, or
|
|
58
|
+
2. Use :class:`Compare` to surface differentiating attributes between two columns.
|
|
59
|
+
|
|
60
|
+
The discovered terms are then grouped into buckets via :class:`Bucket` and finally
|
|
61
|
+
applied back onto the dataset using :class:`Classify`.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, cfg: DiscoverConfig) -> None:
|
|
65
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
66
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
cfg.save_dir = str(expanded)
|
|
68
|
+
self.cfg = cfg
|
|
69
|
+
|
|
70
|
+
def _to_serializable(self, value: Any) -> Any:
|
|
71
|
+
if isinstance(value, pd.DataFrame):
|
|
72
|
+
safe = value.copy()
|
|
73
|
+
safe = safe.where(pd.notna(safe), None)
|
|
74
|
+
return safe.to_dict(orient="records")
|
|
75
|
+
if isinstance(value, dict):
|
|
76
|
+
return {str(k): self._to_serializable(v) for k, v in value.items()}
|
|
77
|
+
if isinstance(value, (list, tuple)):
|
|
78
|
+
return [self._to_serializable(v) for v in value]
|
|
79
|
+
if isinstance(value, pd.Series):
|
|
80
|
+
safe_series = value.where(pd.notna(value), None)
|
|
81
|
+
return safe_series.tolist()
|
|
82
|
+
try:
|
|
83
|
+
if pd.isna(value): # type: ignore[arg-type]
|
|
84
|
+
return None
|
|
85
|
+
except Exception:
|
|
86
|
+
pass
|
|
87
|
+
if isinstance(value, (pd.Timestamp, pd.Timedelta)):
|
|
88
|
+
return value.isoformat()
|
|
89
|
+
return value
|
|
90
|
+
|
|
91
|
+
def _persist_result_snapshot(self, result: Dict[str, Any]) -> None:
|
|
92
|
+
payload = {
|
|
93
|
+
"generated_at": datetime.utcnow().isoformat() + "Z",
|
|
94
|
+
"results": {k: self._to_serializable(v) for k, v in result.items()},
|
|
95
|
+
}
|
|
96
|
+
out_path = os.path.join(self.cfg.save_dir, "discover_results_snapshot.json")
|
|
97
|
+
try:
|
|
98
|
+
with open(out_path, "w", encoding="utf-8") as f:
|
|
99
|
+
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
100
|
+
except Exception:
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
def _value_to_dataframe(self, name: str, value: Any) -> Optional[pd.DataFrame]:
|
|
104
|
+
if isinstance(value, pd.DataFrame):
|
|
105
|
+
return value.copy()
|
|
106
|
+
if isinstance(value, pd.Series):
|
|
107
|
+
return value.to_frame().reset_index(drop=True)
|
|
108
|
+
if isinstance(value, dict):
|
|
109
|
+
rows = [{"key": str(k), "value": v} for k, v in value.items()]
|
|
110
|
+
df = pd.DataFrame(rows)
|
|
111
|
+
if name == "buckets":
|
|
112
|
+
df = df.rename(columns={"key": "bucket", "value": "definition"})
|
|
113
|
+
return df
|
|
114
|
+
if isinstance(value, list):
|
|
115
|
+
if not value:
|
|
116
|
+
return pd.DataFrame()
|
|
117
|
+
if all(isinstance(item, dict) for item in value):
|
|
118
|
+
return pd.DataFrame(value)
|
|
119
|
+
return pd.DataFrame({"value": value})
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
def _export_result_archive(self, result: Dict[str, Any]) -> None:
|
|
123
|
+
tables: Dict[str, pd.DataFrame] = {}
|
|
124
|
+
for key, value in result.items():
|
|
125
|
+
df = self._value_to_dataframe(key, value)
|
|
126
|
+
if df is not None:
|
|
127
|
+
tables[key] = df
|
|
128
|
+
if not tables:
|
|
129
|
+
return
|
|
130
|
+
archive_path = os.path.join(self.cfg.save_dir, "discover_results_export.zip")
|
|
131
|
+
try:
|
|
132
|
+
with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
|
133
|
+
name_counts: Dict[str, int] = {}
|
|
134
|
+
for key, df in tables.items():
|
|
135
|
+
csv_buffer = io.StringIO()
|
|
136
|
+
df.to_csv(csv_buffer, index=False)
|
|
137
|
+
safe = re.sub(r"[^0-9A-Za-z._-]+", "_", key).strip("._-") or "table"
|
|
138
|
+
count = name_counts.get(safe, 0)
|
|
139
|
+
name_counts[safe] = count + 1
|
|
140
|
+
if count:
|
|
141
|
+
safe = f"{safe}_{count}"
|
|
142
|
+
filename = f"{safe}.csv"
|
|
143
|
+
zf.writestr(filename, csv_buffer.getvalue())
|
|
144
|
+
except Exception:
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
async def run(
|
|
149
|
+
self,
|
|
150
|
+
df: pd.DataFrame,
|
|
151
|
+
*,
|
|
152
|
+
column_name: Optional[str] = None,
|
|
153
|
+
circle_column_name: Optional[str] = None,
|
|
154
|
+
square_column_name: Optional[str] = None,
|
|
155
|
+
reset_files: bool = False,
|
|
156
|
+
**kwargs: Any,
|
|
157
|
+
) -> Dict[str, Any]:
|
|
158
|
+
"""Execute the discovery pipeline.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
df: Input dataframe.
|
|
162
|
+
column_name: Column to analyse when using a single column pipeline.
|
|
163
|
+
circle_column_name: First column when contrasting two columns.
|
|
164
|
+
square_column_name: Second column when contrasting two columns.
|
|
165
|
+
reset_files: Forwarded to underlying tasks to control caching.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Dictionary with intermediate and final results. Keys include:
|
|
169
|
+
``candidates`` (raw candidate terms), ``buckets`` (bucket definitions),
|
|
170
|
+
``classification`` (original dataframe with label columns), ``summary`` (if
|
|
171
|
+
circle/square columns were provided) containing per-label differences (``difference_pct``
|
|
172
|
+
expresses circle minus square in percentage points),
|
|
173
|
+
and optionally
|
|
174
|
+
``compare`` or ``codify`` depending on which stage was used for candidate
|
|
175
|
+
generation.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
single = column_name is not None
|
|
179
|
+
pair = circle_column_name is not None and square_column_name is not None
|
|
180
|
+
if single == pair:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"Provide either column_name or both circle_column_name and square_column_name"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if single:
|
|
186
|
+
self.cfg.differentiate = False
|
|
187
|
+
elif pair:
|
|
188
|
+
self.cfg.differentiate = True
|
|
189
|
+
|
|
190
|
+
compare_df: Optional[pd.DataFrame] = None
|
|
191
|
+
codify_df: Optional[pd.DataFrame] = None
|
|
192
|
+
|
|
193
|
+
# ── 1. candidate discovery ─────────────────────────────────────
|
|
194
|
+
if single:
|
|
195
|
+
coder_cfg = CodifyConfig(
|
|
196
|
+
save_dir=os.path.join(self.cfg.save_dir, "codify"),
|
|
197
|
+
model=self.cfg.model,
|
|
198
|
+
n_parallels=self.cfg.n_parallels,
|
|
199
|
+
max_words_per_call=self.cfg.max_words_per_call,
|
|
200
|
+
max_categories_per_call=self.cfg.max_categories_per_call,
|
|
201
|
+
debug_print=False,
|
|
202
|
+
use_dummy=self.cfg.use_dummy,
|
|
203
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
204
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
205
|
+
max_timeout=self.cfg.max_timeout,
|
|
206
|
+
)
|
|
207
|
+
coder = Codify(coder_cfg)
|
|
208
|
+
codify_df = await coder.run(
|
|
209
|
+
df,
|
|
210
|
+
column_name, # type: ignore[arg-type]
|
|
211
|
+
categories=None,
|
|
212
|
+
additional_instructions=self.cfg.additional_instructions or "",
|
|
213
|
+
reset_files=reset_files,
|
|
214
|
+
)
|
|
215
|
+
term_defs: Dict[str, str] = {}
|
|
216
|
+
if "coded_passages" in codify_df:
|
|
217
|
+
for entry in codify_df["coded_passages"].dropna():
|
|
218
|
+
if isinstance(entry, dict):
|
|
219
|
+
for k, v in entry.items():
|
|
220
|
+
if k not in term_defs:
|
|
221
|
+
if isinstance(v, list) and v:
|
|
222
|
+
term_defs[k] = str(v[0])
|
|
223
|
+
else:
|
|
224
|
+
term_defs[k] = str(v) if v is not None else ""
|
|
225
|
+
if self.cfg.raw_term_definitions:
|
|
226
|
+
candidate_df = (
|
|
227
|
+
pd.DataFrame({"term": [term_defs]})
|
|
228
|
+
if term_defs
|
|
229
|
+
else pd.DataFrame({"term": []})
|
|
230
|
+
)
|
|
231
|
+
else:
|
|
232
|
+
candidate_df = pd.DataFrame({"term": sorted(set(term_defs.keys()))})
|
|
233
|
+
else:
|
|
234
|
+
cmp_cfg = CompareConfig(
|
|
235
|
+
save_dir=os.path.join(self.cfg.save_dir, "compare"),
|
|
236
|
+
model=self.cfg.model,
|
|
237
|
+
n_parallels=self.cfg.n_parallels,
|
|
238
|
+
use_dummy=self.cfg.use_dummy,
|
|
239
|
+
max_timeout=self.cfg.max_timeout,
|
|
240
|
+
differentiate=self.cfg.differentiate,
|
|
241
|
+
additional_instructions=self.cfg.additional_instructions,
|
|
242
|
+
modality=self.cfg.modality,
|
|
243
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
244
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
245
|
+
)
|
|
246
|
+
cmp = Compare(cmp_cfg)
|
|
247
|
+
compare_df = await cmp.run(
|
|
248
|
+
df,
|
|
249
|
+
circle_column_name, # type: ignore[arg-type]
|
|
250
|
+
square_column_name, # type: ignore[arg-type]
|
|
251
|
+
reset_files=reset_files,
|
|
252
|
+
)
|
|
253
|
+
term_defs = {}
|
|
254
|
+
for attr, expl in zip(
|
|
255
|
+
compare_df["attribute"], compare_df["explanation"]
|
|
256
|
+
):
|
|
257
|
+
if pd.notna(attr) and attr not in term_defs:
|
|
258
|
+
term_defs[attr] = str(expl) if pd.notna(expl) else ""
|
|
259
|
+
if self.cfg.raw_term_definitions:
|
|
260
|
+
candidate_df = (
|
|
261
|
+
pd.DataFrame({"term": [term_defs]})
|
|
262
|
+
if term_defs
|
|
263
|
+
else pd.DataFrame({"term": []})
|
|
264
|
+
)
|
|
265
|
+
else:
|
|
266
|
+
candidate_df = pd.DataFrame({"term": sorted(set(term_defs.keys()))})
|
|
267
|
+
|
|
268
|
+
# ── 2. bucketisation ───────────────────────────────────────────
|
|
269
|
+
bucket_df: pd.DataFrame
|
|
270
|
+
if candidate_df.empty:
|
|
271
|
+
bucket_df = pd.DataFrame(columns=["bucket", "definition"])
|
|
272
|
+
else:
|
|
273
|
+
buck_cfg = BucketConfig(
|
|
274
|
+
bucket_count=self.cfg.bucket_count,
|
|
275
|
+
save_dir=os.path.join(self.cfg.save_dir, "bucket"),
|
|
276
|
+
model=self.cfg.model,
|
|
277
|
+
n_parallels=self.cfg.n_parallels,
|
|
278
|
+
use_dummy=self.cfg.use_dummy,
|
|
279
|
+
additional_instructions=self.cfg.additional_instructions,
|
|
280
|
+
differentiate=self.cfg.differentiate if pair else False,
|
|
281
|
+
n_terms_per_prompt=self.cfg.n_terms_per_prompt,
|
|
282
|
+
repeat_bucketing=self.cfg.repeat_bucketing,
|
|
283
|
+
repeat_voting=self.cfg.repeat_voting,
|
|
284
|
+
next_round_frac=self.cfg.next_round_frac,
|
|
285
|
+
top_k_per_round=self.cfg.top_k_per_round,
|
|
286
|
+
raw_term_definitions=self.cfg.raw_term_definitions,
|
|
287
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
288
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
289
|
+
max_timeout=self.cfg.max_timeout,
|
|
290
|
+
)
|
|
291
|
+
buck = Bucket(buck_cfg)
|
|
292
|
+
bucket_df = await buck.run(
|
|
293
|
+
candidate_df,
|
|
294
|
+
"term",
|
|
295
|
+
reset_files=reset_files,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
labels = (
|
|
299
|
+
dict(zip(bucket_df["bucket"], bucket_df["definition"]))
|
|
300
|
+
if not bucket_df.empty
|
|
301
|
+
else {}
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# ── 3. classification ──────────────────────────────────────────
|
|
305
|
+
classify_result: pd.DataFrame
|
|
306
|
+
summary_df: Optional[pd.DataFrame] = None
|
|
307
|
+
if not labels:
|
|
308
|
+
classify_result = df.reset_index(drop=True).copy()
|
|
309
|
+
elif pair:
|
|
310
|
+
base_cfg = {
|
|
311
|
+
"model": self.cfg.model,
|
|
312
|
+
"n_parallels": self.cfg.n_parallels,
|
|
313
|
+
"n_runs": self.cfg.n_runs,
|
|
314
|
+
"min_frequency": self.cfg.min_frequency,
|
|
315
|
+
"use_dummy": self.cfg.use_dummy,
|
|
316
|
+
"modality": self.cfg.modality,
|
|
317
|
+
"reasoning_effort": self.cfg.reasoning_effort,
|
|
318
|
+
"reasoning_summary": self.cfg.reasoning_summary,
|
|
319
|
+
"n_attributes_per_run": 8,
|
|
320
|
+
"differentiate": True,
|
|
321
|
+
"additional_instructions": self.cfg.additional_instructions or "",
|
|
322
|
+
"max_timeout": self.cfg.max_timeout,
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
def swap_cs(text: str) -> str:
|
|
326
|
+
def repl(match: re.Match[str]) -> str:
|
|
327
|
+
word = match.group(0)
|
|
328
|
+
return "square" if word.lower() == "circle" else "circle"
|
|
329
|
+
return re.sub(r"(?i)circle|square", repl, text)
|
|
330
|
+
|
|
331
|
+
def build_combined_metadata(
|
|
332
|
+
base_labels: Dict[str, str]
|
|
333
|
+
) -> Tuple[Dict[str, str], Dict[str, str]]:
|
|
334
|
+
combined_local: Dict[str, str] = {}
|
|
335
|
+
rename_local: Dict[str, str] = {}
|
|
336
|
+
for lab, desc in base_labels.items():
|
|
337
|
+
actual_key = lab
|
|
338
|
+
swapped_lab = swap_cs(lab)
|
|
339
|
+
if swapped_lab == lab:
|
|
340
|
+
swapped_lab = f"{lab} (inverted)"
|
|
341
|
+
combined_local[actual_key] = desc
|
|
342
|
+
rename_local[actual_key] = f"{lab}_actual"
|
|
343
|
+
combined_local[swapped_lab] = swap_cs(desc)
|
|
344
|
+
rename_local[swapped_lab] = f"{lab}_inverted"
|
|
345
|
+
return combined_local, rename_local
|
|
346
|
+
|
|
347
|
+
def derive_base_from_combined(
|
|
348
|
+
combined_local: Dict[str, str]
|
|
349
|
+
) -> Tuple[Dict[str, str], Dict[str, str]]:
|
|
350
|
+
base_local: Dict[str, str] = {}
|
|
351
|
+
rename_local: Dict[str, str] = {}
|
|
352
|
+
processed: Set[str] = set()
|
|
353
|
+
for key, desc in combined_local.items():
|
|
354
|
+
if key in processed:
|
|
355
|
+
continue
|
|
356
|
+
swapped_key = swap_cs(key)
|
|
357
|
+
inverted_key: Optional[str] = None
|
|
358
|
+
if swapped_key != key and swapped_key in combined_local:
|
|
359
|
+
inverted_key = swapped_key
|
|
360
|
+
else:
|
|
361
|
+
candidate = f"{key} (inverted)"
|
|
362
|
+
if candidate in combined_local:
|
|
363
|
+
inverted_key = candidate
|
|
364
|
+
canonical_key = key
|
|
365
|
+
canonical_desc = desc
|
|
366
|
+
if canonical_key.endswith(" (inverted)") and swapped_key in combined_local:
|
|
367
|
+
canonical_key = swapped_key
|
|
368
|
+
canonical_desc = combined_local[canonical_key]
|
|
369
|
+
inverted_key = key
|
|
370
|
+
base_local[canonical_key] = canonical_desc
|
|
371
|
+
rename_local[canonical_key] = f"{canonical_key}_actual"
|
|
372
|
+
if inverted_key and inverted_key in combined_local:
|
|
373
|
+
rename_local[inverted_key] = f"{canonical_key}_inverted"
|
|
374
|
+
processed.add(inverted_key)
|
|
375
|
+
processed.add(canonical_key)
|
|
376
|
+
return base_local, rename_local
|
|
377
|
+
|
|
378
|
+
combined_labels, rename_map = build_combined_metadata(labels)
|
|
379
|
+
|
|
380
|
+
clf_cfg = ClassifyConfig(
|
|
381
|
+
labels=combined_labels,
|
|
382
|
+
save_dir=os.path.join(self.cfg.save_dir, "classify"),
|
|
383
|
+
**base_cfg, # type: ignore[arg-type]
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
clf = Classify(clf_cfg)
|
|
387
|
+
|
|
388
|
+
combined_df = await clf.run(
|
|
389
|
+
df,
|
|
390
|
+
circle_column_name=circle_column_name, # type: ignore[arg-type]
|
|
391
|
+
square_column_name=square_column_name, # type: ignore[arg-type]
|
|
392
|
+
reset_files=reset_files,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
actual_combined_labels = dict(clf.cfg.labels)
|
|
396
|
+
if set(actual_combined_labels.keys()) != set(combined_labels.keys()):
|
|
397
|
+
print(
|
|
398
|
+
"[Discover] Detected mismatch between cached classifier labels and generated buckets; "
|
|
399
|
+
"using cached labels from the classification cache instead."
|
|
400
|
+
)
|
|
401
|
+
combined_labels = actual_combined_labels
|
|
402
|
+
labels, rename_map = derive_base_from_combined(combined_labels)
|
|
403
|
+
if labels:
|
|
404
|
+
bucket_df = pd.DataFrame(
|
|
405
|
+
{
|
|
406
|
+
"bucket": list(labels.keys()),
|
|
407
|
+
"definition": list(labels.values()),
|
|
408
|
+
}
|
|
409
|
+
)
|
|
410
|
+
else:
|
|
411
|
+
bucket_df = pd.DataFrame(columns=["bucket", "definition"])
|
|
412
|
+
|
|
413
|
+
classify_result = combined_df.rename(columns=rename_map)
|
|
414
|
+
|
|
415
|
+
available: Dict[str, str] = {}
|
|
416
|
+
missing: List[str] = []
|
|
417
|
+
for lab, desc in labels.items():
|
|
418
|
+
actual_col = f"{lab}_actual"
|
|
419
|
+
inverted_col = f"{lab}_inverted"
|
|
420
|
+
if (
|
|
421
|
+
actual_col not in classify_result.columns
|
|
422
|
+
or inverted_col not in classify_result.columns
|
|
423
|
+
):
|
|
424
|
+
missing.append(lab)
|
|
425
|
+
continue
|
|
426
|
+
available[lab] = desc
|
|
427
|
+
if missing:
|
|
428
|
+
print(
|
|
429
|
+
"[Discover] Warning: classification cache is missing the following labels, "
|
|
430
|
+
"so they were skipped:",
|
|
431
|
+
", ".join(missing),
|
|
432
|
+
)
|
|
433
|
+
if not bucket_df.empty:
|
|
434
|
+
bucket_df = bucket_df[bucket_df["bucket"].isin(available.keys())]
|
|
435
|
+
labels = available
|
|
436
|
+
|
|
437
|
+
summary_records: List[Dict[str, Any]] = []
|
|
438
|
+
for lab in labels:
|
|
439
|
+
actual_col = f"{lab}_actual"
|
|
440
|
+
inverted_col = f"{lab}_inverted"
|
|
441
|
+
actual_true = (
|
|
442
|
+
classify_result[actual_col]
|
|
443
|
+
.fillna(False)
|
|
444
|
+
.infer_objects(copy=False)
|
|
445
|
+
.sum()
|
|
446
|
+
)
|
|
447
|
+
inverted_true = (
|
|
448
|
+
classify_result[inverted_col]
|
|
449
|
+
.fillna(False)
|
|
450
|
+
.infer_objects(copy=False)
|
|
451
|
+
.sum()
|
|
452
|
+
)
|
|
453
|
+
total = classify_result[[actual_col, inverted_col]].notna().any(axis=1).sum()
|
|
454
|
+
actual_pct = (actual_true / total * 100) if total else None
|
|
455
|
+
inverted_pct = (inverted_true / total * 100) if total else None
|
|
456
|
+
net_pct = (
|
|
457
|
+
(actual_pct - inverted_pct)
|
|
458
|
+
if actual_pct is not None and inverted_pct is not None
|
|
459
|
+
else None
|
|
460
|
+
)
|
|
461
|
+
summary_records.append({
|
|
462
|
+
"label": lab,
|
|
463
|
+
"actual_true": actual_true,
|
|
464
|
+
"inverted_true": inverted_true,
|
|
465
|
+
"total": total,
|
|
466
|
+
"actual_pct": actual_pct,
|
|
467
|
+
"inverted_pct": inverted_pct,
|
|
468
|
+
"net_pct": net_pct,
|
|
469
|
+
})
|
|
470
|
+
summary_df = pd.DataFrame(summary_records)
|
|
471
|
+
else:
|
|
472
|
+
clf_cfg = ClassifyConfig(
|
|
473
|
+
labels=labels,
|
|
474
|
+
save_dir=os.path.join(self.cfg.save_dir, "classify"),
|
|
475
|
+
model=self.cfg.model,
|
|
476
|
+
n_parallels=self.cfg.n_parallels,
|
|
477
|
+
n_runs=self.cfg.n_runs,
|
|
478
|
+
min_frequency=self.cfg.min_frequency,
|
|
479
|
+
additional_instructions=self.cfg.additional_instructions or "",
|
|
480
|
+
use_dummy=self.cfg.use_dummy,
|
|
481
|
+
modality=self.cfg.modality,
|
|
482
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
483
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
484
|
+
n_attributes_per_run=8,
|
|
485
|
+
max_timeout=self.cfg.max_timeout,
|
|
486
|
+
)
|
|
487
|
+
clf = Classify(clf_cfg)
|
|
488
|
+
classify_result = await clf.run(
|
|
489
|
+
df,
|
|
490
|
+
column_name, # type: ignore[arg-type]
|
|
491
|
+
reset_files=reset_files,
|
|
492
|
+
)
|
|
493
|
+
actual_labels = dict(clf.cfg.labels)
|
|
494
|
+
if actual_labels != labels:
|
|
495
|
+
print(
|
|
496
|
+
"[Discover] Detected mismatch between cached classifier labels and generated buckets; "
|
|
497
|
+
"using cached labels from the classification cache instead."
|
|
498
|
+
)
|
|
499
|
+
labels = actual_labels
|
|
500
|
+
if labels:
|
|
501
|
+
bucket_df = pd.DataFrame(
|
|
502
|
+
{"bucket": list(labels.keys()), "definition": list(labels.values())}
|
|
503
|
+
)
|
|
504
|
+
else:
|
|
505
|
+
bucket_df = pd.DataFrame(columns=["bucket", "definition"])
|
|
506
|
+
else:
|
|
507
|
+
labels = actual_labels
|
|
508
|
+
|
|
509
|
+
result: Dict[str, Any] = {
|
|
510
|
+
"candidates": candidate_df,
|
|
511
|
+
"buckets": labels,
|
|
512
|
+
"classification": classify_result,
|
|
513
|
+
}
|
|
514
|
+
if not bucket_df.empty:
|
|
515
|
+
result["bucket_df"] = bucket_df
|
|
516
|
+
if summary_df is not None:
|
|
517
|
+
result["summary"] = summary_df
|
|
518
|
+
if compare_df is not None:
|
|
519
|
+
result["compare"] = compare_df
|
|
520
|
+
if codify_df is not None:
|
|
521
|
+
result["codify"] = codify_df
|
|
522
|
+
self._persist_result_snapshot(result)
|
|
523
|
+
self._export_result_archive(result)
|
|
524
|
+
return result
|