openai-gabriel 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. gabriel/__init__.py +61 -0
  2. gabriel/_version.py +1 -0
  3. gabriel/api.py +2284 -0
  4. gabriel/cli/__main__.py +60 -0
  5. gabriel/core/__init__.py +7 -0
  6. gabriel/core/llm_client.py +34 -0
  7. gabriel/core/pipeline.py +18 -0
  8. gabriel/core/prompt_template.py +152 -0
  9. gabriel/prompts/__init__.py +1 -0
  10. gabriel/prompts/bucket_prompt.jinja2 +113 -0
  11. gabriel/prompts/classification_prompt.jinja2 +50 -0
  12. gabriel/prompts/codify_prompt.jinja2 +95 -0
  13. gabriel/prompts/comparison_prompt.jinja2 +60 -0
  14. gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
  15. gabriel/prompts/deidentification_prompt.jinja2 +112 -0
  16. gabriel/prompts/extraction_prompt.jinja2 +61 -0
  17. gabriel/prompts/filter_prompt.jinja2 +31 -0
  18. gabriel/prompts/ideation_prompt.jinja2 +80 -0
  19. gabriel/prompts/merge_prompt.jinja2 +47 -0
  20. gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
  21. gabriel/prompts/rankings_prompt.jinja2 +49 -0
  22. gabriel/prompts/ratings_prompt.jinja2 +50 -0
  23. gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
  24. gabriel/prompts/seed.jinja2 +43 -0
  25. gabriel/prompts/snippets.jinja2 +117 -0
  26. gabriel/tasks/__init__.py +63 -0
  27. gabriel/tasks/_attribute_utils.py +69 -0
  28. gabriel/tasks/bucket.py +432 -0
  29. gabriel/tasks/classify.py +562 -0
  30. gabriel/tasks/codify.py +1033 -0
  31. gabriel/tasks/compare.py +235 -0
  32. gabriel/tasks/debias.py +1460 -0
  33. gabriel/tasks/deduplicate.py +341 -0
  34. gabriel/tasks/deidentify.py +316 -0
  35. gabriel/tasks/discover.py +524 -0
  36. gabriel/tasks/extract.py +455 -0
  37. gabriel/tasks/filter.py +169 -0
  38. gabriel/tasks/ideate.py +782 -0
  39. gabriel/tasks/merge.py +464 -0
  40. gabriel/tasks/paraphrase.py +531 -0
  41. gabriel/tasks/rank.py +2041 -0
  42. gabriel/tasks/rate.py +347 -0
  43. gabriel/tasks/seed.py +465 -0
  44. gabriel/tasks/whatever.py +344 -0
  45. gabriel/utils/__init__.py +64 -0
  46. gabriel/utils/audio_utils.py +42 -0
  47. gabriel/utils/file_utils.py +464 -0
  48. gabriel/utils/image_utils.py +22 -0
  49. gabriel/utils/jinja.py +31 -0
  50. gabriel/utils/logging.py +86 -0
  51. gabriel/utils/mapmaker.py +304 -0
  52. gabriel/utils/media_utils.py +78 -0
  53. gabriel/utils/modality_utils.py +148 -0
  54. gabriel/utils/openai_utils.py +5470 -0
  55. gabriel/utils/parsing.py +282 -0
  56. gabriel/utils/passage_viewer.py +2557 -0
  57. gabriel/utils/pdf_utils.py +20 -0
  58. gabriel/utils/plot_utils.py +2881 -0
  59. gabriel/utils/prompt_utils.py +42 -0
  60. gabriel/utils/word_matching.py +158 -0
  61. openai_gabriel-1.0.1.dist-info/METADATA +443 -0
  62. openai_gabriel-1.0.1.dist-info/RECORD +67 -0
  63. openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
  64. openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
  65. openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
  66. openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
  67. openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,455 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import asyncio
5
+ import os
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from collections import defaultdict
9
+ from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple
10
+
11
+ import pandas as pd
12
+
13
+ from ..core.prompt_template import PromptTemplate, resolve_template
14
+ from ..utils.openai_utils import get_all_responses
15
+ from ..utils import (
16
+ safest_json,
17
+ load_image_inputs,
18
+ load_audio_inputs,
19
+ load_pdf_inputs,
20
+ warn_if_modality_mismatch,
21
+ )
22
+ from ..utils.logging import announce_prompt_rendering
23
+ from ._attribute_utils import load_persisted_attributes
24
+
25
+
26
+ @dataclass
27
+ class ExtractConfig:
28
+ attributes: Dict[str, str]
29
+ save_dir: str = "extraction"
30
+ file_name: str = "extraction.csv"
31
+ model: str = "gpt-5-mini"
32
+ n_parallels: int = 650
33
+ n_runs: int = 1
34
+ use_dummy: bool = False
35
+ max_timeout: Optional[float] = None
36
+ additional_instructions: Optional[str] = None
37
+ modality: str = "entity"
38
+ n_attributes_per_run: int = 8
39
+ reasoning_effort: Optional[str] = None
40
+ reasoning_summary: Optional[str] = None
41
+
42
+ def __post_init__(self) -> None:
43
+ if self.additional_instructions is not None:
44
+ cleaned = str(self.additional_instructions).strip()
45
+ self.additional_instructions = cleaned or None
46
+
47
+
48
+ class Extract:
49
+ """Extract attributes from passages using an LLM."""
50
+
51
+ def __init__(
52
+ self,
53
+ cfg: ExtractConfig,
54
+ template: Optional[PromptTemplate] = None,
55
+ template_path: Optional[str] = None,
56
+ ) -> None:
57
+ expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
58
+ expanded.mkdir(parents=True, exist_ok=True)
59
+ cfg.save_dir = str(expanded)
60
+ self.cfg = cfg
61
+ self.template = resolve_template(
62
+ template=template,
63
+ template_path=template_path,
64
+ reference_filename="extraction_prompt.jinja2",
65
+ )
66
+
67
+ async def _parse(
68
+ self, raw: Any, attrs: List[str]
69
+ ) -> List[Tuple[Optional[str], Dict[str, str]]]:
70
+ obj = await safest_json(raw)
71
+ attr_names = list(attrs)
72
+
73
+ def _default_attr_map() -> Dict[str, str]:
74
+ return {attr: "unknown" for attr in attr_names}
75
+
76
+ def _clean_name(name: Any) -> Optional[str]:
77
+ if isinstance(name, str):
78
+ cleaned = name.strip()
79
+ return cleaned or None
80
+ if name is None:
81
+ return None
82
+ text = str(name).strip()
83
+ return text or None
84
+
85
+ def _build_entry(
86
+ entity_name: Optional[str], payload: Optional[Dict[str, Any]]
87
+ ) -> Tuple[Optional[str], Dict[str, str]]:
88
+ values = _default_attr_map()
89
+ if isinstance(payload, dict):
90
+ for attr in attr_names:
91
+ val = payload.get(attr)
92
+ values[attr] = str(val) if val is not None else "unknown"
93
+ return (_clean_name(entity_name), values)
94
+
95
+ entries: List[Tuple[Optional[str], Dict[str, str]]] = []
96
+
97
+ if isinstance(obj, dict):
98
+ attr_keys = set(attr_names)
99
+ nested_candidates = [
100
+ (key, val)
101
+ for key, val in obj.items()
102
+ if isinstance(val, dict)
103
+ and (not attr_keys or key not in attr_keys)
104
+ ]
105
+ if nested_candidates:
106
+ for name, payload in nested_candidates:
107
+ entries.append(_build_entry(name, payload))
108
+ if entries:
109
+ return entries
110
+ entries.append(_build_entry(None, obj))
111
+ return entries
112
+
113
+ if isinstance(obj, list):
114
+ for item in obj:
115
+ if not isinstance(item, dict):
116
+ continue
117
+ attr_payload: Optional[Dict[str, Any]] = None
118
+ if isinstance(item.get("attributes"), dict):
119
+ attr_payload = item.get("attributes") # type: ignore[assignment]
120
+ else:
121
+ attr_payload = item
122
+ name = item.get("entity_name") or item.get("entity") or item.get("name")
123
+ entries.append(_build_entry(name, attr_payload))
124
+ if entries:
125
+ return entries
126
+
127
+ return [
128
+ _build_entry(None, None),
129
+ ]
130
+
131
+ async def run(
132
+ self,
133
+ df: pd.DataFrame,
134
+ column_name: str,
135
+ *,
136
+ reset_files: bool = False,
137
+ types: Optional[Dict[str, Any]] = None,
138
+ **kwargs: Any,
139
+ ) -> pd.DataFrame:
140
+ df_proc = df.reset_index(drop=True).copy()
141
+ input_columns = list(df_proc.columns)
142
+ base_name = os.path.splitext(self.cfg.file_name)[0]
143
+ self.cfg.attributes = load_persisted_attributes(
144
+ save_dir=self.cfg.save_dir,
145
+ incoming=self.cfg.attributes,
146
+ reset_files=reset_files,
147
+ task_name="Extract",
148
+ item_name="attributes",
149
+ legacy_filename=f"{base_name}_attrs.json",
150
+ )
151
+ values = df_proc[column_name].tolist()
152
+ texts = [str(v) for v in values]
153
+
154
+ warn_if_modality_mismatch(values, self.cfg.modality, column_name=column_name)
155
+
156
+ base_ids: List[str] = []
157
+ id_to_rows: DefaultDict[str, List[int]] = defaultdict(list)
158
+ id_to_val: Dict[str, Any] = {}
159
+ prompt_texts: Dict[str, str] = {}
160
+ row_ids: List[str] = []
161
+
162
+ for row, (passage, orig) in enumerate(zip(texts, values)):
163
+ sha8 = hashlib.sha1(passage.encode()).hexdigest()[:8]
164
+ row_ids.append(sha8)
165
+ id_to_rows[sha8].append(row)
166
+ if len(id_to_rows[sha8]) > 1:
167
+ continue
168
+ id_to_val[sha8] = orig
169
+ prompt_texts[sha8] = passage if self.cfg.modality in {"text", "entity", "web"} else ""
170
+ base_ids.append(sha8)
171
+
172
+ df_proc["_gid"] = row_ids
173
+
174
+ if not base_ids:
175
+ out_path = os.path.join(self.cfg.save_dir, f"{base_name}_cleaned.csv")
176
+ result = df_proc.drop(columns=["_gid"])
177
+ result["entity_name"] = pd.NA
178
+ for attr in self.cfg.attributes.keys():
179
+ result[attr] = pd.NA
180
+ result.to_csv(out_path, index=False)
181
+ return result
182
+
183
+ attr_items = list(self.cfg.attributes.items())
184
+ attr_count = len(attr_items)
185
+ if attr_count > self.cfg.n_attributes_per_run:
186
+ batches = (
187
+ attr_count + self.cfg.n_attributes_per_run - 1
188
+ ) // self.cfg.n_attributes_per_run
189
+ print(
190
+ f"[Extract] {attr_count} attributes provided. n_attributes_per_run={self.cfg.n_attributes_per_run}. "
191
+ f"Splitting into {batches} prompt batches. Increase n_attributes_per_run if you want all attributes "
192
+ "to be processed in the same prompt."
193
+ )
194
+ attr_batches: List[Dict[str, str]] = [
195
+ dict(attr_items[i : i + self.cfg.n_attributes_per_run])
196
+ for i in range(0, len(attr_items), self.cfg.n_attributes_per_run)
197
+ ]
198
+
199
+ prompts: List[str] = []
200
+ ids: List[str] = []
201
+ announce_prompt_rendering("Extract", len(base_ids) * len(attr_batches))
202
+ for batch_idx, batch_attrs in enumerate(attr_batches):
203
+ for ident in base_ids:
204
+ prompts.append(
205
+ self.template.render(
206
+ text=prompt_texts[ident],
207
+ attributes=batch_attrs,
208
+ additional_instructions=self.cfg.additional_instructions or "",
209
+ modality=self.cfg.modality,
210
+ )
211
+ )
212
+ ids.append(f"{ident}_batch{batch_idx}")
213
+
214
+ prompt_images: Optional[Dict[str, List[str]]] = None
215
+ prompt_audio: Optional[Dict[str, List[Dict[str, str]]]] = None
216
+ prompt_pdfs: Optional[Dict[str, List[Dict[str, str]]]] = None
217
+ if self.cfg.modality == "image":
218
+ tmp: Dict[str, List[str]] = {}
219
+ for ident, rows in id_to_rows.items():
220
+ imgs = load_image_inputs(values[rows[0]])
221
+ if imgs:
222
+ for batch_idx in range(len(attr_batches)):
223
+ tmp[f"{ident}_batch{batch_idx}"] = imgs
224
+ prompt_images = tmp or None
225
+ elif self.cfg.modality == "audio":
226
+ tmp_a: Dict[str, List[Dict[str, str]]] = {}
227
+ for ident, rows in id_to_rows.items():
228
+ auds = load_audio_inputs(values[rows[0]])
229
+ if auds:
230
+ for batch_idx in range(len(attr_batches)):
231
+ tmp_a[f"{ident}_batch{batch_idx}"] = auds
232
+ prompt_audio = tmp_a or None
233
+ elif self.cfg.modality == "pdf":
234
+ tmp_p: Dict[str, List[Dict[str, str]]] = {}
235
+ for ident, rows in id_to_rows.items():
236
+ pdfs = load_pdf_inputs(values[rows[0]])
237
+ if pdfs:
238
+ for batch_idx in range(len(attr_batches)):
239
+ tmp_p[f"{ident}_batch{batch_idx}"] = pdfs
240
+ prompt_pdfs = tmp_p or None
241
+
242
+ csv_path = os.path.join(self.cfg.save_dir, f"{base_name}_raw_responses.csv")
243
+
244
+ kwargs.setdefault("web_search", self.cfg.modality == "web")
245
+
246
+ if not isinstance(self.cfg.n_runs, int) or self.cfg.n_runs < 1:
247
+ raise ValueError("n_runs must be an integer >= 1")
248
+
249
+ existing_ids: Set[str] = set()
250
+ if not reset_files and os.path.exists(csv_path):
251
+ try:
252
+ existing_df = pd.read_csv(csv_path, usecols=["Identifier"])
253
+ existing_ids = set(existing_df["Identifier"].astype(str))
254
+ except Exception:
255
+ existing_ids = set()
256
+
257
+ run_identifier_lists: List[List[str]] = []
258
+ for run_idx in range(1, self.cfg.n_runs + 1):
259
+ run_ids: List[str] = []
260
+ for ident in ids:
261
+ if run_idx == 1:
262
+ legacy_ident = f"{ident}_run1"
263
+ run_ids.append(legacy_ident if legacy_ident in existing_ids else ident)
264
+ else:
265
+ run_ids.append(f"{ident}_run{run_idx}")
266
+ run_identifier_lists.append(run_ids)
267
+
268
+ prompts_all: List[str] = []
269
+ ids_all: List[str] = []
270
+ for run_ids in run_identifier_lists:
271
+ prompts_all.extend(prompts)
272
+ ids_all.extend(run_ids)
273
+
274
+ prompt_images_all: Optional[Dict[str, List[str]]] = None
275
+ if prompt_images:
276
+ prompt_images_all = {}
277
+ for run_ids in run_identifier_lists:
278
+ for base_ident, run_ident in zip(ids, run_ids):
279
+ imgs = prompt_images.get(base_ident)
280
+ if imgs:
281
+ prompt_images_all[run_ident] = imgs
282
+ prompt_audio_all: Optional[Dict[str, List[Dict[str, str]]]] = None
283
+ if prompt_audio:
284
+ prompt_audio_all = {}
285
+ for run_ids in run_identifier_lists:
286
+ for base_ident, run_ident in zip(ids, run_ids):
287
+ auds = prompt_audio.get(base_ident)
288
+ if auds:
289
+ prompt_audio_all[run_ident] = auds
290
+ prompt_pdfs_all: Optional[Dict[str, List[Dict[str, str]]]] = None
291
+ if prompt_pdfs:
292
+ prompt_pdfs_all = {}
293
+ for run_ids in run_identifier_lists:
294
+ for base_ident, run_ident in zip(ids, run_ids):
295
+ pdfs = prompt_pdfs.get(base_ident)
296
+ if pdfs:
297
+ prompt_pdfs_all[run_ident] = pdfs
298
+
299
+ df_resp_all = await get_all_responses(
300
+ prompts=prompts_all,
301
+ identifiers=ids_all,
302
+ prompt_images=prompt_images_all,
303
+ prompt_audio=prompt_audio_all,
304
+ prompt_pdfs=prompt_pdfs_all,
305
+ n_parallels=self.cfg.n_parallels,
306
+ model=self.cfg.model,
307
+ save_path=csv_path,
308
+ use_dummy=self.cfg.use_dummy,
309
+ max_timeout=self.cfg.max_timeout,
310
+ json_mode=self.cfg.modality != "audio",
311
+ reset_files=reset_files,
312
+ reasoning_effort=self.cfg.reasoning_effort,
313
+ reasoning_summary=self.cfg.reasoning_summary,
314
+ **kwargs,
315
+ )
316
+ if not isinstance(df_resp_all, pd.DataFrame):
317
+ raise RuntimeError("get_all_responses returned no DataFrame")
318
+
319
+ df_resps = []
320
+ for run_idx, run_ids in enumerate(run_identifier_lists, start=1):
321
+ suffix = f"_run{run_idx}"
322
+ sub = df_resp_all[df_resp_all.Identifier.isin(run_ids)].copy()
323
+ sub.Identifier = sub.Identifier.str.replace(suffix + "$", "", regex=True)
324
+ df_resps.append(sub)
325
+
326
+ full_records: List[Dict[str, Any]] = []
327
+ base_attrs = list(self.cfg.attributes.keys())
328
+ for run_idx, df_resp in enumerate(df_resps, start=1):
329
+ id_to_entity_vals: Dict[str, Dict[Optional[str], Dict[str, str]]] = {
330
+ ident: {} for ident in base_ids
331
+ }
332
+ for ident_batch, raw in zip(df_resp.Identifier, df_resp.Response):
333
+ if "_batch" not in ident_batch:
334
+ continue
335
+ base_ident, batch_part = ident_batch.rsplit("_batch", 1)
336
+ batch_idx = int(batch_part)
337
+ attrs = list(attr_batches[batch_idx].keys())
338
+ parsed_entities = await self._parse(raw, attrs)
339
+ entity_store = id_to_entity_vals.setdefault(base_ident, {})
340
+ for entity_name, entity_attrs in parsed_entities:
341
+ key = entity_name if entity_name is not None else None
342
+ if key not in entity_store:
343
+ entity_store[key] = {attr: "unknown" for attr in base_attrs}
344
+ for attr in attrs:
345
+ entity_store[key][attr] = entity_attrs.get(attr, "unknown")
346
+ for ident in base_ids:
347
+ entity_map = id_to_entity_vals.get(ident) or {
348
+ None: {attr: "unknown" for attr in base_attrs}
349
+ }
350
+ for entity_name, attr_values in entity_map.items():
351
+ rec: Dict[str, Any] = {
352
+ "id": ident,
353
+ "entity_name": entity_name,
354
+ "text": id_to_val[ident],
355
+ "run": run_idx,
356
+ }
357
+ rec.update({attr: attr_values.get(attr, "unknown") for attr in base_attrs})
358
+ full_records.append(rec)
359
+
360
+ full_df = pd.DataFrame(full_records).set_index(["id", "entity_name", "run"])
361
+ if self.cfg.n_runs > 1:
362
+ disagg_path = os.path.join(
363
+ self.cfg.save_dir, f"{base_name}_full_disaggregated.csv"
364
+ )
365
+ full_df.to_csv(disagg_path, index_label=["id", "entity_name", "run"])
366
+
367
+ def _pick_first(s: pd.Series) -> str:
368
+ for val in s.dropna():
369
+ if str(val).strip().lower() != "unknown":
370
+ return str(val)
371
+ return "unknown"
372
+
373
+ entity_index = full_df.index.droplevel("run").unique()
374
+ if base_attrs:
375
+ agg_series = {
376
+ attr: full_df[attr]
377
+ .groupby(level=["id", "entity_name"], sort=False)
378
+ .apply(_pick_first)
379
+ for attr in base_attrs
380
+ }
381
+ agg_df = pd.DataFrame(agg_series)
382
+ else:
383
+ agg_df = pd.DataFrame(index=entity_index)
384
+ agg_df = agg_df.reindex(entity_index)
385
+
386
+ unknown_counts = {attr: (agg_df[attr] == "unknown").sum() for attr in base_attrs}
387
+
388
+ out_path = os.path.join(self.cfg.save_dir, f"{base_name}_cleaned.csv")
389
+ agg_reset = agg_df.reset_index()
390
+ result = df_proc.merge(agg_reset, left_on="_gid", right_on="id", how="left")
391
+ drop_cols = [col for col in ("_gid", "id") if col in result.columns]
392
+ if drop_cols:
393
+ result = result.drop(columns=drop_cols)
394
+
395
+ original_cols = [col for col in df_proc.columns if col != "_gid"]
396
+ final_order: List[str] = []
397
+ for col in original_cols:
398
+ if col in result.columns:
399
+ final_order.append(col)
400
+ if "entity_name" in result.columns and "entity_name" not in final_order:
401
+ final_order.append("entity_name")
402
+ for attr in base_attrs:
403
+ if attr in result.columns:
404
+ final_order.append(attr)
405
+ remaining = [col for col in result.columns if col not in final_order]
406
+ if remaining:
407
+ final_order.extend(remaining)
408
+ result = result[final_order]
409
+
410
+ result.to_csv(out_path, index=False)
411
+
412
+ result = result.replace("unknown", pd.NA)
413
+
414
+ duplicate_rows = len(result) - len(df_proc)
415
+
416
+ if types:
417
+ coerced = result.copy()
418
+ fail_logs: Dict[str, int] = {}
419
+ for col, typ in types.items():
420
+ if col not in coerced:
421
+ continue
422
+ orig = coerced[col]
423
+ non_null = orig.notna()
424
+ target = str(typ).lower()
425
+ if target in {"datetime", "date"}:
426
+ conv = pd.to_datetime(orig, errors="coerce")
427
+ else:
428
+ conv = pd.to_numeric(orig, errors="coerce")
429
+ if target in {"int", "int64"}:
430
+ conv = conv.round().astype("Int64")
431
+ coerced[col] = conv
432
+ fail_logs[col] = int((non_null & conv.isna()).sum())
433
+ coerced_path = os.path.join(self.cfg.save_dir, f"{base_name}_cleaned_coerced.csv")
434
+ coerced.to_csv(coerced_path, index=False)
435
+ for col, n_fail in fail_logs.items():
436
+ print(f"[Extract] Failed to coerce {n_fail} values in column '{col}'.")
437
+ result = coerced
438
+
439
+ total = len(agg_df)
440
+ print("\n=== Extraction coverage ===")
441
+ for attr in base_attrs:
442
+ known = total - unknown_counts[attr]
443
+ print(f"{attr:<55s}: {known:5d} extracted, {unknown_counts[attr]:5d} unknown")
444
+ print("============================\n")
445
+
446
+ if duplicate_rows > 0:
447
+ subset_hint = ", ".join(f"'{col}'" for col in input_columns)
448
+ print(
449
+ "[Extract] Multiple entity names were returned for at least one input.\n"
450
+ f" {duplicate_rows} additional row(s) were created so the cleaned file now has {len(result)} rows versus {len(df_proc)} inputs.\n"
451
+ " If you only need one row per original input, deduplicate on your source columns\n"
452
+ f" (e.g. `result = result.drop_duplicates(subset=[{subset_hint}], keep='first')`).\n"
453
+ )
454
+
455
+ return result
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Set
8
+
9
+ import pandas as pd
10
+
11
+ from ..core.prompt_template import PromptTemplate, resolve_template
12
+ from ..utils.openai_utils import get_all_responses
13
+ from ..utils.logging import announce_prompt_rendering
14
+ from ..utils import safest_json
15
+
16
+
17
+ @dataclass
18
+ class FilterConfig:
19
+ """Configuration for :class:`Filter`."""
20
+
21
+ condition: str
22
+ save_dir: str
23
+ file_name: str = "filter_responses.csv"
24
+ model: str = "gpt-5-nano"
25
+ n_parallels: int = 650
26
+ entities_per_call: int = 150
27
+ shuffle: bool = True
28
+ random_seed: int = 42
29
+ n_runs: int = 1
30
+ threshold: float = 0.5
31
+ additional_instructions: Optional[str] = None
32
+ use_dummy: bool = False
33
+ max_timeout: Optional[float] = None
34
+ fix_json_with_llm: bool = False
35
+ json_fix_timeout: Optional[float] = 60.0
36
+
37
+ def __post_init__(self) -> None:
38
+ if self.additional_instructions is not None:
39
+ cleaned = str(self.additional_instructions).strip()
40
+ self.additional_instructions = cleaned or None
41
+
42
+
43
+ class Filter:
44
+ """Filter entities in a DataFrame column based on a condition using an LLM."""
45
+
46
+ def __init__(
47
+ self,
48
+ cfg: FilterConfig,
49
+ template: Optional[PromptTemplate] = None,
50
+ template_path: Optional[str] = None,
51
+ ) -> None:
52
+ expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
53
+ expanded.mkdir(parents=True, exist_ok=True)
54
+ cfg.save_dir = str(expanded)
55
+ self.cfg = cfg
56
+ self.template = resolve_template(
57
+ template=template,
58
+ template_path=template_path,
59
+ reference_filename="filter_prompt.jinja2",
60
+ )
61
+
62
+ # ------------------------------------------------------------------
63
+ async def run(
64
+ self,
65
+ df: pd.DataFrame,
66
+ column_name: str,
67
+ *,
68
+ reset_files: bool = False,
69
+ **kwargs: Any,
70
+ ) -> pd.DataFrame:
71
+ df_proc = df.reset_index(drop=True).copy()
72
+ raw_entities = [str(x) for x in df_proc[column_name].dropna()]
73
+
74
+ # unique while preserving order
75
+ seen: Set[str] = set()
76
+ entities: List[str] = []
77
+ for ent in raw_entities:
78
+ if ent not in seen:
79
+ seen.add(ent)
80
+ entities.append(ent)
81
+
82
+ prompts: List[str] = []
83
+ identifiers: List[str] = []
84
+ total_chunks = 0
85
+ for run in range(self.cfg.n_runs):
86
+ ents = list(entities)
87
+ if self.cfg.shuffle:
88
+ rnd = random.Random(self.cfg.random_seed + run)
89
+ rnd.shuffle(ents)
90
+ chunks = [
91
+ ents[i : i + self.cfg.entities_per_call]
92
+ for i in range(0, len(ents), self.cfg.entities_per_call)
93
+ ]
94
+ total_chunks += len(chunks)
95
+ for idx, chunk in enumerate(chunks):
96
+ prompts.append(
97
+ self.template.render(
98
+ condition=self.cfg.condition,
99
+ entities=chunk,
100
+ additional_instructions=self.cfg.additional_instructions,
101
+ )
102
+ )
103
+ identifiers.append(f"filter_{run:03d}_{idx:05d}")
104
+
105
+ announce_prompt_rendering("Filter", total_chunks)
106
+
107
+ save_path = os.path.join(self.cfg.save_dir, self.cfg.file_name)
108
+ if prompts:
109
+ resp_df = await get_all_responses(
110
+ prompts=prompts,
111
+ identifiers=identifiers,
112
+ n_parallels=self.cfg.n_parallels,
113
+ model=self.cfg.model,
114
+ save_path=save_path,
115
+ use_dummy=self.cfg.use_dummy,
116
+ max_timeout=self.cfg.max_timeout,
117
+ json_mode=True,
118
+ reset_files=reset_files,
119
+ **kwargs,
120
+ )
121
+ else:
122
+ resp_df = pd.DataFrame(columns=["Identifier", "Response"])
123
+
124
+ resp_map: Dict[str, Any] = dict(
125
+ zip(resp_df.get("Identifier", []), resp_df.get("Response", []))
126
+ )
127
+
128
+ meets_by_run: List[Set[str]] = [set() for _ in range(self.cfg.n_runs)]
129
+ for ident, raw in resp_map.items():
130
+ parts = ident.split("_")
131
+ if len(parts) < 3:
132
+ continue
133
+ try:
134
+ run_idx = int(parts[1])
135
+ except ValueError:
136
+ continue
137
+ parsed = await safest_json(
138
+ raw,
139
+ model=self.cfg.model if self.cfg.fix_json_with_llm else None,
140
+ use_llm_fallback=self.cfg.fix_json_with_llm,
141
+ llm_timeout=self.cfg.json_fix_timeout,
142
+ )
143
+ ent_list: Optional[List[str]] = None
144
+ if isinstance(parsed, dict):
145
+ val = parsed.get("entities meeting condition") or parsed.get(
146
+ "entities_meeting_condition"
147
+ )
148
+ if isinstance(val, list):
149
+ ent_list = [str(v) for v in val if isinstance(v, str)]
150
+ elif isinstance(parsed, list):
151
+ ent_list = [str(v) for v in parsed if isinstance(v, str)]
152
+ if ent_list and 0 <= run_idx < self.cfg.n_runs:
153
+ for ent in ent_list:
154
+ meets_by_run[run_idx].add(ent.strip())
155
+
156
+ run_cols: List[str] = []
157
+ for run_idx in range(self.cfg.n_runs):
158
+ meets_norm = {m.lower() for m in meets_by_run[run_idx]}
159
+ col = f"meets_condition_run_{run_idx + 1}"
160
+ run_cols.append(col)
161
+ df_proc[col] = [
162
+ str(v).lower() in meets_norm if not pd.isna(v) else False
163
+ for v in df_proc[column_name]
164
+ ]
165
+
166
+ df_proc["meets_condition"] = (
167
+ df_proc[run_cols].sum(axis=1) / self.cfg.n_runs >= self.cfg.threshold
168
+ )
169
+ return df_proc