openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
gabriel/tasks/compare.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import hashlib
|
|
5
|
+
import os
|
|
6
|
+
import random
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
from ..core.prompt_template import PromptTemplate, resolve_template
|
|
14
|
+
from ..utils.openai_utils import get_all_responses
|
|
15
|
+
from ..utils import (
|
|
16
|
+
safest_json,
|
|
17
|
+
load_image_inputs,
|
|
18
|
+
load_audio_inputs,
|
|
19
|
+
load_pdf_inputs,
|
|
20
|
+
warn_if_modality_mismatch,
|
|
21
|
+
)
|
|
22
|
+
from ..utils.logging import announce_prompt_rendering
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class CompareConfig:
|
|
27
|
+
save_dir: str = "comparison"
|
|
28
|
+
file_name: str = "comparison_responses.csv"
|
|
29
|
+
model: str = "gpt-5-mini"
|
|
30
|
+
n_parallels: int = 650
|
|
31
|
+
n_runs: int = 1
|
|
32
|
+
use_dummy: bool = False
|
|
33
|
+
max_timeout: Optional[float] = None
|
|
34
|
+
differentiate: bool = True
|
|
35
|
+
additional_instructions: Optional[str] = None
|
|
36
|
+
modality: str = "text"
|
|
37
|
+
reasoning_effort: Optional[str] = None
|
|
38
|
+
reasoning_summary: Optional[str] = None
|
|
39
|
+
circle_first: Optional[bool] = None
|
|
40
|
+
|
|
41
|
+
def __post_init__(self) -> None:
|
|
42
|
+
if self.additional_instructions is not None:
|
|
43
|
+
cleaned = str(self.additional_instructions).strip()
|
|
44
|
+
self.additional_instructions = cleaned or None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Compare:
|
|
48
|
+
"""Compare two columns row-wise using an LLM."""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
cfg: CompareConfig,
|
|
53
|
+
template: Optional[PromptTemplate] = None,
|
|
54
|
+
template_path: Optional[str] = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
57
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
58
|
+
cfg.save_dir = str(expanded)
|
|
59
|
+
self.cfg = cfg
|
|
60
|
+
self.template = resolve_template(
|
|
61
|
+
template=template,
|
|
62
|
+
template_path=template_path,
|
|
63
|
+
reference_filename="comparison_prompt.jinja2",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
async def _parse(self, raw: Any) -> Dict[str, str]:
|
|
67
|
+
obj = await safest_json(raw)
|
|
68
|
+
if isinstance(obj, dict):
|
|
69
|
+
return {str(k): str(v) if v is not None else "" for k, v in obj.items()}
|
|
70
|
+
return {}
|
|
71
|
+
|
|
72
|
+
async def run(
|
|
73
|
+
self,
|
|
74
|
+
df: pd.DataFrame,
|
|
75
|
+
circle_column_name: str,
|
|
76
|
+
square_column_name: str,
|
|
77
|
+
*,
|
|
78
|
+
reset_files: bool = False,
|
|
79
|
+
**kwargs: Any,
|
|
80
|
+
) -> pd.DataFrame:
|
|
81
|
+
df_proc = df.reset_index(drop=True).copy()
|
|
82
|
+
mask = df_proc[circle_column_name].notna() & df_proc[square_column_name].notna()
|
|
83
|
+
skipped = int((~mask).sum())
|
|
84
|
+
if skipped:
|
|
85
|
+
print(
|
|
86
|
+
f"Skipping {skipped} rows with NaN in {circle_column_name} or {square_column_name}"
|
|
87
|
+
)
|
|
88
|
+
df_proc = df_proc[mask].reset_index(drop=True)
|
|
89
|
+
|
|
90
|
+
circles = df_proc[circle_column_name].tolist()
|
|
91
|
+
squares = df_proc[square_column_name].tolist()
|
|
92
|
+
pairs = list(zip(circles, squares))
|
|
93
|
+
|
|
94
|
+
warn_if_modality_mismatch(circles, self.cfg.modality, column_name=circle_column_name)
|
|
95
|
+
warn_if_modality_mismatch(squares, self.cfg.modality, column_name=square_column_name)
|
|
96
|
+
|
|
97
|
+
prompts: List[str] = []
|
|
98
|
+
ids: List[str] = []
|
|
99
|
+
id_to_circle_first: Dict[str, bool] = {}
|
|
100
|
+
prompt_circle_text: Dict[str, str] = {}
|
|
101
|
+
prompt_square_text: Dict[str, str] = {}
|
|
102
|
+
for circle, square in pairs:
|
|
103
|
+
ident = hashlib.sha1(f"{circle}|{square}".encode()).hexdigest()[:8]
|
|
104
|
+
ids.append(ident)
|
|
105
|
+
circle_first_flag = (
|
|
106
|
+
self.cfg.circle_first
|
|
107
|
+
if self.cfg.circle_first is not None
|
|
108
|
+
else random.random() < 0.5
|
|
109
|
+
)
|
|
110
|
+
id_to_circle_first[ident] = circle_first_flag
|
|
111
|
+
prompt_circle_text[ident] = (
|
|
112
|
+
circle if self.cfg.modality in {"text", "entity", "web"} else ""
|
|
113
|
+
)
|
|
114
|
+
prompt_square_text[ident] = (
|
|
115
|
+
square if self.cfg.modality in {"text", "entity", "web"} else ""
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
announce_prompt_rendering("Compare", len(ids))
|
|
119
|
+
|
|
120
|
+
for ident in ids:
|
|
121
|
+
prompts.append(
|
|
122
|
+
self.template.render(
|
|
123
|
+
entry_circle=prompt_circle_text[ident],
|
|
124
|
+
entry_square=prompt_square_text[ident],
|
|
125
|
+
differentiate=self.cfg.differentiate,
|
|
126
|
+
additional_instructions=self.cfg.additional_instructions or "",
|
|
127
|
+
modality=self.cfg.modality,
|
|
128
|
+
circle_first=id_to_circle_first[ident],
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
prompt_images: Optional[Dict[str, List[str]]] = None
|
|
133
|
+
prompt_audio: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
134
|
+
prompt_pdfs: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
135
|
+
if self.cfg.modality == "image":
|
|
136
|
+
tmp: Dict[str, List[str]] = {}
|
|
137
|
+
for ident, (circle, square) in zip(ids, pairs):
|
|
138
|
+
imgs: List[str] = []
|
|
139
|
+
circle_imgs = load_image_inputs(circle)
|
|
140
|
+
square_imgs = load_image_inputs(square)
|
|
141
|
+
if id_to_circle_first.get(ident, False):
|
|
142
|
+
if circle_imgs:
|
|
143
|
+
imgs.extend(circle_imgs)
|
|
144
|
+
if square_imgs:
|
|
145
|
+
imgs.extend(square_imgs)
|
|
146
|
+
else:
|
|
147
|
+
if square_imgs:
|
|
148
|
+
imgs.extend(square_imgs)
|
|
149
|
+
if circle_imgs:
|
|
150
|
+
imgs.extend(circle_imgs)
|
|
151
|
+
if imgs:
|
|
152
|
+
tmp[ident] = imgs
|
|
153
|
+
prompt_images = tmp or None
|
|
154
|
+
elif self.cfg.modality == "audio":
|
|
155
|
+
tmp_a: Dict[str, List[Dict[str, str]]] = {}
|
|
156
|
+
for ident, (circle, square) in zip(ids, pairs):
|
|
157
|
+
auds: List[Dict[str, str]] = []
|
|
158
|
+
circle_auds = load_audio_inputs(circle)
|
|
159
|
+
square_auds = load_audio_inputs(square)
|
|
160
|
+
if id_to_circle_first.get(ident, False):
|
|
161
|
+
if circle_auds:
|
|
162
|
+
auds.extend(circle_auds)
|
|
163
|
+
if square_auds:
|
|
164
|
+
auds.extend(square_auds)
|
|
165
|
+
else:
|
|
166
|
+
if square_auds:
|
|
167
|
+
auds.extend(square_auds)
|
|
168
|
+
if circle_auds:
|
|
169
|
+
auds.extend(circle_auds)
|
|
170
|
+
if auds:
|
|
171
|
+
tmp_a[ident] = auds
|
|
172
|
+
prompt_audio = tmp_a or None
|
|
173
|
+
elif self.cfg.modality == "pdf":
|
|
174
|
+
tmp_p: Dict[str, List[Dict[str, str]]] = {}
|
|
175
|
+
for ident, (circle, square) in zip(ids, pairs):
|
|
176
|
+
pdfs: List[Dict[str, str]] = []
|
|
177
|
+
circle_pdfs = load_pdf_inputs(circle)
|
|
178
|
+
square_pdfs = load_pdf_inputs(square)
|
|
179
|
+
if id_to_circle_first.get(ident, False):
|
|
180
|
+
if circle_pdfs:
|
|
181
|
+
pdfs.extend(circle_pdfs)
|
|
182
|
+
if square_pdfs:
|
|
183
|
+
pdfs.extend(square_pdfs)
|
|
184
|
+
else:
|
|
185
|
+
if square_pdfs:
|
|
186
|
+
pdfs.extend(square_pdfs)
|
|
187
|
+
if circle_pdfs:
|
|
188
|
+
pdfs.extend(circle_pdfs)
|
|
189
|
+
if pdfs:
|
|
190
|
+
tmp_p[ident] = pdfs
|
|
191
|
+
prompt_pdfs = tmp_p or None
|
|
192
|
+
|
|
193
|
+
csv_path = os.path.join(self.cfg.save_dir, self.cfg.file_name)
|
|
194
|
+
|
|
195
|
+
kwargs.setdefault("web_search", self.cfg.modality == "web")
|
|
196
|
+
|
|
197
|
+
df_resp_all = await get_all_responses(
|
|
198
|
+
prompts=prompts,
|
|
199
|
+
identifiers=ids,
|
|
200
|
+
prompt_images=prompt_images,
|
|
201
|
+
prompt_audio=prompt_audio,
|
|
202
|
+
prompt_pdfs=prompt_pdfs,
|
|
203
|
+
n_parallels=self.cfg.n_parallels,
|
|
204
|
+
model=self.cfg.model,
|
|
205
|
+
save_path=csv_path,
|
|
206
|
+
use_dummy=self.cfg.use_dummy,
|
|
207
|
+
max_timeout=self.cfg.max_timeout,
|
|
208
|
+
json_mode=self.cfg.modality != "audio",
|
|
209
|
+
reset_files=reset_files,
|
|
210
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
211
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
212
|
+
**kwargs,
|
|
213
|
+
)
|
|
214
|
+
if not isinstance(df_resp_all, pd.DataFrame):
|
|
215
|
+
raise RuntimeError("get_all_responses returned no DataFrame")
|
|
216
|
+
|
|
217
|
+
resp_map = dict(zip(df_resp_all.Identifier, df_resp_all.Response))
|
|
218
|
+
parsed = await asyncio.gather(*[self._parse(resp_map.get(i, "")) for i in ids])
|
|
219
|
+
|
|
220
|
+
records: List[Dict[str, str]] = []
|
|
221
|
+
for (circle, square), res in zip(pairs, parsed):
|
|
222
|
+
for attr, expl in res.items():
|
|
223
|
+
records.append(
|
|
224
|
+
{
|
|
225
|
+
circle_column_name: circle,
|
|
226
|
+
square_column_name: square,
|
|
227
|
+
"attribute": attr,
|
|
228
|
+
"explanation": expl,
|
|
229
|
+
}
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
out_df = pd.DataFrame(records)
|
|
233
|
+
if not out_df.empty:
|
|
234
|
+
out_df.set_index([circle_column_name, square_column_name], inplace=True)
|
|
235
|
+
return out_df
|