cat-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,277 @@
1
+ """
2
+ Embedding-based centroid tiebreaker for ensemble classification.
3
+
4
+ Resolves true ties in ensemble consensus by building per-category centroids
5
+ from unanimously-agreed rows, then comparing tied texts to those centroids.
6
+
7
+ This module is called after ensemble classification completes but before
8
+ building output DataFrames. It mutates all_results in place, updating
9
+ consensus values for tied rows and adding tiebreaker_resolved metadata.
10
+
11
+ How it works:
12
+ 1. Identify "confident" rows — where every model unanimously agrees
13
+ 2. Embed all texts, compute mean embedding per category from confident rows
14
+ 3. For true ties only (positive_rate == threshold exactly), compare text
15
+ embedding to positive vs negative centroid
16
+ 4. Resolve: pick whichever centroid is closer; if only positive centroid
17
+ exists, use absolute similarity threshold
18
+
19
+ Requires: pip install cat-llm[embeddings]
20
+ """
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+
25
+
26
+ def _compute_centroid(embeddings_matrix):
27
+ """
28
+ Compute L2-normalized centroid (mean embedding) from a matrix of embeddings.
29
+
30
+ Args:
31
+ embeddings_matrix: numpy array of shape (N, D) with N embeddings.
32
+
33
+ Returns:
34
+ L2-normalized centroid vector of shape (D,).
35
+ """
36
+ mean_vec = np.mean(embeddings_matrix, axis=0)
37
+ norm = np.linalg.norm(mean_vec)
38
+ if norm > 0:
39
+ mean_vec = mean_vec / norm
40
+ return mean_vec
41
+
42
+
43
+ def _find_confident_and_tied_rows(all_results, category_key, threshold):
44
+ """
45
+ Bucket rows into confident-positive, confident-negative, and true-tied.
46
+
47
+ A row is "confident" when all successful models unanimously agree (all 1s
48
+ or all 0s). A row is "true-tied" when positive_rate == threshold exactly
49
+ (e.g., 2-2 split with majority vote threshold=0.5).
50
+
51
+ Args:
52
+ all_results: List of result dicts from classify_ensemble.
53
+ category_key: String key like "1", "2", etc.
54
+ threshold: Numeric consensus threshold (0-1).
55
+
56
+ Returns:
57
+ (confident_positive_indices, confident_negative_indices, tied_indices)
58
+ Each is a list of row indices into all_results.
59
+ """
60
+ confident_pos = []
61
+ confident_neg = []
62
+ tied = []
63
+
64
+ for row_idx, result in enumerate(all_results):
65
+ if result.get("skipped"):
66
+ continue
67
+
68
+ aggregated = result["aggregated"]
69
+ if aggregated.get("error"):
70
+ continue
71
+
72
+ per_model = aggregated.get("per_model", {})
73
+ if not per_model:
74
+ continue
75
+
76
+ # Collect votes for this category from all successful models
77
+ votes = []
78
+ for model_name, parsed in per_model.items():
79
+ vote_str = parsed.get(category_key, "0")
80
+ try:
81
+ votes.append(int(vote_str))
82
+ except (ValueError, TypeError):
83
+ votes.append(0)
84
+
85
+ if not votes:
86
+ continue
87
+
88
+ num_models = len(votes)
89
+ positive_count = sum(votes)
90
+ positive_rate = positive_count / num_models
91
+
92
+ # Check for true tie: positive_rate == threshold exactly
93
+ if abs(positive_rate - threshold) < 1e-9:
94
+ tied.append(row_idx)
95
+ elif positive_rate == 1.0:
96
+ # All models agree positive
97
+ confident_pos.append(row_idx)
98
+ elif positive_rate == 0.0:
99
+ # All models agree negative
100
+ confident_neg.append(row_idx)
101
+
102
+ return confident_pos, confident_neg, tied
103
+
104
+
105
+ def resolve_ties_with_centroids(
106
+ all_results,
107
+ categories,
108
+ embedding_model,
109
+ consensus_threshold,
110
+ min_centroid_size=3,
111
+ ):
112
+ """
113
+ Resolve true ties in ensemble consensus using embedding centroids.
114
+
115
+ Builds per-category centroids from texts where ALL models unanimously agree,
116
+ then uses cosine similarity to those centroids to break ties.
117
+
118
+ Only mutates rows where positive_rate == threshold exactly (true ties).
119
+ Rows with clear majorities are left alone.
120
+
121
+ Args:
122
+ all_results: List of result dicts from classify_ensemble (mutated in place).
123
+ categories: List of category name strings.
124
+ embedding_model: Loaded SentenceTransformer model.
125
+ consensus_threshold: Numeric threshold (already resolved to float).
126
+ min_centroid_size: Minimum confident rows needed to build a centroid.
127
+
128
+ Returns:
129
+ Dict with summary stats:
130
+ - total_ties: number of true ties found across all categories
131
+ - resolved: number of ties resolved by centroid
132
+ - skipped_categories: categories skipped (insufficient confident data)
133
+ """
134
+ threshold = consensus_threshold
135
+
136
+ # Quick scan: any ties at all?
137
+ has_any_ties = False
138
+ for cat_idx in range(len(categories)):
139
+ cat_key = str(cat_idx + 1)
140
+ _, _, tied = _find_confident_and_tied_rows(all_results, cat_key, threshold)
141
+ if tied:
142
+ has_any_ties = True
143
+ break
144
+
145
+ if not has_any_ties:
146
+ print("[CatLLM] Embedding tiebreaker: no true ties found — skipping.")
147
+ # Still mark all rows as resolved by vote for consistent output columns
148
+ _mark_all_as_vote(all_results, categories)
149
+ return {"total_ties": 0, "resolved": 0, "skipped_categories": []}
150
+
151
+ # Encode all non-skipped texts in one batch
152
+ texts = []
153
+ row_to_embed_idx = {} # row_idx -> index in texts list
154
+ for row_idx, result in enumerate(all_results):
155
+ if result.get("skipped"):
156
+ continue
157
+ raw_text = result.get("_original_item", result.get("response", ""))
158
+ if pd.notna(raw_text) and str(raw_text).strip():
159
+ row_to_embed_idx[row_idx] = len(texts)
160
+ texts.append(str(raw_text))
161
+
162
+ if not texts:
163
+ return {"total_ties": 0, "resolved": 0, "skipped_categories": []}
164
+
165
+ print(f"[CatLLM] Embedding tiebreaker: encoding {len(texts)} texts...")
166
+ all_embeddings = embedding_model.encode(
167
+ texts, normalize_embeddings=True,
168
+ show_progress_bar=len(texts) > 100,
169
+ )
170
+
171
+ total_ties = 0
172
+ total_resolved = 0
173
+ skipped_categories = []
174
+
175
+ for cat_idx in range(len(categories)):
176
+ cat_key = str(cat_idx + 1)
177
+ cat_name = categories[cat_idx]
178
+
179
+ confident_pos, confident_neg, tied = _find_confident_and_tied_rows(
180
+ all_results, cat_key, threshold,
181
+ )
182
+
183
+ if not tied:
184
+ continue
185
+
186
+ total_ties += len(tied)
187
+
188
+ # Build positive centroid from confident-positive rows
189
+ pos_embed_indices = [
190
+ row_to_embed_idx[r] for r in confident_pos if r in row_to_embed_idx
191
+ ]
192
+ neg_embed_indices = [
193
+ row_to_embed_idx[r] for r in confident_neg if r in row_to_embed_idx
194
+ ]
195
+
196
+ if len(pos_embed_indices) < min_centroid_size:
197
+ skipped_categories.append(cat_name)
198
+ # Mark these tied rows as resolved by vote (unchanged)
199
+ for row_idx in tied:
200
+ agg = all_results[row_idx]["aggregated"]
201
+ if "tiebreaker_resolved" not in agg:
202
+ agg["tiebreaker_resolved"] = {}
203
+ agg["tiebreaker_resolved"][cat_key] = "vote"
204
+ continue
205
+
206
+ pos_centroid = _compute_centroid(all_embeddings[pos_embed_indices])
207
+
208
+ # Build negative centroid if enough data
209
+ neg_centroid = None
210
+ if len(neg_embed_indices) >= min_centroid_size:
211
+ neg_centroid = _compute_centroid(all_embeddings[neg_embed_indices])
212
+
213
+ # Resolve each tied row
214
+ resolved_this_cat = 0
215
+ for row_idx in tied:
216
+ if row_idx not in row_to_embed_idx:
217
+ continue
218
+
219
+ embed_idx = row_to_embed_idx[row_idx]
220
+ text_embedding = all_embeddings[embed_idx]
221
+
222
+ # Cosine similarity (embeddings are already normalized, so dot product)
223
+ sim_to_pos = float(np.dot(text_embedding, pos_centroid))
224
+
225
+ if neg_centroid is not None:
226
+ sim_to_neg = float(np.dot(text_embedding, neg_centroid))
227
+ new_consensus = "1" if sim_to_pos >= sim_to_neg else "0"
228
+ else:
229
+ # Only positive centroid — use absolute threshold (0.5 on similarity)
230
+ new_consensus = "1" if sim_to_pos >= 0.5 else "0"
231
+
232
+ agg = all_results[row_idx]["aggregated"]
233
+ agg["consensus"][cat_key] = new_consensus
234
+ if "tiebreaker_resolved" not in agg:
235
+ agg["tiebreaker_resolved"] = {}
236
+ agg["tiebreaker_resolved"][cat_key] = "centroid"
237
+ resolved_this_cat += 1
238
+
239
+ total_resolved += resolved_this_cat
240
+
241
+ # Mark all non-tied rows as resolved by vote
242
+ _mark_all_as_vote(all_results, categories)
243
+
244
+ if skipped_categories:
245
+ print(
246
+ f"[CatLLM] Embedding tiebreaker: skipped {len(skipped_categories)} "
247
+ f"categor{'y' if len(skipped_categories) == 1 else 'ies'} "
248
+ f"(fewer than {min_centroid_size} confident rows): "
249
+ f"{', '.join(skipped_categories)}"
250
+ )
251
+
252
+ print(
253
+ f"[CatLLM] Embedding tiebreaker: {total_ties} true ties found, "
254
+ f"{total_resolved} resolved by centroid."
255
+ )
256
+
257
+ return {
258
+ "total_ties": total_ties,
259
+ "resolved": total_resolved,
260
+ "skipped_categories": skipped_categories,
261
+ }
262
+
263
+
264
+ def _mark_all_as_vote(all_results, categories):
265
+ """Fill in 'vote' for any row/category not already marked by the tiebreaker."""
266
+ for result in all_results:
267
+ if result.get("skipped"):
268
+ continue
269
+ agg = result["aggregated"]
270
+ if agg.get("error"):
271
+ continue
272
+ if "tiebreaker_resolved" not in agg:
273
+ agg["tiebreaker_resolved"] = {}
274
+ for cat_idx in range(len(categories)):
275
+ cat_key = str(cat_idx + 1)
276
+ if cat_key not in agg["tiebreaker_resolved"]:
277
+ agg["tiebreaker_resolved"][cat_key] = "vote"