openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
gabriel/tasks/extract.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import asyncio
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
from ..core.prompt_template import PromptTemplate, resolve_template
|
|
14
|
+
from ..utils.openai_utils import get_all_responses
|
|
15
|
+
from ..utils import (
|
|
16
|
+
safest_json,
|
|
17
|
+
load_image_inputs,
|
|
18
|
+
load_audio_inputs,
|
|
19
|
+
load_pdf_inputs,
|
|
20
|
+
warn_if_modality_mismatch,
|
|
21
|
+
)
|
|
22
|
+
from ..utils.logging import announce_prompt_rendering
|
|
23
|
+
from ._attribute_utils import load_persisted_attributes
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class ExtractConfig:
|
|
28
|
+
attributes: Dict[str, str]
|
|
29
|
+
save_dir: str = "extraction"
|
|
30
|
+
file_name: str = "extraction.csv"
|
|
31
|
+
model: str = "gpt-5-mini"
|
|
32
|
+
n_parallels: int = 650
|
|
33
|
+
n_runs: int = 1
|
|
34
|
+
use_dummy: bool = False
|
|
35
|
+
max_timeout: Optional[float] = None
|
|
36
|
+
additional_instructions: Optional[str] = None
|
|
37
|
+
modality: str = "entity"
|
|
38
|
+
n_attributes_per_run: int = 8
|
|
39
|
+
reasoning_effort: Optional[str] = None
|
|
40
|
+
reasoning_summary: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
def __post_init__(self) -> None:
|
|
43
|
+
if self.additional_instructions is not None:
|
|
44
|
+
cleaned = str(self.additional_instructions).strip()
|
|
45
|
+
self.additional_instructions = cleaned or None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Extract:
|
|
49
|
+
"""Extract attributes from passages using an LLM."""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
cfg: ExtractConfig,
|
|
54
|
+
template: Optional[PromptTemplate] = None,
|
|
55
|
+
template_path: Optional[str] = None,
|
|
56
|
+
) -> None:
|
|
57
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
58
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
59
|
+
cfg.save_dir = str(expanded)
|
|
60
|
+
self.cfg = cfg
|
|
61
|
+
self.template = resolve_template(
|
|
62
|
+
template=template,
|
|
63
|
+
template_path=template_path,
|
|
64
|
+
reference_filename="extraction_prompt.jinja2",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
async def _parse(
|
|
68
|
+
self, raw: Any, attrs: List[str]
|
|
69
|
+
) -> List[Tuple[Optional[str], Dict[str, str]]]:
|
|
70
|
+
obj = await safest_json(raw)
|
|
71
|
+
attr_names = list(attrs)
|
|
72
|
+
|
|
73
|
+
def _default_attr_map() -> Dict[str, str]:
|
|
74
|
+
return {attr: "unknown" for attr in attr_names}
|
|
75
|
+
|
|
76
|
+
def _clean_name(name: Any) -> Optional[str]:
|
|
77
|
+
if isinstance(name, str):
|
|
78
|
+
cleaned = name.strip()
|
|
79
|
+
return cleaned or None
|
|
80
|
+
if name is None:
|
|
81
|
+
return None
|
|
82
|
+
text = str(name).strip()
|
|
83
|
+
return text or None
|
|
84
|
+
|
|
85
|
+
def _build_entry(
|
|
86
|
+
entity_name: Optional[str], payload: Optional[Dict[str, Any]]
|
|
87
|
+
) -> Tuple[Optional[str], Dict[str, str]]:
|
|
88
|
+
values = _default_attr_map()
|
|
89
|
+
if isinstance(payload, dict):
|
|
90
|
+
for attr in attr_names:
|
|
91
|
+
val = payload.get(attr)
|
|
92
|
+
values[attr] = str(val) if val is not None else "unknown"
|
|
93
|
+
return (_clean_name(entity_name), values)
|
|
94
|
+
|
|
95
|
+
entries: List[Tuple[Optional[str], Dict[str, str]]] = []
|
|
96
|
+
|
|
97
|
+
if isinstance(obj, dict):
|
|
98
|
+
attr_keys = set(attr_names)
|
|
99
|
+
nested_candidates = [
|
|
100
|
+
(key, val)
|
|
101
|
+
for key, val in obj.items()
|
|
102
|
+
if isinstance(val, dict)
|
|
103
|
+
and (not attr_keys or key not in attr_keys)
|
|
104
|
+
]
|
|
105
|
+
if nested_candidates:
|
|
106
|
+
for name, payload in nested_candidates:
|
|
107
|
+
entries.append(_build_entry(name, payload))
|
|
108
|
+
if entries:
|
|
109
|
+
return entries
|
|
110
|
+
entries.append(_build_entry(None, obj))
|
|
111
|
+
return entries
|
|
112
|
+
|
|
113
|
+
if isinstance(obj, list):
|
|
114
|
+
for item in obj:
|
|
115
|
+
if not isinstance(item, dict):
|
|
116
|
+
continue
|
|
117
|
+
attr_payload: Optional[Dict[str, Any]] = None
|
|
118
|
+
if isinstance(item.get("attributes"), dict):
|
|
119
|
+
attr_payload = item.get("attributes") # type: ignore[assignment]
|
|
120
|
+
else:
|
|
121
|
+
attr_payload = item
|
|
122
|
+
name = item.get("entity_name") or item.get("entity") or item.get("name")
|
|
123
|
+
entries.append(_build_entry(name, attr_payload))
|
|
124
|
+
if entries:
|
|
125
|
+
return entries
|
|
126
|
+
|
|
127
|
+
return [
|
|
128
|
+
_build_entry(None, None),
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
async def run(
|
|
132
|
+
self,
|
|
133
|
+
df: pd.DataFrame,
|
|
134
|
+
column_name: str,
|
|
135
|
+
*,
|
|
136
|
+
reset_files: bool = False,
|
|
137
|
+
types: Optional[Dict[str, Any]] = None,
|
|
138
|
+
**kwargs: Any,
|
|
139
|
+
) -> pd.DataFrame:
|
|
140
|
+
df_proc = df.reset_index(drop=True).copy()
|
|
141
|
+
input_columns = list(df_proc.columns)
|
|
142
|
+
base_name = os.path.splitext(self.cfg.file_name)[0]
|
|
143
|
+
self.cfg.attributes = load_persisted_attributes(
|
|
144
|
+
save_dir=self.cfg.save_dir,
|
|
145
|
+
incoming=self.cfg.attributes,
|
|
146
|
+
reset_files=reset_files,
|
|
147
|
+
task_name="Extract",
|
|
148
|
+
item_name="attributes",
|
|
149
|
+
legacy_filename=f"{base_name}_attrs.json",
|
|
150
|
+
)
|
|
151
|
+
values = df_proc[column_name].tolist()
|
|
152
|
+
texts = [str(v) for v in values]
|
|
153
|
+
|
|
154
|
+
warn_if_modality_mismatch(values, self.cfg.modality, column_name=column_name)
|
|
155
|
+
|
|
156
|
+
base_ids: List[str] = []
|
|
157
|
+
id_to_rows: DefaultDict[str, List[int]] = defaultdict(list)
|
|
158
|
+
id_to_val: Dict[str, Any] = {}
|
|
159
|
+
prompt_texts: Dict[str, str] = {}
|
|
160
|
+
row_ids: List[str] = []
|
|
161
|
+
|
|
162
|
+
for row, (passage, orig) in enumerate(zip(texts, values)):
|
|
163
|
+
sha8 = hashlib.sha1(passage.encode()).hexdigest()[:8]
|
|
164
|
+
row_ids.append(sha8)
|
|
165
|
+
id_to_rows[sha8].append(row)
|
|
166
|
+
if len(id_to_rows[sha8]) > 1:
|
|
167
|
+
continue
|
|
168
|
+
id_to_val[sha8] = orig
|
|
169
|
+
prompt_texts[sha8] = passage if self.cfg.modality in {"text", "entity", "web"} else ""
|
|
170
|
+
base_ids.append(sha8)
|
|
171
|
+
|
|
172
|
+
df_proc["_gid"] = row_ids
|
|
173
|
+
|
|
174
|
+
if not base_ids:
|
|
175
|
+
out_path = os.path.join(self.cfg.save_dir, f"{base_name}_cleaned.csv")
|
|
176
|
+
result = df_proc.drop(columns=["_gid"])
|
|
177
|
+
result["entity_name"] = pd.NA
|
|
178
|
+
for attr in self.cfg.attributes.keys():
|
|
179
|
+
result[attr] = pd.NA
|
|
180
|
+
result.to_csv(out_path, index=False)
|
|
181
|
+
return result
|
|
182
|
+
|
|
183
|
+
attr_items = list(self.cfg.attributes.items())
|
|
184
|
+
attr_count = len(attr_items)
|
|
185
|
+
if attr_count > self.cfg.n_attributes_per_run:
|
|
186
|
+
batches = (
|
|
187
|
+
attr_count + self.cfg.n_attributes_per_run - 1
|
|
188
|
+
) // self.cfg.n_attributes_per_run
|
|
189
|
+
print(
|
|
190
|
+
f"[Extract] {attr_count} attributes provided. n_attributes_per_run={self.cfg.n_attributes_per_run}. "
|
|
191
|
+
f"Splitting into {batches} prompt batches. Increase n_attributes_per_run if you want all attributes "
|
|
192
|
+
"to be processed in the same prompt."
|
|
193
|
+
)
|
|
194
|
+
attr_batches: List[Dict[str, str]] = [
|
|
195
|
+
dict(attr_items[i : i + self.cfg.n_attributes_per_run])
|
|
196
|
+
for i in range(0, len(attr_items), self.cfg.n_attributes_per_run)
|
|
197
|
+
]
|
|
198
|
+
|
|
199
|
+
prompts: List[str] = []
|
|
200
|
+
ids: List[str] = []
|
|
201
|
+
announce_prompt_rendering("Extract", len(base_ids) * len(attr_batches))
|
|
202
|
+
for batch_idx, batch_attrs in enumerate(attr_batches):
|
|
203
|
+
for ident in base_ids:
|
|
204
|
+
prompts.append(
|
|
205
|
+
self.template.render(
|
|
206
|
+
text=prompt_texts[ident],
|
|
207
|
+
attributes=batch_attrs,
|
|
208
|
+
additional_instructions=self.cfg.additional_instructions or "",
|
|
209
|
+
modality=self.cfg.modality,
|
|
210
|
+
)
|
|
211
|
+
)
|
|
212
|
+
ids.append(f"{ident}_batch{batch_idx}")
|
|
213
|
+
|
|
214
|
+
prompt_images: Optional[Dict[str, List[str]]] = None
|
|
215
|
+
prompt_audio: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
216
|
+
prompt_pdfs: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
217
|
+
if self.cfg.modality == "image":
|
|
218
|
+
tmp: Dict[str, List[str]] = {}
|
|
219
|
+
for ident, rows in id_to_rows.items():
|
|
220
|
+
imgs = load_image_inputs(values[rows[0]])
|
|
221
|
+
if imgs:
|
|
222
|
+
for batch_idx in range(len(attr_batches)):
|
|
223
|
+
tmp[f"{ident}_batch{batch_idx}"] = imgs
|
|
224
|
+
prompt_images = tmp or None
|
|
225
|
+
elif self.cfg.modality == "audio":
|
|
226
|
+
tmp_a: Dict[str, List[Dict[str, str]]] = {}
|
|
227
|
+
for ident, rows in id_to_rows.items():
|
|
228
|
+
auds = load_audio_inputs(values[rows[0]])
|
|
229
|
+
if auds:
|
|
230
|
+
for batch_idx in range(len(attr_batches)):
|
|
231
|
+
tmp_a[f"{ident}_batch{batch_idx}"] = auds
|
|
232
|
+
prompt_audio = tmp_a or None
|
|
233
|
+
elif self.cfg.modality == "pdf":
|
|
234
|
+
tmp_p: Dict[str, List[Dict[str, str]]] = {}
|
|
235
|
+
for ident, rows in id_to_rows.items():
|
|
236
|
+
pdfs = load_pdf_inputs(values[rows[0]])
|
|
237
|
+
if pdfs:
|
|
238
|
+
for batch_idx in range(len(attr_batches)):
|
|
239
|
+
tmp_p[f"{ident}_batch{batch_idx}"] = pdfs
|
|
240
|
+
prompt_pdfs = tmp_p or None
|
|
241
|
+
|
|
242
|
+
csv_path = os.path.join(self.cfg.save_dir, f"{base_name}_raw_responses.csv")
|
|
243
|
+
|
|
244
|
+
kwargs.setdefault("web_search", self.cfg.modality == "web")
|
|
245
|
+
|
|
246
|
+
if not isinstance(self.cfg.n_runs, int) or self.cfg.n_runs < 1:
|
|
247
|
+
raise ValueError("n_runs must be an integer >= 1")
|
|
248
|
+
|
|
249
|
+
existing_ids: Set[str] = set()
|
|
250
|
+
if not reset_files and os.path.exists(csv_path):
|
|
251
|
+
try:
|
|
252
|
+
existing_df = pd.read_csv(csv_path, usecols=["Identifier"])
|
|
253
|
+
existing_ids = set(existing_df["Identifier"].astype(str))
|
|
254
|
+
except Exception:
|
|
255
|
+
existing_ids = set()
|
|
256
|
+
|
|
257
|
+
run_identifier_lists: List[List[str]] = []
|
|
258
|
+
for run_idx in range(1, self.cfg.n_runs + 1):
|
|
259
|
+
run_ids: List[str] = []
|
|
260
|
+
for ident in ids:
|
|
261
|
+
if run_idx == 1:
|
|
262
|
+
legacy_ident = f"{ident}_run1"
|
|
263
|
+
run_ids.append(legacy_ident if legacy_ident in existing_ids else ident)
|
|
264
|
+
else:
|
|
265
|
+
run_ids.append(f"{ident}_run{run_idx}")
|
|
266
|
+
run_identifier_lists.append(run_ids)
|
|
267
|
+
|
|
268
|
+
prompts_all: List[str] = []
|
|
269
|
+
ids_all: List[str] = []
|
|
270
|
+
for run_ids in run_identifier_lists:
|
|
271
|
+
prompts_all.extend(prompts)
|
|
272
|
+
ids_all.extend(run_ids)
|
|
273
|
+
|
|
274
|
+
prompt_images_all: Optional[Dict[str, List[str]]] = None
|
|
275
|
+
if prompt_images:
|
|
276
|
+
prompt_images_all = {}
|
|
277
|
+
for run_ids in run_identifier_lists:
|
|
278
|
+
for base_ident, run_ident in zip(ids, run_ids):
|
|
279
|
+
imgs = prompt_images.get(base_ident)
|
|
280
|
+
if imgs:
|
|
281
|
+
prompt_images_all[run_ident] = imgs
|
|
282
|
+
prompt_audio_all: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
283
|
+
if prompt_audio:
|
|
284
|
+
prompt_audio_all = {}
|
|
285
|
+
for run_ids in run_identifier_lists:
|
|
286
|
+
for base_ident, run_ident in zip(ids, run_ids):
|
|
287
|
+
auds = prompt_audio.get(base_ident)
|
|
288
|
+
if auds:
|
|
289
|
+
prompt_audio_all[run_ident] = auds
|
|
290
|
+
prompt_pdfs_all: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
291
|
+
if prompt_pdfs:
|
|
292
|
+
prompt_pdfs_all = {}
|
|
293
|
+
for run_ids in run_identifier_lists:
|
|
294
|
+
for base_ident, run_ident in zip(ids, run_ids):
|
|
295
|
+
pdfs = prompt_pdfs.get(base_ident)
|
|
296
|
+
if pdfs:
|
|
297
|
+
prompt_pdfs_all[run_ident] = pdfs
|
|
298
|
+
|
|
299
|
+
df_resp_all = await get_all_responses(
|
|
300
|
+
prompts=prompts_all,
|
|
301
|
+
identifiers=ids_all,
|
|
302
|
+
prompt_images=prompt_images_all,
|
|
303
|
+
prompt_audio=prompt_audio_all,
|
|
304
|
+
prompt_pdfs=prompt_pdfs_all,
|
|
305
|
+
n_parallels=self.cfg.n_parallels,
|
|
306
|
+
model=self.cfg.model,
|
|
307
|
+
save_path=csv_path,
|
|
308
|
+
use_dummy=self.cfg.use_dummy,
|
|
309
|
+
max_timeout=self.cfg.max_timeout,
|
|
310
|
+
json_mode=self.cfg.modality != "audio",
|
|
311
|
+
reset_files=reset_files,
|
|
312
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
313
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
314
|
+
**kwargs,
|
|
315
|
+
)
|
|
316
|
+
if not isinstance(df_resp_all, pd.DataFrame):
|
|
317
|
+
raise RuntimeError("get_all_responses returned no DataFrame")
|
|
318
|
+
|
|
319
|
+
df_resps = []
|
|
320
|
+
for run_idx, run_ids in enumerate(run_identifier_lists, start=1):
|
|
321
|
+
suffix = f"_run{run_idx}"
|
|
322
|
+
sub = df_resp_all[df_resp_all.Identifier.isin(run_ids)].copy()
|
|
323
|
+
sub.Identifier = sub.Identifier.str.replace(suffix + "$", "", regex=True)
|
|
324
|
+
df_resps.append(sub)
|
|
325
|
+
|
|
326
|
+
full_records: List[Dict[str, Any]] = []
|
|
327
|
+
base_attrs = list(self.cfg.attributes.keys())
|
|
328
|
+
for run_idx, df_resp in enumerate(df_resps, start=1):
|
|
329
|
+
id_to_entity_vals: Dict[str, Dict[Optional[str], Dict[str, str]]] = {
|
|
330
|
+
ident: {} for ident in base_ids
|
|
331
|
+
}
|
|
332
|
+
for ident_batch, raw in zip(df_resp.Identifier, df_resp.Response):
|
|
333
|
+
if "_batch" not in ident_batch:
|
|
334
|
+
continue
|
|
335
|
+
base_ident, batch_part = ident_batch.rsplit("_batch", 1)
|
|
336
|
+
batch_idx = int(batch_part)
|
|
337
|
+
attrs = list(attr_batches[batch_idx].keys())
|
|
338
|
+
parsed_entities = await self._parse(raw, attrs)
|
|
339
|
+
entity_store = id_to_entity_vals.setdefault(base_ident, {})
|
|
340
|
+
for entity_name, entity_attrs in parsed_entities:
|
|
341
|
+
key = entity_name if entity_name is not None else None
|
|
342
|
+
if key not in entity_store:
|
|
343
|
+
entity_store[key] = {attr: "unknown" for attr in base_attrs}
|
|
344
|
+
for attr in attrs:
|
|
345
|
+
entity_store[key][attr] = entity_attrs.get(attr, "unknown")
|
|
346
|
+
for ident in base_ids:
|
|
347
|
+
entity_map = id_to_entity_vals.get(ident) or {
|
|
348
|
+
None: {attr: "unknown" for attr in base_attrs}
|
|
349
|
+
}
|
|
350
|
+
for entity_name, attr_values in entity_map.items():
|
|
351
|
+
rec: Dict[str, Any] = {
|
|
352
|
+
"id": ident,
|
|
353
|
+
"entity_name": entity_name,
|
|
354
|
+
"text": id_to_val[ident],
|
|
355
|
+
"run": run_idx,
|
|
356
|
+
}
|
|
357
|
+
rec.update({attr: attr_values.get(attr, "unknown") for attr in base_attrs})
|
|
358
|
+
full_records.append(rec)
|
|
359
|
+
|
|
360
|
+
full_df = pd.DataFrame(full_records).set_index(["id", "entity_name", "run"])
|
|
361
|
+
if self.cfg.n_runs > 1:
|
|
362
|
+
disagg_path = os.path.join(
|
|
363
|
+
self.cfg.save_dir, f"{base_name}_full_disaggregated.csv"
|
|
364
|
+
)
|
|
365
|
+
full_df.to_csv(disagg_path, index_label=["id", "entity_name", "run"])
|
|
366
|
+
|
|
367
|
+
def _pick_first(s: pd.Series) -> str:
|
|
368
|
+
for val in s.dropna():
|
|
369
|
+
if str(val).strip().lower() != "unknown":
|
|
370
|
+
return str(val)
|
|
371
|
+
return "unknown"
|
|
372
|
+
|
|
373
|
+
entity_index = full_df.index.droplevel("run").unique()
|
|
374
|
+
if base_attrs:
|
|
375
|
+
agg_series = {
|
|
376
|
+
attr: full_df[attr]
|
|
377
|
+
.groupby(level=["id", "entity_name"], sort=False)
|
|
378
|
+
.apply(_pick_first)
|
|
379
|
+
for attr in base_attrs
|
|
380
|
+
}
|
|
381
|
+
agg_df = pd.DataFrame(agg_series)
|
|
382
|
+
else:
|
|
383
|
+
agg_df = pd.DataFrame(index=entity_index)
|
|
384
|
+
agg_df = agg_df.reindex(entity_index)
|
|
385
|
+
|
|
386
|
+
unknown_counts = {attr: (agg_df[attr] == "unknown").sum() for attr in base_attrs}
|
|
387
|
+
|
|
388
|
+
out_path = os.path.join(self.cfg.save_dir, f"{base_name}_cleaned.csv")
|
|
389
|
+
agg_reset = agg_df.reset_index()
|
|
390
|
+
result = df_proc.merge(agg_reset, left_on="_gid", right_on="id", how="left")
|
|
391
|
+
drop_cols = [col for col in ("_gid", "id") if col in result.columns]
|
|
392
|
+
if drop_cols:
|
|
393
|
+
result = result.drop(columns=drop_cols)
|
|
394
|
+
|
|
395
|
+
original_cols = [col for col in df_proc.columns if col != "_gid"]
|
|
396
|
+
final_order: List[str] = []
|
|
397
|
+
for col in original_cols:
|
|
398
|
+
if col in result.columns:
|
|
399
|
+
final_order.append(col)
|
|
400
|
+
if "entity_name" in result.columns and "entity_name" not in final_order:
|
|
401
|
+
final_order.append("entity_name")
|
|
402
|
+
for attr in base_attrs:
|
|
403
|
+
if attr in result.columns:
|
|
404
|
+
final_order.append(attr)
|
|
405
|
+
remaining = [col for col in result.columns if col not in final_order]
|
|
406
|
+
if remaining:
|
|
407
|
+
final_order.extend(remaining)
|
|
408
|
+
result = result[final_order]
|
|
409
|
+
|
|
410
|
+
result.to_csv(out_path, index=False)
|
|
411
|
+
|
|
412
|
+
result = result.replace("unknown", pd.NA)
|
|
413
|
+
|
|
414
|
+
duplicate_rows = len(result) - len(df_proc)
|
|
415
|
+
|
|
416
|
+
if types:
|
|
417
|
+
coerced = result.copy()
|
|
418
|
+
fail_logs: Dict[str, int] = {}
|
|
419
|
+
for col, typ in types.items():
|
|
420
|
+
if col not in coerced:
|
|
421
|
+
continue
|
|
422
|
+
orig = coerced[col]
|
|
423
|
+
non_null = orig.notna()
|
|
424
|
+
target = str(typ).lower()
|
|
425
|
+
if target in {"datetime", "date"}:
|
|
426
|
+
conv = pd.to_datetime(orig, errors="coerce")
|
|
427
|
+
else:
|
|
428
|
+
conv = pd.to_numeric(orig, errors="coerce")
|
|
429
|
+
if target in {"int", "int64"}:
|
|
430
|
+
conv = conv.round().astype("Int64")
|
|
431
|
+
coerced[col] = conv
|
|
432
|
+
fail_logs[col] = int((non_null & conv.isna()).sum())
|
|
433
|
+
coerced_path = os.path.join(self.cfg.save_dir, f"{base_name}_cleaned_coerced.csv")
|
|
434
|
+
coerced.to_csv(coerced_path, index=False)
|
|
435
|
+
for col, n_fail in fail_logs.items():
|
|
436
|
+
print(f"[Extract] Failed to coerce {n_fail} values in column '{col}'.")
|
|
437
|
+
result = coerced
|
|
438
|
+
|
|
439
|
+
total = len(agg_df)
|
|
440
|
+
print("\n=== Extraction coverage ===")
|
|
441
|
+
for attr in base_attrs:
|
|
442
|
+
known = total - unknown_counts[attr]
|
|
443
|
+
print(f"{attr:<55s}: {known:5d} extracted, {unknown_counts[attr]:5d} unknown")
|
|
444
|
+
print("============================\n")
|
|
445
|
+
|
|
446
|
+
if duplicate_rows > 0:
|
|
447
|
+
subset_hint = ", ".join(f"'{col}'" for col in input_columns)
|
|
448
|
+
print(
|
|
449
|
+
"[Extract] Multiple entity names were returned for at least one input.\n"
|
|
450
|
+
f" {duplicate_rows} additional row(s) were created so the cleaned file now has {len(result)} rows versus {len(df_proc)} inputs.\n"
|
|
451
|
+
" If you only need one row per original input, deduplicate on your source columns\n"
|
|
452
|
+
f" (e.g. `result = result.drop_duplicates(subset=[{subset_hint}], keep='first')`).\n"
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
return result
|
gabriel/tasks/filter.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import random
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, List, Optional, Set
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from ..core.prompt_template import PromptTemplate, resolve_template
|
|
12
|
+
from ..utils.openai_utils import get_all_responses
|
|
13
|
+
from ..utils.logging import announce_prompt_rendering
|
|
14
|
+
from ..utils import safest_json
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class FilterConfig:
|
|
19
|
+
"""Configuration for :class:`Filter`."""
|
|
20
|
+
|
|
21
|
+
condition: str
|
|
22
|
+
save_dir: str
|
|
23
|
+
file_name: str = "filter_responses.csv"
|
|
24
|
+
model: str = "gpt-5-nano"
|
|
25
|
+
n_parallels: int = 650
|
|
26
|
+
entities_per_call: int = 150
|
|
27
|
+
shuffle: bool = True
|
|
28
|
+
random_seed: int = 42
|
|
29
|
+
n_runs: int = 1
|
|
30
|
+
threshold: float = 0.5
|
|
31
|
+
additional_instructions: Optional[str] = None
|
|
32
|
+
use_dummy: bool = False
|
|
33
|
+
max_timeout: Optional[float] = None
|
|
34
|
+
fix_json_with_llm: bool = False
|
|
35
|
+
json_fix_timeout: Optional[float] = 60.0
|
|
36
|
+
|
|
37
|
+
def __post_init__(self) -> None:
|
|
38
|
+
if self.additional_instructions is not None:
|
|
39
|
+
cleaned = str(self.additional_instructions).strip()
|
|
40
|
+
self.additional_instructions = cleaned or None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Filter:
|
|
44
|
+
"""Filter entities in a DataFrame column based on a condition using an LLM."""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
cfg: FilterConfig,
|
|
49
|
+
template: Optional[PromptTemplate] = None,
|
|
50
|
+
template_path: Optional[str] = None,
|
|
51
|
+
) -> None:
|
|
52
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
53
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
54
|
+
cfg.save_dir = str(expanded)
|
|
55
|
+
self.cfg = cfg
|
|
56
|
+
self.template = resolve_template(
|
|
57
|
+
template=template,
|
|
58
|
+
template_path=template_path,
|
|
59
|
+
reference_filename="filter_prompt.jinja2",
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# ------------------------------------------------------------------
|
|
63
|
+
async def run(
|
|
64
|
+
self,
|
|
65
|
+
df: pd.DataFrame,
|
|
66
|
+
column_name: str,
|
|
67
|
+
*,
|
|
68
|
+
reset_files: bool = False,
|
|
69
|
+
**kwargs: Any,
|
|
70
|
+
) -> pd.DataFrame:
|
|
71
|
+
df_proc = df.reset_index(drop=True).copy()
|
|
72
|
+
raw_entities = [str(x) for x in df_proc[column_name].dropna()]
|
|
73
|
+
|
|
74
|
+
# unique while preserving order
|
|
75
|
+
seen: Set[str] = set()
|
|
76
|
+
entities: List[str] = []
|
|
77
|
+
for ent in raw_entities:
|
|
78
|
+
if ent not in seen:
|
|
79
|
+
seen.add(ent)
|
|
80
|
+
entities.append(ent)
|
|
81
|
+
|
|
82
|
+
prompts: List[str] = []
|
|
83
|
+
identifiers: List[str] = []
|
|
84
|
+
total_chunks = 0
|
|
85
|
+
for run in range(self.cfg.n_runs):
|
|
86
|
+
ents = list(entities)
|
|
87
|
+
if self.cfg.shuffle:
|
|
88
|
+
rnd = random.Random(self.cfg.random_seed + run)
|
|
89
|
+
rnd.shuffle(ents)
|
|
90
|
+
chunks = [
|
|
91
|
+
ents[i : i + self.cfg.entities_per_call]
|
|
92
|
+
for i in range(0, len(ents), self.cfg.entities_per_call)
|
|
93
|
+
]
|
|
94
|
+
total_chunks += len(chunks)
|
|
95
|
+
for idx, chunk in enumerate(chunks):
|
|
96
|
+
prompts.append(
|
|
97
|
+
self.template.render(
|
|
98
|
+
condition=self.cfg.condition,
|
|
99
|
+
entities=chunk,
|
|
100
|
+
additional_instructions=self.cfg.additional_instructions,
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
identifiers.append(f"filter_{run:03d}_{idx:05d}")
|
|
104
|
+
|
|
105
|
+
announce_prompt_rendering("Filter", total_chunks)
|
|
106
|
+
|
|
107
|
+
save_path = os.path.join(self.cfg.save_dir, self.cfg.file_name)
|
|
108
|
+
if prompts:
|
|
109
|
+
resp_df = await get_all_responses(
|
|
110
|
+
prompts=prompts,
|
|
111
|
+
identifiers=identifiers,
|
|
112
|
+
n_parallels=self.cfg.n_parallels,
|
|
113
|
+
model=self.cfg.model,
|
|
114
|
+
save_path=save_path,
|
|
115
|
+
use_dummy=self.cfg.use_dummy,
|
|
116
|
+
max_timeout=self.cfg.max_timeout,
|
|
117
|
+
json_mode=True,
|
|
118
|
+
reset_files=reset_files,
|
|
119
|
+
**kwargs,
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
resp_df = pd.DataFrame(columns=["Identifier", "Response"])
|
|
123
|
+
|
|
124
|
+
resp_map: Dict[str, Any] = dict(
|
|
125
|
+
zip(resp_df.get("Identifier", []), resp_df.get("Response", []))
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
meets_by_run: List[Set[str]] = [set() for _ in range(self.cfg.n_runs)]
|
|
129
|
+
for ident, raw in resp_map.items():
|
|
130
|
+
parts = ident.split("_")
|
|
131
|
+
if len(parts) < 3:
|
|
132
|
+
continue
|
|
133
|
+
try:
|
|
134
|
+
run_idx = int(parts[1])
|
|
135
|
+
except ValueError:
|
|
136
|
+
continue
|
|
137
|
+
parsed = await safest_json(
|
|
138
|
+
raw,
|
|
139
|
+
model=self.cfg.model if self.cfg.fix_json_with_llm else None,
|
|
140
|
+
use_llm_fallback=self.cfg.fix_json_with_llm,
|
|
141
|
+
llm_timeout=self.cfg.json_fix_timeout,
|
|
142
|
+
)
|
|
143
|
+
ent_list: Optional[List[str]] = None
|
|
144
|
+
if isinstance(parsed, dict):
|
|
145
|
+
val = parsed.get("entities meeting condition") or parsed.get(
|
|
146
|
+
"entities_meeting_condition"
|
|
147
|
+
)
|
|
148
|
+
if isinstance(val, list):
|
|
149
|
+
ent_list = [str(v) for v in val if isinstance(v, str)]
|
|
150
|
+
elif isinstance(parsed, list):
|
|
151
|
+
ent_list = [str(v) for v in parsed if isinstance(v, str)]
|
|
152
|
+
if ent_list and 0 <= run_idx < self.cfg.n_runs:
|
|
153
|
+
for ent in ent_list:
|
|
154
|
+
meets_by_run[run_idx].add(ent.strip())
|
|
155
|
+
|
|
156
|
+
run_cols: List[str] = []
|
|
157
|
+
for run_idx in range(self.cfg.n_runs):
|
|
158
|
+
meets_norm = {m.lower() for m in meets_by_run[run_idx]}
|
|
159
|
+
col = f"meets_condition_run_{run_idx + 1}"
|
|
160
|
+
run_cols.append(col)
|
|
161
|
+
df_proc[col] = [
|
|
162
|
+
str(v).lower() in meets_norm if not pd.isna(v) else False
|
|
163
|
+
for v in df_proc[column_name]
|
|
164
|
+
]
|
|
165
|
+
|
|
166
|
+
df_proc["meets_condition"] = (
|
|
167
|
+
df_proc[run_cols].sum(axis=1) / self.cfg.n_runs >= self.cfg.threshold
|
|
168
|
+
)
|
|
169
|
+
return df_proc
|