cat-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1358 @@
1
+ """
2
+ Text classification functions for CatLLM.
3
+
4
+ This module provides multi-class text classification using a unified HTTP-based approach
5
+ that works with multiple LLM providers (OpenAI, Anthropic, Google, Mistral, xAI,
6
+ Perplexity, HuggingFace, and Ollama) without requiring provider-specific SDKs.
7
+ """
8
+
9
+ import json
10
+ import time
11
+ import warnings
12
+
13
+ # Exported names (excludes deprecated multi_class)
14
+ __all__ = [
15
+ "UnifiedLLMClient",
16
+ "detect_provider",
17
+ "set_ollama_endpoint",
18
+ "check_ollama_running",
19
+ "list_ollama_models",
20
+ "check_ollama_model",
21
+ "check_system_resources",
22
+ "get_ollama_model_size_estimate",
23
+ "pull_ollama_model",
24
+ "check_claude_cli_available",
25
+ "build_json_schema",
26
+ "extract_json",
27
+ "validate_classification_json",
28
+ "ollama_two_step_classify",
29
+ "explore_corpus",
30
+ "explore_common_categories",
31
+ # Internal utilities used by other modules
32
+ "_detect_model_source",
33
+ "_get_stepback_insight",
34
+ "_detect_huggingface_endpoint",
35
+ ]
36
+ import pandas as pd
37
+ import regex
38
+ from tqdm import tqdm
39
+
40
+ from .calls.stepback import (
41
+ get_stepback_insight_openai,
42
+ get_stepback_insight_anthropic,
43
+ get_stepback_insight_google,
44
+ get_stepback_insight_mistral
45
+ )
46
+ from .calls.CoVe import (
47
+ chain_of_verification_openai,
48
+ chain_of_verification_google,
49
+ chain_of_verification_anthropic,
50
+ chain_of_verification_mistral
51
+ )
52
+ from .calls.top_n import (
53
+ get_openai_top_n,
54
+ get_anthropic_top_n,
55
+ get_google_top_n,
56
+ get_mistral_top_n
57
+ )
58
+
59
+ from ._providers import (
60
+ UnifiedLLMClient,
61
+ PROVIDER_CONFIG,
62
+ detect_provider,
63
+ _detect_model_source,
64
+ _detect_huggingface_endpoint,
65
+ set_ollama_endpoint,
66
+ check_ollama_running,
67
+ list_ollama_models,
68
+ check_ollama_model,
69
+ check_system_resources,
70
+ get_ollama_model_size_estimate,
71
+ pull_ollama_model,
72
+ check_claude_cli_available,
73
+ OLLAMA_MODEL_SIZES,
74
+ )
75
+
76
+
77
+ # =============================================================================
78
+ # Helper Functions
79
+ # =============================================================================
80
+
81
+ def _get_stepback_insight(model_source, stepback, api_key, user_model, creativity):
82
+ """Get step-back insight using the appropriate provider."""
83
+ stepback_functions = {
84
+ "openai": get_stepback_insight_openai,
85
+ "perplexity": get_stepback_insight_openai,
86
+ "huggingface": get_stepback_insight_openai,
87
+ "huggingface-together": get_stepback_insight_openai,
88
+ "xai": get_stepback_insight_openai,
89
+ "anthropic": get_stepback_insight_anthropic,
90
+ "google": get_stepback_insight_google,
91
+ "mistral": get_stepback_insight_mistral,
92
+ }
93
+
94
+ func = stepback_functions.get(model_source)
95
+ if func is None:
96
+ return None, False
97
+
98
+ return func(
99
+ stepback=stepback,
100
+ api_key=api_key,
101
+ user_model=user_model,
102
+ model_source=model_source,
103
+ creativity=creativity
104
+ )
105
+
106
+
107
+
108
+ # =============================================================================
109
+ # JSON Schema Functions
110
+ # =============================================================================
111
+
112
+ def build_json_schema(categories: list, include_additional_properties: bool = True) -> dict:
113
+ """Build a JSON schema for the classification output.
114
+
115
+ Args:
116
+ categories: List of category names
117
+ include_additional_properties: If True, includes additionalProperties: false
118
+ (required by OpenAI strict mode, not supported by Google)
119
+ """
120
+ properties = {}
121
+ for i, cat in enumerate(categories, 1):
122
+ properties[str(i)] = {
123
+ "type": "string",
124
+ "enum": ["0", "1"],
125
+ "description": cat,
126
+ }
127
+
128
+ schema = {
129
+ "type": "object",
130
+ "properties": properties,
131
+ "required": list(properties.keys()),
132
+ }
133
+
134
+ if include_additional_properties:
135
+ schema["additionalProperties"] = False
136
+
137
+ return schema
138
+
139
+
140
+ def extract_json(reply: str) -> str:
141
+ """Extract JSON from model reply."""
142
+ if reply is None:
143
+ return '{"1":"e"}'
144
+
145
+ extracted = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
146
+ if extracted:
147
+ raw = extracted[0].replace('[', '').replace(']', '')
148
+ # Parse and re-serialize to normalize structural whitespace while
149
+ # preserving spaces inside string values (e.g. summaries)
150
+ try:
151
+ parsed = json.loads(raw)
152
+ return json.dumps(parsed, separators=(',', ':'))
153
+ except json.JSONDecodeError:
154
+ return raw.replace('\n', '')
155
+ else:
156
+ return '{"1":"e"}'
157
+
158
+
159
+ def validate_classification_json(json_str: str, num_categories: int) -> tuple[bool, dict | None]:
160
+ """
161
+ Validate that a JSON string contains valid classification output.
162
+
163
+ Args:
164
+ json_str: The JSON string to validate
165
+ num_categories: Expected number of categories
166
+
167
+ Returns:
168
+ tuple: (is_valid, parsed_dict or None)
169
+ """
170
+ try:
171
+ parsed = json.loads(json_str)
172
+
173
+ if not isinstance(parsed, dict):
174
+ return False, None
175
+
176
+ # Check that all expected keys are present and values are "0" or "1"
177
+ for i in range(1, num_categories + 1):
178
+ key = str(i)
179
+ if key not in parsed:
180
+ return False, None
181
+ val = str(parsed[key]).strip()
182
+ if val not in ("0", "1"):
183
+ return False, None
184
+
185
+ # Normalize values to strings
186
+ normalized = {str(i): str(parsed[str(i)]).strip() for i in range(1, num_categories + 1)}
187
+ return True, normalized
188
+
189
+ except (json.JSONDecodeError, KeyError, TypeError):
190
+ return False, None
191
+
192
+
193
+ def ollama_two_step_classify(
194
+ client,
195
+ response_text: str,
196
+ categories: list,
197
+ categories_str: str,
198
+ survey_question: str = "",
199
+ creativity: float = None,
200
+ max_retries: int = 5,
201
+ ) -> tuple[str, str | None]:
202
+ """
203
+ Two-step classification for Ollama models.
204
+
205
+ Step 1: Classify the response (natural language output OK)
206
+ Step 2: Convert classification to strict JSON format
207
+
208
+ This approach is more reliable for local models that struggle with
209
+ simultaneous reasoning and JSON formatting.
210
+
211
+ Args:
212
+ client: UnifiedLLMClient instance
213
+ response_text: The text response to classify
214
+ categories: List of category names
215
+ categories_str: Pre-formatted category string
216
+ survey_question: Optional context
217
+ creativity: Temperature setting
218
+ max_retries: Number of retry attempts for JSON validation
219
+
220
+ Returns:
221
+ tuple: (json_string, error_message or None)
222
+ """
223
+ num_categories = len(categories)
224
+ survey_context = f"Context: {survey_question}." if survey_question else ""
225
+
226
+ # ==========================================================================
227
+ # Step 1: Classification (natural language - focus on accuracy)
228
+ # ==========================================================================
229
+ step1_messages = [
230
+ {
231
+ "role": "system",
232
+ "content": "You are an expert at categorizing text responses. Focus on accurate classification."
233
+ },
234
+ {
235
+ "role": "user",
236
+ "content": f"""{survey_context}
237
+
238
+ Analyze this text response and determine which categories apply:
239
+
240
+ Response: "{response_text}"
241
+
242
+ Categories:
243
+ {categories_str}
244
+
245
+ For each category, explain briefly whether it applies (YES) or not (NO) to this response.
246
+ Format your answer as:
247
+ 1. [Category name]: YES/NO - [brief reason]
248
+ 2. [Category name]: YES/NO - [brief reason]
249
+ ...and so on for all categories."""
250
+ }
251
+ ]
252
+
253
+ step1_reply, step1_error = client.complete(
254
+ messages=step1_messages,
255
+ json_schema=None, # No JSON requirement for step 1
256
+ creativity=creativity,
257
+ )
258
+
259
+ if step1_error:
260
+ return '{"1":"e"}', f"Step 1 failed: {step1_error}"
261
+
262
+ # ==========================================================================
263
+ # Step 2: JSON Formatting with validation and retry
264
+ # ==========================================================================
265
+ example_json = json.dumps({str(i): "0" for i in range(1, num_categories + 1)})
266
+
267
+ for attempt in range(max_retries):
268
+ step2_messages = [
269
+ {
270
+ "role": "system",
271
+ "content": "You convert classification results to JSON. Output ONLY valid JSON, nothing else."
272
+ },
273
+ {
274
+ "role": "user",
275
+ "content": f"""Convert this classification to JSON format.
276
+
277
+ Classification results:
278
+ {step1_reply}
279
+
280
+ Rules:
281
+ - Output ONLY a JSON object, no other text
282
+ - Use category numbers as keys (1, 2, 3, etc.)
283
+ - Use "1" if the category was marked YES, "0" if NO
284
+ - Include ALL {num_categories} categories
285
+
286
+ Example format:
287
+ {example_json}
288
+
289
+ Your JSON output:"""
290
+ }
291
+ ]
292
+
293
+ step2_reply, step2_error = client.complete(
294
+ messages=step2_messages,
295
+ json_schema=None, # Ollama doesn't support strict schema anyway
296
+ creativity=0.1, # Low temperature for formatting task
297
+ )
298
+
299
+ if step2_error:
300
+ if attempt < max_retries - 1:
301
+ continue
302
+ return '{"1":"e"}', f"Step 2 failed: {step2_error}"
303
+
304
+ # Extract and validate JSON
305
+ extracted = extract_json(step2_reply)
306
+ is_valid, normalized = validate_classification_json(extracted, num_categories)
307
+
308
+ if is_valid:
309
+ return json.dumps(normalized), None
310
+
311
+ # If invalid, try again with more explicit instructions
312
+ if attempt < max_retries - 1:
313
+ step1_reply = f"""Previous attempt produced invalid JSON.
314
+
315
+ Original classification:
316
+ {step1_reply}
317
+
318
+ Please be more careful to output EXACTLY {num_categories} categories numbered 1 through {num_categories}."""
319
+
320
+ # All retries exhausted - try to salvage what we can
321
+ extracted = extract_json(step2_reply) if step2_reply else '{"1":"e"}'
322
+ return extracted, f"JSON validation failed after {max_retries} attempts"
323
+
324
+
325
+ # =============================================================================
326
+ # Category Exploration Functions
327
+ # =============================================================================
328
+
329
+ def explore_corpus(
330
+ survey_question,
331
+ input_data,
332
+ api_key: str = None,
333
+ research_question=None,
334
+ specificity="broad",
335
+ categories_per_chunk=10,
336
+ divisions=5,
337
+ model: str = "gpt-4o",
338
+ provider: str = "auto",
339
+ creativity=None,
340
+ filename="corpus_exploration.csv",
341
+ focus: str = None,
342
+ ):
343
+ """
344
+ Extract categories from text corpus using LLM.
345
+
346
+ Uses raw HTTP requests via UnifiedLLMClient - supports all providers.
347
+
348
+ Args:
349
+ survey_question: The survey question being analyzed
350
+ input_data: Series or list of text responses
351
+ api_key: API key for the LLM provider
352
+ research_question: Optional research context
353
+ specificity: "broad" or "specific" categories
354
+ categories_per_chunk: Number of categories to extract per chunk
355
+ divisions: Number of chunks to process
356
+ model: Model name (e.g., "gpt-4o", "claude-3-haiku-20240307", "gemini-2.5-flash")
357
+ provider: Provider name or "auto" to detect from model name
358
+ creativity: Temperature setting
359
+ filename: Output CSV filename (None to skip saving)
360
+ focus: Optional focus instruction for category extraction (e.g., "decisions to move",
361
+ "emotional responses", "financial considerations"). When provided, the model
362
+ will prioritize extracting categories related to this focus.
363
+
364
+ Returns:
365
+ DataFrame with extracted categories and counts
366
+ """
367
+ # Detect provider
368
+ provider = detect_provider(model, provider)
369
+
370
+ # Validate api_key
371
+ if provider not in ("ollama", "claude-code") and not api_key:
372
+ raise ValueError(f"api_key is required for provider '{provider}'")
373
+
374
+ print(f"Exploring categories for question: '{survey_question}'")
375
+ print(f"Using provider: {provider}, model: {model}")
376
+ if focus:
377
+ print(f"Focus: {focus}")
378
+ print(f" {categories_per_chunk * divisions} unique categories to be extracted.")
379
+ print()
380
+
381
+ # Input normalization
382
+ if not isinstance(input_data, pd.Series):
383
+ input_data = pd.Series(input_data)
384
+ input_data = input_data.dropna()
385
+
386
+ n = len(input_data)
387
+ if n == 0:
388
+ raise ValueError("input_data is empty after dropping NA.")
389
+
390
+ # Auto-adjust divisions for small datasets
391
+ original_divisions = divisions
392
+ divisions = min(divisions, max(1, n // 3))
393
+ if divisions != original_divisions:
394
+ print(f"Auto-adjusted divisions from {original_divisions} to {divisions} for {n} responses.")
395
+
396
+ chunk_size = int(round(max(1, n / divisions), 0))
397
+
398
+ if chunk_size < (categories_per_chunk / 2):
399
+ old_categories_per_chunk = categories_per_chunk
400
+ categories_per_chunk = max(3, chunk_size * 2)
401
+ print(f"Auto-adjusted categories_per_chunk from {old_categories_per_chunk} to {categories_per_chunk} for chunk size {chunk_size}.")
402
+
403
+ # Initialize unified client
404
+ client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
405
+
406
+ # Build system message
407
+ if research_question:
408
+ system_content = (
409
+ f"You are a helpful assistant that extracts categories from text responses. "
410
+ f"The specific task is to identify {specificity} categories of responses to a text prompt. "
411
+ f"The research question is: {research_question}"
412
+ )
413
+ else:
414
+ system_content = "You are a helpful assistant that extracts categories from text responses."
415
+
416
+ # Sample chunks
417
+ random_chunks = []
418
+ for i in range(divisions):
419
+ chunk = input_data.sample(n=chunk_size).tolist()
420
+ random_chunks.append(chunk)
421
+
422
+ responses = []
423
+ responses_list = []
424
+
425
+ for i in tqdm(range(divisions), desc="Processing chunks"):
426
+ survey_participant_chunks = '; '.join(str(x) for x in random_chunks[i])
427
+ focus_text = f" Focus specifically on {focus}." if focus else ""
428
+ prompt = (
429
+ f'Identify {categories_per_chunk} {specificity} categories of responses to the question "{survey_question}" '
430
+ f"in the following list of responses.{focus_text} Responses are each separated by a semicolon. "
431
+ f"Responses are contained within triple backticks here: ```{survey_participant_chunks}``` "
432
+ f"Number your categories from 1 through {categories_per_chunk} and be concise with the category labels and provide no description of the categories."
433
+ )
434
+
435
+ messages = [
436
+ {"role": "system", "content": system_content},
437
+ {"role": "user", "content": prompt}
438
+ ]
439
+
440
+ reply, error = client.complete(
441
+ messages=messages,
442
+ creativity=creativity,
443
+ force_json=False, # Text response, not JSON
444
+ )
445
+
446
+ if error:
447
+ if "context_length_exceeded" in str(error) or "maximum context length" in str(error):
448
+ raise ValueError(
449
+ f"Token limit exceeded for model {model}. "
450
+ f"Try increasing the 'divisions' parameter to create smaller chunks."
451
+ )
452
+ else:
453
+ print(f"API error on chunk {i+1}: {error}")
454
+ reply = ""
455
+
456
+ responses.append(reply)
457
+
458
+ # Extract just the text as a list
459
+ items = []
460
+ for line in (reply or "").split('\n'):
461
+ if '. ' in line:
462
+ try:
463
+ items.append(line.split('. ', 1)[1])
464
+ except IndexError:
465
+ pass
466
+
467
+ responses_list.append(items)
468
+
469
+ flat_list = [item.lower() for sublist in responses_list for item in sublist]
470
+
471
+ if not flat_list:
472
+ raise ValueError("No categories were extracted from the model responses.")
473
+
474
+ df = pd.DataFrame(flat_list, columns=['Category'])
475
+ counts = pd.Series(flat_list).value_counts()
476
+ df['counts'] = df['Category'].map(counts)
477
+ df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
478
+ df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
479
+
480
+ if filename is not None:
481
+ df.to_csv(filename, index=False)
482
+ print(f"Results saved to {filename}")
483
+
484
+ return df
485
+
486
+
487
+ def explore_common_categories(
488
+ input_data,
489
+ api_key: str = None,
490
+ survey_question: str = "",
491
+ max_categories: int = 12,
492
+ categories_per_chunk: int = 10,
493
+ divisions: int = 5,
494
+ model: str = "gpt-4o",
495
+ provider: str = "auto",
496
+ creativity: float = None,
497
+ specificity: str = "broad",
498
+ research_question: str = None,
499
+ filename: str = None,
500
+ iterations: int = 5,
501
+ random_state: int = None,
502
+ focus: str = None,
503
+ progress_callback: callable = None,
504
+ return_raw: bool = False,
505
+ chunk_delay: float = 0.0,
506
+ auto_download: bool = False,
507
+ # Legacy parameter names for backward compatibility
508
+ user_model: str = None,
509
+ model_source: str = None,
510
+ ):
511
+ """
512
+ Extract and rank common categories from survey corpus.
513
+
514
+ Uses raw HTTP requests via UnifiedLLMClient - supports all providers.
515
+
516
+ Args:
517
+ input_data: Series or list of text responses
518
+ api_key: API key for the LLM provider
519
+ survey_question: The survey question being analyzed
520
+ max_categories: Maximum number of top categories to return
521
+ categories_per_chunk: Number of categories to extract per chunk
522
+ divisions: Number of chunks to process per iteration
523
+ model: Model name (e.g., "gpt-4o", "claude-3-haiku-20240307", "gemini-2.5-flash")
524
+ provider: Provider name or "auto" to detect from model name
525
+ creativity: Temperature setting
526
+ specificity: "broad" or "specific" categories
527
+ research_question: Optional research context
528
+ filename: Output CSV filename (None to skip saving)
529
+ iterations: Number of passes over the data
530
+ random_state: Random seed for reproducibility
531
+ focus: Optional focus instruction for category extraction (e.g., "decisions to move",
532
+ "emotional responses", "financial considerations"). When provided, the model
533
+ will prioritize extracting categories related to this focus.
534
+ progress_callback: Optional callback function for progress updates.
535
+ Called as progress_callback(current_step, total_steps, step_label).
536
+ auto_download: If True, automatically download missing Ollama models
537
+ without prompting. Default False (interactive prompt).
538
+
539
+ Returns:
540
+ dict with 'counts_df', 'top_categories', and 'raw_top_text'
541
+ """
542
+ import re
543
+ import numpy as np
544
+
545
+ # Handle legacy parameter names
546
+ if user_model is not None:
547
+ model = user_model
548
+ if model_source is not None:
549
+ provider = model_source
550
+
551
+ # Detect provider
552
+ provider = detect_provider(model, provider)
553
+
554
+ # Validate api_key
555
+ if provider not in ("ollama", "claude-code") and not api_key:
556
+ raise ValueError(f"api_key is required for provider '{provider}'")
557
+
558
+ # Ollama-specific checks
559
+ if provider == "ollama":
560
+ if not check_ollama_running():
561
+ raise ConnectionError(
562
+ "\n" + "="*60 + "\n"
563
+ " OLLAMA NOT RUNNING\n"
564
+ "="*60 + "\n\n"
565
+ "Ollama must be running to use local models.\n\n"
566
+ "To start Ollama:\n"
567
+ " macOS: Open the Ollama app, or run 'ollama serve'\n"
568
+ " Linux: Run 'ollama serve' in terminal\n"
569
+ " Windows: Open the Ollama app\n\n"
570
+ "Don't have Ollama installed?\n"
571
+ " Download from: https://ollama.ai/download\n\n"
572
+ "After starting Ollama, run your code again.\n"
573
+ + "="*60
574
+ )
575
+
576
+ # Check system resources before proceeding
577
+ resources = check_system_resources(model)
578
+
579
+ # Check if model needs to be downloaded
580
+ model_installed = check_ollama_model(model)
581
+
582
+ if not model_installed:
583
+ if not pull_ollama_model(model, auto_confirm=auto_download):
584
+ raise RuntimeError(
585
+ f"Model '{model}' not available. "
586
+ f"To download manually: ollama pull {model}"
587
+ )
588
+ else:
589
+ # Model is installed - still check if it can run
590
+ if resources["warnings"] or not resources["can_run"]:
591
+ print(f"\n{'='*60}")
592
+ print(f" Model '{model}' - System Resource Check")
593
+ print(f"{'='*60}")
594
+ size_estimate = get_ollama_model_size_estimate(model)
595
+ print(f" Model size: {size_estimate}")
596
+ if resources["details"].get("estimated_ram"):
597
+ print(f" RAM required: ~{resources['details']['estimated_ram']}")
598
+ if resources["details"].get("total_ram"):
599
+ print(f" System RAM: {resources['details']['total_ram']}")
600
+
601
+ if resources["warnings"]:
602
+ print(f"\n {'!'*50}")
603
+ for warning in resources["warnings"]:
604
+ print(f" Warning: {warning}")
605
+ print(f" {'!'*50}")
606
+
607
+ if not resources["can_run"]:
608
+ print(f"\n Warning: Model may not run well on this system.")
609
+ print(f" Consider a smaller variant (e.g., '{model}:1b' or '{model}:3b').")
610
+ print(f"{'='*60}")
611
+
612
+ if not auto_download:
613
+ try:
614
+ response = input(f"\n Continue anyway? [y/N]: ").strip().lower()
615
+ if response not in ['y', 'yes']:
616
+ raise RuntimeError(
617
+ f"Model '{model}' may be too large for this system. "
618
+ f"Try a smaller variant like '{model}:3b' or '{model}:1b'."
619
+ )
620
+ except (EOFError, KeyboardInterrupt):
621
+ raise RuntimeError("Operation cancelled by user.")
622
+
623
+ print()
624
+
625
+ # Input normalization
626
+ if not isinstance(input_data, pd.Series):
627
+ input_data = pd.Series(input_data)
628
+ input_data = input_data.dropna().astype("string")
629
+ n = len(input_data)
630
+ if n == 0:
631
+ raise ValueError("input_data is empty after dropping NA.")
632
+
633
+ # Auto-adjust divisions for small datasets
634
+ original_divisions = divisions
635
+ divisions = min(divisions, max(1, n // 3))
636
+ if divisions != original_divisions:
637
+ print(f"Auto-adjusted divisions from {original_divisions} to {divisions} for {n} responses.")
638
+
639
+ # Chunk sizing
640
+ chunk_size = int(round(max(1, n / divisions), 0))
641
+ if chunk_size < (categories_per_chunk / 2):
642
+ old_categories_per_chunk = categories_per_chunk
643
+ categories_per_chunk = max(3, chunk_size * 2)
644
+ print(f"Auto-adjusted categories_per_chunk from {old_categories_per_chunk} to {categories_per_chunk} for chunk size {chunk_size}.")
645
+
646
+ print(f"Exploring categories for question: '{survey_question}'")
647
+ print(f"Using provider: {provider}, model: {model}")
648
+ if focus:
649
+ print(f"Focus: {focus}")
650
+ print(f" {categories_per_chunk * divisions * iterations} total category extractions across {iterations} iterations.")
651
+ print(f" Top {max_categories} categories will be identified.\n")
652
+
653
+ # RNG for reproducible re-sampling across passes
654
+ rng = np.random.default_rng(random_state)
655
+
656
+ # Initialize unified client
657
+ client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
658
+
659
+ # Build system message
660
+ if research_question:
661
+ system_content = (
662
+ f"You are a helpful assistant that extracts categories from text responses. "
663
+ f"The specific task is to identify {specificity} categories of responses to a text prompt. "
664
+ f"The research question is: {research_question}"
665
+ )
666
+ else:
667
+ system_content = "You are a helpful assistant that extracts categories from text responses."
668
+
669
+ def make_prompt(responses_blob: str) -> str:
670
+ focus_text = f" Focus specifically on {focus}." if focus else ""
671
+ return (
672
+ f'Identify {categories_per_chunk} {specificity} categories of responses to the question "{survey_question}" '
673
+ f"in the following list of responses.{focus_text} Responses are separated by semicolons. "
674
+ f"Responses are within triple backticks: ```{responses_blob}``` "
675
+ f"Number your categories from 1 through {categories_per_chunk} and provide concise labels only (no descriptions)."
676
+ )
677
+
678
+ # Parse numbered list
679
+ line_pat = re.compile(r"^\s*\d+\s*[\.\)\-]\s*(.+)$")
680
+
681
+ all_items = []
682
+
683
+ # Calculate total steps for progress tracking: (iterations * divisions) + 1 for final merge
684
+ total_steps = (iterations * divisions) + 1
685
+ current_step = 0
686
+
687
+ for pass_idx in range(iterations):
688
+ random_chunks = []
689
+ for _ in range(divisions):
690
+ seed = int(rng.integers(0, 2**32 - 1))
691
+ chunk = input_data.sample(n=chunk_size, random_state=seed).tolist()
692
+ random_chunks.append(chunk)
693
+
694
+ for i in tqdm(range(divisions), desc=f"Processing chunks (pass {pass_idx+1}/{iterations})"):
695
+ survey_participant_chunks = "; ".join(str(x) for x in random_chunks[i])
696
+ prompt = make_prompt(survey_participant_chunks)
697
+
698
+ messages = [
699
+ {"role": "system", "content": system_content},
700
+ {"role": "user", "content": prompt}
701
+ ]
702
+
703
+ reply, error = client.complete(
704
+ messages=messages,
705
+ creativity=creativity,
706
+ force_json=False, # Text response, not JSON
707
+ )
708
+
709
+ if error:
710
+ raise RuntimeError(
711
+ f"Model call failed on pass {pass_idx+1}, chunk {i+1}: {error}"
712
+ )
713
+
714
+ items = []
715
+ for raw_line in (reply or "").splitlines():
716
+ m = line_pat.match(raw_line.strip())
717
+ if m:
718
+ items.append(m.group(1).strip())
719
+ if not items:
720
+ for raw_line in (reply or "").splitlines():
721
+ s = raw_line.strip()
722
+ if s:
723
+ items.append(s)
724
+
725
+ all_items.extend(items)
726
+
727
+ # Progress callback
728
+ current_step += 1
729
+ if progress_callback:
730
+ progress_callback(current_step, total_steps, f"Pass {pass_idx+1}/{iterations}, chunk {i+1}/{divisions}")
731
+
732
+ # Per-chunk delay to avoid rate limits
733
+ if chunk_delay > 0:
734
+ time.sleep(chunk_delay)
735
+
736
+ # Early return for raw output (used by explore())
737
+ if return_raw:
738
+ return all_items
739
+
740
+ # Normalize and count
741
+ def normalize_category(cat):
742
+ terms = sorted([t.strip().lower() for t in str(cat).split("/")])
743
+ return "/".join(terms)
744
+
745
+ flat_list = [str(x).strip() for x in all_items if str(x).strip()]
746
+ if not flat_list:
747
+ raise ValueError("No categories were extracted from the model responses.")
748
+
749
+ df = pd.DataFrame(flat_list, columns=["Category"])
750
+ df["normalized"] = df["Category"].map(normalize_category)
751
+
752
+ result = (
753
+ df.groupby("normalized")
754
+ .agg(Category=("Category", lambda x: x.value_counts().index[0]),
755
+ counts=("Category", "size"))
756
+ .sort_values("counts", ascending=False)
757
+ .reset_index(drop=True)
758
+ )
759
+
760
+ # Second-pass semantic merge prompt
761
+ seed_list = result["Category"].head(max_categories * 3).tolist()
762
+
763
+ second_prompt = f"""
764
+ You are a data analyst reviewing categorized text data.
765
+
766
+ Task: From the provided categories, identify and return the top {max_categories} CONCEPTUALLY UNIQUE categories.
767
+
768
+ Critical Instructions:
769
+ 1) Exact duplicates are already removed.
770
+ 2) Merge SEMANTIC duplicates (same concept, different wording). Examples:
771
+ - "closer to work" = "commute/proximity to work"
772
+ - "breakup/household conflict" = "relationship problems"
773
+ 3) When merging:
774
+ - Combine frequencies mentally
775
+ - Keep the most frequent OR clearest label
776
+ - Each concept appears ONLY ONCE
777
+ 4) Keep category names {specificity}.
778
+ 5) Return ONLY a numbered list of {max_categories} categories. No extra text.
779
+
780
+ Pre-processed Categories (sorted by frequency, top sample):
781
+ {seed_list}
782
+
783
+ Output:
784
+ 1. category
785
+ 2. category
786
+ ...
787
+ {max_categories}. category
788
+ """.strip()
789
+
790
+ # Second pass call
791
+ reply2, error2 = client.complete(
792
+ messages=[{"role": "user", "content": second_prompt}],
793
+ creativity=creativity,
794
+ force_json=False, # Text response
795
+ )
796
+
797
+ # Final progress callback for the merge step
798
+ if progress_callback:
799
+ progress_callback(total_steps, total_steps, "Merging categories")
800
+
801
+ if error2:
802
+ print(f"Warning: Second pass failed: {error2}")
803
+ top_categories_text = ""
804
+ else:
805
+ top_categories_text = reply2 or ""
806
+
807
+ final = []
808
+ for line in top_categories_text.splitlines():
809
+ m = line_pat.match(line.strip())
810
+ if m:
811
+ final.append(m.group(1).strip())
812
+ if not final:
813
+ final = [l.strip("-* ").strip() for l in top_categories_text.splitlines() if l.strip()]
814
+
815
+ # Fallback to counts_df if second pass failed
816
+ if not final:
817
+ final = result["Category"].head(max_categories).tolist()
818
+
819
+ print("\nTop categories:\n" + "\n".join(f"{i+1}. {c}" for i, c in enumerate(final[:max_categories])))
820
+
821
+ if filename:
822
+ result.to_csv(filename, index=False)
823
+ print(f"\nResults saved to {filename}")
824
+
825
+ return {
826
+ "counts_df": result,
827
+ "top_categories": final[:max_categories],
828
+ "raw_top_text": top_categories_text
829
+ }
830
+
831
+
832
+ # =============================================================================
833
+ # Main Classification Function
834
+ # =============================================================================
835
+
836
+ def multi_class(
837
+ input_data,
838
+ categories,
839
+ api_key: str = None,
840
+ model: str = "gpt-4o",
841
+ provider: str = "auto",
842
+ survey_question: str = "",
843
+ example1: str = None,
844
+ example2: str = None,
845
+ example3: str = None,
846
+ example4: str = None,
847
+ example5: str = None,
848
+ example6: str = None,
849
+ creativity: float = None,
850
+ safety: bool = False,
851
+ chain_of_verification: bool = False,
852
+ chain_of_thought: bool = False,
853
+ step_back_prompt: bool = False,
854
+ context_prompt: bool = False,
855
+ thinking_budget: int = 0,
856
+ max_categories: int = 12,
857
+ categories_per_chunk: int = 10,
858
+ divisions: int = 10,
859
+ research_question: str = None,
860
+ use_json_schema: bool = True,
861
+ filename: str = None,
862
+ save_directory: str = None,
863
+ auto_download: bool = False,
864
+ ):
865
+ """
866
+ Multi-class text classification using a unified HTTP-based approach.
867
+
868
+ This function uses raw HTTP requests for all providers, eliminating SDK dependencies.
869
+ Supports multiple prompting strategies including chain-of-thought, chain-of-verification,
870
+ step-back prompting, and context prompting.
871
+
872
+ Args:
873
+ input_data: List or Series of text responses to classify
874
+ categories: List of category names, or "auto" to auto-detect categories
875
+ api_key: API key for the LLM provider (not required for Ollama)
876
+ model: Model name (e.g., "gpt-4o", "claude-sonnet-4-5-20250929", "gemini-2.5-flash",
877
+ or any Ollama model like "llama3.2", "mistral", "phi3")
878
+ provider: Provider name or "auto" to detect from model name.
879
+ For local models, use provider="ollama"
880
+ survey_question: Optional context about what question was asked
881
+ example1-6: Optional few-shot examples for classification
882
+ creativity: Temperature setting (None for provider default)
883
+ safety: If True, saves results incrementally during processing
884
+ chain_of_verification: If True, uses 4-step CoVe prompting for verification
885
+ chain_of_thought: If True, uses step-by-step reasoning in prompt
886
+ step_back_prompt: If True, first asks about underlying factors before classifying
887
+ context_prompt: If True, adds expert context prefix to prompts
888
+ thinking_budget: Token budget for Google's extended thinking (0 to disable)
889
+ max_categories: Maximum categories when using auto-detection
890
+ categories_per_chunk: Categories per chunk for auto-detection
891
+ divisions: Number of divisions for auto-detection
892
+ research_question: Research context for auto-detection
893
+ use_json_schema: Whether to use strict JSON schema (vs just json_object mode)
894
+ filename: Optional CSV filename to save results
895
+ save_directory: Optional directory for safety saves
896
+ auto_download: If True, automatically download missing Ollama models
897
+
898
+ Returns:
899
+ DataFrame with classification results
900
+
901
+ Example with Ollama (local):
902
+ results = multi_class(
903
+ input_data=["I moved for work"],
904
+ categories=["Employment", "Family"],
905
+ model="llama3.2",
906
+ provider="ollama",
907
+ )
908
+
909
+ Example with cloud provider:
910
+ results = multi_class(
911
+ input_data=["I moved for work"],
912
+ categories=["Employment", "Family"],
913
+ api_key="your-api-key",
914
+ model="gpt-4o",
915
+ )
916
+
917
+ Example with chain-of-verification:
918
+ results = multi_class(
919
+ input_data=["I moved for work"],
920
+ categories=["Employment", "Family"],
921
+ api_key="your-api-key",
922
+ model="gpt-4o",
923
+ chain_of_verification=True,
924
+ survey_question="Why did you move?",
925
+ )
926
+
927
+ .. deprecated::
928
+ Use :func:`cat_stack.classify` instead. This function will be removed in a future version.
929
+ """
930
+ warnings.warn(
931
+ "multi_class() is deprecated and will be removed in a future version. "
932
+ "Use cat_stack.classify() instead, which supports single and multi-model classification.",
933
+ DeprecationWarning,
934
+ stacklevel=2,
935
+ )
936
+
937
+ # Detect provider
938
+ provider = detect_provider(model, provider)
939
+
940
+ # Validate api_key requirement
941
+ if provider not in ("ollama", "claude-code") and not api_key:
942
+ raise ValueError(f"api_key is required for provider '{provider}'")
943
+
944
+ # Handle categories="auto" - auto-detect categories from the data
945
+ if categories == "auto":
946
+ if survey_question == "":
947
+ raise TypeError("survey_question is required when using categories='auto'. Please provide the survey question you are analyzing.")
948
+
949
+ categories = explore_common_categories(
950
+ survey_question=survey_question,
951
+ input_data=input_data,
952
+ research_question=research_question,
953
+ api_key=api_key,
954
+ model_source=provider,
955
+ user_model=model,
956
+ max_categories=max_categories,
957
+ categories_per_chunk=categories_per_chunk,
958
+ divisions=divisions
959
+ )
960
+
961
+ # Build examples text for few-shot prompting
962
+ examples = [example1, example2, example3, example4, example5, example6]
963
+ examples_text = "\n".join(
964
+ f"Example {i}: {ex}" for i, ex in enumerate(examples, 1) if ex is not None
965
+ )
966
+
967
+ # Survey question context
968
+ survey_question_context = f"Context: {survey_question}." if survey_question else ""
969
+
970
+ # Step-back insight initialization
971
+ stepback_insight = None
972
+ step_back_added = False
973
+ if step_back_prompt:
974
+ if survey_question == "":
975
+ raise TypeError("survey_question is required when using step_back_prompt. Please provide the survey question you are analyzing.")
976
+
977
+ stepback_question = f'What are the underlying factors or dimensions that explain how people typically answer "{survey_question}"?'
978
+ stepback_insight, step_back_added = _get_stepback_insight(
979
+ provider, stepback_question, api_key, model, creativity
980
+ )
981
+
982
+ # Ollama-specific checks
983
+ if provider == "ollama":
984
+ if not check_ollama_running():
985
+ raise ConnectionError(
986
+ "\n" + "="*60 + "\n"
987
+ " OLLAMA NOT RUNNING\n"
988
+ "="*60 + "\n\n"
989
+ "Ollama must be running to use local models.\n\n"
990
+ "To start Ollama:\n"
991
+ " macOS: Open the Ollama app, or run 'ollama serve'\n"
992
+ " Linux: Run 'ollama serve' in terminal\n"
993
+ " Windows: Open the Ollama app\n\n"
994
+ "Don't have Ollama installed?\n"
995
+ " Download from: https://ollama.ai/download\n\n"
996
+ "After starting Ollama, run your code again.\n"
997
+ + "="*60
998
+ )
999
+
1000
+ # Check system resources before proceeding
1001
+ resources = check_system_resources(model)
1002
+
1003
+ # Check if model needs to be downloaded
1004
+ model_installed = check_ollama_model(model)
1005
+
1006
+ if not model_installed:
1007
+ if not pull_ollama_model(model, auto_confirm=auto_download):
1008
+ raise RuntimeError(
1009
+ f"Model '{model}' not available. "
1010
+ f"To download manually: ollama pull {model}"
1011
+ )
1012
+ else:
1013
+ # Model is installed - still check if it can run
1014
+ if resources["warnings"] or not resources["can_run"]:
1015
+ print(f"\n{'='*60}")
1016
+ print(f" Model '{model}' - System Resource Check")
1017
+ print(f"{'='*60}")
1018
+ size_estimate = get_ollama_model_size_estimate(model)
1019
+ print(f" Model size: {size_estimate}")
1020
+ if resources["details"].get("estimated_ram"):
1021
+ print(f" RAM required: ~{resources['details']['estimated_ram']}")
1022
+ if resources["details"].get("total_ram"):
1023
+ print(f" System RAM: {resources['details']['total_ram']}")
1024
+
1025
+ if resources["warnings"]:
1026
+ print(f"\n {'!'*50}")
1027
+ for warning in resources["warnings"]:
1028
+ print(f" Warning: {warning}")
1029
+ print(f" {'!'*50}")
1030
+
1031
+ if not resources["can_run"]:
1032
+ print(f"\n Warning: Model may not run well on this system.")
1033
+ print(f" Consider a smaller variant (e.g., '{model}:1b' or '{model}:3b').")
1034
+ print(f"{'='*60}")
1035
+
1036
+ if not auto_download:
1037
+ try:
1038
+ response = input(f"\n Continue anyway? [y/N]: ").strip().lower()
1039
+ if response not in ['y', 'yes']:
1040
+ raise RuntimeError(
1041
+ f"Model '{model}' may be too large for this system. "
1042
+ f"Try a smaller variant like '{model}:3b' or '{model}:1b'."
1043
+ )
1044
+ except (EOFError, KeyboardInterrupt):
1045
+ raise RuntimeError("Operation cancelled by user.")
1046
+
1047
+ print()
1048
+
1049
+ print(f"Using provider: {provider}, model: {model}")
1050
+
1051
+ # Initialize client
1052
+ client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
1053
+
1054
+ # Build category string and schema
1055
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
1056
+ # Build JSON schema - Google doesn't support additionalProperties
1057
+ if use_json_schema:
1058
+ include_additional = (provider != "google")
1059
+ json_schema = build_json_schema(categories, include_additional_properties=include_additional)
1060
+ else:
1061
+ json_schema = None
1062
+
1063
+ # Print categories
1064
+ print(f"\nCategories to classify ({len(categories)} total):")
1065
+ for i, cat in enumerate(categories, 1):
1066
+ print(f" {i}. {cat}")
1067
+ print()
1068
+
1069
+ # Build prompt template
1070
+ def build_prompt(response_text: str) -> tuple:
1071
+ """Build the classification prompt for a single response.
1072
+
1073
+ Returns:
1074
+ tuple: (messages list, user_prompt string for CoVe)
1075
+ """
1076
+ if chain_of_thought:
1077
+ user_prompt = f"""{survey_question_context}
1078
+
1079
+ Categorize this text response "{response_text}" into the following categories that apply:
1080
+ {categories_str}
1081
+
1082
+ Let's think step by step:
1083
+ 1. First, identify the main themes mentioned in the response
1084
+ 2. Then, match each theme to the relevant categories
1085
+ 3. Finally, assign 1 to matching categories and 0 to non-matching categories
1086
+
1087
+ {examples_text}
1088
+
1089
+ Provide your answer in JSON format where the category number is the key and "1" if present, "0" if not."""
1090
+ else:
1091
+ user_prompt = f"""{survey_question_context}
1092
+ Categorize this text response "{response_text}" into the following categories that apply:
1093
+ {categories_str}
1094
+ {examples_text}
1095
+ Provide your answer in JSON format where the category number is the key and "1" if present, "0" if not."""
1096
+
1097
+ # Add context prompt prefix if enabled
1098
+ if context_prompt:
1099
+ context = """You are an expert researcher in text data categorization.
1100
+ Apply multi-label classification and base decisions on explicit and implicit meanings.
1101
+ When uncertain, prioritize precision over recall.
1102
+
1103
+ """
1104
+ user_prompt = context + user_prompt
1105
+
1106
+ # Build messages list
1107
+ messages = []
1108
+
1109
+ # Add step-back insight if available
1110
+ if step_back_prompt and step_back_added and stepback_insight:
1111
+ messages.append({"role": "user", "content": stepback_question})
1112
+ messages.append({"role": "assistant", "content": stepback_insight})
1113
+
1114
+ messages.append({"role": "user", "content": user_prompt})
1115
+
1116
+ return messages, user_prompt
1117
+
1118
+ # Build chain of verification prompts
1119
+ def build_cove_prompts(prompt: str, response_text: str) -> tuple:
1120
+ """Build chain of verification prompts."""
1121
+ step2_prompt = f"""You provided this initial categorization:
1122
+ <<INITIAL_REPLY>>
1123
+
1124
+ Original task: {prompt}
1125
+
1126
+ Generate a focused list of 3-5 verification questions to fact-check your categorization. Each question should:
1127
+ - Be concise and specific (one sentence)
1128
+ - Address a distinct aspect of the categorization
1129
+ - Be answerable independently
1130
+
1131
+ Focus on verifying:
1132
+ - Whether each category assignment is accurate
1133
+ - Whether the categories match the criteria in the original task
1134
+ - Whether there are any logical inconsistencies
1135
+
1136
+ Provide only the verification questions as a numbered list."""
1137
+
1138
+ step3_prompt = f"""Answer the following verification question based on the text response provided.
1139
+
1140
+ Text response: {response_text}
1141
+
1142
+ Verification question: <<QUESTION>>
1143
+
1144
+ Provide a brief, direct answer (1-2 sentences maximum).
1145
+
1146
+ Answer:"""
1147
+
1148
+ step4_prompt = f"""Original task: {prompt}
1149
+ Initial categorization:
1150
+ <<INITIAL_REPLY>>
1151
+ Verification questions and answers:
1152
+ <<VERIFICATION_QA>>
1153
+ If no categories are present, assign "0" to all categories.
1154
+ Provide the final corrected categorization in the same JSON format:"""
1155
+
1156
+ return step2_prompt, step3_prompt, step4_prompt
1157
+
1158
+ def remove_numbering(line: str) -> str:
1159
+ """Remove numbering/bullets from a line for CoVe question parsing."""
1160
+ line = line.strip()
1161
+ if line.startswith('- '):
1162
+ return line[2:].strip()
1163
+ if line.startswith('• '):
1164
+ return line[2:].strip()
1165
+ if line and line[0].isdigit():
1166
+ i = 0
1167
+ while i < len(line) and line[i].isdigit():
1168
+ i += 1
1169
+ if i < len(line) and line[i] in '.)':
1170
+ return line[i+1:].strip()
1171
+ return line
1172
+
1173
+ def run_chain_of_verification(initial_reply: str, step2_prompt: str, step3_prompt: str, step4_prompt: str) -> str:
1174
+ """Run chain of verification using the unified client."""
1175
+ # Step 2: Generate verification questions (text response, not JSON)
1176
+ step2_filled = step2_prompt.replace("<<INITIAL_REPLY>>", initial_reply)
1177
+ questions_reply, err = client.complete(
1178
+ messages=[{"role": "user", "content": step2_filled}],
1179
+ creativity=creativity,
1180
+ force_json=False, # Text response
1181
+ )
1182
+ if err:
1183
+ return initial_reply # Fall back to initial reply on error
1184
+
1185
+ # Parse questions
1186
+ questions = [remove_numbering(line) for line in questions_reply.strip().split('\n') if line.strip()]
1187
+
1188
+ # Step 3: Answer each verification question (text responses)
1189
+ qa_pairs = []
1190
+ for question in questions[:5]: # Limit to 5 questions
1191
+ step3_filled = step3_prompt.replace("<<QUESTION>>", question)
1192
+ answer_reply, err = client.complete(
1193
+ messages=[{"role": "user", "content": step3_filled}],
1194
+ creativity=creativity,
1195
+ force_json=False, # Text response
1196
+ )
1197
+ if not err:
1198
+ qa_pairs.append(f"Q: {question}\nA: {answer_reply.strip()}")
1199
+
1200
+ verification_qa = "\n\n".join(qa_pairs)
1201
+
1202
+ # Step 4: Final corrected categorization (JSON response)
1203
+ step4_filled = step4_prompt.replace("<<INITIAL_REPLY>>", initial_reply).replace("<<VERIFICATION_QA>>", verification_qa)
1204
+ final_reply, err = client.complete(
1205
+ messages=[{"role": "user", "content": step4_filled}],
1206
+ json_schema=json_schema,
1207
+ creativity=creativity,
1208
+ )
1209
+
1210
+ if err:
1211
+ return initial_reply
1212
+ return final_reply
1213
+
1214
+ # Process each response
1215
+ results = []
1216
+ extracted_jsons = []
1217
+
1218
+ # Use two-step approach for Ollama (more reliable JSON output)
1219
+ use_two_step = (provider == "ollama")
1220
+
1221
+ if use_two_step:
1222
+ print("Using two-step classification for Ollama (classify -> format JSON)")
1223
+
1224
+ for idx, response in enumerate(tqdm(input_data, desc="Classifying responses")):
1225
+ if pd.isna(response):
1226
+ results.append(("Skipped NaN", "Skipped NaN input"))
1227
+ extracted_jsons.append('{"1":"e"}')
1228
+ continue
1229
+
1230
+ if use_two_step:
1231
+ json_result, error = ollama_two_step_classify(
1232
+ client=client,
1233
+ response_text=response,
1234
+ categories=categories,
1235
+ categories_str=categories_str,
1236
+ survey_question=survey_question,
1237
+ creativity=creativity,
1238
+ max_retries=5,
1239
+ )
1240
+
1241
+ if error:
1242
+ results.append((json_result, error))
1243
+ else:
1244
+ results.append((json_result, None))
1245
+ extracted_jsons.append(json_result)
1246
+
1247
+ else:
1248
+ messages, user_prompt = build_prompt(response)
1249
+ reply, error = client.complete(
1250
+ messages=messages,
1251
+ json_schema=json_schema,
1252
+ creativity=creativity,
1253
+ thinking_budget=thinking_budget if provider == "google" else None,
1254
+ )
1255
+
1256
+ if error:
1257
+ results.append((None, error))
1258
+ extracted_jsons.append('{"1":"e"}')
1259
+ else:
1260
+ # Apply chain of verification if enabled
1261
+ if chain_of_verification and reply:
1262
+ step2, step3, step4 = build_cove_prompts(user_prompt, response)
1263
+ reply = run_chain_of_verification(reply, step2, step3, step4)
1264
+
1265
+ results.append((reply, None))
1266
+ extracted_jsons.append(extract_json(reply))
1267
+
1268
+ # Safety incremental save
1269
+ if safety:
1270
+ if filename is None:
1271
+ raise TypeError("filename is required when using safety=True. Please provide a filename to save to.")
1272
+
1273
+ # Build partial DataFrame and save
1274
+ normalized_partial = []
1275
+ for json_str in extracted_jsons:
1276
+ try:
1277
+ parsed = json.loads(json_str)
1278
+ normalized_partial.append(pd.json_normalize(parsed))
1279
+ except json.JSONDecodeError:
1280
+ normalized_partial.append(pd.DataFrame({"1": ["e"]}))
1281
+
1282
+ if normalized_partial:
1283
+ normalized_df = pd.concat(normalized_partial, ignore_index=True)
1284
+ partial_df = pd.DataFrame({
1285
+ 'input_data': pd.Series(input_data[:len(results)]).reset_index(drop=True),
1286
+ 'model_response': [r[0] for r in results],
1287
+ 'error': [r[1] for r in results],
1288
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True),
1289
+ })
1290
+ partial_df = pd.concat([partial_df, normalized_df], axis=1)
1291
+ partial_df = partial_df.rename(columns=lambda x: f'category_{x}' if str(x).isdigit() else x)
1292
+
1293
+ save_path = filename
1294
+ if save_directory:
1295
+ import os
1296
+ os.makedirs(save_directory, exist_ok=True)
1297
+ save_path = os.path.join(save_directory, filename)
1298
+ partial_df.to_csv(save_path, index=False)
1299
+
1300
+ # Build output DataFrame
1301
+ normalized_data_list = []
1302
+ for json_str in extracted_jsons:
1303
+ try:
1304
+ parsed = json.loads(json_str)
1305
+ normalized_data_list.append(pd.json_normalize(parsed))
1306
+ except json.JSONDecodeError:
1307
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
1308
+
1309
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
1310
+
1311
+ # Create main DataFrame
1312
+ df = pd.DataFrame({
1313
+ 'input_data': pd.Series(input_data).reset_index(drop=True),
1314
+ 'model_response': [r[0] for r in results],
1315
+ 'error': [r[1] for r in results],
1316
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True),
1317
+ })
1318
+
1319
+ df = pd.concat([df, normalized_data], axis=1)
1320
+
1321
+ # Rename category columns
1322
+ df = df.rename(columns=lambda x: f'category_{x}' if str(x).isdigit() else x)
1323
+
1324
+ # Process category columns
1325
+ cat_cols = [col for col in df.columns if col.startswith('category_')]
1326
+
1327
+ # Identify invalid rows
1328
+ has_invalid = df[cat_cols].apply(
1329
+ lambda col: pd.to_numeric(col, errors='coerce').isna() & col.notna()
1330
+ ).any(axis=1)
1331
+
1332
+ df['processing_status'] = (~has_invalid).map({True: 'success', False: 'error'})
1333
+ df.loc[has_invalid, cat_cols] = pd.NA
1334
+
1335
+ # Convert to numeric
1336
+ for col in cat_cols:
1337
+ df[col] = pd.to_numeric(df[col], errors='coerce')
1338
+
1339
+ # Fill NaN with 0 for valid rows
1340
+ df.loc[~has_invalid, cat_cols] = df.loc[~has_invalid, cat_cols].fillna(0)
1341
+
1342
+ # Convert to Int64
1343
+ df[cat_cols] = df[cat_cols].astype('Int64')
1344
+
1345
+ # Create categories_id
1346
+ df['categories_id'] = df[cat_cols].apply(
1347
+ lambda x: ','.join(x.dropna().astype(int).astype(str)), axis=1
1348
+ )
1349
+
1350
+ if filename:
1351
+ df.to_csv(filename, index=False)
1352
+ print(f"\nResults saved to {filename}")
1353
+
1354
+ return df
1355
+
1356
+
1357
+ # Note: For the legacy implementation with chain_of_verification, step_back_prompt,
1358
+ # and other advanced features, see text_functions_old.py