openai-gabriel 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. gabriel/__init__.py +61 -0
  2. gabriel/_version.py +1 -0
  3. gabriel/api.py +2284 -0
  4. gabriel/cli/__main__.py +60 -0
  5. gabriel/core/__init__.py +7 -0
  6. gabriel/core/llm_client.py +34 -0
  7. gabriel/core/pipeline.py +18 -0
  8. gabriel/core/prompt_template.py +152 -0
  9. gabriel/prompts/__init__.py +1 -0
  10. gabriel/prompts/bucket_prompt.jinja2 +113 -0
  11. gabriel/prompts/classification_prompt.jinja2 +50 -0
  12. gabriel/prompts/codify_prompt.jinja2 +95 -0
  13. gabriel/prompts/comparison_prompt.jinja2 +60 -0
  14. gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
  15. gabriel/prompts/deidentification_prompt.jinja2 +112 -0
  16. gabriel/prompts/extraction_prompt.jinja2 +61 -0
  17. gabriel/prompts/filter_prompt.jinja2 +31 -0
  18. gabriel/prompts/ideation_prompt.jinja2 +80 -0
  19. gabriel/prompts/merge_prompt.jinja2 +47 -0
  20. gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
  21. gabriel/prompts/rankings_prompt.jinja2 +49 -0
  22. gabriel/prompts/ratings_prompt.jinja2 +50 -0
  23. gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
  24. gabriel/prompts/seed.jinja2 +43 -0
  25. gabriel/prompts/snippets.jinja2 +117 -0
  26. gabriel/tasks/__init__.py +63 -0
  27. gabriel/tasks/_attribute_utils.py +69 -0
  28. gabriel/tasks/bucket.py +432 -0
  29. gabriel/tasks/classify.py +562 -0
  30. gabriel/tasks/codify.py +1033 -0
  31. gabriel/tasks/compare.py +235 -0
  32. gabriel/tasks/debias.py +1460 -0
  33. gabriel/tasks/deduplicate.py +341 -0
  34. gabriel/tasks/deidentify.py +316 -0
  35. gabriel/tasks/discover.py +524 -0
  36. gabriel/tasks/extract.py +455 -0
  37. gabriel/tasks/filter.py +169 -0
  38. gabriel/tasks/ideate.py +782 -0
  39. gabriel/tasks/merge.py +464 -0
  40. gabriel/tasks/paraphrase.py +531 -0
  41. gabriel/tasks/rank.py +2041 -0
  42. gabriel/tasks/rate.py +347 -0
  43. gabriel/tasks/seed.py +465 -0
  44. gabriel/tasks/whatever.py +344 -0
  45. gabriel/utils/__init__.py +64 -0
  46. gabriel/utils/audio_utils.py +42 -0
  47. gabriel/utils/file_utils.py +464 -0
  48. gabriel/utils/image_utils.py +22 -0
  49. gabriel/utils/jinja.py +31 -0
  50. gabriel/utils/logging.py +86 -0
  51. gabriel/utils/mapmaker.py +304 -0
  52. gabriel/utils/media_utils.py +78 -0
  53. gabriel/utils/modality_utils.py +148 -0
  54. gabriel/utils/openai_utils.py +5470 -0
  55. gabriel/utils/parsing.py +282 -0
  56. gabriel/utils/passage_viewer.py +2557 -0
  57. gabriel/utils/pdf_utils.py +20 -0
  58. gabriel/utils/plot_utils.py +2881 -0
  59. gabriel/utils/prompt_utils.py +42 -0
  60. gabriel/utils/word_matching.py +158 -0
  61. openai_gabriel-1.0.1.dist-info/METADATA +443 -0
  62. openai_gabriel-1.0.1.dist-info/RECORD +67 -0
  63. openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
  64. openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
  65. openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
  66. openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
  67. openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
gabriel/tasks/rate.py ADDED
@@ -0,0 +1,347 @@
1
+ # src/gabriel/tasks/rate.py
2
+ # ════════════════════════════════════════════════════════════════════
3
+ # Robust passage-rating task with optional debug logging.
4
+ # ════════════════════════════════════════════════════════════════════
5
+ from __future__ import annotations
6
+
7
+ import hashlib
8
+ import asyncio
9
+ from collections import defaultdict
10
+ from dataclasses import dataclass
11
+ from typing import Any, DefaultDict, Dict, List, Optional, Set
12
+ import os
13
+ from pathlib import Path
14
+
15
+ import pandas as pd
16
+
17
+ from ..core.prompt_template import PromptTemplate, resolve_template
18
+ from ..utils.openai_utils import get_all_responses
19
+ from ..utils import (
20
+ safest_json,
21
+ load_image_inputs,
22
+ load_audio_inputs,
23
+ load_pdf_inputs,
24
+ warn_if_modality_mismatch,
25
+ )
26
+ from ..utils.logging import announce_prompt_rendering
27
+ from ._attribute_utils import load_persisted_attributes
28
+
29
+
30
+ # ────────────────────────────
31
+ # Configuration dataclass
32
+ # ────────────────────────────
33
+ @dataclass
34
+ class RateConfig:
35
+ attributes: Dict[str, str]
36
+ save_dir: str = "ratings"
37
+ file_name: str = "ratings.csv"
38
+ model: str = "gpt-5-mini"
39
+ n_parallels: int = 650
40
+ n_runs: int = 1
41
+ use_dummy: bool = False
42
+ max_timeout: Optional[float] = None
43
+ rating_scale: Optional[str] = None
44
+ additional_instructions: Optional[str] = None
45
+ modality: str = "text"
46
+ n_attributes_per_run: int = 8
47
+ reasoning_effort: Optional[str] = None
48
+ reasoning_summary: Optional[str] = None
49
+ search_context_size: str = "medium"
50
+
51
+ def __post_init__(self) -> None:
52
+ if self.additional_instructions is not None:
53
+ cleaned = str(self.additional_instructions).strip()
54
+ self.additional_instructions = cleaned or None
55
+
56
+
57
+ # ────────────────────────────
58
+ # Main rating task
59
+ # ────────────────────────────
60
+ class Rate:
61
+ """Rate passages on specified attributes (0–100)."""
62
+
63
+
64
+ # -----------------------------------------------------------------
65
+ def __init__(
66
+ self,
67
+ cfg: RateConfig,
68
+ template: Optional[PromptTemplate] = None,
69
+ template_path: Optional[str] = None,
70
+ ) -> None:
71
+ expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
72
+ expanded.mkdir(parents=True, exist_ok=True)
73
+ cfg.save_dir = str(expanded)
74
+ self.cfg = cfg
75
+ self.template = resolve_template(
76
+ template=template,
77
+ template_path=template_path,
78
+ reference_filename="ratings_prompt.jinja2",
79
+ )
80
+
81
+ # -----------------------------------------------------------------
82
+ # Parse raw LLM output into {attribute: float}
83
+ # -----------------------------------------------------------------
84
+ async def _parse(self, raw: Any, attrs: List[str]) -> Dict[str, Optional[float]]:
85
+ obj = await safest_json(raw)
86
+ out: Dict[str, Optional[float]] = {}
87
+ if isinstance(obj, dict):
88
+ for attr in attrs:
89
+ try:
90
+ out[attr] = float(obj.get(attr)) if obj.get(attr) is not None else None
91
+ except Exception:
92
+ out[attr] = None
93
+ return out
94
+ return {attr: None for attr in attrs}
95
+
96
+ # -----------------------------------------------------------------
97
+ # Main entry point
98
+ # -----------------------------------------------------------------
99
+ async def run(
100
+ self,
101
+ df: pd.DataFrame,
102
+ column_name: str,
103
+ *,
104
+ debug: bool = False,
105
+ reset_files: bool = False,
106
+ **kwargs: Any,
107
+ ) -> pd.DataFrame:
108
+ """Return ``df`` with one column per attribute rating."""
109
+
110
+ df_proc = df.reset_index(drop=True).copy()
111
+ values = df_proc[column_name].tolist()
112
+ texts = [str(v) for v in values]
113
+ base_name = os.path.splitext(self.cfg.file_name)[0]
114
+
115
+ warn_if_modality_mismatch(values, self.cfg.modality, column_name=column_name)
116
+
117
+ base_ids: List[str] = []
118
+ id_to_rows: DefaultDict[str, List[int]] = defaultdict(list)
119
+ id_to_val: Dict[str, Any] = {}
120
+ prompt_texts: Dict[str, str] = {}
121
+ row_ids: List[str] = []
122
+
123
+ for row, (passage, orig) in enumerate(zip(texts, values)):
124
+ sha8 = hashlib.sha1(passage.encode()).hexdigest()[:8]
125
+ row_ids.append(sha8)
126
+ id_to_rows[sha8].append(row)
127
+ if len(id_to_rows[sha8]) > 1:
128
+ continue
129
+ id_to_val[sha8] = orig
130
+ prompt_texts[sha8] = passage if self.cfg.modality in {"text", "entity", "web"} else ""
131
+ base_ids.append(sha8)
132
+
133
+ df_proc["_gid"] = row_ids
134
+
135
+ self.cfg.attributes = load_persisted_attributes(
136
+ save_dir=self.cfg.save_dir,
137
+ incoming=self.cfg.attributes,
138
+ reset_files=reset_files,
139
+ task_name="Rate",
140
+ item_name="attributes",
141
+ legacy_filename=f"{base_name}_attrs.json",
142
+ )
143
+
144
+ attr_items = list(self.cfg.attributes.items())
145
+ attr_count = len(attr_items)
146
+ if attr_count > self.cfg.n_attributes_per_run:
147
+ batches = (attr_count + self.cfg.n_attributes_per_run - 1) // self.cfg.n_attributes_per_run
148
+ print(
149
+ f"[Rate] {attr_count} attributes provided. n_attributes_per_run={self.cfg.n_attributes_per_run}. "
150
+ f"Splitting into {batches} prompt batches. Increase n_attributes_per_run if you want all attributes "
151
+ "to be processed in the same prompt."
152
+ )
153
+ attr_batches: List[Dict[str, str]] = [
154
+ dict(attr_items[i : i + self.cfg.n_attributes_per_run])
155
+ for i in range(0, len(attr_items), self.cfg.n_attributes_per_run)
156
+ ]
157
+
158
+ prompts: List[str] = []
159
+ ids: List[str] = []
160
+ for batch_idx, batch_attrs in enumerate(attr_batches):
161
+ for ident in base_ids:
162
+ if batch_idx == 0 and not prompts and batch_attrs is attr_batches[0]:
163
+ announce_prompt_rendering(
164
+ "Rate",
165
+ len(base_ids) * len(attr_batches),
166
+ )
167
+ prompts.append(
168
+ self.template.render(
169
+ text=prompt_texts[ident],
170
+ attributes=batch_attrs,
171
+ scale=self.cfg.rating_scale,
172
+ additional_instructions=self.cfg.additional_instructions,
173
+ modality=self.cfg.modality,
174
+ )
175
+ )
176
+ ids.append(f"{ident}_batch{batch_idx}")
177
+
178
+ prompt_images: Optional[Dict[str, List[str]]] = None
179
+ prompt_audio: Optional[Dict[str, List[Dict[str, str]]]] = None
180
+ prompt_pdfs: Optional[Dict[str, List[Dict[str, str]]]] = None
181
+
182
+ if self.cfg.modality == "image":
183
+ tmp: Dict[str, List[str]] = {}
184
+ for ident, rows in id_to_rows.items():
185
+ imgs = load_image_inputs(values[rows[0]])
186
+ if imgs:
187
+ for batch_idx in range(len(attr_batches)):
188
+ tmp[f"{ident}_batch{batch_idx}"] = imgs
189
+ prompt_images = tmp or None
190
+ elif self.cfg.modality == "audio":
191
+ tmp_a: Dict[str, List[Dict[str, str]]] = {}
192
+ for ident, rows in id_to_rows.items():
193
+ auds = load_audio_inputs(values[rows[0]])
194
+ if auds:
195
+ for batch_idx in range(len(attr_batches)):
196
+ tmp_a[f"{ident}_batch{batch_idx}"] = auds
197
+ prompt_audio = tmp_a or None
198
+ elif self.cfg.modality == "pdf":
199
+ tmp_p: Dict[str, List[Dict[str, str]]] = {}
200
+ for ident, rows in id_to_rows.items():
201
+ pdfs = load_pdf_inputs(values[rows[0]])
202
+ if pdfs:
203
+ for batch_idx in range(len(attr_batches)):
204
+ tmp_p[f"{ident}_batch{batch_idx}"] = pdfs
205
+ prompt_pdfs = tmp_p or None
206
+
207
+ csv_path = os.path.join(self.cfg.save_dir, f"{base_name}_raw_responses.csv")
208
+ kwargs.setdefault("web_search", self.cfg.modality == "web")
209
+ kwargs.setdefault("search_context_size", self.cfg.search_context_size)
210
+
211
+ if not isinstance(self.cfg.n_runs, int) or self.cfg.n_runs < 1:
212
+ raise ValueError("n_runs must be an integer >= 1")
213
+
214
+ existing_ids: Set[str] = set()
215
+ if not reset_files and os.path.exists(csv_path):
216
+ try:
217
+ existing_df = pd.read_csv(csv_path, usecols=["Identifier"])
218
+ existing_ids = set(existing_df["Identifier"].astype(str))
219
+ except Exception:
220
+ existing_ids = set()
221
+
222
+ run_identifier_lists: List[List[str]] = []
223
+ for run_idx in range(1, self.cfg.n_runs + 1):
224
+ run_ids: List[str] = []
225
+ for ident in ids:
226
+ if run_idx == 1:
227
+ legacy_ident = f"{ident}_run1"
228
+ run_ids.append(legacy_ident if legacy_ident in existing_ids else ident)
229
+ else:
230
+ run_ids.append(f"{ident}_run{run_idx}")
231
+ run_identifier_lists.append(run_ids)
232
+
233
+ prompts_all: List[str] = []
234
+ ids_all: List[str] = []
235
+ for run_ids in run_identifier_lists:
236
+ prompts_all.extend(prompts)
237
+ ids_all.extend(run_ids)
238
+
239
+ prompt_images_all: Optional[Dict[str, List[str]]] = None
240
+ if prompt_images:
241
+ prompt_images_all = {}
242
+ for run_ids in run_identifier_lists:
243
+ for base_ident, run_ident in zip(ids, run_ids):
244
+ imgs = prompt_images.get(base_ident)
245
+ if imgs:
246
+ prompt_images_all[run_ident] = imgs
247
+ prompt_audio_all: Optional[Dict[str, List[Dict[str, str]]]] = None
248
+ if prompt_audio:
249
+ prompt_audio_all = {}
250
+ for run_ids in run_identifier_lists:
251
+ for base_ident, run_ident in zip(ids, run_ids):
252
+ auds = prompt_audio.get(base_ident)
253
+ if auds:
254
+ prompt_audio_all[run_ident] = auds
255
+ prompt_pdfs_all: Optional[Dict[str, List[Dict[str, str]]]] = None
256
+ if prompt_pdfs:
257
+ prompt_pdfs_all = {}
258
+ for run_ids in run_identifier_lists:
259
+ for base_ident, run_ident in zip(ids, run_ids):
260
+ pdfs = prompt_pdfs.get(base_ident)
261
+ if pdfs:
262
+ prompt_pdfs_all[run_ident] = pdfs
263
+
264
+ df_resp_all = await get_all_responses(
265
+ prompts=prompts_all,
266
+ identifiers=ids_all,
267
+ prompt_images=prompt_images_all,
268
+ prompt_audio=prompt_audio_all,
269
+ prompt_pdfs=prompt_pdfs_all,
270
+ n_parallels=self.cfg.n_parallels,
271
+ model=self.cfg.model,
272
+ save_path=csv_path,
273
+ use_dummy=self.cfg.use_dummy,
274
+ max_timeout=self.cfg.max_timeout,
275
+ json_mode=self.cfg.modality != "audio",
276
+ reset_files=reset_files,
277
+ reasoning_effort=self.cfg.reasoning_effort,
278
+ reasoning_summary=self.cfg.reasoning_summary,
279
+ **kwargs,
280
+ )
281
+
282
+ if not isinstance(df_resp_all, pd.DataFrame):
283
+ raise RuntimeError("get_all_responses returned no DataFrame")
284
+
285
+ df_resps = []
286
+ for run_idx, run_ids in enumerate(run_identifier_lists, start=1):
287
+ suffix = f"_run{run_idx}"
288
+ sub = df_resp_all[df_resp_all.Identifier.isin(run_ids)].copy()
289
+ sub.Identifier = sub.Identifier.str.replace(suffix + "$", "", regex=True)
290
+ df_resps.append(sub)
291
+
292
+ if debug:
293
+ print("\n── raw LLM responses ──")
294
+ for run_idx, df_resp in enumerate(df_resps, start=1):
295
+ for ident, raw in zip(df_resp.Identifier, df_resp.Response):
296
+ r = raw[0] if isinstance(raw, list) and raw else raw
297
+ print(f"[run {run_idx}] {ident} →\n{r}\n")
298
+ print("────────────────────────\n")
299
+
300
+ # parse each run and build disaggregated records
301
+ full_records: List[Dict[str, Any]] = []
302
+ base_attrs = list(self.cfg.attributes.keys())
303
+ for run_idx, df_resp in enumerate(df_resps, start=1):
304
+ id_to_ratings: Dict[str, Dict[str, Optional[float]]] = {
305
+ ident: {attr: None for attr in base_attrs} for ident in base_ids
306
+ }
307
+ for ident_batch, raw in zip(df_resp.Identifier, df_resp.Response):
308
+ main = raw[0] if isinstance(raw, list) and raw else raw
309
+ try:
310
+ base_ident, batch_part = ident_batch.rsplit("_batch", 1)
311
+ batch_idx = int(batch_part)
312
+ attrs = list(attr_batches[batch_idx].keys())
313
+ except (ValueError, IndexError):
314
+ if debug:
315
+ print(f"[Rate] Skipping malformed identifier {ident_batch}")
316
+ continue
317
+ if base_ident not in id_to_ratings:
318
+ if debug:
319
+ print(f"[Rate] Skipping unknown identifier {base_ident}")
320
+ continue
321
+ parsed = await self._parse(main, attrs)
322
+ for attr in attrs:
323
+ id_to_ratings[base_ident][attr] = parsed.get(attr)
324
+ for ident in base_ids:
325
+ parsed = id_to_ratings.get(ident, {attr: None for attr in base_attrs})
326
+ rec = {"id": ident, "text": id_to_val[ident], "run": run_idx}
327
+ rec.update({attr: parsed.get(attr) for attr in base_attrs})
328
+ full_records.append(rec)
329
+
330
+ full_df = pd.DataFrame(full_records).set_index(["id", "run"])
331
+ if self.cfg.n_runs > 1:
332
+ disagg_path = os.path.join(
333
+ self.cfg.save_dir, f"{base_name}_full_disaggregated.csv"
334
+ )
335
+ full_df.to_csv(disagg_path, index_label=["id", "run"])
336
+
337
+ # aggregate across runs
338
+ agg_df = full_df.groupby("id")[list(self.cfg.attributes)].mean()
339
+
340
+ out_path = os.path.join(self.cfg.save_dir, f"{base_name}_cleaned.csv")
341
+ result = df_proc.merge(agg_df, left_on="_gid", right_index=True, how="left")
342
+ result = result.drop(columns=["_gid"])
343
+ result.to_csv(out_path, index=False)
344
+
345
+ # keep raw response files for reference
346
+
347
+ return result