cat-llm 0.0.67__py3-none-any.whl → 0.0.69__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_llm-0.0.67.dist-info → cat_llm-0.0.69.dist-info}/METADATA +2 -2
- {cat_llm-0.0.67.dist-info → cat_llm-0.0.69.dist-info}/RECORD +10 -6
- catllm/__about__.py +1 -1
- catllm/calls/CoVe.py +304 -0
- catllm/calls/__init__.py +25 -0
- catllm/calls/all_calls.py +433 -0
- catllm/model_reference_list.py +94 -0
- catllm/text_functions.py +335 -42
- {cat_llm-0.0.67.dist-info → cat_llm-0.0.69.dist-info}/WHEEL +0 -0
- {cat_llm-0.0.67.dist-info → cat_llm-0.0.69.dist-info}/licenses/LICENSE +0 -0
catllm/text_functions.py
CHANGED
@@ -1,3 +1,15 @@
|
|
1
|
+
from .calls.all_calls import (
|
2
|
+
get_stepback_insight_openai,
|
3
|
+
get_stepback_insight_anthropic,
|
4
|
+
get_stepback_insight_google,
|
5
|
+
get_stepback_insight_mistral,
|
6
|
+
chain_of_verification_openai,
|
7
|
+
chain_of_verification_google,
|
8
|
+
chain_of_verification_anthropic,
|
9
|
+
chain_of_verification_mistral
|
10
|
+
)
|
11
|
+
|
12
|
+
|
1
13
|
#extract categories from corpus
|
2
14
|
def explore_corpus(
|
3
15
|
survey_question,
|
@@ -232,24 +244,27 @@ Return the top {top_n} categories as a numbered list sorted from the most to lea
|
|
232
244
|
# GOAL: enable step-back prompting
|
233
245
|
# GOAL 2: enable self-consistency
|
234
246
|
def multi_class(
|
235
|
-
survey_question,
|
236
247
|
survey_input,
|
237
248
|
categories,
|
238
249
|
api_key,
|
239
250
|
user_model="gpt-5",
|
240
251
|
user_prompt = None,
|
252
|
+
survey_question = "",
|
241
253
|
example1 = None,
|
242
254
|
example2 = None,
|
243
255
|
example3 = None,
|
244
256
|
example4 = None,
|
245
257
|
example5 = None,
|
246
258
|
example6 = None,
|
247
|
-
creativity=None,
|
248
|
-
safety=False,
|
249
|
-
to_csv=False,
|
250
|
-
|
251
|
-
|
252
|
-
|
259
|
+
creativity = None,
|
260
|
+
safety = False,
|
261
|
+
to_csv = False,
|
262
|
+
chain_of_verification = False,
|
263
|
+
step_back_prompt = False,
|
264
|
+
context_prompt = False,
|
265
|
+
filename = "categorized_data.csv",
|
266
|
+
save_directory = None,
|
267
|
+
model_source = "auto"
|
253
268
|
):
|
254
269
|
import os
|
255
270
|
import json
|
@@ -257,7 +272,54 @@ def multi_class(
|
|
257
272
|
import regex
|
258
273
|
from tqdm import tqdm
|
259
274
|
|
275
|
+
def remove_numbering(line):
|
276
|
+
line = line.strip()
|
277
|
+
|
278
|
+
# Handle bullet points
|
279
|
+
if line.startswith('- '):
|
280
|
+
return line[2:].strip()
|
281
|
+
if line.startswith('• '):
|
282
|
+
return line[2:].strip()
|
283
|
+
|
284
|
+
# Handle numbered lists "1.", "10.", etc.
|
285
|
+
if line and line[0].isdigit():
|
286
|
+
# Find where the number ends
|
287
|
+
i = 0
|
288
|
+
while i < len(line) and line[i].isdigit():
|
289
|
+
i += 1
|
290
|
+
|
291
|
+
# Check if followed by '.' or ')'
|
292
|
+
if i < len(line) and line[i] in '.':
|
293
|
+
return line[i+1:].strip()
|
294
|
+
elif i < len(line) and line[i] in ')':
|
295
|
+
return line[i+1:].strip()
|
296
|
+
|
297
|
+
return line
|
298
|
+
|
260
299
|
model_source = model_source.lower() # eliminating case sensitivity
|
300
|
+
|
301
|
+
# auto-detect model source if not provided
|
302
|
+
if model_source is None or model_source == "auto":
|
303
|
+
user_model_lower = user_model.lower()
|
304
|
+
|
305
|
+
if "gpt" in user_model_lower:
|
306
|
+
model_source = "openai"
|
307
|
+
elif "claude" in user_model_lower:
|
308
|
+
model_source = "anthropic"
|
309
|
+
elif "gemini" in user_model_lower or "gemma" in user_model_lower:
|
310
|
+
model_source = "google"
|
311
|
+
elif "llama" in user_model_lower or "meta" in user_model_lower:
|
312
|
+
model_source = "huggingface"
|
313
|
+
elif "mistral" in user_model_lower or "mixtral" in user_model_lower:
|
314
|
+
model_source = "mistral"
|
315
|
+
elif "sonar" in user_model_lower or "pplx" in user_model_lower:
|
316
|
+
model_source = "perplexity"
|
317
|
+
elif "deepseek" in user_model_lower or "qwen" in user_model_lower:
|
318
|
+
model_source = "huggingface"
|
319
|
+
else:
|
320
|
+
raise ValueError(f"❌ Could not auto-detect model source from '{user_model}'. Please specify model_source explicitly: OpenAI, Anthropic, Perplexity, Google, Huggingface, or Mistral")
|
321
|
+
else:
|
322
|
+
model_source = model_source.lower()
|
261
323
|
|
262
324
|
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
263
325
|
cat_num = len(categories)
|
@@ -265,17 +327,66 @@ def multi_class(
|
|
265
327
|
example_JSON = json.dumps(category_dict, indent=4)
|
266
328
|
|
267
329
|
# ensure number of categories is what user wants
|
268
|
-
print("\nThe categories you entered:")
|
330
|
+
print(f"\nThe categories you entered to be coded by {model_source} {user_model}:")
|
269
331
|
for i, cat in enumerate(categories, 1):
|
270
332
|
print(f"{i}. {cat}")
|
271
333
|
|
272
334
|
link1 = []
|
273
335
|
extracted_jsons = []
|
336
|
+
|
274
337
|
#handling example inputs
|
275
338
|
examples = [example1, example2, example3, example4, example5, example6]
|
276
339
|
examples_text = "\n".join(
|
277
340
|
f"Example {i}: {ex}" for i, ex in enumerate(examples, 1) if ex is not None
|
278
341
|
)
|
342
|
+
# allowing users to contextualize the survey question
|
343
|
+
if survey_question != None:
|
344
|
+
survey_question_context = f"A respondent was asked: {survey_question}."
|
345
|
+
else:
|
346
|
+
survey_question_context = ""
|
347
|
+
|
348
|
+
# step back insight initializationif step_back_prompt:
|
349
|
+
if step_back_prompt:
|
350
|
+
if survey_question == "": # step back requires the survey question to function well
|
351
|
+
raise TypeError("survey_question is required when using step_back_prompt. Please provide the survey question you are analyzing.")
|
352
|
+
|
353
|
+
stepback = f"""What are the underlying factors or dimensions that explain how people typically answer "{survey_question}"?"""
|
354
|
+
|
355
|
+
if model_source in ["openai", "perplexity", "huggingface"]:
|
356
|
+
stepback_insight, step_back_added = get_stepback_insight_openai(
|
357
|
+
stepback=stepback,
|
358
|
+
api_key=api_key,
|
359
|
+
user_model=user_model,
|
360
|
+
model_source=model_source,
|
361
|
+
creativity=creativity
|
362
|
+
)
|
363
|
+
elif model_source == "anthropic":
|
364
|
+
stepback_insight, step_back_added = get_stepback_insight_anthropic(
|
365
|
+
stepback=stepback,
|
366
|
+
api_key=api_key,
|
367
|
+
user_model=user_model,
|
368
|
+
model_source=model_source,
|
369
|
+
creativity=creativity
|
370
|
+
)
|
371
|
+
elif model_source == "google":
|
372
|
+
stepback_insight, step_back_added = get_stepback_insight_google(
|
373
|
+
stepback=stepback,
|
374
|
+
api_key=api_key,
|
375
|
+
user_model=user_model,
|
376
|
+
model_source=model_source,
|
377
|
+
creativity=creativity
|
378
|
+
)
|
379
|
+
elif model_source == "mistral":
|
380
|
+
stepback_insight, step_back_added = get_stepback_insight_mistral(
|
381
|
+
stepback=stepback,
|
382
|
+
api_key=api_key,
|
383
|
+
user_model=user_model,
|
384
|
+
model_source=model_source,
|
385
|
+
creativity=creativity
|
386
|
+
)
|
387
|
+
else:
|
388
|
+
stepback_insight = None
|
389
|
+
step_back_added = False
|
279
390
|
|
280
391
|
for idx, response in enumerate(tqdm(survey_input, desc="Categorizing responses")):
|
281
392
|
reply = None
|
@@ -287,58 +398,166 @@ def multi_class(
|
|
287
398
|
#print(f"Skipped NaN input.")
|
288
399
|
else:
|
289
400
|
|
290
|
-
prompt = f"""
|
401
|
+
prompt = f"""{survey_question_context} \
|
291
402
|
Categorize this survey response "{response}" into the following categories that apply: \
|
292
403
|
{categories_str}
|
293
404
|
{examples_text}
|
294
|
-
Provide your work in JSON format
|
405
|
+
Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values."""
|
406
|
+
|
407
|
+
if context_prompt:
|
408
|
+
context = """You are an expert researcher in survey data categorization.
|
409
|
+
Apply multi-label classification and base decisions on explicit and implicit meanings.
|
410
|
+
When uncertain, prioritize precision over recall."""
|
411
|
+
|
412
|
+
prompt = context + prompt
|
413
|
+
print(prompt)
|
414
|
+
|
415
|
+
if chain_of_verification:
|
416
|
+
step2_prompt = f"""You provided this initial categorization:
|
417
|
+
<<INITIAL_REPLY>>
|
418
|
+
|
419
|
+
Original task: {prompt}
|
420
|
+
|
421
|
+
Generate a focused list of 3-5 verification questions to fact-check your categorization. Each question should:
|
422
|
+
- Be concise and specific (one sentence)
|
423
|
+
- Address a distinct aspect of the categorization
|
424
|
+
- Be answerable independently
|
295
425
|
|
296
|
-
|
426
|
+
Focus on verifying:
|
427
|
+
- Whether each category assignment is accurate
|
428
|
+
- Whether the categories match the criteria in the original task
|
429
|
+
- Whether there are any logical inconsistencies
|
430
|
+
|
431
|
+
Provide only the verification questions as a numbered list."""
|
432
|
+
|
433
|
+
step3_prompt = f"""Answer the following verification question based on the survey response provided.
|
434
|
+
|
435
|
+
Survey response: {response}
|
436
|
+
|
437
|
+
Verification question: <<QUESTION>>
|
438
|
+
|
439
|
+
Provide a brief, direct answer (1-2 sentences maximum).
|
440
|
+
|
441
|
+
Answer:"""
|
442
|
+
|
443
|
+
|
444
|
+
step4_prompt = f"""Original task: {prompt}
|
445
|
+
Initial categorization:
|
446
|
+
<<INITIAL_REPLY>>
|
447
|
+
Verification questions and answers:
|
448
|
+
<<VERIFICATION_QA>>
|
449
|
+
If no categories are present, assign "0" to all categories.
|
450
|
+
Provide the final corrected categorization in the same JSON format:"""
|
451
|
+
|
452
|
+
# Main model interaction
|
453
|
+
if model_source in ["openai", "perplexity", "huggingface"]:
|
297
454
|
from openai import OpenAI
|
298
|
-
|
455
|
+
from openai import OpenAI, BadRequestError, AuthenticationError
|
456
|
+
# conditional base_url setting based on model source
|
457
|
+
base_url = (
|
458
|
+
"https://api.perplexity.ai" if model_source == "perplexity"
|
459
|
+
else "https://router.huggingface.co/v1" if model_source == "huggingface"
|
460
|
+
else None # default
|
461
|
+
)
|
462
|
+
|
463
|
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
464
|
+
|
299
465
|
try:
|
466
|
+
messages = [
|
467
|
+
*([{'role': 'user', 'content': stepback}] if step_back_prompt and step_back_added else []), # only if step back is enabled and successful
|
468
|
+
*([{'role': 'assistant', 'content': stepback_insight}] if step_back_added else {}), # include insight if step back succeeded
|
469
|
+
{'role': 'user', 'content': prompt}
|
470
|
+
]
|
471
|
+
|
300
472
|
response_obj = client.chat.completions.create(
|
301
473
|
model=user_model,
|
302
|
-
messages=
|
474
|
+
messages=messages,
|
303
475
|
**({"temperature": creativity} if creativity is not None else {})
|
304
|
-
)
|
305
|
-
reply = response_obj.choices[0].message.content
|
306
|
-
link1.append(reply)
|
307
|
-
except Exception as e:
|
308
|
-
print(f"An error occurred: {e}")
|
309
|
-
link1.append(f"Error processing input: {e}")
|
310
|
-
elif model_source == "perplexity":
|
311
|
-
from openai import OpenAI
|
312
|
-
client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
|
313
|
-
try:
|
314
|
-
response_obj = client.chat.completions.create(
|
315
|
-
model=user_model,
|
316
|
-
messages=[{'role': 'user', 'content': prompt}],
|
317
|
-
**({"temperature": creativity} if creativity is not None else {})
|
318
476
|
)
|
477
|
+
|
319
478
|
reply = response_obj.choices[0].message.content
|
320
|
-
|
479
|
+
|
480
|
+
if chain_of_verification:
|
481
|
+
reply = chain_of_verification_openai(
|
482
|
+
initial_reply=reply,
|
483
|
+
step2_prompt=step2_prompt,
|
484
|
+
step3_prompt=step3_prompt,
|
485
|
+
step4_prompt=step4_prompt,
|
486
|
+
client=client,
|
487
|
+
user_model=user_model,
|
488
|
+
creativity=creativity,
|
489
|
+
remove_numbering=remove_numbering
|
490
|
+
)
|
491
|
+
|
492
|
+
link1.append(reply)
|
493
|
+
else:
|
494
|
+
#if chain of verification is not enabled, just append initial reply
|
495
|
+
link1.append(reply)
|
496
|
+
|
497
|
+
except BadRequestError as e:
|
498
|
+
# Model doesn't exist - halt immediately
|
499
|
+
raise ValueError(f"❌ Model '{user_model}' on {model_source} not found. Please check the model name and try again.") from e
|
321
500
|
except Exception as e:
|
322
501
|
print(f"An error occurred: {e}")
|
323
502
|
link1.append(f"Error processing input: {e}")
|
503
|
+
|
324
504
|
elif model_source == "anthropic":
|
505
|
+
|
325
506
|
import anthropic
|
326
507
|
client = anthropic.Anthropic(api_key=api_key)
|
508
|
+
|
327
509
|
try:
|
328
|
-
|
510
|
+
response_obj = client.messages.create(
|
329
511
|
model=user_model,
|
330
|
-
max_tokens=
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
512
|
+
max_tokens=4096,
|
513
|
+
messages=[{'role': 'user', 'content': prompt}],
|
514
|
+
**({"temperature": creativity} if creativity is not None else {})
|
515
|
+
)
|
516
|
+
|
517
|
+
reply = response_obj.content[0].text
|
518
|
+
|
519
|
+
if chain_of_verification:
|
520
|
+
reply = chain_of_verification_anthropic(
|
521
|
+
initial_reply=reply,
|
522
|
+
step2_prompt=step2_prompt,
|
523
|
+
step3_prompt=step3_prompt,
|
524
|
+
step4_prompt=step4_prompt,
|
525
|
+
client=client,
|
526
|
+
user_model=user_model,
|
527
|
+
creativity=creativity,
|
528
|
+
remove_numbering=remove_numbering
|
529
|
+
)
|
530
|
+
|
531
|
+
link1.append(reply)
|
532
|
+
else:
|
533
|
+
#if chain of verification is not enabled, just append initial reply
|
534
|
+
link1.append(reply)
|
535
|
+
|
536
|
+
except anthropic.NotFoundError as e:
|
537
|
+
# Model doesn't exist - halt immediately
|
538
|
+
raise ValueError(f"❌ Model '{user_model}' on {model_source} not found. Please check the model name and try again.") from e
|
336
539
|
except Exception as e:
|
337
540
|
print(f"An error occurred: {e}")
|
338
541
|
link1.append(f"Error processing input: {e}")
|
339
542
|
|
340
543
|
elif model_source == "google":
|
341
544
|
import requests
|
545
|
+
|
546
|
+
def make_google_request(url, headers, payload, max_retries=3):
|
547
|
+
"""Make Google API request with exponential backoff on 429 errors"""
|
548
|
+
for attempt in range(max_retries):
|
549
|
+
try:
|
550
|
+
response = requests.post(url, headers=headers, json=payload)
|
551
|
+
response.raise_for_status()
|
552
|
+
return response.json()
|
553
|
+
except requests.exceptions.HTTPError as e:
|
554
|
+
if e.response.status_code == 429 and attempt < max_retries - 1:
|
555
|
+
wait_time = 10 * (2 ** attempt)
|
556
|
+
print(f"⚠️ Rate limited. Waiting {wait_time}s...")
|
557
|
+
time.sleep(wait_time)
|
558
|
+
else:
|
559
|
+
raise
|
560
|
+
|
342
561
|
url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
|
343
562
|
try:
|
344
563
|
headers = {
|
@@ -353,22 +572,49 @@ def multi_class(
|
|
353
572
|
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
354
573
|
}
|
355
574
|
|
356
|
-
|
357
|
-
response.raise_for_status() # Raise exception for HTTP errors
|
358
|
-
result = response.json()
|
575
|
+
result = make_google_request(url, headers, payload)
|
359
576
|
|
360
577
|
if "candidates" in result and result["candidates"]:
|
361
578
|
reply = result["candidates"][0]["content"]["parts"][0]["text"]
|
362
579
|
else:
|
363
580
|
reply = "No response generated"
|
364
581
|
|
365
|
-
|
582
|
+
if chain_of_verification:
|
583
|
+
reply = chain_of_verification_google(
|
584
|
+
initial_reply=reply,
|
585
|
+
prompt=prompt,
|
586
|
+
step2_prompt=step2_prompt,
|
587
|
+
step3_prompt=step3_prompt,
|
588
|
+
step4_prompt=step4_prompt,
|
589
|
+
url=url,
|
590
|
+
headers=headers,
|
591
|
+
creativity=creativity,
|
592
|
+
remove_numbering=remove_numbering,
|
593
|
+
make_google_request=make_google_request
|
594
|
+
)
|
595
|
+
|
596
|
+
link1.append(reply)
|
597
|
+
|
598
|
+
else:
|
599
|
+
# if chain of verification is not enabled, just append initial reply
|
600
|
+
link1.append(reply)
|
601
|
+
|
602
|
+
except requests.exceptions.HTTPError as e:
|
603
|
+
if e.response.status_code == 404:
|
604
|
+
raise ValueError(f"❌ Model '{user_model}' not found. Please check the model name and try again.") from e
|
605
|
+
elif e.response.status_code == 401 or e.response.status_code == 403:
|
606
|
+
raise ValueError(f"❌ Authentication failed. Please check your Google API key.") from e
|
607
|
+
else:
|
608
|
+
print(f"HTTP error occurred: {e}")
|
609
|
+
link1.append(f"Error processing input: {e}")
|
366
610
|
except Exception as e:
|
367
611
|
print(f"An error occurred: {e}")
|
368
612
|
link1.append(f"Error processing input: {e}")
|
369
613
|
|
370
614
|
elif model_source == "mistral":
|
371
615
|
from mistralai import Mistral
|
616
|
+
from mistralai.models import SDKError
|
617
|
+
|
372
618
|
client = Mistral(api_key=api_key)
|
373
619
|
try:
|
374
620
|
response = client.chat.complete(
|
@@ -379,12 +625,40 @@ def multi_class(
|
|
379
625
|
**({"temperature": creativity} if creativity is not None else {})
|
380
626
|
)
|
381
627
|
reply = response.choices[0].message.content
|
382
|
-
|
628
|
+
|
629
|
+
if chain_of_verification:
|
630
|
+
reply = chain_of_verification_mistral(
|
631
|
+
initial_reply=reply,
|
632
|
+
step2_prompt=step2_prompt,
|
633
|
+
step3_prompt=step3_prompt,
|
634
|
+
step4_prompt=step4_prompt,
|
635
|
+
client=client,
|
636
|
+
user_model=user_model,
|
637
|
+
creativity=creativity,
|
638
|
+
remove_numbering=remove_numbering
|
639
|
+
)
|
640
|
+
|
641
|
+
link1.append(reply)
|
642
|
+
|
643
|
+
else:
|
644
|
+
#if chain of verification is not enabled, just append initial reply
|
645
|
+
link1.append(reply)
|
646
|
+
|
647
|
+
except SDKError as e:
|
648
|
+
error_str = str(e).lower()
|
649
|
+
if "invalid_model" in error_str or "invalid model" in error_str:
|
650
|
+
raise ValueError(f"❌ Model '{user_model}' not found.") from e
|
651
|
+
elif "401" in str(e) or "unauthorized" in str(e).lower():
|
652
|
+
raise ValueError(f"❌ Authentication failed. Please check your Mistral API key.") from e
|
653
|
+
else:
|
654
|
+
print(f"An error occurred: {e}")
|
655
|
+
link1.append(f"Error processing input: {e}")
|
383
656
|
except Exception as e:
|
384
|
-
print(f"An error occurred: {e}")
|
657
|
+
print(f"An unexpected error occurred: {e}")
|
385
658
|
link1.append(f"Error processing input: {e}")
|
659
|
+
|
386
660
|
else:
|
387
|
-
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
|
661
|
+
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, Google, Huggingface, or Mistral")
|
388
662
|
# in situation that no JSON is found
|
389
663
|
if reply is not None:
|
390
664
|
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
@@ -442,6 +716,25 @@ def multi_class(
|
|
442
716
|
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
443
717
|
})
|
444
718
|
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
719
|
+
categorized_data = categorized_data.rename(columns=lambda x: f'category_{x}' if str(x).isdigit() else x)
|
720
|
+
|
721
|
+
#converting to numeric
|
722
|
+
cat_cols = [col for col in categorized_data.columns if col.startswith('category_')]
|
723
|
+
|
724
|
+
categorized_data['processing_status'] = np.where(
|
725
|
+
categorized_data[cat_cols].isna().all(axis=1),
|
726
|
+
'error',
|
727
|
+
'success'
|
728
|
+
)
|
729
|
+
|
730
|
+
categorized_data.loc[categorized_data[cat_cols].apply(pd.to_numeric, errors='coerce').isna().any(axis=1), cat_cols] = np.nan
|
731
|
+
categorized_data[cat_cols] = categorized_data[cat_cols].astype('Int64')
|
732
|
+
|
733
|
+
categorized_data['categories_present'] = categorized_data[cat_cols].apply(
|
734
|
+
lambda x: ','.join(x.dropna().astype(str)), axis=1
|
735
|
+
)
|
736
|
+
|
737
|
+
categorized_data['categories_counted'] = categorized_data[cat_cols].count(axis=1)
|
445
738
|
|
446
739
|
if to_csv:
|
447
740
|
if save_directory is None:
|
File without changes
|
File without changes
|