cat-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,348 @@
1
+ """
2
+ Category analysis utilities for CatLLM.
3
+
4
+ Provides functions for analyzing user-provided category lists,
5
+ such as detecting whether an "Other" catch-all category exists.
6
+ """
7
+
8
+ import json
9
+ import re
10
+
11
+ from .text_functions import UnifiedLLMClient, detect_provider
12
+
13
+ __all__ = ["has_other_category", "check_category_verbosity"]
14
+
15
+ # Max words for a category to be checked against broad phrase patterns.
16
+ # Real catch-all categories are short ("Other", "None of the above", "Does not fit").
17
+ # Longer categories using these words ("Does not fit the clinical profile") are
18
+ # specific descriptive labels, not catch-alls.
19
+ _MAX_HEURISTIC_WORDS = 4
20
+
21
+ # Tier 1: Anchored patterns — safe at any category length.
22
+ # These only match when the keyword IS the category label itself.
23
+ _ANCHORED_PATTERNS = [
24
+ re.compile(r"^other\s*$", re.IGNORECASE), # exact "Other"
25
+ re.compile(r"^other\s*[:(]", re.IGNORECASE), # "Other: ...", "Other (..."
26
+ re.compile(r"^n/?a\s*$", re.IGNORECASE), # exact "N/A", "NA"
27
+ re.compile(r"^miscellaneous\s*$", re.IGNORECASE), # exact "Miscellaneous"
28
+ re.compile(r"^catch[\s-]?all\s*$", re.IGNORECASE), # exact "catch-all"
29
+ ]
30
+
31
+ # Tier 2: Phrase patterns — only applied to short categories (≤ _MAX_HEURISTIC_WORDS).
32
+ # Multi-word phrases that clearly signal a catch-all when they dominate the category name.
33
+ _SHORT_ONLY_PATTERNS = [
34
+ re.compile(r"\bnone of the above\b", re.IGNORECASE),
35
+ re.compile(r"\bdoes not fit\b", re.IGNORECASE),
36
+ re.compile(r"\bdoesn't fit\b", re.IGNORECASE),
37
+ re.compile(r"\bnot applicable\b", re.IGNORECASE),
38
+ re.compile(r"\bnone apply\b", re.IGNORECASE),
39
+ re.compile(r"\bnone of these\b", re.IGNORECASE),
40
+ ]
41
+
42
+ # Top-tier model per provider for the LLM fallback
43
+ _TOP_TIER_MODELS = {
44
+ "openai": "gpt-4o",
45
+ "anthropic": "claude-sonnet-4-5-20250929",
46
+ "google": "gemini-2.5-flash",
47
+ "mistral": "mistral-large-latest",
48
+ "xai": "grok-2",
49
+ "perplexity": "sonar-pro",
50
+ "huggingface": "meta-llama/Llama-3.3-70B-Instruct",
51
+ }
52
+
53
+
54
+ def _heuristic_check(categories: list) -> bool:
55
+ """
56
+ Fast, free check for common "Other" category patterns.
57
+
58
+ Uses a two-tier approach to avoid false positives:
59
+ - Tier 1 (anchored): matches at any length — the pattern is specific enough
60
+ (e.g. exact "Other", "N/A", or "Other: …" label prefix).
61
+ - Tier 2 (phrase): only matches short categories (≤ _MAX_HEURISTIC_WORDS words).
62
+ Phrases like "does not fit" are catch-alls when they ARE the category, but
63
+ not when embedded in longer descriptions ("Does not fit the clinical profile").
64
+
65
+ Returns True if any category matches a known catch-all pattern.
66
+ """
67
+ for cat in categories:
68
+ cat_str = str(cat).strip()
69
+
70
+ # Tier 1: anchored patterns — safe at any length
71
+ for pattern in _ANCHORED_PATTERNS:
72
+ if pattern.search(cat_str):
73
+ return True
74
+
75
+ # Tier 2: phrase patterns — only for short categories
76
+ if len(cat_str.split()) <= _MAX_HEURISTIC_WORDS:
77
+ for pattern in _SHORT_ONLY_PATTERNS:
78
+ if pattern.search(cat_str):
79
+ return True
80
+
81
+ return False
82
+
83
+
84
+ def _llm_check(categories: list, api_key: str, model: str, provider: str) -> bool:
85
+ """
86
+ Use an LLM to determine whether the category list contains a catch-all.
87
+
88
+ Makes a single API call and parses a yes/no answer.
89
+
90
+ Returns True if the LLM judges a catch-all category exists, False otherwise.
91
+ """
92
+ cat_list = "\n".join(f"- {c}" for c in categories)
93
+ messages = [
94
+ {
95
+ "role": "system",
96
+ "content": (
97
+ "You are a helpful assistant. Answer with ONLY 'yes' or 'no', "
98
+ "nothing else."
99
+ ),
100
+ },
101
+ {
102
+ "role": "user",
103
+ "content": (
104
+ "Does the following list of categories contain a catch-all or "
105
+ "'Other' category — i.e., a category meant to capture responses "
106
+ "that don't fit any of the specific categories?\n\n"
107
+ f"Categories:\n{cat_list}\n\n"
108
+ "Answer 'yes' or 'no'."
109
+ ),
110
+ },
111
+ ]
112
+
113
+ client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
114
+ response_text, error = client.complete(
115
+ messages=messages,
116
+ force_json=False,
117
+ max_retries=2,
118
+ creativity=0.0,
119
+ )
120
+
121
+ if error or not response_text:
122
+ return False
123
+
124
+ # Strip whitespace and punctuation, then check for affirmative answer
125
+ answer = response_text.strip().lower().rstrip(".!,;:")
126
+ return answer in ("yes", "true")
127
+
128
+
129
+ def _resolve_provider_and_model(user_model, model_source):
130
+ """Resolve provider and model from user args, falling back to top-tier defaults."""
131
+ if user_model is not None:
132
+ provider = detect_provider(user_model, provider=model_source)
133
+ model = user_model
134
+ else:
135
+ if model_source and model_source.lower() != "auto":
136
+ provider = model_source.lower()
137
+ else:
138
+ provider = "openai"
139
+ model = _TOP_TIER_MODELS.get(provider, "gpt-4o")
140
+ return provider, model
141
+
142
+
143
+ def has_other_category(
144
+ categories: list,
145
+ api_key: str = None,
146
+ user_model: str = None,
147
+ model_source: str = "auto",
148
+ ) -> bool:
149
+ """
150
+ Detect whether a list of categories contains a catch-all / "Other" category.
151
+
152
+ Uses a two-stage approach:
153
+ 1. **Heuristic** (free, instant) — checks for common patterns like "Other",
154
+ "None of the above", "Miscellaneous", etc.
155
+ 2. **LLM fallback** (1 API call) — if the heuristic finds nothing and an
156
+ ``api_key`` is provided, asks an LLM to judge whether a catch-all exists.
157
+
158
+ Args:
159
+ categories: List of category strings to analyze.
160
+ api_key: Optional API key for the LLM fallback. If not provided and the
161
+ heuristic doesn't match, the function returns ``False``.
162
+ user_model: Optional model name for the LLM fallback. If not provided,
163
+ a top-tier default model is selected based on the provider.
164
+ model_source: Provider to use for the LLM fallback (e.g. "openai",
165
+ "anthropic", "google"). Defaults to "auto" which auto-detects
166
+ from ``user_model``, or falls back to "openai" when no model
167
+ is specified.
168
+
169
+ Returns:
170
+ ``True`` if a catch-all / "Other" category is detected, ``False`` otherwise.
171
+
172
+ Examples:
173
+ >>> has_other_category(["Positive", "Negative", "Other"])
174
+ True
175
+
176
+ >>> has_other_category(["Positive", "Negative"])
177
+ False
178
+
179
+ >>> has_other_category(
180
+ ... ["Happy", "Sad", "Doesn't fit any category"],
181
+ ... api_key="sk-...",
182
+ ... )
183
+ True
184
+ """
185
+ if not categories:
186
+ return False
187
+
188
+ # Stage 1: heuristic
189
+ if _heuristic_check(categories):
190
+ return True
191
+
192
+ # Stage 2: LLM fallback (only if api_key provided)
193
+ if api_key is None:
194
+ return False
195
+
196
+ provider, model = _resolve_provider_and_model(user_model, model_source)
197
+ return _llm_check(categories, api_key, model, provider)
198
+
199
+
200
+ # =============================================================================
201
+ # Category Verbosity Check
202
+ # =============================================================================
203
+
204
+ def check_category_verbosity(
205
+ categories: list,
206
+ api_key: str,
207
+ user_model: str = None,
208
+ model_source: str = "auto",
209
+ ) -> list:
210
+ """
211
+ Assess whether each category has a clear description and illustrative examples.
212
+
213
+ Makes a single LLM call to evaluate all categories at once. Returns per-category
214
+ flags indicating what's present and what's missing.
215
+
216
+ Args:
217
+ categories: List of category strings to analyze.
218
+ api_key: API key for the LLM provider (required).
219
+ user_model: Model name to use. If not provided, a top-tier default is
220
+ selected based on the provider.
221
+ model_source: Provider (e.g. "openai", "anthropic", "google").
222
+ Defaults to "auto".
223
+
224
+ Returns:
225
+ A list of dicts, one per category, each containing::
226
+
227
+ {
228
+ "category": str, # the original category text
229
+ "has_description": bool, # has an explanation beyond a bare label
230
+ "has_examples": bool, # includes concrete examples
231
+ "is_verbose": bool, # True if BOTH description and examples present
232
+ }
233
+
234
+ Examples:
235
+ >>> check_category_verbosity(
236
+ ... ["Positive", "Negative: expresses dissatisfaction (e.g., 'I hate this')"],
237
+ ... api_key="sk-...",
238
+ ... )
239
+ [
240
+ {"category": "Positive", "has_description": False, "has_examples": False, "is_verbose": False},
241
+ {"category": "Negative: ...", "has_description": True, "has_examples": True, "is_verbose": True},
242
+ ]
243
+ """
244
+ if not categories:
245
+ return []
246
+
247
+ provider, model = _resolve_provider_and_model(user_model, model_source)
248
+
249
+ # Build numbered list for the prompt
250
+ cat_list = "\n".join(f"{i+1}. {c}" for i, c in enumerate(categories))
251
+
252
+ messages = [
253
+ {
254
+ "role": "system",
255
+ "content": (
256
+ "You are an expert at evaluating classification category definitions. "
257
+ "Return ONLY valid JSON, no other text."
258
+ ),
259
+ },
260
+ {
261
+ "role": "user",
262
+ "content": (
263
+ "For each category below, assess two things:\n"
264
+ "1. **has_description**: Does it include an explanation or clarification "
265
+ "beyond just a bare label? (e.g., 'Positive: the response expresses "
266
+ "satisfaction or approval' has a description, but just 'Positive' does not)\n"
267
+ "2. **has_examples**: Does it include concrete examples of what belongs "
268
+ "in the category? (e.g., 'such as rent increases, pay cuts' or "
269
+ "'e.g., I love this product')\n\n"
270
+ f"Categories:\n{cat_list}\n\n"
271
+ 'Return a JSON object with a "results" array containing one object per '
272
+ "category (in the same order), each with:\n"
273
+ '- "category_number": the 1-based index\n'
274
+ '- "has_description": true or false\n'
275
+ '- "has_examples": true or false\n\n'
276
+ "Example response format:\n"
277
+ '{"results": [{"category_number": 1, "has_description": false, '
278
+ '"has_examples": false}, ...]}'
279
+ ),
280
+ },
281
+ ]
282
+
283
+ client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
284
+ response_text, error = client.complete(
285
+ messages=messages,
286
+ force_json=True,
287
+ max_retries=3,
288
+ creativity=0.0,
289
+ )
290
+
291
+ # Parse the LLM response
292
+ results = _parse_verbosity_response(response_text, error, categories)
293
+ return results
294
+
295
+
296
+ def _parse_verbosity_response(response_text, error, categories):
297
+ """Parse LLM response into per-category verbosity flags."""
298
+ # Default: assume nothing is verbose (safe fallback)
299
+ default = [
300
+ {
301
+ "category": cat,
302
+ "has_description": False,
303
+ "has_examples": False,
304
+ "is_verbose": False,
305
+ }
306
+ for cat in categories
307
+ ]
308
+
309
+ if error or not response_text:
310
+ return default
311
+
312
+ try:
313
+ data = json.loads(response_text)
314
+ except json.JSONDecodeError:
315
+ # Try extracting JSON from the response (may have markdown wrapping)
316
+ match = re.search(r'\{.*\}', response_text, re.DOTALL)
317
+ if not match:
318
+ return default
319
+ try:
320
+ data = json.loads(match.group())
321
+ except json.JSONDecodeError:
322
+ return default
323
+
324
+ llm_results = data.get("results", [])
325
+
326
+ output = []
327
+ for i, cat in enumerate(categories):
328
+ # Find the matching LLM result by index
329
+ llm_entry = None
330
+ for entry in llm_results:
331
+ if entry.get("category_number") == i + 1:
332
+ llm_entry = entry
333
+ break
334
+ # Fall back to positional match
335
+ if llm_entry is None and i < len(llm_results):
336
+ llm_entry = llm_results[i]
337
+
338
+ has_desc = bool(llm_entry.get("has_description", False)) if llm_entry else False
339
+ has_ex = bool(llm_entry.get("has_examples", False)) if llm_entry else False
340
+
341
+ output.append({
342
+ "category": cat,
343
+ "has_description": has_desc,
344
+ "has_examples": has_ex,
345
+ "is_verbose": has_desc and has_ex,
346
+ })
347
+
348
+ return output