cat-stack 0.2.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {cat_stack-0.2.0 → cat_stack-0.3.0}/.gitignore +3 -2
  2. {cat_stack-0.2.0 → cat_stack-0.3.0}/PKG-INFO +1 -1
  3. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/__about__.py +1 -1
  4. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/__init__.py +2 -0
  5. cat_stack-0.3.0/src/cat_stack/_pilot_test.py +338 -0
  6. cat_stack-0.3.0/src/cat_stack/_review_ui.py +366 -0
  7. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/classify.py +67 -0
  8. cat_stack-0.3.0/src/cat_stack/prompt_tune.py +785 -0
  9. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/text_functions_ensemble.py +10 -2
  10. {cat_stack-0.2.0 → cat_stack-0.3.0}/LICENSE +0 -0
  11. {cat_stack-0.2.0 → cat_stack-0.3.0}/README.md +0 -0
  12. {cat_stack-0.2.0 → cat_stack-0.3.0}/pyproject.toml +0 -0
  13. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/_batch.py +0 -0
  14. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/_category_analysis.py +0 -0
  15. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/_chunked.py +0 -0
  16. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/_embeddings.py +0 -0
  17. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/_formatter.py +0 -0
  18. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/_providers.py +0 -0
  19. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/_tiebreaker.py +0 -0
  20. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/_utils.py +0 -0
  21. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/_web_fetch.py +0 -0
  22. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/calls/CoVe.py +0 -0
  23. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/calls/__init__.py +0 -0
  24. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/calls/all_calls.py +0 -0
  25. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/calls/image_CoVe.py +0 -0
  26. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/calls/image_stepback.py +0 -0
  27. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/calls/pdf_CoVe.py +0 -0
  28. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/calls/pdf_stepback.py +0 -0
  29. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/calls/stepback.py +0 -0
  30. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/calls/top_n.py +0 -0
  31. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/explore.py +0 -0
  32. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/extract.py +0 -0
  33. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/image_functions.py +0 -0
  34. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/images/circle.png +0 -0
  35. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/images/cube.png +0 -0
  36. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/images/diamond.png +0 -0
  37. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/images/overlapping_pentagons.png +0 -0
  38. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/images/rectangles.png +0 -0
  39. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/model_reference_list.py +0 -0
  40. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/pdf_functions.py +0 -0
  41. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/summarize.py +0 -0
  42. {cat_stack-0.2.0 → cat_stack-0.3.0}/src/cat_stack/text_functions.py +0 -0
@@ -31,8 +31,9 @@ images/generate_classify_diagram.py
31
31
  src/catllm/old_code/
32
32
  src/catllm/circle_classifier.py
33
33
 
34
- # Test files and test directory (top-level only)
35
- /tests/
34
+ # Test output and data (keep test scripts tracked)
35
+ /tests/output/
36
+ /tests/data/
36
37
 
37
38
  # Survey summarizer (separate project)
38
39
  survey-summarizer
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cat-stack
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: Domain-agnostic text, image, PDF, and DOCX classification engine powered by LLMs
5
5
  Project-URL: Documentation, https://github.com/chrissoria/cat-stack#readme
6
6
  Project-URL: Issues, https://github.com/chrissoria/cat-stack/issues
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
2
  #
3
3
  # SPDX-License-Identifier: GPL-3.0-or-later
4
- __version__ = "0.2.0"
4
+ __version__ = "0.3.0"
5
5
  __author__ = "Chris Soria"
6
6
  __email__ = "chrissoria@berkeley.edu"
7
7
  __title__ = "cat-stack"
@@ -20,6 +20,7 @@ from .extract import extract
20
20
  from .explore import explore
21
21
  from .classify import classify
22
22
  from .summarize import summarize
23
+ from .prompt_tune import prompt_tune
23
24
 
24
25
  # Category analysis
25
26
  from ._category_analysis import has_other_category, check_category_verbosity
@@ -92,6 +93,7 @@ __all__ = [
92
93
  "explore",
93
94
  "classify",
94
95
  "summarize",
96
+ "prompt_tune",
95
97
  # Category analysis
96
98
  "has_other_category",
97
99
  "check_category_verbosity",
@@ -0,0 +1,338 @@
1
+ """
2
+ Pilot test module for CatLLM.
3
+
4
+ Provides two capabilities:
5
+ 1. collect_corrections() — classify a small sample and collect category-level
6
+ user corrections via a browser UI. Used by prompt_tune() and classify(pilot_test=True).
7
+ 2. run_pilot_test() — wrapper that collects corrections and asks whether to proceed.
8
+ """
9
+
10
+ import random
11
+
12
+
13
+ def compute_metrics(corrections):
14
+ """
15
+ Compute cell-level accuracy, sensitivity, and precision from corrections.
16
+
17
+ Each (item, category) pair is a cell. The model's original output is the
18
+ prediction; the user's corrected output is ground truth.
19
+
20
+ - TP: model=1, truth=1 (correctly identified)
21
+ - FP: model=1, truth=0 (false alarm — user flipped 1→0)
22
+ - FN: model=0, truth=1 (missed — user flipped 0→1)
23
+ - TN: model=0, truth=0 (correctly excluded)
24
+
25
+ Returns:
26
+ dict with "accuracy", "sensitivity", "precision" (each 0-1 float).
27
+ When a denominator is zero (e.g. no positives), that metric is 1.0.
28
+ """
29
+ tp = fp = fn = tn = 0
30
+ for c in corrections:
31
+ for cat, orig_val in c["original"].items():
32
+ truth_val = c["corrected"][cat]
33
+ if orig_val == 1 and truth_val == 1:
34
+ tp += 1
35
+ elif orig_val == 1 and truth_val == 0:
36
+ fp += 1
37
+ elif orig_val == 0 and truth_val == 1:
38
+ fn += 1
39
+ else:
40
+ tn += 1
41
+
42
+ total = tp + fp + fn + tn
43
+ accuracy = (tp + tn) / total if total > 0 else 1.0
44
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 1.0
45
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 1.0
46
+
47
+ return {
48
+ "accuracy": accuracy,
49
+ "sensitivity": sensitivity,
50
+ "precision": precision,
51
+ }
52
+
53
+
54
+ def collect_corrections(
55
+ input_data,
56
+ categories,
57
+ models,
58
+ classify_ensemble_fn,
59
+ ensemble_kwargs,
60
+ sample_size=10,
61
+ system_prompt="",
62
+ ui="browser",
63
+ ):
64
+ """
65
+ Classify a random sample and collect per-category user corrections.
66
+
67
+ Opens a browser-based review UI where the user can toggle category
68
+ checkboxes for each item, then submit all corrections at once.
69
+
70
+ Args:
71
+ input_data: The full input data (list or Series).
72
+ categories: List of category names.
73
+ models: Models list (same format as classify()).
74
+ classify_ensemble_fn: The classify_ensemble callable.
75
+ ensemble_kwargs: Dict of keyword arguments to forward to classify_ensemble.
76
+ sample_size: Number of random items to test. Default 10.
77
+ system_prompt: Optional system prompt to use for this classification run.
78
+ ui: Review interface to use. "browser" (default) opens a local web page
79
+ with checkboxes. "terminal" uses text-based input.
80
+
81
+ Returns:
82
+ dict with keys:
83
+ - "corrections": list of dicts, each with:
84
+ - "input": str — the input text
85
+ - "original": dict — {category_name: 0/1} as model classified
86
+ - "corrected": dict — {category_name: 0/1} after user corrections
87
+ - "changed": list of str — category names that were flipped
88
+ - "metrics": dict with "accuracy", "sensitivity", "precision"
89
+ (each 0-1 float), computed cell-wise across all (item, category)
90
+ pairs.
91
+ - "total_flips": int — total number of category-level corrections
92
+ - "sample_indices": list of int indices that were sampled
93
+ Returns None if user cancels.
94
+ """
95
+ import pandas as pd
96
+
97
+ # Convert to list for indexing
98
+ if isinstance(input_data, pd.Series):
99
+ items_list = input_data.tolist()
100
+ else:
101
+ items_list = list(input_data)
102
+
103
+ n_total = len(items_list)
104
+ if n_total == 0:
105
+ print("[CatLLM] No items to test.")
106
+ return {
107
+ "corrections": [],
108
+ "metrics": {"accuracy": 1.0, "sensitivity": 1.0, "precision": 1.0},
109
+ "total_flips": 0, "sample_indices": [],
110
+ }
111
+
112
+ # Sample
113
+ actual_sample_size = min(sample_size, n_total)
114
+ sample_indices = sorted(random.sample(range(n_total), actual_sample_size))
115
+ sample_items = [items_list[i] for i in sample_indices]
116
+
117
+ print(f"\n[CatLLM] Classifying {actual_sample_size} random item(s)...")
118
+ print("=" * 60)
119
+
120
+ # Run classification on the sample
121
+ pilot_kwargs = dict(ensemble_kwargs)
122
+ pilot_kwargs["filename"] = None
123
+ pilot_kwargs["save_directory"] = None
124
+ pilot_kwargs["progress_callback"] = None
125
+ pilot_kwargs["input_data"] = sample_items
126
+ pilot_kwargs["categories"] = categories
127
+ pilot_kwargs["models"] = models
128
+ if system_prompt:
129
+ pilot_kwargs["system_prompt"] = system_prompt
130
+
131
+ try:
132
+ pilot_result = classify_ensemble_fn(**pilot_kwargs)
133
+ except Exception as e:
134
+ print(f"\n[CatLLM] Classification failed: {e}")
135
+ return None
136
+
137
+ is_multi_model = len(models) > 1
138
+
139
+ # Extract per-item category values from the result DataFrame
140
+ review_items = []
141
+ for row_idx in range(len(pilot_result)):
142
+ row = pilot_result.iloc[row_idx]
143
+ input_text = sample_items[row_idx]
144
+
145
+ cat_values = {}
146
+ for cat_idx, cat in enumerate(categories, 1):
147
+ if is_multi_model:
148
+ col = f"category_{cat_idx}_consensus"
149
+ else:
150
+ col = f"category_{cat_idx}"
151
+
152
+ val = 0
153
+ if col in pilot_result.columns:
154
+ raw = row[col]
155
+ if raw is not None and str(raw) == "1":
156
+ val = 1
157
+ cat_values[cat] = val
158
+
159
+ review_items.append({
160
+ "input": input_text,
161
+ "values": cat_values,
162
+ })
163
+
164
+ # Collect corrections via the chosen UI
165
+ if ui == "browser":
166
+ corrections = _collect_via_browser(review_items, categories)
167
+ else:
168
+ corrections = _collect_via_terminal(review_items, categories)
169
+
170
+ if corrections is None:
171
+ return None
172
+
173
+ total_flips = sum(len(c["changed"]) for c in corrections)
174
+ metrics = compute_metrics(corrections)
175
+
176
+ return {
177
+ "corrections": corrections,
178
+ "metrics": metrics,
179
+ "total_flips": total_flips,
180
+ "sample_indices": sample_indices,
181
+ }
182
+
183
+
184
+ def _collect_via_browser(review_items, categories):
185
+ """Open a browser-based review UI and return corrections."""
186
+ from ._review_ui import open_review_ui
187
+ return open_review_ui(review_items, categories)
188
+
189
+
190
+ def _collect_via_terminal(review_items, categories):
191
+ """Collect corrections via terminal text input (fallback)."""
192
+ corrections = []
193
+ n = len(review_items)
194
+
195
+ print(f"\n{'=' * 60}")
196
+ print("RESULTS — Review each classification")
197
+ print("Enter category numbers to flip (e.g. '1,3'), or press Enter if correct.")
198
+ print(f"{'=' * 60}\n")
199
+
200
+ for idx, item in enumerate(review_items):
201
+ input_text = item["input"]
202
+ cat_values = item["values"]
203
+
204
+ display_text = str(input_text)
205
+ if len(display_text) > 200:
206
+ display_text = display_text[:200] + "..."
207
+
208
+ print(f"--- Item {idx + 1}/{n} ---")
209
+ print(f" Input: {display_text}\n")
210
+
211
+ print(" Categories:")
212
+ for cat_idx, cat in enumerate(categories, 1):
213
+ val = cat_values[cat]
214
+ marker = "1" if val else "0"
215
+ cat_display = cat if len(cat) <= 60 else cat[:57] + "..."
216
+ print(f" {cat_idx}. {cat_display:<60s} = {marker}")
217
+ print()
218
+
219
+ try:
220
+ answer = input(
221
+ " Numbers to flip (e.g. '1,3'), Enter if correct, 'q' to quit: "
222
+ ).strip().lower()
223
+ except (EOFError, KeyboardInterrupt):
224
+ print("\n\n[CatLLM] Cancelled.")
225
+ return None
226
+
227
+ if answer in ("q", "quit", "exit"):
228
+ print("\n[CatLLM] Cancelled by user.")
229
+ return None
230
+
231
+ original = dict(cat_values)
232
+ corrected = dict(cat_values)
233
+ changed = []
234
+
235
+ if answer:
236
+ try:
237
+ nums = [int(x.strip()) for x in answer.split(",") if x.strip()]
238
+ except ValueError:
239
+ print(" (Could not parse input — treating as no corrections)")
240
+ nums = []
241
+
242
+ for num in nums:
243
+ if 1 <= num <= len(categories):
244
+ cat_name = categories[num - 1]
245
+ corrected[cat_name] = 1 - corrected[cat_name]
246
+ changed.append(cat_name)
247
+ else:
248
+ print(f" (Ignoring invalid number: {num})")
249
+
250
+ if changed:
251
+ print(f" Flipped: {', '.join(changed)}")
252
+
253
+ corrections.append({
254
+ "input": input_text,
255
+ "original": original,
256
+ "corrected": corrected,
257
+ "changed": changed,
258
+ })
259
+ print()
260
+
261
+ return corrections
262
+
263
+
264
+ def run_pilot_test(
265
+ input_data,
266
+ categories,
267
+ models,
268
+ classify_ensemble_fn,
269
+ ensemble_kwargs,
270
+ sample_size=10,
271
+ ui="browser",
272
+ ):
273
+ """
274
+ Run a pilot classification, collect corrections, and ask whether to proceed.
275
+
276
+ Thin wrapper around collect_corrections() that prints a summary and asks
277
+ the user to confirm before the full classification run.
278
+
279
+ Returns:
280
+ dict with "proceed" key (bool) plus all keys from collect_corrections(),
281
+ or None if cancelled.
282
+ """
283
+ result = collect_corrections(
284
+ input_data=input_data,
285
+ categories=categories,
286
+ models=models,
287
+ classify_ensemble_fn=classify_ensemble_fn,
288
+ ensemble_kwargs=ensemble_kwargs,
289
+ sample_size=sample_size,
290
+ ui=ui,
291
+ )
292
+
293
+ if result is None:
294
+ return None
295
+
296
+ # Print summary
297
+ m = result["metrics"]
298
+
299
+ print(f"{'=' * 60}")
300
+ print(f"PILOT TEST SUMMARY")
301
+ print(f" Accuracy: {m['accuracy'] * 100:.1f}%")
302
+ print(f" Sensitivity: {m['sensitivity'] * 100:.1f}%")
303
+ print(f" Precision: {m['precision'] * 100:.1f}%")
304
+ print(f" Corrections: {result['total_flips']}")
305
+ print(f"{'=' * 60}\n")
306
+
307
+ avg = (m["accuracy"] + m["sensitivity"] + m["precision"]) / 3
308
+ if avg < 0.7:
309
+ print(
310
+ " WARNING: Average score is below 70%.\n"
311
+ " Consider revising your categories — adding descriptions and examples\n"
312
+ " significantly improves accuracy. You can also use prompt_tune() to\n"
313
+ " automatically optimize the classification prompt.\n"
314
+ )
315
+ elif avg < 0.9:
316
+ print(
317
+ " Some classifications needed corrections. Consider using prompt_tune()\n"
318
+ " to optimize the prompt before running the full classification.\n"
319
+ )
320
+ else:
321
+ print(" Classifications look good!\n")
322
+
323
+ # Ask whether to proceed
324
+ try:
325
+ answer = input(" Proceed with full classification? (Y/n): ").strip().lower()
326
+ except (EOFError, KeyboardInterrupt):
327
+ print("\n[CatLLM] Classification cancelled.")
328
+ result["proceed"] = False
329
+ return result
330
+
331
+ result["proceed"] = answer in ("", "y", "yes")
332
+
333
+ if not result["proceed"]:
334
+ print("\n[CatLLM] Classification cancelled. Adjust your categories and try again.\n")
335
+ else:
336
+ print("\n[CatLLM] Proceeding with full classification...\n")
337
+
338
+ return result