cat-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cat_stack/__about__.py +10 -0
- cat_stack/__init__.py +128 -0
- cat_stack/_batch.py +1388 -0
- cat_stack/_category_analysis.py +348 -0
- cat_stack/_chunked.py +424 -0
- cat_stack/_embeddings.py +189 -0
- cat_stack/_formatter.py +169 -0
- cat_stack/_providers.py +1048 -0
- cat_stack/_tiebreaker.py +277 -0
- cat_stack/_utils.py +512 -0
- cat_stack/_web_fetch.py +194 -0
- cat_stack/calls/CoVe.py +287 -0
- cat_stack/calls/__init__.py +25 -0
- cat_stack/calls/all_calls.py +622 -0
- cat_stack/calls/image_CoVe.py +386 -0
- cat_stack/calls/image_stepback.py +210 -0
- cat_stack/calls/pdf_CoVe.py +386 -0
- cat_stack/calls/pdf_stepback.py +210 -0
- cat_stack/calls/stepback.py +180 -0
- cat_stack/calls/top_n.py +217 -0
- cat_stack/classify.py +682 -0
- cat_stack/explore.py +111 -0
- cat_stack/extract.py +218 -0
- cat_stack/image_functions.py +2078 -0
- cat_stack/images/circle.png +0 -0
- cat_stack/images/cube.png +0 -0
- cat_stack/images/diamond.png +0 -0
- cat_stack/images/overlapping_pentagons.png +0 -0
- cat_stack/images/rectangles.png +0 -0
- cat_stack/model_reference_list.py +94 -0
- cat_stack/pdf_functions.py +2087 -0
- cat_stack/summarize.py +290 -0
- cat_stack/text_functions.py +1358 -0
- cat_stack/text_functions_ensemble.py +3644 -0
- cat_stack-0.1.0.dist-info/METADATA +150 -0
- cat_stack-0.1.0.dist-info/RECORD +38 -0
- cat_stack-0.1.0.dist-info/WHEEL +4 -0
- cat_stack-0.1.0.dist-info/licenses/LICENSE +672 -0
|
@@ -0,0 +1,1358 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Text classification functions for CatLLM.
|
|
3
|
+
|
|
4
|
+
This module provides multi-class text classification using a unified HTTP-based approach
|
|
5
|
+
that works with multiple LLM providers (OpenAI, Anthropic, Google, Mistral, xAI,
|
|
6
|
+
Perplexity, HuggingFace, and Ollama) without requiring provider-specific SDKs.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import time
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
# Exported names (excludes deprecated multi_class)
|
|
14
|
+
__all__ = [
|
|
15
|
+
"UnifiedLLMClient",
|
|
16
|
+
"detect_provider",
|
|
17
|
+
"set_ollama_endpoint",
|
|
18
|
+
"check_ollama_running",
|
|
19
|
+
"list_ollama_models",
|
|
20
|
+
"check_ollama_model",
|
|
21
|
+
"check_system_resources",
|
|
22
|
+
"get_ollama_model_size_estimate",
|
|
23
|
+
"pull_ollama_model",
|
|
24
|
+
"check_claude_cli_available",
|
|
25
|
+
"build_json_schema",
|
|
26
|
+
"extract_json",
|
|
27
|
+
"validate_classification_json",
|
|
28
|
+
"ollama_two_step_classify",
|
|
29
|
+
"explore_corpus",
|
|
30
|
+
"explore_common_categories",
|
|
31
|
+
# Internal utilities used by other modules
|
|
32
|
+
"_detect_model_source",
|
|
33
|
+
"_get_stepback_insight",
|
|
34
|
+
"_detect_huggingface_endpoint",
|
|
35
|
+
]
|
|
36
|
+
import pandas as pd
|
|
37
|
+
import regex
|
|
38
|
+
from tqdm import tqdm
|
|
39
|
+
|
|
40
|
+
from .calls.stepback import (
|
|
41
|
+
get_stepback_insight_openai,
|
|
42
|
+
get_stepback_insight_anthropic,
|
|
43
|
+
get_stepback_insight_google,
|
|
44
|
+
get_stepback_insight_mistral
|
|
45
|
+
)
|
|
46
|
+
from .calls.CoVe import (
|
|
47
|
+
chain_of_verification_openai,
|
|
48
|
+
chain_of_verification_google,
|
|
49
|
+
chain_of_verification_anthropic,
|
|
50
|
+
chain_of_verification_mistral
|
|
51
|
+
)
|
|
52
|
+
from .calls.top_n import (
|
|
53
|
+
get_openai_top_n,
|
|
54
|
+
get_anthropic_top_n,
|
|
55
|
+
get_google_top_n,
|
|
56
|
+
get_mistral_top_n
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
from ._providers import (
|
|
60
|
+
UnifiedLLMClient,
|
|
61
|
+
PROVIDER_CONFIG,
|
|
62
|
+
detect_provider,
|
|
63
|
+
_detect_model_source,
|
|
64
|
+
_detect_huggingface_endpoint,
|
|
65
|
+
set_ollama_endpoint,
|
|
66
|
+
check_ollama_running,
|
|
67
|
+
list_ollama_models,
|
|
68
|
+
check_ollama_model,
|
|
69
|
+
check_system_resources,
|
|
70
|
+
get_ollama_model_size_estimate,
|
|
71
|
+
pull_ollama_model,
|
|
72
|
+
check_claude_cli_available,
|
|
73
|
+
OLLAMA_MODEL_SIZES,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# =============================================================================
|
|
78
|
+
# Helper Functions
|
|
79
|
+
# =============================================================================
|
|
80
|
+
|
|
81
|
+
def _get_stepback_insight(model_source, stepback, api_key, user_model, creativity):
|
|
82
|
+
"""Get step-back insight using the appropriate provider."""
|
|
83
|
+
stepback_functions = {
|
|
84
|
+
"openai": get_stepback_insight_openai,
|
|
85
|
+
"perplexity": get_stepback_insight_openai,
|
|
86
|
+
"huggingface": get_stepback_insight_openai,
|
|
87
|
+
"huggingface-together": get_stepback_insight_openai,
|
|
88
|
+
"xai": get_stepback_insight_openai,
|
|
89
|
+
"anthropic": get_stepback_insight_anthropic,
|
|
90
|
+
"google": get_stepback_insight_google,
|
|
91
|
+
"mistral": get_stepback_insight_mistral,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
func = stepback_functions.get(model_source)
|
|
95
|
+
if func is None:
|
|
96
|
+
return None, False
|
|
97
|
+
|
|
98
|
+
return func(
|
|
99
|
+
stepback=stepback,
|
|
100
|
+
api_key=api_key,
|
|
101
|
+
user_model=user_model,
|
|
102
|
+
model_source=model_source,
|
|
103
|
+
creativity=creativity
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# =============================================================================
|
|
109
|
+
# JSON Schema Functions
|
|
110
|
+
# =============================================================================
|
|
111
|
+
|
|
112
|
+
def build_json_schema(categories: list, include_additional_properties: bool = True) -> dict:
|
|
113
|
+
"""Build a JSON schema for the classification output.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
categories: List of category names
|
|
117
|
+
include_additional_properties: If True, includes additionalProperties: false
|
|
118
|
+
(required by OpenAI strict mode, not supported by Google)
|
|
119
|
+
"""
|
|
120
|
+
properties = {}
|
|
121
|
+
for i, cat in enumerate(categories, 1):
|
|
122
|
+
properties[str(i)] = {
|
|
123
|
+
"type": "string",
|
|
124
|
+
"enum": ["0", "1"],
|
|
125
|
+
"description": cat,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
schema = {
|
|
129
|
+
"type": "object",
|
|
130
|
+
"properties": properties,
|
|
131
|
+
"required": list(properties.keys()),
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
if include_additional_properties:
|
|
135
|
+
schema["additionalProperties"] = False
|
|
136
|
+
|
|
137
|
+
return schema
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def extract_json(reply: str) -> str:
|
|
141
|
+
"""Extract JSON from model reply."""
|
|
142
|
+
if reply is None:
|
|
143
|
+
return '{"1":"e"}'
|
|
144
|
+
|
|
145
|
+
extracted = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
146
|
+
if extracted:
|
|
147
|
+
raw = extracted[0].replace('[', '').replace(']', '')
|
|
148
|
+
# Parse and re-serialize to normalize structural whitespace while
|
|
149
|
+
# preserving spaces inside string values (e.g. summaries)
|
|
150
|
+
try:
|
|
151
|
+
parsed = json.loads(raw)
|
|
152
|
+
return json.dumps(parsed, separators=(',', ':'))
|
|
153
|
+
except json.JSONDecodeError:
|
|
154
|
+
return raw.replace('\n', '')
|
|
155
|
+
else:
|
|
156
|
+
return '{"1":"e"}'
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def validate_classification_json(json_str: str, num_categories: int) -> tuple[bool, dict | None]:
|
|
160
|
+
"""
|
|
161
|
+
Validate that a JSON string contains valid classification output.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
json_str: The JSON string to validate
|
|
165
|
+
num_categories: Expected number of categories
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
tuple: (is_valid, parsed_dict or None)
|
|
169
|
+
"""
|
|
170
|
+
try:
|
|
171
|
+
parsed = json.loads(json_str)
|
|
172
|
+
|
|
173
|
+
if not isinstance(parsed, dict):
|
|
174
|
+
return False, None
|
|
175
|
+
|
|
176
|
+
# Check that all expected keys are present and values are "0" or "1"
|
|
177
|
+
for i in range(1, num_categories + 1):
|
|
178
|
+
key = str(i)
|
|
179
|
+
if key not in parsed:
|
|
180
|
+
return False, None
|
|
181
|
+
val = str(parsed[key]).strip()
|
|
182
|
+
if val not in ("0", "1"):
|
|
183
|
+
return False, None
|
|
184
|
+
|
|
185
|
+
# Normalize values to strings
|
|
186
|
+
normalized = {str(i): str(parsed[str(i)]).strip() for i in range(1, num_categories + 1)}
|
|
187
|
+
return True, normalized
|
|
188
|
+
|
|
189
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
190
|
+
return False, None
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def ollama_two_step_classify(
|
|
194
|
+
client,
|
|
195
|
+
response_text: str,
|
|
196
|
+
categories: list,
|
|
197
|
+
categories_str: str,
|
|
198
|
+
survey_question: str = "",
|
|
199
|
+
creativity: float = None,
|
|
200
|
+
max_retries: int = 5,
|
|
201
|
+
) -> tuple[str, str | None]:
|
|
202
|
+
"""
|
|
203
|
+
Two-step classification for Ollama models.
|
|
204
|
+
|
|
205
|
+
Step 1: Classify the response (natural language output OK)
|
|
206
|
+
Step 2: Convert classification to strict JSON format
|
|
207
|
+
|
|
208
|
+
This approach is more reliable for local models that struggle with
|
|
209
|
+
simultaneous reasoning and JSON formatting.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
client: UnifiedLLMClient instance
|
|
213
|
+
response_text: The text response to classify
|
|
214
|
+
categories: List of category names
|
|
215
|
+
categories_str: Pre-formatted category string
|
|
216
|
+
survey_question: Optional context
|
|
217
|
+
creativity: Temperature setting
|
|
218
|
+
max_retries: Number of retry attempts for JSON validation
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
tuple: (json_string, error_message or None)
|
|
222
|
+
"""
|
|
223
|
+
num_categories = len(categories)
|
|
224
|
+
survey_context = f"Context: {survey_question}." if survey_question else ""
|
|
225
|
+
|
|
226
|
+
# ==========================================================================
|
|
227
|
+
# Step 1: Classification (natural language - focus on accuracy)
|
|
228
|
+
# ==========================================================================
|
|
229
|
+
step1_messages = [
|
|
230
|
+
{
|
|
231
|
+
"role": "system",
|
|
232
|
+
"content": "You are an expert at categorizing text responses. Focus on accurate classification."
|
|
233
|
+
},
|
|
234
|
+
{
|
|
235
|
+
"role": "user",
|
|
236
|
+
"content": f"""{survey_context}
|
|
237
|
+
|
|
238
|
+
Analyze this text response and determine which categories apply:
|
|
239
|
+
|
|
240
|
+
Response: "{response_text}"
|
|
241
|
+
|
|
242
|
+
Categories:
|
|
243
|
+
{categories_str}
|
|
244
|
+
|
|
245
|
+
For each category, explain briefly whether it applies (YES) or not (NO) to this response.
|
|
246
|
+
Format your answer as:
|
|
247
|
+
1. [Category name]: YES/NO - [brief reason]
|
|
248
|
+
2. [Category name]: YES/NO - [brief reason]
|
|
249
|
+
...and so on for all categories."""
|
|
250
|
+
}
|
|
251
|
+
]
|
|
252
|
+
|
|
253
|
+
step1_reply, step1_error = client.complete(
|
|
254
|
+
messages=step1_messages,
|
|
255
|
+
json_schema=None, # No JSON requirement for step 1
|
|
256
|
+
creativity=creativity,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if step1_error:
|
|
260
|
+
return '{"1":"e"}', f"Step 1 failed: {step1_error}"
|
|
261
|
+
|
|
262
|
+
# ==========================================================================
|
|
263
|
+
# Step 2: JSON Formatting with validation and retry
|
|
264
|
+
# ==========================================================================
|
|
265
|
+
example_json = json.dumps({str(i): "0" for i in range(1, num_categories + 1)})
|
|
266
|
+
|
|
267
|
+
for attempt in range(max_retries):
|
|
268
|
+
step2_messages = [
|
|
269
|
+
{
|
|
270
|
+
"role": "system",
|
|
271
|
+
"content": "You convert classification results to JSON. Output ONLY valid JSON, nothing else."
|
|
272
|
+
},
|
|
273
|
+
{
|
|
274
|
+
"role": "user",
|
|
275
|
+
"content": f"""Convert this classification to JSON format.
|
|
276
|
+
|
|
277
|
+
Classification results:
|
|
278
|
+
{step1_reply}
|
|
279
|
+
|
|
280
|
+
Rules:
|
|
281
|
+
- Output ONLY a JSON object, no other text
|
|
282
|
+
- Use category numbers as keys (1, 2, 3, etc.)
|
|
283
|
+
- Use "1" if the category was marked YES, "0" if NO
|
|
284
|
+
- Include ALL {num_categories} categories
|
|
285
|
+
|
|
286
|
+
Example format:
|
|
287
|
+
{example_json}
|
|
288
|
+
|
|
289
|
+
Your JSON output:"""
|
|
290
|
+
}
|
|
291
|
+
]
|
|
292
|
+
|
|
293
|
+
step2_reply, step2_error = client.complete(
|
|
294
|
+
messages=step2_messages,
|
|
295
|
+
json_schema=None, # Ollama doesn't support strict schema anyway
|
|
296
|
+
creativity=0.1, # Low temperature for formatting task
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
if step2_error:
|
|
300
|
+
if attempt < max_retries - 1:
|
|
301
|
+
continue
|
|
302
|
+
return '{"1":"e"}', f"Step 2 failed: {step2_error}"
|
|
303
|
+
|
|
304
|
+
# Extract and validate JSON
|
|
305
|
+
extracted = extract_json(step2_reply)
|
|
306
|
+
is_valid, normalized = validate_classification_json(extracted, num_categories)
|
|
307
|
+
|
|
308
|
+
if is_valid:
|
|
309
|
+
return json.dumps(normalized), None
|
|
310
|
+
|
|
311
|
+
# If invalid, try again with more explicit instructions
|
|
312
|
+
if attempt < max_retries - 1:
|
|
313
|
+
step1_reply = f"""Previous attempt produced invalid JSON.
|
|
314
|
+
|
|
315
|
+
Original classification:
|
|
316
|
+
{step1_reply}
|
|
317
|
+
|
|
318
|
+
Please be more careful to output EXACTLY {num_categories} categories numbered 1 through {num_categories}."""
|
|
319
|
+
|
|
320
|
+
# All retries exhausted - try to salvage what we can
|
|
321
|
+
extracted = extract_json(step2_reply) if step2_reply else '{"1":"e"}'
|
|
322
|
+
return extracted, f"JSON validation failed after {max_retries} attempts"
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
# =============================================================================
|
|
326
|
+
# Category Exploration Functions
|
|
327
|
+
# =============================================================================
|
|
328
|
+
|
|
329
|
+
def explore_corpus(
|
|
330
|
+
survey_question,
|
|
331
|
+
input_data,
|
|
332
|
+
api_key: str = None,
|
|
333
|
+
research_question=None,
|
|
334
|
+
specificity="broad",
|
|
335
|
+
categories_per_chunk=10,
|
|
336
|
+
divisions=5,
|
|
337
|
+
model: str = "gpt-4o",
|
|
338
|
+
provider: str = "auto",
|
|
339
|
+
creativity=None,
|
|
340
|
+
filename="corpus_exploration.csv",
|
|
341
|
+
focus: str = None,
|
|
342
|
+
):
|
|
343
|
+
"""
|
|
344
|
+
Extract categories from text corpus using LLM.
|
|
345
|
+
|
|
346
|
+
Uses raw HTTP requests via UnifiedLLMClient - supports all providers.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
survey_question: The survey question being analyzed
|
|
350
|
+
input_data: Series or list of text responses
|
|
351
|
+
api_key: API key for the LLM provider
|
|
352
|
+
research_question: Optional research context
|
|
353
|
+
specificity: "broad" or "specific" categories
|
|
354
|
+
categories_per_chunk: Number of categories to extract per chunk
|
|
355
|
+
divisions: Number of chunks to process
|
|
356
|
+
model: Model name (e.g., "gpt-4o", "claude-3-haiku-20240307", "gemini-2.5-flash")
|
|
357
|
+
provider: Provider name or "auto" to detect from model name
|
|
358
|
+
creativity: Temperature setting
|
|
359
|
+
filename: Output CSV filename (None to skip saving)
|
|
360
|
+
focus: Optional focus instruction for category extraction (e.g., "decisions to move",
|
|
361
|
+
"emotional responses", "financial considerations"). When provided, the model
|
|
362
|
+
will prioritize extracting categories related to this focus.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
DataFrame with extracted categories and counts
|
|
366
|
+
"""
|
|
367
|
+
# Detect provider
|
|
368
|
+
provider = detect_provider(model, provider)
|
|
369
|
+
|
|
370
|
+
# Validate api_key
|
|
371
|
+
if provider not in ("ollama", "claude-code") and not api_key:
|
|
372
|
+
raise ValueError(f"api_key is required for provider '{provider}'")
|
|
373
|
+
|
|
374
|
+
print(f"Exploring categories for question: '{survey_question}'")
|
|
375
|
+
print(f"Using provider: {provider}, model: {model}")
|
|
376
|
+
if focus:
|
|
377
|
+
print(f"Focus: {focus}")
|
|
378
|
+
print(f" {categories_per_chunk * divisions} unique categories to be extracted.")
|
|
379
|
+
print()
|
|
380
|
+
|
|
381
|
+
# Input normalization
|
|
382
|
+
if not isinstance(input_data, pd.Series):
|
|
383
|
+
input_data = pd.Series(input_data)
|
|
384
|
+
input_data = input_data.dropna()
|
|
385
|
+
|
|
386
|
+
n = len(input_data)
|
|
387
|
+
if n == 0:
|
|
388
|
+
raise ValueError("input_data is empty after dropping NA.")
|
|
389
|
+
|
|
390
|
+
# Auto-adjust divisions for small datasets
|
|
391
|
+
original_divisions = divisions
|
|
392
|
+
divisions = min(divisions, max(1, n // 3))
|
|
393
|
+
if divisions != original_divisions:
|
|
394
|
+
print(f"Auto-adjusted divisions from {original_divisions} to {divisions} for {n} responses.")
|
|
395
|
+
|
|
396
|
+
chunk_size = int(round(max(1, n / divisions), 0))
|
|
397
|
+
|
|
398
|
+
if chunk_size < (categories_per_chunk / 2):
|
|
399
|
+
old_categories_per_chunk = categories_per_chunk
|
|
400
|
+
categories_per_chunk = max(3, chunk_size * 2)
|
|
401
|
+
print(f"Auto-adjusted categories_per_chunk from {old_categories_per_chunk} to {categories_per_chunk} for chunk size {chunk_size}.")
|
|
402
|
+
|
|
403
|
+
# Initialize unified client
|
|
404
|
+
client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
|
|
405
|
+
|
|
406
|
+
# Build system message
|
|
407
|
+
if research_question:
|
|
408
|
+
system_content = (
|
|
409
|
+
f"You are a helpful assistant that extracts categories from text responses. "
|
|
410
|
+
f"The specific task is to identify {specificity} categories of responses to a text prompt. "
|
|
411
|
+
f"The research question is: {research_question}"
|
|
412
|
+
)
|
|
413
|
+
else:
|
|
414
|
+
system_content = "You are a helpful assistant that extracts categories from text responses."
|
|
415
|
+
|
|
416
|
+
# Sample chunks
|
|
417
|
+
random_chunks = []
|
|
418
|
+
for i in range(divisions):
|
|
419
|
+
chunk = input_data.sample(n=chunk_size).tolist()
|
|
420
|
+
random_chunks.append(chunk)
|
|
421
|
+
|
|
422
|
+
responses = []
|
|
423
|
+
responses_list = []
|
|
424
|
+
|
|
425
|
+
for i in tqdm(range(divisions), desc="Processing chunks"):
|
|
426
|
+
survey_participant_chunks = '; '.join(str(x) for x in random_chunks[i])
|
|
427
|
+
focus_text = f" Focus specifically on {focus}." if focus else ""
|
|
428
|
+
prompt = (
|
|
429
|
+
f'Identify {categories_per_chunk} {specificity} categories of responses to the question "{survey_question}" '
|
|
430
|
+
f"in the following list of responses.{focus_text} Responses are each separated by a semicolon. "
|
|
431
|
+
f"Responses are contained within triple backticks here: ```{survey_participant_chunks}``` "
|
|
432
|
+
f"Number your categories from 1 through {categories_per_chunk} and be concise with the category labels and provide no description of the categories."
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
messages = [
|
|
436
|
+
{"role": "system", "content": system_content},
|
|
437
|
+
{"role": "user", "content": prompt}
|
|
438
|
+
]
|
|
439
|
+
|
|
440
|
+
reply, error = client.complete(
|
|
441
|
+
messages=messages,
|
|
442
|
+
creativity=creativity,
|
|
443
|
+
force_json=False, # Text response, not JSON
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if error:
|
|
447
|
+
if "context_length_exceeded" in str(error) or "maximum context length" in str(error):
|
|
448
|
+
raise ValueError(
|
|
449
|
+
f"Token limit exceeded for model {model}. "
|
|
450
|
+
f"Try increasing the 'divisions' parameter to create smaller chunks."
|
|
451
|
+
)
|
|
452
|
+
else:
|
|
453
|
+
print(f"API error on chunk {i+1}: {error}")
|
|
454
|
+
reply = ""
|
|
455
|
+
|
|
456
|
+
responses.append(reply)
|
|
457
|
+
|
|
458
|
+
# Extract just the text as a list
|
|
459
|
+
items = []
|
|
460
|
+
for line in (reply or "").split('\n'):
|
|
461
|
+
if '. ' in line:
|
|
462
|
+
try:
|
|
463
|
+
items.append(line.split('. ', 1)[1])
|
|
464
|
+
except IndexError:
|
|
465
|
+
pass
|
|
466
|
+
|
|
467
|
+
responses_list.append(items)
|
|
468
|
+
|
|
469
|
+
flat_list = [item.lower() for sublist in responses_list for item in sublist]
|
|
470
|
+
|
|
471
|
+
if not flat_list:
|
|
472
|
+
raise ValueError("No categories were extracted from the model responses.")
|
|
473
|
+
|
|
474
|
+
df = pd.DataFrame(flat_list, columns=['Category'])
|
|
475
|
+
counts = pd.Series(flat_list).value_counts()
|
|
476
|
+
df['counts'] = df['Category'].map(counts)
|
|
477
|
+
df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
|
|
478
|
+
df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
|
|
479
|
+
|
|
480
|
+
if filename is not None:
|
|
481
|
+
df.to_csv(filename, index=False)
|
|
482
|
+
print(f"Results saved to {filename}")
|
|
483
|
+
|
|
484
|
+
return df
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def explore_common_categories(
|
|
488
|
+
input_data,
|
|
489
|
+
api_key: str = None,
|
|
490
|
+
survey_question: str = "",
|
|
491
|
+
max_categories: int = 12,
|
|
492
|
+
categories_per_chunk: int = 10,
|
|
493
|
+
divisions: int = 5,
|
|
494
|
+
model: str = "gpt-4o",
|
|
495
|
+
provider: str = "auto",
|
|
496
|
+
creativity: float = None,
|
|
497
|
+
specificity: str = "broad",
|
|
498
|
+
research_question: str = None,
|
|
499
|
+
filename: str = None,
|
|
500
|
+
iterations: int = 5,
|
|
501
|
+
random_state: int = None,
|
|
502
|
+
focus: str = None,
|
|
503
|
+
progress_callback: callable = None,
|
|
504
|
+
return_raw: bool = False,
|
|
505
|
+
chunk_delay: float = 0.0,
|
|
506
|
+
auto_download: bool = False,
|
|
507
|
+
# Legacy parameter names for backward compatibility
|
|
508
|
+
user_model: str = None,
|
|
509
|
+
model_source: str = None,
|
|
510
|
+
):
|
|
511
|
+
"""
|
|
512
|
+
Extract and rank common categories from survey corpus.
|
|
513
|
+
|
|
514
|
+
Uses raw HTTP requests via UnifiedLLMClient - supports all providers.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
input_data: Series or list of text responses
|
|
518
|
+
api_key: API key for the LLM provider
|
|
519
|
+
survey_question: The survey question being analyzed
|
|
520
|
+
max_categories: Maximum number of top categories to return
|
|
521
|
+
categories_per_chunk: Number of categories to extract per chunk
|
|
522
|
+
divisions: Number of chunks to process per iteration
|
|
523
|
+
model: Model name (e.g., "gpt-4o", "claude-3-haiku-20240307", "gemini-2.5-flash")
|
|
524
|
+
provider: Provider name or "auto" to detect from model name
|
|
525
|
+
creativity: Temperature setting
|
|
526
|
+
specificity: "broad" or "specific" categories
|
|
527
|
+
research_question: Optional research context
|
|
528
|
+
filename: Output CSV filename (None to skip saving)
|
|
529
|
+
iterations: Number of passes over the data
|
|
530
|
+
random_state: Random seed for reproducibility
|
|
531
|
+
focus: Optional focus instruction for category extraction (e.g., "decisions to move",
|
|
532
|
+
"emotional responses", "financial considerations"). When provided, the model
|
|
533
|
+
will prioritize extracting categories related to this focus.
|
|
534
|
+
progress_callback: Optional callback function for progress updates.
|
|
535
|
+
Called as progress_callback(current_step, total_steps, step_label).
|
|
536
|
+
auto_download: If True, automatically download missing Ollama models
|
|
537
|
+
without prompting. Default False (interactive prompt).
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
dict with 'counts_df', 'top_categories', and 'raw_top_text'
|
|
541
|
+
"""
|
|
542
|
+
import re
|
|
543
|
+
import numpy as np
|
|
544
|
+
|
|
545
|
+
# Handle legacy parameter names
|
|
546
|
+
if user_model is not None:
|
|
547
|
+
model = user_model
|
|
548
|
+
if model_source is not None:
|
|
549
|
+
provider = model_source
|
|
550
|
+
|
|
551
|
+
# Detect provider
|
|
552
|
+
provider = detect_provider(model, provider)
|
|
553
|
+
|
|
554
|
+
# Validate api_key
|
|
555
|
+
if provider not in ("ollama", "claude-code") and not api_key:
|
|
556
|
+
raise ValueError(f"api_key is required for provider '{provider}'")
|
|
557
|
+
|
|
558
|
+
# Ollama-specific checks
|
|
559
|
+
if provider == "ollama":
|
|
560
|
+
if not check_ollama_running():
|
|
561
|
+
raise ConnectionError(
|
|
562
|
+
"\n" + "="*60 + "\n"
|
|
563
|
+
" OLLAMA NOT RUNNING\n"
|
|
564
|
+
"="*60 + "\n\n"
|
|
565
|
+
"Ollama must be running to use local models.\n\n"
|
|
566
|
+
"To start Ollama:\n"
|
|
567
|
+
" macOS: Open the Ollama app, or run 'ollama serve'\n"
|
|
568
|
+
" Linux: Run 'ollama serve' in terminal\n"
|
|
569
|
+
" Windows: Open the Ollama app\n\n"
|
|
570
|
+
"Don't have Ollama installed?\n"
|
|
571
|
+
" Download from: https://ollama.ai/download\n\n"
|
|
572
|
+
"After starting Ollama, run your code again.\n"
|
|
573
|
+
+ "="*60
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
# Check system resources before proceeding
|
|
577
|
+
resources = check_system_resources(model)
|
|
578
|
+
|
|
579
|
+
# Check if model needs to be downloaded
|
|
580
|
+
model_installed = check_ollama_model(model)
|
|
581
|
+
|
|
582
|
+
if not model_installed:
|
|
583
|
+
if not pull_ollama_model(model, auto_confirm=auto_download):
|
|
584
|
+
raise RuntimeError(
|
|
585
|
+
f"Model '{model}' not available. "
|
|
586
|
+
f"To download manually: ollama pull {model}"
|
|
587
|
+
)
|
|
588
|
+
else:
|
|
589
|
+
# Model is installed - still check if it can run
|
|
590
|
+
if resources["warnings"] or not resources["can_run"]:
|
|
591
|
+
print(f"\n{'='*60}")
|
|
592
|
+
print(f" Model '{model}' - System Resource Check")
|
|
593
|
+
print(f"{'='*60}")
|
|
594
|
+
size_estimate = get_ollama_model_size_estimate(model)
|
|
595
|
+
print(f" Model size: {size_estimate}")
|
|
596
|
+
if resources["details"].get("estimated_ram"):
|
|
597
|
+
print(f" RAM required: ~{resources['details']['estimated_ram']}")
|
|
598
|
+
if resources["details"].get("total_ram"):
|
|
599
|
+
print(f" System RAM: {resources['details']['total_ram']}")
|
|
600
|
+
|
|
601
|
+
if resources["warnings"]:
|
|
602
|
+
print(f"\n {'!'*50}")
|
|
603
|
+
for warning in resources["warnings"]:
|
|
604
|
+
print(f" Warning: {warning}")
|
|
605
|
+
print(f" {'!'*50}")
|
|
606
|
+
|
|
607
|
+
if not resources["can_run"]:
|
|
608
|
+
print(f"\n Warning: Model may not run well on this system.")
|
|
609
|
+
print(f" Consider a smaller variant (e.g., '{model}:1b' or '{model}:3b').")
|
|
610
|
+
print(f"{'='*60}")
|
|
611
|
+
|
|
612
|
+
if not auto_download:
|
|
613
|
+
try:
|
|
614
|
+
response = input(f"\n Continue anyway? [y/N]: ").strip().lower()
|
|
615
|
+
if response not in ['y', 'yes']:
|
|
616
|
+
raise RuntimeError(
|
|
617
|
+
f"Model '{model}' may be too large for this system. "
|
|
618
|
+
f"Try a smaller variant like '{model}:3b' or '{model}:1b'."
|
|
619
|
+
)
|
|
620
|
+
except (EOFError, KeyboardInterrupt):
|
|
621
|
+
raise RuntimeError("Operation cancelled by user.")
|
|
622
|
+
|
|
623
|
+
print()
|
|
624
|
+
|
|
625
|
+
# Input normalization
|
|
626
|
+
if not isinstance(input_data, pd.Series):
|
|
627
|
+
input_data = pd.Series(input_data)
|
|
628
|
+
input_data = input_data.dropna().astype("string")
|
|
629
|
+
n = len(input_data)
|
|
630
|
+
if n == 0:
|
|
631
|
+
raise ValueError("input_data is empty after dropping NA.")
|
|
632
|
+
|
|
633
|
+
# Auto-adjust divisions for small datasets
|
|
634
|
+
original_divisions = divisions
|
|
635
|
+
divisions = min(divisions, max(1, n // 3))
|
|
636
|
+
if divisions != original_divisions:
|
|
637
|
+
print(f"Auto-adjusted divisions from {original_divisions} to {divisions} for {n} responses.")
|
|
638
|
+
|
|
639
|
+
# Chunk sizing
|
|
640
|
+
chunk_size = int(round(max(1, n / divisions), 0))
|
|
641
|
+
if chunk_size < (categories_per_chunk / 2):
|
|
642
|
+
old_categories_per_chunk = categories_per_chunk
|
|
643
|
+
categories_per_chunk = max(3, chunk_size * 2)
|
|
644
|
+
print(f"Auto-adjusted categories_per_chunk from {old_categories_per_chunk} to {categories_per_chunk} for chunk size {chunk_size}.")
|
|
645
|
+
|
|
646
|
+
print(f"Exploring categories for question: '{survey_question}'")
|
|
647
|
+
print(f"Using provider: {provider}, model: {model}")
|
|
648
|
+
if focus:
|
|
649
|
+
print(f"Focus: {focus}")
|
|
650
|
+
print(f" {categories_per_chunk * divisions * iterations} total category extractions across {iterations} iterations.")
|
|
651
|
+
print(f" Top {max_categories} categories will be identified.\n")
|
|
652
|
+
|
|
653
|
+
# RNG for reproducible re-sampling across passes
|
|
654
|
+
rng = np.random.default_rng(random_state)
|
|
655
|
+
|
|
656
|
+
# Initialize unified client
|
|
657
|
+
client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
|
|
658
|
+
|
|
659
|
+
# Build system message
|
|
660
|
+
if research_question:
|
|
661
|
+
system_content = (
|
|
662
|
+
f"You are a helpful assistant that extracts categories from text responses. "
|
|
663
|
+
f"The specific task is to identify {specificity} categories of responses to a text prompt. "
|
|
664
|
+
f"The research question is: {research_question}"
|
|
665
|
+
)
|
|
666
|
+
else:
|
|
667
|
+
system_content = "You are a helpful assistant that extracts categories from text responses."
|
|
668
|
+
|
|
669
|
+
def make_prompt(responses_blob: str) -> str:
|
|
670
|
+
focus_text = f" Focus specifically on {focus}." if focus else ""
|
|
671
|
+
return (
|
|
672
|
+
f'Identify {categories_per_chunk} {specificity} categories of responses to the question "{survey_question}" '
|
|
673
|
+
f"in the following list of responses.{focus_text} Responses are separated by semicolons. "
|
|
674
|
+
f"Responses are within triple backticks: ```{responses_blob}``` "
|
|
675
|
+
f"Number your categories from 1 through {categories_per_chunk} and provide concise labels only (no descriptions)."
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
# Parse numbered list
|
|
679
|
+
line_pat = re.compile(r"^\s*\d+\s*[\.\)\-]\s*(.+)$")
|
|
680
|
+
|
|
681
|
+
all_items = []
|
|
682
|
+
|
|
683
|
+
# Calculate total steps for progress tracking: (iterations * divisions) + 1 for final merge
|
|
684
|
+
total_steps = (iterations * divisions) + 1
|
|
685
|
+
current_step = 0
|
|
686
|
+
|
|
687
|
+
for pass_idx in range(iterations):
|
|
688
|
+
random_chunks = []
|
|
689
|
+
for _ in range(divisions):
|
|
690
|
+
seed = int(rng.integers(0, 2**32 - 1))
|
|
691
|
+
chunk = input_data.sample(n=chunk_size, random_state=seed).tolist()
|
|
692
|
+
random_chunks.append(chunk)
|
|
693
|
+
|
|
694
|
+
for i in tqdm(range(divisions), desc=f"Processing chunks (pass {pass_idx+1}/{iterations})"):
|
|
695
|
+
survey_participant_chunks = "; ".join(str(x) for x in random_chunks[i])
|
|
696
|
+
prompt = make_prompt(survey_participant_chunks)
|
|
697
|
+
|
|
698
|
+
messages = [
|
|
699
|
+
{"role": "system", "content": system_content},
|
|
700
|
+
{"role": "user", "content": prompt}
|
|
701
|
+
]
|
|
702
|
+
|
|
703
|
+
reply, error = client.complete(
|
|
704
|
+
messages=messages,
|
|
705
|
+
creativity=creativity,
|
|
706
|
+
force_json=False, # Text response, not JSON
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
if error:
|
|
710
|
+
raise RuntimeError(
|
|
711
|
+
f"Model call failed on pass {pass_idx+1}, chunk {i+1}: {error}"
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
items = []
|
|
715
|
+
for raw_line in (reply or "").splitlines():
|
|
716
|
+
m = line_pat.match(raw_line.strip())
|
|
717
|
+
if m:
|
|
718
|
+
items.append(m.group(1).strip())
|
|
719
|
+
if not items:
|
|
720
|
+
for raw_line in (reply or "").splitlines():
|
|
721
|
+
s = raw_line.strip()
|
|
722
|
+
if s:
|
|
723
|
+
items.append(s)
|
|
724
|
+
|
|
725
|
+
all_items.extend(items)
|
|
726
|
+
|
|
727
|
+
# Progress callback
|
|
728
|
+
current_step += 1
|
|
729
|
+
if progress_callback:
|
|
730
|
+
progress_callback(current_step, total_steps, f"Pass {pass_idx+1}/{iterations}, chunk {i+1}/{divisions}")
|
|
731
|
+
|
|
732
|
+
# Per-chunk delay to avoid rate limits
|
|
733
|
+
if chunk_delay > 0:
|
|
734
|
+
time.sleep(chunk_delay)
|
|
735
|
+
|
|
736
|
+
# Early return for raw output (used by explore())
|
|
737
|
+
if return_raw:
|
|
738
|
+
return all_items
|
|
739
|
+
|
|
740
|
+
# Normalize and count
|
|
741
|
+
def normalize_category(cat):
|
|
742
|
+
terms = sorted([t.strip().lower() for t in str(cat).split("/")])
|
|
743
|
+
return "/".join(terms)
|
|
744
|
+
|
|
745
|
+
flat_list = [str(x).strip() for x in all_items if str(x).strip()]
|
|
746
|
+
if not flat_list:
|
|
747
|
+
raise ValueError("No categories were extracted from the model responses.")
|
|
748
|
+
|
|
749
|
+
df = pd.DataFrame(flat_list, columns=["Category"])
|
|
750
|
+
df["normalized"] = df["Category"].map(normalize_category)
|
|
751
|
+
|
|
752
|
+
result = (
|
|
753
|
+
df.groupby("normalized")
|
|
754
|
+
.agg(Category=("Category", lambda x: x.value_counts().index[0]),
|
|
755
|
+
counts=("Category", "size"))
|
|
756
|
+
.sort_values("counts", ascending=False)
|
|
757
|
+
.reset_index(drop=True)
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
# Second-pass semantic merge prompt
|
|
761
|
+
seed_list = result["Category"].head(max_categories * 3).tolist()
|
|
762
|
+
|
|
763
|
+
second_prompt = f"""
|
|
764
|
+
You are a data analyst reviewing categorized text data.
|
|
765
|
+
|
|
766
|
+
Task: From the provided categories, identify and return the top {max_categories} CONCEPTUALLY UNIQUE categories.
|
|
767
|
+
|
|
768
|
+
Critical Instructions:
|
|
769
|
+
1) Exact duplicates are already removed.
|
|
770
|
+
2) Merge SEMANTIC duplicates (same concept, different wording). Examples:
|
|
771
|
+
- "closer to work" = "commute/proximity to work"
|
|
772
|
+
- "breakup/household conflict" = "relationship problems"
|
|
773
|
+
3) When merging:
|
|
774
|
+
- Combine frequencies mentally
|
|
775
|
+
- Keep the most frequent OR clearest label
|
|
776
|
+
- Each concept appears ONLY ONCE
|
|
777
|
+
4) Keep category names {specificity}.
|
|
778
|
+
5) Return ONLY a numbered list of {max_categories} categories. No extra text.
|
|
779
|
+
|
|
780
|
+
Pre-processed Categories (sorted by frequency, top sample):
|
|
781
|
+
{seed_list}
|
|
782
|
+
|
|
783
|
+
Output:
|
|
784
|
+
1. category
|
|
785
|
+
2. category
|
|
786
|
+
...
|
|
787
|
+
{max_categories}. category
|
|
788
|
+
""".strip()
|
|
789
|
+
|
|
790
|
+
# Second pass call
|
|
791
|
+
reply2, error2 = client.complete(
|
|
792
|
+
messages=[{"role": "user", "content": second_prompt}],
|
|
793
|
+
creativity=creativity,
|
|
794
|
+
force_json=False, # Text response
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
# Final progress callback for the merge step
|
|
798
|
+
if progress_callback:
|
|
799
|
+
progress_callback(total_steps, total_steps, "Merging categories")
|
|
800
|
+
|
|
801
|
+
if error2:
|
|
802
|
+
print(f"Warning: Second pass failed: {error2}")
|
|
803
|
+
top_categories_text = ""
|
|
804
|
+
else:
|
|
805
|
+
top_categories_text = reply2 or ""
|
|
806
|
+
|
|
807
|
+
final = []
|
|
808
|
+
for line in top_categories_text.splitlines():
|
|
809
|
+
m = line_pat.match(line.strip())
|
|
810
|
+
if m:
|
|
811
|
+
final.append(m.group(1).strip())
|
|
812
|
+
if not final:
|
|
813
|
+
final = [l.strip("-* ").strip() for l in top_categories_text.splitlines() if l.strip()]
|
|
814
|
+
|
|
815
|
+
# Fallback to counts_df if second pass failed
|
|
816
|
+
if not final:
|
|
817
|
+
final = result["Category"].head(max_categories).tolist()
|
|
818
|
+
|
|
819
|
+
print("\nTop categories:\n" + "\n".join(f"{i+1}. {c}" for i, c in enumerate(final[:max_categories])))
|
|
820
|
+
|
|
821
|
+
if filename:
|
|
822
|
+
result.to_csv(filename, index=False)
|
|
823
|
+
print(f"\nResults saved to {filename}")
|
|
824
|
+
|
|
825
|
+
return {
|
|
826
|
+
"counts_df": result,
|
|
827
|
+
"top_categories": final[:max_categories],
|
|
828
|
+
"raw_top_text": top_categories_text
|
|
829
|
+
}
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
# =============================================================================
|
|
833
|
+
# Main Classification Function
|
|
834
|
+
# =============================================================================
|
|
835
|
+
|
|
836
|
+
def multi_class(
|
|
837
|
+
input_data,
|
|
838
|
+
categories,
|
|
839
|
+
api_key: str = None,
|
|
840
|
+
model: str = "gpt-4o",
|
|
841
|
+
provider: str = "auto",
|
|
842
|
+
survey_question: str = "",
|
|
843
|
+
example1: str = None,
|
|
844
|
+
example2: str = None,
|
|
845
|
+
example3: str = None,
|
|
846
|
+
example4: str = None,
|
|
847
|
+
example5: str = None,
|
|
848
|
+
example6: str = None,
|
|
849
|
+
creativity: float = None,
|
|
850
|
+
safety: bool = False,
|
|
851
|
+
chain_of_verification: bool = False,
|
|
852
|
+
chain_of_thought: bool = False,
|
|
853
|
+
step_back_prompt: bool = False,
|
|
854
|
+
context_prompt: bool = False,
|
|
855
|
+
thinking_budget: int = 0,
|
|
856
|
+
max_categories: int = 12,
|
|
857
|
+
categories_per_chunk: int = 10,
|
|
858
|
+
divisions: int = 10,
|
|
859
|
+
research_question: str = None,
|
|
860
|
+
use_json_schema: bool = True,
|
|
861
|
+
filename: str = None,
|
|
862
|
+
save_directory: str = None,
|
|
863
|
+
auto_download: bool = False,
|
|
864
|
+
):
|
|
865
|
+
"""
|
|
866
|
+
Multi-class text classification using a unified HTTP-based approach.
|
|
867
|
+
|
|
868
|
+
This function uses raw HTTP requests for all providers, eliminating SDK dependencies.
|
|
869
|
+
Supports multiple prompting strategies including chain-of-thought, chain-of-verification,
|
|
870
|
+
step-back prompting, and context prompting.
|
|
871
|
+
|
|
872
|
+
Args:
|
|
873
|
+
input_data: List or Series of text responses to classify
|
|
874
|
+
categories: List of category names, or "auto" to auto-detect categories
|
|
875
|
+
api_key: API key for the LLM provider (not required for Ollama)
|
|
876
|
+
model: Model name (e.g., "gpt-4o", "claude-sonnet-4-5-20250929", "gemini-2.5-flash",
|
|
877
|
+
or any Ollama model like "llama3.2", "mistral", "phi3")
|
|
878
|
+
provider: Provider name or "auto" to detect from model name.
|
|
879
|
+
For local models, use provider="ollama"
|
|
880
|
+
survey_question: Optional context about what question was asked
|
|
881
|
+
example1-6: Optional few-shot examples for classification
|
|
882
|
+
creativity: Temperature setting (None for provider default)
|
|
883
|
+
safety: If True, saves results incrementally during processing
|
|
884
|
+
chain_of_verification: If True, uses 4-step CoVe prompting for verification
|
|
885
|
+
chain_of_thought: If True, uses step-by-step reasoning in prompt
|
|
886
|
+
step_back_prompt: If True, first asks about underlying factors before classifying
|
|
887
|
+
context_prompt: If True, adds expert context prefix to prompts
|
|
888
|
+
thinking_budget: Token budget for Google's extended thinking (0 to disable)
|
|
889
|
+
max_categories: Maximum categories when using auto-detection
|
|
890
|
+
categories_per_chunk: Categories per chunk for auto-detection
|
|
891
|
+
divisions: Number of divisions for auto-detection
|
|
892
|
+
research_question: Research context for auto-detection
|
|
893
|
+
use_json_schema: Whether to use strict JSON schema (vs just json_object mode)
|
|
894
|
+
filename: Optional CSV filename to save results
|
|
895
|
+
save_directory: Optional directory for safety saves
|
|
896
|
+
auto_download: If True, automatically download missing Ollama models
|
|
897
|
+
|
|
898
|
+
Returns:
|
|
899
|
+
DataFrame with classification results
|
|
900
|
+
|
|
901
|
+
Example with Ollama (local):
|
|
902
|
+
results = multi_class(
|
|
903
|
+
input_data=["I moved for work"],
|
|
904
|
+
categories=["Employment", "Family"],
|
|
905
|
+
model="llama3.2",
|
|
906
|
+
provider="ollama",
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
Example with cloud provider:
|
|
910
|
+
results = multi_class(
|
|
911
|
+
input_data=["I moved for work"],
|
|
912
|
+
categories=["Employment", "Family"],
|
|
913
|
+
api_key="your-api-key",
|
|
914
|
+
model="gpt-4o",
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
Example with chain-of-verification:
|
|
918
|
+
results = multi_class(
|
|
919
|
+
input_data=["I moved for work"],
|
|
920
|
+
categories=["Employment", "Family"],
|
|
921
|
+
api_key="your-api-key",
|
|
922
|
+
model="gpt-4o",
|
|
923
|
+
chain_of_verification=True,
|
|
924
|
+
survey_question="Why did you move?",
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
.. deprecated::
|
|
928
|
+
Use :func:`cat_stack.classify` instead. This function will be removed in a future version.
|
|
929
|
+
"""
|
|
930
|
+
warnings.warn(
|
|
931
|
+
"multi_class() is deprecated and will be removed in a future version. "
|
|
932
|
+
"Use cat_stack.classify() instead, which supports single and multi-model classification.",
|
|
933
|
+
DeprecationWarning,
|
|
934
|
+
stacklevel=2,
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
# Detect provider
|
|
938
|
+
provider = detect_provider(model, provider)
|
|
939
|
+
|
|
940
|
+
# Validate api_key requirement
|
|
941
|
+
if provider not in ("ollama", "claude-code") and not api_key:
|
|
942
|
+
raise ValueError(f"api_key is required for provider '{provider}'")
|
|
943
|
+
|
|
944
|
+
# Handle categories="auto" - auto-detect categories from the data
|
|
945
|
+
if categories == "auto":
|
|
946
|
+
if survey_question == "":
|
|
947
|
+
raise TypeError("survey_question is required when using categories='auto'. Please provide the survey question you are analyzing.")
|
|
948
|
+
|
|
949
|
+
categories = explore_common_categories(
|
|
950
|
+
survey_question=survey_question,
|
|
951
|
+
input_data=input_data,
|
|
952
|
+
research_question=research_question,
|
|
953
|
+
api_key=api_key,
|
|
954
|
+
model_source=provider,
|
|
955
|
+
user_model=model,
|
|
956
|
+
max_categories=max_categories,
|
|
957
|
+
categories_per_chunk=categories_per_chunk,
|
|
958
|
+
divisions=divisions
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
# Build examples text for few-shot prompting
|
|
962
|
+
examples = [example1, example2, example3, example4, example5, example6]
|
|
963
|
+
examples_text = "\n".join(
|
|
964
|
+
f"Example {i}: {ex}" for i, ex in enumerate(examples, 1) if ex is not None
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
# Survey question context
|
|
968
|
+
survey_question_context = f"Context: {survey_question}." if survey_question else ""
|
|
969
|
+
|
|
970
|
+
# Step-back insight initialization
|
|
971
|
+
stepback_insight = None
|
|
972
|
+
step_back_added = False
|
|
973
|
+
if step_back_prompt:
|
|
974
|
+
if survey_question == "":
|
|
975
|
+
raise TypeError("survey_question is required when using step_back_prompt. Please provide the survey question you are analyzing.")
|
|
976
|
+
|
|
977
|
+
stepback_question = f'What are the underlying factors or dimensions that explain how people typically answer "{survey_question}"?'
|
|
978
|
+
stepback_insight, step_back_added = _get_stepback_insight(
|
|
979
|
+
provider, stepback_question, api_key, model, creativity
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
# Ollama-specific checks
|
|
983
|
+
if provider == "ollama":
|
|
984
|
+
if not check_ollama_running():
|
|
985
|
+
raise ConnectionError(
|
|
986
|
+
"\n" + "="*60 + "\n"
|
|
987
|
+
" OLLAMA NOT RUNNING\n"
|
|
988
|
+
"="*60 + "\n\n"
|
|
989
|
+
"Ollama must be running to use local models.\n\n"
|
|
990
|
+
"To start Ollama:\n"
|
|
991
|
+
" macOS: Open the Ollama app, or run 'ollama serve'\n"
|
|
992
|
+
" Linux: Run 'ollama serve' in terminal\n"
|
|
993
|
+
" Windows: Open the Ollama app\n\n"
|
|
994
|
+
"Don't have Ollama installed?\n"
|
|
995
|
+
" Download from: https://ollama.ai/download\n\n"
|
|
996
|
+
"After starting Ollama, run your code again.\n"
|
|
997
|
+
+ "="*60
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
# Check system resources before proceeding
|
|
1001
|
+
resources = check_system_resources(model)
|
|
1002
|
+
|
|
1003
|
+
# Check if model needs to be downloaded
|
|
1004
|
+
model_installed = check_ollama_model(model)
|
|
1005
|
+
|
|
1006
|
+
if not model_installed:
|
|
1007
|
+
if not pull_ollama_model(model, auto_confirm=auto_download):
|
|
1008
|
+
raise RuntimeError(
|
|
1009
|
+
f"Model '{model}' not available. "
|
|
1010
|
+
f"To download manually: ollama pull {model}"
|
|
1011
|
+
)
|
|
1012
|
+
else:
|
|
1013
|
+
# Model is installed - still check if it can run
|
|
1014
|
+
if resources["warnings"] or not resources["can_run"]:
|
|
1015
|
+
print(f"\n{'='*60}")
|
|
1016
|
+
print(f" Model '{model}' - System Resource Check")
|
|
1017
|
+
print(f"{'='*60}")
|
|
1018
|
+
size_estimate = get_ollama_model_size_estimate(model)
|
|
1019
|
+
print(f" Model size: {size_estimate}")
|
|
1020
|
+
if resources["details"].get("estimated_ram"):
|
|
1021
|
+
print(f" RAM required: ~{resources['details']['estimated_ram']}")
|
|
1022
|
+
if resources["details"].get("total_ram"):
|
|
1023
|
+
print(f" System RAM: {resources['details']['total_ram']}")
|
|
1024
|
+
|
|
1025
|
+
if resources["warnings"]:
|
|
1026
|
+
print(f"\n {'!'*50}")
|
|
1027
|
+
for warning in resources["warnings"]:
|
|
1028
|
+
print(f" Warning: {warning}")
|
|
1029
|
+
print(f" {'!'*50}")
|
|
1030
|
+
|
|
1031
|
+
if not resources["can_run"]:
|
|
1032
|
+
print(f"\n Warning: Model may not run well on this system.")
|
|
1033
|
+
print(f" Consider a smaller variant (e.g., '{model}:1b' or '{model}:3b').")
|
|
1034
|
+
print(f"{'='*60}")
|
|
1035
|
+
|
|
1036
|
+
if not auto_download:
|
|
1037
|
+
try:
|
|
1038
|
+
response = input(f"\n Continue anyway? [y/N]: ").strip().lower()
|
|
1039
|
+
if response not in ['y', 'yes']:
|
|
1040
|
+
raise RuntimeError(
|
|
1041
|
+
f"Model '{model}' may be too large for this system. "
|
|
1042
|
+
f"Try a smaller variant like '{model}:3b' or '{model}:1b'."
|
|
1043
|
+
)
|
|
1044
|
+
except (EOFError, KeyboardInterrupt):
|
|
1045
|
+
raise RuntimeError("Operation cancelled by user.")
|
|
1046
|
+
|
|
1047
|
+
print()
|
|
1048
|
+
|
|
1049
|
+
print(f"Using provider: {provider}, model: {model}")
|
|
1050
|
+
|
|
1051
|
+
# Initialize client
|
|
1052
|
+
client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
|
|
1053
|
+
|
|
1054
|
+
# Build category string and schema
|
|
1055
|
+
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
|
1056
|
+
# Build JSON schema - Google doesn't support additionalProperties
|
|
1057
|
+
if use_json_schema:
|
|
1058
|
+
include_additional = (provider != "google")
|
|
1059
|
+
json_schema = build_json_schema(categories, include_additional_properties=include_additional)
|
|
1060
|
+
else:
|
|
1061
|
+
json_schema = None
|
|
1062
|
+
|
|
1063
|
+
# Print categories
|
|
1064
|
+
print(f"\nCategories to classify ({len(categories)} total):")
|
|
1065
|
+
for i, cat in enumerate(categories, 1):
|
|
1066
|
+
print(f" {i}. {cat}")
|
|
1067
|
+
print()
|
|
1068
|
+
|
|
1069
|
+
# Build prompt template
|
|
1070
|
+
def build_prompt(response_text: str) -> tuple:
|
|
1071
|
+
"""Build the classification prompt for a single response.
|
|
1072
|
+
|
|
1073
|
+
Returns:
|
|
1074
|
+
tuple: (messages list, user_prompt string for CoVe)
|
|
1075
|
+
"""
|
|
1076
|
+
if chain_of_thought:
|
|
1077
|
+
user_prompt = f"""{survey_question_context}
|
|
1078
|
+
|
|
1079
|
+
Categorize this text response "{response_text}" into the following categories that apply:
|
|
1080
|
+
{categories_str}
|
|
1081
|
+
|
|
1082
|
+
Let's think step by step:
|
|
1083
|
+
1. First, identify the main themes mentioned in the response
|
|
1084
|
+
2. Then, match each theme to the relevant categories
|
|
1085
|
+
3. Finally, assign 1 to matching categories and 0 to non-matching categories
|
|
1086
|
+
|
|
1087
|
+
{examples_text}
|
|
1088
|
+
|
|
1089
|
+
Provide your answer in JSON format where the category number is the key and "1" if present, "0" if not."""
|
|
1090
|
+
else:
|
|
1091
|
+
user_prompt = f"""{survey_question_context}
|
|
1092
|
+
Categorize this text response "{response_text}" into the following categories that apply:
|
|
1093
|
+
{categories_str}
|
|
1094
|
+
{examples_text}
|
|
1095
|
+
Provide your answer in JSON format where the category number is the key and "1" if present, "0" if not."""
|
|
1096
|
+
|
|
1097
|
+
# Add context prompt prefix if enabled
|
|
1098
|
+
if context_prompt:
|
|
1099
|
+
context = """You are an expert researcher in text data categorization.
|
|
1100
|
+
Apply multi-label classification and base decisions on explicit and implicit meanings.
|
|
1101
|
+
When uncertain, prioritize precision over recall.
|
|
1102
|
+
|
|
1103
|
+
"""
|
|
1104
|
+
user_prompt = context + user_prompt
|
|
1105
|
+
|
|
1106
|
+
# Build messages list
|
|
1107
|
+
messages = []
|
|
1108
|
+
|
|
1109
|
+
# Add step-back insight if available
|
|
1110
|
+
if step_back_prompt and step_back_added and stepback_insight:
|
|
1111
|
+
messages.append({"role": "user", "content": stepback_question})
|
|
1112
|
+
messages.append({"role": "assistant", "content": stepback_insight})
|
|
1113
|
+
|
|
1114
|
+
messages.append({"role": "user", "content": user_prompt})
|
|
1115
|
+
|
|
1116
|
+
return messages, user_prompt
|
|
1117
|
+
|
|
1118
|
+
# Build chain of verification prompts
|
|
1119
|
+
def build_cove_prompts(prompt: str, response_text: str) -> tuple:
|
|
1120
|
+
"""Build chain of verification prompts."""
|
|
1121
|
+
step2_prompt = f"""You provided this initial categorization:
|
|
1122
|
+
<<INITIAL_REPLY>>
|
|
1123
|
+
|
|
1124
|
+
Original task: {prompt}
|
|
1125
|
+
|
|
1126
|
+
Generate a focused list of 3-5 verification questions to fact-check your categorization. Each question should:
|
|
1127
|
+
- Be concise and specific (one sentence)
|
|
1128
|
+
- Address a distinct aspect of the categorization
|
|
1129
|
+
- Be answerable independently
|
|
1130
|
+
|
|
1131
|
+
Focus on verifying:
|
|
1132
|
+
- Whether each category assignment is accurate
|
|
1133
|
+
- Whether the categories match the criteria in the original task
|
|
1134
|
+
- Whether there are any logical inconsistencies
|
|
1135
|
+
|
|
1136
|
+
Provide only the verification questions as a numbered list."""
|
|
1137
|
+
|
|
1138
|
+
step3_prompt = f"""Answer the following verification question based on the text response provided.
|
|
1139
|
+
|
|
1140
|
+
Text response: {response_text}
|
|
1141
|
+
|
|
1142
|
+
Verification question: <<QUESTION>>
|
|
1143
|
+
|
|
1144
|
+
Provide a brief, direct answer (1-2 sentences maximum).
|
|
1145
|
+
|
|
1146
|
+
Answer:"""
|
|
1147
|
+
|
|
1148
|
+
step4_prompt = f"""Original task: {prompt}
|
|
1149
|
+
Initial categorization:
|
|
1150
|
+
<<INITIAL_REPLY>>
|
|
1151
|
+
Verification questions and answers:
|
|
1152
|
+
<<VERIFICATION_QA>>
|
|
1153
|
+
If no categories are present, assign "0" to all categories.
|
|
1154
|
+
Provide the final corrected categorization in the same JSON format:"""
|
|
1155
|
+
|
|
1156
|
+
return step2_prompt, step3_prompt, step4_prompt
|
|
1157
|
+
|
|
1158
|
+
def remove_numbering(line: str) -> str:
|
|
1159
|
+
"""Remove numbering/bullets from a line for CoVe question parsing."""
|
|
1160
|
+
line = line.strip()
|
|
1161
|
+
if line.startswith('- '):
|
|
1162
|
+
return line[2:].strip()
|
|
1163
|
+
if line.startswith('• '):
|
|
1164
|
+
return line[2:].strip()
|
|
1165
|
+
if line and line[0].isdigit():
|
|
1166
|
+
i = 0
|
|
1167
|
+
while i < len(line) and line[i].isdigit():
|
|
1168
|
+
i += 1
|
|
1169
|
+
if i < len(line) and line[i] in '.)':
|
|
1170
|
+
return line[i+1:].strip()
|
|
1171
|
+
return line
|
|
1172
|
+
|
|
1173
|
+
def run_chain_of_verification(initial_reply: str, step2_prompt: str, step3_prompt: str, step4_prompt: str) -> str:
|
|
1174
|
+
"""Run chain of verification using the unified client."""
|
|
1175
|
+
# Step 2: Generate verification questions (text response, not JSON)
|
|
1176
|
+
step2_filled = step2_prompt.replace("<<INITIAL_REPLY>>", initial_reply)
|
|
1177
|
+
questions_reply, err = client.complete(
|
|
1178
|
+
messages=[{"role": "user", "content": step2_filled}],
|
|
1179
|
+
creativity=creativity,
|
|
1180
|
+
force_json=False, # Text response
|
|
1181
|
+
)
|
|
1182
|
+
if err:
|
|
1183
|
+
return initial_reply # Fall back to initial reply on error
|
|
1184
|
+
|
|
1185
|
+
# Parse questions
|
|
1186
|
+
questions = [remove_numbering(line) for line in questions_reply.strip().split('\n') if line.strip()]
|
|
1187
|
+
|
|
1188
|
+
# Step 3: Answer each verification question (text responses)
|
|
1189
|
+
qa_pairs = []
|
|
1190
|
+
for question in questions[:5]: # Limit to 5 questions
|
|
1191
|
+
step3_filled = step3_prompt.replace("<<QUESTION>>", question)
|
|
1192
|
+
answer_reply, err = client.complete(
|
|
1193
|
+
messages=[{"role": "user", "content": step3_filled}],
|
|
1194
|
+
creativity=creativity,
|
|
1195
|
+
force_json=False, # Text response
|
|
1196
|
+
)
|
|
1197
|
+
if not err:
|
|
1198
|
+
qa_pairs.append(f"Q: {question}\nA: {answer_reply.strip()}")
|
|
1199
|
+
|
|
1200
|
+
verification_qa = "\n\n".join(qa_pairs)
|
|
1201
|
+
|
|
1202
|
+
# Step 4: Final corrected categorization (JSON response)
|
|
1203
|
+
step4_filled = step4_prompt.replace("<<INITIAL_REPLY>>", initial_reply).replace("<<VERIFICATION_QA>>", verification_qa)
|
|
1204
|
+
final_reply, err = client.complete(
|
|
1205
|
+
messages=[{"role": "user", "content": step4_filled}],
|
|
1206
|
+
json_schema=json_schema,
|
|
1207
|
+
creativity=creativity,
|
|
1208
|
+
)
|
|
1209
|
+
|
|
1210
|
+
if err:
|
|
1211
|
+
return initial_reply
|
|
1212
|
+
return final_reply
|
|
1213
|
+
|
|
1214
|
+
# Process each response
|
|
1215
|
+
results = []
|
|
1216
|
+
extracted_jsons = []
|
|
1217
|
+
|
|
1218
|
+
# Use two-step approach for Ollama (more reliable JSON output)
|
|
1219
|
+
use_two_step = (provider == "ollama")
|
|
1220
|
+
|
|
1221
|
+
if use_two_step:
|
|
1222
|
+
print("Using two-step classification for Ollama (classify -> format JSON)")
|
|
1223
|
+
|
|
1224
|
+
for idx, response in enumerate(tqdm(input_data, desc="Classifying responses")):
|
|
1225
|
+
if pd.isna(response):
|
|
1226
|
+
results.append(("Skipped NaN", "Skipped NaN input"))
|
|
1227
|
+
extracted_jsons.append('{"1":"e"}')
|
|
1228
|
+
continue
|
|
1229
|
+
|
|
1230
|
+
if use_two_step:
|
|
1231
|
+
json_result, error = ollama_two_step_classify(
|
|
1232
|
+
client=client,
|
|
1233
|
+
response_text=response,
|
|
1234
|
+
categories=categories,
|
|
1235
|
+
categories_str=categories_str,
|
|
1236
|
+
survey_question=survey_question,
|
|
1237
|
+
creativity=creativity,
|
|
1238
|
+
max_retries=5,
|
|
1239
|
+
)
|
|
1240
|
+
|
|
1241
|
+
if error:
|
|
1242
|
+
results.append((json_result, error))
|
|
1243
|
+
else:
|
|
1244
|
+
results.append((json_result, None))
|
|
1245
|
+
extracted_jsons.append(json_result)
|
|
1246
|
+
|
|
1247
|
+
else:
|
|
1248
|
+
messages, user_prompt = build_prompt(response)
|
|
1249
|
+
reply, error = client.complete(
|
|
1250
|
+
messages=messages,
|
|
1251
|
+
json_schema=json_schema,
|
|
1252
|
+
creativity=creativity,
|
|
1253
|
+
thinking_budget=thinking_budget if provider == "google" else None,
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
if error:
|
|
1257
|
+
results.append((None, error))
|
|
1258
|
+
extracted_jsons.append('{"1":"e"}')
|
|
1259
|
+
else:
|
|
1260
|
+
# Apply chain of verification if enabled
|
|
1261
|
+
if chain_of_verification and reply:
|
|
1262
|
+
step2, step3, step4 = build_cove_prompts(user_prompt, response)
|
|
1263
|
+
reply = run_chain_of_verification(reply, step2, step3, step4)
|
|
1264
|
+
|
|
1265
|
+
results.append((reply, None))
|
|
1266
|
+
extracted_jsons.append(extract_json(reply))
|
|
1267
|
+
|
|
1268
|
+
# Safety incremental save
|
|
1269
|
+
if safety:
|
|
1270
|
+
if filename is None:
|
|
1271
|
+
raise TypeError("filename is required when using safety=True. Please provide a filename to save to.")
|
|
1272
|
+
|
|
1273
|
+
# Build partial DataFrame and save
|
|
1274
|
+
normalized_partial = []
|
|
1275
|
+
for json_str in extracted_jsons:
|
|
1276
|
+
try:
|
|
1277
|
+
parsed = json.loads(json_str)
|
|
1278
|
+
normalized_partial.append(pd.json_normalize(parsed))
|
|
1279
|
+
except json.JSONDecodeError:
|
|
1280
|
+
normalized_partial.append(pd.DataFrame({"1": ["e"]}))
|
|
1281
|
+
|
|
1282
|
+
if normalized_partial:
|
|
1283
|
+
normalized_df = pd.concat(normalized_partial, ignore_index=True)
|
|
1284
|
+
partial_df = pd.DataFrame({
|
|
1285
|
+
'input_data': pd.Series(input_data[:len(results)]).reset_index(drop=True),
|
|
1286
|
+
'model_response': [r[0] for r in results],
|
|
1287
|
+
'error': [r[1] for r in results],
|
|
1288
|
+
'json': pd.Series(extracted_jsons).reset_index(drop=True),
|
|
1289
|
+
})
|
|
1290
|
+
partial_df = pd.concat([partial_df, normalized_df], axis=1)
|
|
1291
|
+
partial_df = partial_df.rename(columns=lambda x: f'category_{x}' if str(x).isdigit() else x)
|
|
1292
|
+
|
|
1293
|
+
save_path = filename
|
|
1294
|
+
if save_directory:
|
|
1295
|
+
import os
|
|
1296
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
1297
|
+
save_path = os.path.join(save_directory, filename)
|
|
1298
|
+
partial_df.to_csv(save_path, index=False)
|
|
1299
|
+
|
|
1300
|
+
# Build output DataFrame
|
|
1301
|
+
normalized_data_list = []
|
|
1302
|
+
for json_str in extracted_jsons:
|
|
1303
|
+
try:
|
|
1304
|
+
parsed = json.loads(json_str)
|
|
1305
|
+
normalized_data_list.append(pd.json_normalize(parsed))
|
|
1306
|
+
except json.JSONDecodeError:
|
|
1307
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
1308
|
+
|
|
1309
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
1310
|
+
|
|
1311
|
+
# Create main DataFrame
|
|
1312
|
+
df = pd.DataFrame({
|
|
1313
|
+
'input_data': pd.Series(input_data).reset_index(drop=True),
|
|
1314
|
+
'model_response': [r[0] for r in results],
|
|
1315
|
+
'error': [r[1] for r in results],
|
|
1316
|
+
'json': pd.Series(extracted_jsons).reset_index(drop=True),
|
|
1317
|
+
})
|
|
1318
|
+
|
|
1319
|
+
df = pd.concat([df, normalized_data], axis=1)
|
|
1320
|
+
|
|
1321
|
+
# Rename category columns
|
|
1322
|
+
df = df.rename(columns=lambda x: f'category_{x}' if str(x).isdigit() else x)
|
|
1323
|
+
|
|
1324
|
+
# Process category columns
|
|
1325
|
+
cat_cols = [col for col in df.columns if col.startswith('category_')]
|
|
1326
|
+
|
|
1327
|
+
# Identify invalid rows
|
|
1328
|
+
has_invalid = df[cat_cols].apply(
|
|
1329
|
+
lambda col: pd.to_numeric(col, errors='coerce').isna() & col.notna()
|
|
1330
|
+
).any(axis=1)
|
|
1331
|
+
|
|
1332
|
+
df['processing_status'] = (~has_invalid).map({True: 'success', False: 'error'})
|
|
1333
|
+
df.loc[has_invalid, cat_cols] = pd.NA
|
|
1334
|
+
|
|
1335
|
+
# Convert to numeric
|
|
1336
|
+
for col in cat_cols:
|
|
1337
|
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
|
1338
|
+
|
|
1339
|
+
# Fill NaN with 0 for valid rows
|
|
1340
|
+
df.loc[~has_invalid, cat_cols] = df.loc[~has_invalid, cat_cols].fillna(0)
|
|
1341
|
+
|
|
1342
|
+
# Convert to Int64
|
|
1343
|
+
df[cat_cols] = df[cat_cols].astype('Int64')
|
|
1344
|
+
|
|
1345
|
+
# Create categories_id
|
|
1346
|
+
df['categories_id'] = df[cat_cols].apply(
|
|
1347
|
+
lambda x: ','.join(x.dropna().astype(int).astype(str)), axis=1
|
|
1348
|
+
)
|
|
1349
|
+
|
|
1350
|
+
if filename:
|
|
1351
|
+
df.to_csv(filename, index=False)
|
|
1352
|
+
print(f"\nResults saved to {filename}")
|
|
1353
|
+
|
|
1354
|
+
return df
|
|
1355
|
+
|
|
1356
|
+
|
|
1357
|
+
# Note: For the legacy implementation with chain_of_verification, step_back_prompt,
|
|
1358
|
+
# and other advanced features, see text_functions_old.py
|