openai-gabriel 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. gabriel/__init__.py +61 -0
  2. gabriel/_version.py +1 -0
  3. gabriel/api.py +2284 -0
  4. gabriel/cli/__main__.py +60 -0
  5. gabriel/core/__init__.py +7 -0
  6. gabriel/core/llm_client.py +34 -0
  7. gabriel/core/pipeline.py +18 -0
  8. gabriel/core/prompt_template.py +152 -0
  9. gabriel/prompts/__init__.py +1 -0
  10. gabriel/prompts/bucket_prompt.jinja2 +113 -0
  11. gabriel/prompts/classification_prompt.jinja2 +50 -0
  12. gabriel/prompts/codify_prompt.jinja2 +95 -0
  13. gabriel/prompts/comparison_prompt.jinja2 +60 -0
  14. gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
  15. gabriel/prompts/deidentification_prompt.jinja2 +112 -0
  16. gabriel/prompts/extraction_prompt.jinja2 +61 -0
  17. gabriel/prompts/filter_prompt.jinja2 +31 -0
  18. gabriel/prompts/ideation_prompt.jinja2 +80 -0
  19. gabriel/prompts/merge_prompt.jinja2 +47 -0
  20. gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
  21. gabriel/prompts/rankings_prompt.jinja2 +49 -0
  22. gabriel/prompts/ratings_prompt.jinja2 +50 -0
  23. gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
  24. gabriel/prompts/seed.jinja2 +43 -0
  25. gabriel/prompts/snippets.jinja2 +117 -0
  26. gabriel/tasks/__init__.py +63 -0
  27. gabriel/tasks/_attribute_utils.py +69 -0
  28. gabriel/tasks/bucket.py +432 -0
  29. gabriel/tasks/classify.py +562 -0
  30. gabriel/tasks/codify.py +1033 -0
  31. gabriel/tasks/compare.py +235 -0
  32. gabriel/tasks/debias.py +1460 -0
  33. gabriel/tasks/deduplicate.py +341 -0
  34. gabriel/tasks/deidentify.py +316 -0
  35. gabriel/tasks/discover.py +524 -0
  36. gabriel/tasks/extract.py +455 -0
  37. gabriel/tasks/filter.py +169 -0
  38. gabriel/tasks/ideate.py +782 -0
  39. gabriel/tasks/merge.py +464 -0
  40. gabriel/tasks/paraphrase.py +531 -0
  41. gabriel/tasks/rank.py +2041 -0
  42. gabriel/tasks/rate.py +347 -0
  43. gabriel/tasks/seed.py +465 -0
  44. gabriel/tasks/whatever.py +344 -0
  45. gabriel/utils/__init__.py +64 -0
  46. gabriel/utils/audio_utils.py +42 -0
  47. gabriel/utils/file_utils.py +464 -0
  48. gabriel/utils/image_utils.py +22 -0
  49. gabriel/utils/jinja.py +31 -0
  50. gabriel/utils/logging.py +86 -0
  51. gabriel/utils/mapmaker.py +304 -0
  52. gabriel/utils/media_utils.py +78 -0
  53. gabriel/utils/modality_utils.py +148 -0
  54. gabriel/utils/openai_utils.py +5470 -0
  55. gabriel/utils/parsing.py +282 -0
  56. gabriel/utils/passage_viewer.py +2557 -0
  57. gabriel/utils/pdf_utils.py +20 -0
  58. gabriel/utils/plot_utils.py +2881 -0
  59. gabriel/utils/prompt_utils.py +42 -0
  60. gabriel/utils/word_matching.py +158 -0
  61. openai_gabriel-1.0.1.dist-info/METADATA +443 -0
  62. openai_gabriel-1.0.1.dist-info/RECORD +67 -0
  63. openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
  64. openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
  65. openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
  66. openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
  67. openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,531 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ import re
6
+ import warnings
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Dict, Any, List, Tuple
9
+
10
+ import pandas as pd
11
+ from pathlib import Path
12
+
13
+ from ..core.prompt_template import PromptTemplate, resolve_template
14
+ from ..utils.openai_utils import get_all_responses
15
+ from ..utils.logging import announce_prompt_rendering
16
+
17
+ # Import classifier utilities for recursive validation. Importing from
18
+ # ``gabriel.tasks.classify`` does not introduce a circular dependency
19
+ # because the classifier does not import the paraphrasing task.
20
+ from ..tasks.classify import Classify, ClassifyConfig
21
+
22
+
23
+ @dataclass
24
+ class ParaphraseConfig:
25
+ """Configuration for :class:`Paraphrase`."""
26
+
27
+ # Instruction passed to the paraphrase prompt. This should describe
28
+ # how the original passage should be rewritten.
29
+ instructions: str
30
+ # Optional name for the revised column in the output DataFrame. If
31
+ # unspecified, ``f"{column_name}_revised"`` will be used.
32
+ revised_column_name: Optional[str] = None
33
+ # Number of revisions to generate per passage. When greater than 1,
34
+ # additional columns will be appended to the output DataFrame with
35
+ # suffixes ``_1``, ``_2``, etc.
36
+ n_revisions: int = 1
37
+ # Directory where all paraphrase responses and intermediate files are
38
+ # persisted.
39
+ save_dir: str = "paraphrase"
40
+ # Base file name for the raw paraphrase responses. Cleaned and
41
+ # validated responses will be written using the same stem with
42
+ # ``_cleaned.csv`` appended.
43
+ file_name: str = "paraphrase_responses.csv"
44
+ # OpenAI model used for both paraphrasing and classification. The
45
+ # default matches the existing behaviour of GABRIEL.
46
+ model: str = "gpt-5-mini"
47
+ # When true, the model will be asked to output JSON only.
48
+ json_mode: bool = False
49
+ # When set, controls whether the underlying helper should use web search
50
+ # augmentation. ``None`` defers to downstream defaults while ``True`` and
51
+ # ``False`` explicitly enable or disable the feature.
52
+ web_search: Optional[bool] = None
53
+ # Maximum number of parallel requests that will be sent to the
54
+ # underlying API. Note that classification and paraphrasing share
55
+ # this value for simplicity.
56
+ n_parallels: int = 650
57
+ # Use dummy responses instead of real API calls. Exposed here
58
+ # primarily for testing.
59
+ use_dummy: bool = False
60
+ # Optional reasoning effort passed through to the LLM helper.
61
+ reasoning_effort: Optional[str] = None
62
+ # Optional reasoning summary passed through to the LLM helper.
63
+ reasoning_summary: Optional[str] = None
64
+ # Maximum number of paraphrase/validation rounds to run. A value of
65
+ # ``1`` preserves the historical behaviour of a single paraphrase
66
+ # generation pass with no recursive validation. Values greater than
67
+ # one enable recursive validation with an upper bound on the number of
68
+ # cycles.
69
+ n_rounds: int = 1
70
+ # Deprecated flag kept for backwards compatibility. When set to
71
+ # ``True`` and ``n_rounds`` is left at its default, it will be coerced
72
+ # to ``2``.
73
+ recursive_validation: Optional[bool] = None
74
+ # When greater than one, multiple paraphrases are generated for
75
+ # each passage in the very first round of generation. If at least
76
+ # one candidate passes the validation check, that candidate will be
77
+ # selected immediately without triggering further rounds. A value
78
+ # of one preserves the historical behaviour of producing a single
79
+ # paraphrase per passage at the outset.
80
+ n_initial_candidates: int = 1
81
+ # Number of candidate paraphrases to generate for each failing
82
+ # passage in subsequent validation rounds. This value is used
83
+ # whenever a paraphrase does not initially satisfy the validation
84
+ # criterion. Generating multiple candidates in later rounds
85
+ # improves the probability of finding an acceptable paraphrase.
86
+ n_validation_candidates: int = 5
87
+ # Whether to feed the previously chosen paraphrase back into the
88
+ # generator during recursive validation. When ``False`` (the
89
+ # default), the original text is always used as the source for
90
+ # regeneration. When ``True``, the most recent paraphrase is
91
+ # provided as the input for further rewriting. This can be useful
92
+ # when incremental improvements are desired rather than starting
93
+ # over from the original each time. The option only has effect
94
+ # when ``recursive_validation`` is enabled.
95
+ use_modified_source: bool = False
96
+
97
+ def __post_init__(self) -> None:
98
+ try:
99
+ rounds = int(self.n_rounds)
100
+ except (TypeError, ValueError):
101
+ rounds = 1
102
+ if rounds < 1:
103
+ rounds = 1
104
+
105
+ if self.recursive_validation is not None:
106
+ warnings.warn(
107
+ "recursive_validation is deprecated; use n_rounds instead.",
108
+ DeprecationWarning,
109
+ stacklevel=2,
110
+ )
111
+ if self.recursive_validation and rounds <= 1:
112
+ rounds = 2
113
+
114
+ self.n_rounds = rounds
115
+
116
+
117
+ class Paraphrase:
118
+ """Paraphrase text columns in a DataFrame."""
119
+
120
+ def __init__(
121
+ self,
122
+ cfg: ParaphraseConfig,
123
+ template: Optional[PromptTemplate] = None,
124
+ template_path: Optional[str] = None,
125
+ ) -> None:
126
+ self.cfg = cfg
127
+ expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
128
+ expanded.mkdir(parents=True, exist_ok=True)
129
+ cfg.save_dir = str(expanded)
130
+ self.template = resolve_template(
131
+ template=template,
132
+ template_path=template_path,
133
+ reference_filename="paraphrase_prompt.jinja2",
134
+ )
135
+
136
+ async def run(
137
+ self,
138
+ df: pd.DataFrame,
139
+ column_name: str,
140
+ *,
141
+ reset_files: bool = False,
142
+ **kwargs: Any,
143
+ ) -> pd.DataFrame:
144
+ """Paraphrase ``df[column_name]`` and return a DataFrame with revisions.
145
+
146
+ This method orchestrates prompt construction, asynchronous API
147
+ calls, optional recursive validation using the classifier, and
148
+ persistence of results. The output DataFrame preserves the
149
+ order of the input and appends one or more revised columns.
150
+ """
151
+ # Ensure row indices are contiguous so we can map responses back
152
+ # deterministically. A copy is created to avoid mutating the
153
+ # caller's DataFrame.
154
+ df_proc = df.reset_index(drop=True).copy()
155
+ # Convert the target column into a list of strings. We coerce
156
+ # values to strings so that non-string columns (e.g. numbers) are
157
+ # handled gracefully.
158
+ texts: List[str] = df_proc[column_name].astype(str).tolist()
159
+ # Determine the base name for the revised column(s).
160
+ base_col = self.cfg.revised_column_name or f"{column_name}_revised"
161
+ # Determine how many paraphrases to produce per passage. A value
162
+ # less than 1 defaults to a single revision to align with
163
+ # existing behaviour.
164
+ n = self.cfg.n_revisions if self.cfg.n_revisions and self.cfg.n_revisions > 0 else 1
165
+
166
+ announce_prompt_rendering("Paraphrase", len(texts) * n)
167
+
168
+ # Resolve the number of recursive validation rounds to run. This
169
+ # value is consumed here (and not forwarded to
170
+ # ``get_all_responses``) to avoid unexpected keyword errors.
171
+ requested_rounds = kwargs.pop("n_rounds", self.cfg.n_rounds)
172
+ try:
173
+ max_rounds = int(requested_rounds)
174
+ except (TypeError, ValueError):
175
+ max_rounds = self.cfg.n_rounds
176
+ if max_rounds < 1:
177
+ max_rounds = 1
178
+
179
+ # Track whether each revision ultimately received validation
180
+ # approval. In the non-recursive case every output is treated as
181
+ # approved.
182
+ approval_map: Dict[Tuple[int, int], bool] = {}
183
+
184
+ # When recursive validation is disabled (or limited to a single
185
+ # round), follow the original behaviour: generate a single
186
+ # paraphrase per requested revision and skip classification.
187
+ # Otherwise, defer generation and validation to the recursive
188
+ # routine.
189
+ if max_rounds <= 1:
190
+ prompts: List[str] = []
191
+ identifiers: List[str] = []
192
+ for idx, text in enumerate(texts):
193
+ for j in range(1, n + 1):
194
+ prompts.append(
195
+ self.template.render(text=text, instructions=self.cfg.instructions)
196
+ )
197
+ identifiers.append(f"row_{idx}_rev{j}")
198
+ save_path = os.path.join(self.cfg.save_dir, self.cfg.file_name)
199
+ resp_df = await get_all_responses(
200
+ prompts=prompts,
201
+ identifiers=identifiers,
202
+ save_path=save_path,
203
+ model=self.cfg.model,
204
+ json_mode=self.cfg.json_mode,
205
+ web_search=self.cfg.web_search,
206
+ n_parallels=self.cfg.n_parallels,
207
+ use_dummy=self.cfg.use_dummy,
208
+ reset_files=reset_files,
209
+ reasoning_effort=self.cfg.reasoning_effort,
210
+ reasoning_summary=self.cfg.reasoning_summary,
211
+ **kwargs,
212
+ )
213
+ resp_map: Dict[Tuple[int, int], str] = {}
214
+ for ident, resp in zip(resp_df["Identifier"], resp_df["Response"]):
215
+ main = resp[0] if isinstance(resp, list) and resp else resp
216
+ m = re.match(r"row_(\d+)_rev(\d+)", ident)
217
+ if m:
218
+ row = int(m.group(1))
219
+ rev = int(m.group(2)) - 1
220
+ resp_map[(row, rev)] = main
221
+ approval_map[(row, rev)] = True
222
+ else:
223
+ # Initialise an empty response map. The recursive validation
224
+ # routine will populate this map with one paraphrase per
225
+ # (row, revision) key and record whether it passed
226
+ # validation.
227
+ resp_map: Dict[Tuple[int, int], str] = {}
228
+ await self._recursive_validate(
229
+ texts,
230
+ resp_map,
231
+ approval_map,
232
+ reset_files=reset_files,
233
+ max_rounds=max_rounds,
234
+ )
235
+
236
+ # Assemble the final columns. When multiple revisions are
237
+ # requested, each revision will occupy its own column with a
238
+ # numeric suffix.
239
+ col_names = [base_col] if n == 1 else [f"{base_col}_{i}" for i in range(1, n + 1)]
240
+ approval_cols = (
241
+ [f"{base_col}_approved"]
242
+ if n == 1
243
+ else [f"{col}_approved" for col in col_names]
244
+ )
245
+ for j, col in enumerate(col_names):
246
+ df_proc[col] = [resp_map.get((i, j), "") for i in range(len(df_proc))]
247
+ df_proc[approval_cols[j]] = [
248
+ bool(approval_map.get((i, j), True)) for i in range(len(df_proc))
249
+ ]
250
+
251
+ # Persist the cleaned and validated DataFrame to disk. This file
252
+ # excludes metadata columns such as ``Identifier`` or ``Response``
253
+ # and is intended for downstream analysis.
254
+ out_path = os.path.join(
255
+ self.cfg.save_dir,
256
+ f"{os.path.splitext(self.cfg.file_name)[0]}_cleaned.csv",
257
+ )
258
+ df_proc.to_csv(out_path, index=False)
259
+ return df_proc
260
+
261
+ async def _recursive_validate(
262
+ self,
263
+ original_texts: List[str],
264
+ resp_map: Dict[Tuple[int, int], str],
265
+ approval_map: Dict[Tuple[int, int], bool],
266
+ *,
267
+ reset_files: bool = False,
268
+ max_rounds: int,
269
+ ) -> None:
270
+ """
271
+ Generate and validate paraphrases for each passage using a
272
+ classifier. This routine unifies initial and subsequent
273
+ candidate generation by allowing a configurable number of
274
+ candidates on the first round (``n_initial_candidates``) and a
275
+ separate number for later rounds (``n_validation_candidates``).
276
+ Candidates that pass validation are accepted immediately. For
277
+ candidates that fail, further paraphrases are generated until
278
+ either a valid paraphrase is found or no new paraphrases can
279
+ be produced.
280
+
281
+ If ``use_modified_source`` is ``True``, subsequent rounds will
282
+ generate new paraphrases from the most recently chosen
283
+ paraphrase rather than starting from the original text. In
284
+ either case, the classifier always evaluates the modified
285
+ candidate against the original text to ensure the original
286
+ instruction has been followed.
287
+ """
288
+ # Determine the number of revisions (columns) to produce. At
289
+ # least one revision is always generated. This mirrors the logic
290
+ # in :meth:`run`.
291
+ n_revs = self.cfg.n_revisions if self.cfg.n_revisions and self.cfg.n_revisions > 0 else 1
292
+ # Build a list of keys for every passage/revision pair. Keys
293
+ # encode the row index and zero-based revision index.
294
+ all_keys: List[Tuple[int, int]] = [
295
+ (row_idx, rev_idx)
296
+ for row_idx in range(len(original_texts))
297
+ for rev_idx in range(n_revs)
298
+ ]
299
+
300
+ # We'll use this list to track which keys still require
301
+ # validation in each round. Initially, all keys are awaiting
302
+ # generation and validation.
303
+ to_check: List[Tuple[int, int]] = list(all_keys)
304
+ round_number = 0
305
+
306
+ # Create the classifier configuration once. A dedicated
307
+ # validation directory is used to store classification results.
308
+ validation_dir = os.path.join(self.cfg.save_dir, "validation")
309
+ os.makedirs(validation_dir, exist_ok=True)
310
+ # A single label is used to indicate whether the instructions
311
+ # were followed. The definition is intentionally phrased in a
312
+ # slightly more permissive way than before to reduce the false
313
+ # rejection rate.
314
+ labels = {
315
+ "instructions_followed": (
316
+ "Return True if the instructions were largely (even if not perfectly) followed in turning the "
317
+ "original text into the modified text (i.e. the modified text mostly exhibits the spirit of the instructions "
318
+ "even if not everything is exact). Be quite forgiving; understand that the modifications won't be perfect. "
319
+ "Ensure the spirit of the instructions is followed, even if not word for word. "
320
+ "False otherwise, if there are still important shortcomings in the modified text vis a vis the instructions."
321
+ )
322
+ }
323
+ classify_cfg = ClassifyConfig(
324
+ labels=labels,
325
+ save_dir=validation_dir,
326
+ model=self.cfg.model,
327
+ n_parallels=self.cfg.n_parallels,
328
+ n_runs=1,
329
+ use_dummy=self.cfg.use_dummy,
330
+ reasoning_effort=self.cfg.reasoning_effort,
331
+ reasoning_summary=self.cfg.reasoning_summary,
332
+ )
333
+ classifier = Classify(classify_cfg)
334
+
335
+ # Continue looping until there are no passages left to validate
336
+ # or until we hit the round limit.
337
+ last_candidate_map: Dict[Tuple[int, int], List[str]] = {}
338
+ while to_check and round_number < max_rounds:
339
+ # Determine how many candidate paraphrases to generate per
340
+ # passage for this round. The first round uses
341
+ # ``n_initial_candidates``; later rounds use
342
+ # ``n_validation_candidates``.
343
+ if round_number == 0:
344
+ candidates_per_key = max(self.cfg.n_initial_candidates, 1)
345
+ else:
346
+ candidates_per_key = max(self.cfg.n_validation_candidates * round_number, 1)
347
+
348
+ # Build paraphrase prompts for every key still requiring
349
+ # validation. Each key may produce multiple candidates.
350
+ prompts: List[str] = []
351
+ identifiers: List[str] = []
352
+ for key in to_check:
353
+ row_idx, rev_idx = key
354
+ # Choose the base text according to whether we reuse
355
+ # modified text in later rounds. On the first round we
356
+ # always use the original text. On subsequent rounds,
357
+ # if ``use_modified_source`` is true and a paraphrase
358
+ # exists for this key, use that paraphrase as the base
359
+ # for regeneration. Otherwise, continue to use the
360
+ # original.
361
+ if round_number > 0 and self.cfg.use_modified_source and key in resp_map:
362
+ base_text = resp_map[key]
363
+ else:
364
+ base_text = original_texts[row_idx]
365
+ for cand_idx in range(candidates_per_key):
366
+ prompts.append(
367
+ self.template.render(text=base_text, instructions=self.cfg.instructions)
368
+ )
369
+ # Encode row, revision, round and candidate index in
370
+ # the identifier. Revision numbers are stored one-
371
+ # based in the identifier for backwards compatibility.
372
+ identifiers.append(
373
+ f"row_{row_idx}_rev{rev_idx + 1}_round{round_number}_cand{cand_idx}"
374
+ )
375
+
376
+ # If no prompts were constructed (which may happen if
377
+ # ``candidates_per_key`` is zero), break to avoid an
378
+ # infinite loop.
379
+ if not prompts:
380
+ break
381
+
382
+ announce_prompt_rendering("Paraphrase:validate", len(prompts))
383
+
384
+ # Write the prompts to the paraphrasing API. We construct a
385
+ # unique filename for each round to preserve intermediate
386
+ # results. The responses are appended to any existing
387
+ # files unless ``reset_files`` is true.
388
+ tmp_save_path = os.path.join(
389
+ self.cfg.save_dir,
390
+ f"{os.path.splitext(self.cfg.file_name)[0]}_round{round_number}.csv",
391
+ )
392
+ new_resp_df = await get_all_responses(
393
+ prompts=prompts,
394
+ identifiers=identifiers,
395
+ save_path=tmp_save_path,
396
+ model=self.cfg.model,
397
+ json_mode=self.cfg.json_mode,
398
+ web_search=self.cfg.web_search,
399
+ n_parallels=self.cfg.n_parallels,
400
+ use_dummy=self.cfg.use_dummy,
401
+ reset_files=reset_files,
402
+ reasoning_effort=self.cfg.reasoning_effort,
403
+ reasoning_summary=self.cfg.reasoning_summary,
404
+ )
405
+
406
+ # Organise API responses by (row, revision) so that
407
+ # classification can be performed in bulk. Each key
408
+ # corresponds to a list of candidate paraphrases.
409
+ candidate_map: Dict[Tuple[int, int], List[str]] = {
410
+ key: [] for key in to_check
411
+ }
412
+ for ident, resp in zip(new_resp_df["Identifier"], new_resp_df["Response"]):
413
+ text = resp[0] if isinstance(resp, list) and resp else resp
414
+ # Parse the identifier back into row and revision indices.
415
+ m = re.match(r"row_(\d+)_rev(\d+)_round\d+_cand(\d+)", ident)
416
+ if m:
417
+ row_i = int(m.group(1))
418
+ rev_i = int(m.group(2)) - 1
419
+ candidate_map.setdefault((row_i, rev_i), []).append(text)
420
+
421
+ last_candidate_map = candidate_map
422
+
423
+ # Build classification prompts for every candidate across all
424
+ # keys. We keep a parallel list of (row, revision, candidate
425
+ # index) so we can map results back after classification.
426
+ cand_prompts: List[str] = []
427
+ cand_keys: List[Tuple[int, int, int]] = []
428
+ for key in to_check:
429
+ row_idx, rev_idx = key
430
+ orig_text = original_texts[row_idx]
431
+ candidates = candidate_map.get(key, [])
432
+ for cand_index, cand_text in enumerate(candidates):
433
+ cand_text = cand_text or ""
434
+ cls_prompt = (
435
+ "INSTRUCTIONS:\n"
436
+ f"{self.cfg.instructions.strip()}\n\n"
437
+ "BEGIN ORIGINAL TEXT:\n"
438
+ f"{orig_text.strip()}\n"
439
+ "END ORIGINAL TEXT\n\n"
440
+ "BEGIN MODIFIED TEXT:\n"
441
+ f"{cand_text.strip()}\n"
442
+ "END MODIFIED TEXT\n\n"
443
+ "Previously, the original text was taken and modified, following the provided instructions, to create the modified text. "
444
+ "Does the modified text faithfully apply the instructions as a transformation of "
445
+ "the original text? Answer True if the modification follows "
446
+ "the instructions to a satisfactory, though not necessarily perfect, degree. "
447
+ "Tolerate some imperfection and inconsistency, as long as the spirit of the instructions is obeyed to the extent that is reasonable. Be reasonably forgiving. "
448
+ "Again, the modification instructions that need to be validated are: "
449
+ f"{self.cfg.instructions.strip()}\n\n"
450
+ )
451
+ cand_prompts.append(cls_prompt)
452
+ cand_keys.append((row_idx, rev_idx, cand_index))
453
+
454
+ # Run the classifier on all candidate prompts. If there are
455
+ # no candidates (which should not occur), produce an empty
456
+ # DataFrame to avoid indexing errors.
457
+ if cand_prompts:
458
+ cand_df = pd.DataFrame({"text": cand_prompts})
459
+ cand_res_df = await classifier.run(
460
+ cand_df, column_name="text", reset_files=reset_files
461
+ )
462
+ else:
463
+ cand_res_df = pd.DataFrame()
464
+
465
+ # Build a lookup from (row, revision) to a list of boolean
466
+ # classification results corresponding to each candidate. A
467
+ # candidate passes if the classifier returns True or None. A
468
+ # None value indicates uncertainty but is treated as a pass
469
+ # here to reduce the false rejection rate.
470
+ cand_results_map: Dict[Tuple[int, int], List[bool]] = {
471
+ key: [False] * len(candidate_map.get(key, [])) for key in to_check
472
+ }
473
+ for idx, (row_idx, rev_idx, cand_index) in enumerate(cand_keys):
474
+ if not cand_res_df.empty:
475
+ flag = cand_res_df.loc[idx, "instructions_followed"]
476
+ else:
477
+ flag = None
478
+ # Treat None (uncertain) as a pass and only count False
479
+ # values as failures.
480
+ cand_results_map[(row_idx, rev_idx)][cand_index] = bool(flag) or flag is None
481
+
482
+ # Determine which passages still require another round and
483
+ # select the best candidate for each key. For each key, the
484
+ # first passing candidate is chosen. If no candidate passed,
485
+ # the first candidate is chosen as a fallback and the key
486
+ # remains scheduled for another round.
487
+ next_to_check: List[Tuple[int, int]] = []
488
+ for key in to_check:
489
+ row_idx, rev_idx = key
490
+ candidates = candidate_map.get(key, [])
491
+ results = cand_results_map.get(key, [])
492
+ chosen_text: Optional[str] = None
493
+ passed_flag = False
494
+ for cand_text, passed in zip(candidates, results):
495
+ if passed:
496
+ chosen_text = cand_text
497
+ passed_flag = True
498
+ break
499
+ # If no candidate passed and at least one candidate
500
+ # exists, choose the first candidate as a fallback.
501
+ if chosen_text is None and candidates:
502
+ chosen_text = candidates[0]
503
+ passed_flag = False
504
+ # Update the response map with the chosen paraphrase.
505
+ if chosen_text is not None:
506
+ resp_map[(row_idx, rev_idx)] = chosen_text
507
+ approval_map[(row_idx, rev_idx)] = passed_flag
508
+ # If the candidate did not pass validation, schedule
509
+ # another round.
510
+ if not passed_flag:
511
+ next_to_check.append(key)
512
+ else:
513
+ # If no candidates were produced (which should not
514
+ # happen), keep the key for another round to avoid
515
+ # losing the entry entirely.
516
+ next_to_check.append(key)
517
+
518
+ # Prepare for the next round.
519
+ to_check = next_to_check
520
+ round_number += 1
521
+
522
+ # If we exited because we hit the round limit, ensure all
523
+ # remaining keys are assigned a paraphrase and marked as not
524
+ # approved. When candidates were produced in the last round, we
525
+ # randomly choose one of them as the final fallback.
526
+ if to_check:
527
+ for key in to_check:
528
+ candidates = last_candidate_map.get(key, [])
529
+ if candidates:
530
+ resp_map[key] = random.choice(candidates)
531
+ approval_map[key] = False