openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
gabriel/tasks/merge.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import json
|
|
7
|
+
import math
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
11
|
+
|
|
12
|
+
import html
|
|
13
|
+
import unicodedata
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
from scipy.cluster.vq import kmeans2
|
|
18
|
+
|
|
19
|
+
from ..core.prompt_template import PromptTemplate, resolve_template
|
|
20
|
+
from ..utils.openai_utils import get_all_responses
|
|
21
|
+
from ..utils import safest_json, safe_json, get_all_embeddings
|
|
22
|
+
from ..utils.logging import announce_prompt_rendering
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class MergeConfig:
|
|
27
|
+
"""Configuration options for :class:`Merge`."""
|
|
28
|
+
|
|
29
|
+
save_dir: str = "merge"
|
|
30
|
+
file_name: str = "merge_responses.csv"
|
|
31
|
+
model: str = "gpt-5-nano"
|
|
32
|
+
n_parallels: int = 650
|
|
33
|
+
n_runs: int = 1
|
|
34
|
+
use_dummy: bool = False
|
|
35
|
+
max_timeout: Optional[float] = None
|
|
36
|
+
additional_instructions: Optional[str] = None
|
|
37
|
+
use_embeddings: bool = True
|
|
38
|
+
short_list_len: int = 16
|
|
39
|
+
long_list_len: int = 256
|
|
40
|
+
max_attempts: int = 4
|
|
41
|
+
short_list_multiplier: float = 0.5
|
|
42
|
+
auto_match_threshold: float = 0.75
|
|
43
|
+
use_best_auto_match: bool = False
|
|
44
|
+
candidate_scan_chunks: int = 5
|
|
45
|
+
|
|
46
|
+
def __post_init__(self) -> None:
|
|
47
|
+
if self.additional_instructions is not None:
|
|
48
|
+
cleaned = str(self.additional_instructions).strip()
|
|
49
|
+
self.additional_instructions = cleaned or None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Merge:
|
|
53
|
+
"""Fuzzy merge between two DataFrames using LLM assistance."""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
cfg: MergeConfig,
|
|
58
|
+
template: Optional[PromptTemplate] = None,
|
|
59
|
+
template_path: Optional[str] = None,
|
|
60
|
+
) -> None:
|
|
61
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
62
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
cfg.save_dir = str(expanded)
|
|
64
|
+
self.cfg = cfg
|
|
65
|
+
self.template = resolve_template(
|
|
66
|
+
template=template,
|
|
67
|
+
template_path=template_path,
|
|
68
|
+
reference_filename="merge_prompt.jinja2",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# ------------------------------------------------------------------
|
|
72
|
+
@staticmethod
|
|
73
|
+
def _normalize(val: str) -> str:
|
|
74
|
+
"""Normalize strings for fuzzy matching."""
|
|
75
|
+
# Convert HTML entities and strip accents before keeping alphanumeric
|
|
76
|
+
txt = html.unescape(val).lower()
|
|
77
|
+
txt = unicodedata.normalize("NFKD", txt)
|
|
78
|
+
return "".join(ch for ch in txt if ch.isalnum())
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def _deduplicate(
|
|
82
|
+
cls, series: pd.Series
|
|
83
|
+
) -> Tuple[List[str], Dict[str, List[str]], Dict[str, str]]:
|
|
84
|
+
"""Return (unique_values, rep_to_group, norm_to_rep) for a Series."""
|
|
85
|
+
norm_map: Dict[str, str] = {}
|
|
86
|
+
groups: Dict[str, List[str]] = {}
|
|
87
|
+
for val in series.dropna().astype(str):
|
|
88
|
+
norm = cls._normalize(val)
|
|
89
|
+
if norm in norm_map:
|
|
90
|
+
rep = norm_map[norm]
|
|
91
|
+
groups[rep].append(val)
|
|
92
|
+
else:
|
|
93
|
+
norm_map[norm] = val
|
|
94
|
+
groups[val] = [val]
|
|
95
|
+
uniques = list(groups.keys())
|
|
96
|
+
return uniques, groups, norm_map
|
|
97
|
+
|
|
98
|
+
# ------------------------------------------------------------------
|
|
99
|
+
async def run(
|
|
100
|
+
self,
|
|
101
|
+
df_left: pd.DataFrame,
|
|
102
|
+
df_right: pd.DataFrame,
|
|
103
|
+
*,
|
|
104
|
+
on: Optional[str] = None,
|
|
105
|
+
left_on: Optional[str] = None,
|
|
106
|
+
right_on: Optional[str] = None,
|
|
107
|
+
how: str = "left",
|
|
108
|
+
reset_files: bool = False,
|
|
109
|
+
**kwargs: Any,
|
|
110
|
+
) -> pd.DataFrame:
|
|
111
|
+
"""Perform an LLM-assisted merge between two DataFrames."""
|
|
112
|
+
|
|
113
|
+
if on:
|
|
114
|
+
left_key = right_key = on
|
|
115
|
+
elif left_on and right_on:
|
|
116
|
+
left_key, right_key = left_on, right_on
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError("Specify `on` or both `left_on` and `right_on`.")
|
|
119
|
+
|
|
120
|
+
how = how.lower()
|
|
121
|
+
if how not in {"left", "right"}:
|
|
122
|
+
raise ValueError("`how` must be either 'left' or 'right'.")
|
|
123
|
+
|
|
124
|
+
if how == "left":
|
|
125
|
+
short_df, long_df = df_left.reset_index(drop=True), df_right.reset_index(drop=True)
|
|
126
|
+
short_key, long_key = left_key, right_key
|
|
127
|
+
else: # right merge
|
|
128
|
+
short_df, long_df = df_right.reset_index(drop=True), df_left.reset_index(drop=True)
|
|
129
|
+
short_key, long_key = right_key, left_key
|
|
130
|
+
|
|
131
|
+
# Deduplicate keys and track normalized maps
|
|
132
|
+
short_uniques, short_groups, short_norm_map = self._deduplicate(short_df[short_key])
|
|
133
|
+
long_uniques, long_groups, long_norm_map = self._deduplicate(long_df[long_key])
|
|
134
|
+
|
|
135
|
+
# Build a global norm→representative map for the left-hand keys.
|
|
136
|
+
global_short_norm_map = {self._normalize(s): s for s in short_uniques}
|
|
137
|
+
|
|
138
|
+
use_embeddings = self.cfg.use_embeddings and len(long_uniques) >= self.cfg.long_list_len
|
|
139
|
+
|
|
140
|
+
if reset_files:
|
|
141
|
+
for p in Path(self.cfg.save_dir).glob("merge_groups_attempt*.json"):
|
|
142
|
+
try:
|
|
143
|
+
p.unlink()
|
|
144
|
+
except OSError:
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
short_emb: Dict[str, List[float]] = {}
|
|
148
|
+
long_emb: Dict[str, List[float]] = {}
|
|
149
|
+
if use_embeddings:
|
|
150
|
+
short_emb = await get_all_embeddings(
|
|
151
|
+
texts=short_uniques,
|
|
152
|
+
identifiers=short_uniques,
|
|
153
|
+
save_path=os.path.join(self.cfg.save_dir, "short_embeddings.pkl"),
|
|
154
|
+
reset_file=reset_files,
|
|
155
|
+
use_dummy=self.cfg.use_dummy,
|
|
156
|
+
verbose=False,
|
|
157
|
+
)
|
|
158
|
+
long_emb = await get_all_embeddings(
|
|
159
|
+
texts=long_uniques,
|
|
160
|
+
identifiers=long_uniques,
|
|
161
|
+
save_path=os.path.join(self.cfg.save_dir, "long_embeddings.pkl"),
|
|
162
|
+
reset_file=reset_files,
|
|
163
|
+
use_dummy=self.cfg.use_dummy,
|
|
164
|
+
verbose=False,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
matches: Dict[str, str] = {}
|
|
168
|
+
remaining = short_uniques[:]
|
|
169
|
+
if use_embeddings and self.cfg.auto_match_threshold > 0:
|
|
170
|
+
short_matrix = np.array([short_emb[s] for s in remaining], dtype=float)
|
|
171
|
+
long_matrix = np.array([long_emb[t] for t in long_uniques], dtype=float)
|
|
172
|
+
short_norms = np.linalg.norm(short_matrix, axis=1) + 1e-8
|
|
173
|
+
long_norms = np.linalg.norm(long_matrix, axis=1) + 1e-8
|
|
174
|
+
sims = (short_matrix @ long_matrix.T) / (short_norms[:, None] * long_norms[None, :])
|
|
175
|
+
for i, s in enumerate(remaining):
|
|
176
|
+
row = sims[i]
|
|
177
|
+
above = np.where(row >= self.cfg.auto_match_threshold)[0]
|
|
178
|
+
if len(above) == 1:
|
|
179
|
+
matches[s] = long_uniques[above[0]]
|
|
180
|
+
elif len(above) > 1 and self.cfg.use_best_auto_match:
|
|
181
|
+
best_idx = above[np.argmax(row[above])]
|
|
182
|
+
matches[s] = long_uniques[best_idx]
|
|
183
|
+
remaining = [s for s in remaining if s not in matches]
|
|
184
|
+
|
|
185
|
+
def _build_groups(
|
|
186
|
+
remaining_short: List[str], short_len: int, extra_scans: int = 1
|
|
187
|
+
) -> Tuple[List[List[str]], List[List[str]]]:
|
|
188
|
+
clusters_out: List[List[str]] = []
|
|
189
|
+
candidates: List[List[str]] = []
|
|
190
|
+
if use_embeddings and short_emb:
|
|
191
|
+
arr = np.array([short_emb[s] for s in remaining_short], dtype=float)
|
|
192
|
+
k = max(1, int(np.ceil(len(remaining_short) / short_len)))
|
|
193
|
+
_, labels = kmeans2(arr, k, minit="points")
|
|
194
|
+
cluster_sets: List[List[str]] = []
|
|
195
|
+
for cluster_id in range(k):
|
|
196
|
+
members = [remaining_short[i] for i, lbl in enumerate(labels) if lbl == cluster_id]
|
|
197
|
+
if not members:
|
|
198
|
+
continue
|
|
199
|
+
for j in range(0, len(members), short_len):
|
|
200
|
+
subset = members[j : j + short_len]
|
|
201
|
+
cluster_sets.append(subset)
|
|
202
|
+
|
|
203
|
+
long_matrix = np.array([long_emb[t] for t in long_uniques], dtype=float)
|
|
204
|
+
long_norms = np.linalg.norm(long_matrix, axis=1) + 1e-8
|
|
205
|
+
for subset in cluster_sets:
|
|
206
|
+
short_vecs = [np.array(short_emb[s], dtype=float) for s in subset]
|
|
207
|
+
short_norms = [np.linalg.norm(vec) + 1e-8 for vec in short_vecs]
|
|
208
|
+
orders: List[np.ndarray] = []
|
|
209
|
+
sims_list: List[np.ndarray] = []
|
|
210
|
+
for vec, norm in zip(short_vecs, short_norms):
|
|
211
|
+
sims = long_matrix @ vec / (long_norms * norm)
|
|
212
|
+
order = np.argsort(sims)[::-1]
|
|
213
|
+
orders.append(order)
|
|
214
|
+
sims_list.append(sims)
|
|
215
|
+
|
|
216
|
+
per_term = max(1, math.ceil(self.cfg.long_list_len / len(subset)))
|
|
217
|
+
for scan in range(extra_scans):
|
|
218
|
+
combined: Dict[int, float] = {}
|
|
219
|
+
start = scan * per_term
|
|
220
|
+
end = start + per_term
|
|
221
|
+
for order, sims in zip(orders, sims_list):
|
|
222
|
+
if start >= len(order):
|
|
223
|
+
continue
|
|
224
|
+
idx_slice = order[start:end]
|
|
225
|
+
for idx in idx_slice:
|
|
226
|
+
score = float(sims[idx])
|
|
227
|
+
if idx not in combined or score > combined[idx]:
|
|
228
|
+
combined[idx] = score
|
|
229
|
+
if not combined:
|
|
230
|
+
continue
|
|
231
|
+
sorted_idx = sorted(combined.keys(), key=lambda i: combined[i], reverse=True)
|
|
232
|
+
candidate = [long_uniques[i] for i in sorted_idx[: self.cfg.long_list_len]]
|
|
233
|
+
candidates.append(candidate)
|
|
234
|
+
clusters_out.append(subset)
|
|
235
|
+
else:
|
|
236
|
+
short_sorted = sorted(remaining_short, key=lambda x: x.lower())
|
|
237
|
+
long_sorted = sorted(long_uniques, key=lambda x: x.lower())
|
|
238
|
+
if len(long_sorted) <= self.cfg.long_list_len:
|
|
239
|
+
base_candidate = list(long_sorted)
|
|
240
|
+
for i in range(0, len(short_sorted), short_len):
|
|
241
|
+
subset = short_sorted[i : i + short_len]
|
|
242
|
+
for _ in range(extra_scans):
|
|
243
|
+
clusters_out.append(subset)
|
|
244
|
+
candidates.append(base_candidate)
|
|
245
|
+
else:
|
|
246
|
+
import bisect
|
|
247
|
+
|
|
248
|
+
lower_long = [s.lower() for s in long_sorted]
|
|
249
|
+
for i in range(0, len(short_sorted), short_len):
|
|
250
|
+
subset = short_sorted[i : i + short_len]
|
|
251
|
+
mid = subset[len(subset) // 2].lower()
|
|
252
|
+
idx = bisect.bisect_left(lower_long, mid)
|
|
253
|
+
start = max(0, idx - self.cfg.long_list_len // 2)
|
|
254
|
+
for scan in range(extra_scans):
|
|
255
|
+
scan_start = start + scan * self.cfg.long_list_len
|
|
256
|
+
scan_end = scan_start + self.cfg.long_list_len
|
|
257
|
+
if scan_end > len(long_sorted):
|
|
258
|
+
scan_end = len(long_sorted)
|
|
259
|
+
scan_start = max(0, scan_end - self.cfg.long_list_len)
|
|
260
|
+
clusters_out.append(subset)
|
|
261
|
+
candidates.append(list(long_sorted[scan_start:scan_end]))
|
|
262
|
+
return clusters_out, candidates
|
|
263
|
+
|
|
264
|
+
def _parse_response(res: Any) -> Dict[str, str]:
|
|
265
|
+
"""Normalize raw model output into a dictionary."""
|
|
266
|
+
if isinstance(res, list):
|
|
267
|
+
combined: Dict[str, str] = {}
|
|
268
|
+
for item in res:
|
|
269
|
+
if isinstance(item, dict):
|
|
270
|
+
for k, v in item.items():
|
|
271
|
+
if isinstance(k, str) and isinstance(v, str):
|
|
272
|
+
combined[k] = v
|
|
273
|
+
elif isinstance(item, str):
|
|
274
|
+
inner = safe_json(item)
|
|
275
|
+
if isinstance(inner, dict):
|
|
276
|
+
for k, v in inner.items():
|
|
277
|
+
if isinstance(k, str) and isinstance(v, str):
|
|
278
|
+
combined[k] = v
|
|
279
|
+
res = combined
|
|
280
|
+
elif isinstance(res, str):
|
|
281
|
+
res = safe_json(res)
|
|
282
|
+
|
|
283
|
+
if isinstance(res, dict):
|
|
284
|
+
return {k: v for k, v in res.items() if isinstance(k, str) and isinstance(v, str)}
|
|
285
|
+
return {}
|
|
286
|
+
|
|
287
|
+
save_path = os.path.join(self.cfg.save_dir, self.cfg.file_name)
|
|
288
|
+
progress_path = os.path.join(self.cfg.save_dir, "merge_progress.csv")
|
|
289
|
+
if reset_files and os.path.exists(progress_path):
|
|
290
|
+
try:
|
|
291
|
+
os.remove(progress_path)
|
|
292
|
+
except OSError:
|
|
293
|
+
pass
|
|
294
|
+
for attempt in range(self.cfg.max_attempts):
|
|
295
|
+
if not remaining:
|
|
296
|
+
break
|
|
297
|
+
prev_total = len(matches)
|
|
298
|
+
cur_short_len = max(1, int(self.cfg.short_list_len * (self.cfg.short_list_multiplier ** attempt)))
|
|
299
|
+
group_path = os.path.join(self.cfg.save_dir, f"merge_groups_attempt{attempt}.json")
|
|
300
|
+
if os.path.exists(group_path) and not reset_files:
|
|
301
|
+
with open(group_path, "r", encoding="utf-8") as f:
|
|
302
|
+
data = json.load(f)
|
|
303
|
+
clusters = data.get("clusters", [])
|
|
304
|
+
candidates = data.get("candidates", [])
|
|
305
|
+
else:
|
|
306
|
+
extra = self.cfg.candidate_scan_chunks if attempt >= 1 else 1
|
|
307
|
+
clusters, candidates = _build_groups(remaining, cur_short_len, extra)
|
|
308
|
+
with open(group_path, "w", encoding="utf-8") as f:
|
|
309
|
+
json.dump({"clusters": clusters, "candidates": candidates}, f)
|
|
310
|
+
|
|
311
|
+
prompts: List[str] = []
|
|
312
|
+
identifiers: List[str] = []
|
|
313
|
+
base_ids: List[str] = []
|
|
314
|
+
announce_prompt_rendering("Merge", len(clusters) * max(1, self.cfg.n_runs))
|
|
315
|
+
for idx, (short_terms, long_terms) in enumerate(zip(clusters, candidates)):
|
|
316
|
+
short_dict = {s: "" for s in short_terms}
|
|
317
|
+
prompt = self.template.render(
|
|
318
|
+
short_list=short_dict,
|
|
319
|
+
long_list=list(long_terms),
|
|
320
|
+
additional_instructions=self.cfg.additional_instructions or "",
|
|
321
|
+
)
|
|
322
|
+
base_id = f"merge_{attempt:02d}_{idx:05d}"
|
|
323
|
+
base_ids.append(base_id)
|
|
324
|
+
if self.cfg.n_runs > 1:
|
|
325
|
+
for run in range(self.cfg.n_runs):
|
|
326
|
+
prompts.append(prompt)
|
|
327
|
+
identifiers.append(f"{base_id}_run{run}")
|
|
328
|
+
else:
|
|
329
|
+
prompts.append(prompt)
|
|
330
|
+
identifiers.append(base_id)
|
|
331
|
+
|
|
332
|
+
if prompts:
|
|
333
|
+
resp_df = await get_all_responses(
|
|
334
|
+
prompts=prompts,
|
|
335
|
+
identifiers=identifiers,
|
|
336
|
+
n_parallels=self.cfg.n_parallels,
|
|
337
|
+
model=self.cfg.model,
|
|
338
|
+
save_path=save_path,
|
|
339
|
+
use_dummy=self.cfg.use_dummy,
|
|
340
|
+
max_timeout=self.cfg.max_timeout,
|
|
341
|
+
json_mode=True,
|
|
342
|
+
reset_files=reset_files if attempt == 0 else False,
|
|
343
|
+
**kwargs,
|
|
344
|
+
)
|
|
345
|
+
else:
|
|
346
|
+
resp_df = pd.DataFrame(columns=["Identifier", "Response"])
|
|
347
|
+
|
|
348
|
+
resp_map = dict(zip(resp_df.get("Identifier", []), resp_df.get("Response", [])))
|
|
349
|
+
parsed = await asyncio.gather(
|
|
350
|
+
*[safest_json(resp_map.get(i, "")) for i in identifiers]
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
responses_by_base: Dict[str, List[Dict[str, str]]] = {bid: [] for bid in base_ids}
|
|
354
|
+
for ident, res in zip(identifiers, parsed):
|
|
355
|
+
base_id = ident.rsplit("_run", 1)[0] if self.cfg.n_runs > 1 else ident
|
|
356
|
+
responses_by_base.setdefault(base_id, []).append(_parse_response(res))
|
|
357
|
+
|
|
358
|
+
for clus, base_id in zip(clusters, base_ids):
|
|
359
|
+
results = responses_by_base.get(base_id, [])
|
|
360
|
+
normalized_results = [
|
|
361
|
+
{
|
|
362
|
+
self._normalize(k): v
|
|
363
|
+
for k, v in res.items()
|
|
364
|
+
if isinstance(k, str) and isinstance(v, str)
|
|
365
|
+
}
|
|
366
|
+
for res in results
|
|
367
|
+
]
|
|
368
|
+
for s in clus:
|
|
369
|
+
counts: Dict[str, int] = {}
|
|
370
|
+
s_norm = self._normalize(s)
|
|
371
|
+
for res_map in normalized_results:
|
|
372
|
+
val = res_map.get(s_norm)
|
|
373
|
+
if val and self._normalize(val) != "nocertainmatch":
|
|
374
|
+
counts[val] = counts.get(val, 0) + 1
|
|
375
|
+
if not counts:
|
|
376
|
+
continue
|
|
377
|
+
max_count = max(counts.values())
|
|
378
|
+
top_candidates = [v for v, c in counts.items() if c == max_count]
|
|
379
|
+
chosen: Optional[str] = None
|
|
380
|
+
if len(top_candidates) == 1:
|
|
381
|
+
chosen = top_candidates[0]
|
|
382
|
+
elif use_embeddings and short_emb and long_emb:
|
|
383
|
+
s_vec = np.array(short_emb.get(s, []), dtype=float)
|
|
384
|
+
if s_vec.size:
|
|
385
|
+
sims: Dict[str, float] = {}
|
|
386
|
+
s_norm_val = np.linalg.norm(s_vec) + 1e-8
|
|
387
|
+
for cand in top_candidates:
|
|
388
|
+
l_vec = np.array(long_emb.get(cand, []), dtype=float)
|
|
389
|
+
if l_vec.size:
|
|
390
|
+
sims[cand] = float(
|
|
391
|
+
s_vec @ l_vec / (s_norm_val * (np.linalg.norm(l_vec) + 1e-8))
|
|
392
|
+
)
|
|
393
|
+
if sims:
|
|
394
|
+
chosen = max(sims, key=sims.get)
|
|
395
|
+
if chosen:
|
|
396
|
+
k_norm = s_norm
|
|
397
|
+
v_norm = self._normalize(chosen)
|
|
398
|
+
if k_norm in global_short_norm_map and v_norm in long_norm_map:
|
|
399
|
+
short_rep = global_short_norm_map[k_norm]
|
|
400
|
+
long_rep = long_norm_map[v_norm]
|
|
401
|
+
matches[short_rep] = long_rep
|
|
402
|
+
|
|
403
|
+
remaining = [s for s in remaining if s not in matches]
|
|
404
|
+
round_matches = len(matches) - prev_total
|
|
405
|
+
total_matches = len(matches)
|
|
406
|
+
missing = len(remaining)
|
|
407
|
+
print(
|
|
408
|
+
f"[Merge] Attempt {attempt}: {round_matches} matches this round, "
|
|
409
|
+
f"{total_matches} total, {missing} remaining"
|
|
410
|
+
)
|
|
411
|
+
progress_df = pd.DataFrame(
|
|
412
|
+
[
|
|
413
|
+
{
|
|
414
|
+
"attempt": attempt,
|
|
415
|
+
"matches_this_round": round_matches,
|
|
416
|
+
"total_matches": total_matches,
|
|
417
|
+
"remaining": missing,
|
|
418
|
+
}
|
|
419
|
+
]
|
|
420
|
+
)
|
|
421
|
+
progress_df.to_csv(
|
|
422
|
+
progress_path,
|
|
423
|
+
mode="a",
|
|
424
|
+
header=not os.path.exists(progress_path),
|
|
425
|
+
index=False,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
records: List[Dict[str, str]] = []
|
|
429
|
+
if short_key == long_key:
|
|
430
|
+
temp_col = f"{long_key}_match"
|
|
431
|
+
for short_rep, long_rep in matches.items():
|
|
432
|
+
for s in short_groups.get(short_rep, []):
|
|
433
|
+
for l in long_groups.get(long_rep, []):
|
|
434
|
+
records.append({short_key: s, temp_col: l})
|
|
435
|
+
map_df = pd.DataFrame(records, columns=[short_key, temp_col])
|
|
436
|
+
map_df[short_key] = map_df[short_key].astype(object)
|
|
437
|
+
map_df[temp_col] = map_df[temp_col].astype(object)
|
|
438
|
+
merged = short_df.merge(map_df, how="left", on=short_key)
|
|
439
|
+
merged = merged.merge(
|
|
440
|
+
long_df,
|
|
441
|
+
how="left",
|
|
442
|
+
left_on=temp_col,
|
|
443
|
+
right_on=long_key,
|
|
444
|
+
suffixes=("", "_y"),
|
|
445
|
+
)
|
|
446
|
+
merged = merged.drop(columns=[temp_col])
|
|
447
|
+
else:
|
|
448
|
+
for short_rep, long_rep in matches.items():
|
|
449
|
+
for s in short_groups.get(short_rep, []):
|
|
450
|
+
for l in long_groups.get(long_rep, []):
|
|
451
|
+
records.append({short_key: s, long_key: l})
|
|
452
|
+
map_df = pd.DataFrame(records, columns=[short_key, long_key])
|
|
453
|
+
map_df[short_key] = map_df[short_key].astype(object)
|
|
454
|
+
map_df[long_key] = map_df[long_key].astype(object)
|
|
455
|
+
merged = short_df.merge(map_df, how="left", on=short_key)
|
|
456
|
+
merged = merged.merge(
|
|
457
|
+
long_df,
|
|
458
|
+
how="left",
|
|
459
|
+
left_on=long_key,
|
|
460
|
+
right_on=long_key,
|
|
461
|
+
suffixes=("", "_y"),
|
|
462
|
+
)
|
|
463
|
+
merged = merged.drop_duplicates(subset=[short_key])
|
|
464
|
+
return merged
|