openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
{% macro single_entry(modality, text) %}
|
|
2
|
+
{% if modality == "text" %}
|
|
3
|
+
BEGIN TEXT ENTRY
|
|
4
|
+
{{ text }}
|
|
5
|
+
END TEXT ENTRY
|
|
6
|
+
Read entire text content carefully—start, middle, end. Do not skim; comprehend whole text deeply, including subtleties buried deep in the content.
|
|
7
|
+
{% elif modality == "image" %}
|
|
8
|
+
Carefully consider and analyze the provided image content. Comprehend it fully, both big picture and relevant subtleties.
|
|
9
|
+
{% elif modality == "audio" %}
|
|
10
|
+
Carefully consider and analyze the provided audio content. Comprehend it fully, both big picture and relevant subtleties. Pay close attention to both the content as well as the tone/style of the audio, and the sonic qualities of the content's presentation.
|
|
11
|
+
{% elif modality == "pdf" %}
|
|
12
|
+
Carefully consider and analyze the provided PDF content. Comprehend it fully, both big picture and relevant subtleties across the entire document.
|
|
13
|
+
{% elif modality == "entity" %}
|
|
14
|
+
Entity: {{ text }}
|
|
15
|
+
Use your prodigious internal knowledge on the item/entity "{{ text }}".
|
|
16
|
+
Comprehensively consider every relevant detail, both big picture and subtleties.
|
|
17
|
+
Consider "{{ text }}" as the content to be analyzed.
|
|
18
|
+
{% elif modality == "web" %}
|
|
19
|
+
Entity: {{ text }}
|
|
20
|
+
Comprehensively explore the web for all relevant information on the item/entity "{{ text }}".
|
|
21
|
+
Ensure you collect a highly truthful, accurate, comprehensive, and representative picture of "{{ text }}" from the web.
|
|
22
|
+
Consider "{{ text }}" as the content to be analyzed, fully grounded in the info you find on the web.
|
|
23
|
+
Thoroughly research "{{ text }}" on the web; creatively use different search approaches; gather as much info as possible to deeply characterize the relevant aspects of "{{ text }}".
|
|
24
|
+
Crucial: ONLY use the direct info you find on the web to perform your analysis. Collect the info then dispassionately analyze it with no preconceived notions.
|
|
25
|
+
{% endif %}
|
|
26
|
+
{% endmacro %}
|
|
27
|
+
|
|
28
|
+
{% macro pair_entries(modality, entry_square, entry_circle, circle_first=False) %}
|
|
29
|
+
{% if modality == "text" %}
|
|
30
|
+
Consider the following two distinct text entries:
|
|
31
|
+
{% if circle_first %}
|
|
32
|
+
BEGIN ENTRY CIRCLE
|
|
33
|
+
{{ entry_circle }}
|
|
34
|
+
END ENTRY CIRCLE
|
|
35
|
+
Above entry ("circle") is distinct from below entry ("square").
|
|
36
|
+
BEGIN ENTRY SQUARE
|
|
37
|
+
{{ entry_square }}
|
|
38
|
+
END ENTRY SQUARE
|
|
39
|
+
{% else %}
|
|
40
|
+
BEGIN ENTRY SQUARE
|
|
41
|
+
{{ entry_square }}
|
|
42
|
+
END ENTRY SQUARE
|
|
43
|
+
Above entry ("square") is distinct from below entry ("circle").
|
|
44
|
+
BEGIN ENTRY CIRCLE
|
|
45
|
+
{{ entry_circle }}
|
|
46
|
+
END ENTRY CIRCLE
|
|
47
|
+
{% endif %}
|
|
48
|
+
|
|
49
|
+
Read both text entries separately from one another. Process each by itself and fully-start, middle, end.
|
|
50
|
+
Do not skim; comprehend each whole text deeply, including subtleties buried deep in each text entry.
|
|
51
|
+
Understand every nuance before comparing.
|
|
52
|
+
{% elif modality == "image" %}
|
|
53
|
+
{% if circle_first %}
|
|
54
|
+
Two images are provided. **CIRCLE = first image. SQUARE = second image.** Comprehend both fully (overall + fine detail).
|
|
55
|
+
Again, "square" references the second, later image entry; "circle" references the first, earlier image entry.
|
|
56
|
+
Be very clear in your mind that "circle" is the first image and "square" is the second image.
|
|
57
|
+
{% else %}
|
|
58
|
+
Two images are provided. **SQUARE = first image. CIRCLE = second image.** Comprehend both fully (overall + fine detail).
|
|
59
|
+
Again, "circle" references the second, later image entry; "square" references the first, earlier image entry.
|
|
60
|
+
Be very clear in your mind that "square" is the first image and "circle" is the second image.
|
|
61
|
+
{% endif %}
|
|
62
|
+
{% elif modality == "audio" %}
|
|
63
|
+
{% if circle_first %}
|
|
64
|
+
Two audio files are provided. **CIRCLE = first audio recording. SQUARE = second audio recording.** Comprehend both fully (content, style, tone, sonic qualities).
|
|
65
|
+
Again, "square" references the second, later audio entry; "circle" references the first, earlier audio entry.
|
|
66
|
+
Be very clear in your mind that "circle" is the first audio recording and "square" is the second audio recording.
|
|
67
|
+
{% else %}
|
|
68
|
+
Two audio files are provided. **SQUARE = first audio recording. CIRCLE = second audio recording.** Comprehend both fully (content, style, tone, sonic qualities).
|
|
69
|
+
Again, "circle" references the second, later audio entry; "square" references the first, earlier audio entry.
|
|
70
|
+
Be very clear in your mind that "square" is the first audio recording and "circle" is the second audio recording.
|
|
71
|
+
{% endif %}
|
|
72
|
+
{% elif modality == "pdf" %}
|
|
73
|
+
{% if circle_first %}
|
|
74
|
+
Two PDFs are provided. **CIRCLE = first PDF. SQUARE = second PDF.** Comprehend both fully (overall content + fine detail).
|
|
75
|
+
Again, "square" references the second, later PDF entry; "circle" references the first, earlier PDF entry.
|
|
76
|
+
Be very clear in your mind that "circle" is the first PDF and "square" is the second PDF.
|
|
77
|
+
{% else %}
|
|
78
|
+
Two PDFs are provided. **SQUARE = first PDF. CIRCLE = second PDF.** Comprehend both fully (overall content + fine detail).
|
|
79
|
+
Again, "circle" references the second, later PDF entry; "square" references the first, earlier PDF entry.
|
|
80
|
+
Be very clear in your mind that "square" is the first PDF and "circle" is the second PDF.
|
|
81
|
+
{% endif %}
|
|
82
|
+
{% elif modality == "entity" %}
|
|
83
|
+
{% if circle_first %}
|
|
84
|
+
Entity circle: {{ entry_circle }}
|
|
85
|
+
Entity square: {{ entry_square }}
|
|
86
|
+
Use your prodigious internal knowledge on both item/entity square ("{{ entry_square }}") and item/entity circle ("{{ entry_circle }}").
|
|
87
|
+
Comprehensively consider every relevant detail you know about each, both big picture and subtleties.
|
|
88
|
+
Consider "{{ entry_square }}" ("square") and "{{ entry_circle }}" ("circle") as the entries/content to be compared, using your internal knowledge.
|
|
89
|
+
{% else %}
|
|
90
|
+
Entity square: {{ entry_square }}
|
|
91
|
+
Entity circle: {{ entry_circle }}
|
|
92
|
+
Use your prodigious internal knowledge on both item/entity circle ("{{ entry_circle }}") and item/entity square ("{{ entry_square }}").
|
|
93
|
+
Comprehensively consider every relevant detail you know about each, both big picture and subtleties.
|
|
94
|
+
Consider "{{ entry_circle }}" ("circle") and "{{ entry_square }}" ("square") as the entries/content to be compared, using your internal knowledge.
|
|
95
|
+
{% endif %}
|
|
96
|
+
{% elif modality == "web" %}
|
|
97
|
+
{% if circle_first %}
|
|
98
|
+
Entity circle: {{ entry_circle }}
|
|
99
|
+
Entity square: {{ entry_square }}
|
|
100
|
+
Comprehensively explore the web for all relevant information on both item/entity square ("{{ entry_square }}") and item/entity circle ("{{ entry_circle }}").
|
|
101
|
+
Ensure you collect a highly truthful, accurate, comprehensive, and representative picture of each from the web.
|
|
102
|
+
Critical that you search on each entity separately and with equal thoroughness.
|
|
103
|
+
Consider "{{ entry_square }}" ("square") and "{{ entry_circle }}" ("circle") as the entries/content to be compared, fully grounded in the info you find on the web.
|
|
104
|
+
Thoroughly research each entity on the web; creatively use different search approaches; gather as much info as possible to deeply characterize the relevant aspects of both square and circle.
|
|
105
|
+
Crucial: ONLY use the direct info you find on the web to perform your analysis. Collect the info then dispassionately analyze it with no preconceived notions.
|
|
106
|
+
{% else %}
|
|
107
|
+
Entity square: {{ entry_square }}
|
|
108
|
+
Entity circle: {{ entry_circle }}
|
|
109
|
+
Comprehensively explore the web for all relevant information on both item/entity circle ("{{ entry_circle }}") and item/entity square ("{{ entry_square }}").
|
|
110
|
+
Ensure you collect a highly truthful, accurate, comprehensive, and representative picture of each from the web.
|
|
111
|
+
Critical that you search on each entity separately and with equal thoroughness.
|
|
112
|
+
Consider "{{ entry_circle }}" ("circle") and "{{ entry_square }}" ("square") as the entries/content to be compared, fully grounded in the info you find on the web.
|
|
113
|
+
Thoroughly research each entity on the web; creatively use different search approaches; gather as much info as possible to deeply characterize the relevant aspects of both circle and square.
|
|
114
|
+
Crucial: ONLY use the direct info you find on the web to perform your analysis. Collect the info then dispassionately analyze it with no preconceived notions.
|
|
115
|
+
{% endif %}
|
|
116
|
+
{% endif %}
|
|
117
|
+
{% endmacro %}
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Task implementations for GABRIEL."""
|
|
2
|
+
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
_lazy_imports = {
|
|
7
|
+
"Rate": ".rate",
|
|
8
|
+
"RateConfig": ".rate",
|
|
9
|
+
"Deidentifier": ".deidentify",
|
|
10
|
+
"DeidentifyConfig": ".deidentify",
|
|
11
|
+
"EloRater": ".elo",
|
|
12
|
+
"EloConfig": ".elo",
|
|
13
|
+
"Classify": ".classify",
|
|
14
|
+
"ClassifyConfig": ".classify",
|
|
15
|
+
"Rank": ".rank",
|
|
16
|
+
"RankConfig": ".rank",
|
|
17
|
+
"Codify": ".codify",
|
|
18
|
+
"CodifyConfig": ".codify",
|
|
19
|
+
"Paraphrase": ".paraphrase",
|
|
20
|
+
"ParaphraseConfig": ".paraphrase",
|
|
21
|
+
"Extract": ".extract",
|
|
22
|
+
"ExtractConfig": ".extract",
|
|
23
|
+
"Regional": ".regional",
|
|
24
|
+
"RegionalConfig": ".regional",
|
|
25
|
+
"CountyCounter": ".county_counter",
|
|
26
|
+
"Compare": ".compare",
|
|
27
|
+
"CompareConfig": ".compare",
|
|
28
|
+
"Merge": ".merge",
|
|
29
|
+
"MergeConfig": ".merge",
|
|
30
|
+
"Deduplicate": ".deduplicate",
|
|
31
|
+
"DeduplicateConfig": ".deduplicate",
|
|
32
|
+
"Bucket": ".bucket",
|
|
33
|
+
"BucketConfig": ".bucket",
|
|
34
|
+
"Discover": ".discover",
|
|
35
|
+
"DiscoverConfig": ".discover",
|
|
36
|
+
"Seed": ".seed",
|
|
37
|
+
"SeedConfig": ".seed",
|
|
38
|
+
"Filter": ".filter",
|
|
39
|
+
"FilterConfig": ".filter",
|
|
40
|
+
"Whatever": ".whatever",
|
|
41
|
+
"WhateverConfig": ".whatever",
|
|
42
|
+
"Ideate": ".ideate",
|
|
43
|
+
"IdeateConfig": ".ideate",
|
|
44
|
+
"DebiasPipeline": ".debias",
|
|
45
|
+
"DebiasConfig": ".debias",
|
|
46
|
+
"DebiasResult": ".debias",
|
|
47
|
+
"DebiasRegressionResult": ".debias",
|
|
48
|
+
"MeasurementMode": ".debias",
|
|
49
|
+
"RemovalMethod": ".debias",
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
__all__ = list(_lazy_imports.keys())
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def __getattr__(name: str):
|
|
56
|
+
if name in _lazy_imports:
|
|
57
|
+
module = import_module(_lazy_imports[name], __name__)
|
|
58
|
+
return getattr(module, name)
|
|
59
|
+
raise AttributeError(name)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def __dir__() -> List[str]:
|
|
63
|
+
return __all__
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def load_persisted_attributes(
|
|
7
|
+
*,
|
|
8
|
+
save_dir: str,
|
|
9
|
+
incoming: Dict[str, Any],
|
|
10
|
+
reset_files: bool,
|
|
11
|
+
task_name: str,
|
|
12
|
+
item_name: str = "attributes",
|
|
13
|
+
legacy_filename: Optional[str] = None,
|
|
14
|
+
) -> Dict[str, Any]:
|
|
15
|
+
"""Load attributes/labels from disk for reproducibility.
|
|
16
|
+
|
|
17
|
+
Preference order:
|
|
18
|
+
1) ``attributes.json`` in ``save_dir``
|
|
19
|
+
2) ``legacy_filename`` (e.g., ``ratings_attrs.json``) in ``save_dir``
|
|
20
|
+
When neither exists, ``incoming`` is written to both paths (when
|
|
21
|
+
applicable) for future runs.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
primary_path = os.path.join(save_dir, "attributes.json")
|
|
25
|
+
legacy_path = os.path.join(save_dir, legacy_filename) if legacy_filename else None
|
|
26
|
+
candidate_paths = [primary_path] + ([legacy_path] if legacy_path else [])
|
|
27
|
+
|
|
28
|
+
if reset_files:
|
|
29
|
+
for path in candidate_paths:
|
|
30
|
+
if path and os.path.exists(path):
|
|
31
|
+
try:
|
|
32
|
+
os.remove(path)
|
|
33
|
+
except Exception:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
loaded: Optional[Dict[str, Any]] = None
|
|
37
|
+
source_path: Optional[str] = None
|
|
38
|
+
for path in candidate_paths:
|
|
39
|
+
if not path or not os.path.exists(path):
|
|
40
|
+
continue
|
|
41
|
+
try:
|
|
42
|
+
with open(path) as f:
|
|
43
|
+
loaded = json.load(f)
|
|
44
|
+
source_path = path
|
|
45
|
+
break
|
|
46
|
+
except Exception:
|
|
47
|
+
continue
|
|
48
|
+
|
|
49
|
+
if loaded is not None:
|
|
50
|
+
message = (
|
|
51
|
+
f"[{task_name}] Found saved {item_name} in {source_path}. Using them for consistency."
|
|
52
|
+
)
|
|
53
|
+
if loaded != incoming:
|
|
54
|
+
message += (
|
|
55
|
+
f" The provided {item_name} differ; set reset_files=True or use a new save_dir to update them."
|
|
56
|
+
)
|
|
57
|
+
print(message)
|
|
58
|
+
return loaded
|
|
59
|
+
|
|
60
|
+
for path in candidate_paths:
|
|
61
|
+
if not path:
|
|
62
|
+
continue
|
|
63
|
+
try:
|
|
64
|
+
with open(path, "w") as f:
|
|
65
|
+
json.dump(incoming, f, indent=2)
|
|
66
|
+
except Exception:
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
return incoming
|
gabriel/tasks/bucket.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import hashlib
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import random
|
|
8
|
+
import math
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Dict, List, Optional, Set
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
from ..core.prompt_template import PromptTemplate, resolve_template
|
|
18
|
+
from ..utils.openai_utils import get_all_responses
|
|
19
|
+
from ..utils import safest_json
|
|
20
|
+
from ..utils.logging import announce_prompt_rendering
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class BucketConfig:
|
|
25
|
+
"""Configuration for :class:`Bucket`."""
|
|
26
|
+
|
|
27
|
+
bucket_count: int = 10
|
|
28
|
+
save_dir: str = "buckets"
|
|
29
|
+
file_name: str = "bucket_definitions.csv"
|
|
30
|
+
model: str = "gpt-5-mini"
|
|
31
|
+
n_parallels: int = 650
|
|
32
|
+
use_dummy: bool = False
|
|
33
|
+
max_timeout: Optional[float] = None
|
|
34
|
+
additional_instructions: Optional[str] = None
|
|
35
|
+
differentiate: bool = False
|
|
36
|
+
n_terms_per_prompt: int = 250
|
|
37
|
+
repeat_bucketing: int = 5
|
|
38
|
+
repeat_voting: int = 25
|
|
39
|
+
next_round_frac: float = 0.25
|
|
40
|
+
top_k_per_round: int = 1
|
|
41
|
+
raw_term_definitions: bool = True
|
|
42
|
+
reasoning_effort: Optional[str] = None
|
|
43
|
+
reasoning_summary: Optional[str] = None
|
|
44
|
+
|
|
45
|
+
def __post_init__(self) -> None:
|
|
46
|
+
if self.additional_instructions is not None:
|
|
47
|
+
cleaned = str(self.additional_instructions).strip()
|
|
48
|
+
self.additional_instructions = cleaned or None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Bucket:
|
|
52
|
+
"""Group raw terms into a smaller set of mutually exclusive buckets."""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
cfg: BucketConfig,
|
|
57
|
+
template: Optional[PromptTemplate] = None,
|
|
58
|
+
template_path: Optional[str] = None,
|
|
59
|
+
) -> None:
|
|
60
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
61
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
62
|
+
cfg.save_dir = str(expanded)
|
|
63
|
+
self.cfg = cfg
|
|
64
|
+
self.template = resolve_template(
|
|
65
|
+
template=template,
|
|
66
|
+
template_path=template_path,
|
|
67
|
+
reference_filename="bucket_prompt.jinja2",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# ------------------------------------------------------------------
|
|
71
|
+
# Helpers for persisting intermediate progress
|
|
72
|
+
# ------------------------------------------------------------------
|
|
73
|
+
def _state_path(self) -> str:
|
|
74
|
+
return os.path.join(self.cfg.save_dir, "bucket_state.json")
|
|
75
|
+
|
|
76
|
+
def _read_state(self) -> Dict[str, Any]:
|
|
77
|
+
path = self._state_path()
|
|
78
|
+
try:
|
|
79
|
+
with open(path, "r", encoding="utf-8") as fh:
|
|
80
|
+
data = json.load(fh)
|
|
81
|
+
if isinstance(data, dict):
|
|
82
|
+
return data
|
|
83
|
+
except FileNotFoundError:
|
|
84
|
+
return {}
|
|
85
|
+
except Exception:
|
|
86
|
+
return {}
|
|
87
|
+
return {}
|
|
88
|
+
|
|
89
|
+
def _write_state(self, state: Dict[str, Any]) -> None:
|
|
90
|
+
path = self._state_path()
|
|
91
|
+
payload = dict(state)
|
|
92
|
+
payload["updated_at"] = datetime.utcnow().isoformat() + "Z"
|
|
93
|
+
try:
|
|
94
|
+
with open(path, "w", encoding="utf-8") as fh:
|
|
95
|
+
json.dump(payload, fh, ensure_ascii=False, indent=2)
|
|
96
|
+
except Exception:
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
def _terms_signature(self, terms: List[str], term_map: Dict[str, str]) -> str:
|
|
100
|
+
if self.cfg.raw_term_definitions:
|
|
101
|
+
entries = [f"{t}::{term_map.get(t, '')}" for t in sorted(terms)]
|
|
102
|
+
else:
|
|
103
|
+
entries = sorted(terms)
|
|
104
|
+
joined = "||".join(entries)
|
|
105
|
+
return hashlib.sha1(joined.encode("utf-8")).hexdigest()
|
|
106
|
+
|
|
107
|
+
async def _parse(self, raw: Any) -> Dict[str, str]:
|
|
108
|
+
obj = await safest_json(raw)
|
|
109
|
+
if isinstance(obj, list) and obj:
|
|
110
|
+
obj = obj[0]
|
|
111
|
+
if isinstance(obj, dict):
|
|
112
|
+
return {str(k): str(v) if v is not None else "" for k, v in obj.items()}
|
|
113
|
+
return {}
|
|
114
|
+
|
|
115
|
+
async def run(
|
|
116
|
+
self,
|
|
117
|
+
df: pd.DataFrame,
|
|
118
|
+
column_name: str,
|
|
119
|
+
*,
|
|
120
|
+
reset_files: bool = False,
|
|
121
|
+
**kwargs: Any,
|
|
122
|
+
) -> pd.DataFrame:
|
|
123
|
+
cache_path = os.path.join(self.cfg.save_dir, self.cfg.file_name)
|
|
124
|
+
state_path = self._state_path()
|
|
125
|
+
if reset_files and os.path.exists(state_path):
|
|
126
|
+
try:
|
|
127
|
+
os.remove(state_path)
|
|
128
|
+
except Exception:
|
|
129
|
+
pass
|
|
130
|
+
if not reset_files and os.path.exists(cache_path):
|
|
131
|
+
try:
|
|
132
|
+
cached = pd.read_csv(cache_path)
|
|
133
|
+
if {"bucket", "definition"}.issubset(cached.columns):
|
|
134
|
+
cols = ["bucket", "definition"]
|
|
135
|
+
return cached[cols]
|
|
136
|
+
except Exception:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
state: Dict[str, Any] = {} if reset_files else self._read_state()
|
|
140
|
+
|
|
141
|
+
df_proc = df.reset_index(drop=True).copy()
|
|
142
|
+
raw_entries = df_proc[column_name].dropna().tolist()
|
|
143
|
+
|
|
144
|
+
seen: Set[str] = set()
|
|
145
|
+
terms: List[str] = []
|
|
146
|
+
term_map: Dict[str, str] = {}
|
|
147
|
+
for entry in raw_entries:
|
|
148
|
+
if isinstance(entry, dict):
|
|
149
|
+
for k, v in entry.items():
|
|
150
|
+
key = str(k)
|
|
151
|
+
if key not in seen:
|
|
152
|
+
seen.add(key)
|
|
153
|
+
terms.append(key)
|
|
154
|
+
term_map.setdefault(key, str(v) if v is not None else "")
|
|
155
|
+
elif isinstance(entry, list):
|
|
156
|
+
for item in entry:
|
|
157
|
+
key = str(item)
|
|
158
|
+
if key not in seen:
|
|
159
|
+
seen.add(key)
|
|
160
|
+
terms.append(key)
|
|
161
|
+
term_map.setdefault(key, "")
|
|
162
|
+
else:
|
|
163
|
+
key = str(entry)
|
|
164
|
+
if key not in seen:
|
|
165
|
+
seen.add(key)
|
|
166
|
+
terms.append(key)
|
|
167
|
+
term_map.setdefault(key, "")
|
|
168
|
+
|
|
169
|
+
if not terms:
|
|
170
|
+
return pd.DataFrame(columns=["bucket", "definition"])
|
|
171
|
+
|
|
172
|
+
signature = self._terms_signature(terms, term_map)
|
|
173
|
+
if state.get("terms_signature") != signature:
|
|
174
|
+
state = {"terms_signature": signature}
|
|
175
|
+
else:
|
|
176
|
+
state["terms_signature"] = signature
|
|
177
|
+
|
|
178
|
+
def persist_state() -> None:
|
|
179
|
+
self._write_state(state)
|
|
180
|
+
|
|
181
|
+
if state.get("finalized") and state.get("final_buckets") is not None:
|
|
182
|
+
records = state.get("final_buckets") or []
|
|
183
|
+
final_df = pd.DataFrame(records)
|
|
184
|
+
if not final_df.empty and not {"bucket", "definition"}.issubset(final_df.columns):
|
|
185
|
+
final_df = final_df.rename(columns={0: "bucket", 1: "definition"})
|
|
186
|
+
if not final_df.empty and not os.path.exists(cache_path):
|
|
187
|
+
try:
|
|
188
|
+
final_df.to_csv(cache_path, index=False)
|
|
189
|
+
except Exception:
|
|
190
|
+
pass
|
|
191
|
+
if not final_df.empty:
|
|
192
|
+
cols = ["bucket", "definition"]
|
|
193
|
+
return final_df[cols]
|
|
194
|
+
return pd.DataFrame(columns=["bucket", "definition"])
|
|
195
|
+
|
|
196
|
+
# ── 1: generate bucket candidates ───────────────────────────────
|
|
197
|
+
candidate_defs: Dict[str, str] = {}
|
|
198
|
+
candidates: List[str] = []
|
|
199
|
+
if state.get("candidate_defs") and isinstance(state["candidate_defs"], dict):
|
|
200
|
+
cached_defs = {
|
|
201
|
+
str(k): str(v) if v is not None else ""
|
|
202
|
+
for k, v in state["candidate_defs"].items()
|
|
203
|
+
}
|
|
204
|
+
candidate_defs.update(cached_defs)
|
|
205
|
+
candidates = [
|
|
206
|
+
c
|
|
207
|
+
for c in state.get("candidates", list(candidate_defs.keys()))
|
|
208
|
+
if c in candidate_defs
|
|
209
|
+
]
|
|
210
|
+
if not candidate_defs:
|
|
211
|
+
prompts: List[str] = []
|
|
212
|
+
ids: List[str] = []
|
|
213
|
+
chunks_per_rep = max(
|
|
214
|
+
1, math.ceil(len(terms) / self.cfg.n_terms_per_prompt)
|
|
215
|
+
)
|
|
216
|
+
announce_prompt_rendering(
|
|
217
|
+
"Bucket:generate",
|
|
218
|
+
chunks_per_rep * self.cfg.repeat_bucketing,
|
|
219
|
+
)
|
|
220
|
+
for rep in range(self.cfg.repeat_bucketing):
|
|
221
|
+
random.shuffle(terms)
|
|
222
|
+
chunks = [
|
|
223
|
+
terms[i : i + self.cfg.n_terms_per_prompt]
|
|
224
|
+
for i in range(0, len(terms), self.cfg.n_terms_per_prompt)
|
|
225
|
+
]
|
|
226
|
+
for ci, chunk in enumerate(chunks):
|
|
227
|
+
chunk_data = (
|
|
228
|
+
{t: term_map.get(t, "") for t in chunk}
|
|
229
|
+
if self.cfg.raw_term_definitions
|
|
230
|
+
else chunk
|
|
231
|
+
)
|
|
232
|
+
prompts.append(
|
|
233
|
+
self.template.render(
|
|
234
|
+
terms=chunk_data,
|
|
235
|
+
bucket_count=self.cfg.bucket_count,
|
|
236
|
+
differentiate=self.cfg.differentiate,
|
|
237
|
+
additional_instructions=self.cfg.additional_instructions or "",
|
|
238
|
+
voting=False,
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
ids.append(f"gen|{rep}|{ci}")
|
|
242
|
+
|
|
243
|
+
gen_df = await get_all_responses(
|
|
244
|
+
prompts=prompts,
|
|
245
|
+
identifiers=ids,
|
|
246
|
+
n_parallels=self.cfg.n_parallels,
|
|
247
|
+
model=self.cfg.model,
|
|
248
|
+
save_path=os.path.join(self.cfg.save_dir, "bucket_generation.csv"),
|
|
249
|
+
use_dummy=self.cfg.use_dummy,
|
|
250
|
+
max_timeout=self.cfg.max_timeout,
|
|
251
|
+
json_mode=True,
|
|
252
|
+
reset_files=reset_files,
|
|
253
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
254
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
255
|
+
**kwargs,
|
|
256
|
+
)
|
|
257
|
+
if not isinstance(gen_df, pd.DataFrame):
|
|
258
|
+
raise RuntimeError("get_all_responses returned no DataFrame")
|
|
259
|
+
|
|
260
|
+
resp_map = dict(zip(gen_df.Identifier, gen_df.Response))
|
|
261
|
+
parsed = await asyncio.gather(*[self._parse(resp_map.get(i, "")) for i in ids])
|
|
262
|
+
for res in parsed:
|
|
263
|
+
for b, j in res.items():
|
|
264
|
+
candidate_defs.setdefault(b, j)
|
|
265
|
+
|
|
266
|
+
candidates = list(candidate_defs.keys())
|
|
267
|
+
state["candidate_defs"] = candidate_defs
|
|
268
|
+
state["candidates"] = candidates
|
|
269
|
+
state["stage"] = "candidates"
|
|
270
|
+
persist_state()
|
|
271
|
+
elif not candidates:
|
|
272
|
+
candidates = list(candidate_defs.keys())
|
|
273
|
+
|
|
274
|
+
# helper to build voting prompts
|
|
275
|
+
def _vote_prompts(opts: List[str], selected: List[str], tag: str):
|
|
276
|
+
pr: List[str] = []
|
|
277
|
+
idn: List[str] = []
|
|
278
|
+
chunks_per_rep = max(
|
|
279
|
+
1, math.ceil(len(opts) / self.cfg.n_terms_per_prompt)
|
|
280
|
+
)
|
|
281
|
+
announce_prompt_rendering(
|
|
282
|
+
f"Bucket:{tag}",
|
|
283
|
+
chunks_per_rep * self.cfg.repeat_voting,
|
|
284
|
+
)
|
|
285
|
+
for rep in range(self.cfg.repeat_voting):
|
|
286
|
+
random.shuffle(opts)
|
|
287
|
+
chunks = [
|
|
288
|
+
opts[i : i + self.cfg.n_terms_per_prompt]
|
|
289
|
+
for i in range(0, len(opts), self.cfg.n_terms_per_prompt)
|
|
290
|
+
]
|
|
291
|
+
for ci, ch in enumerate(chunks):
|
|
292
|
+
sample_list = random.sample(
|
|
293
|
+
terms, min(len(terms), self.cfg.n_terms_per_prompt)
|
|
294
|
+
)
|
|
295
|
+
sample_terms = (
|
|
296
|
+
{t: term_map.get(t, "") for t in sample_list}
|
|
297
|
+
if self.cfg.raw_term_definitions
|
|
298
|
+
else sample_list
|
|
299
|
+
)
|
|
300
|
+
selected_map = {
|
|
301
|
+
b: candidate_defs.get(b, "") for b in selected
|
|
302
|
+
}
|
|
303
|
+
pr.append(
|
|
304
|
+
self.template.render(
|
|
305
|
+
terms=sample_terms,
|
|
306
|
+
bucket_count=self.cfg.bucket_count,
|
|
307
|
+
differentiate=self.cfg.differentiate,
|
|
308
|
+
additional_instructions=self.cfg.additional_instructions
|
|
309
|
+
or "",
|
|
310
|
+
voting=True,
|
|
311
|
+
bucket_candidates=ch,
|
|
312
|
+
selected_buckets=selected_map if selected_map else None,
|
|
313
|
+
)
|
|
314
|
+
)
|
|
315
|
+
idn.append(f"vote|{tag}|{rep}|{ci}")
|
|
316
|
+
return pr, idn
|
|
317
|
+
|
|
318
|
+
# ── 2: iterative reduction ─────────────────────────────────────
|
|
319
|
+
current = candidates[:]
|
|
320
|
+
if state.get("current_candidates"):
|
|
321
|
+
saved_current = [
|
|
322
|
+
c for c in state["current_candidates"] if c in candidate_defs
|
|
323
|
+
]
|
|
324
|
+
if saved_current:
|
|
325
|
+
current = saved_current
|
|
326
|
+
round_idx = int(state.get("reduce_round", 0))
|
|
327
|
+
while len(current) >= 3 * self.cfg.bucket_count:
|
|
328
|
+
round_idx += 1
|
|
329
|
+
pr, idn = _vote_prompts(current, [], f"reduce{round_idx}")
|
|
330
|
+
vote_df = await get_all_responses(
|
|
331
|
+
prompts=pr,
|
|
332
|
+
identifiers=idn,
|
|
333
|
+
n_parallels=self.cfg.n_parallels,
|
|
334
|
+
model=self.cfg.model,
|
|
335
|
+
save_path=os.path.join(
|
|
336
|
+
self.cfg.save_dir, f"vote_reduce{round_idx}.csv"
|
|
337
|
+
),
|
|
338
|
+
use_dummy=self.cfg.use_dummy,
|
|
339
|
+
max_timeout=self.cfg.max_timeout,
|
|
340
|
+
json_mode=True,
|
|
341
|
+
reset_files=reset_files,
|
|
342
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
343
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
344
|
+
**kwargs,
|
|
345
|
+
)
|
|
346
|
+
vote_map = dict(zip(vote_df.Identifier, vote_df.Response))
|
|
347
|
+
parsed_votes = await asyncio.gather(
|
|
348
|
+
*[self._parse(vote_map.get(i, "")) for i in idn]
|
|
349
|
+
)
|
|
350
|
+
tallies: Dict[str, int] = defaultdict(int)
|
|
351
|
+
for res in parsed_votes:
|
|
352
|
+
for b in res.keys():
|
|
353
|
+
tallies[b] += 1
|
|
354
|
+
current.sort(
|
|
355
|
+
key=lambda x: (tallies.get(x, 0), random.random()), reverse=True
|
|
356
|
+
)
|
|
357
|
+
keep = max(
|
|
358
|
+
self.cfg.bucket_count, int(len(current) * self.cfg.next_round_frac)
|
|
359
|
+
)
|
|
360
|
+
current = current[:keep]
|
|
361
|
+
state["current_candidates"] = current
|
|
362
|
+
state["reduce_round"] = round_idx
|
|
363
|
+
state["stage"] = "reduce"
|
|
364
|
+
persist_state()
|
|
365
|
+
|
|
366
|
+
# ── 3: final selection ─────────────────────────────────────────
|
|
367
|
+
selected: List[str] = [
|
|
368
|
+
c for c in state.get("selected", []) if c in candidate_defs
|
|
369
|
+
]
|
|
370
|
+
remaining = [o for o in current if o not in selected]
|
|
371
|
+
loop_idx = int(state.get("final_loop", 0))
|
|
372
|
+
while len(selected) < self.cfg.bucket_count and remaining:
|
|
373
|
+
loop_idx += 1
|
|
374
|
+
pr, idn = _vote_prompts(
|
|
375
|
+
[o for o in remaining if o not in selected],
|
|
376
|
+
selected,
|
|
377
|
+
f"final{loop_idx}",
|
|
378
|
+
)
|
|
379
|
+
vote_df = await get_all_responses(
|
|
380
|
+
prompts=pr,
|
|
381
|
+
identifiers=idn,
|
|
382
|
+
n_parallels=self.cfg.n_parallels,
|
|
383
|
+
model=self.cfg.model,
|
|
384
|
+
save_path=os.path.join(
|
|
385
|
+
self.cfg.save_dir, f"vote_final{loop_idx}.csv"
|
|
386
|
+
),
|
|
387
|
+
use_dummy=self.cfg.use_dummy,
|
|
388
|
+
max_timeout=self.cfg.max_timeout,
|
|
389
|
+
json_mode=True,
|
|
390
|
+
reset_files=reset_files,
|
|
391
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
392
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
393
|
+
**kwargs,
|
|
394
|
+
)
|
|
395
|
+
vote_map = dict(zip(vote_df.Identifier, vote_df.Response))
|
|
396
|
+
parsed_votes = await asyncio.gather(
|
|
397
|
+
*[self._parse(vote_map.get(i, "")) for i in idn]
|
|
398
|
+
)
|
|
399
|
+
tallies: Dict[str, int] = defaultdict(int)
|
|
400
|
+
for res in parsed_votes:
|
|
401
|
+
for b in res.keys():
|
|
402
|
+
if b not in selected:
|
|
403
|
+
tallies[b] += 1
|
|
404
|
+
remaining = [o for o in remaining if o not in selected]
|
|
405
|
+
remaining.sort(
|
|
406
|
+
key=lambda x: (tallies.get(x, 0), random.random()), reverse=True
|
|
407
|
+
)
|
|
408
|
+
n_pick = min(
|
|
409
|
+
self.cfg.top_k_per_round,
|
|
410
|
+
self.cfg.bucket_count - len(selected),
|
|
411
|
+
len(remaining),
|
|
412
|
+
)
|
|
413
|
+
winners = remaining[:n_pick]
|
|
414
|
+
selected.extend(winners)
|
|
415
|
+
state["selected"] = selected
|
|
416
|
+
state["remaining_candidates"] = remaining
|
|
417
|
+
state["final_loop"] = loop_idx
|
|
418
|
+
state["stage"] = "finalizing"
|
|
419
|
+
persist_state()
|
|
420
|
+
|
|
421
|
+
bucket_defs = {b: candidate_defs.get(b, "") for b in selected}
|
|
422
|
+
out_df = pd.DataFrame(
|
|
423
|
+
{"bucket": list(bucket_defs.keys()), "definition": list(bucket_defs.values())}
|
|
424
|
+
)
|
|
425
|
+
out_df.to_csv(
|
|
426
|
+
os.path.join(self.cfg.save_dir, self.cfg.file_name), index=False
|
|
427
|
+
)
|
|
428
|
+
state["final_buckets"] = out_df.to_dict(orient="records")
|
|
429
|
+
state["finalized"] = True
|
|
430
|
+
state["stage"] = "complete"
|
|
431
|
+
persist_state()
|
|
432
|
+
return out_df
|