openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
gabriel/tasks/codify.py
ADDED
|
@@ -0,0 +1,1033 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import warnings
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from dataclasses import dataclass, InitVar
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
from ..core.prompt_template import PromptTemplate, resolve_template
|
|
14
|
+
from ..tasks.classify import Classify, ClassifyConfig
|
|
15
|
+
from ..utils import (
|
|
16
|
+
get_all_responses,
|
|
17
|
+
letters_only,
|
|
18
|
+
load_audio_inputs,
|
|
19
|
+
load_image_inputs,
|
|
20
|
+
load_pdf_inputs,
|
|
21
|
+
normalize_text_aggressive,
|
|
22
|
+
normalize_text_generous,
|
|
23
|
+
normalize_whitespace,
|
|
24
|
+
robust_find_improved,
|
|
25
|
+
safe_json,
|
|
26
|
+
strict_find,
|
|
27
|
+
warn_if_modality_mismatch,
|
|
28
|
+
)
|
|
29
|
+
from ..utils.logging import announce_prompt_rendering
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class CodifyConfig:
|
|
34
|
+
"""Configuration for :class:`Codify`."""
|
|
35
|
+
|
|
36
|
+
save_dir: str
|
|
37
|
+
file_name: str = "coding_results.csv"
|
|
38
|
+
model: str = "gpt-5-mini"
|
|
39
|
+
n_parallels: int = 650
|
|
40
|
+
max_words_per_call: int = 1000
|
|
41
|
+
max_categories_per_call: int = 8
|
|
42
|
+
debug_print: bool = False
|
|
43
|
+
use_dummy: bool = False
|
|
44
|
+
reasoning_effort: Optional[str] = None
|
|
45
|
+
reasoning_summary: Optional[str] = None
|
|
46
|
+
modality: str = "text"
|
|
47
|
+
json_mode: bool = True
|
|
48
|
+
max_timeout: Optional[float] = None
|
|
49
|
+
n_rounds: int = 2 # Total Codify passes including the initial run; set to 1 to skip completion sweeps
|
|
50
|
+
completion_classifier_instructions: Optional[str] = None
|
|
51
|
+
completion_max_rounds: InitVar[Optional[int]] = None
|
|
52
|
+
|
|
53
|
+
def __post_init__(self, completion_max_rounds: Optional[int]) -> None:
|
|
54
|
+
if completion_max_rounds is not None:
|
|
55
|
+
warnings.warn(
|
|
56
|
+
"completion_max_rounds is deprecated; use n_rounds instead.",
|
|
57
|
+
DeprecationWarning,
|
|
58
|
+
stacklevel=2,
|
|
59
|
+
)
|
|
60
|
+
self.n_rounds = completion_max_rounds
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
rounds = int(self.n_rounds)
|
|
64
|
+
except (TypeError, ValueError):
|
|
65
|
+
rounds = 1
|
|
66
|
+
if rounds < 1:
|
|
67
|
+
rounds = 1
|
|
68
|
+
self.n_rounds = rounds
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class ChunkResult:
|
|
73
|
+
"""Container holding the parsed response for a single chunk."""
|
|
74
|
+
|
|
75
|
+
identifier: str
|
|
76
|
+
chunk_text: str
|
|
77
|
+
data: Dict[str, Any]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class PromptRequest:
|
|
82
|
+
"""Metadata for a prompt dispatched to the model."""
|
|
83
|
+
|
|
84
|
+
identifier: str
|
|
85
|
+
prompt: str
|
|
86
|
+
row_index: int
|
|
87
|
+
chunk_text: str
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class Codify:
|
|
91
|
+
"""Pipeline for coding passages of text according to specified categories."""
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
cfg: CodifyConfig,
|
|
96
|
+
template: Optional[PromptTemplate] = None,
|
|
97
|
+
template_path: Optional[str] = None,
|
|
98
|
+
) -> None:
|
|
99
|
+
"""Create a new :class:`Codify` instance.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
template:
|
|
104
|
+
Optional preconstructed :class:`PromptTemplate`.
|
|
105
|
+
template_path:
|
|
106
|
+
Path to a custom Jinja2 template on disk. The template is
|
|
107
|
+
validated to ensure it exposes the same variables as the
|
|
108
|
+
built-in ``codify_prompt.jinja2`` template.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
112
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
113
|
+
cfg.save_dir = str(expanded)
|
|
114
|
+
self.cfg = cfg
|
|
115
|
+
self.hit_rate_stats = {} # Track hit rates across all texts
|
|
116
|
+
self.template = resolve_template(
|
|
117
|
+
template=template,
|
|
118
|
+
template_path=template_path,
|
|
119
|
+
reference_filename="codify_prompt.jinja2",
|
|
120
|
+
)
|
|
121
|
+
self.hit_rate_stats = {} # Track hit rates across all texts
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def view(
|
|
125
|
+
df: pd.DataFrame,
|
|
126
|
+
column_name: str,
|
|
127
|
+
attributes: Optional[Union[List[str], str]] = None,
|
|
128
|
+
**viewer_kwargs: Any,
|
|
129
|
+
):
|
|
130
|
+
"""Convenience wrapper around :func:`gabriel.view`.
|
|
131
|
+
|
|
132
|
+
This helper makes it easy to visualise coding results produced by
|
|
133
|
+
:class:`Codify`. Additional keyword arguments are forwarded to
|
|
134
|
+
:func:`gabriel.view`, enabling features such as the Colab viewer,
|
|
135
|
+
attribute chips, and custom metadata headers.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
from ..utils import view as view_results
|
|
139
|
+
|
|
140
|
+
normalized_attributes = attributes or "coded_passages"
|
|
141
|
+
|
|
142
|
+
return view_results(
|
|
143
|
+
df,
|
|
144
|
+
column_name,
|
|
145
|
+
attributes=normalized_attributes,
|
|
146
|
+
**viewer_kwargs,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def parse_json(self, response_text: Any) -> Optional[dict]:
|
|
150
|
+
"""Robust JSON parsing using :func:`safe_json`."""
|
|
151
|
+
|
|
152
|
+
parsed = safe_json(response_text)
|
|
153
|
+
if isinstance(parsed, dict):
|
|
154
|
+
return parsed
|
|
155
|
+
if isinstance(parsed, list) and parsed:
|
|
156
|
+
inner = safe_json(parsed[0])
|
|
157
|
+
if isinstance(inner, dict):
|
|
158
|
+
return inner
|
|
159
|
+
return None
|
|
160
|
+
|
|
161
|
+
def chunk_by_words(self, text: str, max_words: int) -> List[str]:
|
|
162
|
+
"""Split text into chunks by word count."""
|
|
163
|
+
words = text.split()
|
|
164
|
+
if len(words) <= max_words:
|
|
165
|
+
return [text]
|
|
166
|
+
return [" ".join(words[i : i + max_words]) for i in range(0, len(words), max_words)]
|
|
167
|
+
|
|
168
|
+
def find_snippet_in_text(self, text: str, beginning_excerpt: str, ending_excerpt: str) -> Optional[str]:
|
|
169
|
+
"""Fast snippet finding that returns actual text from the original document."""
|
|
170
|
+
if not beginning_excerpt:
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
# Handle short excerpts (no ending)
|
|
174
|
+
if not ending_excerpt:
|
|
175
|
+
match = robust_find_improved(text, beginning_excerpt)
|
|
176
|
+
if match:
|
|
177
|
+
# Find the actual position in the original text
|
|
178
|
+
start_pos, end_pos, match_type = self._find_actual_position_with_type(text, beginning_excerpt)
|
|
179
|
+
if start_pos is not None:
|
|
180
|
+
# Always ensure word boundaries and add some context
|
|
181
|
+
start_pos = self._find_word_start(text, start_pos)
|
|
182
|
+
end_pos = self._find_word_end(text, end_pos)
|
|
183
|
+
|
|
184
|
+
# If using fallback matching, expand to include more context
|
|
185
|
+
if match_type in ['first_20', 'last_20', 'first_half', 'second_half']:
|
|
186
|
+
start_pos, end_pos = self._expand_fallback_match(text, start_pos, end_pos, beginning_excerpt)
|
|
187
|
+
else:
|
|
188
|
+
# Add minimal context for exact matches
|
|
189
|
+
words_after = self._get_n_words_after(text, end_pos, 5)
|
|
190
|
+
end_pos = min(len(text), end_pos + len(words_after))
|
|
191
|
+
|
|
192
|
+
return text[start_pos:end_pos].strip()
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
# Handle longer snippets with both beginning and ending
|
|
196
|
+
begin_match = robust_find_improved(text, beginning_excerpt)
|
|
197
|
+
end_match = robust_find_improved(text, ending_excerpt)
|
|
198
|
+
|
|
199
|
+
if not begin_match and not end_match:
|
|
200
|
+
return None
|
|
201
|
+
elif begin_match and not end_match:
|
|
202
|
+
# Beginning found but ending not found - include beginning + 20 words after
|
|
203
|
+
begin_start, begin_end, _ = self._find_actual_position_with_type(text, beginning_excerpt)
|
|
204
|
+
if begin_start is not None:
|
|
205
|
+
# Find word boundary at beginning end
|
|
206
|
+
word_end = self._find_word_end(text, begin_end)
|
|
207
|
+
words_after = self._get_n_words_after(text, word_end, 20)
|
|
208
|
+
|
|
209
|
+
# Calculate end position for the 20 words after
|
|
210
|
+
after_end = min(len(text), word_end + len(words_after))
|
|
211
|
+
result = text[begin_start:after_end].strip()
|
|
212
|
+
return result if result else None
|
|
213
|
+
return None
|
|
214
|
+
elif not begin_match and end_match:
|
|
215
|
+
# Ending found but beginning not found - include 20 words before + ending
|
|
216
|
+
end_start, end_end, _ = self._find_actual_position_with_type(text, ending_excerpt)
|
|
217
|
+
if end_start is not None:
|
|
218
|
+
# Find word boundary at ending start
|
|
219
|
+
word_start = self._find_word_start(text, end_start)
|
|
220
|
+
words_before = self._get_n_words_before(text, word_start, 20)
|
|
221
|
+
|
|
222
|
+
# Calculate start position for the 20 words before
|
|
223
|
+
before_start = max(0, word_start - len(words_before))
|
|
224
|
+
result = text[before_start:end_end].strip()
|
|
225
|
+
return result if result else None
|
|
226
|
+
return None
|
|
227
|
+
else:
|
|
228
|
+
# Both beginning and ending found - extract the actual snippet from original text
|
|
229
|
+
begin_start, begin_end, begin_type = self._find_actual_position_with_type(text, beginning_excerpt)
|
|
230
|
+
end_start, end_end, end_type = self._find_actual_position_with_type(text, ending_excerpt)
|
|
231
|
+
|
|
232
|
+
if begin_start is not None and end_start is not None:
|
|
233
|
+
# Expand fallback matches
|
|
234
|
+
if begin_type in ['first_20', 'last_20', 'first_half', 'second_half']:
|
|
235
|
+
begin_start, begin_end = self._expand_fallback_match(text, begin_start, begin_end, beginning_excerpt)
|
|
236
|
+
if end_type in ['first_20', 'last_20', 'first_half', 'second_half']:
|
|
237
|
+
end_start, end_end = self._expand_fallback_match(text, end_start, end_end, ending_excerpt)
|
|
238
|
+
|
|
239
|
+
# Make sure ending comes after beginning
|
|
240
|
+
if end_start >= begin_start:
|
|
241
|
+
return text[begin_start:end_end].strip()
|
|
242
|
+
else:
|
|
243
|
+
# Ending comes before beginning, just return beginning snippet
|
|
244
|
+
return self.find_snippet_in_text(text, beginning_excerpt, "")
|
|
245
|
+
elif begin_start is not None:
|
|
246
|
+
# Only beginning found
|
|
247
|
+
return self.find_snippet_in_text(text, beginning_excerpt, "")
|
|
248
|
+
|
|
249
|
+
return None
|
|
250
|
+
|
|
251
|
+
def _find_actual_position(self, text: str, excerpt: str, _recursion_depth: int = 0) -> tuple:
|
|
252
|
+
"""Find the actual character positions of an excerpt in the original text."""
|
|
253
|
+
result = self._find_actual_position_with_type(text, excerpt, _recursion_depth)
|
|
254
|
+
return result[0], result[1] # Return just position, not match type
|
|
255
|
+
|
|
256
|
+
def _find_actual_position_with_type(self, text: str, excerpt: str, _recursion_depth: int = 0) -> tuple:
|
|
257
|
+
"""Find the actual character positions and match type using the SAME permissive strategies as robust_find_improved."""
|
|
258
|
+
if not excerpt.strip():
|
|
259
|
+
return None, None, None
|
|
260
|
+
|
|
261
|
+
# Prevent infinite recursion
|
|
262
|
+
if _recursion_depth > 1:
|
|
263
|
+
return None, None, None
|
|
264
|
+
|
|
265
|
+
# Strategy 1: Try direct matching first (fastest)
|
|
266
|
+
text_lower = text.lower()
|
|
267
|
+
excerpt_lower = excerpt.lower().strip()
|
|
268
|
+
idx = text_lower.find(excerpt_lower)
|
|
269
|
+
if idx != -1:
|
|
270
|
+
return idx, idx + len(excerpt_lower), 'exact'
|
|
271
|
+
|
|
272
|
+
# Strategy 2: Try with our aggressive normalization
|
|
273
|
+
text_norm = normalize_text_aggressive(text)
|
|
274
|
+
excerpt_norm = normalize_text_aggressive(excerpt)
|
|
275
|
+
|
|
276
|
+
idx = text_norm.lower().find(excerpt_norm.lower())
|
|
277
|
+
if idx != -1:
|
|
278
|
+
# Map back to original text position approximately
|
|
279
|
+
start_pos, end_pos = self._map_normalized_to_original(text, text_norm, idx, len(excerpt_norm))
|
|
280
|
+
return start_pos, end_pos, 'normalized'
|
|
281
|
+
|
|
282
|
+
# Strategy 3: Letters-only matching (same as robust_find_improved)
|
|
283
|
+
text_letters = letters_only(text)
|
|
284
|
+
excerpt_letters = letters_only(excerpt)
|
|
285
|
+
|
|
286
|
+
if excerpt_letters and excerpt_letters in text_letters:
|
|
287
|
+
letters_idx = text_letters.find(excerpt_letters)
|
|
288
|
+
ratio = letters_idx / len(text_letters) if text_letters else 0
|
|
289
|
+
approx_start = int(ratio * len(text))
|
|
290
|
+
return approx_start, approx_start + len(excerpt), 'letters_only'
|
|
291
|
+
|
|
292
|
+
# Strategy 4: First 20 characters fallback (same as robust_find_improved)
|
|
293
|
+
if len(excerpt_letters) >= 20:
|
|
294
|
+
excerpt_first_20 = excerpt_letters[:20]
|
|
295
|
+
if excerpt_first_20 in text_letters:
|
|
296
|
+
letters_idx = text_letters.find(excerpt_first_20)
|
|
297
|
+
ratio = letters_idx / len(text_letters) if text_letters else 0
|
|
298
|
+
approx_start = int(ratio * len(text))
|
|
299
|
+
return approx_start, approx_start + len(excerpt), 'first_20'
|
|
300
|
+
|
|
301
|
+
# Strategy 5: Last 20 characters fallback (same as robust_find_improved)
|
|
302
|
+
if len(excerpt_letters) >= 20:
|
|
303
|
+
excerpt_last_20 = excerpt_letters[-20:]
|
|
304
|
+
if excerpt_last_20 in text_letters:
|
|
305
|
+
letters_idx = text_letters.find(excerpt_last_20)
|
|
306
|
+
ratio = letters_idx / len(text_letters) if text_letters else 0
|
|
307
|
+
approx_start = int(ratio * len(text))
|
|
308
|
+
return approx_start, approx_start + len(excerpt), 'last_20'
|
|
309
|
+
|
|
310
|
+
# Strategy 6: First + Last 10 fallback (same as robust_find_improved)
|
|
311
|
+
if len(excerpt_letters) >= 20:
|
|
312
|
+
excerpt_first_10 = excerpt_letters[:10]
|
|
313
|
+
excerpt_last_10 = excerpt_letters[-10:]
|
|
314
|
+
if excerpt_first_10 in text_letters and excerpt_last_10 in text_letters:
|
|
315
|
+
letters_idx = text_letters.find(excerpt_first_10)
|
|
316
|
+
ratio = letters_idx / len(text_letters) if text_letters else 0
|
|
317
|
+
approx_start = int(ratio * len(text))
|
|
318
|
+
return approx_start, approx_start + len(excerpt), 'first_last_10'
|
|
319
|
+
|
|
320
|
+
# Strategy 7: Half matching for shorter excerpts (same as robust_find_improved)
|
|
321
|
+
if 10 <= len(excerpt_letters) < 20:
|
|
322
|
+
excerpt_first_half = excerpt_letters[:len(excerpt_letters)//2]
|
|
323
|
+
excerpt_second_half = excerpt_letters[len(excerpt_letters)//2:]
|
|
324
|
+
if len(excerpt_first_half) >= 5 and len(excerpt_second_half) >= 5:
|
|
325
|
+
if excerpt_first_half in text_letters and excerpt_second_half in text_letters:
|
|
326
|
+
letters_idx = text_letters.find(excerpt_first_half)
|
|
327
|
+
ratio = letters_idx / len(text_letters) if text_letters else 0
|
|
328
|
+
approx_start = int(ratio * len(text))
|
|
329
|
+
return approx_start, approx_start + len(excerpt), 'first_half'
|
|
330
|
+
|
|
331
|
+
return None, None, None
|
|
332
|
+
|
|
333
|
+
def _map_normalized_to_original(self, original: str, normalized: str, norm_start: int, norm_length: int) -> tuple:
|
|
334
|
+
"""Map a position in normalized text back to original text."""
|
|
335
|
+
# This is an approximation - we'll search around the estimated area
|
|
336
|
+
if len(normalized) == 0:
|
|
337
|
+
return None, None
|
|
338
|
+
|
|
339
|
+
# Estimate the ratio
|
|
340
|
+
ratio_start = norm_start / len(normalized)
|
|
341
|
+
ratio_end = (norm_start + norm_length) / len(normalized)
|
|
342
|
+
|
|
343
|
+
# Estimate positions in original text
|
|
344
|
+
orig_start_est = int(ratio_start * len(original))
|
|
345
|
+
orig_end_est = int(ratio_end * len(original))
|
|
346
|
+
|
|
347
|
+
# Expand search window
|
|
348
|
+
window_size = max(50, norm_length * 2)
|
|
349
|
+
search_start = max(0, orig_start_est - window_size)
|
|
350
|
+
search_end = min(len(original), orig_end_est + window_size)
|
|
351
|
+
|
|
352
|
+
# Try to find the best match in this window
|
|
353
|
+
search_text = original[search_start:search_end]
|
|
354
|
+
excerpt_to_find = normalized[norm_start:norm_start + norm_length]
|
|
355
|
+
|
|
356
|
+
# Simple substring search in the window
|
|
357
|
+
for i in range(len(search_text) - len(excerpt_to_find) + 1):
|
|
358
|
+
window = search_text[i:i + len(excerpt_to_find)]
|
|
359
|
+
if normalize_text_aggressive(window).lower() == excerpt_to_find.lower():
|
|
360
|
+
return search_start + i, search_start + i + len(window)
|
|
361
|
+
|
|
362
|
+
# Fallback: return estimated positions
|
|
363
|
+
return max(0, orig_start_est), min(len(original), orig_end_est)
|
|
364
|
+
|
|
365
|
+
def _expand_fallback_match(self, text: str, start_pos: int, end_pos: int, original_excerpt: str) -> tuple:
|
|
366
|
+
"""Expand a fallback match to include proper word boundaries and context."""
|
|
367
|
+
# Find word boundaries around the match
|
|
368
|
+
new_start = self._find_word_start(text, start_pos)
|
|
369
|
+
new_end = self._find_word_end(text, end_pos)
|
|
370
|
+
|
|
371
|
+
# Add some context words for better snippet quality
|
|
372
|
+
words_before = self._get_n_words_before(text, new_start, 3)
|
|
373
|
+
words_after = self._get_n_words_after(text, new_end, 3)
|
|
374
|
+
|
|
375
|
+
# Calculate final boundaries
|
|
376
|
+
final_start = max(0, new_start - len(words_before))
|
|
377
|
+
final_end = min(len(text), new_end + len(words_after))
|
|
378
|
+
|
|
379
|
+
return final_start, final_end
|
|
380
|
+
|
|
381
|
+
def _find_word_start(self, text: str, pos: int) -> int:
|
|
382
|
+
"""Find the start of the word containing the given position."""
|
|
383
|
+
if pos <= 0:
|
|
384
|
+
return 0
|
|
385
|
+
# Move backwards to find word boundary
|
|
386
|
+
while pos > 0 and text[pos-1].isalnum():
|
|
387
|
+
pos -= 1
|
|
388
|
+
return pos
|
|
389
|
+
|
|
390
|
+
def _find_word_end(self, text: str, pos: int) -> int:
|
|
391
|
+
"""Find the end of the word containing the given position."""
|
|
392
|
+
if pos >= len(text):
|
|
393
|
+
return len(text)
|
|
394
|
+
# Move forwards to find word boundary
|
|
395
|
+
while pos < len(text) and text[pos].isalnum():
|
|
396
|
+
pos += 1
|
|
397
|
+
return pos
|
|
398
|
+
|
|
399
|
+
def _get_n_words_before(self, text: str, pos: int, n: int) -> str:
|
|
400
|
+
"""Get n words before the given position."""
|
|
401
|
+
if pos <= 0:
|
|
402
|
+
return ""
|
|
403
|
+
|
|
404
|
+
# Look backwards from position to find word boundaries
|
|
405
|
+
before_text = text[:pos]
|
|
406
|
+
words = before_text.split()
|
|
407
|
+
|
|
408
|
+
if len(words) <= n:
|
|
409
|
+
return before_text
|
|
410
|
+
else:
|
|
411
|
+
return " ".join(words[-n:]) + " "
|
|
412
|
+
|
|
413
|
+
def _get_n_words_after(self, text: str, pos: int, n: int) -> str:
|
|
414
|
+
"""Get n words after the given position."""
|
|
415
|
+
if pos >= len(text):
|
|
416
|
+
return ""
|
|
417
|
+
|
|
418
|
+
# Look forwards from position to find word boundaries
|
|
419
|
+
after_text = text[pos:]
|
|
420
|
+
words = after_text.split()
|
|
421
|
+
|
|
422
|
+
if len(words) <= n:
|
|
423
|
+
return after_text
|
|
424
|
+
else:
|
|
425
|
+
return " " + " ".join(words[:n])
|
|
426
|
+
|
|
427
|
+
def consolidate_snippets(
|
|
428
|
+
self,
|
|
429
|
+
original_text: str,
|
|
430
|
+
chunk_results: List[ChunkResult],
|
|
431
|
+
category: str,
|
|
432
|
+
*,
|
|
433
|
+
debug_print: bool = False,
|
|
434
|
+
) -> List[str]:
|
|
435
|
+
"""Convert per-chunk responses into verbatim snippets for ``category``."""
|
|
436
|
+
|
|
437
|
+
all_excerpts: List[Tuple[str, str]] = []
|
|
438
|
+
chunk_texts: List[str] = []
|
|
439
|
+
for chunk_result in chunk_results:
|
|
440
|
+
payload = chunk_result.data
|
|
441
|
+
if not isinstance(payload, dict):
|
|
442
|
+
continue
|
|
443
|
+
if category in payload and isinstance(payload[category], list):
|
|
444
|
+
for item in payload[category]:
|
|
445
|
+
if isinstance(item, dict):
|
|
446
|
+
beginning_raw = item.get("beginning excerpt", "")
|
|
447
|
+
ending_raw = item.get("ending excerpt", "")
|
|
448
|
+
beginning = "" if beginning_raw is None else str(beginning_raw)
|
|
449
|
+
ending = "" if ending_raw is None else str(ending_raw)
|
|
450
|
+
if beginning:
|
|
451
|
+
all_excerpts.append((beginning, ending))
|
|
452
|
+
chunk_texts.append(chunk_result.chunk_text)
|
|
453
|
+
|
|
454
|
+
found = 0
|
|
455
|
+
snippets: List[str] = []
|
|
456
|
+
failed: List[Tuple[str, str]] = []
|
|
457
|
+
begin_fail_count = 0
|
|
458
|
+
end_fail_count = 0
|
|
459
|
+
strict_matches = 0
|
|
460
|
+
|
|
461
|
+
for idx, (beginning, ending) in enumerate(all_excerpts):
|
|
462
|
+
chunk_text = chunk_texts[idx] if idx < len(chunk_texts) else ""
|
|
463
|
+
snippet: Optional[str] = None
|
|
464
|
+
|
|
465
|
+
strict_begin = strict_find(original_text, beginning)
|
|
466
|
+
strict_end = strict_find(original_text, ending) if ending and ending.strip() else True
|
|
467
|
+
if strict_begin and strict_end:
|
|
468
|
+
strict_matches += 1
|
|
469
|
+
|
|
470
|
+
if chunk_text:
|
|
471
|
+
snippet = self.find_snippet_in_text(chunk_text, beginning, ending)
|
|
472
|
+
if debug_print and snippet:
|
|
473
|
+
print(f"[DEBUG] Found in chunk: '{beginning[:50]}...'")
|
|
474
|
+
|
|
475
|
+
if not snippet:
|
|
476
|
+
snippet = self.find_snippet_in_text(original_text, beginning, ending)
|
|
477
|
+
if debug_print and snippet:
|
|
478
|
+
print(f"[DEBUG] Found in full text: '{beginning[:50]}...'")
|
|
479
|
+
elif debug_print:
|
|
480
|
+
print(f"[DEBUG] FAILED to find: '{beginning[:50]}...'")
|
|
481
|
+
letters_begin = letters_only(beginning)
|
|
482
|
+
letters_text = letters_only(original_text)
|
|
483
|
+
print(f"[DEBUG] Letters-only excerpt: '{letters_begin[:50]}...'")
|
|
484
|
+
print(f"[DEBUG] Letters-only contains: {letters_begin in letters_text}")
|
|
485
|
+
begin_match = robust_find_improved(original_text, beginning)
|
|
486
|
+
end_match = robust_find_improved(original_text, ending) if ending else True
|
|
487
|
+
print(f"[DEBUG] Failure analysis for '{beginning[:30]}...':")
|
|
488
|
+
print(f"[DEBUG] Begin match: {begin_match is not None}")
|
|
489
|
+
print(
|
|
490
|
+
f"[DEBUG] End match: {end_match is not None} (ending: '{ending[:20]}...' if ending else 'None')"
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
if snippet:
|
|
494
|
+
if snippet not in snippets:
|
|
495
|
+
snippets.append(snippet)
|
|
496
|
+
found += 1
|
|
497
|
+
else:
|
|
498
|
+
begin_match = strict_find(original_text, beginning)
|
|
499
|
+
if ending and ending.strip():
|
|
500
|
+
end_match = strict_find(original_text, ending)
|
|
501
|
+
if not end_match:
|
|
502
|
+
end_fail_count += 1
|
|
503
|
+
if not begin_match:
|
|
504
|
+
begin_fail_count += 1
|
|
505
|
+
failed.append((beginning, ending))
|
|
506
|
+
|
|
507
|
+
total = len(all_excerpts)
|
|
508
|
+
if total:
|
|
509
|
+
stats = self.hit_rate_stats.setdefault(
|
|
510
|
+
category,
|
|
511
|
+
{
|
|
512
|
+
"found": 0,
|
|
513
|
+
"total": 0,
|
|
514
|
+
"failed_examples": [],
|
|
515
|
+
"begin_failures": 0,
|
|
516
|
+
"end_failures": 0,
|
|
517
|
+
"strict_matches": 0,
|
|
518
|
+
},
|
|
519
|
+
)
|
|
520
|
+
stats["found"] += found
|
|
521
|
+
stats["total"] += total
|
|
522
|
+
stats["begin_failures"] += begin_fail_count
|
|
523
|
+
stats["end_failures"] += end_fail_count
|
|
524
|
+
stats["strict_matches"] += strict_matches
|
|
525
|
+
if failed and len(stats["failed_examples"]) < 3:
|
|
526
|
+
stats["failed_examples"].extend(failed[:2])
|
|
527
|
+
|
|
528
|
+
if debug_print and total:
|
|
529
|
+
rate = 100.0 * found / total if total else 0.0
|
|
530
|
+
strict_rate = 100.0 * strict_matches / total if total else 0.0
|
|
531
|
+
print(
|
|
532
|
+
f"[DEBUG] Category '{category}': {found}/{total} matched ({rate:.1f}%)"
|
|
533
|
+
f" | Strict: {strict_matches} ({strict_rate:.1f}%)"
|
|
534
|
+
f" | Begin failures: {begin_fail_count} | End failures: {end_fail_count}"
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
return snippets
|
|
538
|
+
|
|
539
|
+
def print_final_hit_rates(self) -> None:
|
|
540
|
+
"""Print aggregated hit-rate statistics for debugging."""
|
|
541
|
+
|
|
542
|
+
if not self.hit_rate_stats:
|
|
543
|
+
return
|
|
544
|
+
|
|
545
|
+
print("\n" + "=" * 80)
|
|
546
|
+
print("FINAL MATCHING STATISTICS")
|
|
547
|
+
print("=" * 80)
|
|
548
|
+
|
|
549
|
+
total_found = 0
|
|
550
|
+
total_excerpts = 0
|
|
551
|
+
total_begin_failures = 0
|
|
552
|
+
total_end_failures = 0
|
|
553
|
+
total_strict_matches = 0
|
|
554
|
+
|
|
555
|
+
for category in sorted(self.hit_rate_stats.keys()):
|
|
556
|
+
stats = self.hit_rate_stats[category]
|
|
557
|
+
found = stats.get("found", 0)
|
|
558
|
+
total = stats.get("total", 0)
|
|
559
|
+
begin_fail = stats.get("begin_failures", 0)
|
|
560
|
+
end_fail = stats.get("end_failures", 0)
|
|
561
|
+
strict_match = stats.get("strict_matches", 0)
|
|
562
|
+
hit_rate = 100.0 * found / total if total else 0.0
|
|
563
|
+
strict_rate = 100.0 * strict_match / total if total else 0.0
|
|
564
|
+
begin_fail_pct = 100.0 * begin_fail / total if total else 0.0
|
|
565
|
+
end_fail_pct = 100.0 * end_fail / total if total else 0.0
|
|
566
|
+
|
|
567
|
+
print(
|
|
568
|
+
f"{category:25s}: {found:3d}/{total:3d} ({hit_rate:4.1f}%) | "
|
|
569
|
+
f"Strict: {strict_match:3d} ({strict_rate:4.1f}%) | "
|
|
570
|
+
f"Begin fails: {begin_fail:2d} ({begin_fail_pct:4.1f}%) | "
|
|
571
|
+
f"End fails: {end_fail:2d} ({end_fail_pct:4.1f}%)"
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
total_found += found
|
|
575
|
+
total_excerpts += total
|
|
576
|
+
total_begin_failures += begin_fail
|
|
577
|
+
total_end_failures += end_fail
|
|
578
|
+
total_strict_matches += strict_match
|
|
579
|
+
|
|
580
|
+
overall_rate = 100.0 * total_found / total_excerpts if total_excerpts else 0.0
|
|
581
|
+
overall_strict_rate = (
|
|
582
|
+
100.0 * total_strict_matches / total_excerpts if total_excerpts else 0.0
|
|
583
|
+
)
|
|
584
|
+
overall_begin_fail_rate = (
|
|
585
|
+
100.0 * total_begin_failures / total_excerpts if total_excerpts else 0.0
|
|
586
|
+
)
|
|
587
|
+
overall_end_fail_rate = (
|
|
588
|
+
100.0 * total_end_failures / total_excerpts if total_excerpts else 0.0
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
print("-" * 80)
|
|
592
|
+
print(
|
|
593
|
+
f"{'OVERALL':25s}: {total_found:3d}/{total_excerpts:3d} ({overall_rate:4.1f}%) | "
|
|
594
|
+
f"Strict: {total_strict_matches:3d} ({overall_strict_rate:4.1f}%) | "
|
|
595
|
+
f"Begin fails: {total_begin_failures:2d} ({overall_begin_fail_rate:4.1f}%) | "
|
|
596
|
+
f"End fails: {total_end_failures:2d} ({overall_end_fail_rate:4.1f}%)"
|
|
597
|
+
)
|
|
598
|
+
print("=" * 80)
|
|
599
|
+
|
|
600
|
+
def _iteration_file_name(self, iteration: int) -> str:
|
|
601
|
+
if iteration == 0:
|
|
602
|
+
return self.cfg.file_name
|
|
603
|
+
stem, ext = os.path.splitext(self.cfg.file_name)
|
|
604
|
+
return f"{stem}_iter{iteration}{ext}"
|
|
605
|
+
|
|
606
|
+
def _strip_snippets(self, text: str, snippets_by_category: Dict[str, List[str]]) -> str:
|
|
607
|
+
remaining = text
|
|
608
|
+
for snippets in snippets_by_category.values():
|
|
609
|
+
for snippet in snippets:
|
|
610
|
+
if snippet:
|
|
611
|
+
remaining = remaining.replace(snippet, " ", 1)
|
|
612
|
+
return re.sub(r"\s+", " ", remaining).strip()
|
|
613
|
+
|
|
614
|
+
def _merge_snippet_results(
|
|
615
|
+
self,
|
|
616
|
+
destination: Dict[int, Dict[str, List[str]]],
|
|
617
|
+
source: Dict[int, Dict[str, List[str]]],
|
|
618
|
+
) -> bool:
|
|
619
|
+
added = False
|
|
620
|
+
for row_idx, cat_map in source.items():
|
|
621
|
+
dest_row = destination.setdefault(row_idx, {})
|
|
622
|
+
for category, snippets in cat_map.items():
|
|
623
|
+
dest_list = dest_row.setdefault(category, [])
|
|
624
|
+
for snippet in snippets:
|
|
625
|
+
if snippet and snippet not in dest_list:
|
|
626
|
+
dest_list.append(snippet)
|
|
627
|
+
added = True
|
|
628
|
+
return added
|
|
629
|
+
|
|
630
|
+
async def _gather_iteration(
|
|
631
|
+
self,
|
|
632
|
+
row_texts: Dict[int, str],
|
|
633
|
+
*,
|
|
634
|
+
original_texts: List[str],
|
|
635
|
+
raw_values: List[Any],
|
|
636
|
+
categories: Optional[Dict[str, str]],
|
|
637
|
+
additional_instructions: Optional[str],
|
|
638
|
+
iteration: int,
|
|
639
|
+
dynamic_mode: bool,
|
|
640
|
+
reset_files: bool,
|
|
641
|
+
category_subset: Optional[Set[str]] = None,
|
|
642
|
+
**kwargs: Any,
|
|
643
|
+
) -> Dict[int, Dict[str, List[str]]]:
|
|
644
|
+
if not row_texts:
|
|
645
|
+
return {}
|
|
646
|
+
|
|
647
|
+
debug = self.cfg.debug_print
|
|
648
|
+
selected_categories: List[str] = []
|
|
649
|
+
if not dynamic_mode and categories:
|
|
650
|
+
selected_categories = [
|
|
651
|
+
cat
|
|
652
|
+
for cat in categories.keys()
|
|
653
|
+
if category_subset is None or cat in category_subset
|
|
654
|
+
]
|
|
655
|
+
if not selected_categories:
|
|
656
|
+
return {}
|
|
657
|
+
|
|
658
|
+
requests: List[PromptRequest] = []
|
|
659
|
+
prompt_images: Dict[str, List[str]] = {}
|
|
660
|
+
prompt_audio: Dict[str, List[Dict[str, str]]] = {}
|
|
661
|
+
prompt_pdfs: Dict[str, List[Dict[str, str]]] = {}
|
|
662
|
+
pending_requests: List[Dict[str, Any]] = []
|
|
663
|
+
|
|
664
|
+
if not dynamic_mode and categories:
|
|
665
|
+
category_batches = [
|
|
666
|
+
selected_categories[i : i + self.cfg.max_categories_per_call]
|
|
667
|
+
for i in range(0, len(selected_categories), self.cfg.max_categories_per_call)
|
|
668
|
+
]
|
|
669
|
+
else:
|
|
670
|
+
category_batches = []
|
|
671
|
+
|
|
672
|
+
for row_idx, text in row_texts.items():
|
|
673
|
+
text_str = str(text or "")
|
|
674
|
+
if not text_str.strip():
|
|
675
|
+
continue
|
|
676
|
+
chunks = self.chunk_by_words(text_str, self.cfg.max_words_per_call)
|
|
677
|
+
images = (
|
|
678
|
+
load_image_inputs(raw_values[row_idx])
|
|
679
|
+
if self.cfg.modality == "image"
|
|
680
|
+
else None
|
|
681
|
+
)
|
|
682
|
+
audio_inputs = (
|
|
683
|
+
load_audio_inputs(raw_values[row_idx])
|
|
684
|
+
if self.cfg.modality == "audio"
|
|
685
|
+
else None
|
|
686
|
+
)
|
|
687
|
+
pdf_inputs = (
|
|
688
|
+
load_pdf_inputs(raw_values[row_idx])
|
|
689
|
+
if self.cfg.modality == "pdf"
|
|
690
|
+
else None
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
for chunk_idx, chunk in enumerate(chunks):
|
|
694
|
+
if dynamic_mode:
|
|
695
|
+
identifier = f"row{row_idx}_iter{iteration}_chunk{chunk_idx}"
|
|
696
|
+
pending_requests.append(
|
|
697
|
+
{
|
|
698
|
+
"identifier": identifier,
|
|
699
|
+
"chunk": chunk,
|
|
700
|
+
"row_index": row_idx,
|
|
701
|
+
"batch_categories": None,
|
|
702
|
+
"images": images,
|
|
703
|
+
"audio": audio_inputs,
|
|
704
|
+
"pdfs": pdf_inputs,
|
|
705
|
+
}
|
|
706
|
+
)
|
|
707
|
+
else:
|
|
708
|
+
for batch_idx, batch_keys in enumerate(category_batches):
|
|
709
|
+
assert categories is not None
|
|
710
|
+
batch_categories = {k: categories[k] for k in batch_keys}
|
|
711
|
+
identifier = (
|
|
712
|
+
f"row{row_idx}_iter{iteration}_chunk{chunk_idx}_batch{batch_idx}"
|
|
713
|
+
)
|
|
714
|
+
pending_requests.append(
|
|
715
|
+
{
|
|
716
|
+
"identifier": identifier,
|
|
717
|
+
"chunk": chunk,
|
|
718
|
+
"row_index": row_idx,
|
|
719
|
+
"batch_categories": batch_categories,
|
|
720
|
+
"images": images,
|
|
721
|
+
"audio": audio_inputs,
|
|
722
|
+
"pdfs": pdf_inputs,
|
|
723
|
+
}
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
if not pending_requests:
|
|
727
|
+
return {}
|
|
728
|
+
|
|
729
|
+
announce_prompt_rendering("Codify", len(pending_requests))
|
|
730
|
+
|
|
731
|
+
for req in pending_requests:
|
|
732
|
+
prompt = self.template.render(
|
|
733
|
+
text=req["chunk"],
|
|
734
|
+
categories=req["batch_categories"],
|
|
735
|
+
additional_instructions=additional_instructions,
|
|
736
|
+
modality=self.cfg.modality,
|
|
737
|
+
)
|
|
738
|
+
requests.append(
|
|
739
|
+
PromptRequest(
|
|
740
|
+
identifier=req["identifier"],
|
|
741
|
+
prompt=prompt,
|
|
742
|
+
row_index=req["row_index"],
|
|
743
|
+
chunk_text=req["chunk"],
|
|
744
|
+
)
|
|
745
|
+
)
|
|
746
|
+
if req["images"]:
|
|
747
|
+
prompt_images[req["identifier"]] = list(req["images"])
|
|
748
|
+
if req["audio"]:
|
|
749
|
+
prompt_audio[req["identifier"]] = list(req["audio"])
|
|
750
|
+
if req.get("pdfs"):
|
|
751
|
+
prompt_pdfs[req["identifier"]] = list(req["pdfs"])
|
|
752
|
+
|
|
753
|
+
prompts = [req.prompt for req in requests]
|
|
754
|
+
identifiers = [req.identifier for req in requests]
|
|
755
|
+
id_to_request = {req.identifier: req for req in requests}
|
|
756
|
+
|
|
757
|
+
batch_df = await get_all_responses(
|
|
758
|
+
prompts=prompts,
|
|
759
|
+
identifiers=identifiers,
|
|
760
|
+
n_parallels=self.cfg.n_parallels,
|
|
761
|
+
save_path=os.path.join(self.cfg.save_dir, self._iteration_file_name(iteration)),
|
|
762
|
+
reset_files=reset_files,
|
|
763
|
+
use_dummy=self.cfg.use_dummy,
|
|
764
|
+
json_mode=self.cfg.json_mode,
|
|
765
|
+
model=self.cfg.model,
|
|
766
|
+
max_timeout=self.cfg.max_timeout,
|
|
767
|
+
print_example_prompt=True,
|
|
768
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
769
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
770
|
+
prompt_images=prompt_images or None,
|
|
771
|
+
prompt_audio=prompt_audio or None,
|
|
772
|
+
prompt_pdfs=prompt_pdfs or None,
|
|
773
|
+
**kwargs,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
chunk_results_by_row: Dict[int, List[ChunkResult]] = defaultdict(list)
|
|
777
|
+
for ident, resp in zip(batch_df["Identifier"], batch_df["Response"]):
|
|
778
|
+
request = id_to_request.get(ident)
|
|
779
|
+
if request is None:
|
|
780
|
+
continue
|
|
781
|
+
main = resp[0] if isinstance(resp, list) and resp else resp
|
|
782
|
+
parsed = self.parse_json(main) or {}
|
|
783
|
+
if debug:
|
|
784
|
+
if not parsed:
|
|
785
|
+
print(f"[DEBUG] Failed to parse response for {ident}")
|
|
786
|
+
else:
|
|
787
|
+
print(f"[DEBUG] Parsed response for {ident} with keys: {list(parsed.keys())}")
|
|
788
|
+
chunk_results_by_row[request.row_index].append(
|
|
789
|
+
ChunkResult(identifier=ident, chunk_text=request.chunk_text, data=parsed)
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
iteration_results: Dict[int, Dict[str, List[str]]] = {}
|
|
793
|
+
if dynamic_mode:
|
|
794
|
+
for row_idx in row_texts.keys():
|
|
795
|
+
chunk_results = chunk_results_by_row.get(row_idx, [])
|
|
796
|
+
if not chunk_results:
|
|
797
|
+
continue
|
|
798
|
+
categories_seen: Set[str] = set()
|
|
799
|
+
for chunk_result in chunk_results:
|
|
800
|
+
for key, value in chunk_result.data.items():
|
|
801
|
+
if isinstance(key, str) and isinstance(value, list):
|
|
802
|
+
categories_seen.add(key)
|
|
803
|
+
if not categories_seen:
|
|
804
|
+
continue
|
|
805
|
+
row_map: Dict[str, List[str]] = {}
|
|
806
|
+
for category in sorted(categories_seen):
|
|
807
|
+
snippets = self.consolidate_snippets(
|
|
808
|
+
original_texts[row_idx],
|
|
809
|
+
chunk_results,
|
|
810
|
+
category,
|
|
811
|
+
debug_print=debug,
|
|
812
|
+
)
|
|
813
|
+
if snippets:
|
|
814
|
+
row_map[category] = snippets
|
|
815
|
+
if row_map:
|
|
816
|
+
iteration_results[row_idx] = row_map
|
|
817
|
+
else:
|
|
818
|
+
for row_idx in row_texts.keys():
|
|
819
|
+
chunk_results = chunk_results_by_row.get(row_idx, [])
|
|
820
|
+
row_map: Dict[str, List[str]] = {}
|
|
821
|
+
for category in selected_categories:
|
|
822
|
+
snippets = self.consolidate_snippets(
|
|
823
|
+
original_texts[row_idx],
|
|
824
|
+
chunk_results,
|
|
825
|
+
category,
|
|
826
|
+
debug_print=debug,
|
|
827
|
+
)
|
|
828
|
+
row_map[category] = snippets
|
|
829
|
+
iteration_results[row_idx] = row_map
|
|
830
|
+
|
|
831
|
+
return iteration_results
|
|
832
|
+
|
|
833
|
+
async def _classify_remaining(
|
|
834
|
+
self,
|
|
835
|
+
aggregated: Dict[int, Dict[str, List[str]]],
|
|
836
|
+
original_texts: List[str],
|
|
837
|
+
categories: Dict[str, str],
|
|
838
|
+
additional_instructions: Optional[str],
|
|
839
|
+
iteration: int,
|
|
840
|
+
reset_files: bool,
|
|
841
|
+
) -> Dict[int, Set[str]]:
|
|
842
|
+
row_indices: List[int] = []
|
|
843
|
+
remaining_texts: List[str] = []
|
|
844
|
+
for row_idx, original in enumerate(original_texts):
|
|
845
|
+
snippet_map = aggregated.get(row_idx, {})
|
|
846
|
+
stripped = self._strip_snippets(original, snippet_map)
|
|
847
|
+
if stripped:
|
|
848
|
+
row_indices.append(row_idx)
|
|
849
|
+
remaining_texts.append(stripped)
|
|
850
|
+
|
|
851
|
+
if not row_indices:
|
|
852
|
+
return {}
|
|
853
|
+
|
|
854
|
+
validation_dir = os.path.join(self.cfg.save_dir, "completion_checks")
|
|
855
|
+
os.makedirs(validation_dir, exist_ok=True)
|
|
856
|
+
stem = Path(self.cfg.file_name).stem
|
|
857
|
+
file_name = f"{stem}_completion_iter{iteration}.csv"
|
|
858
|
+
|
|
859
|
+
base_instruction = (
|
|
860
|
+
"These passages contain the remaining text after previously extracted snippets were removed. "
|
|
861
|
+
"Return True for a label only if the remaining text still contains a clear, distinct snippet "
|
|
862
|
+
"that should be coded for that label. Default to False when unsure."
|
|
863
|
+
)
|
|
864
|
+
if self.cfg.completion_classifier_instructions:
|
|
865
|
+
base_instruction += "\n" + self.cfg.completion_classifier_instructions.strip()
|
|
866
|
+
if additional_instructions:
|
|
867
|
+
base_instruction += "\nOriginal coding instructions:\n" + additional_instructions.strip()
|
|
868
|
+
|
|
869
|
+
classify_cfg = ClassifyConfig(
|
|
870
|
+
labels=categories,
|
|
871
|
+
save_dir=validation_dir,
|
|
872
|
+
file_name=file_name,
|
|
873
|
+
model=self.cfg.model,
|
|
874
|
+
n_parallels=self.cfg.n_parallels,
|
|
875
|
+
n_runs=1,
|
|
876
|
+
use_dummy=self.cfg.use_dummy,
|
|
877
|
+
additional_instructions=base_instruction,
|
|
878
|
+
modality=self.cfg.modality,
|
|
879
|
+
n_attributes_per_run=self.cfg.max_categories_per_call,
|
|
880
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
881
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
882
|
+
)
|
|
883
|
+
classifier = Classify(classify_cfg)
|
|
884
|
+
|
|
885
|
+
cls_df = pd.DataFrame({"text": remaining_texts})
|
|
886
|
+
results_df = await classifier.run(
|
|
887
|
+
cls_df,
|
|
888
|
+
column_name="text",
|
|
889
|
+
reset_files=reset_files and iteration == 0,
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
flagged: Dict[int, Set[str]] = {}
|
|
893
|
+
for idx, row_idx in enumerate(row_indices):
|
|
894
|
+
flagged_categories = {
|
|
895
|
+
category
|
|
896
|
+
for category in categories.keys()
|
|
897
|
+
if bool(results_df.at[idx, category])
|
|
898
|
+
}
|
|
899
|
+
if flagged_categories:
|
|
900
|
+
flagged[row_idx] = flagged_categories
|
|
901
|
+
|
|
902
|
+
return flagged
|
|
903
|
+
|
|
904
|
+
async def _completion_loop(
|
|
905
|
+
self,
|
|
906
|
+
aggregated: Dict[int, Dict[str, List[str]]],
|
|
907
|
+
original_texts: List[str],
|
|
908
|
+
raw_values: List[Any],
|
|
909
|
+
categories: Dict[str, str],
|
|
910
|
+
additional_instructions: Optional[str],
|
|
911
|
+
reset_files: bool,
|
|
912
|
+
**kwargs: Any,
|
|
913
|
+
) -> Dict[int, Dict[str, List[str]]]:
|
|
914
|
+
total_rounds = max(1, int(self.cfg.n_rounds))
|
|
915
|
+
completion_iterations = max(0, total_rounds - 1)
|
|
916
|
+
|
|
917
|
+
for depth in range(1, completion_iterations + 1):
|
|
918
|
+
flagged = await self._classify_remaining(
|
|
919
|
+
aggregated,
|
|
920
|
+
original_texts,
|
|
921
|
+
categories,
|
|
922
|
+
additional_instructions,
|
|
923
|
+
iteration=depth - 1,
|
|
924
|
+
reset_files=reset_files,
|
|
925
|
+
)
|
|
926
|
+
if not flagged:
|
|
927
|
+
break
|
|
928
|
+
|
|
929
|
+
category_subset = set().union(*flagged.values())
|
|
930
|
+
row_texts: Dict[int, str] = {}
|
|
931
|
+
for row_idx in flagged.keys():
|
|
932
|
+
stripped = self._strip_snippets(original_texts[row_idx], aggregated.get(row_idx, {}))
|
|
933
|
+
if stripped:
|
|
934
|
+
row_texts[row_idx] = stripped
|
|
935
|
+
if not row_texts:
|
|
936
|
+
break
|
|
937
|
+
|
|
938
|
+
iteration_results = await self._gather_iteration(
|
|
939
|
+
row_texts,
|
|
940
|
+
original_texts=original_texts,
|
|
941
|
+
raw_values=raw_values,
|
|
942
|
+
categories=categories,
|
|
943
|
+
additional_instructions=additional_instructions,
|
|
944
|
+
iteration=depth,
|
|
945
|
+
dynamic_mode=False,
|
|
946
|
+
reset_files=False,
|
|
947
|
+
category_subset=category_subset,
|
|
948
|
+
**kwargs,
|
|
949
|
+
)
|
|
950
|
+
added = self._merge_snippet_results(aggregated, iteration_results)
|
|
951
|
+
if not added:
|
|
952
|
+
break
|
|
953
|
+
|
|
954
|
+
return aggregated
|
|
955
|
+
|
|
956
|
+
async def run(
|
|
957
|
+
self,
|
|
958
|
+
df: pd.DataFrame,
|
|
959
|
+
column_name: str,
|
|
960
|
+
*,
|
|
961
|
+
categories: Optional[Dict[str, str]] = None,
|
|
962
|
+
additional_instructions: str = "",
|
|
963
|
+
reset_files: bool = False,
|
|
964
|
+
**kwargs: Any,
|
|
965
|
+
) -> pd.DataFrame:
|
|
966
|
+
df_proc = df.reset_index(drop=True).copy()
|
|
967
|
+
self.hit_rate_stats = {}
|
|
968
|
+
|
|
969
|
+
raw_values = df_proc[column_name].tolist()
|
|
970
|
+
original_texts = ["" if pd.isna(val) else str(val) for val in raw_values]
|
|
971
|
+
warn_if_modality_mismatch(raw_values, self.cfg.modality, column_name=column_name)
|
|
972
|
+
|
|
973
|
+
additional = (additional_instructions or "").strip() or None
|
|
974
|
+
dynamic_mode = categories is None
|
|
975
|
+
if dynamic_mode and not additional:
|
|
976
|
+
raise ValueError(
|
|
977
|
+
"additional_instructions must be provided when categories is None"
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
categories_dict = categories or {}
|
|
981
|
+
aggregated: Dict[int, Dict[str, List[str]]] = {}
|
|
982
|
+
if not dynamic_mode:
|
|
983
|
+
for idx in range(len(df_proc)):
|
|
984
|
+
aggregated[idx] = {cat: [] for cat in categories_dict.keys()}
|
|
985
|
+
|
|
986
|
+
row_texts = {idx: original_texts[idx] for idx in range(len(df_proc))}
|
|
987
|
+
initial_results = await self._gather_iteration(
|
|
988
|
+
row_texts,
|
|
989
|
+
original_texts=original_texts,
|
|
990
|
+
raw_values=raw_values,
|
|
991
|
+
categories=None if dynamic_mode else categories_dict,
|
|
992
|
+
additional_instructions=additional,
|
|
993
|
+
iteration=0,
|
|
994
|
+
dynamic_mode=dynamic_mode,
|
|
995
|
+
reset_files=reset_files,
|
|
996
|
+
**kwargs,
|
|
997
|
+
)
|
|
998
|
+
self._merge_snippet_results(aggregated, initial_results)
|
|
999
|
+
|
|
1000
|
+
if not dynamic_mode and categories_dict and self.cfg.n_rounds > 1:
|
|
1001
|
+
aggregated = await self._completion_loop(
|
|
1002
|
+
aggregated,
|
|
1003
|
+
original_texts,
|
|
1004
|
+
raw_values,
|
|
1005
|
+
categories_dict,
|
|
1006
|
+
additional,
|
|
1007
|
+
reset_files=reset_files,
|
|
1008
|
+
**kwargs,
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
if dynamic_mode:
|
|
1012
|
+
coded_passages: List[Dict[str, List[str]]] = []
|
|
1013
|
+
for idx in range(len(df_proc)):
|
|
1014
|
+
row_map = aggregated.get(idx, {})
|
|
1015
|
+
coded_passages.append(
|
|
1016
|
+
{cat: list(snippets) for cat, snippets in row_map.items()}
|
|
1017
|
+
)
|
|
1018
|
+
df_proc["coded_passages"] = coded_passages
|
|
1019
|
+
else:
|
|
1020
|
+
for category in categories_dict.keys():
|
|
1021
|
+
df_proc[category] = [
|
|
1022
|
+
list(aggregated.get(idx, {}).get(category, []))
|
|
1023
|
+
for idx in range(len(df_proc))
|
|
1024
|
+
]
|
|
1025
|
+
|
|
1026
|
+
output_path = os.path.join(self.cfg.save_dir, "coded_passages.csv")
|
|
1027
|
+
df_proc.to_csv(output_path, index=False)
|
|
1028
|
+
|
|
1029
|
+
if self.cfg.debug_print:
|
|
1030
|
+
print(f"\n[DEBUG] Processing complete. Results saved to: {self.cfg.save_dir}")
|
|
1031
|
+
self.print_final_hit_rates()
|
|
1032
|
+
|
|
1033
|
+
return df_proc
|