cat-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cat_stack/_utils.py ADDED
@@ -0,0 +1,512 @@
1
+ """
2
+ Shared utilities for CatLLM.
3
+
4
+ This module provides utility functions for JSON handling, file loading,
5
+ encoding, and other common operations used across the package.
6
+ """
7
+
8
+ import json
9
+ import regex
10
+
11
+ __all__ = [
12
+ # JSON utilities
13
+ "build_json_schema",
14
+ "extract_json",
15
+ "validate_classification_json",
16
+ "ollama_two_step_classify",
17
+ # Stepback utilities
18
+ "_get_stepback_insight",
19
+ # Image utilities
20
+ "_load_image_files",
21
+ "_encode_image",
22
+ # PDF utilities
23
+ "_anthropic_supports_pdf",
24
+ "_load_pdf_files",
25
+ "_get_pdf_pages",
26
+ "_extract_page_as_pdf_bytes",
27
+ "_extract_page_as_image_bytes",
28
+ "_encode_bytes_to_base64",
29
+ "_extract_page_text",
30
+ ]
31
+
32
+
33
+ # =============================================================================
34
+ # JSON Schema Functions
35
+ # =============================================================================
36
+
37
+ def build_json_schema(categories: list, include_additional_properties: bool = True) -> dict:
38
+ """Build a JSON schema for the classification output.
39
+
40
+ Args:
41
+ categories: List of category names
42
+ include_additional_properties: If True, includes additionalProperties: false
43
+ (required by OpenAI strict mode, not supported by Google)
44
+ """
45
+ properties = {}
46
+ for i, cat in enumerate(categories, 1):
47
+ properties[str(i)] = {
48
+ "type": "string",
49
+ "enum": ["0", "1"],
50
+ "description": cat,
51
+ }
52
+
53
+ schema = {
54
+ "type": "object",
55
+ "properties": properties,
56
+ "required": list(properties.keys()),
57
+ }
58
+
59
+ if include_additional_properties:
60
+ schema["additionalProperties"] = False
61
+
62
+ return schema
63
+
64
+
65
+ def extract_json(reply: str) -> str:
66
+ """Extract JSON from model reply."""
67
+ if reply is None:
68
+ return '{"1":"e"}'
69
+
70
+ extracted = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
71
+ if extracted:
72
+ # Clean up the JSON string
73
+ return extracted[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '')
74
+ else:
75
+ return '{"1":"e"}'
76
+
77
+
78
+ def validate_classification_json(json_str: str, num_categories: int) -> tuple[bool, dict | None]:
79
+ """
80
+ Validate that a JSON string contains valid classification output.
81
+
82
+ Args:
83
+ json_str: The JSON string to validate
84
+ num_categories: Expected number of categories
85
+
86
+ Returns:
87
+ tuple: (is_valid, parsed_dict or None)
88
+ """
89
+ try:
90
+ parsed = json.loads(json_str)
91
+
92
+ if not isinstance(parsed, dict):
93
+ return False, None
94
+
95
+ # Check that all expected keys are present and values are "0" or "1"
96
+ for i in range(1, num_categories + 1):
97
+ key = str(i)
98
+ if key not in parsed:
99
+ return False, None
100
+ val = str(parsed[key]).strip()
101
+ if val not in ("0", "1"):
102
+ return False, None
103
+
104
+ # Normalize values to strings
105
+ normalized = {str(i): str(parsed[str(i)]).strip() for i in range(1, num_categories + 1)}
106
+ return True, normalized
107
+
108
+ except (json.JSONDecodeError, KeyError, TypeError):
109
+ return False, None
110
+
111
+
112
+ def ollama_two_step_classify(
113
+ client,
114
+ response_text: str,
115
+ categories: list,
116
+ categories_str: str,
117
+ survey_question: str = "",
118
+ creativity: float = None,
119
+ max_retries: int = 5,
120
+ ) -> tuple[str, str | None]:
121
+ """
122
+ Two-step classification for Ollama models.
123
+
124
+ Step 1: Classify the response (natural language output OK)
125
+ Step 2: Convert classification to strict JSON format
126
+
127
+ This approach is more reliable for local models that struggle with
128
+ simultaneous reasoning and JSON formatting.
129
+
130
+ Args:
131
+ client: UnifiedLLMClient instance
132
+ response_text: The text response to classify
133
+ categories: List of category names
134
+ categories_str: Pre-formatted category string
135
+ survey_question: Optional context
136
+ creativity: Temperature setting
137
+ max_retries: Number of retry attempts for JSON validation
138
+
139
+ Returns:
140
+ tuple: (json_string, error_message or None)
141
+ """
142
+ num_categories = len(categories)
143
+ survey_context = f"Context: {survey_question}." if survey_question else ""
144
+
145
+ # ==========================================================================
146
+ # Step 1: Classification (natural language - focus on accuracy)
147
+ # ==========================================================================
148
+ step1_messages = [
149
+ {
150
+ "role": "system",
151
+ "content": "You are an expert at categorizing text responses. Focus on accurate classification."
152
+ },
153
+ {
154
+ "role": "user",
155
+ "content": f"""{survey_context}
156
+
157
+ Analyze this text response and determine which categories apply:
158
+
159
+ Response: "{response_text}"
160
+
161
+ Categories:
162
+ {categories_str}
163
+
164
+ For each category, explain briefly whether it applies (YES) or not (NO) to this response.
165
+ Format your answer as:
166
+ 1. [Category name]: YES/NO - [brief reason]
167
+ 2. [Category name]: YES/NO - [brief reason]
168
+ ...and so on for all categories."""
169
+ }
170
+ ]
171
+
172
+ step1_reply, step1_error = client.complete(
173
+ messages=step1_messages,
174
+ json_schema=None, # No JSON requirement for step 1
175
+ creativity=creativity,
176
+ )
177
+
178
+ if step1_error:
179
+ return '{"1":"e"}', f"Step 1 failed: {step1_error}"
180
+
181
+ # ==========================================================================
182
+ # Step 2: JSON Formatting with validation and retry
183
+ # ==========================================================================
184
+ example_json = json.dumps({str(i): "0" for i in range(1, num_categories + 1)})
185
+
186
+ for attempt in range(max_retries):
187
+ step2_messages = [
188
+ {
189
+ "role": "system",
190
+ "content": "You convert classification results to JSON. Output ONLY valid JSON, nothing else."
191
+ },
192
+ {
193
+ "role": "user",
194
+ "content": f"""Convert this classification to JSON format.
195
+
196
+ Classification results:
197
+ {step1_reply}
198
+
199
+ Rules:
200
+ - Output ONLY a JSON object, no other text
201
+ - Use category numbers as keys (1, 2, 3, etc.)
202
+ - Use "1" if the category was marked YES, "0" if NO
203
+ - Include ALL {num_categories} categories
204
+
205
+ Example format:
206
+ {example_json}
207
+
208
+ Your JSON output:"""
209
+ }
210
+ ]
211
+
212
+ step2_reply, step2_error = client.complete(
213
+ messages=step2_messages,
214
+ json_schema=None, # Ollama doesn't support strict schema anyway
215
+ creativity=0.1, # Low temperature for formatting task
216
+ )
217
+
218
+ if step2_error:
219
+ if attempt < max_retries - 1:
220
+ continue
221
+ return '{"1":"e"}', f"Step 2 failed: {step2_error}"
222
+
223
+ # Extract and validate JSON
224
+ extracted = extract_json(step2_reply)
225
+ is_valid, normalized = validate_classification_json(extracted, num_categories)
226
+
227
+ if is_valid:
228
+ return json.dumps(normalized), None
229
+
230
+ # If invalid, try again with more explicit instructions
231
+ if attempt < max_retries - 1:
232
+ step1_reply = f"""Previous attempt produced invalid JSON.
233
+
234
+ Original classification:
235
+ {step1_reply}
236
+
237
+ Please be more careful to output EXACTLY {num_categories} categories numbered 1 through {num_categories}."""
238
+
239
+ # All retries exhausted - try to salvage what we can
240
+ extracted = extract_json(step2_reply) if step2_reply else '{"1":"e"}'
241
+ return extracted, f"JSON validation failed after {max_retries} attempts"
242
+
243
+
244
+ # =============================================================================
245
+ # Stepback Insight Utility
246
+ # =============================================================================
247
+
248
+ def _get_stepback_insight(model_source, stepback, api_key, user_model, creativity):
249
+ """Get step-back insight using the appropriate provider."""
250
+ from .calls.stepback import (
251
+ get_stepback_insight_openai,
252
+ get_stepback_insight_anthropic,
253
+ get_stepback_insight_google,
254
+ get_stepback_insight_mistral
255
+ )
256
+
257
+ stepback_functions = {
258
+ "openai": get_stepback_insight_openai,
259
+ "perplexity": get_stepback_insight_openai,
260
+ "huggingface": get_stepback_insight_openai,
261
+ "huggingface-together": get_stepback_insight_openai,
262
+ "xai": get_stepback_insight_openai,
263
+ "anthropic": get_stepback_insight_anthropic,
264
+ "google": get_stepback_insight_google,
265
+ "mistral": get_stepback_insight_mistral,
266
+ }
267
+
268
+ func = stepback_functions.get(model_source)
269
+ if func is None:
270
+ return None, False
271
+
272
+ return func(
273
+ stepback=stepback,
274
+ api_key=api_key,
275
+ user_model=user_model,
276
+ model_source=model_source,
277
+ creativity=creativity
278
+ )
279
+
280
+
281
+ # =============================================================================
282
+ # Image File Utilities
283
+ # =============================================================================
284
+
285
+ def _load_image_files(image_input):
286
+ """Load image files from directory path, single file path, or return list as-is."""
287
+ import os
288
+ import glob
289
+
290
+ image_extensions = [
291
+ '*.png', '*.jpg', '*.jpeg',
292
+ '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
293
+ '*.tif', '*.tiff', '*.bmp',
294
+ '*.heif', '*.heic', '*.ico',
295
+ '*.psd'
296
+ ]
297
+
298
+ if isinstance(image_input, list):
299
+ image_files = image_input
300
+ print(f"Provided a list of {len(image_input)} images.")
301
+ elif os.path.isfile(image_input):
302
+ # Single file path
303
+ image_files = [image_input]
304
+ print(f"Provided 1 image file.")
305
+ elif os.path.isdir(image_input):
306
+ # Directory path - glob for images
307
+ image_files = []
308
+ for ext in image_extensions:
309
+ image_files.extend(glob.glob(os.path.join(image_input, ext)))
310
+ print(f"Found {len(image_files)} images in directory.")
311
+ else:
312
+ raise FileNotFoundError(f"Image input not found: {image_input}")
313
+
314
+ return image_files
315
+
316
+
317
+ def _encode_image(img_path):
318
+ """Encode an image file to base64. Returns (encoded_data, extension, is_valid)."""
319
+ import os
320
+ import base64
321
+ from pathlib import Path
322
+
323
+ if img_path is None or not os.path.exists(img_path):
324
+ return None, None, False
325
+
326
+ if os.path.isdir(img_path):
327
+ return None, None, False
328
+
329
+ try:
330
+ with open(img_path, "rb") as f:
331
+ encoded = base64.b64encode(f.read()).decode("utf-8")
332
+ ext = Path(img_path).suffix.lstrip(".").lower()
333
+ return encoded, ext, True
334
+ except Exception as e:
335
+ print(f"Error encoding image: {e}")
336
+ return None, None, False
337
+
338
+
339
+ # =============================================================================
340
+ # PDF File Utilities
341
+ # =============================================================================
342
+
343
+ def _anthropic_supports_pdf(model_name):
344
+ """Check if the Anthropic model supports native PDF input.
345
+
346
+ PDF support is available for Claude 3.5 Sonnet, Claude 3 Opus, and Claude 3 Sonnet,
347
+ but NOT for Claude 3 Haiku.
348
+ """
349
+ model_lower = model_name.lower()
350
+ # Haiku models don't support PDF
351
+ if "haiku" in model_lower:
352
+ return False
353
+ # Sonnet, Opus support PDF
354
+ if any(x in model_lower for x in ["sonnet", "opus"]):
355
+ return True
356
+ # Default to False for unknown models to be safe
357
+ return False
358
+
359
+
360
+ def _load_pdf_files(pdf_input):
361
+ """Load PDF files from directory path, single file path, or return list as-is."""
362
+ import os
363
+ import glob
364
+
365
+ if isinstance(pdf_input, list):
366
+ pdf_files = pdf_input
367
+ print(f"Provided a list of {len(pdf_input)} PDFs.")
368
+ elif os.path.isfile(pdf_input):
369
+ # Single file path
370
+ pdf_files = [pdf_input]
371
+ print(f"Provided 1 PDF file.")
372
+ elif os.path.isdir(pdf_input):
373
+ # Directory path - glob for PDFs
374
+ pdf_files = glob.glob(os.path.join(pdf_input, '*.pdf'))
375
+ pdf_files.extend(glob.glob(os.path.join(pdf_input, '*.PDF')))
376
+ # Remove duplicates (case-insensitive systems)
377
+ seen = set()
378
+ unique_files = []
379
+ for f in pdf_files:
380
+ if f.lower() not in seen:
381
+ seen.add(f.lower())
382
+ unique_files.append(f)
383
+ pdf_files = unique_files
384
+ print(f"Found {len(pdf_files)} PDFs in directory.")
385
+ else:
386
+ raise FileNotFoundError(f"PDF input not found: {pdf_input}")
387
+
388
+ return pdf_files
389
+
390
+
391
+ def _get_pdf_pages(pdf_path):
392
+ """
393
+ Extract all pages from a PDF as separate page objects.
394
+ Returns list of tuples: [(pdf_path, page_index, page_label), ...]
395
+
396
+ For 'document.pdf' with 3 pages:
397
+ [(pdf_path, 0, "document_p1"), (pdf_path, 1, "document_p2"), (pdf_path, 2, "document_p3")]
398
+
399
+ The actual page data is extracted later based on provider needs.
400
+ """
401
+ from pathlib import Path
402
+
403
+ try:
404
+ import fitz # PyMuPDF
405
+ except ImportError:
406
+ raise ImportError(
407
+ "PyMuPDF is required for PDF processing. "
408
+ "Install it with: pip install PyMuPDF"
409
+ )
410
+
411
+ pdf_name = Path(pdf_path).stem # filename without extension
412
+
413
+ try:
414
+ doc = fitz.open(pdf_path)
415
+ page_count = len(doc)
416
+ doc.close()
417
+
418
+ if page_count == 0:
419
+ print(f"Warning: {pdf_path} has no pages")
420
+ return []
421
+
422
+ pages = []
423
+ for i in range(page_count):
424
+ page_label = f"{pdf_name}_p{i+1}"
425
+ pages.append((pdf_path, i, page_label))
426
+
427
+ return pages
428
+
429
+ except Exception as e:
430
+ print(f"Error reading PDF {pdf_path}: {e}")
431
+ return []
432
+
433
+
434
+ def _extract_page_as_pdf_bytes(pdf_path, page_index):
435
+ """
436
+ Extract a single page from a PDF as PDF bytes.
437
+ Used for providers with native PDF support (Anthropic, Google).
438
+ """
439
+ import fitz # PyMuPDF
440
+
441
+ try:
442
+ doc = fitz.open(pdf_path)
443
+ page = doc[page_index]
444
+
445
+ # Create a new PDF with just this page
446
+ new_doc = fitz.open()
447
+ new_doc.insert_pdf(doc, from_page=page_index, to_page=page_index)
448
+
449
+ pdf_bytes = new_doc.tobytes()
450
+ new_doc.close()
451
+ doc.close()
452
+
453
+ return pdf_bytes, True
454
+
455
+ except Exception as e:
456
+ print(f"Error extracting page {page_index} from {pdf_path}: {e}")
457
+ return None, False
458
+
459
+
460
+ def _extract_page_as_image_bytes(pdf_path, page_index, dpi=150):
461
+ """
462
+ Extract a single page from a PDF as PNG image bytes.
463
+ Used for providers without native PDF support (OpenAI, Mistral, etc.).
464
+ """
465
+ import fitz # PyMuPDF
466
+
467
+ try:
468
+ doc = fitz.open(pdf_path)
469
+ page = doc[page_index]
470
+
471
+ # Render page to image
472
+ mat = fitz.Matrix(dpi / 72, dpi / 72) # 72 is default PDF DPI
473
+ pix = page.get_pixmap(matrix=mat)
474
+
475
+ # Get PNG bytes
476
+ image_bytes = pix.tobytes("png")
477
+ doc.close()
478
+
479
+ return image_bytes, True
480
+
481
+ except Exception as e:
482
+ print(f"Error rendering page {page_index} from {pdf_path}: {e}")
483
+ return None, False
484
+
485
+
486
+ def _encode_bytes_to_base64(data_bytes):
487
+ """Encode bytes to base64 string."""
488
+ import base64
489
+ return base64.b64encode(data_bytes).decode("utf-8")
490
+
491
+
492
+ def _extract_page_text(pdf_path, page_index):
493
+ """
494
+ Extract text content from a single PDF page.
495
+ Used for text-based processing mode.
496
+ """
497
+ import fitz # PyMuPDF
498
+
499
+ try:
500
+ doc = fitz.open(pdf_path)
501
+ page = doc[page_index]
502
+ text = page.get_text("text")
503
+ doc.close()
504
+
505
+ if not text.strip():
506
+ return None, False, "Page contains no extractable text"
507
+
508
+ return text.strip(), True, None
509
+
510
+ except Exception as e:
511
+ print(f"Error extracting text from page {page_index} of {pdf_path}: {e}")
512
+ return None, False, str(e)
@@ -0,0 +1,194 @@
1
+ """
2
+ Web content fetching utilities for URL input type.
3
+
4
+ Provides URL detection, HTML text extraction, and batch URL fetching
5
+ for use as a preprocessing step before text classification/extraction/summarization.
6
+ """
7
+
8
+ import html as html_lib
9
+ import re
10
+
11
+ import requests
12
+
13
+ __all__ = [
14
+ "is_url",
15
+ "fetch_url_text",
16
+ "fetch_urls",
17
+ "detect_url_input",
18
+ "strip_html_tags",
19
+ ]
20
+
21
+ # Timeout for individual URL fetches (seconds)
22
+ _DEFAULT_TIMEOUT = 30
23
+
24
+ # Maximum characters to keep from fetched content
25
+ _MAX_CONTENT_CHARS = 50000
26
+
27
+ # User-Agent header for polite web scraping
28
+ _USER_AGENT = (
29
+ "Mozilla/5.0 (compatible; CatStack/1.0; "
30
+ "+https://github.com/chrissoria/cat-stack)"
31
+ )
32
+
33
+
34
+ def is_url(s) -> bool:
35
+ """
36
+ Check if a string looks like a URL (starts with http:// or https://).
37
+
38
+ Args:
39
+ s: Value to check.
40
+
41
+ Returns:
42
+ True if the value is a string starting with http:// or https://.
43
+ """
44
+ if not isinstance(s, str):
45
+ return False
46
+ return bool(re.match(r"https?://", s.strip()))
47
+
48
+
49
+ def detect_url_input(items) -> bool:
50
+ """
51
+ Check whether input data is a collection of URLs.
52
+
53
+ Inspects the first non-null item in the iterable. Returns True if
54
+ it looks like a URL.
55
+
56
+ Args:
57
+ items: A single string, list, pandas Series, or other iterable.
58
+
59
+ Returns:
60
+ True if the input appears to be URL data.
61
+ """
62
+ import pandas as pd
63
+
64
+ if isinstance(items, str):
65
+ return is_url(items)
66
+
67
+ if hasattr(items, "__iter__"):
68
+ for item in items:
69
+ if item is not None:
70
+ try:
71
+ if pd.isna(item):
72
+ continue
73
+ except (TypeError, ValueError):
74
+ pass
75
+ return is_url(str(item))
76
+
77
+ return False
78
+
79
+
80
+ def strip_html_tags(html: str) -> str:
81
+ """
82
+ Extract readable text from an HTML string.
83
+
84
+ Removes non-content elements (navigation, headers, footers, sidebars,
85
+ forms, scripts, styles), strips remaining tags, collapses whitespace,
86
+ and decodes HTML entities.
87
+
88
+ Args:
89
+ html: Raw HTML string.
90
+
91
+ Returns:
92
+ Plain-text string.
93
+ """
94
+ text = html
95
+
96
+ # Remove non-content element blocks entirely
97
+ _JUNK_TAGS = (
98
+ "script", "style", "nav", "header", "footer", "aside",
99
+ "noscript", "iframe", "form", "svg",
100
+ )
101
+ for tag in _JUNK_TAGS:
102
+ text = re.sub(
103
+ rf"<{tag}[^>]*>.*?</{tag}>",
104
+ "",
105
+ text,
106
+ flags=re.DOTALL | re.IGNORECASE,
107
+ )
108
+
109
+ # Remove void / self-closing non-content tags
110
+ for tag in ("input", "meta", "link", "img"):
111
+ text = re.sub(rf"<{tag}[^>]*/?\s*>", "", text, flags=re.IGNORECASE)
112
+
113
+ # Strip remaining HTML tags
114
+ text = re.sub(r"<[^>]+>", " ", text)
115
+ # Collapse whitespace
116
+ text = re.sub(r"\s+", " ", text).strip()
117
+ # Decode all HTML entities (&#91; &#39; &amp; etc.)
118
+ text = html_lib.unescape(text)
119
+ return text
120
+
121
+
122
+ def fetch_url_text(url: str, timeout: int = _DEFAULT_TIMEOUT):
123
+ """
124
+ Fetch a single URL and extract its text content.
125
+
126
+ HTML responses are stripped of tags; other content types are returned
127
+ as-is. Very long pages are truncated to ``_MAX_CONTENT_CHARS``.
128
+
129
+ Args:
130
+ url: The URL to fetch.
131
+ timeout: Request timeout in seconds.
132
+
133
+ Returns:
134
+ tuple: ``(text, error)`` where *text* is the extracted content and
135
+ *error* is ``None`` on success or an error message string.
136
+ """
137
+ try:
138
+ headers = {"User-Agent": _USER_AGENT}
139
+ try:
140
+ response = requests.get(url.strip(), headers=headers, timeout=timeout)
141
+ except requests.exceptions.SSLError:
142
+ # Retry without SSL verification as fallback
143
+ response = requests.get(
144
+ url.strip(), headers=headers, timeout=timeout, verify=False
145
+ )
146
+ response.raise_for_status()
147
+
148
+ content_type = response.headers.get("Content-Type", "")
149
+ if (
150
+ "text/html" in content_type
151
+ or "text/plain" in content_type
152
+ or not content_type
153
+ ):
154
+ text = strip_html_tags(response.text)
155
+ else:
156
+ text = response.text
157
+
158
+ # Truncate very long content
159
+ if len(text) > _MAX_CONTENT_CHARS:
160
+ text = text[:_MAX_CONTENT_CHARS] + (
161
+ f"\n\n[Content truncated at {_MAX_CONTENT_CHARS} characters]"
162
+ )
163
+
164
+ return text, None
165
+
166
+ except requests.exceptions.Timeout:
167
+ return "", f"Timeout after {timeout}s fetching {url}"
168
+ except requests.exceptions.HTTPError as e:
169
+ return "", f"HTTP {e.response.status_code} fetching {url}"
170
+ except Exception as e:
171
+ return "", f"Error fetching {url}: {e}"
172
+
173
+
174
+ def fetch_urls(urls, timeout: int = _DEFAULT_TIMEOUT):
175
+ """
176
+ Fetch content from a list of URLs.
177
+
178
+ Args:
179
+ urls: Iterable of URL strings.
180
+ timeout: Per-URL request timeout in seconds.
181
+
182
+ Returns:
183
+ list of ``(original_url, fetched_text, error)`` tuples. On success
184
+ *error* is ``None``; on failure *fetched_text* is ``""``.
185
+ """
186
+ results = []
187
+ for url in urls:
188
+ url_str = str(url).strip()
189
+ if not is_url(url_str):
190
+ results.append((url_str, "", f"Not a valid URL: {url_str}"))
191
+ continue
192
+ text, error = fetch_url_text(url_str, timeout=timeout)
193
+ results.append((url_str, text, error))
194
+ return results