codebook-lab 1.2.0__tar.gz → 1.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/PKG-INFO +1 -1
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/annotate.py +323 -48
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/conditions.py +16 -3
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/experiments.py +11 -1
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/metrics.py +317 -278
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/prompts.py +4 -1
- codebook_lab-1.3.0/codebook_lab/span_metrics.py +236 -0
- codebook_lab-1.3.0/codebook_lab/span_value.py +108 -0
- codebook_lab-1.3.0/codebook_lab/tasks/discrete-emotions/codebook.json +52 -0
- codebook_lab-1.3.0/codebook_lab/tasks/discrete-emotions/ground-truth.csv +7 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/types.py +8 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab.egg-info/PKG-INFO +1 -1
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab.egg-info/SOURCES.txt +11 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/pyproject.toml +1 -1
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/tests/test_conditions.py +6 -6
- codebook_lab-1.3.0/tests/test_example_tasks.py +78 -0
- codebook_lab-1.3.0/tests/test_invalid_response_handling.py +140 -0
- codebook_lab-1.3.0/tests/test_metrics_recording.py +151 -0
- codebook_lab-1.3.0/tests/test_metrics_span_integration.py +104 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/tests/test_prompts.py +1 -0
- codebook_lab-1.3.0/tests/test_span_extraction.py +69 -0
- codebook_lab-1.3.0/tests/test_span_metrics.py +101 -0
- codebook_lab-1.3.0/tests/test_span_value.py +115 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/LICENSE +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/README.md +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/__init__.py +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/examples.py +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/human_reliability.py +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/ollama.py +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/py.typed +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/tasks/__init__.py +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/tasks/policy-sentiment/codebook.json +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab/tasks/policy-sentiment/ground-truth.csv +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab.egg-info/dependency_links.txt +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab.egg-info/requires.txt +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/codebook_lab.egg-info/top_level.txt +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/setup.cfg +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/tests/test_examples.py +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/tests/test_experiments.py +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/tests/test_human_reliability.py +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/tests/test_metrics_summary.py +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/tests/test_package_import.py +0 -0
- {codebook_lab-1.2.0 → codebook_lab-1.3.0}/tests/test_types.py +0 -0
|
@@ -3,6 +3,7 @@ import logging
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
import sys
|
|
5
5
|
import time
|
|
6
|
+
from typing import Any, Optional
|
|
6
7
|
|
|
7
8
|
import pandas as pd
|
|
8
9
|
import regex
|
|
@@ -18,13 +19,39 @@ from .conditions import (
|
|
|
18
19
|
normalize_annotation_response_value,
|
|
19
20
|
)
|
|
20
21
|
from .ollama import ensure_ollama_available
|
|
22
|
+
from .span_value import parse_span_value, serialize_span_value
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
class AnnotationResponse(BaseModel):
|
|
24
|
-
"""
|
|
26
|
+
"""Default schema for categorical/numeric/textbox annotation types.
|
|
27
|
+
|
|
28
|
+
Used by ChatOllama structured output to guarantee valid JSON for
|
|
29
|
+
annotation types whose payload is a single string-coercible value
|
|
30
|
+
(checkbox 0/1, likert integers, dropdown choices, textbox free text).
|
|
31
|
+
"""
|
|
25
32
|
response: str
|
|
26
33
|
|
|
27
34
|
|
|
35
|
+
class SpanItem(BaseModel):
|
|
36
|
+
"""One highlighted text span returned by the model."""
|
|
37
|
+
start: int
|
|
38
|
+
end: int
|
|
39
|
+
text: Optional[str] = None
|
|
40
|
+
label: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SpanAnnotationResponse(BaseModel):
|
|
44
|
+
"""Schema used by ChatOllama structured output for span annotations."""
|
|
45
|
+
response: list[SpanItem]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _response_schema_for_type(annotation_type: str) -> type[BaseModel]:
|
|
49
|
+
"""Return the Pydantic schema matching an annotation type."""
|
|
50
|
+
if annotation_type == "span":
|
|
51
|
+
return SpanAnnotationResponse
|
|
52
|
+
return AnnotationResponse
|
|
53
|
+
|
|
54
|
+
|
|
28
55
|
_PROMPT_TEMPLATE = ChatPromptTemplate.from_template("""{question}""")
|
|
29
56
|
from .prompts import PromptContext, get_prompt_type_name, render_prompt
|
|
30
57
|
from .types import AnnotationRunResult
|
|
@@ -77,11 +104,14 @@ class _AnnotationProgressBar:
|
|
|
77
104
|
self.total_steps = max(self.completed_steps, self.total_steps - count)
|
|
78
105
|
|
|
79
106
|
|
|
80
|
-
def _count_annotations(codebook, process_textbox=False):
|
|
107
|
+
def _count_annotations(codebook, process_textbox=False, process_span=False):
|
|
81
108
|
"""Count the maximum number of annotation prompts that could be issued for one row."""
|
|
82
109
|
count = 0
|
|
83
110
|
for _, _, _, annotation in get_annotation_entries(codebook):
|
|
84
|
-
|
|
111
|
+
ann_type = annotation.get("type")
|
|
112
|
+
if ann_type == "textbox" and not process_textbox:
|
|
113
|
+
continue
|
|
114
|
+
if ann_type == "span" and not process_span:
|
|
85
115
|
continue
|
|
86
116
|
count += 1
|
|
87
117
|
return count
|
|
@@ -182,7 +212,15 @@ def setup_model(model_name, temperature=None, top_p=None):
|
|
|
182
212
|
llm = ChatOllama(model=model_name, **model_kwargs)
|
|
183
213
|
return llm
|
|
184
214
|
|
|
185
|
-
def generate_response(
|
|
215
|
+
def generate_response(
|
|
216
|
+
chain,
|
|
217
|
+
prompt,
|
|
218
|
+
char_counts,
|
|
219
|
+
timing_data,
|
|
220
|
+
row_num=None,
|
|
221
|
+
annotation_name=None,
|
|
222
|
+
annotation_type=None,
|
|
223
|
+
):
|
|
186
224
|
"""Run one prompt through the model and update timing/count statistics.
|
|
187
225
|
|
|
188
226
|
Args:
|
|
@@ -192,10 +230,14 @@ def generate_response(chain, prompt, char_counts, timing_data, row_num=None, ann
|
|
|
192
230
|
timing_data: Mutable dict with inference timing counters.
|
|
193
231
|
row_num: Optional 1-based row number for progress logging.
|
|
194
232
|
annotation_name: Optional annotation label for progress logging.
|
|
233
|
+
annotation_type: Annotation type string used to pick the structured
|
|
234
|
+
output schema (``"span"`` uses ``SpanAnnotationResponse``; everything
|
|
235
|
+
else uses ``AnnotationResponse``).
|
|
195
236
|
|
|
196
237
|
Returns:
|
|
197
238
|
Raw model response string, or ``""`` if inference failed.
|
|
198
239
|
"""
|
|
240
|
+
response_schema = _response_schema_for_type(annotation_type or "")
|
|
199
241
|
try:
|
|
200
242
|
# Track input characters
|
|
201
243
|
char_counts['input_chars'] += len(prompt)
|
|
@@ -206,7 +248,7 @@ def generate_response(chain, prompt, char_counts, timing_data, row_num=None, ann
|
|
|
206
248
|
structured_chain = (
|
|
207
249
|
_PROMPT_TEMPLATE
|
|
208
250
|
| chain.with_structured_output(
|
|
209
|
-
|
|
251
|
+
response_schema, method="json_schema", include_raw=True
|
|
210
252
|
)
|
|
211
253
|
)
|
|
212
254
|
|
|
@@ -234,20 +276,94 @@ def generate_response(chain, prompt, char_counts, timing_data, row_num=None, ann
|
|
|
234
276
|
logger.warning("Error generating response: %s", e)
|
|
235
277
|
return ""
|
|
236
278
|
|
|
237
|
-
def
|
|
279
|
+
def _extract_span_response(response, label_options=None, text=None):
|
|
280
|
+
"""Parse a model response into a normalised list of span dicts.
|
|
281
|
+
|
|
282
|
+
Drops spans with missing/invalid offsets, out-of-range offsets, or labels
|
|
283
|
+
outside ``label_options`` (when provided). When ``text`` is available, the
|
|
284
|
+
``text`` field is filled from the offsets to keep the cell self-describing
|
|
285
|
+
even if the model omitted it.
|
|
286
|
+
"""
|
|
287
|
+
pattern = regex.compile(r'\{(?:[^{}]|(?R))*\}')
|
|
288
|
+
array_pattern = regex.compile(r'\[(?:[^\[\]]|(?R))*\]')
|
|
289
|
+
|
|
290
|
+
parsed_value = None
|
|
291
|
+
for json_string in array_pattern.findall(response):
|
|
292
|
+
try:
|
|
293
|
+
candidate = json.loads(json_string)
|
|
294
|
+
except json.JSONDecodeError:
|
|
295
|
+
continue
|
|
296
|
+
if isinstance(candidate, list):
|
|
297
|
+
parsed_value = candidate
|
|
298
|
+
break
|
|
299
|
+
|
|
300
|
+
if parsed_value is None:
|
|
301
|
+
for json_string in pattern.findall(response):
|
|
302
|
+
try:
|
|
303
|
+
candidate = json.loads(json_string)
|
|
304
|
+
except json.JSONDecodeError:
|
|
305
|
+
continue
|
|
306
|
+
if isinstance(candidate, dict) and isinstance(candidate.get("response"), list):
|
|
307
|
+
parsed_value = candidate["response"]
|
|
308
|
+
break
|
|
309
|
+
|
|
310
|
+
if not isinstance(parsed_value, list):
|
|
311
|
+
# No JSON array / {"response": [...]} structure was found at all: treat
|
|
312
|
+
# this as an invalid response (None) so callers can retry. An empty but
|
|
313
|
+
# successfully parsed list is a valid answer ("no spans apply") and is
|
|
314
|
+
# returned as [] by the cleaning loop below.
|
|
315
|
+
return None
|
|
316
|
+
|
|
317
|
+
text_length = len(text) if isinstance(text, str) else None
|
|
318
|
+
allowed_labels = (
|
|
319
|
+
{str(opt) for opt in label_options} if label_options else None
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
cleaned = []
|
|
323
|
+
for entry in parsed_value:
|
|
324
|
+
if not isinstance(entry, dict):
|
|
325
|
+
continue
|
|
326
|
+
try:
|
|
327
|
+
start = int(entry["start"])
|
|
328
|
+
end = int(entry["end"])
|
|
329
|
+
except (KeyError, TypeError, ValueError):
|
|
330
|
+
continue
|
|
331
|
+
if end <= start or start < 0:
|
|
332
|
+
continue
|
|
333
|
+
if text_length is not None and end > text_length:
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
item = {"start": start, "end": end}
|
|
337
|
+
item["text"] = text[start:end] if text_length is not None else str(entry.get("text") or "")
|
|
338
|
+
label = entry.get("label")
|
|
339
|
+
if label:
|
|
340
|
+
label = str(label)
|
|
341
|
+
if allowed_labels is None or label in allowed_labels:
|
|
342
|
+
item["label"] = label
|
|
343
|
+
cleaned.append(item)
|
|
344
|
+
return cleaned
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def extract_json_response(response, annotation_type, min_value=None, max_value=None, options=None,
|
|
348
|
+
label_options=None, text=None):
|
|
238
349
|
"""
|
|
239
350
|
Extract and validate JSON response based on annotation type
|
|
240
|
-
|
|
351
|
+
|
|
241
352
|
Args:
|
|
242
353
|
response: Raw model response text that should contain a JSON object.
|
|
243
354
|
annotation_type: Annotation type string such as ``"dropdown"`` or ``"likert"``.
|
|
244
355
|
min_value: Optional integer lower bound for Likert annotations.
|
|
245
356
|
max_value: Optional integer upper bound for Likert annotations.
|
|
246
357
|
options: Optional dropdown option list used to normalize categorical labels.
|
|
358
|
+
label_options: Allowed labels for span annotations.
|
|
359
|
+
text: Source text for span annotations (used to validate offsets).
|
|
247
360
|
|
|
248
361
|
Returns:
|
|
249
|
-
Parsed response value coerced into the expected annotation format.
|
|
362
|
+
Parsed response value coerced into the expected annotation format. For
|
|
363
|
+
``annotation_type == "span"`` this is a list of span dicts.
|
|
250
364
|
"""
|
|
365
|
+
if annotation_type == "span":
|
|
366
|
+
return _extract_span_response(response, label_options=label_options, text=text)
|
|
251
367
|
pattern = regex.compile(r'\{(?:[^{}]|(?R))*\}')
|
|
252
368
|
json_strings = pattern.findall(response)
|
|
253
369
|
|
|
@@ -279,11 +395,13 @@ def extract_json_response(response, annotation_type, min_value=None, max_value=N
|
|
|
279
395
|
return 1
|
|
280
396
|
elif response_value.lower() in ["no", "false", "0"]:
|
|
281
397
|
return 0
|
|
282
|
-
#
|
|
283
|
-
|
|
398
|
+
# No recognizable boolean value: invalid, so callers can
|
|
399
|
+
# retry/record null rather than silently defaulting to "No".
|
|
400
|
+
return None
|
|
284
401
|
elif annotation_type == "textbox":
|
|
285
|
-
#
|
|
286
|
-
|
|
402
|
+
# Empty text counts as no answer (invalid -> retry/null).
|
|
403
|
+
stripped = str(response_value).strip()
|
|
404
|
+
return stripped or None
|
|
287
405
|
elif annotation_type == "likert":
|
|
288
406
|
# Validate is within range and convert to int
|
|
289
407
|
try:
|
|
@@ -292,10 +410,9 @@ def extract_json_response(response, annotation_type, min_value=None, max_value=N
|
|
|
292
410
|
return max(min_value, min(max_value, value)) # Clamp to range
|
|
293
411
|
return value
|
|
294
412
|
except (ValueError, TypeError):
|
|
295
|
-
#
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
return response_value
|
|
413
|
+
# Not a valid number: invalid, so callers can retry/record
|
|
414
|
+
# null rather than silently defaulting to the scale midpoint.
|
|
415
|
+
return None
|
|
299
416
|
|
|
300
417
|
# Fallback
|
|
301
418
|
return str(response_value).strip() if isinstance(response_value, str) else response_value
|
|
@@ -312,7 +429,7 @@ def extract_json_response(response, annotation_type, min_value=None, max_value=N
|
|
|
312
429
|
return 1
|
|
313
430
|
elif "no" in response.lower() or "false" in response.lower():
|
|
314
431
|
return 0
|
|
315
|
-
return
|
|
432
|
+
return None
|
|
316
433
|
elif annotation_type == "likert" and min_value is not None and max_value is not None:
|
|
317
434
|
# Try to find a number in the response
|
|
318
435
|
numbers = regex.findall(r'\d+', response)
|
|
@@ -323,24 +440,25 @@ def extract_json_response(response, annotation_type, min_value=None, max_value=N
|
|
|
323
440
|
return value
|
|
324
441
|
except ValueError:
|
|
325
442
|
continue
|
|
326
|
-
return
|
|
443
|
+
return None # No in-range number found: invalid -> retry/null
|
|
327
444
|
elif annotation_type == "textbox":
|
|
328
|
-
return stripped_response
|
|
329
|
-
|
|
445
|
+
return stripped_response or None
|
|
446
|
+
|
|
330
447
|
return None
|
|
331
448
|
|
|
332
|
-
def format_prompt(section_name, section_instruction, name, tooltip, annotation_type,
|
|
333
|
-
options=None, min_value=None, max_value=None, example=None,
|
|
334
|
-
text=None, prompt_type="standard", use_examples=False
|
|
449
|
+
def format_prompt(section_name, section_instruction, name, tooltip, annotation_type,
|
|
450
|
+
options=None, min_value=None, max_value=None, example=None,
|
|
451
|
+
text=None, prompt_type="standard", use_examples=False,
|
|
452
|
+
label_options=None):
|
|
335
453
|
"""
|
|
336
454
|
Format the prompt based on annotation type and specified prompt type
|
|
337
|
-
|
|
455
|
+
|
|
338
456
|
Args:
|
|
339
457
|
section_name: Codebook section name.
|
|
340
458
|
section_instruction: Optional section-level instructions.
|
|
341
459
|
name: Annotation name within the section.
|
|
342
460
|
tooltip: Optional guidance text for the annotation.
|
|
343
|
-
annotation_type: One of ``"dropdown"``, ``"checkbox"``, ``"likert"``, or ``"
|
|
461
|
+
annotation_type: One of ``"dropdown"``, ``"checkbox"``, ``"likert"``, ``"textbox"``, or ``"span"``.
|
|
344
462
|
options: Dropdown option list when applicable.
|
|
345
463
|
min_value: Minimum Likert value when applicable.
|
|
346
464
|
max_value: Maximum Likert value when applicable.
|
|
@@ -348,21 +466,22 @@ def format_prompt(section_name, section_instruction, name, tooltip, annotation_t
|
|
|
348
466
|
text: Raw source text being annotated.
|
|
349
467
|
prompt_type: Registered prompt wrapper name or callable wrapper.
|
|
350
468
|
use_examples: Whether examples should be included in the prompt.
|
|
469
|
+
label_options: Allowed labels for span annotations. Ignored for other types.
|
|
351
470
|
|
|
352
471
|
Returns:
|
|
353
472
|
Full prompt string ready to send to the model.
|
|
354
473
|
"""
|
|
355
474
|
# Get response instructions based on annotation type
|
|
356
475
|
response_instructions = _get_response_instructions(
|
|
357
|
-
annotation_type, options, min_value, max_value
|
|
476
|
+
annotation_type, options, min_value, max_value, label_options=label_options
|
|
358
477
|
)
|
|
359
|
-
|
|
478
|
+
|
|
360
479
|
# Build the core prompt that's common to all prompt types
|
|
361
480
|
core_prompt = _build_core_prompt(
|
|
362
|
-
section_name, section_instruction, name, tooltip,
|
|
481
|
+
section_name, section_instruction, name, tooltip,
|
|
363
482
|
response_instructions, example, use_examples
|
|
364
483
|
)
|
|
365
|
-
|
|
484
|
+
|
|
366
485
|
context = PromptContext(
|
|
367
486
|
section_name=section_name,
|
|
368
487
|
section_instruction=section_instruction,
|
|
@@ -372,6 +491,7 @@ def format_prompt(section_name, section_instruction, name, tooltip, annotation_t
|
|
|
372
491
|
options=options,
|
|
373
492
|
min_value=min_value,
|
|
374
493
|
max_value=max_value,
|
|
494
|
+
label_options=label_options,
|
|
375
495
|
example=example or "",
|
|
376
496
|
text=text or "",
|
|
377
497
|
use_examples=use_examples,
|
|
@@ -381,7 +501,13 @@ def format_prompt(section_name, section_instruction, name, tooltip, annotation_t
|
|
|
381
501
|
return render_prompt(prompt_type, context)
|
|
382
502
|
|
|
383
503
|
|
|
384
|
-
def _get_response_instructions(
|
|
504
|
+
def _get_response_instructions(
|
|
505
|
+
annotation_type,
|
|
506
|
+
options=None,
|
|
507
|
+
min_value=None,
|
|
508
|
+
max_value=None,
|
|
509
|
+
label_options=None,
|
|
510
|
+
):
|
|
385
511
|
"""Generate type-specific response instructions for a prompt.
|
|
386
512
|
|
|
387
513
|
Args:
|
|
@@ -389,6 +515,8 @@ def _get_response_instructions(annotation_type, options=None, min_value=None, ma
|
|
|
389
515
|
options: Dropdown options when ``annotation_type`` is ``"dropdown"``.
|
|
390
516
|
min_value: Likert minimum when applicable.
|
|
391
517
|
max_value: Likert maximum when applicable.
|
|
518
|
+
label_options: Allowed labels when ``annotation_type`` is ``"span"`` and
|
|
519
|
+
the annotation is labelled. ``None`` or empty for plain highlights.
|
|
392
520
|
|
|
393
521
|
Returns:
|
|
394
522
|
Instruction string describing the expected response format.
|
|
@@ -402,6 +530,22 @@ def _get_response_instructions(annotation_type, options=None, min_value=None, ma
|
|
|
402
530
|
return f"Respond with a whole number from {min_value} to {max_value} (inclusive), where {min_value} means lowest and {max_value} means highest."
|
|
403
531
|
elif annotation_type == "textbox":
|
|
404
532
|
return "Respond with a brief text explanation."
|
|
533
|
+
elif annotation_type == "span":
|
|
534
|
+
if label_options:
|
|
535
|
+
labels_str = ', or '.join(f'"{option}"' for option in label_options)
|
|
536
|
+
return (
|
|
537
|
+
"Respond with a JSON array of objects, each shaped like "
|
|
538
|
+
'{"start": <int>, "end": <int>, "text": "<quoted span>", '
|
|
539
|
+
f'"label": <one of {labels_str}>}}. '
|
|
540
|
+
"Use 0-indexed character offsets into the text. "
|
|
541
|
+
"Return [] if no spans apply."
|
|
542
|
+
)
|
|
543
|
+
return (
|
|
544
|
+
"Respond with a JSON array of objects, each shaped like "
|
|
545
|
+
'{"start": <int>, "end": <int>, "text": "<quoted span>"}. '
|
|
546
|
+
"Use 0-indexed character offsets into the text. "
|
|
547
|
+
"Return [] if no spans apply."
|
|
548
|
+
)
|
|
405
549
|
return ""
|
|
406
550
|
|
|
407
551
|
|
|
@@ -473,9 +617,99 @@ def _normalize_optional_parameter(value):
|
|
|
473
617
|
return None
|
|
474
618
|
return value
|
|
475
619
|
|
|
620
|
+
RETRY_STRATEGIES = ("identical", "reprompt", "temperature")
|
|
621
|
+
DEFAULT_RETRY_TEMPERATURE = 0.3
|
|
622
|
+
_RETRY_REMINDER = (
|
|
623
|
+
"\n\nIMPORTANT: A previous attempt could not be parsed. Respond with ONLY the "
|
|
624
|
+
"JSON described above, in exactly that format, with no extra commentary."
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
def normalize_retry_strategy(strategy):
|
|
629
|
+
"""Return a supported retry strategy, falling back to ``"identical"``."""
|
|
630
|
+
strategy = str(strategy or "identical").strip().lower()
|
|
631
|
+
return strategy if strategy in RETRY_STRATEGIES else "identical"
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def _generate_and_extract(
|
|
635
|
+
*,
|
|
636
|
+
chain,
|
|
637
|
+
retry_chain,
|
|
638
|
+
prompt,
|
|
639
|
+
char_counts,
|
|
640
|
+
timing_data,
|
|
641
|
+
row_num,
|
|
642
|
+
annotation_full_name,
|
|
643
|
+
annotation_type,
|
|
644
|
+
min_value,
|
|
645
|
+
max_value,
|
|
646
|
+
options,
|
|
647
|
+
label_options,
|
|
648
|
+
text,
|
|
649
|
+
retries,
|
|
650
|
+
retry_strategy,
|
|
651
|
+
):
|
|
652
|
+
"""Generate and extract one annotation, retrying invalid responses.
|
|
653
|
+
|
|
654
|
+
A response is "invalid" when :func:`extract_json_response` returns ``None``
|
|
655
|
+
(unparseable, empty, or out-of-codebook). On each retry the request is
|
|
656
|
+
re-issued according to ``retry_strategy``:
|
|
657
|
+
|
|
658
|
+
* ``"identical"`` (default): re-run the same prompt and model.
|
|
659
|
+
* ``"reprompt"``: append a short format reminder to the prompt.
|
|
660
|
+
* ``"temperature"``: re-run against ``retry_chain`` (a model built at a
|
|
661
|
+
higher temperature) so a deterministic config can still vary its output.
|
|
662
|
+
|
|
663
|
+
Returns the extracted value, or ``None`` if every attempt was invalid.
|
|
664
|
+
"""
|
|
665
|
+
strategy = normalize_retry_strategy(retry_strategy)
|
|
666
|
+
attempts = max(1, 1 + int(retries))
|
|
667
|
+
for attempt in range(attempts):
|
|
668
|
+
active_chain = chain
|
|
669
|
+
active_prompt = prompt
|
|
670
|
+
if attempt > 0:
|
|
671
|
+
if strategy == "reprompt":
|
|
672
|
+
active_prompt = prompt + _RETRY_REMINDER
|
|
673
|
+
elif strategy == "temperature" and retry_chain is not None:
|
|
674
|
+
active_chain = retry_chain
|
|
675
|
+
|
|
676
|
+
response_text = generate_response(
|
|
677
|
+
active_chain,
|
|
678
|
+
active_prompt,
|
|
679
|
+
char_counts,
|
|
680
|
+
timing_data,
|
|
681
|
+
row_num=row_num,
|
|
682
|
+
annotation_name=annotation_full_name,
|
|
683
|
+
annotation_type=annotation_type,
|
|
684
|
+
)
|
|
685
|
+
value = extract_json_response(
|
|
686
|
+
response_text,
|
|
687
|
+
annotation_type,
|
|
688
|
+
min_value,
|
|
689
|
+
max_value,
|
|
690
|
+
options=options,
|
|
691
|
+
label_options=label_options,
|
|
692
|
+
text=text,
|
|
693
|
+
)
|
|
694
|
+
if value is not None:
|
|
695
|
+
return value
|
|
696
|
+
if attempt + 1 < attempts:
|
|
697
|
+
logger.info(
|
|
698
|
+
"Invalid response for %s (attempt %d/%d); retrying with strategy '%s'.",
|
|
699
|
+
annotation_full_name, attempt + 1, attempts, strategy,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
logger.warning(
|
|
703
|
+
"No valid response for %s after %d attempt(s); recording null.",
|
|
704
|
+
annotation_full_name, attempts,
|
|
705
|
+
)
|
|
706
|
+
return None
|
|
707
|
+
|
|
708
|
+
|
|
476
709
|
def classify_text(chain, text, codebook, prompt_type="standard", use_examples=False,
|
|
477
710
|
char_counts=None, timing_data=None, process_textbox=False, row_num=None,
|
|
478
|
-
progress_bar=None, total_rows=None
|
|
711
|
+
progress_bar=None, total_rows=None, process_span=False,
|
|
712
|
+
retries=1, retry_strategy="identical", retry_chain=None):
|
|
479
713
|
"""Annotate one text row across all sections in a codebook.
|
|
480
714
|
|
|
481
715
|
Args:
|
|
@@ -517,6 +751,11 @@ def classify_text(chain, text, codebook, prompt_type="standard", use_examples=Fa
|
|
|
517
751
|
progress_bar.skip()
|
|
518
752
|
continue
|
|
519
753
|
|
|
754
|
+
if annotation_type == "span" and not process_span:
|
|
755
|
+
if progress_bar is not None:
|
|
756
|
+
progress_bar.skip()
|
|
757
|
+
continue
|
|
758
|
+
|
|
520
759
|
if not is_annotation_applicable(codebook, section_key, annotation_key, responses):
|
|
521
760
|
responses[column_name] = None
|
|
522
761
|
if progress_bar is not None:
|
|
@@ -529,12 +768,15 @@ def classify_text(chain, text, codebook, prompt_type="standard", use_examples=Fa
|
|
|
529
768
|
options = None
|
|
530
769
|
min_value = None
|
|
531
770
|
max_value = None
|
|
771
|
+
label_options = None
|
|
532
772
|
|
|
533
773
|
if annotation_type == "dropdown":
|
|
534
774
|
options = annotation.get('options', [])
|
|
535
775
|
elif annotation_type == "likert":
|
|
536
776
|
min_value = annotation.get('min_value')
|
|
537
777
|
max_value = annotation.get('max_value')
|
|
778
|
+
elif annotation_type == "span":
|
|
779
|
+
label_options = annotation.get('label_options', []) or None
|
|
538
780
|
|
|
539
781
|
prompt = format_prompt(
|
|
540
782
|
section_name,
|
|
@@ -548,34 +790,46 @@ def classify_text(chain, text, codebook, prompt_type="standard", use_examples=Fa
|
|
|
548
790
|
example,
|
|
549
791
|
text,
|
|
550
792
|
prompt_type=prompt_type,
|
|
551
|
-
use_examples=use_examples
|
|
793
|
+
use_examples=use_examples,
|
|
794
|
+
label_options=label_options,
|
|
552
795
|
)
|
|
553
796
|
|
|
554
|
-
|
|
555
|
-
chain,
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
797
|
+
response_value = _generate_and_extract(
|
|
798
|
+
chain=chain,
|
|
799
|
+
retry_chain=retry_chain,
|
|
800
|
+
prompt=prompt,
|
|
801
|
+
char_counts=char_counts,
|
|
802
|
+
timing_data=timing_data,
|
|
559
803
|
row_num=row_num,
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
annotation_type,
|
|
565
|
-
min_value,
|
|
566
|
-
max_value,
|
|
804
|
+
annotation_full_name=annotation_full_name,
|
|
805
|
+
annotation_type=annotation_type,
|
|
806
|
+
min_value=min_value,
|
|
807
|
+
max_value=max_value,
|
|
567
808
|
options=options,
|
|
809
|
+
label_options=label_options,
|
|
810
|
+
text=text,
|
|
811
|
+
retries=retries,
|
|
812
|
+
retry_strategy=retry_strategy,
|
|
568
813
|
)
|
|
569
814
|
|
|
570
|
-
|
|
815
|
+
if annotation_type == "span":
|
|
816
|
+
# Spans round-trip through CSV as JSON-encoded strings so the file
|
|
817
|
+
# survives standard CSV tooling (the Studio annotation page uses the
|
|
818
|
+
# same convention). A None result (no valid response) serializes to "".
|
|
819
|
+
responses[column_name] = serialize_span_value(response_value)
|
|
820
|
+
else:
|
|
821
|
+
# response_value is None when no valid response was extracted, which
|
|
822
|
+
# is stored as a blank cell rather than a fabricated default.
|
|
823
|
+
responses[column_name] = response_value
|
|
571
824
|
|
|
572
825
|
if progress_bar is not None and row_num is not None and total_rows is not None:
|
|
573
826
|
progress_bar.update(row_num, total_rows, annotation_full_name)
|
|
574
827
|
|
|
575
828
|
return responses, char_counts, timing_data
|
|
576
829
|
|
|
577
|
-
def apply_classification_to_csv(csv_path, output_path, codebook, chain, prompt_type="standard",
|
|
578
|
-
use_examples=False, process_textbox=False
|
|
830
|
+
def apply_classification_to_csv(csv_path, output_path, codebook, chain, prompt_type="standard",
|
|
831
|
+
use_examples=False, process_textbox=False, process_span=False,
|
|
832
|
+
retries=1, retry_strategy="identical", retry_chain=None):
|
|
579
833
|
"""Run annotation over every row in an input CSV and write incremental results.
|
|
580
834
|
|
|
581
835
|
Args:
|
|
@@ -594,7 +848,7 @@ def apply_classification_to_csv(csv_path, output_path, codebook, chain, prompt_t
|
|
|
594
848
|
|
|
595
849
|
logger.info("Starting classification of %d rows", len(df))
|
|
596
850
|
|
|
597
|
-
annotations_per_row = _count_annotations(codebook, process_textbox)
|
|
851
|
+
annotations_per_row = _count_annotations(codebook, process_textbox, process_span)
|
|
598
852
|
total_steps = len(df) * annotations_per_row
|
|
599
853
|
progress_bar = _AnnotationProgressBar(total_steps)
|
|
600
854
|
|
|
@@ -627,6 +881,10 @@ def apply_classification_to_csv(csv_path, output_path, codebook, chain, prompt_t
|
|
|
627
881
|
row_num=row_num,
|
|
628
882
|
progress_bar=progress_bar,
|
|
629
883
|
total_rows=len(df),
|
|
884
|
+
process_span=process_span,
|
|
885
|
+
retries=retries,
|
|
886
|
+
retry_strategy=retry_strategy,
|
|
887
|
+
retry_chain=retry_chain,
|
|
630
888
|
)
|
|
631
889
|
|
|
632
890
|
# Add annotations to row data
|
|
@@ -667,8 +925,12 @@ def run_annotation(
|
|
|
667
925
|
temperature=None,
|
|
668
926
|
top_p=None,
|
|
669
927
|
process_textbox=False,
|
|
928
|
+
process_span=False,
|
|
670
929
|
country_iso_code="USA",
|
|
671
930
|
start_ollama_if_needed=True,
|
|
931
|
+
retries=1,
|
|
932
|
+
retry_strategy="identical",
|
|
933
|
+
retry_temperature=DEFAULT_RETRY_TEMPERATURE,
|
|
672
934
|
):
|
|
673
935
|
"""Run one annotation job and persist its outputs to disk.
|
|
674
936
|
|
|
@@ -710,8 +972,11 @@ def run_annotation(
|
|
|
710
972
|
"prompt_type": prompt_type_name,
|
|
711
973
|
"use_examples": bool(use_examples),
|
|
712
974
|
"process_textbox": bool(process_textbox),
|
|
975
|
+
"process_span": bool(process_span),
|
|
713
976
|
"country_iso_code": country_iso_code,
|
|
714
977
|
"task_name": task_name,
|
|
978
|
+
"retries": int(retries),
|
|
979
|
+
"retry_strategy": normalize_retry_strategy(retry_strategy),
|
|
715
980
|
}
|
|
716
981
|
if temperature is not None:
|
|
717
982
|
config["temperature"] = temperature
|
|
@@ -740,6 +1005,12 @@ def run_annotation(
|
|
|
740
1005
|
|
|
741
1006
|
try:
|
|
742
1007
|
chain = setup_model(model, temperature, top_p)
|
|
1008
|
+
# For the "temperature" retry strategy, build a second chain at a higher
|
|
1009
|
+
# temperature so retries can diverge from an otherwise deterministic run.
|
|
1010
|
+
retry_strategy_name = normalize_retry_strategy(retry_strategy)
|
|
1011
|
+
retry_chain = None
|
|
1012
|
+
if retry_strategy_name == "temperature":
|
|
1013
|
+
retry_chain = setup_model(model, retry_temperature, top_p)
|
|
743
1014
|
classified_df, char_counts, timing_data = apply_classification_to_csv(
|
|
744
1015
|
str(csv_path),
|
|
745
1016
|
str(output_path),
|
|
@@ -748,6 +1019,10 @@ def run_annotation(
|
|
|
748
1019
|
prompt_type,
|
|
749
1020
|
bool(use_examples),
|
|
750
1021
|
bool(process_textbox),
|
|
1022
|
+
bool(process_span),
|
|
1023
|
+
retries=retries,
|
|
1024
|
+
retry_strategy=retry_strategy_name,
|
|
1025
|
+
retry_chain=retry_chain,
|
|
751
1026
|
)
|
|
752
1027
|
finally:
|
|
753
1028
|
emissions = tracker.stop()
|
|
@@ -4,6 +4,8 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import pandas as pd
|
|
6
6
|
|
|
7
|
+
from .span_value import parse_span_value
|
|
8
|
+
|
|
7
9
|
|
|
8
10
|
def get_sorted_annotation_keys(section_content: dict[str, Any]) -> list[str]:
|
|
9
11
|
"""Return annotation keys in the same stable order used by CodeBook Studio."""
|
|
@@ -64,10 +66,21 @@ def get_annotation_condition(annotation: dict[str, Any]) -> dict[str, Any] | Non
|
|
|
64
66
|
|
|
65
67
|
def normalize_annotation_response_value(annotation: dict[str, Any], value: Any) -> Any:
|
|
66
68
|
"""Coerce stored responses into stable comparable values."""
|
|
67
|
-
if pd.isna(value):
|
|
68
|
-
return None
|
|
69
|
-
|
|
70
69
|
annotation_type = annotation.get("type", "dropdown")
|
|
70
|
+
if annotation_type == "span":
|
|
71
|
+
# Spans round-trip as a list of dicts; preserve the structure so
|
|
72
|
+
# downstream code (metrics, conditions) can reason about them. Spans
|
|
73
|
+
# are not valid condition triggers, so the value is mostly inert here.
|
|
74
|
+
if isinstance(value, list):
|
|
75
|
+
return value
|
|
76
|
+
return parse_span_value(value)
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
if pd.isna(value):
|
|
80
|
+
return None
|
|
81
|
+
except (TypeError, ValueError):
|
|
82
|
+
return value
|
|
83
|
+
|
|
71
84
|
if annotation_type == "dropdown":
|
|
72
85
|
normalized = str(value).strip().strip("`").strip()
|
|
73
86
|
if normalized == "":
|