openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
gabriel/tasks/seed.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import math
|
|
5
|
+
import os
|
|
6
|
+
import random
|
|
7
|
+
import re
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from gabriel.core.prompt_template import PromptTemplate, resolve_template
|
|
15
|
+
from gabriel.tasks.deduplicate import Deduplicate, DeduplicateConfig
|
|
16
|
+
from gabriel.utils import safest_json
|
|
17
|
+
from gabriel.utils.openai_utils import get_all_responses
|
|
18
|
+
from gabriel.utils.logging import announce_prompt_rendering
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class SeedConfig:
|
|
23
|
+
"""Configuration options for :class:`Seed`."""
|
|
24
|
+
|
|
25
|
+
instructions: str
|
|
26
|
+
save_dir: str = os.path.expanduser("~/Documents/runs")
|
|
27
|
+
file_name: str = "seed_entities.csv"
|
|
28
|
+
model: str = "gpt-5.2"
|
|
29
|
+
n_parallels: int = 650
|
|
30
|
+
num_entities: int = 1000
|
|
31
|
+
entities_per_generation: int = 50
|
|
32
|
+
entity_batch_frac: float = 0.25
|
|
33
|
+
existing_entities_cap: int = 100
|
|
34
|
+
use_dummy: bool = False
|
|
35
|
+
deduplicate: bool = False
|
|
36
|
+
deduplicate_sample_seed: int = 42
|
|
37
|
+
max_timeout: Optional[float] = None
|
|
38
|
+
reasoning_effort: Optional[str] = None
|
|
39
|
+
reasoning_summary: Optional[str] = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Seed:
|
|
43
|
+
"""Generate structured entity seeds via batched language-model calls."""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
cfg: SeedConfig,
|
|
48
|
+
template: Optional[PromptTemplate] = None,
|
|
49
|
+
template_path: Optional[str] = None,
|
|
50
|
+
) -> None:
|
|
51
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
52
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
53
|
+
cfg.save_dir = str(expanded)
|
|
54
|
+
self.cfg = cfg
|
|
55
|
+
if cfg.num_entities <= 0:
|
|
56
|
+
raise ValueError("num_entities must be positive")
|
|
57
|
+
if cfg.entities_per_generation <= 0:
|
|
58
|
+
raise ValueError("entities_per_generation must be positive")
|
|
59
|
+
if not 0 < cfg.entity_batch_frac <= 1:
|
|
60
|
+
raise ValueError("entity_batch_frac must be between 0 and 1")
|
|
61
|
+
if cfg.existing_entities_cap < 0:
|
|
62
|
+
raise ValueError("existing_entities_cap must be non-negative")
|
|
63
|
+
self.template = resolve_template(
|
|
64
|
+
template=template,
|
|
65
|
+
template_path=template_path,
|
|
66
|
+
reference_filename="seed.jinja2",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
async def run(
|
|
70
|
+
self,
|
|
71
|
+
*,
|
|
72
|
+
existing_entities: Optional[Sequence[str]] = None,
|
|
73
|
+
reset_files: bool = False,
|
|
74
|
+
**response_kwargs: Any,
|
|
75
|
+
) -> pd.DataFrame:
|
|
76
|
+
"""Generate ``num_entities`` unique seed entities."""
|
|
77
|
+
if self.cfg.deduplicate:
|
|
78
|
+
return await self._run_with_deduplication(
|
|
79
|
+
existing_entities=existing_entities,
|
|
80
|
+
reset_files=reset_files,
|
|
81
|
+
**response_kwargs,
|
|
82
|
+
)
|
|
83
|
+
return await self._run_standard(
|
|
84
|
+
existing_entities=existing_entities,
|
|
85
|
+
reset_files=reset_files,
|
|
86
|
+
**response_kwargs,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
async def _run_standard(
|
|
90
|
+
self,
|
|
91
|
+
*,
|
|
92
|
+
existing_entities: Optional[Sequence[str]],
|
|
93
|
+
reset_files: bool,
|
|
94
|
+
**response_kwargs: Any,
|
|
95
|
+
) -> pd.DataFrame:
|
|
96
|
+
normalized_existing = self._prepare_initial_existing(existing_entities)
|
|
97
|
+
seen: Dict[str, str] = {}
|
|
98
|
+
for ent in normalized_existing:
|
|
99
|
+
norm = self._normalize_entity(ent)
|
|
100
|
+
if norm and norm not in seen:
|
|
101
|
+
seen[norm] = ent
|
|
102
|
+
|
|
103
|
+
batch_target = max(
|
|
104
|
+
self.cfg.entities_per_generation,
|
|
105
|
+
math.ceil(self.cfg.num_entities * self.cfg.entity_batch_frac),
|
|
106
|
+
)
|
|
107
|
+
raw_save = os.path.join(self.cfg.save_dir, "seed_raw_responses.csv")
|
|
108
|
+
batch_index = 0
|
|
109
|
+
request_index = 0
|
|
110
|
+
reset_next = reset_files
|
|
111
|
+
while len(seen) < self.cfg.num_entities:
|
|
112
|
+
remaining = self.cfg.num_entities - len(seen)
|
|
113
|
+
batch_goal = min(batch_target, remaining)
|
|
114
|
+
batch_goal = max(batch_goal, self.cfg.entities_per_generation)
|
|
115
|
+
target_count = min(self.cfg.num_entities, len(seen) + batch_goal)
|
|
116
|
+
batch_added = 0
|
|
117
|
+
while len(seen) < target_count:
|
|
118
|
+
remaining_in_batch = target_count - len(seen)
|
|
119
|
+
current_goal = max(remaining_in_batch, self.cfg.entities_per_generation)
|
|
120
|
+
prompts, identifiers = self._build_prompts(
|
|
121
|
+
current_goal,
|
|
122
|
+
request_index,
|
|
123
|
+
list(seen.values()),
|
|
124
|
+
)
|
|
125
|
+
if not prompts:
|
|
126
|
+
break
|
|
127
|
+
print(
|
|
128
|
+
f"[Seed] Requesting {len(prompts)} prompts (batch {batch_index}, "
|
|
129
|
+
f"targeting {current_goal} entities; batch target {batch_goal})."
|
|
130
|
+
)
|
|
131
|
+
df_resp = await self._request_entities(
|
|
132
|
+
prompts,
|
|
133
|
+
identifiers,
|
|
134
|
+
raw_save=raw_save,
|
|
135
|
+
reset_files=reset_next,
|
|
136
|
+
**response_kwargs,
|
|
137
|
+
)
|
|
138
|
+
resp_lookup = dict(zip(df_resp.Identifier, df_resp.Response))
|
|
139
|
+
parsed = await asyncio.gather(
|
|
140
|
+
*[
|
|
141
|
+
self._parse_entities(resp_lookup.get(identifier, ""))
|
|
142
|
+
for identifier in identifiers
|
|
143
|
+
]
|
|
144
|
+
)
|
|
145
|
+
added = 0
|
|
146
|
+
for entity_list in parsed:
|
|
147
|
+
for entity in entity_list:
|
|
148
|
+
norm = self._normalize_entity(entity)
|
|
149
|
+
if not norm or norm in seen:
|
|
150
|
+
continue
|
|
151
|
+
seen[norm] = entity
|
|
152
|
+
added += 1
|
|
153
|
+
batch_added += added
|
|
154
|
+
reset_next = False
|
|
155
|
+
request_index += 1
|
|
156
|
+
if added == 0 and not any(parsed):
|
|
157
|
+
break
|
|
158
|
+
print(
|
|
159
|
+
f"[Seed] Added {batch_added} new entities in batch {batch_index}. Total so far: {len(seen)}."
|
|
160
|
+
)
|
|
161
|
+
batch_index += 1
|
|
162
|
+
if batch_added == 0:
|
|
163
|
+
break
|
|
164
|
+
|
|
165
|
+
ordered = [seen[norm] for norm in seen]
|
|
166
|
+
trimmed = ordered[: self.cfg.num_entities]
|
|
167
|
+
return self._finalize_entities(trimmed)
|
|
168
|
+
|
|
169
|
+
async def _run_with_deduplication(
|
|
170
|
+
self,
|
|
171
|
+
*,
|
|
172
|
+
existing_entities: Optional[Sequence[str]],
|
|
173
|
+
reset_files: bool,
|
|
174
|
+
**response_kwargs: Any,
|
|
175
|
+
) -> pd.DataFrame:
|
|
176
|
+
normalized_existing = self._prepare_initial_existing(existing_entities)
|
|
177
|
+
all_entities: List[str] = list(normalized_existing)
|
|
178
|
+
seen_norm: Set[str] = set()
|
|
179
|
+
for ent in all_entities:
|
|
180
|
+
norm = self._normalize_entity(ent)
|
|
181
|
+
if norm:
|
|
182
|
+
seen_norm.add(norm)
|
|
183
|
+
|
|
184
|
+
dedup_cycle = 0
|
|
185
|
+
deduped = await self._deduplicate_entities(
|
|
186
|
+
all_entities,
|
|
187
|
+
cycle_index=dedup_cycle,
|
|
188
|
+
reset_files=reset_files,
|
|
189
|
+
response_kwargs=response_kwargs,
|
|
190
|
+
)
|
|
191
|
+
dedup_cycle += 1
|
|
192
|
+
raw_target = self.cfg.num_entities * 4
|
|
193
|
+
batch_target = max(
|
|
194
|
+
self.cfg.entities_per_generation,
|
|
195
|
+
math.ceil(raw_target * self.cfg.entity_batch_frac),
|
|
196
|
+
)
|
|
197
|
+
raw_save = os.path.join(self.cfg.save_dir, "seed_raw_responses.csv")
|
|
198
|
+
batch_index = 0
|
|
199
|
+
request_index = 0
|
|
200
|
+
reset_next = reset_files
|
|
201
|
+
while len(deduped) < self.cfg.num_entities:
|
|
202
|
+
cycle_target = max(raw_target, len(all_entities))
|
|
203
|
+
while len(all_entities) < cycle_target:
|
|
204
|
+
remaining_in_cycle = cycle_target - len(all_entities)
|
|
205
|
+
current_goal = min(batch_target, remaining_in_cycle)
|
|
206
|
+
current_goal = max(current_goal, self.cfg.entities_per_generation)
|
|
207
|
+
prompts, identifiers = self._build_prompts(
|
|
208
|
+
current_goal,
|
|
209
|
+
request_index,
|
|
210
|
+
deduped,
|
|
211
|
+
)
|
|
212
|
+
if not prompts:
|
|
213
|
+
break
|
|
214
|
+
print(
|
|
215
|
+
f"[Seed] Requesting {len(prompts)} prompts (batch {batch_index}, "
|
|
216
|
+
f"targeting {current_goal} entities before deduplication)."
|
|
217
|
+
)
|
|
218
|
+
df_resp = await self._request_entities(
|
|
219
|
+
prompts,
|
|
220
|
+
identifiers,
|
|
221
|
+
raw_save=raw_save,
|
|
222
|
+
reset_files=reset_next,
|
|
223
|
+
**response_kwargs,
|
|
224
|
+
)
|
|
225
|
+
resp_lookup = dict(zip(df_resp.Identifier, df_resp.Response))
|
|
226
|
+
parsed = await asyncio.gather(
|
|
227
|
+
*[
|
|
228
|
+
self._parse_entities(resp_lookup.get(identifier, ""))
|
|
229
|
+
for identifier in identifiers
|
|
230
|
+
]
|
|
231
|
+
)
|
|
232
|
+
added = 0
|
|
233
|
+
for entity_list in parsed:
|
|
234
|
+
for entity in entity_list:
|
|
235
|
+
norm = self._normalize_entity(entity)
|
|
236
|
+
if not norm or norm in seen_norm:
|
|
237
|
+
continue
|
|
238
|
+
all_entities.append(entity)
|
|
239
|
+
seen_norm.add(norm)
|
|
240
|
+
added += 1
|
|
241
|
+
reset_next = False
|
|
242
|
+
request_index += 1
|
|
243
|
+
if added == 0 and not any(parsed):
|
|
244
|
+
break
|
|
245
|
+
|
|
246
|
+
deduped = await self._deduplicate_entities(
|
|
247
|
+
all_entities,
|
|
248
|
+
cycle_index=dedup_cycle,
|
|
249
|
+
reset_files=False,
|
|
250
|
+
response_kwargs=response_kwargs,
|
|
251
|
+
)
|
|
252
|
+
dedup_cycle += 1
|
|
253
|
+
print(
|
|
254
|
+
f"[Seed] Unique after deduplication: {len(deduped)}."
|
|
255
|
+
)
|
|
256
|
+
if len(deduped) >= self.cfg.num_entities:
|
|
257
|
+
break
|
|
258
|
+
|
|
259
|
+
all_entities = list(deduped)
|
|
260
|
+
seen_norm = {
|
|
261
|
+
self._normalize_entity(entity)
|
|
262
|
+
for entity in all_entities
|
|
263
|
+
if self._normalize_entity(entity)
|
|
264
|
+
}
|
|
265
|
+
batch_index += 1
|
|
266
|
+
if not all_entities:
|
|
267
|
+
break
|
|
268
|
+
|
|
269
|
+
trimmed = self._sample_to_target(deduped)
|
|
270
|
+
return self._finalize_entities(trimmed)
|
|
271
|
+
|
|
272
|
+
async def _request_entities(
|
|
273
|
+
self,
|
|
274
|
+
prompts: List[str],
|
|
275
|
+
identifiers: List[str],
|
|
276
|
+
*,
|
|
277
|
+
raw_save: str,
|
|
278
|
+
reset_files: bool,
|
|
279
|
+
**response_kwargs: Any,
|
|
280
|
+
) -> pd.DataFrame:
|
|
281
|
+
kwargs = dict(response_kwargs)
|
|
282
|
+
kwargs.setdefault("model", self.cfg.model)
|
|
283
|
+
kwargs.setdefault("n_parallels", self.cfg.n_parallels)
|
|
284
|
+
kwargs.setdefault("use_dummy", self.cfg.use_dummy)
|
|
285
|
+
kwargs.setdefault("max_timeout", self.cfg.max_timeout)
|
|
286
|
+
kwargs.setdefault("reasoning_effort", self.cfg.reasoning_effort)
|
|
287
|
+
kwargs.setdefault("reasoning_summary", self.cfg.reasoning_summary)
|
|
288
|
+
kwargs.setdefault("json_mode", True)
|
|
289
|
+
kwargs.setdefault("save_path", raw_save)
|
|
290
|
+
kwargs.setdefault("reset_files", reset_files)
|
|
291
|
+
df_resp = await get_all_responses(
|
|
292
|
+
prompts=prompts,
|
|
293
|
+
identifiers=identifiers,
|
|
294
|
+
**kwargs,
|
|
295
|
+
)
|
|
296
|
+
if not isinstance(df_resp, pd.DataFrame):
|
|
297
|
+
raise RuntimeError("get_all_responses returned no DataFrame")
|
|
298
|
+
return df_resp
|
|
299
|
+
|
|
300
|
+
def _finalize_entities(self, entities: List[str]) -> pd.DataFrame:
|
|
301
|
+
df = pd.DataFrame(
|
|
302
|
+
{
|
|
303
|
+
"entity": entities,
|
|
304
|
+
"entity_id": [f"entity-{idx:05d}" for idx in range(len(entities))],
|
|
305
|
+
}
|
|
306
|
+
)
|
|
307
|
+
df["source_batch"] = df.index // max(self.cfg.entities_per_generation, 1)
|
|
308
|
+
df["source_identifier"] = ["seed" for _ in range(len(entities))]
|
|
309
|
+
final_path = os.path.join(self.cfg.save_dir, self.cfg.file_name)
|
|
310
|
+
df.to_csv(final_path, index=False)
|
|
311
|
+
print(
|
|
312
|
+
f"[Seed] Generated {len(df)} entities. Saved aggregated seeds to {final_path}."
|
|
313
|
+
)
|
|
314
|
+
return df
|
|
315
|
+
|
|
316
|
+
async def _parse_entities(self, raw: Any) -> List[str]:
|
|
317
|
+
obj = await safest_json(raw)
|
|
318
|
+
results: List[str] = []
|
|
319
|
+
if isinstance(obj, dict):
|
|
320
|
+
for key in sorted(obj.keys()):
|
|
321
|
+
value = obj.get(key)
|
|
322
|
+
if value is None:
|
|
323
|
+
continue
|
|
324
|
+
text = str(value).strip()
|
|
325
|
+
if text:
|
|
326
|
+
results.append(text)
|
|
327
|
+
elif isinstance(obj, list):
|
|
328
|
+
for item in obj:
|
|
329
|
+
if item is None:
|
|
330
|
+
continue
|
|
331
|
+
text = str(item).strip()
|
|
332
|
+
if text:
|
|
333
|
+
results.append(text)
|
|
334
|
+
elif isinstance(obj, str):
|
|
335
|
+
text = obj.strip()
|
|
336
|
+
if text:
|
|
337
|
+
results.append(text)
|
|
338
|
+
return results
|
|
339
|
+
|
|
340
|
+
def _prepare_initial_existing(
|
|
341
|
+
self, entries: Optional[Sequence[str]]
|
|
342
|
+
) -> List[str]:
|
|
343
|
+
if not entries:
|
|
344
|
+
return []
|
|
345
|
+
unique: List[str] = []
|
|
346
|
+
seen: Set[str] = set()
|
|
347
|
+
for entry in entries:
|
|
348
|
+
text = str(entry).strip()
|
|
349
|
+
if not text:
|
|
350
|
+
continue
|
|
351
|
+
norm = self._normalize_entity(text)
|
|
352
|
+
if norm and norm not in seen:
|
|
353
|
+
seen.add(norm)
|
|
354
|
+
unique.append(text)
|
|
355
|
+
return unique
|
|
356
|
+
|
|
357
|
+
def _build_prompts(
|
|
358
|
+
self,
|
|
359
|
+
goal: int,
|
|
360
|
+
batch_index: int,
|
|
361
|
+
seen_entities: Sequence[str],
|
|
362
|
+
) -> Tuple[List[str], List[str]]:
|
|
363
|
+
prompts: List[str] = []
|
|
364
|
+
identifiers: List[str] = []
|
|
365
|
+
per_call = self.cfg.entities_per_generation
|
|
366
|
+
prompt_count = math.ceil(goal / per_call)
|
|
367
|
+
existing_sample = self._sample_existing(seen_entities)
|
|
368
|
+
existing_blob = "\n".join(existing_sample) if existing_sample else None
|
|
369
|
+
announce_prompt_rendering("Seed", prompt_count)
|
|
370
|
+
for call_index in range(prompt_count):
|
|
371
|
+
identifiers.append(f"seed|{batch_index}|{call_index}")
|
|
372
|
+
prompts.append(
|
|
373
|
+
self.template.render(
|
|
374
|
+
instructions=self.cfg.instructions,
|
|
375
|
+
entities_per_generation=per_call,
|
|
376
|
+
existing_entities=existing_blob,
|
|
377
|
+
)
|
|
378
|
+
)
|
|
379
|
+
return prompts, identifiers
|
|
380
|
+
|
|
381
|
+
def _sample_existing(self, seen_entities: Sequence[str]) -> List[str]:
|
|
382
|
+
if not seen_entities:
|
|
383
|
+
return []
|
|
384
|
+
cap = max(0, self.cfg.existing_entities_cap)
|
|
385
|
+
if cap == 0:
|
|
386
|
+
return []
|
|
387
|
+
if len(seen_entities) <= cap:
|
|
388
|
+
return list(seen_entities)
|
|
389
|
+
return random.sample(list(seen_entities), cap)
|
|
390
|
+
|
|
391
|
+
async def _deduplicate_entities(
|
|
392
|
+
self,
|
|
393
|
+
entities: Sequence[str],
|
|
394
|
+
*,
|
|
395
|
+
cycle_index: int,
|
|
396
|
+
reset_files: bool,
|
|
397
|
+
response_kwargs: Dict[str, Any],
|
|
398
|
+
) -> List[str]:
|
|
399
|
+
cleaned: List[str] = []
|
|
400
|
+
for entity in entities:
|
|
401
|
+
text = str(entity).strip()
|
|
402
|
+
if text:
|
|
403
|
+
cleaned.append(text)
|
|
404
|
+
if not cleaned:
|
|
405
|
+
return []
|
|
406
|
+
|
|
407
|
+
df = pd.DataFrame({"entity": cleaned})
|
|
408
|
+
dedup_cfg = DeduplicateConfig(
|
|
409
|
+
save_dir=os.path.join(self.cfg.save_dir, "seed_deduplicate"),
|
|
410
|
+
file_name=f"seed_deduplicate_cycle{cycle_index}.csv",
|
|
411
|
+
model=self.cfg.model,
|
|
412
|
+
n_parallels=self.cfg.n_parallels,
|
|
413
|
+
n_runs=4,
|
|
414
|
+
use_dummy=self.cfg.use_dummy,
|
|
415
|
+
max_timeout=self.cfg.max_timeout,
|
|
416
|
+
group_size=100,
|
|
417
|
+
)
|
|
418
|
+
dedup = Deduplicate(dedup_cfg)
|
|
419
|
+
dedup_df = await dedup.run(
|
|
420
|
+
df,
|
|
421
|
+
column_name="entity",
|
|
422
|
+
reset_files=reset_files,
|
|
423
|
+
**self._filter_dedup_response_kwargs(response_kwargs),
|
|
424
|
+
)
|
|
425
|
+
mapped_col = (
|
|
426
|
+
"mapped_entity_final"
|
|
427
|
+
if "mapped_entity_final" in dedup_df.columns
|
|
428
|
+
else "mapped_entity"
|
|
429
|
+
)
|
|
430
|
+
unique: List[str] = []
|
|
431
|
+
seen: Set[str] = set()
|
|
432
|
+
for val in dedup_df[mapped_col]:
|
|
433
|
+
text = str(val).strip()
|
|
434
|
+
if not text or text in seen:
|
|
435
|
+
continue
|
|
436
|
+
seen.add(text)
|
|
437
|
+
unique.append(text)
|
|
438
|
+
return unique
|
|
439
|
+
|
|
440
|
+
@staticmethod
|
|
441
|
+
def _filter_dedup_response_kwargs(response_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
442
|
+
blocked = {
|
|
443
|
+
"model",
|
|
444
|
+
"n_parallels",
|
|
445
|
+
"save_path",
|
|
446
|
+
"use_dummy",
|
|
447
|
+
"max_timeout",
|
|
448
|
+
"json_mode",
|
|
449
|
+
"reset_files",
|
|
450
|
+
"prompts",
|
|
451
|
+
"identifiers",
|
|
452
|
+
}
|
|
453
|
+
return {key: value for key, value in response_kwargs.items() if key not in blocked}
|
|
454
|
+
|
|
455
|
+
def _sample_to_target(self, entities: List[str]) -> List[str]:
|
|
456
|
+
if len(entities) <= self.cfg.num_entities:
|
|
457
|
+
return entities[: self.cfg.num_entities]
|
|
458
|
+
rng = random.Random(self.cfg.deduplicate_sample_seed)
|
|
459
|
+
selected = set(rng.sample(entities, self.cfg.num_entities))
|
|
460
|
+
return [entity for entity in entities if entity in selected][: self.cfg.num_entities]
|
|
461
|
+
|
|
462
|
+
@staticmethod
|
|
463
|
+
def _normalize_entity(text: str) -> str:
|
|
464
|
+
collapsed = re.sub(r"\s+", " ", text.strip()).lower()
|
|
465
|
+
return collapsed
|