cat-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,169 @@
1
+ """
2
+ JSON formatter fallback for CatLLM.
3
+
4
+ Uses a fine-tuned Qwen2.5-0.5B model to convert messy LLM classification
5
+ output into valid cat-llm JSON format: {"1":"0","2":"1",...}
6
+
7
+ The formatter is opt-in via json_formatter=True on classify(). It only runs
8
+ when extract_json() produces invalid output — zero cost on the happy path.
9
+
10
+ Requires: pip install cat-llm[formatter]
11
+ """
12
+
13
+ import sys
14
+
15
+ _MERGED_MODEL_REPO = "chrissoria/catllm-json-formatter"
16
+
17
+ _SYSTEM_PROMPT = (
18
+ "You are a JSON formatter for a text classification pipeline. "
19
+ "You will receive a list of categories (numbered 1 to N) and a raw "
20
+ "classification output from another model. Your job is to convert that "
21
+ 'output into the exact JSON format required:\n'
22
+ '{"1":"0","2":"1","3":"0",...}\n\n'
23
+ "Rules:\n"
24
+ '- Keys are 1-indexed strings: "1", "2", ..., "N"\n'
25
+ '- Values are ONLY "0" (category absent) or "1" (category present)\n'
26
+ "- Include ALL N categories, even if absent\n"
27
+ "- Output ONLY the JSON object — no explanation, no markdown, no extra text\n"
28
+ '- If a category\'s presence is ambiguous, default to "0"'
29
+ )
30
+
31
+
32
+ def _check_dependencies():
33
+ """Check that torch and transformers are installed."""
34
+ try:
35
+ import torch # noqa: F401
36
+ import transformers # noqa: F401
37
+ except ImportError:
38
+ raise ImportError(
39
+ "The JSON formatter requires additional dependencies.\n"
40
+ "Install them with: pip install cat-llm[formatter]\n"
41
+ " (requires: torch, transformers, accelerate)"
42
+ )
43
+
44
+
45
+ def _is_model_cached() -> bool:
46
+ """Check if the merged model is already in the HuggingFace cache."""
47
+ try:
48
+ from huggingface_hub import try_to_load_from_cache
49
+ result = try_to_load_from_cache(_MERGED_MODEL_REPO, "config.json")
50
+ return result is not None and not isinstance(result, type(None))
51
+ except Exception:
52
+ return False
53
+
54
+
55
+ def ensure_formatter_available() -> bool:
56
+ """
57
+ Ensure the formatter model is available, prompting to download if needed.
58
+
59
+ Returns:
60
+ True if the formatter is ready to use, False if user declined download.
61
+ """
62
+ _check_dependencies()
63
+
64
+ if _is_model_cached():
65
+ return True
66
+
67
+ print(
68
+ "\n[CatLLM] The JSON formatter model (~1GB) will be downloaded from\n"
69
+ f" HuggingFace Hub ({_MERGED_MODEL_REPO}).\n"
70
+ " This is a one-time download — the model is cached locally after."
71
+ )
72
+ try:
73
+ answer = input(" Continue? (Y/n): ").strip().lower()
74
+ except (EOFError, KeyboardInterrupt):
75
+ answer = "n"
76
+
77
+ if answer in ("", "y", "yes"):
78
+ return True
79
+ else:
80
+ print(" -> JSON formatter disabled for this run.\n")
81
+ return False
82
+
83
+
84
+ def load_formatter(device=None):
85
+ """
86
+ Load the merged formatter model and tokenizer.
87
+
88
+ Args:
89
+ device: Target device. None = auto-detect (CUDA > CPU; MPS skipped).
90
+
91
+ Returns:
92
+ Tuple of (model, tokenizer, device_str).
93
+ """
94
+ _check_dependencies()
95
+
96
+ import torch
97
+ from transformers import AutoModelForCausalLM, AutoTokenizer
98
+
99
+ if device is None:
100
+ if torch.cuda.is_available():
101
+ device = "cuda"
102
+ else:
103
+ # Skip MPS — known PEFT/generation crash issues
104
+ device = "cpu"
105
+
106
+ dtype = torch.float16 if device == "cuda" else torch.float32
107
+
108
+ print(f"[CatLLM] Loading JSON formatter on {device}...")
109
+ tokenizer = AutoTokenizer.from_pretrained(
110
+ _MERGED_MODEL_REPO, trust_remote_code=True
111
+ )
112
+ if tokenizer.pad_token is None:
113
+ tokenizer.pad_token = tokenizer.eos_token
114
+
115
+ model = AutoModelForCausalLM.from_pretrained(
116
+ _MERGED_MODEL_REPO, dtype=dtype, trust_remote_code=True
117
+ )
118
+ model = model.to(device)
119
+ model.eval()
120
+
121
+ print("[CatLLM] JSON formatter ready.")
122
+ return model, tokenizer, device
123
+
124
+
125
+ def run_formatter(raw_output, categories, model, tokenizer, device):
126
+ """
127
+ Run the formatter model to fix malformed classification JSON.
128
+
129
+ Args:
130
+ raw_output: The raw (messy) output from the classification LLM.
131
+ categories: List of category names.
132
+ model: The loaded formatter model.
133
+ tokenizer: The loaded tokenizer.
134
+ device: Device string ("cuda" or "cpu").
135
+
136
+ Returns:
137
+ The formatter's output string (caller should run extract_json on it).
138
+ """
139
+ import torch
140
+
141
+ # Build category list
142
+ cat_lines = "\n".join(
143
+ f"{i + 1}. {cat}" for i, cat in enumerate(categories)
144
+ )
145
+ user_msg = f"Categories:\n{cat_lines}\n\nRaw classification output:\n{raw_output}"
146
+
147
+ messages = [
148
+ {"role": "system", "content": _SYSTEM_PROMPT},
149
+ {"role": "user", "content": user_msg},
150
+ ]
151
+
152
+ text = tokenizer.apply_chat_template(
153
+ messages, tokenize=False, add_generation_prompt=True
154
+ )
155
+ inputs = tokenizer(text, return_tensors="pt").to(device)
156
+
157
+ with torch.no_grad():
158
+ out = model.generate(
159
+ **inputs,
160
+ max_new_tokens=128,
161
+ do_sample=False,
162
+ temperature=None,
163
+ top_p=None,
164
+ pad_token_id=tokenizer.eos_token_id,
165
+ )
166
+
167
+ # Decode only newly generated tokens
168
+ new_tokens = out[0][inputs["input_ids"].shape[1]:]
169
+ return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()