cat-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cat_stack/__about__.py +10 -0
- cat_stack/__init__.py +128 -0
- cat_stack/_batch.py +1388 -0
- cat_stack/_category_analysis.py +348 -0
- cat_stack/_chunked.py +424 -0
- cat_stack/_embeddings.py +189 -0
- cat_stack/_formatter.py +169 -0
- cat_stack/_providers.py +1048 -0
- cat_stack/_tiebreaker.py +277 -0
- cat_stack/_utils.py +512 -0
- cat_stack/_web_fetch.py +194 -0
- cat_stack/calls/CoVe.py +287 -0
- cat_stack/calls/__init__.py +25 -0
- cat_stack/calls/all_calls.py +622 -0
- cat_stack/calls/image_CoVe.py +386 -0
- cat_stack/calls/image_stepback.py +210 -0
- cat_stack/calls/pdf_CoVe.py +386 -0
- cat_stack/calls/pdf_stepback.py +210 -0
- cat_stack/calls/stepback.py +180 -0
- cat_stack/calls/top_n.py +217 -0
- cat_stack/classify.py +682 -0
- cat_stack/explore.py +111 -0
- cat_stack/extract.py +218 -0
- cat_stack/image_functions.py +2078 -0
- cat_stack/images/circle.png +0 -0
- cat_stack/images/cube.png +0 -0
- cat_stack/images/diamond.png +0 -0
- cat_stack/images/overlapping_pentagons.png +0 -0
- cat_stack/images/rectangles.png +0 -0
- cat_stack/model_reference_list.py +94 -0
- cat_stack/pdf_functions.py +2087 -0
- cat_stack/summarize.py +290 -0
- cat_stack/text_functions.py +1358 -0
- cat_stack/text_functions_ensemble.py +3644 -0
- cat_stack-0.1.0.dist-info/METADATA +150 -0
- cat_stack-0.1.0.dist-info/RECORD +38 -0
- cat_stack-0.1.0.dist-info/WHEEL +4 -0
- cat_stack-0.1.0.dist-info/licenses/LICENSE +672 -0
cat_stack/_chunked.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Chunked classification for CatLLM.
|
|
3
|
+
|
|
4
|
+
When users have large category lists, this module splits them into smaller
|
|
5
|
+
chunks, runs a separate LLM call per chunk with local 1..N numbering, and
|
|
6
|
+
merges the results back into global numbering so downstream code
|
|
7
|
+
(aggregate_results, build_output_dataframes) sees a single merged JSON dict.
|
|
8
|
+
|
|
9
|
+
Each chunk automatically gets a temporary "Other" catch-all category appended
|
|
10
|
+
(unless one is already present in the chunk). This gives the LLM an escape
|
|
11
|
+
hatch for ambiguous responses, improving classification accuracy. The "Other"
|
|
12
|
+
column is dropped before merging back to global keys, so the final output
|
|
13
|
+
only contains the user's real categories.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import math
|
|
18
|
+
|
|
19
|
+
from .text_functions import (
|
|
20
|
+
build_json_schema,
|
|
21
|
+
extract_json,
|
|
22
|
+
validate_classification_json,
|
|
23
|
+
ollama_two_step_classify,
|
|
24
|
+
)
|
|
25
|
+
from ._category_analysis import has_other_category
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def run_chunked_classification(
|
|
29
|
+
*,
|
|
30
|
+
client,
|
|
31
|
+
cfg,
|
|
32
|
+
item,
|
|
33
|
+
categories,
|
|
34
|
+
categories_str,
|
|
35
|
+
example_json,
|
|
36
|
+
json_schema,
|
|
37
|
+
cove_original_task,
|
|
38
|
+
effective_creativity,
|
|
39
|
+
use_json_schema,
|
|
40
|
+
survey_question,
|
|
41
|
+
survey_question_context,
|
|
42
|
+
examples_text,
|
|
43
|
+
chain_of_thought,
|
|
44
|
+
context_prompt,
|
|
45
|
+
step_back_prompt,
|
|
46
|
+
stepback_insights,
|
|
47
|
+
chain_of_verification,
|
|
48
|
+
thinking_budget,
|
|
49
|
+
max_retries,
|
|
50
|
+
multi_label,
|
|
51
|
+
categories_per_call,
|
|
52
|
+
add_unified_other=False,
|
|
53
|
+
formatter_fallback_fn,
|
|
54
|
+
# Mode-specific
|
|
55
|
+
is_pdf_mode,
|
|
56
|
+
is_image_mode,
|
|
57
|
+
pdf_mode=None,
|
|
58
|
+
pdf_dpi=150,
|
|
59
|
+
input_description="",
|
|
60
|
+
# Prompt builders (passed in to avoid circular imports)
|
|
61
|
+
build_text_prompt_fn=None,
|
|
62
|
+
build_pdf_prompt_fn=None,
|
|
63
|
+
build_image_prompt_fn=None,
|
|
64
|
+
google_multimodal_fn=None,
|
|
65
|
+
prepare_page_data_fn=None,
|
|
66
|
+
prepare_image_data_fn=None,
|
|
67
|
+
build_cove_prompts_fn=None,
|
|
68
|
+
run_cove_fn=None,
|
|
69
|
+
):
|
|
70
|
+
"""
|
|
71
|
+
Run chunked classification for one item across category chunks.
|
|
72
|
+
|
|
73
|
+
Splits the full category list into chunks of `categories_per_call`,
|
|
74
|
+
runs one LLM call per chunk, and merges results with key remapping.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
tuple: (json_result_str, error) — same contract as a single LLM call
|
|
78
|
+
"""
|
|
79
|
+
# Build chunks: list of (chunk_categories, global_offset)
|
|
80
|
+
chunks = []
|
|
81
|
+
for start in range(0, len(categories), categories_per_call):
|
|
82
|
+
chunk_cats = categories[start : start + categories_per_call]
|
|
83
|
+
chunks.append((chunk_cats, start))
|
|
84
|
+
|
|
85
|
+
merged_json = {}
|
|
86
|
+
chunk_other_values = [] # Track per-chunk "Other" values for unification
|
|
87
|
+
|
|
88
|
+
for chunk_cats, global_offset in chunks:
|
|
89
|
+
# Add temporary "Other" catch-all if the chunk doesn't already have one.
|
|
90
|
+
# This gives the LLM an escape hatch for ambiguous responses, improving
|
|
91
|
+
# accuracy. The "Other" key is dropped before merging to global keys.
|
|
92
|
+
added_other = False
|
|
93
|
+
num_real_cats = len(chunk_cats)
|
|
94
|
+
if not has_other_category(chunk_cats):
|
|
95
|
+
chunk_cats_for_call = list(chunk_cats) + ["Other"]
|
|
96
|
+
added_other = True
|
|
97
|
+
else:
|
|
98
|
+
chunk_cats_for_call = chunk_cats
|
|
99
|
+
|
|
100
|
+
# Build chunk-local prompt components (with "Other" if added)
|
|
101
|
+
chunk_categories_str = "\n".join(
|
|
102
|
+
f"{j+1}. {cat}" for j, cat in enumerate(chunk_cats_for_call)
|
|
103
|
+
)
|
|
104
|
+
chunk_example_json = json.dumps(
|
|
105
|
+
{str(j + 1): "0" for j in range(len(chunk_cats_for_call))}, indent=2
|
|
106
|
+
)
|
|
107
|
+
chunk_json_schema = (
|
|
108
|
+
build_json_schema(
|
|
109
|
+
chunk_cats_for_call,
|
|
110
|
+
include_additional_properties=(cfg["provider"] != "google"),
|
|
111
|
+
)
|
|
112
|
+
if use_json_schema
|
|
113
|
+
else None
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Rebuild CoVe task for this chunk if CoVe enabled
|
|
117
|
+
chunk_cove_task = ""
|
|
118
|
+
if chain_of_verification:
|
|
119
|
+
if multi_label:
|
|
120
|
+
cove_categorize = "into the following categories"
|
|
121
|
+
cove_json = 'Provide your answer in JSON format where the category number is the key and "1" if present, "0" if not.'
|
|
122
|
+
else:
|
|
123
|
+
cove_categorize = "into the single most appropriate category"
|
|
124
|
+
cove_json = 'Provide your answer in JSON format where the category number is the key. Assign "1" to the single best matching category and "0" to all others.'
|
|
125
|
+
chunk_cove_task = f"""{survey_question_context}
|
|
126
|
+
Categorize text responses {cove_categorize}:
|
|
127
|
+
{chunk_categories_str}
|
|
128
|
+
{cove_json}"""
|
|
129
|
+
|
|
130
|
+
# Run one LLM call for this chunk (with "Other" included)
|
|
131
|
+
chunk_result, chunk_error = _run_single_chunk_call(
|
|
132
|
+
client=client,
|
|
133
|
+
cfg=cfg,
|
|
134
|
+
item=item,
|
|
135
|
+
chunk_cats=chunk_cats_for_call,
|
|
136
|
+
chunk_categories_str=chunk_categories_str,
|
|
137
|
+
chunk_json_schema=chunk_json_schema,
|
|
138
|
+
chunk_example_json=chunk_example_json,
|
|
139
|
+
chunk_cove_task=chunk_cove_task,
|
|
140
|
+
effective_creativity=effective_creativity,
|
|
141
|
+
survey_question=survey_question,
|
|
142
|
+
survey_question_context=survey_question_context,
|
|
143
|
+
examples_text=examples_text,
|
|
144
|
+
chain_of_thought=chain_of_thought,
|
|
145
|
+
context_prompt=context_prompt,
|
|
146
|
+
step_back_prompt=step_back_prompt,
|
|
147
|
+
stepback_insights=stepback_insights,
|
|
148
|
+
chain_of_verification=chain_of_verification,
|
|
149
|
+
thinking_budget=thinking_budget,
|
|
150
|
+
max_retries=max_retries,
|
|
151
|
+
multi_label=multi_label,
|
|
152
|
+
formatter_fallback_fn=formatter_fallback_fn,
|
|
153
|
+
is_pdf_mode=is_pdf_mode,
|
|
154
|
+
is_image_mode=is_image_mode,
|
|
155
|
+
pdf_mode=pdf_mode,
|
|
156
|
+
pdf_dpi=pdf_dpi,
|
|
157
|
+
input_description=input_description,
|
|
158
|
+
build_text_prompt_fn=build_text_prompt_fn,
|
|
159
|
+
build_pdf_prompt_fn=build_pdf_prompt_fn,
|
|
160
|
+
build_image_prompt_fn=build_image_prompt_fn,
|
|
161
|
+
google_multimodal_fn=google_multimodal_fn,
|
|
162
|
+
prepare_page_data_fn=prepare_page_data_fn,
|
|
163
|
+
prepare_image_data_fn=prepare_image_data_fn,
|
|
164
|
+
build_cove_prompts_fn=build_cove_prompts_fn,
|
|
165
|
+
run_cove_fn=run_cove_fn,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if chunk_error:
|
|
169
|
+
return (json.dumps(merged_json) if merged_json else '{"1":"e"}', chunk_error)
|
|
170
|
+
|
|
171
|
+
# Remap chunk-local keys (1..N) to global keys, dropping "Other"
|
|
172
|
+
try:
|
|
173
|
+
chunk_parsed = json.loads(chunk_result)
|
|
174
|
+
except (json.JSONDecodeError, TypeError):
|
|
175
|
+
return ('{"1":"e"}', f"Failed to parse chunk result: {chunk_result}")
|
|
176
|
+
|
|
177
|
+
# The "Other" key (if added) is the last one: str(num_real_cats + 1)
|
|
178
|
+
other_local_key = str(num_real_cats + 1) if added_other else None
|
|
179
|
+
|
|
180
|
+
for local_key, value in chunk_parsed.items():
|
|
181
|
+
# Capture the temporary "Other" value, don't merge it
|
|
182
|
+
if local_key == other_local_key:
|
|
183
|
+
try:
|
|
184
|
+
chunk_other_values.append(int(value))
|
|
185
|
+
except (ValueError, TypeError):
|
|
186
|
+
chunk_other_values.append(0)
|
|
187
|
+
continue
|
|
188
|
+
try:
|
|
189
|
+
global_key = str(global_offset + int(local_key))
|
|
190
|
+
merged_json[global_key] = value
|
|
191
|
+
except (ValueError, TypeError):
|
|
192
|
+
# Non-numeric key — skip (shouldn't happen with proper schemas)
|
|
193
|
+
pass
|
|
194
|
+
|
|
195
|
+
# Unified "Other": if all real categories are 0 but at least one chunk's
|
|
196
|
+
# "Other" fired, the response genuinely doesn't fit any category.
|
|
197
|
+
if add_unified_other:
|
|
198
|
+
real_sum = sum(
|
|
199
|
+
int(v) for v in merged_json.values()
|
|
200
|
+
if str(v).strip() in ("0", "1")
|
|
201
|
+
)
|
|
202
|
+
other_sum = sum(chunk_other_values)
|
|
203
|
+
unified_other = "1" if real_sum == 0 and other_sum > 0 else "0"
|
|
204
|
+
merged_json[str(len(categories) + 1)] = unified_other
|
|
205
|
+
|
|
206
|
+
return (json.dumps(merged_json), None)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _run_single_chunk_call(
|
|
210
|
+
*,
|
|
211
|
+
client,
|
|
212
|
+
cfg,
|
|
213
|
+
item,
|
|
214
|
+
chunk_cats,
|
|
215
|
+
chunk_categories_str,
|
|
216
|
+
chunk_json_schema,
|
|
217
|
+
chunk_example_json,
|
|
218
|
+
chunk_cove_task,
|
|
219
|
+
effective_creativity,
|
|
220
|
+
survey_question,
|
|
221
|
+
survey_question_context,
|
|
222
|
+
examples_text,
|
|
223
|
+
chain_of_thought,
|
|
224
|
+
context_prompt,
|
|
225
|
+
step_back_prompt,
|
|
226
|
+
stepback_insights,
|
|
227
|
+
chain_of_verification,
|
|
228
|
+
thinking_budget,
|
|
229
|
+
max_retries,
|
|
230
|
+
multi_label,
|
|
231
|
+
formatter_fallback_fn,
|
|
232
|
+
is_pdf_mode,
|
|
233
|
+
is_image_mode,
|
|
234
|
+
pdf_mode,
|
|
235
|
+
pdf_dpi,
|
|
236
|
+
input_description,
|
|
237
|
+
build_text_prompt_fn,
|
|
238
|
+
build_pdf_prompt_fn,
|
|
239
|
+
build_image_prompt_fn,
|
|
240
|
+
google_multimodal_fn,
|
|
241
|
+
prepare_page_data_fn,
|
|
242
|
+
prepare_image_data_fn,
|
|
243
|
+
build_cove_prompts_fn,
|
|
244
|
+
run_cove_fn,
|
|
245
|
+
):
|
|
246
|
+
"""
|
|
247
|
+
Run one LLM call for one chunk of categories on one item.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
tuple: (json_result_str, error)
|
|
251
|
+
"""
|
|
252
|
+
thinking_providers = ("google", "openai", "anthropic", "huggingface", "huggingface-together")
|
|
253
|
+
|
|
254
|
+
# =================================================================
|
|
255
|
+
# PDF MODE
|
|
256
|
+
# =================================================================
|
|
257
|
+
if is_pdf_mode and isinstance(item, tuple):
|
|
258
|
+
pdf_path, page_index, page_label = item
|
|
259
|
+
|
|
260
|
+
page_data = prepare_page_data_fn(
|
|
261
|
+
pdf_path=pdf_path,
|
|
262
|
+
page_index=page_index,
|
|
263
|
+
page_label=page_label,
|
|
264
|
+
pdf_mode=pdf_mode,
|
|
265
|
+
provider=cfg["provider"],
|
|
266
|
+
pdf_dpi=pdf_dpi,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
if page_data.get("error"):
|
|
270
|
+
return ('{"1":"e"}', page_data["error"])
|
|
271
|
+
|
|
272
|
+
messages = build_pdf_prompt_fn(
|
|
273
|
+
page_data=page_data,
|
|
274
|
+
categories_str=chunk_categories_str,
|
|
275
|
+
input_description=input_description,
|
|
276
|
+
provider=cfg["provider"],
|
|
277
|
+
pdf_mode=pdf_mode,
|
|
278
|
+
chain_of_thought=chain_of_thought,
|
|
279
|
+
context_prompt=context_prompt,
|
|
280
|
+
step_back_prompt=step_back_prompt,
|
|
281
|
+
stepback_insights=stepback_insights,
|
|
282
|
+
model_name=cfg["model"],
|
|
283
|
+
example_json=chunk_example_json,
|
|
284
|
+
multi_label=multi_label,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
if cfg["provider"] == "google":
|
|
288
|
+
reply, error = google_multimodal_fn(
|
|
289
|
+
client=client,
|
|
290
|
+
messages=messages,
|
|
291
|
+
json_schema=chunk_json_schema,
|
|
292
|
+
creativity=effective_creativity,
|
|
293
|
+
thinking_budget=thinking_budget,
|
|
294
|
+
max_retries=max_retries,
|
|
295
|
+
)
|
|
296
|
+
else:
|
|
297
|
+
reply, error = client.complete(
|
|
298
|
+
messages=messages,
|
|
299
|
+
json_schema=chunk_json_schema,
|
|
300
|
+
creativity=effective_creativity,
|
|
301
|
+
thinking_budget=thinking_budget if cfg["provider"] in thinking_providers else None,
|
|
302
|
+
max_retries=max_retries,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
if error:
|
|
306
|
+
return ('{"1":"e"}', error)
|
|
307
|
+
|
|
308
|
+
json_result = extract_json(reply)
|
|
309
|
+
json_result = formatter_fallback_fn(json_result, reply, chunk_cats)
|
|
310
|
+
return (json_result, None)
|
|
311
|
+
|
|
312
|
+
# =================================================================
|
|
313
|
+
# IMAGE MODE
|
|
314
|
+
# =================================================================
|
|
315
|
+
elif is_image_mode and isinstance(item, tuple):
|
|
316
|
+
image_path, image_label = item
|
|
317
|
+
|
|
318
|
+
image_data = prepare_image_data_fn(image_path, image_label)
|
|
319
|
+
|
|
320
|
+
if image_data.get("error"):
|
|
321
|
+
return ('{"1":"e"}', image_data["error"])
|
|
322
|
+
|
|
323
|
+
messages = build_image_prompt_fn(
|
|
324
|
+
image_data=image_data,
|
|
325
|
+
categories_str=chunk_categories_str,
|
|
326
|
+
input_description=input_description,
|
|
327
|
+
provider=cfg["provider"],
|
|
328
|
+
chain_of_thought=chain_of_thought,
|
|
329
|
+
context_prompt=context_prompt,
|
|
330
|
+
step_back_prompt=step_back_prompt,
|
|
331
|
+
stepback_insights=stepback_insights,
|
|
332
|
+
model_name=cfg["model"],
|
|
333
|
+
example_json=chunk_example_json,
|
|
334
|
+
multi_label=multi_label,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
if cfg["provider"] == "google":
|
|
338
|
+
reply, error = google_multimodal_fn(
|
|
339
|
+
client=client,
|
|
340
|
+
messages=messages,
|
|
341
|
+
json_schema=chunk_json_schema,
|
|
342
|
+
creativity=effective_creativity,
|
|
343
|
+
thinking_budget=thinking_budget,
|
|
344
|
+
max_retries=max_retries,
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
reply, error = client.complete(
|
|
348
|
+
messages=messages,
|
|
349
|
+
json_schema=chunk_json_schema,
|
|
350
|
+
creativity=effective_creativity,
|
|
351
|
+
thinking_budget=thinking_budget if cfg["provider"] in thinking_providers else None,
|
|
352
|
+
max_retries=max_retries,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
if error:
|
|
356
|
+
return ('{"1":"e"}', error)
|
|
357
|
+
|
|
358
|
+
json_result = extract_json(reply)
|
|
359
|
+
json_result = formatter_fallback_fn(json_result, reply, chunk_cats)
|
|
360
|
+
return (json_result, None)
|
|
361
|
+
|
|
362
|
+
# =================================================================
|
|
363
|
+
# TEXT MODE
|
|
364
|
+
# =================================================================
|
|
365
|
+
else:
|
|
366
|
+
response_text = item
|
|
367
|
+
|
|
368
|
+
if cfg["use_two_step"]: # Ollama
|
|
369
|
+
json_result, error = ollama_two_step_classify(
|
|
370
|
+
client=client,
|
|
371
|
+
response_text=response_text,
|
|
372
|
+
categories=chunk_cats,
|
|
373
|
+
categories_str=chunk_categories_str,
|
|
374
|
+
survey_question=survey_question,
|
|
375
|
+
creativity=effective_creativity,
|
|
376
|
+
max_retries=max_retries,
|
|
377
|
+
)
|
|
378
|
+
if not error:
|
|
379
|
+
json_result = formatter_fallback_fn(json_result, json_result, chunk_cats)
|
|
380
|
+
return (json_result, error)
|
|
381
|
+
else:
|
|
382
|
+
messages = build_text_prompt_fn(
|
|
383
|
+
response_text=response_text,
|
|
384
|
+
categories_str=chunk_categories_str,
|
|
385
|
+
survey_question_context=survey_question_context,
|
|
386
|
+
examples_text=examples_text,
|
|
387
|
+
chain_of_thought=chain_of_thought,
|
|
388
|
+
context_prompt=context_prompt,
|
|
389
|
+
step_back_prompt=step_back_prompt,
|
|
390
|
+
stepback_insights=stepback_insights,
|
|
391
|
+
model_name=cfg["model"],
|
|
392
|
+
multi_label=multi_label,
|
|
393
|
+
)
|
|
394
|
+
reply, error = client.complete(
|
|
395
|
+
messages=messages,
|
|
396
|
+
json_schema=chunk_json_schema,
|
|
397
|
+
creativity=effective_creativity,
|
|
398
|
+
thinking_budget=thinking_budget if cfg["provider"] in thinking_providers else None,
|
|
399
|
+
max_retries=max_retries,
|
|
400
|
+
)
|
|
401
|
+
if error:
|
|
402
|
+
return ('{"1":"e"}', error)
|
|
403
|
+
|
|
404
|
+
json_result = extract_json(reply)
|
|
405
|
+
json_result = formatter_fallback_fn(json_result, reply, chunk_cats)
|
|
406
|
+
|
|
407
|
+
# Run Chain of Verification if enabled
|
|
408
|
+
if chain_of_verification:
|
|
409
|
+
step2, step3, step4 = build_cove_prompts_fn(
|
|
410
|
+
chunk_cove_task, response_text
|
|
411
|
+
)
|
|
412
|
+
json_result = run_cove_fn(
|
|
413
|
+
client=client,
|
|
414
|
+
initial_reply=json_result,
|
|
415
|
+
step2_prompt=step2,
|
|
416
|
+
step3_prompt=step3,
|
|
417
|
+
step4_prompt=step4,
|
|
418
|
+
json_schema=chunk_json_schema,
|
|
419
|
+
creativity=effective_creativity,
|
|
420
|
+
max_retries=max_retries,
|
|
421
|
+
)
|
|
422
|
+
json_result = formatter_fallback_fn(json_result, json_result, chunk_cats)
|
|
423
|
+
|
|
424
|
+
return (json_result, None)
|
cat_stack/_embeddings.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding-based similarity scores for CatLLM.
|
|
3
|
+
|
|
4
|
+
Uses a local sentence-transformer model (BAAI/bge-small-en-v1.5, 33M params,
|
|
5
|
+
~130MB) to compute cosine similarity between each input text and each category.
|
|
6
|
+
Scores are independent per (text, category) pair — no softmax across categories,
|
|
7
|
+
since this is multi-label classification.
|
|
8
|
+
|
|
9
|
+
The embeddings feature is opt-in via embeddings=True on classify(). It adds
|
|
10
|
+
`_similarity` columns alongside the existing binary 0/1 classification columns.
|
|
11
|
+
|
|
12
|
+
Requires: pip install cat-llm[embeddings]
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
_EMBEDDING_MODEL_NAME = "BAAI/bge-small-en-v1.5"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _check_dependencies():
|
|
21
|
+
"""Check that sentence-transformers is installed."""
|
|
22
|
+
try:
|
|
23
|
+
import sentence_transformers # noqa: F401
|
|
24
|
+
except ImportError:
|
|
25
|
+
raise ImportError(
|
|
26
|
+
"The embeddings feature requires sentence-transformers.\n"
|
|
27
|
+
"Install with: pip install cat-llm[embeddings]\n"
|
|
28
|
+
" (requires: sentence-transformers, which pulls in torch and transformers)"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _is_model_cached() -> bool:
|
|
33
|
+
"""Check if the embedding model is already in the HuggingFace cache."""
|
|
34
|
+
try:
|
|
35
|
+
from huggingface_hub import try_to_load_from_cache
|
|
36
|
+
result = try_to_load_from_cache(_EMBEDDING_MODEL_NAME, "config.json")
|
|
37
|
+
return result is not None and not isinstance(result, type(None))
|
|
38
|
+
except Exception:
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def ensure_embeddings_available() -> bool:
|
|
43
|
+
"""
|
|
44
|
+
Ensure the embedding model is available, prompting to download if needed.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
True if the model is ready to use, False if user declined download.
|
|
48
|
+
"""
|
|
49
|
+
_check_dependencies()
|
|
50
|
+
|
|
51
|
+
if _is_model_cached():
|
|
52
|
+
return True
|
|
53
|
+
|
|
54
|
+
print(
|
|
55
|
+
"\n[CatLLM] The embedding model (~130MB) will be downloaded from\n"
|
|
56
|
+
f" HuggingFace Hub ({_EMBEDDING_MODEL_NAME}).\n"
|
|
57
|
+
" This is a one-time download — the model is cached locally after."
|
|
58
|
+
)
|
|
59
|
+
try:
|
|
60
|
+
answer = input(" Continue? (Y/n): ").strip().lower()
|
|
61
|
+
except (EOFError, KeyboardInterrupt):
|
|
62
|
+
answer = "n"
|
|
63
|
+
|
|
64
|
+
if answer in ("", "y", "yes"):
|
|
65
|
+
return True
|
|
66
|
+
else:
|
|
67
|
+
print(" -> Embedding scores disabled for this run.\n")
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def load_embedding_model():
|
|
72
|
+
"""
|
|
73
|
+
Load and return the sentence-transformer embedding model.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
SentenceTransformer model instance.
|
|
77
|
+
"""
|
|
78
|
+
_check_dependencies()
|
|
79
|
+
|
|
80
|
+
from sentence_transformers import SentenceTransformer
|
|
81
|
+
|
|
82
|
+
print(f"[CatLLM] Loading embedding model ({_EMBEDDING_MODEL_NAME})...")
|
|
83
|
+
model = SentenceTransformer(_EMBEDDING_MODEL_NAME)
|
|
84
|
+
print("[CatLLM] Embedding model ready.")
|
|
85
|
+
return model
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def compute_embedding_scores(texts, categories, model, category_descriptions=None):
|
|
89
|
+
"""
|
|
90
|
+
Compute cosine similarity scores between texts and categories.
|
|
91
|
+
|
|
92
|
+
Each (text, category) score is independent — no softmax across categories.
|
|
93
|
+
Raw cosine similarity is rescaled from [-1, 1] to [0, 1] via (sim + 1) / 2.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
texts: List of input text strings.
|
|
97
|
+
categories: List of category name strings.
|
|
98
|
+
model: Loaded SentenceTransformer model.
|
|
99
|
+
category_descriptions: Optional dict mapping category names to richer
|
|
100
|
+
descriptions for embedding (e.g., {"Past_Support": "References to
|
|
101
|
+
help received from family in the past"}).
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Dict mapping "category_N_similarity" -> list of float scores, where N
|
|
105
|
+
is 1-indexed to match the existing classification column naming.
|
|
106
|
+
"""
|
|
107
|
+
from sentence_transformers import util
|
|
108
|
+
|
|
109
|
+
# Convert NaN/None to empty string
|
|
110
|
+
clean_texts = [str(t) if pd.notna(t) else "" for t in texts]
|
|
111
|
+
|
|
112
|
+
# Build category strings for embedding
|
|
113
|
+
cat_strings = []
|
|
114
|
+
for cat in categories:
|
|
115
|
+
if category_descriptions and cat in category_descriptions:
|
|
116
|
+
cat_strings.append(f"{cat}: {category_descriptions[cat]}")
|
|
117
|
+
else:
|
|
118
|
+
cat_strings.append(cat)
|
|
119
|
+
|
|
120
|
+
# Encode all texts and categories
|
|
121
|
+
text_embeddings = model.encode(clean_texts, normalize_embeddings=True,
|
|
122
|
+
show_progress_bar=len(clean_texts) > 100)
|
|
123
|
+
cat_embeddings = model.encode(cat_strings, normalize_embeddings=True)
|
|
124
|
+
|
|
125
|
+
# Compute cosine similarity matrix: (num_texts, num_categories)
|
|
126
|
+
sim_matrix = util.cos_sim(text_embeddings, cat_embeddings)
|
|
127
|
+
|
|
128
|
+
# Rescale from [-1, 1] to [0, 1]
|
|
129
|
+
scores = (sim_matrix + 1) / 2
|
|
130
|
+
|
|
131
|
+
# Build output dict
|
|
132
|
+
result = {}
|
|
133
|
+
for i, _cat in enumerate(categories):
|
|
134
|
+
col_name = f"category_{i + 1}_similarity"
|
|
135
|
+
result[col_name] = [round(float(scores[row][i]), 4) for row in range(len(clean_texts))]
|
|
136
|
+
|
|
137
|
+
return result
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def apply_embedding_scores(df, categories, embedding_model, category_descriptions=None):
|
|
141
|
+
"""
|
|
142
|
+
Insert embedding similarity columns into a result DataFrame.
|
|
143
|
+
|
|
144
|
+
For each category N, a `category_N_similarity` column is inserted after the
|
|
145
|
+
last existing column that belongs to that category number.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
df: Result DataFrame from classify (single-model or ensemble).
|
|
149
|
+
categories: List of category name strings.
|
|
150
|
+
embedding_model: Loaded SentenceTransformer model.
|
|
151
|
+
category_descriptions: Optional dict mapping category names to descriptions.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
DataFrame with `_similarity` columns inserted.
|
|
155
|
+
"""
|
|
156
|
+
# Find the text column to use for embedding
|
|
157
|
+
if "input_data" in df.columns:
|
|
158
|
+
texts = df["input_data"].tolist()
|
|
159
|
+
else:
|
|
160
|
+
# Fallback: use first column
|
|
161
|
+
texts = df.iloc[:, 0].tolist()
|
|
162
|
+
|
|
163
|
+
scores = compute_embedding_scores(texts, categories, embedding_model,
|
|
164
|
+
category_descriptions)
|
|
165
|
+
|
|
166
|
+
# Insert each _similarity column after the last column for that category number
|
|
167
|
+
result_df = df.copy()
|
|
168
|
+
for i in range(len(categories)):
|
|
169
|
+
prob_col = f"category_{i + 1}_similarity"
|
|
170
|
+
prob_values = scores[prob_col]
|
|
171
|
+
|
|
172
|
+
# Find the last column that starts with "category_{N}_" or equals "category_{N}"
|
|
173
|
+
# Use exact match on the number to avoid category_1 matching category_10
|
|
174
|
+
cat_prefix = f"category_{i + 1}_"
|
|
175
|
+
cat_exact = f"category_{i + 1}"
|
|
176
|
+
|
|
177
|
+
last_pos = -1
|
|
178
|
+
for col_idx, col_name in enumerate(result_df.columns):
|
|
179
|
+
if col_name == cat_exact or col_name.startswith(cat_prefix):
|
|
180
|
+
last_pos = col_idx
|
|
181
|
+
|
|
182
|
+
if last_pos >= 0:
|
|
183
|
+
# Insert after the last matching column
|
|
184
|
+
result_df.insert(last_pos + 1, prob_col, prob_values)
|
|
185
|
+
else:
|
|
186
|
+
# No matching column found — append at the end
|
|
187
|
+
result_df[prob_col] = prob_values
|
|
188
|
+
|
|
189
|
+
return result_df
|