codebook-lab 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codebook_lab/__init__.py +69 -0
- codebook_lab/annotate.py +742 -0
- codebook_lab/examples.py +87 -0
- codebook_lab/experiments.py +319 -0
- codebook_lab/metrics.py +1422 -0
- codebook_lab/ollama.py +117 -0
- codebook_lab/prompts.py +146 -0
- codebook_lab/py.typed +0 -0
- codebook_lab/tasks/__init__.py +1 -0
- codebook_lab/tasks/policy-sentiment/codebook.json +42 -0
- codebook_lab/tasks/policy-sentiment/ground-truth.csv +21 -0
- codebook_lab/types.py +116 -0
- codebook_lab-1.0.0.dist-info/METADATA +338 -0
- codebook_lab-1.0.0.dist-info/RECORD +17 -0
- codebook_lab-1.0.0.dist-info/WHEEL +5 -0
- codebook_lab-1.0.0.dist-info/licenses/LICENSE +661 -0
- codebook_lab-1.0.0.dist-info/top_level.txt +1 -0
codebook_lab/annotate.py
ADDED
|
@@ -0,0 +1,742 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import sys
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import regex
|
|
9
|
+
from codecarbon import OfflineEmissionsTracker
|
|
10
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
11
|
+
from langchain_ollama.llms import OllamaLLM
|
|
12
|
+
|
|
13
|
+
from .ollama import ensure_ollama_available
|
|
14
|
+
from .prompts import PromptContext, get_prompt_type_name, render_prompt
|
|
15
|
+
from .types import AnnotationRunResult
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class _AnnotationProgressBar:
|
|
21
|
+
"""Render a compact terminal progress bar for annotation runs."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, total_steps: int, enabled: bool | None = None) -> None:
|
|
24
|
+
self.total_steps = max(total_steps, 0)
|
|
25
|
+
self.completed_steps = 0
|
|
26
|
+
self.enabled = sys.stderr.isatty() if enabled is None else enabled
|
|
27
|
+
self._last_message = ""
|
|
28
|
+
|
|
29
|
+
def update(self, row_num: int, total_rows: int, annotation_name: str) -> None:
|
|
30
|
+
"""Advance the bar by one annotation and redraw it."""
|
|
31
|
+
if self.total_steps == 0:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
self.completed_steps += 1
|
|
35
|
+
if not self.enabled:
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
width = 28
|
|
39
|
+
progress = self.completed_steps / self.total_steps
|
|
40
|
+
filled = int(width * progress)
|
|
41
|
+
bar = "#" * filled + "-" * (width - filled)
|
|
42
|
+
message = (
|
|
43
|
+
f"\rAnnotating [{bar}] {self.completed_steps}/{self.total_steps} "
|
|
44
|
+
f"({progress:.0%}) row {row_num}/{total_rows} {annotation_name}"
|
|
45
|
+
)
|
|
46
|
+
message = message[:140]
|
|
47
|
+
padding = max(0, len(self._last_message) - len(message))
|
|
48
|
+
sys.stderr.write(message + (" " * padding))
|
|
49
|
+
sys.stderr.flush()
|
|
50
|
+
self._last_message = message
|
|
51
|
+
|
|
52
|
+
def finish(self) -> None:
|
|
53
|
+
"""Terminate the in-place progress bar cleanly."""
|
|
54
|
+
if self.enabled and self.total_steps > 0:
|
|
55
|
+
sys.stderr.write("\n")
|
|
56
|
+
sys.stderr.flush()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _count_annotations(codebook, process_textbox=False):
|
|
60
|
+
"""Count how many annotation prompts will be issued for one row."""
|
|
61
|
+
count = 0
|
|
62
|
+
for key, section in codebook.items():
|
|
63
|
+
if not key.startswith("section_"):
|
|
64
|
+
continue
|
|
65
|
+
for annotation in section.get("annotations", {}).values():
|
|
66
|
+
if annotation.get("type") == "textbox" and not process_textbox:
|
|
67
|
+
continue
|
|
68
|
+
count += 1
|
|
69
|
+
return count
|
|
70
|
+
|
|
71
|
+
def load_codebook(codebook_path):
|
|
72
|
+
"""Load a CodeBook Studio/CodeBook Lab codebook JSON file.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
codebook_path: Path to a ``codebook.json`` file.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Parsed codebook dictionary.
|
|
79
|
+
"""
|
|
80
|
+
with open(codebook_path, 'r') as file:
|
|
81
|
+
codebook = json.load(file)
|
|
82
|
+
return codebook
|
|
83
|
+
|
|
84
|
+
def get_annotation_column_names(codebook):
|
|
85
|
+
"""Return the annotation column names implied by a codebook structure.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
codebook: Parsed codebook dictionary.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
List of column names in ``<section_name>_<annotation_name>`` format.
|
|
92
|
+
"""
|
|
93
|
+
annotation_columns = []
|
|
94
|
+
|
|
95
|
+
for key, section in codebook.items():
|
|
96
|
+
if not key.startswith("section_"):
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
section_name = section["section_name"]
|
|
100
|
+
annotations = section.get("annotations", {})
|
|
101
|
+
|
|
102
|
+
for annotation in annotations.values():
|
|
103
|
+
annotation_columns.append(f"{section_name}_{annotation['name']}")
|
|
104
|
+
|
|
105
|
+
return annotation_columns
|
|
106
|
+
|
|
107
|
+
def load_input_dataframe(csv_path, codebook):
|
|
108
|
+
"""Load the input CSV and remove any existing annotation label columns.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
csv_path: Path to the input CSV containing the source text column.
|
|
112
|
+
codebook: Parsed codebook dictionary describing annotation columns.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Pandas DataFrame ready for annotation.
|
|
116
|
+
"""
|
|
117
|
+
df = pd.read_csv(csv_path)
|
|
118
|
+
annotation_columns = get_annotation_column_names(codebook)
|
|
119
|
+
columns_to_drop = [column for column in annotation_columns if column in df.columns]
|
|
120
|
+
|
|
121
|
+
if columns_to_drop:
|
|
122
|
+
df = df.drop(columns=columns_to_drop)
|
|
123
|
+
dropped_columns = ", ".join(columns_to_drop)
|
|
124
|
+
logger.info(
|
|
125
|
+
"Dropping annotation label columns from input before LLM annotation: %s",
|
|
126
|
+
dropped_columns,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
text_column = codebook["text_column"]
|
|
130
|
+
if text_column not in df.columns:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"Text column '{text_column}' was not found in {csv_path} after preparing the input data."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return df
|
|
136
|
+
|
|
137
|
+
def normalize_country_iso_code(country_iso_code):
|
|
138
|
+
"""Validate and normalize an ISO 3166-1 alpha-3 country code.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
country_iso_code: Three-letter country code such as ``"USA"`` or ``"IRL"``.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Uppercase three-letter country code.
|
|
145
|
+
"""
|
|
146
|
+
normalized = country_iso_code.strip().upper()
|
|
147
|
+
if len(normalized) != 3 or not normalized.isalpha():
|
|
148
|
+
raise ValueError(
|
|
149
|
+
"country_iso_code must be a 3-letter ISO 3166-1 alpha-3 country code, "
|
|
150
|
+
"for example USA, IRL, or DEU."
|
|
151
|
+
)
|
|
152
|
+
return normalized
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def setup_model(model_name, temperature=None, top_p=None):
|
|
156
|
+
"""Create the LangChain-Ollama pipeline used for annotation.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
model_name: Ollama model identifier such as ``"gemma3:270m"``.
|
|
160
|
+
temperature: Optional sampling temperature.
|
|
161
|
+
top_p: Optional nucleus-sampling value.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
LangChain runnable that accepts ``{"question": prompt}``.
|
|
165
|
+
"""
|
|
166
|
+
model_kwargs = {}
|
|
167
|
+
if temperature is not None:
|
|
168
|
+
model_kwargs['temperature'] = float(temperature)
|
|
169
|
+
if top_p is not None:
|
|
170
|
+
model_kwargs['top_p'] = float(top_p)
|
|
171
|
+
|
|
172
|
+
llm = OllamaLLM(model=model_name, **model_kwargs)
|
|
173
|
+
prompt_template = ChatPromptTemplate.from_template("""{question}""")
|
|
174
|
+
chain = prompt_template | llm
|
|
175
|
+
return chain
|
|
176
|
+
|
|
177
|
+
def generate_response(chain, prompt, char_counts, timing_data, row_num=None, annotation_name=None):
|
|
178
|
+
"""Run one prompt through the model and update timing/count statistics.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
chain: Runnable returned by :func:`setup_model`.
|
|
182
|
+
prompt: Fully rendered prompt string.
|
|
183
|
+
char_counts: Mutable dict with ``input_chars`` and ``output_chars`` integers.
|
|
184
|
+
timing_data: Mutable dict with inference timing counters.
|
|
185
|
+
row_num: Optional 1-based row number for progress logging.
|
|
186
|
+
annotation_name: Optional annotation label for progress logging.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Raw model response string, or ``""`` if inference failed.
|
|
190
|
+
"""
|
|
191
|
+
try:
|
|
192
|
+
# Track input characters
|
|
193
|
+
char_counts['input_chars'] += len(prompt)
|
|
194
|
+
|
|
195
|
+
if row_num and annotation_name:
|
|
196
|
+
logger.info("[Row %s] Sending request for: %s...", row_num, annotation_name)
|
|
197
|
+
|
|
198
|
+
start_time = time.time()
|
|
199
|
+
response = chain.invoke({"question": prompt})
|
|
200
|
+
end_time = time.time()
|
|
201
|
+
inference_time = end_time - start_time
|
|
202
|
+
timing_data['total_inference_time'] += inference_time
|
|
203
|
+
timing_data['inference_count'] += 1
|
|
204
|
+
|
|
205
|
+
char_counts['output_chars'] += len(response)
|
|
206
|
+
|
|
207
|
+
if row_num and annotation_name:
|
|
208
|
+
logger.info("[Row %s] %s done (%.1fs)", row_num, annotation_name, inference_time)
|
|
209
|
+
|
|
210
|
+
return response
|
|
211
|
+
except Exception as e:
|
|
212
|
+
logger.warning("Error generating response: %s", e)
|
|
213
|
+
return ""
|
|
214
|
+
|
|
215
|
+
def extract_json_response(response, annotation_type, min_value=None, max_value=None):
|
|
216
|
+
"""
|
|
217
|
+
Extract and validate JSON response based on annotation type
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
response: Raw model response text that should contain a JSON object.
|
|
221
|
+
annotation_type: Annotation type string such as ``"dropdown"`` or ``"likert"``.
|
|
222
|
+
min_value: Optional integer lower bound for Likert annotations.
|
|
223
|
+
max_value: Optional integer upper bound for Likert annotations.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Parsed response value coerced into the expected annotation format.
|
|
227
|
+
"""
|
|
228
|
+
pattern = regex.compile(r'\{(?:[^{}]|(?R))*\}')
|
|
229
|
+
json_strings = pattern.findall(response)
|
|
230
|
+
|
|
231
|
+
for json_string in json_strings:
|
|
232
|
+
try:
|
|
233
|
+
parsed_json = json.loads(json_string)
|
|
234
|
+
response_value = parsed_json.get("response", "")
|
|
235
|
+
|
|
236
|
+
# Validate and format based on annotation type
|
|
237
|
+
if annotation_type == "dropdown":
|
|
238
|
+
return response_value
|
|
239
|
+
elif annotation_type == "checkbox":
|
|
240
|
+
# Convert to 1 or 0
|
|
241
|
+
if isinstance(response_value, bool):
|
|
242
|
+
return 1 if response_value else 0
|
|
243
|
+
elif isinstance(response_value, int) and (response_value == 0 or response_value == 1):
|
|
244
|
+
return response_value
|
|
245
|
+
elif isinstance(response_value, str):
|
|
246
|
+
if response_value.lower() in ["yes", "true", "1"]:
|
|
247
|
+
return 1
|
|
248
|
+
elif response_value.lower() in ["no", "false", "0"]:
|
|
249
|
+
return 0
|
|
250
|
+
# Default to 0 if invalid
|
|
251
|
+
return 0
|
|
252
|
+
elif annotation_type == "textbox":
|
|
253
|
+
# Return as string
|
|
254
|
+
return str(response_value)
|
|
255
|
+
elif annotation_type == "likert":
|
|
256
|
+
# Validate is within range and convert to int
|
|
257
|
+
try:
|
|
258
|
+
value = int(float(response_value))
|
|
259
|
+
if min_value is not None and max_value is not None:
|
|
260
|
+
return max(min_value, min(max_value, value)) # Clamp to range
|
|
261
|
+
return value
|
|
262
|
+
except (ValueError, TypeError):
|
|
263
|
+
# If not a valid number, return the middle of the scale if available
|
|
264
|
+
if min_value is not None and max_value is not None:
|
|
265
|
+
return (min_value + max_value) // 2
|
|
266
|
+
return response_value
|
|
267
|
+
|
|
268
|
+
# Fallback
|
|
269
|
+
return response_value
|
|
270
|
+
except json.JSONDecodeError as e:
|
|
271
|
+
logger.debug("Error parsing JSON: %s", e)
|
|
272
|
+
|
|
273
|
+
# If no valid JSON, try to extract direct response
|
|
274
|
+
if annotation_type == "checkbox":
|
|
275
|
+
if "yes" in response.lower() or "true" in response.lower():
|
|
276
|
+
return 1
|
|
277
|
+
elif "no" in response.lower() or "false" in response.lower():
|
|
278
|
+
return 0
|
|
279
|
+
return 0
|
|
280
|
+
elif annotation_type == "likert" and min_value is not None and max_value is not None:
|
|
281
|
+
# Try to find a number in the response
|
|
282
|
+
numbers = regex.findall(r'\d+', response)
|
|
283
|
+
for num in numbers:
|
|
284
|
+
try:
|
|
285
|
+
value = int(num)
|
|
286
|
+
if min_value <= value <= max_value:
|
|
287
|
+
return value
|
|
288
|
+
except ValueError:
|
|
289
|
+
continue
|
|
290
|
+
return (min_value + max_value) // 2 # Default to middle value
|
|
291
|
+
|
|
292
|
+
return response # Return raw response as fallback
|
|
293
|
+
|
|
294
|
+
def format_prompt(section_name, section_instruction, name, tooltip, annotation_type,
|
|
295
|
+
options=None, min_value=None, max_value=None, example=None,
|
|
296
|
+
text=None, prompt_type="standard", use_examples=False):
|
|
297
|
+
"""
|
|
298
|
+
Format the prompt based on annotation type and specified prompt type
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
section_name: Codebook section name.
|
|
302
|
+
section_instruction: Optional section-level instructions.
|
|
303
|
+
name: Annotation name within the section.
|
|
304
|
+
tooltip: Optional guidance text for the annotation.
|
|
305
|
+
annotation_type: One of ``"dropdown"``, ``"checkbox"``, ``"likert"``, or ``"textbox"``.
|
|
306
|
+
options: Dropdown option list when applicable.
|
|
307
|
+
min_value: Minimum Likert value when applicable.
|
|
308
|
+
max_value: Maximum Likert value when applicable.
|
|
309
|
+
example: Optional example block from the codebook.
|
|
310
|
+
text: Raw source text being annotated.
|
|
311
|
+
prompt_type: Registered prompt wrapper name or callable wrapper.
|
|
312
|
+
use_examples: Whether examples should be included in the prompt.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Full prompt string ready to send to the model.
|
|
316
|
+
"""
|
|
317
|
+
# Get response instructions based on annotation type
|
|
318
|
+
response_instructions = _get_response_instructions(
|
|
319
|
+
annotation_type, options, min_value, max_value
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Build the core prompt that's common to all prompt types
|
|
323
|
+
core_prompt = _build_core_prompt(
|
|
324
|
+
section_name, section_instruction, name, tooltip,
|
|
325
|
+
response_instructions, example, use_examples
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
context = PromptContext(
|
|
329
|
+
section_name=section_name,
|
|
330
|
+
section_instruction=section_instruction,
|
|
331
|
+
annotation_name=name,
|
|
332
|
+
tooltip=tooltip,
|
|
333
|
+
annotation_type=annotation_type,
|
|
334
|
+
options=options,
|
|
335
|
+
min_value=min_value,
|
|
336
|
+
max_value=max_value,
|
|
337
|
+
example=example or "",
|
|
338
|
+
text=text or "",
|
|
339
|
+
use_examples=use_examples,
|
|
340
|
+
response_instructions=response_instructions,
|
|
341
|
+
core_prompt=core_prompt,
|
|
342
|
+
)
|
|
343
|
+
return render_prompt(prompt_type, context)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def _get_response_instructions(annotation_type, options=None, min_value=None, max_value=None):
|
|
347
|
+
"""Generate type-specific response instructions for a prompt.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
annotation_type: Annotation type string.
|
|
351
|
+
options: Dropdown options when ``annotation_type`` is ``"dropdown"``.
|
|
352
|
+
min_value: Likert minimum when applicable.
|
|
353
|
+
max_value: Likert maximum when applicable.
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
Instruction string describing the expected response format.
|
|
357
|
+
"""
|
|
358
|
+
if annotation_type == "dropdown" and options:
|
|
359
|
+
options_str = ', or '.join(f'"{option}"' for option in options)
|
|
360
|
+
return f"Respond only with one of the following options: {options_str}."
|
|
361
|
+
elif annotation_type == "checkbox":
|
|
362
|
+
return "Respond with 1 if \"Yes\" or 0 if \"No\"."
|
|
363
|
+
elif annotation_type == "likert" and min_value is not None and max_value is not None:
|
|
364
|
+
return f"Respond with a whole number from {min_value} to {max_value} (inclusive), where {min_value} means lowest and {max_value} means highest."
|
|
365
|
+
elif annotation_type == "textbox":
|
|
366
|
+
return "Respond with a brief text explanation."
|
|
367
|
+
return ""
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _build_core_prompt(section_name, section_instruction, name, tooltip,
|
|
371
|
+
response_instructions, example, use_examples):
|
|
372
|
+
"""Build the wrapper-agnostic prompt body for a single annotation.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
section_name: Codebook section name.
|
|
376
|
+
section_instruction: Optional section-level instructions.
|
|
377
|
+
name: Annotation name within the section.
|
|
378
|
+
tooltip: Optional annotation guidance text.
|
|
379
|
+
response_instructions: String describing the expected response format.
|
|
380
|
+
example: Optional example block from the codebook.
|
|
381
|
+
use_examples: Whether example blocks should be included.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
Core prompt string before a prompt wrapper is applied.
|
|
385
|
+
"""
|
|
386
|
+
core = f"{section_name}"
|
|
387
|
+
|
|
388
|
+
if section_instruction:
|
|
389
|
+
core += f"\n{section_instruction}"
|
|
390
|
+
|
|
391
|
+
core += f"\n\n{name}"
|
|
392
|
+
|
|
393
|
+
if tooltip:
|
|
394
|
+
core += f"\n{tooltip}"
|
|
395
|
+
|
|
396
|
+
if response_instructions:
|
|
397
|
+
core += f"\n\n{response_instructions}"
|
|
398
|
+
|
|
399
|
+
core += "\n\nReturn your response in JSON format, with the key \"response\"."
|
|
400
|
+
|
|
401
|
+
if use_examples and example:
|
|
402
|
+
core += f"\n\n{example}"
|
|
403
|
+
elif not use_examples and example:
|
|
404
|
+
# Check if example contains instruction text that might be needed
|
|
405
|
+
if "Text:" not in example:
|
|
406
|
+
core += f"\n\n{example}"
|
|
407
|
+
|
|
408
|
+
return core
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _extract_task_name(csv_path):
|
|
412
|
+
"""Extract the task folder name from a CSV path when possible.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
csv_path: Input CSV path, usually under ``tasks/<task_name>/``.
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
Task name string if it can be inferred, otherwise ``None``.
|
|
419
|
+
"""
|
|
420
|
+
task_name = None
|
|
421
|
+
try:
|
|
422
|
+
parts = str(csv_path).split('/')
|
|
423
|
+
if 'tasks' in parts:
|
|
424
|
+
task_idx = parts.index('tasks') + 1
|
|
425
|
+
if task_idx < len(parts):
|
|
426
|
+
task_name = parts[task_idx]
|
|
427
|
+
except Exception:
|
|
428
|
+
pass
|
|
429
|
+
return task_name
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def _normalize_optional_parameter(value):
|
|
433
|
+
"""Normalize blank or ``"None"`` CLI-style values to ``None``."""
|
|
434
|
+
if value in (None, "", "None"):
|
|
435
|
+
return None
|
|
436
|
+
return value
|
|
437
|
+
|
|
438
|
+
def classify_text(chain, text, codebook, prompt_type="standard", use_examples=False,
|
|
439
|
+
char_counts=None, timing_data=None, process_textbox=False, row_num=None,
|
|
440
|
+
progress_bar=None, total_rows=None):
|
|
441
|
+
"""Annotate one text row across all sections in a codebook.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
chain: Runnable returned by :func:`setup_model`.
|
|
445
|
+
text: Raw source text to annotate.
|
|
446
|
+
codebook: Parsed codebook dictionary.
|
|
447
|
+
prompt_type: Registered prompt wrapper name or callable wrapper.
|
|
448
|
+
use_examples: Whether codebook examples should be included in prompts.
|
|
449
|
+
char_counts: Optional mutable counter dict for prompt/response characters.
|
|
450
|
+
timing_data: Optional mutable timing dict for inference statistics.
|
|
451
|
+
process_textbox: Whether textbox annotations should be generated.
|
|
452
|
+
row_num: Optional 1-based row number for progress logging.
|
|
453
|
+
progress_bar: Optional progress-bar helper updated after each annotation.
|
|
454
|
+
total_rows: Optional total row count for progress rendering.
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
Tuple of ``(responses, char_counts, timing_data)``.
|
|
458
|
+
"""
|
|
459
|
+
responses = {}
|
|
460
|
+
|
|
461
|
+
# Initialize character counts if not provided
|
|
462
|
+
if char_counts is None:
|
|
463
|
+
char_counts = {'input_chars': 0, 'output_chars': 0}
|
|
464
|
+
|
|
465
|
+
# Initialize timing data if not provided
|
|
466
|
+
if timing_data is None:
|
|
467
|
+
timing_data = {'total_inference_time': 0, 'inference_count': 0}
|
|
468
|
+
|
|
469
|
+
for key, section in codebook.items():
|
|
470
|
+
if key.startswith('section_'):
|
|
471
|
+
section_name = section['section_name']
|
|
472
|
+
section_instruction = section.get('section_instruction', '')
|
|
473
|
+
annotations = section['annotations']
|
|
474
|
+
|
|
475
|
+
for annotation_key, annotation in annotations.items():
|
|
476
|
+
name = annotation['name']
|
|
477
|
+
annotation_type = annotation['type']
|
|
478
|
+
|
|
479
|
+
# Skip textbox type annotations if process_textbox is False
|
|
480
|
+
if annotation_type == "textbox" and not process_textbox:
|
|
481
|
+
continue
|
|
482
|
+
|
|
483
|
+
tooltip = annotation.get('tooltip', '')
|
|
484
|
+
example = annotation.get('example', '')
|
|
485
|
+
|
|
486
|
+
# Get type-specific parameters
|
|
487
|
+
options = None
|
|
488
|
+
min_value = None
|
|
489
|
+
max_value = None
|
|
490
|
+
|
|
491
|
+
if annotation_type == "dropdown":
|
|
492
|
+
options = annotation.get('options', [])
|
|
493
|
+
elif annotation_type == "likert":
|
|
494
|
+
min_value = annotation.get('min_value')
|
|
495
|
+
max_value = annotation.get('max_value')
|
|
496
|
+
|
|
497
|
+
# Format prompt based on specified type and annotation type
|
|
498
|
+
prompt = format_prompt(
|
|
499
|
+
section_name,
|
|
500
|
+
section_instruction,
|
|
501
|
+
name,
|
|
502
|
+
tooltip,
|
|
503
|
+
annotation_type,
|
|
504
|
+
options,
|
|
505
|
+
min_value,
|
|
506
|
+
max_value,
|
|
507
|
+
example,
|
|
508
|
+
text,
|
|
509
|
+
prompt_type=prompt_type,
|
|
510
|
+
use_examples=use_examples
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
annotation_full_name = f"{section_name}_{name}"
|
|
514
|
+
response_text = generate_response(
|
|
515
|
+
chain,
|
|
516
|
+
prompt,
|
|
517
|
+
char_counts,
|
|
518
|
+
timing_data,
|
|
519
|
+
row_num=row_num,
|
|
520
|
+
annotation_name=annotation_full_name
|
|
521
|
+
)
|
|
522
|
+
response_value = extract_json_response(
|
|
523
|
+
response_text,
|
|
524
|
+
annotation_type,
|
|
525
|
+
min_value,
|
|
526
|
+
max_value
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
if response_value is not None:
|
|
530
|
+
# Store the response with a meaningful column name
|
|
531
|
+
column_name = f"{section_name}_{name}"
|
|
532
|
+
responses[column_name] = response_value
|
|
533
|
+
|
|
534
|
+
if progress_bar is not None and row_num is not None and total_rows is not None:
|
|
535
|
+
progress_bar.update(row_num, total_rows, annotation_full_name)
|
|
536
|
+
|
|
537
|
+
return responses, char_counts, timing_data
|
|
538
|
+
|
|
539
|
+
def apply_classification_to_csv(csv_path, output_path, codebook, chain, prompt_type="standard",
|
|
540
|
+
use_examples=False, process_textbox=False):
|
|
541
|
+
"""Run annotation over every row in an input CSV and write incremental results.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
csv_path: Path to the input CSV file.
|
|
545
|
+
output_path: Path where the annotated CSV should be written.
|
|
546
|
+
codebook: Parsed codebook dictionary.
|
|
547
|
+
chain: Runnable returned by :func:`setup_model`.
|
|
548
|
+
prompt_type: Registered prompt wrapper name or callable wrapper.
|
|
549
|
+
use_examples: Whether codebook examples should be included in prompts.
|
|
550
|
+
process_textbox: Whether textbox annotations should be generated.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Tuple of ``(classified_df, char_counts, timing_data)``.
|
|
554
|
+
"""
|
|
555
|
+
df = load_input_dataframe(csv_path, codebook)
|
|
556
|
+
|
|
557
|
+
logger.info("Starting classification of %d rows", len(df))
|
|
558
|
+
|
|
559
|
+
annotations_per_row = _count_annotations(codebook, process_textbox)
|
|
560
|
+
total_steps = len(df) * annotations_per_row
|
|
561
|
+
progress_bar = _AnnotationProgressBar(total_steps)
|
|
562
|
+
|
|
563
|
+
# Create a list to store all results
|
|
564
|
+
results = []
|
|
565
|
+
|
|
566
|
+
# Initialize character counts dictionary
|
|
567
|
+
char_counts = {'input_chars': 0, 'output_chars': 0}
|
|
568
|
+
|
|
569
|
+
# Initialize timing data dictionary
|
|
570
|
+
timing_data = {'total_inference_time': 0, 'inference_count': 0}
|
|
571
|
+
|
|
572
|
+
# Process each row individually
|
|
573
|
+
try:
|
|
574
|
+
for idx, row in df.iterrows():
|
|
575
|
+
row_num = idx + 1
|
|
576
|
+
text = row[codebook['text_column']]
|
|
577
|
+
|
|
578
|
+
logger.info("[Row %d/%d] Starting annotations...", row_num, len(df))
|
|
579
|
+
|
|
580
|
+
annotations, char_counts, timing_data = classify_text(
|
|
581
|
+
chain,
|
|
582
|
+
text,
|
|
583
|
+
codebook,
|
|
584
|
+
prompt_type,
|
|
585
|
+
use_examples,
|
|
586
|
+
char_counts,
|
|
587
|
+
timing_data,
|
|
588
|
+
process_textbox,
|
|
589
|
+
row_num=row_num,
|
|
590
|
+
progress_bar=progress_bar,
|
|
591
|
+
total_rows=len(df),
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# Add annotations to row data
|
|
595
|
+
row_data = row.to_dict()
|
|
596
|
+
row_data.update(annotations)
|
|
597
|
+
results.append(row_data)
|
|
598
|
+
|
|
599
|
+
# Save progress after each row
|
|
600
|
+
pd.DataFrame(results).to_csv(output_path, index=False)
|
|
601
|
+
|
|
602
|
+
avg_time = timing_data['total_inference_time'] / timing_data['inference_count'] if timing_data['inference_count'] > 0 else 0
|
|
603
|
+
logger.info("[Row %d/%d] Complete! (avg: %.1fs per annotation)", row_num, len(df), avg_time)
|
|
604
|
+
finally:
|
|
605
|
+
progress_bar.finish()
|
|
606
|
+
|
|
607
|
+
# Create final DataFrame
|
|
608
|
+
classified_df = pd.DataFrame(results)
|
|
609
|
+
classified_df.to_csv(output_path, index=False)
|
|
610
|
+
|
|
611
|
+
# Calculate average inference time
|
|
612
|
+
avg_inference_time = 0
|
|
613
|
+
if timing_data['inference_count'] > 0:
|
|
614
|
+
avg_inference_time = timing_data['total_inference_time'] / timing_data['inference_count']
|
|
615
|
+
timing_data['avg_inference_time'] = avg_inference_time
|
|
616
|
+
|
|
617
|
+
# Return character counts and timing data
|
|
618
|
+
return classified_df, char_counts, timing_data
|
|
619
|
+
|
|
620
|
+
def run_annotation(
|
|
621
|
+
*,
|
|
622
|
+
model,
|
|
623
|
+
csv_path,
|
|
624
|
+
codebook_path,
|
|
625
|
+
output_path,
|
|
626
|
+
experiment_directory,
|
|
627
|
+
prompt_type="standard",
|
|
628
|
+
use_examples=False,
|
|
629
|
+
temperature=None,
|
|
630
|
+
top_p=None,
|
|
631
|
+
process_textbox=False,
|
|
632
|
+
country_iso_code="USA",
|
|
633
|
+
start_ollama_if_needed=True,
|
|
634
|
+
):
|
|
635
|
+
"""Run one annotation job and persist its outputs to disk.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
model: Ollama model identifier such as ``"gemma3:270m"``.
|
|
639
|
+
csv_path: Path to the input CSV file to annotate.
|
|
640
|
+
codebook_path: Path to the matching ``codebook.json`` file.
|
|
641
|
+
output_path: Path where the annotated CSV should be written.
|
|
642
|
+
experiment_directory: Directory for metadata and sidecar output files.
|
|
643
|
+
prompt_type: Registered prompt wrapper name or callable wrapper.
|
|
644
|
+
use_examples: Whether codebook examples should be included in prompts.
|
|
645
|
+
temperature: Optional sampling temperature.
|
|
646
|
+
top_p: Optional nucleus-sampling value.
|
|
647
|
+
process_textbox: Whether textbox annotations should be generated.
|
|
648
|
+
country_iso_code: Three-letter ISO 3166-1 alpha-3 country code for CodeCarbon.
|
|
649
|
+
start_ollama_if_needed: If ``True``, try to start a local ``ollama serve``
|
|
650
|
+
process when the default local server is not already reachable.
|
|
651
|
+
Defaults to ``True`` so annotation runs can bring up the local Ollama
|
|
652
|
+
server automatically when needed.
|
|
653
|
+
|
|
654
|
+
Returns:
|
|
655
|
+
:class:`codebook_lab.types.AnnotationRunResult` describing the completed run.
|
|
656
|
+
"""
|
|
657
|
+
country_iso_code = normalize_country_iso_code(country_iso_code)
|
|
658
|
+
temperature = _normalize_optional_parameter(temperature)
|
|
659
|
+
top_p = _normalize_optional_parameter(top_p)
|
|
660
|
+
ollama_base_url = ensure_ollama_available(start_if_needed=start_ollama_if_needed)
|
|
661
|
+
|
|
662
|
+
experiment_directory = Path(experiment_directory)
|
|
663
|
+
output_path = Path(output_path)
|
|
664
|
+
experiment_directory.mkdir(parents=True, exist_ok=True)
|
|
665
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
666
|
+
|
|
667
|
+
task_name = _extract_task_name(csv_path)
|
|
668
|
+
prompt_type_name = get_prompt_type_name(prompt_type)
|
|
669
|
+
|
|
670
|
+
config = {
|
|
671
|
+
"model": model,
|
|
672
|
+
"prompt_type": prompt_type_name,
|
|
673
|
+
"use_examples": bool(use_examples),
|
|
674
|
+
"process_textbox": bool(process_textbox),
|
|
675
|
+
"country_iso_code": country_iso_code,
|
|
676
|
+
"task_name": task_name,
|
|
677
|
+
}
|
|
678
|
+
if temperature is not None:
|
|
679
|
+
config["temperature"] = temperature
|
|
680
|
+
if top_p is not None:
|
|
681
|
+
config["top_p"] = top_p
|
|
682
|
+
|
|
683
|
+
with open(experiment_directory / "config.json", "w") as f:
|
|
684
|
+
json.dump(config, f, indent=2)
|
|
685
|
+
|
|
686
|
+
codebook = load_codebook(codebook_path)
|
|
687
|
+
|
|
688
|
+
project_name = f"{model}_{prompt_type_name}_examples{str(bool(use_examples)).lower()}"
|
|
689
|
+
if temperature is not None:
|
|
690
|
+
project_name += f"_temp{temperature}"
|
|
691
|
+
if top_p is not None:
|
|
692
|
+
project_name += f"_topp{top_p}"
|
|
693
|
+
|
|
694
|
+
tracker = OfflineEmissionsTracker(
|
|
695
|
+
country_iso_code=country_iso_code,
|
|
696
|
+
output_dir=str(experiment_directory),
|
|
697
|
+
project_name=project_name,
|
|
698
|
+
allow_multiple_runs=True,
|
|
699
|
+
log_level='error'
|
|
700
|
+
)
|
|
701
|
+
tracker.start()
|
|
702
|
+
|
|
703
|
+
try:
|
|
704
|
+
chain = setup_model(model, temperature, top_p)
|
|
705
|
+
classified_df, char_counts, timing_data = apply_classification_to_csv(
|
|
706
|
+
str(csv_path),
|
|
707
|
+
str(output_path),
|
|
708
|
+
codebook,
|
|
709
|
+
chain,
|
|
710
|
+
prompt_type,
|
|
711
|
+
bool(use_examples),
|
|
712
|
+
bool(process_textbox),
|
|
713
|
+
)
|
|
714
|
+
finally:
|
|
715
|
+
emissions = tracker.stop()
|
|
716
|
+
|
|
717
|
+
with open(experiment_directory / "char_counts.json", "w") as f:
|
|
718
|
+
json.dump(char_counts, f, indent=2)
|
|
719
|
+
|
|
720
|
+
with open(experiment_directory / "timing_data.json", "w") as f:
|
|
721
|
+
json.dump(timing_data, f, indent=2)
|
|
722
|
+
|
|
723
|
+
logger.info("Classification complete. Results saved to %s", output_path)
|
|
724
|
+
logger.info("Configuration: %s", config)
|
|
725
|
+
logger.info("Country for emissions factors: %s", country_iso_code)
|
|
726
|
+
logger.info("Ollama server: %s", ollama_base_url)
|
|
727
|
+
logger.info("Estimated emissions: %s kg CO2eq", emissions)
|
|
728
|
+
logger.info("Total input characters: %s", char_counts['input_chars'])
|
|
729
|
+
logger.info("Total output characters: %s", char_counts['output_chars'])
|
|
730
|
+
logger.info("Total inference time: %.2f seconds", timing_data['total_inference_time'])
|
|
731
|
+
logger.info("Average inference time: %.2f seconds per call", timing_data['avg_inference_time'])
|
|
732
|
+
|
|
733
|
+
return AnnotationRunResult(
|
|
734
|
+
model=model,
|
|
735
|
+
output_path=output_path,
|
|
736
|
+
experiment_directory=experiment_directory,
|
|
737
|
+
config=config,
|
|
738
|
+
char_counts=char_counts,
|
|
739
|
+
timing_data=timing_data,
|
|
740
|
+
emissions=emissions,
|
|
741
|
+
dataframe=classified_df,
|
|
742
|
+
)
|