cat-stack 1.3.0__tar.gz → 1.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_stack-1.3.0 → cat_stack-1.4.1}/PKG-INFO +1 -1
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/__about__.py +1 -1
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/__init__.py +2 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_wrapper_helpers.py +98 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/classify.py +7 -2
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/summarize.py +2 -2
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/text_functions_ensemble.py +59 -32
- {cat_stack-1.3.0 → cat_stack-1.4.1}/.gitignore +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/LICENSE +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/README.md +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/pyproject.toml +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/cat_stack/__init__.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_batch.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_category_analysis.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_chunked.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_embeddings.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_formatter.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_pilot_test.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_prompts.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_providers.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_review_ui.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_tiebreaker.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_utils.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_web_fetch.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/CoVe.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/__init__.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/all_calls.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/image_CoVe.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/image_stepback.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/pdf_CoVe.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/pdf_stepback.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/stepback.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/top_n.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/explore.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/extract.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/image_functions.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/images/circle.png +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/images/cube.png +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/images/diamond.png +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/images/overlapping_pentagons.png +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/images/rectangles.png +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/model_reference_list.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/pdf_functions.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/prompt_tune.py +0 -0
- {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/text_functions.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-stack
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.1
|
|
4
4
|
Summary: Domain-agnostic text, image, PDF, and DOCX classification engine powered by LLMs
|
|
5
5
|
Project-URL: Documentation, https://github.com/chrissoria/cat-stack#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-stack/issues
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
4
|
-
__version__ = "1.
|
|
4
|
+
__version__ = "1.4.1"
|
|
5
5
|
__author__ = "Chris Soria"
|
|
6
6
|
__email__ = "chrissoria@berkeley.edu"
|
|
7
7
|
__title__ = "cat-stack"
|
|
@@ -92,6 +92,7 @@ from ._wrapper_helpers import (
|
|
|
92
92
|
parse_models_string,
|
|
93
93
|
short_label,
|
|
94
94
|
classify_labels,
|
|
95
|
+
classify_indicators,
|
|
95
96
|
)
|
|
96
97
|
|
|
97
98
|
# Define public API
|
|
@@ -144,4 +145,5 @@ __all__ = [
|
|
|
144
145
|
"parse_models_string",
|
|
145
146
|
"short_label",
|
|
146
147
|
"classify_labels",
|
|
148
|
+
"classify_indicators",
|
|
147
149
|
]
|
|
@@ -328,3 +328,101 @@ def classify_labels(
|
|
|
328
328
|
if return_full:
|
|
329
329
|
return labels_per_row, df
|
|
330
330
|
return labels_per_row
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def classify_indicators(
|
|
334
|
+
input_data,
|
|
335
|
+
categories,
|
|
336
|
+
*,
|
|
337
|
+
short_labels: bool = True,
|
|
338
|
+
return_full: bool = False,
|
|
339
|
+
**kwargs,
|
|
340
|
+
):
|
|
341
|
+
"""Convenience wrapper around `classify()` returning per-category indicators.
|
|
342
|
+
|
|
343
|
+
Like `classify_labels`, but instead of collapsing the wide DataFrame to
|
|
344
|
+
one assigned label per row, it returns a dict mapping each category to
|
|
345
|
+
a list of 0/1 indicators of length `len(input_data)`.
|
|
346
|
+
|
|
347
|
+
This is the right shape for language wrappers that want one indicator
|
|
348
|
+
variable per category (Stata's wide mode, future R `as_indicators=TRUE`
|
|
349
|
+
mode) instead of a single label per row.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
input_data: Same as `classify()`.
|
|
353
|
+
categories: Same as `classify()` — list of category strings.
|
|
354
|
+
short_labels: If True (default), use `short_label()` on each
|
|
355
|
+
category to produce dict keys (`"Positive: defn"` → `"Positive"`).
|
|
356
|
+
If False, the dict keys are the full category strings.
|
|
357
|
+
return_full: If True, return `(indicators_dict, df)` so callers also
|
|
358
|
+
have access to the underlying DataFrame. Default False.
|
|
359
|
+
**kwargs: All other kwargs are forwarded to `classify()`.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
dict[str, list[int]]: keys are category labels (short or full),
|
|
363
|
+
values are 0/1 lists of length `len(input_data)`. In ensemble mode
|
|
364
|
+
the indicators come from the `category_N_consensus` columns; in
|
|
365
|
+
single-model mode from `category_N`.
|
|
366
|
+
Or `(dict, df)` tuple if `return_full=True`.
|
|
367
|
+
|
|
368
|
+
Raises:
|
|
369
|
+
RuntimeError: if `classify()` returns a DataFrame that contains
|
|
370
|
+
neither `category_N` nor `category_N_consensus` columns
|
|
371
|
+
(centralized schema canary, same trigger as `classify_labels`).
|
|
372
|
+
|
|
373
|
+
Example:
|
|
374
|
+
>>> indicators = classify_indicators(
|
|
375
|
+
... ["I moved for the job and to be near family.",
|
|
376
|
+
... "Lower cost of living was the only reason."],
|
|
377
|
+
... ["Job: career", "Family: relationships", "Cost: affordability"],
|
|
378
|
+
... api_key="...", user_model="gpt-4o-mini",
|
|
379
|
+
... )
|
|
380
|
+
>>> indicators
|
|
381
|
+
{'Job': [1, 0], 'Family': [1, 0], 'Cost': [0, 1]}
|
|
382
|
+
"""
|
|
383
|
+
# Reuse classify_labels for the df + centralized schema canary. We
|
|
384
|
+
# pass short_labels=False because we want the raw df; we apply our own
|
|
385
|
+
# short_label() to the dict keys below.
|
|
386
|
+
_labels, df = classify_labels(
|
|
387
|
+
input_data,
|
|
388
|
+
categories,
|
|
389
|
+
short_labels=False,
|
|
390
|
+
return_full=True,
|
|
391
|
+
**kwargs,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
cols = list(df.columns)
|
|
395
|
+
indexed: List[Tuple[int, str]] = []
|
|
396
|
+
for c in cols:
|
|
397
|
+
m = _CONSENSUS_COL_PAT.match(c)
|
|
398
|
+
if m:
|
|
399
|
+
indexed.append((int(m.group(1)), c))
|
|
400
|
+
if not indexed:
|
|
401
|
+
for c in cols:
|
|
402
|
+
m = _SINGLE_COL_PAT.match(c)
|
|
403
|
+
if m:
|
|
404
|
+
indexed.append((int(m.group(1)), c))
|
|
405
|
+
# classify_labels already raised RuntimeError if neither family is
|
|
406
|
+
# present, so we know `indexed` is non-empty here.
|
|
407
|
+
indexed.sort(key=lambda t: t[0])
|
|
408
|
+
|
|
409
|
+
keys = [short_label(c) if short_labels else c for c in categories]
|
|
410
|
+
|
|
411
|
+
out: Dict[str, List[int]] = {}
|
|
412
|
+
for n, col in indexed:
|
|
413
|
+
cat_idx = n - 1
|
|
414
|
+
if not (0 <= cat_idx < len(keys)):
|
|
415
|
+
continue
|
|
416
|
+
key = str(keys[cat_idx])
|
|
417
|
+
series = df[col]
|
|
418
|
+
values: List[int] = []
|
|
419
|
+
for v in series:
|
|
420
|
+
try:
|
|
421
|
+
values.append(1 if int(v) == 1 else 0)
|
|
422
|
+
except (ValueError, TypeError):
|
|
423
|
+
values.append(0)
|
|
424
|
+
out[key] = values
|
|
425
|
+
|
|
426
|
+
if return_full:
|
|
427
|
+
return out, df
|
|
428
|
+
return out
|
|
@@ -84,7 +84,8 @@ def classify(
|
|
|
84
84
|
parallel: bool = None,
|
|
85
85
|
fail_strategy: str = "partial",
|
|
86
86
|
max_retries: int = 5,
|
|
87
|
-
batch_retries: int =
|
|
87
|
+
batch_retries: int = 1,
|
|
88
|
+
json_retries: int = 2,
|
|
88
89
|
retry_delay: float = 1.0,
|
|
89
90
|
row_delay: float = 0.0,
|
|
90
91
|
pdf_dpi: int = 150,
|
|
@@ -183,7 +184,9 @@ def classify(
|
|
|
183
184
|
(e.g., Ollama on limited hardware) or debugging.
|
|
184
185
|
fail_strategy (str): How to handle failures - "partial" (default) or "strict".
|
|
185
186
|
max_retries (int): Max retries per API call. Default 5.
|
|
186
|
-
batch_retries (int): Max retries for batch-level failures. Default
|
|
187
|
+
batch_retries (int): Max retries for batch-level failures. Default 1.
|
|
188
|
+
Note: composes multiplicatively with json_retries — a row can hit
|
|
189
|
+
the LLM up to (1 + json_retries) * (1 + batch_retries) times.
|
|
187
190
|
retry_delay (float): Delay between retries in seconds. Default 1.0.
|
|
188
191
|
row_delay (float): Delay in seconds between processing each row. Useful
|
|
189
192
|
when multiple models share the same API provider/key to avoid rate
|
|
@@ -407,6 +410,7 @@ def classify(
|
|
|
407
410
|
fail_strategy=fail_strategy,
|
|
408
411
|
max_retries=max_retries,
|
|
409
412
|
batch_retries=batch_retries,
|
|
413
|
+
json_retries=json_retries,
|
|
410
414
|
retry_delay=retry_delay,
|
|
411
415
|
row_delay=row_delay,
|
|
412
416
|
auto_download=auto_download,
|
|
@@ -849,6 +853,7 @@ def classify(
|
|
|
849
853
|
fail_strategy=fail_strategy,
|
|
850
854
|
max_retries=max_retries,
|
|
851
855
|
batch_retries=batch_retries,
|
|
856
|
+
json_retries=json_retries,
|
|
852
857
|
retry_delay=retry_delay,
|
|
853
858
|
row_delay=row_delay,
|
|
854
859
|
auto_download=auto_download,
|
|
@@ -55,7 +55,7 @@ def summarize(
|
|
|
55
55
|
# Robustness parameters
|
|
56
56
|
safety: bool = False,
|
|
57
57
|
max_retries: int = 5,
|
|
58
|
-
batch_retries: int =
|
|
58
|
+
batch_retries: int = 1,
|
|
59
59
|
retry_delay: float = 1.0,
|
|
60
60
|
row_delay: float = 0.0,
|
|
61
61
|
fail_strategy: str = "partial",
|
|
@@ -131,7 +131,7 @@ def summarize(
|
|
|
131
131
|
auto_download (bool): Auto-download missing Ollama models. Default False.
|
|
132
132
|
safety (bool): If True, saves progress after each item. Requires filename.
|
|
133
133
|
max_retries (int): Max retries per API call. Default 5.
|
|
134
|
-
batch_retries (int): Max retries for batch-level failures. Default
|
|
134
|
+
batch_retries (int): Max retries for batch-level failures. Default 1.
|
|
135
135
|
retry_delay (float): Delay between retries in seconds. Default 1.0.
|
|
136
136
|
row_delay (float): Delay in seconds between processing each row. Default 0.0.
|
|
137
137
|
fail_strategy (str): How to handle failures - "partial" (default) or "strict".
|
|
@@ -2277,7 +2277,8 @@ def classify_ensemble(
|
|
|
2277
2277
|
fail_strategy: str = "partial",
|
|
2278
2278
|
safety: bool = False,
|
|
2279
2279
|
max_retries: int = 5,
|
|
2280
|
-
batch_retries: int =
|
|
2280
|
+
batch_retries: int = 1,
|
|
2281
|
+
json_retries: int = 2,
|
|
2281
2282
|
retry_delay: float = 1.0,
|
|
2282
2283
|
row_delay: float = 0.0,
|
|
2283
2284
|
filename: str = None,
|
|
@@ -2368,8 +2369,10 @@ def classify_ensemble(
|
|
|
2368
2369
|
max_retries: Maximum retry attempts for each API call (handles rate limits,
|
|
2369
2370
|
server errors, timeouts). Default 5.
|
|
2370
2371
|
batch_retries: Maximum retry passes for failed (row, model) pairs after
|
|
2371
|
-
the batch completes. Default
|
|
2372
|
-
to disable batch-level retries.
|
|
2372
|
+
the batch completes. Default 1 means up to 2 total attempts. Set to 0
|
|
2373
|
+
to disable batch-level retries. Note: composes multiplicatively with
|
|
2374
|
+
json_retries — a row can hit the LLM up to
|
|
2375
|
+
(1 + json_retries) * (1 + batch_retries) times.
|
|
2373
2376
|
retry_delay: Seconds to wait between batch retry passes.
|
|
2374
2377
|
|
|
2375
2378
|
# Output parameters:
|
|
@@ -3009,35 +3012,59 @@ Categorize text responses {cove_categorize}:
|
|
|
3009
3012
|
multi_label=multi_label,
|
|
3010
3013
|
system_prompt=system_prompt,
|
|
3011
3014
|
)
|
|
3012
|
-
|
|
3013
|
-
|
|
3014
|
-
|
|
3015
|
-
|
|
3016
|
-
|
|
3017
|
-
|
|
3018
|
-
|
|
3019
|
-
|
|
3020
|
-
|
|
3021
|
-
|
|
3015
|
+
|
|
3016
|
+
json_result = '{"1":"e"}'
|
|
3017
|
+
error = None
|
|
3018
|
+
_n_cats = len(categories)
|
|
3019
|
+
|
|
3020
|
+
for _json_attempt in range(json_retries + 1):
|
|
3021
|
+
# On retries, nudge the model toward plain JSON output
|
|
3022
|
+
if _json_attempt > 0:
|
|
3023
|
+
_nudge = "\n\nRespond with ONLY valid JSON, no explanation or additional text."
|
|
3024
|
+
_last = messages[-1]
|
|
3025
|
+
_content = _last.get("content", "")
|
|
3026
|
+
_retry_messages = messages[:-1] + [{**_last, "content": _content + _nudge}]
|
|
3027
|
+
else:
|
|
3028
|
+
_retry_messages = messages
|
|
3029
|
+
|
|
3030
|
+
reply, error = client.complete(
|
|
3031
|
+
messages=_retry_messages,
|
|
3032
|
+
json_schema=json_schemas[cfg["model"]],
|
|
3033
|
+
creativity=effective_creativity,
|
|
3034
|
+
thinking_budget=thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None,
|
|
3035
|
+
max_retries=max_retries,
|
|
3036
|
+
)
|
|
3037
|
+
|
|
3038
|
+
if error:
|
|
3039
|
+
json_result = '{"1":"e"}'
|
|
3040
|
+
break # API-level failure already retried by max_retries
|
|
3041
|
+
|
|
3022
3042
|
json_result = extract_json(reply)
|
|
3023
|
-
|
|
3043
|
+
_json_valid, _ = validate_classification_json(json_result, _n_cats)
|
|
3024
3044
|
|
|
3025
|
-
|
|
3026
|
-
|
|
3027
|
-
|
|
3028
|
-
|
|
3029
|
-
|
|
3030
|
-
json_result =
|
|
3031
|
-
|
|
3032
|
-
|
|
3033
|
-
|
|
3034
|
-
|
|
3035
|
-
|
|
3036
|
-
|
|
3037
|
-
|
|
3038
|
-
|
|
3039
|
-
|
|
3040
|
-
|
|
3045
|
+
if _json_valid:
|
|
3046
|
+
break
|
|
3047
|
+
|
|
3048
|
+
# Final attempt: invoke formatter before giving up
|
|
3049
|
+
if _json_attempt == json_retries:
|
|
3050
|
+
json_result = _try_formatter_fallback(json_result, reply)
|
|
3051
|
+
|
|
3052
|
+
# Run Chain of Verification if enabled
|
|
3053
|
+
if chain_of_verification and not error:
|
|
3054
|
+
step2, step3, step4 = build_cove_prompts(
|
|
3055
|
+
cove_original_task, response_text
|
|
3056
|
+
)
|
|
3057
|
+
json_result = run_chain_of_verification(
|
|
3058
|
+
client=client,
|
|
3059
|
+
initial_reply=json_result,
|
|
3060
|
+
step2_prompt=step2,
|
|
3061
|
+
step3_prompt=step3,
|
|
3062
|
+
step4_prompt=step4,
|
|
3063
|
+
json_schema=json_schemas[cfg["model"]],
|
|
3064
|
+
creativity=effective_creativity,
|
|
3065
|
+
max_retries=max_retries,
|
|
3066
|
+
)
|
|
3067
|
+
json_result = _try_formatter_fallback(json_result, json_result)
|
|
3041
3068
|
|
|
3042
3069
|
return (cfg["sanitized_name"], json_result, error)
|
|
3043
3070
|
|
|
@@ -3760,7 +3787,7 @@ def summarize_ensemble(
|
|
|
3760
3787
|
context_prompt: bool = False,
|
|
3761
3788
|
step_back_prompt: bool = False,
|
|
3762
3789
|
max_retries: int = 5,
|
|
3763
|
-
batch_retries: int =
|
|
3790
|
+
batch_retries: int = 1,
|
|
3764
3791
|
retry_delay: float = 1.0,
|
|
3765
3792
|
row_delay: float = 0.0,
|
|
3766
3793
|
fail_strategy: str = "partial",
|
|
@@ -3806,7 +3833,7 @@ def summarize_ensemble(
|
|
|
3806
3833
|
context_prompt: Add expert context prefix
|
|
3807
3834
|
step_back_prompt: Enable step-back prompting
|
|
3808
3835
|
max_retries: Max retries per API call
|
|
3809
|
-
batch_retries: Number of batch retry passes for failed items
|
|
3836
|
+
batch_retries: Number of batch retry passes for failed items (default 1)
|
|
3810
3837
|
retry_delay: Delay between retries in seconds
|
|
3811
3838
|
row_delay: Delay in seconds between processing each row (default 0.0)
|
|
3812
3839
|
fail_strategy: How to handle failures - "partial" (default) or "strict"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|