openai-gabriel 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. gabriel/__init__.py +61 -0
  2. gabriel/_version.py +1 -0
  3. gabriel/api.py +2284 -0
  4. gabriel/cli/__main__.py +60 -0
  5. gabriel/core/__init__.py +7 -0
  6. gabriel/core/llm_client.py +34 -0
  7. gabriel/core/pipeline.py +18 -0
  8. gabriel/core/prompt_template.py +152 -0
  9. gabriel/prompts/__init__.py +1 -0
  10. gabriel/prompts/bucket_prompt.jinja2 +113 -0
  11. gabriel/prompts/classification_prompt.jinja2 +50 -0
  12. gabriel/prompts/codify_prompt.jinja2 +95 -0
  13. gabriel/prompts/comparison_prompt.jinja2 +60 -0
  14. gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
  15. gabriel/prompts/deidentification_prompt.jinja2 +112 -0
  16. gabriel/prompts/extraction_prompt.jinja2 +61 -0
  17. gabriel/prompts/filter_prompt.jinja2 +31 -0
  18. gabriel/prompts/ideation_prompt.jinja2 +80 -0
  19. gabriel/prompts/merge_prompt.jinja2 +47 -0
  20. gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
  21. gabriel/prompts/rankings_prompt.jinja2 +49 -0
  22. gabriel/prompts/ratings_prompt.jinja2 +50 -0
  23. gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
  24. gabriel/prompts/seed.jinja2 +43 -0
  25. gabriel/prompts/snippets.jinja2 +117 -0
  26. gabriel/tasks/__init__.py +63 -0
  27. gabriel/tasks/_attribute_utils.py +69 -0
  28. gabriel/tasks/bucket.py +432 -0
  29. gabriel/tasks/classify.py +562 -0
  30. gabriel/tasks/codify.py +1033 -0
  31. gabriel/tasks/compare.py +235 -0
  32. gabriel/tasks/debias.py +1460 -0
  33. gabriel/tasks/deduplicate.py +341 -0
  34. gabriel/tasks/deidentify.py +316 -0
  35. gabriel/tasks/discover.py +524 -0
  36. gabriel/tasks/extract.py +455 -0
  37. gabriel/tasks/filter.py +169 -0
  38. gabriel/tasks/ideate.py +782 -0
  39. gabriel/tasks/merge.py +464 -0
  40. gabriel/tasks/paraphrase.py +531 -0
  41. gabriel/tasks/rank.py +2041 -0
  42. gabriel/tasks/rate.py +347 -0
  43. gabriel/tasks/seed.py +465 -0
  44. gabriel/tasks/whatever.py +344 -0
  45. gabriel/utils/__init__.py +64 -0
  46. gabriel/utils/audio_utils.py +42 -0
  47. gabriel/utils/file_utils.py +464 -0
  48. gabriel/utils/image_utils.py +22 -0
  49. gabriel/utils/jinja.py +31 -0
  50. gabriel/utils/logging.py +86 -0
  51. gabriel/utils/mapmaker.py +304 -0
  52. gabriel/utils/media_utils.py +78 -0
  53. gabriel/utils/modality_utils.py +148 -0
  54. gabriel/utils/openai_utils.py +5470 -0
  55. gabriel/utils/parsing.py +282 -0
  56. gabriel/utils/passage_viewer.py +2557 -0
  57. gabriel/utils/pdf_utils.py +20 -0
  58. gabriel/utils/plot_utils.py +2881 -0
  59. gabriel/utils/prompt_utils.py +42 -0
  60. gabriel/utils/word_matching.py +158 -0
  61. openai_gabriel-1.0.1.dist-info/METADATA +443 -0
  62. openai_gabriel-1.0.1.dist-info/RECORD +67 -0
  63. openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
  64. openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
  65. openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
  66. openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
  67. openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,235 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import hashlib
5
+ import os
6
+ import random
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import pandas as pd
12
+
13
+ from ..core.prompt_template import PromptTemplate, resolve_template
14
+ from ..utils.openai_utils import get_all_responses
15
+ from ..utils import (
16
+ safest_json,
17
+ load_image_inputs,
18
+ load_audio_inputs,
19
+ load_pdf_inputs,
20
+ warn_if_modality_mismatch,
21
+ )
22
+ from ..utils.logging import announce_prompt_rendering
23
+
24
+
25
+ @dataclass
26
+ class CompareConfig:
27
+ save_dir: str = "comparison"
28
+ file_name: str = "comparison_responses.csv"
29
+ model: str = "gpt-5-mini"
30
+ n_parallels: int = 650
31
+ n_runs: int = 1
32
+ use_dummy: bool = False
33
+ max_timeout: Optional[float] = None
34
+ differentiate: bool = True
35
+ additional_instructions: Optional[str] = None
36
+ modality: str = "text"
37
+ reasoning_effort: Optional[str] = None
38
+ reasoning_summary: Optional[str] = None
39
+ circle_first: Optional[bool] = None
40
+
41
+ def __post_init__(self) -> None:
42
+ if self.additional_instructions is not None:
43
+ cleaned = str(self.additional_instructions).strip()
44
+ self.additional_instructions = cleaned or None
45
+
46
+
47
+ class Compare:
48
+ """Compare two columns row-wise using an LLM."""
49
+
50
+ def __init__(
51
+ self,
52
+ cfg: CompareConfig,
53
+ template: Optional[PromptTemplate] = None,
54
+ template_path: Optional[str] = None,
55
+ ) -> None:
56
+ expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
57
+ expanded.mkdir(parents=True, exist_ok=True)
58
+ cfg.save_dir = str(expanded)
59
+ self.cfg = cfg
60
+ self.template = resolve_template(
61
+ template=template,
62
+ template_path=template_path,
63
+ reference_filename="comparison_prompt.jinja2",
64
+ )
65
+
66
+ async def _parse(self, raw: Any) -> Dict[str, str]:
67
+ obj = await safest_json(raw)
68
+ if isinstance(obj, dict):
69
+ return {str(k): str(v) if v is not None else "" for k, v in obj.items()}
70
+ return {}
71
+
72
+ async def run(
73
+ self,
74
+ df: pd.DataFrame,
75
+ circle_column_name: str,
76
+ square_column_name: str,
77
+ *,
78
+ reset_files: bool = False,
79
+ **kwargs: Any,
80
+ ) -> pd.DataFrame:
81
+ df_proc = df.reset_index(drop=True).copy()
82
+ mask = df_proc[circle_column_name].notna() & df_proc[square_column_name].notna()
83
+ skipped = int((~mask).sum())
84
+ if skipped:
85
+ print(
86
+ f"Skipping {skipped} rows with NaN in {circle_column_name} or {square_column_name}"
87
+ )
88
+ df_proc = df_proc[mask].reset_index(drop=True)
89
+
90
+ circles = df_proc[circle_column_name].tolist()
91
+ squares = df_proc[square_column_name].tolist()
92
+ pairs = list(zip(circles, squares))
93
+
94
+ warn_if_modality_mismatch(circles, self.cfg.modality, column_name=circle_column_name)
95
+ warn_if_modality_mismatch(squares, self.cfg.modality, column_name=square_column_name)
96
+
97
+ prompts: List[str] = []
98
+ ids: List[str] = []
99
+ id_to_circle_first: Dict[str, bool] = {}
100
+ prompt_circle_text: Dict[str, str] = {}
101
+ prompt_square_text: Dict[str, str] = {}
102
+ for circle, square in pairs:
103
+ ident = hashlib.sha1(f"{circle}|{square}".encode()).hexdigest()[:8]
104
+ ids.append(ident)
105
+ circle_first_flag = (
106
+ self.cfg.circle_first
107
+ if self.cfg.circle_first is not None
108
+ else random.random() < 0.5
109
+ )
110
+ id_to_circle_first[ident] = circle_first_flag
111
+ prompt_circle_text[ident] = (
112
+ circle if self.cfg.modality in {"text", "entity", "web"} else ""
113
+ )
114
+ prompt_square_text[ident] = (
115
+ square if self.cfg.modality in {"text", "entity", "web"} else ""
116
+ )
117
+
118
+ announce_prompt_rendering("Compare", len(ids))
119
+
120
+ for ident in ids:
121
+ prompts.append(
122
+ self.template.render(
123
+ entry_circle=prompt_circle_text[ident],
124
+ entry_square=prompt_square_text[ident],
125
+ differentiate=self.cfg.differentiate,
126
+ additional_instructions=self.cfg.additional_instructions or "",
127
+ modality=self.cfg.modality,
128
+ circle_first=id_to_circle_first[ident],
129
+ )
130
+ )
131
+
132
+ prompt_images: Optional[Dict[str, List[str]]] = None
133
+ prompt_audio: Optional[Dict[str, List[Dict[str, str]]]] = None
134
+ prompt_pdfs: Optional[Dict[str, List[Dict[str, str]]]] = None
135
+ if self.cfg.modality == "image":
136
+ tmp: Dict[str, List[str]] = {}
137
+ for ident, (circle, square) in zip(ids, pairs):
138
+ imgs: List[str] = []
139
+ circle_imgs = load_image_inputs(circle)
140
+ square_imgs = load_image_inputs(square)
141
+ if id_to_circle_first.get(ident, False):
142
+ if circle_imgs:
143
+ imgs.extend(circle_imgs)
144
+ if square_imgs:
145
+ imgs.extend(square_imgs)
146
+ else:
147
+ if square_imgs:
148
+ imgs.extend(square_imgs)
149
+ if circle_imgs:
150
+ imgs.extend(circle_imgs)
151
+ if imgs:
152
+ tmp[ident] = imgs
153
+ prompt_images = tmp or None
154
+ elif self.cfg.modality == "audio":
155
+ tmp_a: Dict[str, List[Dict[str, str]]] = {}
156
+ for ident, (circle, square) in zip(ids, pairs):
157
+ auds: List[Dict[str, str]] = []
158
+ circle_auds = load_audio_inputs(circle)
159
+ square_auds = load_audio_inputs(square)
160
+ if id_to_circle_first.get(ident, False):
161
+ if circle_auds:
162
+ auds.extend(circle_auds)
163
+ if square_auds:
164
+ auds.extend(square_auds)
165
+ else:
166
+ if square_auds:
167
+ auds.extend(square_auds)
168
+ if circle_auds:
169
+ auds.extend(circle_auds)
170
+ if auds:
171
+ tmp_a[ident] = auds
172
+ prompt_audio = tmp_a or None
173
+ elif self.cfg.modality == "pdf":
174
+ tmp_p: Dict[str, List[Dict[str, str]]] = {}
175
+ for ident, (circle, square) in zip(ids, pairs):
176
+ pdfs: List[Dict[str, str]] = []
177
+ circle_pdfs = load_pdf_inputs(circle)
178
+ square_pdfs = load_pdf_inputs(square)
179
+ if id_to_circle_first.get(ident, False):
180
+ if circle_pdfs:
181
+ pdfs.extend(circle_pdfs)
182
+ if square_pdfs:
183
+ pdfs.extend(square_pdfs)
184
+ else:
185
+ if square_pdfs:
186
+ pdfs.extend(square_pdfs)
187
+ if circle_pdfs:
188
+ pdfs.extend(circle_pdfs)
189
+ if pdfs:
190
+ tmp_p[ident] = pdfs
191
+ prompt_pdfs = tmp_p or None
192
+
193
+ csv_path = os.path.join(self.cfg.save_dir, self.cfg.file_name)
194
+
195
+ kwargs.setdefault("web_search", self.cfg.modality == "web")
196
+
197
+ df_resp_all = await get_all_responses(
198
+ prompts=prompts,
199
+ identifiers=ids,
200
+ prompt_images=prompt_images,
201
+ prompt_audio=prompt_audio,
202
+ prompt_pdfs=prompt_pdfs,
203
+ n_parallels=self.cfg.n_parallels,
204
+ model=self.cfg.model,
205
+ save_path=csv_path,
206
+ use_dummy=self.cfg.use_dummy,
207
+ max_timeout=self.cfg.max_timeout,
208
+ json_mode=self.cfg.modality != "audio",
209
+ reset_files=reset_files,
210
+ reasoning_effort=self.cfg.reasoning_effort,
211
+ reasoning_summary=self.cfg.reasoning_summary,
212
+ **kwargs,
213
+ )
214
+ if not isinstance(df_resp_all, pd.DataFrame):
215
+ raise RuntimeError("get_all_responses returned no DataFrame")
216
+
217
+ resp_map = dict(zip(df_resp_all.Identifier, df_resp_all.Response))
218
+ parsed = await asyncio.gather(*[self._parse(resp_map.get(i, "")) for i in ids])
219
+
220
+ records: List[Dict[str, str]] = []
221
+ for (circle, square), res in zip(pairs, parsed):
222
+ for attr, expl in res.items():
223
+ records.append(
224
+ {
225
+ circle_column_name: circle,
226
+ square_column_name: square,
227
+ "attribute": attr,
228
+ "explanation": expl,
229
+ }
230
+ )
231
+
232
+ out_df = pd.DataFrame(records)
233
+ if not out_df.empty:
234
+ out_df.set_index([circle_column_name, square_column_name], inplace=True)
235
+ return out_df