cat-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cat_stack/__about__.py +10 -0
- cat_stack/__init__.py +128 -0
- cat_stack/_batch.py +1388 -0
- cat_stack/_category_analysis.py +348 -0
- cat_stack/_chunked.py +424 -0
- cat_stack/_embeddings.py +189 -0
- cat_stack/_formatter.py +169 -0
- cat_stack/_providers.py +1048 -0
- cat_stack/_tiebreaker.py +277 -0
- cat_stack/_utils.py +512 -0
- cat_stack/_web_fetch.py +194 -0
- cat_stack/calls/CoVe.py +287 -0
- cat_stack/calls/__init__.py +25 -0
- cat_stack/calls/all_calls.py +622 -0
- cat_stack/calls/image_CoVe.py +386 -0
- cat_stack/calls/image_stepback.py +210 -0
- cat_stack/calls/pdf_CoVe.py +386 -0
- cat_stack/calls/pdf_stepback.py +210 -0
- cat_stack/calls/stepback.py +180 -0
- cat_stack/calls/top_n.py +217 -0
- cat_stack/classify.py +682 -0
- cat_stack/explore.py +111 -0
- cat_stack/extract.py +218 -0
- cat_stack/image_functions.py +2078 -0
- cat_stack/images/circle.png +0 -0
- cat_stack/images/cube.png +0 -0
- cat_stack/images/diamond.png +0 -0
- cat_stack/images/overlapping_pentagons.png +0 -0
- cat_stack/images/rectangles.png +0 -0
- cat_stack/model_reference_list.py +94 -0
- cat_stack/pdf_functions.py +2087 -0
- cat_stack/summarize.py +290 -0
- cat_stack/text_functions.py +1358 -0
- cat_stack/text_functions_ensemble.py +3644 -0
- cat_stack-0.1.0.dist-info/METADATA +150 -0
- cat_stack-0.1.0.dist-info/RECORD +38 -0
- cat_stack-0.1.0.dist-info/WHEEL +4 -0
- cat_stack-0.1.0.dist-info/licenses/LICENSE +672 -0
cat_stack/_tiebreaker.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding-based centroid tiebreaker for ensemble classification.
|
|
3
|
+
|
|
4
|
+
Resolves true ties in ensemble consensus by building per-category centroids
|
|
5
|
+
from unanimously-agreed rows, then comparing tied texts to those centroids.
|
|
6
|
+
|
|
7
|
+
This module is called after ensemble classification completes but before
|
|
8
|
+
building output DataFrames. It mutates all_results in place, updating
|
|
9
|
+
consensus values for tied rows and adding tiebreaker_resolved metadata.
|
|
10
|
+
|
|
11
|
+
How it works:
|
|
12
|
+
1. Identify "confident" rows — where every model unanimously agrees
|
|
13
|
+
2. Embed all texts, compute mean embedding per category from confident rows
|
|
14
|
+
3. For true ties only (positive_rate == threshold exactly), compare text
|
|
15
|
+
embedding to positive vs negative centroid
|
|
16
|
+
4. Resolve: pick whichever centroid is closer; if only positive centroid
|
|
17
|
+
exists, use absolute similarity threshold
|
|
18
|
+
|
|
19
|
+
Requires: pip install cat-llm[embeddings]
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pandas as pd
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _compute_centroid(embeddings_matrix):
|
|
27
|
+
"""
|
|
28
|
+
Compute L2-normalized centroid (mean embedding) from a matrix of embeddings.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
embeddings_matrix: numpy array of shape (N, D) with N embeddings.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
L2-normalized centroid vector of shape (D,).
|
|
35
|
+
"""
|
|
36
|
+
mean_vec = np.mean(embeddings_matrix, axis=0)
|
|
37
|
+
norm = np.linalg.norm(mean_vec)
|
|
38
|
+
if norm > 0:
|
|
39
|
+
mean_vec = mean_vec / norm
|
|
40
|
+
return mean_vec
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _find_confident_and_tied_rows(all_results, category_key, threshold):
|
|
44
|
+
"""
|
|
45
|
+
Bucket rows into confident-positive, confident-negative, and true-tied.
|
|
46
|
+
|
|
47
|
+
A row is "confident" when all successful models unanimously agree (all 1s
|
|
48
|
+
or all 0s). A row is "true-tied" when positive_rate == threshold exactly
|
|
49
|
+
(e.g., 2-2 split with majority vote threshold=0.5).
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
all_results: List of result dicts from classify_ensemble.
|
|
53
|
+
category_key: String key like "1", "2", etc.
|
|
54
|
+
threshold: Numeric consensus threshold (0-1).
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
(confident_positive_indices, confident_negative_indices, tied_indices)
|
|
58
|
+
Each is a list of row indices into all_results.
|
|
59
|
+
"""
|
|
60
|
+
confident_pos = []
|
|
61
|
+
confident_neg = []
|
|
62
|
+
tied = []
|
|
63
|
+
|
|
64
|
+
for row_idx, result in enumerate(all_results):
|
|
65
|
+
if result.get("skipped"):
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
aggregated = result["aggregated"]
|
|
69
|
+
if aggregated.get("error"):
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
per_model = aggregated.get("per_model", {})
|
|
73
|
+
if not per_model:
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
# Collect votes for this category from all successful models
|
|
77
|
+
votes = []
|
|
78
|
+
for model_name, parsed in per_model.items():
|
|
79
|
+
vote_str = parsed.get(category_key, "0")
|
|
80
|
+
try:
|
|
81
|
+
votes.append(int(vote_str))
|
|
82
|
+
except (ValueError, TypeError):
|
|
83
|
+
votes.append(0)
|
|
84
|
+
|
|
85
|
+
if not votes:
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
num_models = len(votes)
|
|
89
|
+
positive_count = sum(votes)
|
|
90
|
+
positive_rate = positive_count / num_models
|
|
91
|
+
|
|
92
|
+
# Check for true tie: positive_rate == threshold exactly
|
|
93
|
+
if abs(positive_rate - threshold) < 1e-9:
|
|
94
|
+
tied.append(row_idx)
|
|
95
|
+
elif positive_rate == 1.0:
|
|
96
|
+
# All models agree positive
|
|
97
|
+
confident_pos.append(row_idx)
|
|
98
|
+
elif positive_rate == 0.0:
|
|
99
|
+
# All models agree negative
|
|
100
|
+
confident_neg.append(row_idx)
|
|
101
|
+
|
|
102
|
+
return confident_pos, confident_neg, tied
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def resolve_ties_with_centroids(
|
|
106
|
+
all_results,
|
|
107
|
+
categories,
|
|
108
|
+
embedding_model,
|
|
109
|
+
consensus_threshold,
|
|
110
|
+
min_centroid_size=3,
|
|
111
|
+
):
|
|
112
|
+
"""
|
|
113
|
+
Resolve true ties in ensemble consensus using embedding centroids.
|
|
114
|
+
|
|
115
|
+
Builds per-category centroids from texts where ALL models unanimously agree,
|
|
116
|
+
then uses cosine similarity to those centroids to break ties.
|
|
117
|
+
|
|
118
|
+
Only mutates rows where positive_rate == threshold exactly (true ties).
|
|
119
|
+
Rows with clear majorities are left alone.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
all_results: List of result dicts from classify_ensemble (mutated in place).
|
|
123
|
+
categories: List of category name strings.
|
|
124
|
+
embedding_model: Loaded SentenceTransformer model.
|
|
125
|
+
consensus_threshold: Numeric threshold (already resolved to float).
|
|
126
|
+
min_centroid_size: Minimum confident rows needed to build a centroid.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Dict with summary stats:
|
|
130
|
+
- total_ties: number of true ties found across all categories
|
|
131
|
+
- resolved: number of ties resolved by centroid
|
|
132
|
+
- skipped_categories: categories skipped (insufficient confident data)
|
|
133
|
+
"""
|
|
134
|
+
threshold = consensus_threshold
|
|
135
|
+
|
|
136
|
+
# Quick scan: any ties at all?
|
|
137
|
+
has_any_ties = False
|
|
138
|
+
for cat_idx in range(len(categories)):
|
|
139
|
+
cat_key = str(cat_idx + 1)
|
|
140
|
+
_, _, tied = _find_confident_and_tied_rows(all_results, cat_key, threshold)
|
|
141
|
+
if tied:
|
|
142
|
+
has_any_ties = True
|
|
143
|
+
break
|
|
144
|
+
|
|
145
|
+
if not has_any_ties:
|
|
146
|
+
print("[CatLLM] Embedding tiebreaker: no true ties found — skipping.")
|
|
147
|
+
# Still mark all rows as resolved by vote for consistent output columns
|
|
148
|
+
_mark_all_as_vote(all_results, categories)
|
|
149
|
+
return {"total_ties": 0, "resolved": 0, "skipped_categories": []}
|
|
150
|
+
|
|
151
|
+
# Encode all non-skipped texts in one batch
|
|
152
|
+
texts = []
|
|
153
|
+
row_to_embed_idx = {} # row_idx -> index in texts list
|
|
154
|
+
for row_idx, result in enumerate(all_results):
|
|
155
|
+
if result.get("skipped"):
|
|
156
|
+
continue
|
|
157
|
+
raw_text = result.get("_original_item", result.get("response", ""))
|
|
158
|
+
if pd.notna(raw_text) and str(raw_text).strip():
|
|
159
|
+
row_to_embed_idx[row_idx] = len(texts)
|
|
160
|
+
texts.append(str(raw_text))
|
|
161
|
+
|
|
162
|
+
if not texts:
|
|
163
|
+
return {"total_ties": 0, "resolved": 0, "skipped_categories": []}
|
|
164
|
+
|
|
165
|
+
print(f"[CatLLM] Embedding tiebreaker: encoding {len(texts)} texts...")
|
|
166
|
+
all_embeddings = embedding_model.encode(
|
|
167
|
+
texts, normalize_embeddings=True,
|
|
168
|
+
show_progress_bar=len(texts) > 100,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
total_ties = 0
|
|
172
|
+
total_resolved = 0
|
|
173
|
+
skipped_categories = []
|
|
174
|
+
|
|
175
|
+
for cat_idx in range(len(categories)):
|
|
176
|
+
cat_key = str(cat_idx + 1)
|
|
177
|
+
cat_name = categories[cat_idx]
|
|
178
|
+
|
|
179
|
+
confident_pos, confident_neg, tied = _find_confident_and_tied_rows(
|
|
180
|
+
all_results, cat_key, threshold,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if not tied:
|
|
184
|
+
continue
|
|
185
|
+
|
|
186
|
+
total_ties += len(tied)
|
|
187
|
+
|
|
188
|
+
# Build positive centroid from confident-positive rows
|
|
189
|
+
pos_embed_indices = [
|
|
190
|
+
row_to_embed_idx[r] for r in confident_pos if r in row_to_embed_idx
|
|
191
|
+
]
|
|
192
|
+
neg_embed_indices = [
|
|
193
|
+
row_to_embed_idx[r] for r in confident_neg if r in row_to_embed_idx
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
if len(pos_embed_indices) < min_centroid_size:
|
|
197
|
+
skipped_categories.append(cat_name)
|
|
198
|
+
# Mark these tied rows as resolved by vote (unchanged)
|
|
199
|
+
for row_idx in tied:
|
|
200
|
+
agg = all_results[row_idx]["aggregated"]
|
|
201
|
+
if "tiebreaker_resolved" not in agg:
|
|
202
|
+
agg["tiebreaker_resolved"] = {}
|
|
203
|
+
agg["tiebreaker_resolved"][cat_key] = "vote"
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
pos_centroid = _compute_centroid(all_embeddings[pos_embed_indices])
|
|
207
|
+
|
|
208
|
+
# Build negative centroid if enough data
|
|
209
|
+
neg_centroid = None
|
|
210
|
+
if len(neg_embed_indices) >= min_centroid_size:
|
|
211
|
+
neg_centroid = _compute_centroid(all_embeddings[neg_embed_indices])
|
|
212
|
+
|
|
213
|
+
# Resolve each tied row
|
|
214
|
+
resolved_this_cat = 0
|
|
215
|
+
for row_idx in tied:
|
|
216
|
+
if row_idx not in row_to_embed_idx:
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
embed_idx = row_to_embed_idx[row_idx]
|
|
220
|
+
text_embedding = all_embeddings[embed_idx]
|
|
221
|
+
|
|
222
|
+
# Cosine similarity (embeddings are already normalized, so dot product)
|
|
223
|
+
sim_to_pos = float(np.dot(text_embedding, pos_centroid))
|
|
224
|
+
|
|
225
|
+
if neg_centroid is not None:
|
|
226
|
+
sim_to_neg = float(np.dot(text_embedding, neg_centroid))
|
|
227
|
+
new_consensus = "1" if sim_to_pos >= sim_to_neg else "0"
|
|
228
|
+
else:
|
|
229
|
+
# Only positive centroid — use absolute threshold (0.5 on similarity)
|
|
230
|
+
new_consensus = "1" if sim_to_pos >= 0.5 else "0"
|
|
231
|
+
|
|
232
|
+
agg = all_results[row_idx]["aggregated"]
|
|
233
|
+
agg["consensus"][cat_key] = new_consensus
|
|
234
|
+
if "tiebreaker_resolved" not in agg:
|
|
235
|
+
agg["tiebreaker_resolved"] = {}
|
|
236
|
+
agg["tiebreaker_resolved"][cat_key] = "centroid"
|
|
237
|
+
resolved_this_cat += 1
|
|
238
|
+
|
|
239
|
+
total_resolved += resolved_this_cat
|
|
240
|
+
|
|
241
|
+
# Mark all non-tied rows as resolved by vote
|
|
242
|
+
_mark_all_as_vote(all_results, categories)
|
|
243
|
+
|
|
244
|
+
if skipped_categories:
|
|
245
|
+
print(
|
|
246
|
+
f"[CatLLM] Embedding tiebreaker: skipped {len(skipped_categories)} "
|
|
247
|
+
f"categor{'y' if len(skipped_categories) == 1 else 'ies'} "
|
|
248
|
+
f"(fewer than {min_centroid_size} confident rows): "
|
|
249
|
+
f"{', '.join(skipped_categories)}"
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
print(
|
|
253
|
+
f"[CatLLM] Embedding tiebreaker: {total_ties} true ties found, "
|
|
254
|
+
f"{total_resolved} resolved by centroid."
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
return {
|
|
258
|
+
"total_ties": total_ties,
|
|
259
|
+
"resolved": total_resolved,
|
|
260
|
+
"skipped_categories": skipped_categories,
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _mark_all_as_vote(all_results, categories):
|
|
265
|
+
"""Fill in 'vote' for any row/category not already marked by the tiebreaker."""
|
|
266
|
+
for result in all_results:
|
|
267
|
+
if result.get("skipped"):
|
|
268
|
+
continue
|
|
269
|
+
agg = result["aggregated"]
|
|
270
|
+
if agg.get("error"):
|
|
271
|
+
continue
|
|
272
|
+
if "tiebreaker_resolved" not in agg:
|
|
273
|
+
agg["tiebreaker_resolved"] = {}
|
|
274
|
+
for cat_idx in range(len(categories)):
|
|
275
|
+
cat_key = str(cat_idx + 1)
|
|
276
|
+
if cat_key not in agg["tiebreaker_resolved"]:
|
|
277
|
+
agg["tiebreaker_resolved"][cat_key] = "vote"
|