openai-gabriel 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. gabriel/__init__.py +61 -0
  2. gabriel/_version.py +1 -0
  3. gabriel/api.py +2284 -0
  4. gabriel/cli/__main__.py +60 -0
  5. gabriel/core/__init__.py +7 -0
  6. gabriel/core/llm_client.py +34 -0
  7. gabriel/core/pipeline.py +18 -0
  8. gabriel/core/prompt_template.py +152 -0
  9. gabriel/prompts/__init__.py +1 -0
  10. gabriel/prompts/bucket_prompt.jinja2 +113 -0
  11. gabriel/prompts/classification_prompt.jinja2 +50 -0
  12. gabriel/prompts/codify_prompt.jinja2 +95 -0
  13. gabriel/prompts/comparison_prompt.jinja2 +60 -0
  14. gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
  15. gabriel/prompts/deidentification_prompt.jinja2 +112 -0
  16. gabriel/prompts/extraction_prompt.jinja2 +61 -0
  17. gabriel/prompts/filter_prompt.jinja2 +31 -0
  18. gabriel/prompts/ideation_prompt.jinja2 +80 -0
  19. gabriel/prompts/merge_prompt.jinja2 +47 -0
  20. gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
  21. gabriel/prompts/rankings_prompt.jinja2 +49 -0
  22. gabriel/prompts/ratings_prompt.jinja2 +50 -0
  23. gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
  24. gabriel/prompts/seed.jinja2 +43 -0
  25. gabriel/prompts/snippets.jinja2 +117 -0
  26. gabriel/tasks/__init__.py +63 -0
  27. gabriel/tasks/_attribute_utils.py +69 -0
  28. gabriel/tasks/bucket.py +432 -0
  29. gabriel/tasks/classify.py +562 -0
  30. gabriel/tasks/codify.py +1033 -0
  31. gabriel/tasks/compare.py +235 -0
  32. gabriel/tasks/debias.py +1460 -0
  33. gabriel/tasks/deduplicate.py +341 -0
  34. gabriel/tasks/deidentify.py +316 -0
  35. gabriel/tasks/discover.py +524 -0
  36. gabriel/tasks/extract.py +455 -0
  37. gabriel/tasks/filter.py +169 -0
  38. gabriel/tasks/ideate.py +782 -0
  39. gabriel/tasks/merge.py +464 -0
  40. gabriel/tasks/paraphrase.py +531 -0
  41. gabriel/tasks/rank.py +2041 -0
  42. gabriel/tasks/rate.py +347 -0
  43. gabriel/tasks/seed.py +465 -0
  44. gabriel/tasks/whatever.py +344 -0
  45. gabriel/utils/__init__.py +64 -0
  46. gabriel/utils/audio_utils.py +42 -0
  47. gabriel/utils/file_utils.py +464 -0
  48. gabriel/utils/image_utils.py +22 -0
  49. gabriel/utils/jinja.py +31 -0
  50. gabriel/utils/logging.py +86 -0
  51. gabriel/utils/mapmaker.py +304 -0
  52. gabriel/utils/media_utils.py +78 -0
  53. gabriel/utils/modality_utils.py +148 -0
  54. gabriel/utils/openai_utils.py +5470 -0
  55. gabriel/utils/parsing.py +282 -0
  56. gabriel/utils/passage_viewer.py +2557 -0
  57. gabriel/utils/pdf_utils.py +20 -0
  58. gabriel/utils/plot_utils.py +2881 -0
  59. gabriel/utils/prompt_utils.py +42 -0
  60. gabriel/utils/word_matching.py +158 -0
  61. openai_gabriel-1.0.1.dist-info/METADATA +443 -0
  62. openai_gabriel-1.0.1.dist-info/RECORD +67 -0
  63. openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
  64. openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
  65. openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
  66. openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
  67. openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
gabriel/tasks/seed.py ADDED
@@ -0,0 +1,465 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import math
5
+ import os
6
+ import random
7
+ import re
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
11
+
12
+ import pandas as pd
13
+
14
+ from gabriel.core.prompt_template import PromptTemplate, resolve_template
15
+ from gabriel.tasks.deduplicate import Deduplicate, DeduplicateConfig
16
+ from gabriel.utils import safest_json
17
+ from gabriel.utils.openai_utils import get_all_responses
18
+ from gabriel.utils.logging import announce_prompt_rendering
19
+
20
+
21
+ @dataclass
22
+ class SeedConfig:
23
+ """Configuration options for :class:`Seed`."""
24
+
25
+ instructions: str
26
+ save_dir: str = os.path.expanduser("~/Documents/runs")
27
+ file_name: str = "seed_entities.csv"
28
+ model: str = "gpt-5.2"
29
+ n_parallels: int = 650
30
+ num_entities: int = 1000
31
+ entities_per_generation: int = 50
32
+ entity_batch_frac: float = 0.25
33
+ existing_entities_cap: int = 100
34
+ use_dummy: bool = False
35
+ deduplicate: bool = False
36
+ deduplicate_sample_seed: int = 42
37
+ max_timeout: Optional[float] = None
38
+ reasoning_effort: Optional[str] = None
39
+ reasoning_summary: Optional[str] = None
40
+
41
+
42
+ class Seed:
43
+ """Generate structured entity seeds via batched language-model calls."""
44
+
45
+ def __init__(
46
+ self,
47
+ cfg: SeedConfig,
48
+ template: Optional[PromptTemplate] = None,
49
+ template_path: Optional[str] = None,
50
+ ) -> None:
51
+ expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
52
+ expanded.mkdir(parents=True, exist_ok=True)
53
+ cfg.save_dir = str(expanded)
54
+ self.cfg = cfg
55
+ if cfg.num_entities <= 0:
56
+ raise ValueError("num_entities must be positive")
57
+ if cfg.entities_per_generation <= 0:
58
+ raise ValueError("entities_per_generation must be positive")
59
+ if not 0 < cfg.entity_batch_frac <= 1:
60
+ raise ValueError("entity_batch_frac must be between 0 and 1")
61
+ if cfg.existing_entities_cap < 0:
62
+ raise ValueError("existing_entities_cap must be non-negative")
63
+ self.template = resolve_template(
64
+ template=template,
65
+ template_path=template_path,
66
+ reference_filename="seed.jinja2",
67
+ )
68
+
69
+ async def run(
70
+ self,
71
+ *,
72
+ existing_entities: Optional[Sequence[str]] = None,
73
+ reset_files: bool = False,
74
+ **response_kwargs: Any,
75
+ ) -> pd.DataFrame:
76
+ """Generate ``num_entities`` unique seed entities."""
77
+ if self.cfg.deduplicate:
78
+ return await self._run_with_deduplication(
79
+ existing_entities=existing_entities,
80
+ reset_files=reset_files,
81
+ **response_kwargs,
82
+ )
83
+ return await self._run_standard(
84
+ existing_entities=existing_entities,
85
+ reset_files=reset_files,
86
+ **response_kwargs,
87
+ )
88
+
89
+ async def _run_standard(
90
+ self,
91
+ *,
92
+ existing_entities: Optional[Sequence[str]],
93
+ reset_files: bool,
94
+ **response_kwargs: Any,
95
+ ) -> pd.DataFrame:
96
+ normalized_existing = self._prepare_initial_existing(existing_entities)
97
+ seen: Dict[str, str] = {}
98
+ for ent in normalized_existing:
99
+ norm = self._normalize_entity(ent)
100
+ if norm and norm not in seen:
101
+ seen[norm] = ent
102
+
103
+ batch_target = max(
104
+ self.cfg.entities_per_generation,
105
+ math.ceil(self.cfg.num_entities * self.cfg.entity_batch_frac),
106
+ )
107
+ raw_save = os.path.join(self.cfg.save_dir, "seed_raw_responses.csv")
108
+ batch_index = 0
109
+ request_index = 0
110
+ reset_next = reset_files
111
+ while len(seen) < self.cfg.num_entities:
112
+ remaining = self.cfg.num_entities - len(seen)
113
+ batch_goal = min(batch_target, remaining)
114
+ batch_goal = max(batch_goal, self.cfg.entities_per_generation)
115
+ target_count = min(self.cfg.num_entities, len(seen) + batch_goal)
116
+ batch_added = 0
117
+ while len(seen) < target_count:
118
+ remaining_in_batch = target_count - len(seen)
119
+ current_goal = max(remaining_in_batch, self.cfg.entities_per_generation)
120
+ prompts, identifiers = self._build_prompts(
121
+ current_goal,
122
+ request_index,
123
+ list(seen.values()),
124
+ )
125
+ if not prompts:
126
+ break
127
+ print(
128
+ f"[Seed] Requesting {len(prompts)} prompts (batch {batch_index}, "
129
+ f"targeting {current_goal} entities; batch target {batch_goal})."
130
+ )
131
+ df_resp = await self._request_entities(
132
+ prompts,
133
+ identifiers,
134
+ raw_save=raw_save,
135
+ reset_files=reset_next,
136
+ **response_kwargs,
137
+ )
138
+ resp_lookup = dict(zip(df_resp.Identifier, df_resp.Response))
139
+ parsed = await asyncio.gather(
140
+ *[
141
+ self._parse_entities(resp_lookup.get(identifier, ""))
142
+ for identifier in identifiers
143
+ ]
144
+ )
145
+ added = 0
146
+ for entity_list in parsed:
147
+ for entity in entity_list:
148
+ norm = self._normalize_entity(entity)
149
+ if not norm or norm in seen:
150
+ continue
151
+ seen[norm] = entity
152
+ added += 1
153
+ batch_added += added
154
+ reset_next = False
155
+ request_index += 1
156
+ if added == 0 and not any(parsed):
157
+ break
158
+ print(
159
+ f"[Seed] Added {batch_added} new entities in batch {batch_index}. Total so far: {len(seen)}."
160
+ )
161
+ batch_index += 1
162
+ if batch_added == 0:
163
+ break
164
+
165
+ ordered = [seen[norm] for norm in seen]
166
+ trimmed = ordered[: self.cfg.num_entities]
167
+ return self._finalize_entities(trimmed)
168
+
169
+ async def _run_with_deduplication(
170
+ self,
171
+ *,
172
+ existing_entities: Optional[Sequence[str]],
173
+ reset_files: bool,
174
+ **response_kwargs: Any,
175
+ ) -> pd.DataFrame:
176
+ normalized_existing = self._prepare_initial_existing(existing_entities)
177
+ all_entities: List[str] = list(normalized_existing)
178
+ seen_norm: Set[str] = set()
179
+ for ent in all_entities:
180
+ norm = self._normalize_entity(ent)
181
+ if norm:
182
+ seen_norm.add(norm)
183
+
184
+ dedup_cycle = 0
185
+ deduped = await self._deduplicate_entities(
186
+ all_entities,
187
+ cycle_index=dedup_cycle,
188
+ reset_files=reset_files,
189
+ response_kwargs=response_kwargs,
190
+ )
191
+ dedup_cycle += 1
192
+ raw_target = self.cfg.num_entities * 4
193
+ batch_target = max(
194
+ self.cfg.entities_per_generation,
195
+ math.ceil(raw_target * self.cfg.entity_batch_frac),
196
+ )
197
+ raw_save = os.path.join(self.cfg.save_dir, "seed_raw_responses.csv")
198
+ batch_index = 0
199
+ request_index = 0
200
+ reset_next = reset_files
201
+ while len(deduped) < self.cfg.num_entities:
202
+ cycle_target = max(raw_target, len(all_entities))
203
+ while len(all_entities) < cycle_target:
204
+ remaining_in_cycle = cycle_target - len(all_entities)
205
+ current_goal = min(batch_target, remaining_in_cycle)
206
+ current_goal = max(current_goal, self.cfg.entities_per_generation)
207
+ prompts, identifiers = self._build_prompts(
208
+ current_goal,
209
+ request_index,
210
+ deduped,
211
+ )
212
+ if not prompts:
213
+ break
214
+ print(
215
+ f"[Seed] Requesting {len(prompts)} prompts (batch {batch_index}, "
216
+ f"targeting {current_goal} entities before deduplication)."
217
+ )
218
+ df_resp = await self._request_entities(
219
+ prompts,
220
+ identifiers,
221
+ raw_save=raw_save,
222
+ reset_files=reset_next,
223
+ **response_kwargs,
224
+ )
225
+ resp_lookup = dict(zip(df_resp.Identifier, df_resp.Response))
226
+ parsed = await asyncio.gather(
227
+ *[
228
+ self._parse_entities(resp_lookup.get(identifier, ""))
229
+ for identifier in identifiers
230
+ ]
231
+ )
232
+ added = 0
233
+ for entity_list in parsed:
234
+ for entity in entity_list:
235
+ norm = self._normalize_entity(entity)
236
+ if not norm or norm in seen_norm:
237
+ continue
238
+ all_entities.append(entity)
239
+ seen_norm.add(norm)
240
+ added += 1
241
+ reset_next = False
242
+ request_index += 1
243
+ if added == 0 and not any(parsed):
244
+ break
245
+
246
+ deduped = await self._deduplicate_entities(
247
+ all_entities,
248
+ cycle_index=dedup_cycle,
249
+ reset_files=False,
250
+ response_kwargs=response_kwargs,
251
+ )
252
+ dedup_cycle += 1
253
+ print(
254
+ f"[Seed] Unique after deduplication: {len(deduped)}."
255
+ )
256
+ if len(deduped) >= self.cfg.num_entities:
257
+ break
258
+
259
+ all_entities = list(deduped)
260
+ seen_norm = {
261
+ self._normalize_entity(entity)
262
+ for entity in all_entities
263
+ if self._normalize_entity(entity)
264
+ }
265
+ batch_index += 1
266
+ if not all_entities:
267
+ break
268
+
269
+ trimmed = self._sample_to_target(deduped)
270
+ return self._finalize_entities(trimmed)
271
+
272
+ async def _request_entities(
273
+ self,
274
+ prompts: List[str],
275
+ identifiers: List[str],
276
+ *,
277
+ raw_save: str,
278
+ reset_files: bool,
279
+ **response_kwargs: Any,
280
+ ) -> pd.DataFrame:
281
+ kwargs = dict(response_kwargs)
282
+ kwargs.setdefault("model", self.cfg.model)
283
+ kwargs.setdefault("n_parallels", self.cfg.n_parallels)
284
+ kwargs.setdefault("use_dummy", self.cfg.use_dummy)
285
+ kwargs.setdefault("max_timeout", self.cfg.max_timeout)
286
+ kwargs.setdefault("reasoning_effort", self.cfg.reasoning_effort)
287
+ kwargs.setdefault("reasoning_summary", self.cfg.reasoning_summary)
288
+ kwargs.setdefault("json_mode", True)
289
+ kwargs.setdefault("save_path", raw_save)
290
+ kwargs.setdefault("reset_files", reset_files)
291
+ df_resp = await get_all_responses(
292
+ prompts=prompts,
293
+ identifiers=identifiers,
294
+ **kwargs,
295
+ )
296
+ if not isinstance(df_resp, pd.DataFrame):
297
+ raise RuntimeError("get_all_responses returned no DataFrame")
298
+ return df_resp
299
+
300
+ def _finalize_entities(self, entities: List[str]) -> pd.DataFrame:
301
+ df = pd.DataFrame(
302
+ {
303
+ "entity": entities,
304
+ "entity_id": [f"entity-{idx:05d}" for idx in range(len(entities))],
305
+ }
306
+ )
307
+ df["source_batch"] = df.index // max(self.cfg.entities_per_generation, 1)
308
+ df["source_identifier"] = ["seed" for _ in range(len(entities))]
309
+ final_path = os.path.join(self.cfg.save_dir, self.cfg.file_name)
310
+ df.to_csv(final_path, index=False)
311
+ print(
312
+ f"[Seed] Generated {len(df)} entities. Saved aggregated seeds to {final_path}."
313
+ )
314
+ return df
315
+
316
+ async def _parse_entities(self, raw: Any) -> List[str]:
317
+ obj = await safest_json(raw)
318
+ results: List[str] = []
319
+ if isinstance(obj, dict):
320
+ for key in sorted(obj.keys()):
321
+ value = obj.get(key)
322
+ if value is None:
323
+ continue
324
+ text = str(value).strip()
325
+ if text:
326
+ results.append(text)
327
+ elif isinstance(obj, list):
328
+ for item in obj:
329
+ if item is None:
330
+ continue
331
+ text = str(item).strip()
332
+ if text:
333
+ results.append(text)
334
+ elif isinstance(obj, str):
335
+ text = obj.strip()
336
+ if text:
337
+ results.append(text)
338
+ return results
339
+
340
+ def _prepare_initial_existing(
341
+ self, entries: Optional[Sequence[str]]
342
+ ) -> List[str]:
343
+ if not entries:
344
+ return []
345
+ unique: List[str] = []
346
+ seen: Set[str] = set()
347
+ for entry in entries:
348
+ text = str(entry).strip()
349
+ if not text:
350
+ continue
351
+ norm = self._normalize_entity(text)
352
+ if norm and norm not in seen:
353
+ seen.add(norm)
354
+ unique.append(text)
355
+ return unique
356
+
357
+ def _build_prompts(
358
+ self,
359
+ goal: int,
360
+ batch_index: int,
361
+ seen_entities: Sequence[str],
362
+ ) -> Tuple[List[str], List[str]]:
363
+ prompts: List[str] = []
364
+ identifiers: List[str] = []
365
+ per_call = self.cfg.entities_per_generation
366
+ prompt_count = math.ceil(goal / per_call)
367
+ existing_sample = self._sample_existing(seen_entities)
368
+ existing_blob = "\n".join(existing_sample) if existing_sample else None
369
+ announce_prompt_rendering("Seed", prompt_count)
370
+ for call_index in range(prompt_count):
371
+ identifiers.append(f"seed|{batch_index}|{call_index}")
372
+ prompts.append(
373
+ self.template.render(
374
+ instructions=self.cfg.instructions,
375
+ entities_per_generation=per_call,
376
+ existing_entities=existing_blob,
377
+ )
378
+ )
379
+ return prompts, identifiers
380
+
381
+ def _sample_existing(self, seen_entities: Sequence[str]) -> List[str]:
382
+ if not seen_entities:
383
+ return []
384
+ cap = max(0, self.cfg.existing_entities_cap)
385
+ if cap == 0:
386
+ return []
387
+ if len(seen_entities) <= cap:
388
+ return list(seen_entities)
389
+ return random.sample(list(seen_entities), cap)
390
+
391
+ async def _deduplicate_entities(
392
+ self,
393
+ entities: Sequence[str],
394
+ *,
395
+ cycle_index: int,
396
+ reset_files: bool,
397
+ response_kwargs: Dict[str, Any],
398
+ ) -> List[str]:
399
+ cleaned: List[str] = []
400
+ for entity in entities:
401
+ text = str(entity).strip()
402
+ if text:
403
+ cleaned.append(text)
404
+ if not cleaned:
405
+ return []
406
+
407
+ df = pd.DataFrame({"entity": cleaned})
408
+ dedup_cfg = DeduplicateConfig(
409
+ save_dir=os.path.join(self.cfg.save_dir, "seed_deduplicate"),
410
+ file_name=f"seed_deduplicate_cycle{cycle_index}.csv",
411
+ model=self.cfg.model,
412
+ n_parallels=self.cfg.n_parallels,
413
+ n_runs=4,
414
+ use_dummy=self.cfg.use_dummy,
415
+ max_timeout=self.cfg.max_timeout,
416
+ group_size=100,
417
+ )
418
+ dedup = Deduplicate(dedup_cfg)
419
+ dedup_df = await dedup.run(
420
+ df,
421
+ column_name="entity",
422
+ reset_files=reset_files,
423
+ **self._filter_dedup_response_kwargs(response_kwargs),
424
+ )
425
+ mapped_col = (
426
+ "mapped_entity_final"
427
+ if "mapped_entity_final" in dedup_df.columns
428
+ else "mapped_entity"
429
+ )
430
+ unique: List[str] = []
431
+ seen: Set[str] = set()
432
+ for val in dedup_df[mapped_col]:
433
+ text = str(val).strip()
434
+ if not text or text in seen:
435
+ continue
436
+ seen.add(text)
437
+ unique.append(text)
438
+ return unique
439
+
440
+ @staticmethod
441
+ def _filter_dedup_response_kwargs(response_kwargs: Dict[str, Any]) -> Dict[str, Any]:
442
+ blocked = {
443
+ "model",
444
+ "n_parallels",
445
+ "save_path",
446
+ "use_dummy",
447
+ "max_timeout",
448
+ "json_mode",
449
+ "reset_files",
450
+ "prompts",
451
+ "identifiers",
452
+ }
453
+ return {key: value for key, value in response_kwargs.items() if key not in blocked}
454
+
455
+ def _sample_to_target(self, entities: List[str]) -> List[str]:
456
+ if len(entities) <= self.cfg.num_entities:
457
+ return entities[: self.cfg.num_entities]
458
+ rng = random.Random(self.cfg.deduplicate_sample_seed)
459
+ selected = set(rng.sample(entities, self.cfg.num_entities))
460
+ return [entity for entity in entities if entity in selected][: self.cfg.num_entities]
461
+
462
+ @staticmethod
463
+ def _normalize_entity(text: str) -> str:
464
+ collapsed = re.sub(r"\s+", " ", text.strip()).lower()
465
+ return collapsed