cat-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cat_stack/_chunked.py ADDED
@@ -0,0 +1,424 @@
1
+ """
2
+ Chunked classification for CatLLM.
3
+
4
+ When users have large category lists, this module splits them into smaller
5
+ chunks, runs a separate LLM call per chunk with local 1..N numbering, and
6
+ merges the results back into global numbering so downstream code
7
+ (aggregate_results, build_output_dataframes) sees a single merged JSON dict.
8
+
9
+ Each chunk automatically gets a temporary "Other" catch-all category appended
10
+ (unless one is already present in the chunk). This gives the LLM an escape
11
+ hatch for ambiguous responses, improving classification accuracy. The "Other"
12
+ column is dropped before merging back to global keys, so the final output
13
+ only contains the user's real categories.
14
+ """
15
+
16
+ import json
17
+ import math
18
+
19
+ from .text_functions import (
20
+ build_json_schema,
21
+ extract_json,
22
+ validate_classification_json,
23
+ ollama_two_step_classify,
24
+ )
25
+ from ._category_analysis import has_other_category
26
+
27
+
28
+ def run_chunked_classification(
29
+ *,
30
+ client,
31
+ cfg,
32
+ item,
33
+ categories,
34
+ categories_str,
35
+ example_json,
36
+ json_schema,
37
+ cove_original_task,
38
+ effective_creativity,
39
+ use_json_schema,
40
+ survey_question,
41
+ survey_question_context,
42
+ examples_text,
43
+ chain_of_thought,
44
+ context_prompt,
45
+ step_back_prompt,
46
+ stepback_insights,
47
+ chain_of_verification,
48
+ thinking_budget,
49
+ max_retries,
50
+ multi_label,
51
+ categories_per_call,
52
+ add_unified_other=False,
53
+ formatter_fallback_fn,
54
+ # Mode-specific
55
+ is_pdf_mode,
56
+ is_image_mode,
57
+ pdf_mode=None,
58
+ pdf_dpi=150,
59
+ input_description="",
60
+ # Prompt builders (passed in to avoid circular imports)
61
+ build_text_prompt_fn=None,
62
+ build_pdf_prompt_fn=None,
63
+ build_image_prompt_fn=None,
64
+ google_multimodal_fn=None,
65
+ prepare_page_data_fn=None,
66
+ prepare_image_data_fn=None,
67
+ build_cove_prompts_fn=None,
68
+ run_cove_fn=None,
69
+ ):
70
+ """
71
+ Run chunked classification for one item across category chunks.
72
+
73
+ Splits the full category list into chunks of `categories_per_call`,
74
+ runs one LLM call per chunk, and merges results with key remapping.
75
+
76
+ Returns:
77
+ tuple: (json_result_str, error) — same contract as a single LLM call
78
+ """
79
+ # Build chunks: list of (chunk_categories, global_offset)
80
+ chunks = []
81
+ for start in range(0, len(categories), categories_per_call):
82
+ chunk_cats = categories[start : start + categories_per_call]
83
+ chunks.append((chunk_cats, start))
84
+
85
+ merged_json = {}
86
+ chunk_other_values = [] # Track per-chunk "Other" values for unification
87
+
88
+ for chunk_cats, global_offset in chunks:
89
+ # Add temporary "Other" catch-all if the chunk doesn't already have one.
90
+ # This gives the LLM an escape hatch for ambiguous responses, improving
91
+ # accuracy. The "Other" key is dropped before merging to global keys.
92
+ added_other = False
93
+ num_real_cats = len(chunk_cats)
94
+ if not has_other_category(chunk_cats):
95
+ chunk_cats_for_call = list(chunk_cats) + ["Other"]
96
+ added_other = True
97
+ else:
98
+ chunk_cats_for_call = chunk_cats
99
+
100
+ # Build chunk-local prompt components (with "Other" if added)
101
+ chunk_categories_str = "\n".join(
102
+ f"{j+1}. {cat}" for j, cat in enumerate(chunk_cats_for_call)
103
+ )
104
+ chunk_example_json = json.dumps(
105
+ {str(j + 1): "0" for j in range(len(chunk_cats_for_call))}, indent=2
106
+ )
107
+ chunk_json_schema = (
108
+ build_json_schema(
109
+ chunk_cats_for_call,
110
+ include_additional_properties=(cfg["provider"] != "google"),
111
+ )
112
+ if use_json_schema
113
+ else None
114
+ )
115
+
116
+ # Rebuild CoVe task for this chunk if CoVe enabled
117
+ chunk_cove_task = ""
118
+ if chain_of_verification:
119
+ if multi_label:
120
+ cove_categorize = "into the following categories"
121
+ cove_json = 'Provide your answer in JSON format where the category number is the key and "1" if present, "0" if not.'
122
+ else:
123
+ cove_categorize = "into the single most appropriate category"
124
+ cove_json = 'Provide your answer in JSON format where the category number is the key. Assign "1" to the single best matching category and "0" to all others.'
125
+ chunk_cove_task = f"""{survey_question_context}
126
+ Categorize text responses {cove_categorize}:
127
+ {chunk_categories_str}
128
+ {cove_json}"""
129
+
130
+ # Run one LLM call for this chunk (with "Other" included)
131
+ chunk_result, chunk_error = _run_single_chunk_call(
132
+ client=client,
133
+ cfg=cfg,
134
+ item=item,
135
+ chunk_cats=chunk_cats_for_call,
136
+ chunk_categories_str=chunk_categories_str,
137
+ chunk_json_schema=chunk_json_schema,
138
+ chunk_example_json=chunk_example_json,
139
+ chunk_cove_task=chunk_cove_task,
140
+ effective_creativity=effective_creativity,
141
+ survey_question=survey_question,
142
+ survey_question_context=survey_question_context,
143
+ examples_text=examples_text,
144
+ chain_of_thought=chain_of_thought,
145
+ context_prompt=context_prompt,
146
+ step_back_prompt=step_back_prompt,
147
+ stepback_insights=stepback_insights,
148
+ chain_of_verification=chain_of_verification,
149
+ thinking_budget=thinking_budget,
150
+ max_retries=max_retries,
151
+ multi_label=multi_label,
152
+ formatter_fallback_fn=formatter_fallback_fn,
153
+ is_pdf_mode=is_pdf_mode,
154
+ is_image_mode=is_image_mode,
155
+ pdf_mode=pdf_mode,
156
+ pdf_dpi=pdf_dpi,
157
+ input_description=input_description,
158
+ build_text_prompt_fn=build_text_prompt_fn,
159
+ build_pdf_prompt_fn=build_pdf_prompt_fn,
160
+ build_image_prompt_fn=build_image_prompt_fn,
161
+ google_multimodal_fn=google_multimodal_fn,
162
+ prepare_page_data_fn=prepare_page_data_fn,
163
+ prepare_image_data_fn=prepare_image_data_fn,
164
+ build_cove_prompts_fn=build_cove_prompts_fn,
165
+ run_cove_fn=run_cove_fn,
166
+ )
167
+
168
+ if chunk_error:
169
+ return (json.dumps(merged_json) if merged_json else '{"1":"e"}', chunk_error)
170
+
171
+ # Remap chunk-local keys (1..N) to global keys, dropping "Other"
172
+ try:
173
+ chunk_parsed = json.loads(chunk_result)
174
+ except (json.JSONDecodeError, TypeError):
175
+ return ('{"1":"e"}', f"Failed to parse chunk result: {chunk_result}")
176
+
177
+ # The "Other" key (if added) is the last one: str(num_real_cats + 1)
178
+ other_local_key = str(num_real_cats + 1) if added_other else None
179
+
180
+ for local_key, value in chunk_parsed.items():
181
+ # Capture the temporary "Other" value, don't merge it
182
+ if local_key == other_local_key:
183
+ try:
184
+ chunk_other_values.append(int(value))
185
+ except (ValueError, TypeError):
186
+ chunk_other_values.append(0)
187
+ continue
188
+ try:
189
+ global_key = str(global_offset + int(local_key))
190
+ merged_json[global_key] = value
191
+ except (ValueError, TypeError):
192
+ # Non-numeric key — skip (shouldn't happen with proper schemas)
193
+ pass
194
+
195
+ # Unified "Other": if all real categories are 0 but at least one chunk's
196
+ # "Other" fired, the response genuinely doesn't fit any category.
197
+ if add_unified_other:
198
+ real_sum = sum(
199
+ int(v) for v in merged_json.values()
200
+ if str(v).strip() in ("0", "1")
201
+ )
202
+ other_sum = sum(chunk_other_values)
203
+ unified_other = "1" if real_sum == 0 and other_sum > 0 else "0"
204
+ merged_json[str(len(categories) + 1)] = unified_other
205
+
206
+ return (json.dumps(merged_json), None)
207
+
208
+
209
+ def _run_single_chunk_call(
210
+ *,
211
+ client,
212
+ cfg,
213
+ item,
214
+ chunk_cats,
215
+ chunk_categories_str,
216
+ chunk_json_schema,
217
+ chunk_example_json,
218
+ chunk_cove_task,
219
+ effective_creativity,
220
+ survey_question,
221
+ survey_question_context,
222
+ examples_text,
223
+ chain_of_thought,
224
+ context_prompt,
225
+ step_back_prompt,
226
+ stepback_insights,
227
+ chain_of_verification,
228
+ thinking_budget,
229
+ max_retries,
230
+ multi_label,
231
+ formatter_fallback_fn,
232
+ is_pdf_mode,
233
+ is_image_mode,
234
+ pdf_mode,
235
+ pdf_dpi,
236
+ input_description,
237
+ build_text_prompt_fn,
238
+ build_pdf_prompt_fn,
239
+ build_image_prompt_fn,
240
+ google_multimodal_fn,
241
+ prepare_page_data_fn,
242
+ prepare_image_data_fn,
243
+ build_cove_prompts_fn,
244
+ run_cove_fn,
245
+ ):
246
+ """
247
+ Run one LLM call for one chunk of categories on one item.
248
+
249
+ Returns:
250
+ tuple: (json_result_str, error)
251
+ """
252
+ thinking_providers = ("google", "openai", "anthropic", "huggingface", "huggingface-together")
253
+
254
+ # =================================================================
255
+ # PDF MODE
256
+ # =================================================================
257
+ if is_pdf_mode and isinstance(item, tuple):
258
+ pdf_path, page_index, page_label = item
259
+
260
+ page_data = prepare_page_data_fn(
261
+ pdf_path=pdf_path,
262
+ page_index=page_index,
263
+ page_label=page_label,
264
+ pdf_mode=pdf_mode,
265
+ provider=cfg["provider"],
266
+ pdf_dpi=pdf_dpi,
267
+ )
268
+
269
+ if page_data.get("error"):
270
+ return ('{"1":"e"}', page_data["error"])
271
+
272
+ messages = build_pdf_prompt_fn(
273
+ page_data=page_data,
274
+ categories_str=chunk_categories_str,
275
+ input_description=input_description,
276
+ provider=cfg["provider"],
277
+ pdf_mode=pdf_mode,
278
+ chain_of_thought=chain_of_thought,
279
+ context_prompt=context_prompt,
280
+ step_back_prompt=step_back_prompt,
281
+ stepback_insights=stepback_insights,
282
+ model_name=cfg["model"],
283
+ example_json=chunk_example_json,
284
+ multi_label=multi_label,
285
+ )
286
+
287
+ if cfg["provider"] == "google":
288
+ reply, error = google_multimodal_fn(
289
+ client=client,
290
+ messages=messages,
291
+ json_schema=chunk_json_schema,
292
+ creativity=effective_creativity,
293
+ thinking_budget=thinking_budget,
294
+ max_retries=max_retries,
295
+ )
296
+ else:
297
+ reply, error = client.complete(
298
+ messages=messages,
299
+ json_schema=chunk_json_schema,
300
+ creativity=effective_creativity,
301
+ thinking_budget=thinking_budget if cfg["provider"] in thinking_providers else None,
302
+ max_retries=max_retries,
303
+ )
304
+
305
+ if error:
306
+ return ('{"1":"e"}', error)
307
+
308
+ json_result = extract_json(reply)
309
+ json_result = formatter_fallback_fn(json_result, reply, chunk_cats)
310
+ return (json_result, None)
311
+
312
+ # =================================================================
313
+ # IMAGE MODE
314
+ # =================================================================
315
+ elif is_image_mode and isinstance(item, tuple):
316
+ image_path, image_label = item
317
+
318
+ image_data = prepare_image_data_fn(image_path, image_label)
319
+
320
+ if image_data.get("error"):
321
+ return ('{"1":"e"}', image_data["error"])
322
+
323
+ messages = build_image_prompt_fn(
324
+ image_data=image_data,
325
+ categories_str=chunk_categories_str,
326
+ input_description=input_description,
327
+ provider=cfg["provider"],
328
+ chain_of_thought=chain_of_thought,
329
+ context_prompt=context_prompt,
330
+ step_back_prompt=step_back_prompt,
331
+ stepback_insights=stepback_insights,
332
+ model_name=cfg["model"],
333
+ example_json=chunk_example_json,
334
+ multi_label=multi_label,
335
+ )
336
+
337
+ if cfg["provider"] == "google":
338
+ reply, error = google_multimodal_fn(
339
+ client=client,
340
+ messages=messages,
341
+ json_schema=chunk_json_schema,
342
+ creativity=effective_creativity,
343
+ thinking_budget=thinking_budget,
344
+ max_retries=max_retries,
345
+ )
346
+ else:
347
+ reply, error = client.complete(
348
+ messages=messages,
349
+ json_schema=chunk_json_schema,
350
+ creativity=effective_creativity,
351
+ thinking_budget=thinking_budget if cfg["provider"] in thinking_providers else None,
352
+ max_retries=max_retries,
353
+ )
354
+
355
+ if error:
356
+ return ('{"1":"e"}', error)
357
+
358
+ json_result = extract_json(reply)
359
+ json_result = formatter_fallback_fn(json_result, reply, chunk_cats)
360
+ return (json_result, None)
361
+
362
+ # =================================================================
363
+ # TEXT MODE
364
+ # =================================================================
365
+ else:
366
+ response_text = item
367
+
368
+ if cfg["use_two_step"]: # Ollama
369
+ json_result, error = ollama_two_step_classify(
370
+ client=client,
371
+ response_text=response_text,
372
+ categories=chunk_cats,
373
+ categories_str=chunk_categories_str,
374
+ survey_question=survey_question,
375
+ creativity=effective_creativity,
376
+ max_retries=max_retries,
377
+ )
378
+ if not error:
379
+ json_result = formatter_fallback_fn(json_result, json_result, chunk_cats)
380
+ return (json_result, error)
381
+ else:
382
+ messages = build_text_prompt_fn(
383
+ response_text=response_text,
384
+ categories_str=chunk_categories_str,
385
+ survey_question_context=survey_question_context,
386
+ examples_text=examples_text,
387
+ chain_of_thought=chain_of_thought,
388
+ context_prompt=context_prompt,
389
+ step_back_prompt=step_back_prompt,
390
+ stepback_insights=stepback_insights,
391
+ model_name=cfg["model"],
392
+ multi_label=multi_label,
393
+ )
394
+ reply, error = client.complete(
395
+ messages=messages,
396
+ json_schema=chunk_json_schema,
397
+ creativity=effective_creativity,
398
+ thinking_budget=thinking_budget if cfg["provider"] in thinking_providers else None,
399
+ max_retries=max_retries,
400
+ )
401
+ if error:
402
+ return ('{"1":"e"}', error)
403
+
404
+ json_result = extract_json(reply)
405
+ json_result = formatter_fallback_fn(json_result, reply, chunk_cats)
406
+
407
+ # Run Chain of Verification if enabled
408
+ if chain_of_verification:
409
+ step2, step3, step4 = build_cove_prompts_fn(
410
+ chunk_cove_task, response_text
411
+ )
412
+ json_result = run_cove_fn(
413
+ client=client,
414
+ initial_reply=json_result,
415
+ step2_prompt=step2,
416
+ step3_prompt=step3,
417
+ step4_prompt=step4,
418
+ json_schema=chunk_json_schema,
419
+ creativity=effective_creativity,
420
+ max_retries=max_retries,
421
+ )
422
+ json_result = formatter_fallback_fn(json_result, json_result, chunk_cats)
423
+
424
+ return (json_result, None)
@@ -0,0 +1,189 @@
1
+ """
2
+ Embedding-based similarity scores for CatLLM.
3
+
4
+ Uses a local sentence-transformer model (BAAI/bge-small-en-v1.5, 33M params,
5
+ ~130MB) to compute cosine similarity between each input text and each category.
6
+ Scores are independent per (text, category) pair — no softmax across categories,
7
+ since this is multi-label classification.
8
+
9
+ The embeddings feature is opt-in via embeddings=True on classify(). It adds
10
+ `_similarity` columns alongside the existing binary 0/1 classification columns.
11
+
12
+ Requires: pip install cat-llm[embeddings]
13
+ """
14
+
15
+ import pandas as pd
16
+
17
+ _EMBEDDING_MODEL_NAME = "BAAI/bge-small-en-v1.5"
18
+
19
+
20
+ def _check_dependencies():
21
+ """Check that sentence-transformers is installed."""
22
+ try:
23
+ import sentence_transformers # noqa: F401
24
+ except ImportError:
25
+ raise ImportError(
26
+ "The embeddings feature requires sentence-transformers.\n"
27
+ "Install with: pip install cat-llm[embeddings]\n"
28
+ " (requires: sentence-transformers, which pulls in torch and transformers)"
29
+ )
30
+
31
+
32
+ def _is_model_cached() -> bool:
33
+ """Check if the embedding model is already in the HuggingFace cache."""
34
+ try:
35
+ from huggingface_hub import try_to_load_from_cache
36
+ result = try_to_load_from_cache(_EMBEDDING_MODEL_NAME, "config.json")
37
+ return result is not None and not isinstance(result, type(None))
38
+ except Exception:
39
+ return False
40
+
41
+
42
+ def ensure_embeddings_available() -> bool:
43
+ """
44
+ Ensure the embedding model is available, prompting to download if needed.
45
+
46
+ Returns:
47
+ True if the model is ready to use, False if user declined download.
48
+ """
49
+ _check_dependencies()
50
+
51
+ if _is_model_cached():
52
+ return True
53
+
54
+ print(
55
+ "\n[CatLLM] The embedding model (~130MB) will be downloaded from\n"
56
+ f" HuggingFace Hub ({_EMBEDDING_MODEL_NAME}).\n"
57
+ " This is a one-time download — the model is cached locally after."
58
+ )
59
+ try:
60
+ answer = input(" Continue? (Y/n): ").strip().lower()
61
+ except (EOFError, KeyboardInterrupt):
62
+ answer = "n"
63
+
64
+ if answer in ("", "y", "yes"):
65
+ return True
66
+ else:
67
+ print(" -> Embedding scores disabled for this run.\n")
68
+ return False
69
+
70
+
71
+ def load_embedding_model():
72
+ """
73
+ Load and return the sentence-transformer embedding model.
74
+
75
+ Returns:
76
+ SentenceTransformer model instance.
77
+ """
78
+ _check_dependencies()
79
+
80
+ from sentence_transformers import SentenceTransformer
81
+
82
+ print(f"[CatLLM] Loading embedding model ({_EMBEDDING_MODEL_NAME})...")
83
+ model = SentenceTransformer(_EMBEDDING_MODEL_NAME)
84
+ print("[CatLLM] Embedding model ready.")
85
+ return model
86
+
87
+
88
+ def compute_embedding_scores(texts, categories, model, category_descriptions=None):
89
+ """
90
+ Compute cosine similarity scores between texts and categories.
91
+
92
+ Each (text, category) score is independent — no softmax across categories.
93
+ Raw cosine similarity is rescaled from [-1, 1] to [0, 1] via (sim + 1) / 2.
94
+
95
+ Args:
96
+ texts: List of input text strings.
97
+ categories: List of category name strings.
98
+ model: Loaded SentenceTransformer model.
99
+ category_descriptions: Optional dict mapping category names to richer
100
+ descriptions for embedding (e.g., {"Past_Support": "References to
101
+ help received from family in the past"}).
102
+
103
+ Returns:
104
+ Dict mapping "category_N_similarity" -> list of float scores, where N
105
+ is 1-indexed to match the existing classification column naming.
106
+ """
107
+ from sentence_transformers import util
108
+
109
+ # Convert NaN/None to empty string
110
+ clean_texts = [str(t) if pd.notna(t) else "" for t in texts]
111
+
112
+ # Build category strings for embedding
113
+ cat_strings = []
114
+ for cat in categories:
115
+ if category_descriptions and cat in category_descriptions:
116
+ cat_strings.append(f"{cat}: {category_descriptions[cat]}")
117
+ else:
118
+ cat_strings.append(cat)
119
+
120
+ # Encode all texts and categories
121
+ text_embeddings = model.encode(clean_texts, normalize_embeddings=True,
122
+ show_progress_bar=len(clean_texts) > 100)
123
+ cat_embeddings = model.encode(cat_strings, normalize_embeddings=True)
124
+
125
+ # Compute cosine similarity matrix: (num_texts, num_categories)
126
+ sim_matrix = util.cos_sim(text_embeddings, cat_embeddings)
127
+
128
+ # Rescale from [-1, 1] to [0, 1]
129
+ scores = (sim_matrix + 1) / 2
130
+
131
+ # Build output dict
132
+ result = {}
133
+ for i, _cat in enumerate(categories):
134
+ col_name = f"category_{i + 1}_similarity"
135
+ result[col_name] = [round(float(scores[row][i]), 4) for row in range(len(clean_texts))]
136
+
137
+ return result
138
+
139
+
140
+ def apply_embedding_scores(df, categories, embedding_model, category_descriptions=None):
141
+ """
142
+ Insert embedding similarity columns into a result DataFrame.
143
+
144
+ For each category N, a `category_N_similarity` column is inserted after the
145
+ last existing column that belongs to that category number.
146
+
147
+ Args:
148
+ df: Result DataFrame from classify (single-model or ensemble).
149
+ categories: List of category name strings.
150
+ embedding_model: Loaded SentenceTransformer model.
151
+ category_descriptions: Optional dict mapping category names to descriptions.
152
+
153
+ Returns:
154
+ DataFrame with `_similarity` columns inserted.
155
+ """
156
+ # Find the text column to use for embedding
157
+ if "input_data" in df.columns:
158
+ texts = df["input_data"].tolist()
159
+ else:
160
+ # Fallback: use first column
161
+ texts = df.iloc[:, 0].tolist()
162
+
163
+ scores = compute_embedding_scores(texts, categories, embedding_model,
164
+ category_descriptions)
165
+
166
+ # Insert each _similarity column after the last column for that category number
167
+ result_df = df.copy()
168
+ for i in range(len(categories)):
169
+ prob_col = f"category_{i + 1}_similarity"
170
+ prob_values = scores[prob_col]
171
+
172
+ # Find the last column that starts with "category_{N}_" or equals "category_{N}"
173
+ # Use exact match on the number to avoid category_1 matching category_10
174
+ cat_prefix = f"category_{i + 1}_"
175
+ cat_exact = f"category_{i + 1}"
176
+
177
+ last_pos = -1
178
+ for col_idx, col_name in enumerate(result_df.columns):
179
+ if col_name == cat_exact or col_name.startswith(cat_prefix):
180
+ last_pos = col_idx
181
+
182
+ if last_pos >= 0:
183
+ # Insert after the last matching column
184
+ result_df.insert(last_pos + 1, prob_col, prob_values)
185
+ else:
186
+ # No matching column found — append at the end
187
+ result_df[prob_col] = prob_values
188
+
189
+ return result_df