cat-stack 1.1.1__tar.gz → 1.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_stack-1.1.1 → cat_stack-1.3.0}/PKG-INFO +1 -1
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/__about__.py +1 -1
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/__init__.py +17 -0
- cat_stack-1.3.0/src/catstack/_wrapper_helpers.py +330 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/prompt_tune.py +43 -14
- {cat_stack-1.1.1 → cat_stack-1.3.0}/.gitignore +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/LICENSE +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/README.md +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/pyproject.toml +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/cat_stack/__init__.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_batch.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_category_analysis.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_chunked.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_embeddings.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_formatter.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_pilot_test.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_prompts.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_providers.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_review_ui.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_tiebreaker.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_utils.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/_web_fetch.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/calls/CoVe.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/calls/__init__.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/calls/all_calls.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/calls/image_CoVe.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/calls/image_stepback.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/calls/pdf_CoVe.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/calls/pdf_stepback.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/calls/stepback.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/calls/top_n.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/classify.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/explore.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/extract.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/image_functions.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/images/circle.png +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/images/cube.png +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/images/diamond.png +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/images/overlapping_pentagons.png +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/images/rectangles.png +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/model_reference_list.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/pdf_functions.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/summarize.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/text_functions.py +0 -0
- {cat_stack-1.1.1 → cat_stack-1.3.0}/src/catstack/text_functions_ensemble.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-stack
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.3.0
|
|
4
4
|
Summary: Domain-agnostic text, image, PDF, and DOCX classification engine powered by LLMs
|
|
5
5
|
Project-URL: Documentation, https://github.com/chrissoria/cat-stack#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-stack/issues
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
4
|
-
__version__ = "1.
|
|
4
|
+
__version__ = "1.3.0"
|
|
5
5
|
__author__ = "Chris Soria"
|
|
6
6
|
__email__ = "chrissoria@berkeley.edu"
|
|
7
7
|
__title__ = "cat-stack"
|
|
@@ -83,6 +83,17 @@ from .image_functions import (
|
|
|
83
83
|
image_features,
|
|
84
84
|
)
|
|
85
85
|
|
|
86
|
+
# =============================================================================
|
|
87
|
+
# Wrapper helpers (for thin language bindings: Stata, future Julia/CLI)
|
|
88
|
+
# =============================================================================
|
|
89
|
+
from ._wrapper_helpers import (
|
|
90
|
+
get_backend,
|
|
91
|
+
parse_kwargs_string,
|
|
92
|
+
parse_models_string,
|
|
93
|
+
short_label,
|
|
94
|
+
classify_labels,
|
|
95
|
+
)
|
|
96
|
+
|
|
86
97
|
# Define public API
|
|
87
98
|
__all__ = [
|
|
88
99
|
# Batch mode exceptions
|
|
@@ -127,4 +138,10 @@ __all__ = [
|
|
|
127
138
|
"build_json_schema",
|
|
128
139
|
"extract_json",
|
|
129
140
|
"validate_classification_json",
|
|
141
|
+
# Wrapper helpers (for thin language bindings)
|
|
142
|
+
"get_backend",
|
|
143
|
+
"parse_kwargs_string",
|
|
144
|
+
"parse_models_string",
|
|
145
|
+
"short_label",
|
|
146
|
+
"classify_labels",
|
|
130
147
|
]
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Convenience helpers for thin language wrappers (Stata, future Julia/CLI).
|
|
3
|
+
|
|
4
|
+
These functions exist so each language wrapper does not have to re-implement
|
|
5
|
+
the same string-parsing and output-shaping logic. They are thin layers
|
|
6
|
+
over the main `classify()` / `extract()` / `explore()` / `summarize()` API
|
|
7
|
+
— same kwargs, same behavior — plus a few parsers for the string formats
|
|
8
|
+
that wrappers tend to accept from their host languages.
|
|
9
|
+
|
|
10
|
+
R users typically pass native lists / tuples and do not need the string
|
|
11
|
+
parsers, but `classify_labels()` is useful for getting one label per row
|
|
12
|
+
without manually walking the DataFrame.
|
|
13
|
+
|
|
14
|
+
These helpers are intentionally side-effect free and import-safe: nothing
|
|
15
|
+
here imports a domain sub-package (cat-pol, cat-vader, etc.) until the user
|
|
16
|
+
calls `get_backend("pol")`, so importing `catstack` does not require any
|
|
17
|
+
domain package to be installed.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import ast
|
|
23
|
+
import importlib
|
|
24
|
+
import re
|
|
25
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# -----------------------------------------------------------------------------
|
|
29
|
+
# Domain → module resolution
|
|
30
|
+
# -----------------------------------------------------------------------------
|
|
31
|
+
|
|
32
|
+
# Maps the user-facing short domain name to (python import name, pip package).
|
|
33
|
+
# Note: import names and pip names differ for the historical cat-vader,
|
|
34
|
+
# cat-ademic, and cat-web packages, which omit the underscore in their module
|
|
35
|
+
# name. This dict is the single source of truth across the ecosystem.
|
|
36
|
+
_DOMAIN_PACKAGES: Dict[str, Tuple[str, str]] = {
|
|
37
|
+
"pol": ("cat_pol", "cat-pol"),
|
|
38
|
+
"vader": ("catvader", "cat-vader"),
|
|
39
|
+
"ademic": ("catademic", "cat-ademic"),
|
|
40
|
+
"survey": ("cat_survey", "cat-survey"),
|
|
41
|
+
"cog": ("cat_cog", "cat-cog"),
|
|
42
|
+
"web": ("catweb", "cat-web"),
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_backend(domain: Optional[str] = None):
|
|
47
|
+
"""Return the Python module to call for a given domain shortform.
|
|
48
|
+
|
|
49
|
+
Empty string or None returns the base `catstack` module. Known domain
|
|
50
|
+
names ("pol", "vader", "ademic", "survey", "cog", "web") return their
|
|
51
|
+
respective sub-package module.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: if `domain` is set but not in the known list.
|
|
55
|
+
ImportError: if the domain package is not installed. The error
|
|
56
|
+
message tells the user the exact `catllm setup, domain(X)`
|
|
57
|
+
command to fix it.
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
>>> get_backend("").__name__
|
|
61
|
+
'catstack'
|
|
62
|
+
>>> get_backend(None).__name__
|
|
63
|
+
'catstack'
|
|
64
|
+
>>> # get_backend("pol") returns the cat_pol module if installed
|
|
65
|
+
"""
|
|
66
|
+
if not domain or not str(domain).strip():
|
|
67
|
+
import catstack # local import to avoid bootstrap cycles
|
|
68
|
+
return catstack
|
|
69
|
+
|
|
70
|
+
key = str(domain).strip().lower()
|
|
71
|
+
if key not in _DOMAIN_PACKAGES:
|
|
72
|
+
valid = ", ".join(_DOMAIN_PACKAGES.keys())
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Unknown domain: {domain!r}. Valid: {valid}."
|
|
75
|
+
)
|
|
76
|
+
module_name, pip_name = _DOMAIN_PACKAGES[key]
|
|
77
|
+
try:
|
|
78
|
+
return importlib.import_module(module_name)
|
|
79
|
+
except ImportError as e:
|
|
80
|
+
raise ImportError(
|
|
81
|
+
f"Domain package '{pip_name}' is not installed. "
|
|
82
|
+
f"Run: catllm setup, domain({key})"
|
|
83
|
+
) from e
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# -----------------------------------------------------------------------------
|
|
87
|
+
# String parsers (for wrappers whose host language passes options as strings)
|
|
88
|
+
# -----------------------------------------------------------------------------
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _strip_surrounding_quotes(s: str) -> str:
|
|
92
|
+
"""Strip one balanced pair of surrounding ' or " — Stata `string asis`
|
|
93
|
+
artifact. Leaves inner quotes untouched."""
|
|
94
|
+
s = s.strip()
|
|
95
|
+
if len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"):
|
|
96
|
+
return s[1:-1]
|
|
97
|
+
return s
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def parse_kwargs_string(s: Optional[str]) -> Dict[str, Any]:
|
|
101
|
+
"""Parse a `"key=val, key=val"` string into a Python kwargs dict.
|
|
102
|
+
|
|
103
|
+
Each value is run through `ast.literal_eval` so numbers, booleans,
|
|
104
|
+
strings, and lists all work naturally. Values that don't parse fall
|
|
105
|
+
back to the raw string.
|
|
106
|
+
|
|
107
|
+
Commas inside quotes / brackets are respected (no naive split).
|
|
108
|
+
|
|
109
|
+
Returns an empty dict for empty / None input.
|
|
110
|
+
|
|
111
|
+
Example:
|
|
112
|
+
>>> parse_kwargs_string("max_retries=3, retry_delay=0.5")
|
|
113
|
+
{'max_retries': 3, 'retry_delay': 0.5}
|
|
114
|
+
>>> parse_kwargs_string("format='bullets', research_question='Why did you move?'")
|
|
115
|
+
{'format': 'bullets', 'research_question': 'Why did you move?'}
|
|
116
|
+
"""
|
|
117
|
+
if not s:
|
|
118
|
+
return {}
|
|
119
|
+
s = _strip_surrounding_quotes(str(s))
|
|
120
|
+
if not s.strip():
|
|
121
|
+
return {}
|
|
122
|
+
|
|
123
|
+
# Walk character-by-character to split on commas at the top level only
|
|
124
|
+
# (not inside quotes or brackets).
|
|
125
|
+
pieces: List[str] = []
|
|
126
|
+
buf: List[str] = []
|
|
127
|
+
depth = 0
|
|
128
|
+
quote_char: Optional[str] = None
|
|
129
|
+
for ch in s:
|
|
130
|
+
if quote_char:
|
|
131
|
+
buf.append(ch)
|
|
132
|
+
if ch == quote_char:
|
|
133
|
+
quote_char = None
|
|
134
|
+
elif ch in ('"', "'"):
|
|
135
|
+
quote_char = ch
|
|
136
|
+
buf.append(ch)
|
|
137
|
+
elif ch in "([{":
|
|
138
|
+
depth += 1
|
|
139
|
+
buf.append(ch)
|
|
140
|
+
elif ch in ")]}":
|
|
141
|
+
depth -= 1
|
|
142
|
+
buf.append(ch)
|
|
143
|
+
elif ch == "," and depth == 0:
|
|
144
|
+
pieces.append("".join(buf))
|
|
145
|
+
buf = []
|
|
146
|
+
else:
|
|
147
|
+
buf.append(ch)
|
|
148
|
+
if buf:
|
|
149
|
+
pieces.append("".join(buf))
|
|
150
|
+
|
|
151
|
+
kwargs: Dict[str, Any] = {}
|
|
152
|
+
for p in pieces:
|
|
153
|
+
if "=" not in p:
|
|
154
|
+
continue
|
|
155
|
+
k, _, v = p.partition("=")
|
|
156
|
+
k = k.strip()
|
|
157
|
+
v = v.strip()
|
|
158
|
+
if not k:
|
|
159
|
+
continue
|
|
160
|
+
try:
|
|
161
|
+
kwargs[k] = ast.literal_eval(v)
|
|
162
|
+
except (ValueError, SyntaxError):
|
|
163
|
+
kwargs[k] = v
|
|
164
|
+
return kwargs
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def parse_models_string(
|
|
168
|
+
s: Optional[str],
|
|
169
|
+
default_api_key: Optional[str] = None,
|
|
170
|
+
) -> Optional[List[Tuple[str, ...]]]:
|
|
171
|
+
"""Parse `"model provider key; model provider key"` into a list of tuples.
|
|
172
|
+
|
|
173
|
+
Each entry is whitespace-split into 3 fields. Two-field entries inherit
|
|
174
|
+
`default_api_key` for the third position (useful when the same API key
|
|
175
|
+
powers multiple cloud models in an ensemble).
|
|
176
|
+
|
|
177
|
+
Returns None for empty / None input so callers can do `if models: ...`.
|
|
178
|
+
|
|
179
|
+
Example:
|
|
180
|
+
>>> parse_models_string("gpt-4o openai sk-...; claude-haiku-4-5 anthropic sk-ant-...")
|
|
181
|
+
[('gpt-4o', 'openai', 'sk-...'), ('claude-haiku-4-5', 'anthropic', 'sk-ant-...')]
|
|
182
|
+
>>> parse_models_string("qwen2.5:7b ollama _")
|
|
183
|
+
[('qwen2.5:7b', 'ollama', '_')]
|
|
184
|
+
"""
|
|
185
|
+
if not s or not str(s).strip():
|
|
186
|
+
return None
|
|
187
|
+
s = _strip_surrounding_quotes(str(s))
|
|
188
|
+
if not s.strip():
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
out: List[Tuple[str, ...]] = []
|
|
192
|
+
for entry in s.split(";"):
|
|
193
|
+
parts = entry.strip().split()
|
|
194
|
+
if len(parts) >= 3:
|
|
195
|
+
out.append(tuple(parts[:3]))
|
|
196
|
+
elif len(parts) == 2 and default_api_key is not None:
|
|
197
|
+
out.append((parts[0], parts[1], default_api_key))
|
|
198
|
+
# 1-token or empty entries are silently dropped — they're malformed
|
|
199
|
+
return out or None
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
# -----------------------------------------------------------------------------
|
|
203
|
+
# Output shaping
|
|
204
|
+
# -----------------------------------------------------------------------------
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def short_label(s: Any) -> Any:
|
|
208
|
+
"""Return the short label from a "Label: definition..." string.
|
|
209
|
+
|
|
210
|
+
Verbose category labels improve classification accuracy but are awkward
|
|
211
|
+
to display in a single output cell. `short_label("Positive: The
|
|
212
|
+
respondent expresses approval.")` returns `"Positive"`.
|
|
213
|
+
|
|
214
|
+
No-colon strings, empty strings, and non-string values are returned
|
|
215
|
+
unchanged.
|
|
216
|
+
"""
|
|
217
|
+
if isinstance(s, str) and ":" in s:
|
|
218
|
+
head = s.split(":", 1)[0].strip()
|
|
219
|
+
if head:
|
|
220
|
+
return head
|
|
221
|
+
return s
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
# Patterns used by classify_labels to find the per-category output columns.
|
|
225
|
+
_CONSENSUS_COL_PAT = re.compile(r"^category_(\d+)_consensus$")
|
|
226
|
+
_SINGLE_COL_PAT = re.compile(r"^category_(\d+)$")
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def classify_labels(
|
|
230
|
+
input_data,
|
|
231
|
+
categories,
|
|
232
|
+
*,
|
|
233
|
+
short_labels: bool = True,
|
|
234
|
+
multi_label_sep: str = "; ",
|
|
235
|
+
return_full: bool = False,
|
|
236
|
+
**kwargs,
|
|
237
|
+
):
|
|
238
|
+
"""Convenience wrapper around `classify()` returning one label per row.
|
|
239
|
+
|
|
240
|
+
The standard `classify()` returns a wide DataFrame with `category_1`,
|
|
241
|
+
`category_2`, ... (or `category_1_consensus`, ... in ensemble mode)
|
|
242
|
+
indicator columns. `classify_labels()` collapses that to a `list[str]`
|
|
243
|
+
of length `len(input_data)`, where each entry is the assigned category
|
|
244
|
+
name (joined by `multi_label_sep` if more than one category applies).
|
|
245
|
+
|
|
246
|
+
This is the function thin language wrappers should call when the host
|
|
247
|
+
language wants one labeled column per row (Stata, simple CLI tools).
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
input_data: List of texts, paths, or otherwise — same as `classify()`.
|
|
251
|
+
categories: List of category names — same as `classify()`.
|
|
252
|
+
short_labels: If True (default), apply `short_label()` to each
|
|
253
|
+
assigned category — so `"Positive: definition..."` becomes
|
|
254
|
+
`"Positive"` in the output. Pass False to keep the full text.
|
|
255
|
+
multi_label_sep: Separator used to join multiple matched categories
|
|
256
|
+
for a row. Default `"; "`. Has no effect when only one
|
|
257
|
+
category matches per row (the common case).
|
|
258
|
+
return_full: If True, return `(labels, df)` so callers also have
|
|
259
|
+
access to the underlying DataFrame. Default False.
|
|
260
|
+
**kwargs: All other kwargs are forwarded to `classify()`.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
list[str] of length `len(input_data)`, or `(labels, df)` tuple if
|
|
264
|
+
`return_full=True`.
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
RuntimeError: if `classify()` returns a DataFrame that contains
|
|
268
|
+
neither `category_N` nor `category_N_consensus` columns —
|
|
269
|
+
indicates that cat-stack's output schema has changed
|
|
270
|
+
incompatibly.
|
|
271
|
+
|
|
272
|
+
Example:
|
|
273
|
+
>>> labels = classify_labels(
|
|
274
|
+
... ["Great service", "Awful experience"],
|
|
275
|
+
... ["Positive: approval", "Negative: criticism"],
|
|
276
|
+
... api_key="...", user_model="gpt-4o-mini",
|
|
277
|
+
... )
|
|
278
|
+
>>> labels
|
|
279
|
+
['Positive', 'Negative']
|
|
280
|
+
"""
|
|
281
|
+
# Local import — `classify` lives in catstack.classify, but importing it
|
|
282
|
+
# at module load time would create a circular import (classify.py
|
|
283
|
+
# imports from this package indirectly).
|
|
284
|
+
from .classify import classify
|
|
285
|
+
|
|
286
|
+
df = classify(input_data=input_data, categories=categories, **kwargs)
|
|
287
|
+
|
|
288
|
+
cols = list(df.columns)
|
|
289
|
+
# Ensemble path first (more specific suffix)
|
|
290
|
+
indexed: List[Tuple[int, str]] = []
|
|
291
|
+
for c in cols:
|
|
292
|
+
m = _CONSENSUS_COL_PAT.match(c)
|
|
293
|
+
if m:
|
|
294
|
+
indexed.append((int(m.group(1)), c))
|
|
295
|
+
if not indexed:
|
|
296
|
+
for c in cols:
|
|
297
|
+
m = _SINGLE_COL_PAT.match(c)
|
|
298
|
+
if m:
|
|
299
|
+
indexed.append((int(m.group(1)), c))
|
|
300
|
+
if not indexed:
|
|
301
|
+
raise RuntimeError(
|
|
302
|
+
"classify() returned no category_N or category_N_consensus "
|
|
303
|
+
"columns. The output schema may have changed; this version of "
|
|
304
|
+
"classify_labels cannot map the result back to user-provided "
|
|
305
|
+
"category names. Got columns: " + ", ".join(cols)
|
|
306
|
+
)
|
|
307
|
+
indexed.sort(key=lambda t: t[0])
|
|
308
|
+
|
|
309
|
+
# Pre-shorten the category list once if requested.
|
|
310
|
+
if short_labels:
|
|
311
|
+
display_cats = [short_label(c) for c in categories]
|
|
312
|
+
else:
|
|
313
|
+
display_cats = list(categories)
|
|
314
|
+
|
|
315
|
+
labels_per_row: List[str] = []
|
|
316
|
+
for _, row in df.iterrows():
|
|
317
|
+
matched: List[str] = []
|
|
318
|
+
for n, col in indexed:
|
|
319
|
+
try:
|
|
320
|
+
if int(row[col]) == 1:
|
|
321
|
+
cat_idx = n - 1
|
|
322
|
+
if 0 <= cat_idx < len(display_cats):
|
|
323
|
+
matched.append(str(display_cats[cat_idx]))
|
|
324
|
+
except (ValueError, TypeError, KeyError):
|
|
325
|
+
continue
|
|
326
|
+
labels_per_row.append(multi_label_sep.join(matched))
|
|
327
|
+
|
|
328
|
+
if return_full:
|
|
329
|
+
return labels_per_row, df
|
|
330
|
+
return labels_per_row
|
|
@@ -18,7 +18,7 @@ Categories are never modified — only the system prompt changes.
|
|
|
18
18
|
from typing import Union
|
|
19
19
|
|
|
20
20
|
from ._category_analysis import has_other_category
|
|
21
|
-
from ._pilot_test import collect_corrections
|
|
21
|
+
from ._pilot_test import collect_corrections, compute_metrics
|
|
22
22
|
from .text_functions_ensemble import classify_ensemble
|
|
23
23
|
from ._providers import UnifiedLLMClient, detect_provider
|
|
24
24
|
|
|
@@ -290,16 +290,15 @@ def prompt_tune(
|
|
|
290
290
|
corrections = result["corrections"]
|
|
291
291
|
metrics = result["metrics"]
|
|
292
292
|
total_flips = result["total_flips"]
|
|
293
|
-
|
|
293
|
+
sample_indices = result["sample_indices"]
|
|
294
294
|
|
|
295
295
|
# Save ground truth from user corrections for auto-scoring later iterations
|
|
296
|
-
sample_indices = result["sample_indices"]
|
|
297
296
|
ground_truth = {
|
|
298
297
|
i: c["corrected"] for i, c in zip(sample_indices, corrections)
|
|
299
298
|
}
|
|
300
299
|
|
|
301
|
-
# Per-category metrics from baseline
|
|
302
300
|
per_cat = _compute_per_category_metrics(corrections, categories)
|
|
301
|
+
baseline_target = _target_fn(metrics)
|
|
303
302
|
|
|
304
303
|
# Print baseline summary
|
|
305
304
|
_print_classification_summary("Baseline", metrics, per_cat, categories, total_flips)
|
|
@@ -334,6 +333,8 @@ def prompt_tune(
|
|
|
334
333
|
print(f" Category {cat_idx}/{len(cats_with_errors)}: {target_cat} ({cat_errors} errors)")
|
|
335
334
|
print(f" Up to {max_iterations} iteration(s)")
|
|
336
335
|
|
|
336
|
+
attempt_history = []
|
|
337
|
+
prev_score = baseline_target
|
|
337
338
|
prev_instruction = cat_instructions.get(target_cat, "")
|
|
338
339
|
|
|
339
340
|
for attempt in range(1, max_iterations + 1):
|
|
@@ -353,6 +354,7 @@ def prompt_tune(
|
|
|
353
354
|
meta_source=meta_source,
|
|
354
355
|
meta_key=meta_key,
|
|
355
356
|
max_retries=max_retries,
|
|
357
|
+
attempt_history=attempt_history,
|
|
356
358
|
)
|
|
357
359
|
|
|
358
360
|
if not instruction:
|
|
@@ -379,7 +381,6 @@ def prompt_tune(
|
|
|
379
381
|
|
|
380
382
|
if result is None:
|
|
381
383
|
print("\n[CatLLM] Re-classification failed.")
|
|
382
|
-
# Revert this category
|
|
383
384
|
if prev_instruction:
|
|
384
385
|
cat_instructions[target_cat] = prev_instruction
|
|
385
386
|
else:
|
|
@@ -395,6 +396,21 @@ def prompt_tune(
|
|
|
395
396
|
|
|
396
397
|
new_cat_errors = per_cat[target_cat]["fp"] + per_cat[target_cat]["fn"]
|
|
397
398
|
|
|
399
|
+
# Classify outcome
|
|
400
|
+
if target_score > prev_score + 0.001:
|
|
401
|
+
outcome = "improved"
|
|
402
|
+
elif target_score < prev_score - 0.001:
|
|
403
|
+
outcome = "regressed"
|
|
404
|
+
else:
|
|
405
|
+
outcome = "no_change"
|
|
406
|
+
|
|
407
|
+
attempt_history.append({
|
|
408
|
+
"instruction": instruction,
|
|
409
|
+
"outcome": outcome,
|
|
410
|
+
"score_before": prev_score,
|
|
411
|
+
"score_after": target_score,
|
|
412
|
+
})
|
|
413
|
+
|
|
398
414
|
_print_classification_summary(
|
|
399
415
|
f"{target_cat} attempt {attempt}", metrics, per_cat, categories, total_flips,
|
|
400
416
|
)
|
|
@@ -412,21 +428,18 @@ def prompt_tune(
|
|
|
412
428
|
best_target = target_score
|
|
413
429
|
best_prompt = current_prompt
|
|
414
430
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
431
|
+
if outcome == "improved":
|
|
432
|
+
print(f" Improved: {target_cat} errors {cat_errors} → {new_cat_errors}")
|
|
433
|
+
prev_score = target_score
|
|
418
434
|
prev_instruction = instruction
|
|
419
435
|
cat_errors = new_cat_errors
|
|
420
436
|
if new_cat_errors == 0:
|
|
421
437
|
print(f" {target_cat}: all errors fixed!")
|
|
422
438
|
break
|
|
423
|
-
|
|
424
|
-
elif new_cat_errors == cat_errors:
|
|
439
|
+
elif outcome == "no_change":
|
|
425
440
|
print(f" No change for {target_cat} ({cat_errors} errors)")
|
|
426
|
-
# Instruction didn't help — try again with a different one
|
|
427
441
|
else:
|
|
428
|
-
print(f" Regressed: {target_cat} errors {cat_errors}
|
|
429
|
-
# Revert this attempt
|
|
442
|
+
print(f" Regressed: {target_cat} errors {cat_errors} → {new_cat_errors} — reverting")
|
|
430
443
|
if prev_instruction:
|
|
431
444
|
cat_instructions[target_cat] = prev_instruction
|
|
432
445
|
else:
|
|
@@ -654,6 +667,7 @@ def _generate_category_instruction(
|
|
|
654
667
|
meta_source,
|
|
655
668
|
meta_key,
|
|
656
669
|
max_retries,
|
|
670
|
+
attempt_history=None,
|
|
657
671
|
):
|
|
658
672
|
"""
|
|
659
673
|
Generate a targeted instruction for one category, given full error context.
|
|
@@ -735,6 +749,21 @@ def _generate_category_instruction(
|
|
|
735
749
|
# Current instruction
|
|
736
750
|
current_text = f'\nCURRENT INSTRUCTION FOR THIS CATEGORY:\n"{current_instruction}"\n' if current_instruction else ""
|
|
737
751
|
|
|
752
|
+
# History of previous attempts — capped at last 3 to avoid prompt bloat.
|
|
753
|
+
# Format is deliberately simple (no score numbers) so smaller models can follow it.
|
|
754
|
+
history_text = ""
|
|
755
|
+
if attempt_history:
|
|
756
|
+
recent = attempt_history[-3:]
|
|
757
|
+
history_lines = [
|
|
758
|
+
f' - "{h["instruction"]}" [{h["outcome"]}]'
|
|
759
|
+
for h in recent
|
|
760
|
+
]
|
|
761
|
+
history_text = (
|
|
762
|
+
"\nPREVIOUS INSTRUCTIONS TRIED FOR THIS CATEGORY (already tested — write something different):\n"
|
|
763
|
+
+ "\n".join(history_lines)
|
|
764
|
+
+ "\n"
|
|
765
|
+
)
|
|
766
|
+
|
|
738
767
|
optimize_guidance = {
|
|
739
768
|
"balanced": "",
|
|
740
769
|
"precision": " Focus especially on reducing false positives.",
|
|
@@ -753,7 +782,7 @@ ALL ERRORS ACROSS ALL CATEGORIES (<<< marks errors involving your target):
|
|
|
753
782
|
{all_error_lines and chr(10).join(all_error_lines) or "(no errors)"}
|
|
754
783
|
|
|
755
784
|
{target_section}
|
|
756
|
-
{current_text}
|
|
785
|
+
{current_text}{history_text}
|
|
757
786
|
Write a 1-2 sentence instruction for the category "{target_category}" that tells
|
|
758
787
|
a classifier when to assign and when NOT to assign it. Use the full error context
|
|
759
788
|
above to understand how this category relates to others, but only output guidance
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|