cat-stack 1.0.22__tar.gz → 1.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_stack-1.0.22 → cat_stack-1.1.1}/PKG-INFO +1 -1
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/__about__.py +1 -1
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_category_analysis.py +15 -6
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_providers.py +9 -4
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/classify.py +41 -5
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/text_functions.py +44 -22
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/text_functions_ensemble.py +56 -8
- {cat_stack-1.0.22 → cat_stack-1.1.1}/.gitignore +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/LICENSE +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/README.md +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/pyproject.toml +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/cat_stack/__init__.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/__init__.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_batch.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_chunked.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_embeddings.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_formatter.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_pilot_test.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_prompts.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_review_ui.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_tiebreaker.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_utils.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/_web_fetch.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/calls/CoVe.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/calls/__init__.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/calls/all_calls.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/calls/image_CoVe.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/calls/image_stepback.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/calls/pdf_CoVe.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/calls/pdf_stepback.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/calls/stepback.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/calls/top_n.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/explore.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/extract.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/image_functions.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/images/circle.png +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/images/cube.png +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/images/diamond.png +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/images/overlapping_pentagons.png +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/images/rectangles.png +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/model_reference_list.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/pdf_functions.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/prompt_tune.py +0 -0
- {cat_stack-1.0.22 → cat_stack-1.1.1}/src/catstack/summarize.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-stack
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.1
|
|
4
4
|
Summary: Domain-agnostic text, image, PDF, and DOCX classification engine powered by LLMs
|
|
5
5
|
Project-URL: Documentation, https://github.com/chrissoria/cat-stack#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-stack/issues
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
4
|
-
__version__ = "1.
|
|
4
|
+
__version__ = "1.1.1"
|
|
5
5
|
__author__ = "Chris Soria"
|
|
6
6
|
__email__ = "chrissoria@berkeley.edu"
|
|
7
7
|
__title__ = "cat-stack"
|
|
@@ -19,13 +19,22 @@ __all__ = ["has_other_category", "check_category_verbosity"]
|
|
|
19
19
|
_MAX_HEURISTIC_WORDS = 4
|
|
20
20
|
|
|
21
21
|
# Tier 1: Anchored patterns — safe at any category length.
|
|
22
|
-
# These only match when the keyword IS the category label itself.
|
|
22
|
+
# These only match when the keyword IS the category label itself (or its prefix).
|
|
23
23
|
_ANCHORED_PATTERNS = [
|
|
24
|
-
re.compile(r"^other\s*$", re.IGNORECASE),
|
|
25
|
-
re.compile(r"^other\s*[:(]", re.IGNORECASE),
|
|
26
|
-
re.compile(r"^n/?a\s*$", re.IGNORECASE),
|
|
27
|
-
re.compile(r"^miscellaneous\s*$", re.IGNORECASE),
|
|
28
|
-
re.compile(r"^catch[\s-]?all\s*$", re.IGNORECASE),
|
|
24
|
+
re.compile(r"^other\s*$", re.IGNORECASE), # exact "Other"
|
|
25
|
+
re.compile(r"^other\s*[:(]", re.IGNORECASE), # "Other: ...", "Other (..."
|
|
26
|
+
re.compile(r"^n/?a\s*$", re.IGNORECASE), # exact "N/A", "NA"
|
|
27
|
+
re.compile(r"^miscellaneous\s*$", re.IGNORECASE), # exact "Miscellaneous"
|
|
28
|
+
re.compile(r"^catch[\s-]?all\s*$", re.IGNORECASE), # exact "catch-all"
|
|
29
|
+
# Sentiment/opinion contexts — "Neutral" serves as the ambivalent catch-all
|
|
30
|
+
re.compile(r"^neutral\s*$", re.IGNORECASE), # exact "Neutral"
|
|
31
|
+
re.compile(r"^neutral\s*[:(]", re.IGNORECASE), # "Neutral: ...", "Neutral (..."
|
|
32
|
+
# Other common ambivalent labels
|
|
33
|
+
re.compile(r"^ambiguous\s*$", re.IGNORECASE), # exact "Ambiguous"
|
|
34
|
+
re.compile(r"^ambiguous\s*[:(]", re.IGNORECASE), # "Ambiguous: ..."
|
|
35
|
+
re.compile(r"^unclear\s*$", re.IGNORECASE), # exact "Unclear"
|
|
36
|
+
re.compile(r"^unclassified\s*$", re.IGNORECASE), # exact "Unclassified"
|
|
37
|
+
re.compile(r"^unknown\s*$", re.IGNORECASE), # exact "Unknown"
|
|
29
38
|
]
|
|
30
39
|
|
|
31
40
|
# Tier 2: Phrase patterns — only applied to short categories (≤ _MAX_HEURISTIC_WORDS).
|
|
@@ -772,12 +772,17 @@ def check_ollama_model(model: str, host: str = "localhost", port: int = 11434) -
|
|
|
772
772
|
True if model is available, False otherwise
|
|
773
773
|
"""
|
|
774
774
|
available_models = list_ollama_models(host, port)
|
|
775
|
-
# Check for exact match or partial match (e.g., "llama3.2" matches "llama3.2:latest")
|
|
776
775
|
model_lower = model.lower()
|
|
776
|
+
if ":" in model_lower:
|
|
777
|
+
# User specified an explicit tag (e.g. "qwen2.5:14b") — require exact
|
|
778
|
+
# match. An installed "qwen2.5:7b" must NOT satisfy a request for
|
|
779
|
+
# "qwen2.5:14b"; the previous prefix-match logic let this through,
|
|
780
|
+
# which caused silent classification failures downstream.
|
|
781
|
+
return any(m.lower() == model_lower for m in available_models)
|
|
782
|
+
# User specified just the family (e.g. "llama3.2") — any installed
|
|
783
|
+
# variant of that family counts (e.g. "llama3.2:latest", "llama3.2:7b").
|
|
777
784
|
return any(
|
|
778
|
-
|
|
779
|
-
m.lower().startswith(f"{model_lower}:") or
|
|
780
|
-
model_lower.startswith(m.lower().split(":")[0])
|
|
785
|
+
m.lower() == model_lower or m.lower().startswith(f"{model_lower}:")
|
|
781
786
|
for m in available_models
|
|
782
787
|
)
|
|
783
788
|
|
|
@@ -92,6 +92,7 @@ def classify(
|
|
|
92
92
|
add_other = "prompt",
|
|
93
93
|
check_verbosity: bool = True,
|
|
94
94
|
json_formatter: Optional[bool] = None,
|
|
95
|
+
two_step_classify: Optional[bool] = None,
|
|
95
96
|
embeddings: bool = False,
|
|
96
97
|
category_descriptions: dict = None,
|
|
97
98
|
embedding_tiebreaker: bool = False,
|
|
@@ -133,7 +134,11 @@ def classify(
|
|
|
133
134
|
- "image" (default): Render pages as images
|
|
134
135
|
- "text": Extract text only
|
|
135
136
|
- "both": Send both image and extracted text
|
|
136
|
-
creativity (float): Temperature setting. None uses model default
|
|
137
|
+
creativity (float): Temperature setting. None uses model default,
|
|
138
|
+
except for Ollama where it defaults to 0.0 (classification is not
|
|
139
|
+
creative generation; deterministic output reproduces across runs
|
|
140
|
+
and avoids high-entropy junk that throws off small local models).
|
|
141
|
+
Pass an explicit value to override.
|
|
137
142
|
safety (bool): If True, saves progress after each item.
|
|
138
143
|
chain_of_verification (bool): Enable Chain of Verification for accuracy.
|
|
139
144
|
chain_of_thought (bool): Enable step-by-step reasoning. Default False.
|
|
@@ -202,6 +207,19 @@ def classify(
|
|
|
202
207
|
produces invalid output — zero cost on the happy path. On first
|
|
203
208
|
use, the model (~1GB) is downloaded from HuggingFace Hub.
|
|
204
209
|
Requires: pip install cat-llm[formatter]. Default False.
|
|
210
|
+
Auto-enabled when two_step_classify is True (or when any model in
|
|
211
|
+
`models` uses the Ollama provider).
|
|
212
|
+
two_step_classify (bool): Split classification into two LLM calls:
|
|
213
|
+
(1) natural-language reasoning, then (2) JSON formatting. More
|
|
214
|
+
reliable for weaker models — local Ollama models, but also lower-
|
|
215
|
+
tier API models (gpt-4o-mini, claude-haiku, gemini-flash) that
|
|
216
|
+
struggle to produce strict per-category JSON in a single shot.
|
|
217
|
+
When enabled, the raw step-1 reasoning is routed through the
|
|
218
|
+
fine-tuned JSON formatter (json_formatter is auto-enabled).
|
|
219
|
+
Default None: auto-enable for Ollama models, disable otherwise.
|
|
220
|
+
Set True to force it on any provider; False to disable for Ollama.
|
|
221
|
+
Per-model override is also supported via the 4-tuple options dict:
|
|
222
|
+
("gpt-4o-mini", "openai", key, {"two_step_classify": True})
|
|
205
223
|
embeddings (bool): If True, add embedding-based similarity scores
|
|
206
224
|
alongside binary 0/1 classifications. Uses a local sentence-
|
|
207
225
|
transformer model (BAAI/bge-small-en-v1.5, ~130MB) to compute
|
|
@@ -552,14 +570,31 @@ def classify(
|
|
|
552
570
|
return True
|
|
553
571
|
return False
|
|
554
572
|
|
|
573
|
+
# Local Ollama models benefit enormously from temperature=0 on classification:
|
|
574
|
+
# in benchmarks, qwen2.5:7b accuracy jumped from 78% to 85% and produced
|
|
575
|
+
# bit-identical labels across runs (no more "{Negative: '.$/1234567890...'}"
|
|
576
|
+
# high-entropy junk). Classification is not creative generation; the user
|
|
577
|
+
# can still override by passing creativity= explicitly.
|
|
578
|
+
if creativity is None and _uses_ollama_provider():
|
|
579
|
+
creativity = 0.0
|
|
580
|
+
|
|
555
581
|
if json_formatter is None:
|
|
556
|
-
|
|
557
|
-
|
|
582
|
+
if two_step_classify is True:
|
|
583
|
+
json_formatter = True
|
|
558
584
|
print(
|
|
559
|
-
"\n[CatLLM]
|
|
560
|
-
" (
|
|
585
|
+
"\n[CatLLM] two_step_classify=True — auto-enabling JSON formatter\n"
|
|
586
|
+
" (the formatter receives the step-1 reasoning text and is what\n"
|
|
587
|
+
" makes the two-step path actually more accurate than one-shot).\n"
|
|
561
588
|
" Pass json_formatter=False to opt out."
|
|
562
589
|
)
|
|
590
|
+
else:
|
|
591
|
+
json_formatter = _uses_ollama_provider()
|
|
592
|
+
if json_formatter:
|
|
593
|
+
print(
|
|
594
|
+
"\n[CatLLM] Ollama detected — auto-enabling JSON formatter fallback\n"
|
|
595
|
+
" (small local models more often emit malformed JSON).\n"
|
|
596
|
+
" Pass json_formatter=False to opt out."
|
|
597
|
+
)
|
|
563
598
|
|
|
564
599
|
# The formatter MODEL is loaded lazily on the first parse failure (saves
|
|
565
600
|
# ~1 GB RAM + load time when no rows actually need rescuing). The dep
|
|
@@ -832,6 +867,7 @@ def classify(
|
|
|
832
867
|
save_directory=save_directory,
|
|
833
868
|
progress_callback=progress_callback,
|
|
834
869
|
formatter_state=_formatter_state,
|
|
870
|
+
two_step_classify=two_step_classify,
|
|
835
871
|
multi_label=multi_label,
|
|
836
872
|
categories_per_call=categories_per_call,
|
|
837
873
|
embedding_tiebreaker_state=_embedding_tiebreaker_state,
|
|
@@ -219,7 +219,7 @@ def ollama_two_step_classify(
|
|
|
219
219
|
survey_question: str = "",
|
|
220
220
|
creativity: float = None,
|
|
221
221
|
max_retries: int = 5,
|
|
222
|
-
) -> tuple[str, str | None]:
|
|
222
|
+
) -> tuple[str, str, str | None]:
|
|
223
223
|
"""
|
|
224
224
|
Two-step classification for Ollama models.
|
|
225
225
|
|
|
@@ -239,35 +239,42 @@ def ollama_two_step_classify(
|
|
|
239
239
|
max_retries: Number of retry attempts for JSON validation
|
|
240
240
|
|
|
241
241
|
Returns:
|
|
242
|
-
tuple: (json_string, error_message or None)
|
|
242
|
+
tuple: (json_string, step1_raw_reply, error_message or None)
|
|
243
|
+
step1_raw_reply is the unformatted step-1 output; callers can
|
|
244
|
+
pass it to the fine-tuned formatter even when step-2 produced
|
|
245
|
+
syntactically valid (but semantically empty) JSON.
|
|
243
246
|
"""
|
|
244
247
|
num_categories = len(categories)
|
|
245
248
|
survey_context = f"Context: {survey_question}." if survey_question else ""
|
|
246
249
|
|
|
247
250
|
# ==========================================================================
|
|
248
|
-
# Step 1: Classification (
|
|
251
|
+
# Step 1: Classification (simple list of applicable categories)
|
|
249
252
|
# ==========================================================================
|
|
253
|
+
# Weak models (local Ollama, lower-tier API models) can't reliably produce
|
|
254
|
+
# per-category YES/NO output OR strict JSON in one shot. Ask for the
|
|
255
|
+
# simplest possible output — just the names of the applicable categories,
|
|
256
|
+
# one per line — and let the fine-tuned formatter (or step 2) slot those
|
|
257
|
+
# names into the indexed JSON schema.
|
|
250
258
|
step1_messages = [
|
|
251
259
|
{
|
|
252
260
|
"role": "system",
|
|
253
|
-
"content": "You are an expert at categorizing text
|
|
261
|
+
"content": "You are an expert at categorizing text. You read a response and pick the categories that apply."
|
|
254
262
|
},
|
|
255
263
|
{
|
|
256
264
|
"role": "user",
|
|
257
265
|
"content": f"""{survey_context}
|
|
258
266
|
|
|
259
|
-
|
|
267
|
+
Read this text response:
|
|
260
268
|
|
|
261
|
-
|
|
269
|
+
"{response_text}"
|
|
270
|
+
|
|
271
|
+
Decide which of these categories apply to the response:
|
|
262
272
|
|
|
263
|
-
Categories:
|
|
264
273
|
{categories_str}
|
|
265
274
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
2. [Category name]: YES/NO - [brief reason]
|
|
270
|
-
...and so on for all categories."""
|
|
275
|
+
Output ONLY the names of the categories that apply, one per line.
|
|
276
|
+
Write nothing else — no numbering, no reasoning, no JSON, no markdown.
|
|
277
|
+
If none apply, write the single word: None"""
|
|
271
278
|
}
|
|
272
279
|
]
|
|
273
280
|
|
|
@@ -276,33 +283,48 @@ Format your answer as:
|
|
|
276
283
|
json_schema=None, # No JSON requirement for step 1
|
|
277
284
|
creativity=creativity,
|
|
278
285
|
)
|
|
286
|
+
# Preserve the original step-1 text; the retry loop below overwrites
|
|
287
|
+
# step1_reply with error context, but callers need the raw output so the
|
|
288
|
+
# fine-tuned formatter can extract the true classification signal from it
|
|
289
|
+
# even when step-2 later produces valid-but-all-zero JSON.
|
|
290
|
+
original_step1_reply = step1_reply
|
|
279
291
|
|
|
280
292
|
if step1_error:
|
|
281
|
-
return '{"1":"e"}', f"Step 1 failed: {step1_error}"
|
|
293
|
+
return '{"1":"e"}', "", f"Step 1 failed: {step1_error}"
|
|
282
294
|
|
|
283
295
|
# ==========================================================================
|
|
284
296
|
# Step 2: JSON Formatting with validation and retry
|
|
285
297
|
# ==========================================================================
|
|
286
298
|
example_json = json.dumps({str(i): "0" for i in range(1, num_categories + 1)})
|
|
287
299
|
|
|
300
|
+
# Numbered category list for step 2 — the formatter needs to map each
|
|
301
|
+
# name in step1_reply back to its position in the original list.
|
|
302
|
+
numbered_categories = "\n".join(
|
|
303
|
+
f"{i + 1}. {c}" for i, c in enumerate(categories)
|
|
304
|
+
)
|
|
305
|
+
|
|
288
306
|
for attempt in range(max_retries):
|
|
289
307
|
step2_messages = [
|
|
290
308
|
{
|
|
291
309
|
"role": "system",
|
|
292
|
-
"content": "You convert
|
|
310
|
+
"content": "You convert a list of category names to a JSON object marking which categories were selected. Output ONLY valid JSON, nothing else."
|
|
293
311
|
},
|
|
294
312
|
{
|
|
295
313
|
"role": "user",
|
|
296
|
-
"content": f"""
|
|
314
|
+
"content": f"""Categories (numbered 1 to {num_categories}):
|
|
315
|
+
{numbered_categories}
|
|
297
316
|
|
|
298
|
-
|
|
317
|
+
Selected categories (the names that were chosen — may be a subset, all, or none):
|
|
299
318
|
{step1_reply}
|
|
300
319
|
|
|
320
|
+
Output a JSON object where each key is a category number ("1" through "{num_categories}")
|
|
321
|
+
and each value is "1" if that category appears in the selected list, "0" if not.
|
|
322
|
+
|
|
301
323
|
Rules:
|
|
302
324
|
- Output ONLY a JSON object, no other text
|
|
303
|
-
-
|
|
304
|
-
-
|
|
305
|
-
-
|
|
325
|
+
- Include ALL {num_categories} categories as keys
|
|
326
|
+
- Match by category name (allow partial / case-insensitive matches)
|
|
327
|
+
- If the selected list says "None" or is empty, all values are "0"
|
|
306
328
|
|
|
307
329
|
Example format:
|
|
308
330
|
{example_json}
|
|
@@ -320,14 +342,14 @@ Your JSON output:"""
|
|
|
320
342
|
if step2_error:
|
|
321
343
|
if attempt < max_retries - 1:
|
|
322
344
|
continue
|
|
323
|
-
return '{"1":"e"}', f"Step 2 failed: {step2_error}"
|
|
345
|
+
return '{"1":"e"}', original_step1_reply, f"Step 2 failed: {step2_error}"
|
|
324
346
|
|
|
325
347
|
# Extract and validate JSON
|
|
326
348
|
extracted = extract_json(step2_reply)
|
|
327
349
|
is_valid, normalized = validate_classification_json(extracted, num_categories)
|
|
328
350
|
|
|
329
351
|
if is_valid:
|
|
330
|
-
return json.dumps(normalized), None
|
|
352
|
+
return json.dumps(normalized), original_step1_reply, None
|
|
331
353
|
|
|
332
354
|
# If invalid, try again with more explicit instructions
|
|
333
355
|
if attempt < max_retries - 1:
|
|
@@ -340,7 +362,7 @@ Please be more careful to output EXACTLY {num_categories} categories numbered 1
|
|
|
340
362
|
|
|
341
363
|
# All retries exhausted - try to salvage what we can
|
|
342
364
|
extracted = extract_json(step2_reply) if step2_reply else '{"1":"e"}'
|
|
343
|
-
return extracted, f"JSON validation failed after {max_retries} attempts"
|
|
365
|
+
return extracted, original_step1_reply, f"JSON validation failed after {max_retries} attempts"
|
|
344
366
|
|
|
345
367
|
|
|
346
368
|
# =============================================================================
|
|
@@ -575,7 +575,11 @@ def _format_creativity_suffix(creativity) -> str:
|
|
|
575
575
|
return f"_t{int(round(creativity * 100))}"
|
|
576
576
|
|
|
577
577
|
|
|
578
|
-
def prepare_model_configs(
|
|
578
|
+
def prepare_model_configs(
|
|
579
|
+
models: list,
|
|
580
|
+
auto_download: bool = False,
|
|
581
|
+
two_step_classify: Optional[bool] = None,
|
|
582
|
+
) -> list:
|
|
579
583
|
"""
|
|
580
584
|
Validate and prepare model configurations.
|
|
581
585
|
|
|
@@ -583,8 +587,14 @@ def prepare_model_configs(models: list, auto_download: bool = False) -> list:
|
|
|
583
587
|
models: List of tuples. Each tuple can be:
|
|
584
588
|
- (model, provider, api_key) — 3 elements
|
|
585
589
|
- (model, provider, api_key, options) — 4 elements, where options is a
|
|
586
|
-
dict with per-model overrides (e.g. {"creativity": 0.5
|
|
590
|
+
dict with per-model overrides (e.g. {"creativity": 0.5,
|
|
591
|
+
"two_step_classify": True})
|
|
587
592
|
auto_download: If True, automatically download missing Ollama models
|
|
593
|
+
two_step_classify: Global override for the two-step classify mode.
|
|
594
|
+
None (default) → auto-enable for Ollama, off for everything else.
|
|
595
|
+
True → enable for all models (useful for weaker API models that
|
|
596
|
+
also struggle with strict JSON). False → never use it.
|
|
597
|
+
Per-model overrides via the options dict take precedence.
|
|
588
598
|
|
|
589
599
|
Returns:
|
|
590
600
|
List of config dicts with validated settings
|
|
@@ -690,6 +700,18 @@ def prepare_model_configs(models: list, auto_download: bool = False) -> list:
|
|
|
690
700
|
# Per-model creativity override (None means use global)
|
|
691
701
|
per_model_creativity = options.get("creativity", None) if options else None
|
|
692
702
|
|
|
703
|
+
# Resolve two-step setting. Precedence:
|
|
704
|
+
# 1. per-model option override (options["two_step_classify"])
|
|
705
|
+
# 2. global parameter override (two_step_classify=)
|
|
706
|
+
# 3. auto-detect: True iff provider is Ollama
|
|
707
|
+
per_model_two_step = options.get("two_step_classify", None) if options else None
|
|
708
|
+
if per_model_two_step is not None:
|
|
709
|
+
effective_two_step = bool(per_model_two_step)
|
|
710
|
+
elif two_step_classify is not None:
|
|
711
|
+
effective_two_step = bool(two_step_classify)
|
|
712
|
+
else:
|
|
713
|
+
effective_two_step = (detected_provider == "ollama")
|
|
714
|
+
|
|
693
715
|
# Build sanitized column name
|
|
694
716
|
base_name = sanitize_model_name(model)
|
|
695
717
|
if is_ensemble:
|
|
@@ -699,7 +721,7 @@ def prepare_model_configs(models: list, auto_download: bool = False) -> list:
|
|
|
699
721
|
"model": model,
|
|
700
722
|
"provider": detected_provider,
|
|
701
723
|
"api_key": api_key,
|
|
702
|
-
"use_two_step":
|
|
724
|
+
"use_two_step": effective_two_step,
|
|
703
725
|
"sanitized_name": base_name,
|
|
704
726
|
"creativity": per_model_creativity,
|
|
705
727
|
})
|
|
@@ -2274,6 +2296,8 @@ def classify_ensemble(
|
|
|
2274
2296
|
auto_download: bool = False,
|
|
2275
2297
|
# JSON formatter fallback
|
|
2276
2298
|
formatter_state: dict = None,
|
|
2299
|
+
# Two-step classify (text-first then format). None = auto-detect for Ollama.
|
|
2300
|
+
two_step_classify: Optional[bool] = None,
|
|
2277
2301
|
# Label mode
|
|
2278
2302
|
multi_label: bool = True,
|
|
2279
2303
|
# Chunked classification
|
|
@@ -2483,7 +2507,11 @@ def classify_ensemble(
|
|
|
2483
2507
|
|
|
2484
2508
|
# Prepare model configurations
|
|
2485
2509
|
print(f"Validating {len(models)} model configuration(s)...")
|
|
2486
|
-
model_configs = prepare_model_configs(
|
|
2510
|
+
model_configs = prepare_model_configs(
|
|
2511
|
+
models,
|
|
2512
|
+
auto_download=auto_download,
|
|
2513
|
+
two_step_classify=two_step_classify,
|
|
2514
|
+
)
|
|
2487
2515
|
|
|
2488
2516
|
# Print model info
|
|
2489
2517
|
print(f"\nModels to use:")
|
|
@@ -2934,8 +2962,8 @@ Categorize text responses {cove_categorize}:
|
|
|
2934
2962
|
else:
|
|
2935
2963
|
response_text = item
|
|
2936
2964
|
|
|
2937
|
-
if cfg["use_two_step"]: # Ollama
|
|
2938
|
-
json_result, error = ollama_two_step_classify(
|
|
2965
|
+
if cfg["use_two_step"]: # Ollama (or two_step_classify=True)
|
|
2966
|
+
json_result, step1_raw, error = ollama_two_step_classify(
|
|
2939
2967
|
client=client,
|
|
2940
2968
|
response_text=response_text,
|
|
2941
2969
|
categories=categories,
|
|
@@ -2944,8 +2972,28 @@ Categorize text responses {cove_categorize}:
|
|
|
2944
2972
|
creativity=effective_creativity,
|
|
2945
2973
|
max_retries=max_retries,
|
|
2946
2974
|
)
|
|
2947
|
-
|
|
2948
|
-
|
|
2975
|
+
# Normal path: step 2 (qwen as formatter) usually maps the
|
|
2976
|
+
# step-1 list correctly. Only fall back to the fine-tuned
|
|
2977
|
+
# formatter when step 2 returned all-zeros AND step 1 said
|
|
2978
|
+
# something non-empty — that combination signals step 2
|
|
2979
|
+
# silently lost the classification signal (the original bug).
|
|
2980
|
+
# Overriding a confident, non-zero step-2 result with the
|
|
2981
|
+
# formatter's interpretation of messy step-1 text loses
|
|
2982
|
+
# accuracy in the common case.
|
|
2983
|
+
def _is_all_zero(js):
|
|
2984
|
+
try:
|
|
2985
|
+
d = json.loads(js)
|
|
2986
|
+
return all(str(v) == "0" for v in d.values())
|
|
2987
|
+
except Exception:
|
|
2988
|
+
return False
|
|
2989
|
+
|
|
2990
|
+
step1_meaningful = step1_raw and step1_raw.strip().lower() not in ("", "none")
|
|
2991
|
+
if step1_meaningful and _is_all_zero(json_result):
|
|
2992
|
+
fmt_result = _try_formatter_fallback('{"1":"e"}', step1_raw)
|
|
2993
|
+
if fmt_result != '{"1":"e"}':
|
|
2994
|
+
json_result = fmt_result
|
|
2995
|
+
elif error or not json_result:
|
|
2996
|
+
json_result = _try_formatter_fallback(json_result or '{"1":"e"}', step1_raw or "")
|
|
2949
2997
|
# CoVe not supported for Ollama two-step (already has verification)
|
|
2950
2998
|
else:
|
|
2951
2999
|
messages = build_text_classification_prompt(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|