cat-stack 1.6.8__tar.gz → 2.0.0b1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/.gitignore +8 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/PKG-INFO +1 -1
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/__about__.py +1 -1
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/__init__.py +2 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_formatter.py +90 -8
- cat_stack-2.0.0b1/src/catstack/collapse_themes.py +364 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/LICENSE +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/README.md +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/pyproject.toml +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/cat_stack/__init__.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_batch.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_category_analysis.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_chunked.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_embeddings.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_pilot_test.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_prompts.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_providers.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_review_ui.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_tiebreaker.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_utils.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_web_fetch.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/_wrapper_helpers.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/calls/CoVe.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/calls/__init__.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/calls/image_CoVe.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/calls/image_stepback.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/calls/pdf_CoVe.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/calls/pdf_stepback.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/calls/stepback.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/calls/top_n.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/classify.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/explore.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/extract.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/image_functions.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/images/circle.png +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/images/cube.png +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/images/diamond.png +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/images/overlapping_pentagons.png +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/images/rectangles.png +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/model_reference_list.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/pdf_functions.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/prompt_tune.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/summarize.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/text_functions.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0b1}/src/catstack/text_functions_ensemble.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-stack
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.0.0b1
|
|
4
4
|
Summary: Domain-agnostic text, image, PDF, and DOCX classification engine powered by LLMs
|
|
5
5
|
Project-URL: Documentation, https://github.com/chrissoria/cat-stack#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-stack/issues
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
4
|
-
__version__ = "
|
|
4
|
+
__version__ = "2.0.0b1"
|
|
5
5
|
__author__ = "Chris Soria"
|
|
6
6
|
__email__ = "chrissoria@berkeley.edu"
|
|
7
7
|
__title__ = "cat-stack"
|
|
@@ -18,6 +18,7 @@ from .__about__ import (
|
|
|
18
18
|
# Main entry points
|
|
19
19
|
from .extract import extract
|
|
20
20
|
from .explore import explore
|
|
21
|
+
from .collapse_themes import collapse_themes
|
|
21
22
|
from .classify import classify
|
|
22
23
|
from .summarize import summarize
|
|
23
24
|
from .prompt_tune import prompt_tune
|
|
@@ -103,6 +104,7 @@ __all__ = [
|
|
|
103
104
|
# Main entry points
|
|
104
105
|
"extract",
|
|
105
106
|
"explore",
|
|
107
|
+
"collapse_themes",
|
|
106
108
|
"classify",
|
|
107
109
|
"summarize",
|
|
108
110
|
"prompt_tune",
|
|
@@ -45,6 +45,11 @@ def _check_dependencies():
|
|
|
45
45
|
def _check_dependencies_installed() -> bool:
|
|
46
46
|
"""Pure check — returns True if all formatter deps import successfully.
|
|
47
47
|
No side effects, no install attempts."""
|
|
48
|
+
# If a dep was just pip-installed in this process's lifetime, the import
|
|
49
|
+
# system may have cached its earlier absence; clear that so the re-check
|
|
50
|
+
# actually sees the freshly-installed package.
|
|
51
|
+
import importlib
|
|
52
|
+
importlib.invalidate_caches()
|
|
48
53
|
try:
|
|
49
54
|
import torch # noqa: F401
|
|
50
55
|
import transformers # noqa: F401
|
|
@@ -165,7 +170,31 @@ def _ensure_dependencies(verbose: bool = True) -> bool:
|
|
|
165
170
|
" To skip this and disable the formatter, pass json_formatter=False."
|
|
166
171
|
)
|
|
167
172
|
|
|
168
|
-
|
|
173
|
+
ok = _install_dependencies(verbose=verbose)
|
|
174
|
+
if not ok:
|
|
175
|
+
# Freshly pip-installed packages (esp. compiled ones like torch) often
|
|
176
|
+
# cannot be imported by the SAME running process — but they ARE on disk
|
|
177
|
+
# now. Tell the user to re-run rather than silently degrading every row
|
|
178
|
+
# to an error.
|
|
179
|
+
if verbose and _deps_on_disk():
|
|
180
|
+
print(
|
|
181
|
+
"[CatLLM] Formatter dependencies were just installed but cannot "
|
|
182
|
+
"be imported into the already-running process. Please RE-RUN your "
|
|
183
|
+
"command — they will load on the next start. (Avoid this by "
|
|
184
|
+
"pre-installing: pip install 'cat-stack[formatter]'.)"
|
|
185
|
+
)
|
|
186
|
+
return ok
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _deps_on_disk() -> bool:
|
|
190
|
+
"""True if the formatter deps are findable on disk (importable in a FRESH
|
|
191
|
+
process) even if they failed to import in the current one."""
|
|
192
|
+
import importlib.util
|
|
193
|
+
try:
|
|
194
|
+
return all(importlib.util.find_spec(m) is not None
|
|
195
|
+
for m in ("torch", "transformers", "accelerate"))
|
|
196
|
+
except (ImportError, ValueError):
|
|
197
|
+
return False
|
|
169
198
|
|
|
170
199
|
|
|
171
200
|
def _is_model_cached() -> bool:
|
|
@@ -205,6 +234,51 @@ def ensure_formatter_available() -> bool:
|
|
|
205
234
|
return True # actual download happens in load_formatter()
|
|
206
235
|
|
|
207
236
|
|
|
237
|
+
def _load_formatter_tokenizer(AutoTokenizer):
|
|
238
|
+
"""Load the formatter tokenizer, defending against a malformed
|
|
239
|
+
`tokenizer_config.json`.
|
|
240
|
+
|
|
241
|
+
Some published configs store `extra_special_tokens` as a LIST, but
|
|
242
|
+
transformers 4.56–4.57.x expect a dict and crash in
|
|
243
|
+
`_set_model_specific_special_tokens` with
|
|
244
|
+
`'list' object has no attribute 'keys'`. On that failure we snapshot the
|
|
245
|
+
repo locally, normalize a list-valued `extra_special_tokens` to `{}`
|
|
246
|
+
(the tokens already live in `added_tokens`/`special_tokens_map`, so
|
|
247
|
+
dropping the field is lossless), and load from the patched local copy.
|
|
248
|
+
"""
|
|
249
|
+
try:
|
|
250
|
+
return AutoTokenizer.from_pretrained(
|
|
251
|
+
_MERGED_MODEL_REPO, trust_remote_code=True
|
|
252
|
+
)
|
|
253
|
+
except (AttributeError, TypeError) as e:
|
|
254
|
+
if "keys" not in str(e) and "extra_special_tokens" not in str(e):
|
|
255
|
+
raise
|
|
256
|
+
import json
|
|
257
|
+
import os
|
|
258
|
+
from huggingface_hub import snapshot_download
|
|
259
|
+
|
|
260
|
+
local_dir = snapshot_download(_MERGED_MODEL_REPO)
|
|
261
|
+
cfg_path = os.path.join(local_dir, "tokenizer_config.json")
|
|
262
|
+
with open(cfg_path) as f:
|
|
263
|
+
cfg = json.load(f)
|
|
264
|
+
if isinstance(cfg.get("extra_special_tokens"), list):
|
|
265
|
+
cfg["extra_special_tokens"] = {}
|
|
266
|
+
# snapshot dirs are often read-only symlink caches; patch a copy.
|
|
267
|
+
import tempfile
|
|
268
|
+
import shutil
|
|
269
|
+
patched = tempfile.mkdtemp(prefix="catllm_formatter_tok_")
|
|
270
|
+
for fn in os.listdir(local_dir):
|
|
271
|
+
src = os.path.join(local_dir, fn)
|
|
272
|
+
if os.path.isfile(src):
|
|
273
|
+
shutil.copy(src, os.path.join(patched, fn))
|
|
274
|
+
with open(os.path.join(patched, "tokenizer_config.json"), "w") as f:
|
|
275
|
+
json.dump(cfg, f)
|
|
276
|
+
print("[CatLLM] Patched malformed extra_special_tokens in the "
|
|
277
|
+
"formatter tokenizer config (list -> {}).")
|
|
278
|
+
return AutoTokenizer.from_pretrained(patched, trust_remote_code=True)
|
|
279
|
+
raise
|
|
280
|
+
|
|
281
|
+
|
|
208
282
|
def load_formatter(device=None):
|
|
209
283
|
"""
|
|
210
284
|
Load the merged formatter model and tokenizer.
|
|
@@ -230,15 +304,21 @@ def load_formatter(device=None):
|
|
|
230
304
|
dtype = torch.float16 if device == "cuda" else torch.float32
|
|
231
305
|
|
|
232
306
|
print(f"[CatLLM] Loading JSON formatter on {device}...")
|
|
233
|
-
tokenizer = AutoTokenizer
|
|
234
|
-
_MERGED_MODEL_REPO, trust_remote_code=True
|
|
235
|
-
)
|
|
307
|
+
tokenizer = _load_formatter_tokenizer(AutoTokenizer)
|
|
236
308
|
if tokenizer.pad_token is None:
|
|
237
309
|
tokenizer.pad_token = tokenizer.eos_token
|
|
238
310
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
311
|
+
# `dtype=` is the transformers >=4.56 kwarg; older versions only accept
|
|
312
|
+
# `torch_dtype=` and crash if `dtype=` leaks into the config. Try the new
|
|
313
|
+
# name, fall back to the old one.
|
|
314
|
+
try:
|
|
315
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
316
|
+
_MERGED_MODEL_REPO, dtype=dtype, trust_remote_code=True
|
|
317
|
+
)
|
|
318
|
+
except TypeError:
|
|
319
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
320
|
+
_MERGED_MODEL_REPO, torch_dtype=dtype, trust_remote_code=True
|
|
321
|
+
)
|
|
242
322
|
model = model.to(device)
|
|
243
323
|
model.eval()
|
|
244
324
|
|
|
@@ -281,7 +361,9 @@ def run_formatter(raw_output, categories, model, tokenizer, device):
|
|
|
281
361
|
with torch.no_grad():
|
|
282
362
|
out = model.generate(
|
|
283
363
|
**inputs,
|
|
284
|
-
|
|
364
|
+
# 512 (was 128): a large category set produces a long N-key JSON
|
|
365
|
+
# object; 128 tokens truncated it for 28/48-category tasks.
|
|
366
|
+
max_new_tokens=512,
|
|
285
367
|
do_sample=False,
|
|
286
368
|
temperature=None,
|
|
287
369
|
top_p=None,
|
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Theme collapsing for CatLLM.
|
|
3
|
+
|
|
4
|
+
collapse_themes() takes an already-extracted list of category/theme strings (for
|
|
5
|
+
example the output of explore()) and iteratively consolidates near-duplicate /
|
|
6
|
+
synonymous labels into a smaller list. Each pass:
|
|
7
|
+
|
|
8
|
+
A. accept the list,
|
|
9
|
+
B. PRE-CLEAN before the model — normalize + Jaro-Winkler dedup (surface
|
|
10
|
+
variants) then embedding-merge (semantic near-duplicates),
|
|
11
|
+
C. split the cleaned list into batches of `batch_size`,
|
|
12
|
+
D. read every batch with one LLM call (extract-unique, or aggressive merge),
|
|
13
|
+
E. concatenate and dedupe into a single, smaller list.
|
|
14
|
+
|
|
15
|
+
`passes` iterations run in one call, randomizing batch composition each pass so
|
|
16
|
+
labels stranded in separate batches get fresh chances to meet and merge.
|
|
17
|
+
Provider-agnostic via the same dispatch classify()/explore() use.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import random
|
|
21
|
+
import re
|
|
22
|
+
import sys
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
26
|
+
from jellyfish import jaro_winkler_similarity
|
|
27
|
+
|
|
28
|
+
from ._providers import UnifiedLLMClient, detect_provider
|
|
29
|
+
from ._utils import _clean_label
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"collapse_themes",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
_LINE_PAT = re.compile(r"^\s*\d+\s*[\.\)\-]\s*(.+)$")
|
|
36
|
+
_EMB_MODEL = None # cached embedding model (loaded once per process)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _strip_parens(label):
|
|
40
|
+
"""Drop parenthetical examples — '(...)' doesn't change the category."""
|
|
41
|
+
return re.sub(r"\s*\([^)]*\)", "", label).strip()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _norm_key(label):
|
|
45
|
+
"""Canonical dedup key: parens-stripped, lowercased, separators/order unified."""
|
|
46
|
+
s = _strip_parens(label).lower().strip()
|
|
47
|
+
s = re.sub(r"\s*&\s*|\s+and\s+|\s*/\s*", " / ", s)
|
|
48
|
+
parts = sorted(p.strip() for p in s.split("/") if p.strip())
|
|
49
|
+
return " / ".join(parts)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _jw_dedupe(items, threshold):
|
|
53
|
+
"""Order-preserving dedup: normalize each label and collapse near-identical
|
|
54
|
+
normalized labels with a Jaro-Winkler threshold. Returns readable forms."""
|
|
55
|
+
kept_keys = []
|
|
56
|
+
out = []
|
|
57
|
+
for c in items:
|
|
58
|
+
disp = _strip_parens(c).lower().strip()
|
|
59
|
+
key = _norm_key(c)
|
|
60
|
+
if not disp or not key:
|
|
61
|
+
continue
|
|
62
|
+
is_dup = any(
|
|
63
|
+
k == key
|
|
64
|
+
or (threshold < 1.0 and jaro_winkler_similarity(key, k) >= threshold)
|
|
65
|
+
for k in kept_keys
|
|
66
|
+
)
|
|
67
|
+
if not is_dup:
|
|
68
|
+
kept_keys.append(key)
|
|
69
|
+
out.append(disp)
|
|
70
|
+
return out
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _embedding_merge(items, threshold):
|
|
74
|
+
"""Greedy embedding clustering: drop labels whose cosine similarity to an
|
|
75
|
+
already-kept label is >= threshold. Keeps the first-seen representative.
|
|
76
|
+
Uses cat-stack's canonical BAAI/bge-small model (cached)."""
|
|
77
|
+
global _EMB_MODEL
|
|
78
|
+
if not threshold or threshold >= 1.0 or len(items) < 2:
|
|
79
|
+
return items
|
|
80
|
+
if _EMB_MODEL is None:
|
|
81
|
+
from ._embeddings import load_embedding_model
|
|
82
|
+
_EMB_MODEL = load_embedding_model()
|
|
83
|
+
embs = _EMB_MODEL.encode(items, normalize_embeddings=True, show_progress_bar=False)
|
|
84
|
+
reps, rep_embs = [], []
|
|
85
|
+
for it, e in zip(items, embs):
|
|
86
|
+
if rep_embs and float(np.max(np.asarray(rep_embs) @ e)) >= threshold:
|
|
87
|
+
continue
|
|
88
|
+
reps.append(it)
|
|
89
|
+
rep_embs.append(e)
|
|
90
|
+
return reps
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _collapse_batch(client, batch, description, creativity, mode="unique"):
|
|
94
|
+
"""One LLM call on a single batch -> list[str].
|
|
95
|
+
|
|
96
|
+
mode="unique": extract unique categories only (remove restatements, keep
|
|
97
|
+
distinct ones) — gentle, near-idempotent, guaranteed to only remove.
|
|
98
|
+
mode="merge": aggressively consolidate related labels into broader concepts
|
|
99
|
+
while retaining meaningful distinctions — for a final compression step.
|
|
100
|
+
|
|
101
|
+
Strict numbered-list prompt + strict parsing, so the reply is always a clean
|
|
102
|
+
list and any stray prose is ignored. Guardrails: a failed call returns the
|
|
103
|
+
batch unchanged (no data loss); in "unique" mode the output is forced to be a
|
|
104
|
+
subset of the input (monotone, drift-free).
|
|
105
|
+
"""
|
|
106
|
+
items_blob = "; ".join(batch)
|
|
107
|
+
context = f' about: "{description}"' if description else ""
|
|
108
|
+
if mode == "merge":
|
|
109
|
+
prompt = (
|
|
110
|
+
f"You are consolidating a list of category labels{context} into a smaller set of "
|
|
111
|
+
"broader categories. Group labels that describe the same underlying concept and give "
|
|
112
|
+
"each group ONE clear representative label — actively merge near-synonyms and closely "
|
|
113
|
+
"related labels into broader themes. BUT retain nuance: do NOT over-merge — keep labels "
|
|
114
|
+
"separate when they capture a genuinely distinct concept, even if related, rather than "
|
|
115
|
+
"collapsing them into one vague catch-all. Prefer fewer, cleaner categories without "
|
|
116
|
+
f"losing real distinctions. Labels are separated by semicolons within triple backticks: "
|
|
117
|
+
f"```{items_blob}```.\n\n"
|
|
118
|
+
"Return ONLY a numbered list of the consolidated categories. Each line must follow this "
|
|
119
|
+
"exact format, with no other text before or after the list:\n"
|
|
120
|
+
"N. label\n\n"
|
|
121
|
+
"Example:\n"
|
|
122
|
+
"1. Employment\n"
|
|
123
|
+
"2. Education\n"
|
|
124
|
+
"3. Religion"
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
prompt = (
|
|
128
|
+
f"You are given a list of category labels{context}. "
|
|
129
|
+
"Return the UNIQUE categories. Remove ONLY exact duplicates and labels that "
|
|
130
|
+
"restate the SAME category in different words — when two labels are the same "
|
|
131
|
+
"category, keep one of them exactly as written. KEEP every genuinely distinct "
|
|
132
|
+
"category. Do NOT merge categories that are merely related, do NOT invent or "
|
|
133
|
+
"broaden labels, and do NOT drop a category just to make the list shorter. "
|
|
134
|
+
"If all the labels are already distinct categories, return ALL of them unchanged. "
|
|
135
|
+
f"Labels are separated by semicolons within triple backticks: ```{items_blob}```.\n\n"
|
|
136
|
+
"Return ONLY a numbered list, using the labels exactly as they appear. Each line "
|
|
137
|
+
"must follow this exact format, with no other text before or after the list:\n"
|
|
138
|
+
"N. label\n\n"
|
|
139
|
+
"Example:\n"
|
|
140
|
+
"1. Employment\n"
|
|
141
|
+
"2. Education\n"
|
|
142
|
+
"3. Religion"
|
|
143
|
+
)
|
|
144
|
+
reply, error = client.complete(
|
|
145
|
+
messages=[{"role": "user", "content": prompt}],
|
|
146
|
+
creativity=creativity,
|
|
147
|
+
force_json=False,
|
|
148
|
+
)
|
|
149
|
+
if error:
|
|
150
|
+
# No data loss: keep the batch unchanged so its categories aren't dropped.
|
|
151
|
+
sys.stderr.write(f"[collapse_themes] batch failed: {error} — keeping batch unchanged\n")
|
|
152
|
+
return [str(x).strip().lower() for x in batch]
|
|
153
|
+
|
|
154
|
+
out = []
|
|
155
|
+
for line in (reply or "").splitlines():
|
|
156
|
+
m = _LINE_PAT.match(line.strip())
|
|
157
|
+
if m:
|
|
158
|
+
label = _clean_label(m.group(1)).strip(" ;.,")
|
|
159
|
+
if label:
|
|
160
|
+
out.append(label)
|
|
161
|
+
|
|
162
|
+
if mode == "unique":
|
|
163
|
+
# Contraction guarantee: extract-unique must only REMOVE, never add or
|
|
164
|
+
# mutate. Keep only outputs that map back to an input label (by normalized
|
|
165
|
+
# key), as the original input string. Makes every pass monotone and
|
|
166
|
+
# drift-free, immune to intermittent model rephrasing/splitting.
|
|
167
|
+
in_by_key = {}
|
|
168
|
+
for x in batch:
|
|
169
|
+
in_by_key.setdefault(_norm_key(x), str(x).strip().lower())
|
|
170
|
+
seen, subset = set(), []
|
|
171
|
+
for o in out:
|
|
172
|
+
k = _norm_key(o)
|
|
173
|
+
if k in in_by_key and k not in seen:
|
|
174
|
+
seen.add(k)
|
|
175
|
+
subset.append(in_by_key[k])
|
|
176
|
+
# If parsing/matching failed entirely, fall back to the batch (no loss).
|
|
177
|
+
out = subset if subset else [str(x).strip().lower() for x in batch]
|
|
178
|
+
return out
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _to_counts(input_data):
|
|
182
|
+
"""Coerce the accepted input forms into a {category: count} dict."""
|
|
183
|
+
if isinstance(input_data, pd.DataFrame):
|
|
184
|
+
cols = {c.lower(): c for c in input_data.columns}
|
|
185
|
+
cat_col = cols.get("category")
|
|
186
|
+
cnt_col = cols.get("count")
|
|
187
|
+
if cat_col is None:
|
|
188
|
+
raise ValueError("DataFrame input must have a 'category' column.")
|
|
189
|
+
if cnt_col is not None:
|
|
190
|
+
return input_data.groupby(cat_col)[cnt_col].sum().astype(int).to_dict()
|
|
191
|
+
return input_data[cat_col].value_counts().to_dict()
|
|
192
|
+
if isinstance(input_data, dict):
|
|
193
|
+
return {str(k): int(v) for k, v in input_data.items()}
|
|
194
|
+
series = input_data if isinstance(input_data, pd.Series) else pd.Series(input_data)
|
|
195
|
+
series = series.dropna().astype("string")
|
|
196
|
+
return series.value_counts().to_dict()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _collapse_once(
|
|
200
|
+
client,
|
|
201
|
+
items,
|
|
202
|
+
*,
|
|
203
|
+
description,
|
|
204
|
+
batch_size,
|
|
205
|
+
dedupe_threshold,
|
|
206
|
+
embedding_merge_threshold,
|
|
207
|
+
mode,
|
|
208
|
+
shuffle,
|
|
209
|
+
random_state,
|
|
210
|
+
creativity,
|
|
211
|
+
max_workers,
|
|
212
|
+
):
|
|
213
|
+
"""Run a single collapse pass over `items` and return the reduced list."""
|
|
214
|
+
# A. accept -> {category: count}
|
|
215
|
+
counts = _to_counts(items)
|
|
216
|
+
|
|
217
|
+
# B. PRE-CLEAN before the model: normalize+JW dedup, then embedding-merge
|
|
218
|
+
ordered = sorted(counts, key=counts.get, reverse=True)
|
|
219
|
+
cleaned = _jw_dedupe(ordered, dedupe_threshold)
|
|
220
|
+
cleaned = _embedding_merge(cleaned, embedding_merge_threshold)
|
|
221
|
+
|
|
222
|
+
# Randomize order so batch composition varies across passes — gives near-
|
|
223
|
+
# duplicates split across batches fresh chances to co-occur and merge.
|
|
224
|
+
if shuffle:
|
|
225
|
+
random.Random(random_state).shuffle(cleaned)
|
|
226
|
+
|
|
227
|
+
# C. split into batches
|
|
228
|
+
batches = [cleaned[i:i + batch_size] for i in range(0, len(cleaned), batch_size)]
|
|
229
|
+
|
|
230
|
+
# D. one LLM call per batch (sequential or parallel)
|
|
231
|
+
if max_workers and max_workers > 1:
|
|
232
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
233
|
+
|
|
234
|
+
results = [None] * len(batches)
|
|
235
|
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
|
236
|
+
futures = {
|
|
237
|
+
ex.submit(_collapse_batch, client, b, description, creativity, mode): i
|
|
238
|
+
for i, b in enumerate(batches)
|
|
239
|
+
}
|
|
240
|
+
for fut in as_completed(futures):
|
|
241
|
+
results[futures[fut]] = fut.result()
|
|
242
|
+
out = [label for r in results for label in (r or [])]
|
|
243
|
+
else:
|
|
244
|
+
out = []
|
|
245
|
+
for batch in batches:
|
|
246
|
+
out.extend(_collapse_batch(client, batch, description, creativity, mode))
|
|
247
|
+
|
|
248
|
+
# E. dedupe the concatenated output (surface-level)
|
|
249
|
+
return _jw_dedupe(out, dedupe_threshold)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def collapse_themes(
|
|
253
|
+
input_data,
|
|
254
|
+
api_key=None,
|
|
255
|
+
description="",
|
|
256
|
+
passes=1,
|
|
257
|
+
batch_size=40,
|
|
258
|
+
aggressive=False,
|
|
259
|
+
dedupe_threshold=0.95,
|
|
260
|
+
embedding_merge_threshold=0.92,
|
|
261
|
+
shuffle=True,
|
|
262
|
+
user_model="gpt-4o",
|
|
263
|
+
model_source="auto",
|
|
264
|
+
creativity=0,
|
|
265
|
+
max_workers=1,
|
|
266
|
+
random_state=None,
|
|
267
|
+
filename=None,
|
|
268
|
+
progress_callback=None,
|
|
269
|
+
):
|
|
270
|
+
"""
|
|
271
|
+
Collapse a list of extracted themes into a smaller, deduplicated list.
|
|
272
|
+
|
|
273
|
+
Iteratively consolidates near-duplicate / synonymous category labels (for
|
|
274
|
+
example the output of explore()). Each pass PRE-CLEANS before the model
|
|
275
|
+
(normalize + Jaro-Winkler dedup, then embedding-merge), splits into batches,
|
|
276
|
+
sends each batch to the model, and dedupes the concatenated result. Runs
|
|
277
|
+
`passes` iterations, randomizing batch composition each pass so labels
|
|
278
|
+
stranded in separate batches get fresh chances to merge.
|
|
279
|
+
|
|
280
|
+
Two modes:
|
|
281
|
+
- aggressive=False (default): extract-unique — only removes duplicates /
|
|
282
|
+
restatements, never invents or broadens. Each pass is guaranteed monotone
|
|
283
|
+
(output is a subset of its input). Use to thin a noisy list faithfully.
|
|
284
|
+
- aggressive=True: conceptual merge — actively consolidates related labels
|
|
285
|
+
into broader categories while retaining meaningful distinctions. Use as a
|
|
286
|
+
final compression step.
|
|
287
|
+
|
|
288
|
+
Provider-agnostic (model_source: "auto", "openai", "huggingface", ...), via
|
|
289
|
+
the same dispatch classify()/explore() use.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
input_data: Themes to collapse. list[str] (duplicates allowed), pandas
|
|
293
|
+
Series, dict {category: count}, or DataFrame with "category"
|
|
294
|
+
[and optional "count"] columns.
|
|
295
|
+
api_key (str): API key for the model provider.
|
|
296
|
+
description (str): Data/question context, injected into the prompt — e.g.
|
|
297
|
+
the survey question the categories came from. Helps the model judge
|
|
298
|
+
which distinctions matter.
|
|
299
|
+
passes (int): Number of collapse iterations to run. Default 1.
|
|
300
|
+
batch_size (int): Themes per LLM chunk (ceil(n / batch_size) calls per
|
|
301
|
+
pass). Default 40.
|
|
302
|
+
aggressive (bool): If True, use the conceptual-merge prompt (compress);
|
|
303
|
+
if False, extract-unique (faithful thinning). Default False.
|
|
304
|
+
dedupe_threshold (float): Jaro-Winkler similarity at/above which two
|
|
305
|
+
normalized labels are deduped. Default 0.95; 1.0 = exact only.
|
|
306
|
+
embedding_merge_threshold (float): Cosine similarity at/above which labels
|
|
307
|
+
are merged in the pre-LLM embedding step (BAAI/bge-small). Default
|
|
308
|
+
0.92. None or >=1.0 skips embeddings.
|
|
309
|
+
shuffle (bool): Randomize order each pass so batch composition varies.
|
|
310
|
+
Default True (improves convergence stability).
|
|
311
|
+
user_model (str): Model name. Default "gpt-4o". Use a capable model —
|
|
312
|
+
small models can degenerate into repetition.
|
|
313
|
+
model_source (str): Provider — "auto", "openai", "huggingface", etc.
|
|
314
|
+
creativity (float): Temperature. Default 0 (deterministic).
|
|
315
|
+
max_workers (int): Batches processed concurrently per pass. Default 1.
|
|
316
|
+
random_state (int): Seed for shuffling (per-pass seed = random_state + p).
|
|
317
|
+
None = nondeterministic.
|
|
318
|
+
filename (str): Optional CSV path to save the final list.
|
|
319
|
+
progress_callback (callable): Optional callback(pass, passes, label).
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
list[str]: The collapsed category list after `passes` iterations.
|
|
323
|
+
|
|
324
|
+
Examples:
|
|
325
|
+
>>> import cat_stack as cat
|
|
326
|
+
>>> themes = cat.explore(df['responses'], description="Why did you move?",
|
|
327
|
+
... api_key=key)
|
|
328
|
+
>>> # 1) thin faithfully, then 2) compress
|
|
329
|
+
>>> thinned = cat.collapse_themes(themes, api_key=key,
|
|
330
|
+
... description="Why did you move?", passes=10, max_workers=8)
|
|
331
|
+
>>> final = cat.collapse_themes(thinned, api_key=key,
|
|
332
|
+
... description="Why did you move?", passes=2, aggressive=True)
|
|
333
|
+
"""
|
|
334
|
+
if not api_key:
|
|
335
|
+
raise ValueError("collapse_themes() needs an api_key for the LLM call.")
|
|
336
|
+
|
|
337
|
+
mode = "merge" if aggressive else "unique"
|
|
338
|
+
provider = detect_provider(user_model, model_source)
|
|
339
|
+
client = UnifiedLLMClient(provider=provider, api_key=api_key, model=user_model)
|
|
340
|
+
|
|
341
|
+
current = input_data
|
|
342
|
+
for p in range(passes):
|
|
343
|
+
seed = None if random_state is None else random_state + p
|
|
344
|
+
current = _collapse_once(
|
|
345
|
+
client,
|
|
346
|
+
current,
|
|
347
|
+
description=description,
|
|
348
|
+
batch_size=batch_size,
|
|
349
|
+
dedupe_threshold=dedupe_threshold,
|
|
350
|
+
embedding_merge_threshold=embedding_merge_threshold,
|
|
351
|
+
mode=mode,
|
|
352
|
+
shuffle=shuffle,
|
|
353
|
+
random_state=seed,
|
|
354
|
+
creativity=creativity,
|
|
355
|
+
max_workers=max_workers,
|
|
356
|
+
)
|
|
357
|
+
if progress_callback:
|
|
358
|
+
progress_callback(p + 1, passes, "collapse_themes")
|
|
359
|
+
|
|
360
|
+
if filename:
|
|
361
|
+
pd.DataFrame({"category": current}).to_csv(filename, index=False)
|
|
362
|
+
print(f"Collapsed categories saved to {filename}")
|
|
363
|
+
|
|
364
|
+
return current
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|