cat-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cat_stack/__about__.py +10 -0
- cat_stack/__init__.py +128 -0
- cat_stack/_batch.py +1388 -0
- cat_stack/_category_analysis.py +348 -0
- cat_stack/_chunked.py +424 -0
- cat_stack/_embeddings.py +189 -0
- cat_stack/_formatter.py +169 -0
- cat_stack/_providers.py +1048 -0
- cat_stack/_tiebreaker.py +277 -0
- cat_stack/_utils.py +512 -0
- cat_stack/_web_fetch.py +194 -0
- cat_stack/calls/CoVe.py +287 -0
- cat_stack/calls/__init__.py +25 -0
- cat_stack/calls/all_calls.py +622 -0
- cat_stack/calls/image_CoVe.py +386 -0
- cat_stack/calls/image_stepback.py +210 -0
- cat_stack/calls/pdf_CoVe.py +386 -0
- cat_stack/calls/pdf_stepback.py +210 -0
- cat_stack/calls/stepback.py +180 -0
- cat_stack/calls/top_n.py +217 -0
- cat_stack/classify.py +682 -0
- cat_stack/explore.py +111 -0
- cat_stack/extract.py +218 -0
- cat_stack/image_functions.py +2078 -0
- cat_stack/images/circle.png +0 -0
- cat_stack/images/cube.png +0 -0
- cat_stack/images/diamond.png +0 -0
- cat_stack/images/overlapping_pentagons.png +0 -0
- cat_stack/images/rectangles.png +0 -0
- cat_stack/model_reference_list.py +94 -0
- cat_stack/pdf_functions.py +2087 -0
- cat_stack/summarize.py +290 -0
- cat_stack/text_functions.py +1358 -0
- cat_stack/text_functions_ensemble.py +3644 -0
- cat_stack-0.1.0.dist-info/METADATA +150 -0
- cat_stack-0.1.0.dist-info/RECORD +38 -0
- cat_stack-0.1.0.dist-info/WHEEL +4 -0
- cat_stack-0.1.0.dist-info/licenses/LICENSE +672 -0
cat_stack/_formatter.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JSON formatter fallback for CatLLM.
|
|
3
|
+
|
|
4
|
+
Uses a fine-tuned Qwen2.5-0.5B model to convert messy LLM classification
|
|
5
|
+
output into valid cat-llm JSON format: {"1":"0","2":"1",...}
|
|
6
|
+
|
|
7
|
+
The formatter is opt-in via json_formatter=True on classify(). It only runs
|
|
8
|
+
when extract_json() produces invalid output — zero cost on the happy path.
|
|
9
|
+
|
|
10
|
+
Requires: pip install cat-llm[formatter]
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import sys
|
|
14
|
+
|
|
15
|
+
_MERGED_MODEL_REPO = "chrissoria/catllm-json-formatter"
|
|
16
|
+
|
|
17
|
+
_SYSTEM_PROMPT = (
|
|
18
|
+
"You are a JSON formatter for a text classification pipeline. "
|
|
19
|
+
"You will receive a list of categories (numbered 1 to N) and a raw "
|
|
20
|
+
"classification output from another model. Your job is to convert that "
|
|
21
|
+
'output into the exact JSON format required:\n'
|
|
22
|
+
'{"1":"0","2":"1","3":"0",...}\n\n'
|
|
23
|
+
"Rules:\n"
|
|
24
|
+
'- Keys are 1-indexed strings: "1", "2", ..., "N"\n'
|
|
25
|
+
'- Values are ONLY "0" (category absent) or "1" (category present)\n'
|
|
26
|
+
"- Include ALL N categories, even if absent\n"
|
|
27
|
+
"- Output ONLY the JSON object — no explanation, no markdown, no extra text\n"
|
|
28
|
+
'- If a category\'s presence is ambiguous, default to "0"'
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _check_dependencies():
|
|
33
|
+
"""Check that torch and transformers are installed."""
|
|
34
|
+
try:
|
|
35
|
+
import torch # noqa: F401
|
|
36
|
+
import transformers # noqa: F401
|
|
37
|
+
except ImportError:
|
|
38
|
+
raise ImportError(
|
|
39
|
+
"The JSON formatter requires additional dependencies.\n"
|
|
40
|
+
"Install them with: pip install cat-llm[formatter]\n"
|
|
41
|
+
" (requires: torch, transformers, accelerate)"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _is_model_cached() -> bool:
|
|
46
|
+
"""Check if the merged model is already in the HuggingFace cache."""
|
|
47
|
+
try:
|
|
48
|
+
from huggingface_hub import try_to_load_from_cache
|
|
49
|
+
result = try_to_load_from_cache(_MERGED_MODEL_REPO, "config.json")
|
|
50
|
+
return result is not None and not isinstance(result, type(None))
|
|
51
|
+
except Exception:
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def ensure_formatter_available() -> bool:
|
|
56
|
+
"""
|
|
57
|
+
Ensure the formatter model is available, prompting to download if needed.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
True if the formatter is ready to use, False if user declined download.
|
|
61
|
+
"""
|
|
62
|
+
_check_dependencies()
|
|
63
|
+
|
|
64
|
+
if _is_model_cached():
|
|
65
|
+
return True
|
|
66
|
+
|
|
67
|
+
print(
|
|
68
|
+
"\n[CatLLM] The JSON formatter model (~1GB) will be downloaded from\n"
|
|
69
|
+
f" HuggingFace Hub ({_MERGED_MODEL_REPO}).\n"
|
|
70
|
+
" This is a one-time download — the model is cached locally after."
|
|
71
|
+
)
|
|
72
|
+
try:
|
|
73
|
+
answer = input(" Continue? (Y/n): ").strip().lower()
|
|
74
|
+
except (EOFError, KeyboardInterrupt):
|
|
75
|
+
answer = "n"
|
|
76
|
+
|
|
77
|
+
if answer in ("", "y", "yes"):
|
|
78
|
+
return True
|
|
79
|
+
else:
|
|
80
|
+
print(" -> JSON formatter disabled for this run.\n")
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def load_formatter(device=None):
|
|
85
|
+
"""
|
|
86
|
+
Load the merged formatter model and tokenizer.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
device: Target device. None = auto-detect (CUDA > CPU; MPS skipped).
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Tuple of (model, tokenizer, device_str).
|
|
93
|
+
"""
|
|
94
|
+
_check_dependencies()
|
|
95
|
+
|
|
96
|
+
import torch
|
|
97
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
98
|
+
|
|
99
|
+
if device is None:
|
|
100
|
+
if torch.cuda.is_available():
|
|
101
|
+
device = "cuda"
|
|
102
|
+
else:
|
|
103
|
+
# Skip MPS — known PEFT/generation crash issues
|
|
104
|
+
device = "cpu"
|
|
105
|
+
|
|
106
|
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
|
107
|
+
|
|
108
|
+
print(f"[CatLLM] Loading JSON formatter on {device}...")
|
|
109
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
110
|
+
_MERGED_MODEL_REPO, trust_remote_code=True
|
|
111
|
+
)
|
|
112
|
+
if tokenizer.pad_token is None:
|
|
113
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
114
|
+
|
|
115
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
116
|
+
_MERGED_MODEL_REPO, dtype=dtype, trust_remote_code=True
|
|
117
|
+
)
|
|
118
|
+
model = model.to(device)
|
|
119
|
+
model.eval()
|
|
120
|
+
|
|
121
|
+
print("[CatLLM] JSON formatter ready.")
|
|
122
|
+
return model, tokenizer, device
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def run_formatter(raw_output, categories, model, tokenizer, device):
|
|
126
|
+
"""
|
|
127
|
+
Run the formatter model to fix malformed classification JSON.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
raw_output: The raw (messy) output from the classification LLM.
|
|
131
|
+
categories: List of category names.
|
|
132
|
+
model: The loaded formatter model.
|
|
133
|
+
tokenizer: The loaded tokenizer.
|
|
134
|
+
device: Device string ("cuda" or "cpu").
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
The formatter's output string (caller should run extract_json on it).
|
|
138
|
+
"""
|
|
139
|
+
import torch
|
|
140
|
+
|
|
141
|
+
# Build category list
|
|
142
|
+
cat_lines = "\n".join(
|
|
143
|
+
f"{i + 1}. {cat}" for i, cat in enumerate(categories)
|
|
144
|
+
)
|
|
145
|
+
user_msg = f"Categories:\n{cat_lines}\n\nRaw classification output:\n{raw_output}"
|
|
146
|
+
|
|
147
|
+
messages = [
|
|
148
|
+
{"role": "system", "content": _SYSTEM_PROMPT},
|
|
149
|
+
{"role": "user", "content": user_msg},
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
text = tokenizer.apply_chat_template(
|
|
153
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
154
|
+
)
|
|
155
|
+
inputs = tokenizer(text, return_tensors="pt").to(device)
|
|
156
|
+
|
|
157
|
+
with torch.no_grad():
|
|
158
|
+
out = model.generate(
|
|
159
|
+
**inputs,
|
|
160
|
+
max_new_tokens=128,
|
|
161
|
+
do_sample=False,
|
|
162
|
+
temperature=None,
|
|
163
|
+
top_p=None,
|
|
164
|
+
pad_token_id=tokenizer.eos_token_id,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Decode only newly generated tokens
|
|
168
|
+
new_tokens = out[0][inputs["input_ids"].shape[1]:]
|
|
169
|
+
return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|