cat-stack 1.3.0__tar.gz → 1.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {cat_stack-1.3.0 → cat_stack-1.4.1}/PKG-INFO +1 -1
  2. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/__about__.py +1 -1
  3. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/__init__.py +2 -0
  4. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_wrapper_helpers.py +98 -0
  5. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/classify.py +7 -2
  6. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/summarize.py +2 -2
  7. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/text_functions_ensemble.py +59 -32
  8. {cat_stack-1.3.0 → cat_stack-1.4.1}/.gitignore +0 -0
  9. {cat_stack-1.3.0 → cat_stack-1.4.1}/LICENSE +0 -0
  10. {cat_stack-1.3.0 → cat_stack-1.4.1}/README.md +0 -0
  11. {cat_stack-1.3.0 → cat_stack-1.4.1}/pyproject.toml +0 -0
  12. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/cat_stack/__init__.py +0 -0
  13. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_batch.py +0 -0
  14. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_category_analysis.py +0 -0
  15. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_chunked.py +0 -0
  16. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_embeddings.py +0 -0
  17. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_formatter.py +0 -0
  18. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_pilot_test.py +0 -0
  19. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_prompts.py +0 -0
  20. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_providers.py +0 -0
  21. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_review_ui.py +0 -0
  22. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_tiebreaker.py +0 -0
  23. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_utils.py +0 -0
  24. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/_web_fetch.py +0 -0
  25. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/CoVe.py +0 -0
  26. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/__init__.py +0 -0
  27. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/all_calls.py +0 -0
  28. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/image_CoVe.py +0 -0
  29. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/image_stepback.py +0 -0
  30. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/pdf_CoVe.py +0 -0
  31. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/pdf_stepback.py +0 -0
  32. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/stepback.py +0 -0
  33. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/calls/top_n.py +0 -0
  34. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/explore.py +0 -0
  35. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/extract.py +0 -0
  36. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/image_functions.py +0 -0
  37. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/images/circle.png +0 -0
  38. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/images/cube.png +0 -0
  39. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/images/diamond.png +0 -0
  40. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/images/overlapping_pentagons.png +0 -0
  41. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/images/rectangles.png +0 -0
  42. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/model_reference_list.py +0 -0
  43. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/pdf_functions.py +0 -0
  44. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/prompt_tune.py +0 -0
  45. {cat_stack-1.3.0 → cat_stack-1.4.1}/src/catstack/text_functions.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cat-stack
3
- Version: 1.3.0
3
+ Version: 1.4.1
4
4
  Summary: Domain-agnostic text, image, PDF, and DOCX classification engine powered by LLMs
5
5
  Project-URL: Documentation, https://github.com/chrissoria/cat-stack#readme
6
6
  Project-URL: Issues, https://github.com/chrissoria/cat-stack/issues
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
2
  #
3
3
  # SPDX-License-Identifier: GPL-3.0-or-later
4
- __version__ = "1.3.0"
4
+ __version__ = "1.4.1"
5
5
  __author__ = "Chris Soria"
6
6
  __email__ = "chrissoria@berkeley.edu"
7
7
  __title__ = "cat-stack"
@@ -92,6 +92,7 @@ from ._wrapper_helpers import (
92
92
  parse_models_string,
93
93
  short_label,
94
94
  classify_labels,
95
+ classify_indicators,
95
96
  )
96
97
 
97
98
  # Define public API
@@ -144,4 +145,5 @@ __all__ = [
144
145
  "parse_models_string",
145
146
  "short_label",
146
147
  "classify_labels",
148
+ "classify_indicators",
147
149
  ]
@@ -328,3 +328,101 @@ def classify_labels(
328
328
  if return_full:
329
329
  return labels_per_row, df
330
330
  return labels_per_row
331
+
332
+
333
+ def classify_indicators(
334
+ input_data,
335
+ categories,
336
+ *,
337
+ short_labels: bool = True,
338
+ return_full: bool = False,
339
+ **kwargs,
340
+ ):
341
+ """Convenience wrapper around `classify()` returning per-category indicators.
342
+
343
+ Like `classify_labels`, but instead of collapsing the wide DataFrame to
344
+ one assigned label per row, it returns a dict mapping each category to
345
+ a list of 0/1 indicators of length `len(input_data)`.
346
+
347
+ This is the right shape for language wrappers that want one indicator
348
+ variable per category (Stata's wide mode, future R `as_indicators=TRUE`
349
+ mode) instead of a single label per row.
350
+
351
+ Args:
352
+ input_data: Same as `classify()`.
353
+ categories: Same as `classify()` — list of category strings.
354
+ short_labels: If True (default), use `short_label()` on each
355
+ category to produce dict keys (`"Positive: defn"` → `"Positive"`).
356
+ If False, the dict keys are the full category strings.
357
+ return_full: If True, return `(indicators_dict, df)` so callers also
358
+ have access to the underlying DataFrame. Default False.
359
+ **kwargs: All other kwargs are forwarded to `classify()`.
360
+
361
+ Returns:
362
+ dict[str, list[int]]: keys are category labels (short or full),
363
+ values are 0/1 lists of length `len(input_data)`. In ensemble mode
364
+ the indicators come from the `category_N_consensus` columns; in
365
+ single-model mode from `category_N`.
366
+ Or `(dict, df)` tuple if `return_full=True`.
367
+
368
+ Raises:
369
+ RuntimeError: if `classify()` returns a DataFrame that contains
370
+ neither `category_N` nor `category_N_consensus` columns
371
+ (centralized schema canary, same trigger as `classify_labels`).
372
+
373
+ Example:
374
+ >>> indicators = classify_indicators(
375
+ ... ["I moved for the job and to be near family.",
376
+ ... "Lower cost of living was the only reason."],
377
+ ... ["Job: career", "Family: relationships", "Cost: affordability"],
378
+ ... api_key="...", user_model="gpt-4o-mini",
379
+ ... )
380
+ >>> indicators
381
+ {'Job': [1, 0], 'Family': [1, 0], 'Cost': [0, 1]}
382
+ """
383
+ # Reuse classify_labels for the df + centralized schema canary. We
384
+ # pass short_labels=False because we want the raw df; we apply our own
385
+ # short_label() to the dict keys below.
386
+ _labels, df = classify_labels(
387
+ input_data,
388
+ categories,
389
+ short_labels=False,
390
+ return_full=True,
391
+ **kwargs,
392
+ )
393
+
394
+ cols = list(df.columns)
395
+ indexed: List[Tuple[int, str]] = []
396
+ for c in cols:
397
+ m = _CONSENSUS_COL_PAT.match(c)
398
+ if m:
399
+ indexed.append((int(m.group(1)), c))
400
+ if not indexed:
401
+ for c in cols:
402
+ m = _SINGLE_COL_PAT.match(c)
403
+ if m:
404
+ indexed.append((int(m.group(1)), c))
405
+ # classify_labels already raised RuntimeError if neither family is
406
+ # present, so we know `indexed` is non-empty here.
407
+ indexed.sort(key=lambda t: t[0])
408
+
409
+ keys = [short_label(c) if short_labels else c for c in categories]
410
+
411
+ out: Dict[str, List[int]] = {}
412
+ for n, col in indexed:
413
+ cat_idx = n - 1
414
+ if not (0 <= cat_idx < len(keys)):
415
+ continue
416
+ key = str(keys[cat_idx])
417
+ series = df[col]
418
+ values: List[int] = []
419
+ for v in series:
420
+ try:
421
+ values.append(1 if int(v) == 1 else 0)
422
+ except (ValueError, TypeError):
423
+ values.append(0)
424
+ out[key] = values
425
+
426
+ if return_full:
427
+ return out, df
428
+ return out
@@ -84,7 +84,8 @@ def classify(
84
84
  parallel: bool = None,
85
85
  fail_strategy: str = "partial",
86
86
  max_retries: int = 5,
87
- batch_retries: int = 2,
87
+ batch_retries: int = 1,
88
+ json_retries: int = 2,
88
89
  retry_delay: float = 1.0,
89
90
  row_delay: float = 0.0,
90
91
  pdf_dpi: int = 150,
@@ -183,7 +184,9 @@ def classify(
183
184
  (e.g., Ollama on limited hardware) or debugging.
184
185
  fail_strategy (str): How to handle failures - "partial" (default) or "strict".
185
186
  max_retries (int): Max retries per API call. Default 5.
186
- batch_retries (int): Max retries for batch-level failures. Default 2.
187
+ batch_retries (int): Max retries for batch-level failures. Default 1.
188
+ Note: composes multiplicatively with json_retries — a row can hit
189
+ the LLM up to (1 + json_retries) * (1 + batch_retries) times.
187
190
  retry_delay (float): Delay between retries in seconds. Default 1.0.
188
191
  row_delay (float): Delay in seconds between processing each row. Useful
189
192
  when multiple models share the same API provider/key to avoid rate
@@ -407,6 +410,7 @@ def classify(
407
410
  fail_strategy=fail_strategy,
408
411
  max_retries=max_retries,
409
412
  batch_retries=batch_retries,
413
+ json_retries=json_retries,
410
414
  retry_delay=retry_delay,
411
415
  row_delay=row_delay,
412
416
  auto_download=auto_download,
@@ -849,6 +853,7 @@ def classify(
849
853
  fail_strategy=fail_strategy,
850
854
  max_retries=max_retries,
851
855
  batch_retries=batch_retries,
856
+ json_retries=json_retries,
852
857
  retry_delay=retry_delay,
853
858
  row_delay=row_delay,
854
859
  auto_download=auto_download,
@@ -55,7 +55,7 @@ def summarize(
55
55
  # Robustness parameters
56
56
  safety: bool = False,
57
57
  max_retries: int = 5,
58
- batch_retries: int = 2,
58
+ batch_retries: int = 1,
59
59
  retry_delay: float = 1.0,
60
60
  row_delay: float = 0.0,
61
61
  fail_strategy: str = "partial",
@@ -131,7 +131,7 @@ def summarize(
131
131
  auto_download (bool): Auto-download missing Ollama models. Default False.
132
132
  safety (bool): If True, saves progress after each item. Requires filename.
133
133
  max_retries (int): Max retries per API call. Default 5.
134
- batch_retries (int): Max retries for batch-level failures. Default 2.
134
+ batch_retries (int): Max retries for batch-level failures. Default 1.
135
135
  retry_delay (float): Delay between retries in seconds. Default 1.0.
136
136
  row_delay (float): Delay in seconds between processing each row. Default 0.0.
137
137
  fail_strategy (str): How to handle failures - "partial" (default) or "strict".
@@ -2277,7 +2277,8 @@ def classify_ensemble(
2277
2277
  fail_strategy: str = "partial",
2278
2278
  safety: bool = False,
2279
2279
  max_retries: int = 5,
2280
- batch_retries: int = 2,
2280
+ batch_retries: int = 1,
2281
+ json_retries: int = 2,
2281
2282
  retry_delay: float = 1.0,
2282
2283
  row_delay: float = 0.0,
2283
2284
  filename: str = None,
@@ -2368,8 +2369,10 @@ def classify_ensemble(
2368
2369
  max_retries: Maximum retry attempts for each API call (handles rate limits,
2369
2370
  server errors, timeouts). Default 5.
2370
2371
  batch_retries: Maximum retry passes for failed (row, model) pairs after
2371
- the batch completes. Default 2 means up to 3 total attempts. Set to 0
2372
- to disable batch-level retries.
2372
+ the batch completes. Default 1 means up to 2 total attempts. Set to 0
2373
+ to disable batch-level retries. Note: composes multiplicatively with
2374
+ json_retries — a row can hit the LLM up to
2375
+ (1 + json_retries) * (1 + batch_retries) times.
2373
2376
  retry_delay: Seconds to wait between batch retry passes.
2374
2377
 
2375
2378
  # Output parameters:
@@ -3009,35 +3012,59 @@ Categorize text responses {cove_categorize}:
3009
3012
  multi_label=multi_label,
3010
3013
  system_prompt=system_prompt,
3011
3014
  )
3012
- reply, error = client.complete(
3013
- messages=messages,
3014
- json_schema=json_schemas[cfg["model"]],
3015
- creativity=effective_creativity,
3016
- thinking_budget=thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None,
3017
- max_retries=max_retries,
3018
- )
3019
- if error:
3020
- json_result = '{"1":"e"}'
3021
- else:
3015
+
3016
+ json_result = '{"1":"e"}'
3017
+ error = None
3018
+ _n_cats = len(categories)
3019
+
3020
+ for _json_attempt in range(json_retries + 1):
3021
+ # On retries, nudge the model toward plain JSON output
3022
+ if _json_attempt > 0:
3023
+ _nudge = "\n\nRespond with ONLY valid JSON, no explanation or additional text."
3024
+ _last = messages[-1]
3025
+ _content = _last.get("content", "")
3026
+ _retry_messages = messages[:-1] + [{**_last, "content": _content + _nudge}]
3027
+ else:
3028
+ _retry_messages = messages
3029
+
3030
+ reply, error = client.complete(
3031
+ messages=_retry_messages,
3032
+ json_schema=json_schemas[cfg["model"]],
3033
+ creativity=effective_creativity,
3034
+ thinking_budget=thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None,
3035
+ max_retries=max_retries,
3036
+ )
3037
+
3038
+ if error:
3039
+ json_result = '{"1":"e"}'
3040
+ break # API-level failure already retried by max_retries
3041
+
3022
3042
  json_result = extract_json(reply)
3023
- json_result = _try_formatter_fallback(json_result, reply)
3043
+ _json_valid, _ = validate_classification_json(json_result, _n_cats)
3024
3044
 
3025
- # Run Chain of Verification if enabled
3026
- if chain_of_verification and not error:
3027
- step2, step3, step4 = build_cove_prompts(
3028
- cove_original_task, response_text
3029
- )
3030
- json_result = run_chain_of_verification(
3031
- client=client,
3032
- initial_reply=json_result,
3033
- step2_prompt=step2,
3034
- step3_prompt=step3,
3035
- step4_prompt=step4,
3036
- json_schema=json_schemas[cfg["model"]],
3037
- creativity=effective_creativity,
3038
- max_retries=max_retries,
3039
- )
3040
- json_result = _try_formatter_fallback(json_result, json_result)
3045
+ if _json_valid:
3046
+ break
3047
+
3048
+ # Final attempt: invoke formatter before giving up
3049
+ if _json_attempt == json_retries:
3050
+ json_result = _try_formatter_fallback(json_result, reply)
3051
+
3052
+ # Run Chain of Verification if enabled
3053
+ if chain_of_verification and not error:
3054
+ step2, step3, step4 = build_cove_prompts(
3055
+ cove_original_task, response_text
3056
+ )
3057
+ json_result = run_chain_of_verification(
3058
+ client=client,
3059
+ initial_reply=json_result,
3060
+ step2_prompt=step2,
3061
+ step3_prompt=step3,
3062
+ step4_prompt=step4,
3063
+ json_schema=json_schemas[cfg["model"]],
3064
+ creativity=effective_creativity,
3065
+ max_retries=max_retries,
3066
+ )
3067
+ json_result = _try_formatter_fallback(json_result, json_result)
3041
3068
 
3042
3069
  return (cfg["sanitized_name"], json_result, error)
3043
3070
 
@@ -3760,7 +3787,7 @@ def summarize_ensemble(
3760
3787
  context_prompt: bool = False,
3761
3788
  step_back_prompt: bool = False,
3762
3789
  max_retries: int = 5,
3763
- batch_retries: int = 2,
3790
+ batch_retries: int = 1,
3764
3791
  retry_delay: float = 1.0,
3765
3792
  row_delay: float = 0.0,
3766
3793
  fail_strategy: str = "partial",
@@ -3806,7 +3833,7 @@ def summarize_ensemble(
3806
3833
  context_prompt: Add expert context prefix
3807
3834
  step_back_prompt: Enable step-back prompting
3808
3835
  max_retries: Max retries per API call
3809
- batch_retries: Number of batch retry passes for failed items
3836
+ batch_retries: Number of batch retry passes for failed items (default 1)
3810
3837
  retry_delay: Delay between retries in seconds
3811
3838
  row_delay: Delay in seconds between processing each row (default 0.0)
3812
3839
  fail_strategy: How to handle failures - "partial" (default) or "strict"
File without changes
File without changes
File without changes
File without changes