cat-llm 0.0.67__py3-none-any.whl → 0.0.69__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
catllm/text_functions.py CHANGED
@@ -1,3 +1,15 @@
1
+ from .calls.all_calls import (
2
+ get_stepback_insight_openai,
3
+ get_stepback_insight_anthropic,
4
+ get_stepback_insight_google,
5
+ get_stepback_insight_mistral,
6
+ chain_of_verification_openai,
7
+ chain_of_verification_google,
8
+ chain_of_verification_anthropic,
9
+ chain_of_verification_mistral
10
+ )
11
+
12
+
1
13
  #extract categories from corpus
2
14
  def explore_corpus(
3
15
  survey_question,
@@ -232,24 +244,27 @@ Return the top {top_n} categories as a numbered list sorted from the most to lea
232
244
  # GOAL: enable step-back prompting
233
245
  # GOAL 2: enable self-consistency
234
246
  def multi_class(
235
- survey_question,
236
247
  survey_input,
237
248
  categories,
238
249
  api_key,
239
250
  user_model="gpt-5",
240
251
  user_prompt = None,
252
+ survey_question = "",
241
253
  example1 = None,
242
254
  example2 = None,
243
255
  example3 = None,
244
256
  example4 = None,
245
257
  example5 = None,
246
258
  example6 = None,
247
- creativity=None,
248
- safety=False,
249
- to_csv=False,
250
- filename="categorized_data.csv",
251
- save_directory=None,
252
- model_source="OpenAI"
259
+ creativity = None,
260
+ safety = False,
261
+ to_csv = False,
262
+ chain_of_verification = False,
263
+ step_back_prompt = False,
264
+ context_prompt = False,
265
+ filename = "categorized_data.csv",
266
+ save_directory = None,
267
+ model_source = "auto"
253
268
  ):
254
269
  import os
255
270
  import json
@@ -257,7 +272,54 @@ def multi_class(
257
272
  import regex
258
273
  from tqdm import tqdm
259
274
 
275
+ def remove_numbering(line):
276
+ line = line.strip()
277
+
278
+ # Handle bullet points
279
+ if line.startswith('- '):
280
+ return line[2:].strip()
281
+ if line.startswith('• '):
282
+ return line[2:].strip()
283
+
284
+ # Handle numbered lists "1.", "10.", etc.
285
+ if line and line[0].isdigit():
286
+ # Find where the number ends
287
+ i = 0
288
+ while i < len(line) and line[i].isdigit():
289
+ i += 1
290
+
291
+ # Check if followed by '.' or ')'
292
+ if i < len(line) and line[i] in '.':
293
+ return line[i+1:].strip()
294
+ elif i < len(line) and line[i] in ')':
295
+ return line[i+1:].strip()
296
+
297
+ return line
298
+
260
299
  model_source = model_source.lower() # eliminating case sensitivity
300
+
301
+ # auto-detect model source if not provided
302
+ if model_source is None or model_source == "auto":
303
+ user_model_lower = user_model.lower()
304
+
305
+ if "gpt" in user_model_lower:
306
+ model_source = "openai"
307
+ elif "claude" in user_model_lower:
308
+ model_source = "anthropic"
309
+ elif "gemini" in user_model_lower or "gemma" in user_model_lower:
310
+ model_source = "google"
311
+ elif "llama" in user_model_lower or "meta" in user_model_lower:
312
+ model_source = "huggingface"
313
+ elif "mistral" in user_model_lower or "mixtral" in user_model_lower:
314
+ model_source = "mistral"
315
+ elif "sonar" in user_model_lower or "pplx" in user_model_lower:
316
+ model_source = "perplexity"
317
+ elif "deepseek" in user_model_lower or "qwen" in user_model_lower:
318
+ model_source = "huggingface"
319
+ else:
320
+ raise ValueError(f"❌ Could not auto-detect model source from '{user_model}'. Please specify model_source explicitly: OpenAI, Anthropic, Perplexity, Google, Huggingface, or Mistral")
321
+ else:
322
+ model_source = model_source.lower()
261
323
 
262
324
  categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
263
325
  cat_num = len(categories)
@@ -265,17 +327,66 @@ def multi_class(
265
327
  example_JSON = json.dumps(category_dict, indent=4)
266
328
 
267
329
  # ensure number of categories is what user wants
268
- print("\nThe categories you entered:")
330
+ print(f"\nThe categories you entered to be coded by {model_source} {user_model}:")
269
331
  for i, cat in enumerate(categories, 1):
270
332
  print(f"{i}. {cat}")
271
333
 
272
334
  link1 = []
273
335
  extracted_jsons = []
336
+
274
337
  #handling example inputs
275
338
  examples = [example1, example2, example3, example4, example5, example6]
276
339
  examples_text = "\n".join(
277
340
  f"Example {i}: {ex}" for i, ex in enumerate(examples, 1) if ex is not None
278
341
  )
342
+ # allowing users to contextualize the survey question
343
+ if survey_question != None:
344
+ survey_question_context = f"A respondent was asked: {survey_question}."
345
+ else:
346
+ survey_question_context = ""
347
+
348
+ # step back insight initializationif step_back_prompt:
349
+ if step_back_prompt:
350
+ if survey_question == "": # step back requires the survey question to function well
351
+ raise TypeError("survey_question is required when using step_back_prompt. Please provide the survey question you are analyzing.")
352
+
353
+ stepback = f"""What are the underlying factors or dimensions that explain how people typically answer "{survey_question}"?"""
354
+
355
+ if model_source in ["openai", "perplexity", "huggingface"]:
356
+ stepback_insight, step_back_added = get_stepback_insight_openai(
357
+ stepback=stepback,
358
+ api_key=api_key,
359
+ user_model=user_model,
360
+ model_source=model_source,
361
+ creativity=creativity
362
+ )
363
+ elif model_source == "anthropic":
364
+ stepback_insight, step_back_added = get_stepback_insight_anthropic(
365
+ stepback=stepback,
366
+ api_key=api_key,
367
+ user_model=user_model,
368
+ model_source=model_source,
369
+ creativity=creativity
370
+ )
371
+ elif model_source == "google":
372
+ stepback_insight, step_back_added = get_stepback_insight_google(
373
+ stepback=stepback,
374
+ api_key=api_key,
375
+ user_model=user_model,
376
+ model_source=model_source,
377
+ creativity=creativity
378
+ )
379
+ elif model_source == "mistral":
380
+ stepback_insight, step_back_added = get_stepback_insight_mistral(
381
+ stepback=stepback,
382
+ api_key=api_key,
383
+ user_model=user_model,
384
+ model_source=model_source,
385
+ creativity=creativity
386
+ )
387
+ else:
388
+ stepback_insight = None
389
+ step_back_added = False
279
390
 
280
391
  for idx, response in enumerate(tqdm(survey_input, desc="Categorizing responses")):
281
392
  reply = None
@@ -287,58 +398,166 @@ def multi_class(
287
398
  #print(f"Skipped NaN input.")
288
399
  else:
289
400
 
290
- prompt = f"""A respondent was asked: {survey_question}. \
401
+ prompt = f"""{survey_question_context} \
291
402
  Categorize this survey response "{response}" into the following categories that apply: \
292
403
  {categories_str}
293
404
  {examples_text}
294
- Provide your work in JSON format..."""
405
+ Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values."""
406
+
407
+ if context_prompt:
408
+ context = """You are an expert researcher in survey data categorization.
409
+ Apply multi-label classification and base decisions on explicit and implicit meanings.
410
+ When uncertain, prioritize precision over recall."""
411
+
412
+ prompt = context + prompt
413
+ print(prompt)
414
+
415
+ if chain_of_verification:
416
+ step2_prompt = f"""You provided this initial categorization:
417
+ <<INITIAL_REPLY>>
418
+
419
+ Original task: {prompt}
420
+
421
+ Generate a focused list of 3-5 verification questions to fact-check your categorization. Each question should:
422
+ - Be concise and specific (one sentence)
423
+ - Address a distinct aspect of the categorization
424
+ - Be answerable independently
295
425
 
296
- if model_source == ("openai"):
426
+ Focus on verifying:
427
+ - Whether each category assignment is accurate
428
+ - Whether the categories match the criteria in the original task
429
+ - Whether there are any logical inconsistencies
430
+
431
+ Provide only the verification questions as a numbered list."""
432
+
433
+ step3_prompt = f"""Answer the following verification question based on the survey response provided.
434
+
435
+ Survey response: {response}
436
+
437
+ Verification question: <<QUESTION>>
438
+
439
+ Provide a brief, direct answer (1-2 sentences maximum).
440
+
441
+ Answer:"""
442
+
443
+
444
+ step4_prompt = f"""Original task: {prompt}
445
+ Initial categorization:
446
+ <<INITIAL_REPLY>>
447
+ Verification questions and answers:
448
+ <<VERIFICATION_QA>>
449
+ If no categories are present, assign "0" to all categories.
450
+ Provide the final corrected categorization in the same JSON format:"""
451
+
452
+ # Main model interaction
453
+ if model_source in ["openai", "perplexity", "huggingface"]:
297
454
  from openai import OpenAI
298
- client = OpenAI(api_key=api_key)
455
+ from openai import OpenAI, BadRequestError, AuthenticationError
456
+ # conditional base_url setting based on model source
457
+ base_url = (
458
+ "https://api.perplexity.ai" if model_source == "perplexity"
459
+ else "https://router.huggingface.co/v1" if model_source == "huggingface"
460
+ else None # default
461
+ )
462
+
463
+ client = OpenAI(api_key=api_key, base_url=base_url)
464
+
299
465
  try:
466
+ messages = [
467
+ *([{'role': 'user', 'content': stepback}] if step_back_prompt and step_back_added else []), # only if step back is enabled and successful
468
+ *([{'role': 'assistant', 'content': stepback_insight}] if step_back_added else {}), # include insight if step back succeeded
469
+ {'role': 'user', 'content': prompt}
470
+ ]
471
+
300
472
  response_obj = client.chat.completions.create(
301
473
  model=user_model,
302
- messages=[{'role': 'user', 'content': prompt}],
474
+ messages=messages,
303
475
  **({"temperature": creativity} if creativity is not None else {})
304
- )
305
- reply = response_obj.choices[0].message.content
306
- link1.append(reply)
307
- except Exception as e:
308
- print(f"An error occurred: {e}")
309
- link1.append(f"Error processing input: {e}")
310
- elif model_source == "perplexity":
311
- from openai import OpenAI
312
- client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
313
- try:
314
- response_obj = client.chat.completions.create(
315
- model=user_model,
316
- messages=[{'role': 'user', 'content': prompt}],
317
- **({"temperature": creativity} if creativity is not None else {})
318
476
  )
477
+
319
478
  reply = response_obj.choices[0].message.content
320
- link1.append(reply)
479
+
480
+ if chain_of_verification:
481
+ reply = chain_of_verification_openai(
482
+ initial_reply=reply,
483
+ step2_prompt=step2_prompt,
484
+ step3_prompt=step3_prompt,
485
+ step4_prompt=step4_prompt,
486
+ client=client,
487
+ user_model=user_model,
488
+ creativity=creativity,
489
+ remove_numbering=remove_numbering
490
+ )
491
+
492
+ link1.append(reply)
493
+ else:
494
+ #if chain of verification is not enabled, just append initial reply
495
+ link1.append(reply)
496
+
497
+ except BadRequestError as e:
498
+ # Model doesn't exist - halt immediately
499
+ raise ValueError(f"❌ Model '{user_model}' on {model_source} not found. Please check the model name and try again.") from e
321
500
  except Exception as e:
322
501
  print(f"An error occurred: {e}")
323
502
  link1.append(f"Error processing input: {e}")
503
+
324
504
  elif model_source == "anthropic":
505
+
325
506
  import anthropic
326
507
  client = anthropic.Anthropic(api_key=api_key)
508
+
327
509
  try:
328
- message = client.messages.create(
510
+ response_obj = client.messages.create(
329
511
  model=user_model,
330
- max_tokens=1024,
331
- **({"temperature": creativity} if creativity is not None else {}),
332
- messages=[{"role": "user", "content": prompt}]
333
- )
334
- reply = message.content[0].text # Anthropic returns content as list
335
- link1.append(reply)
512
+ max_tokens=4096,
513
+ messages=[{'role': 'user', 'content': prompt}],
514
+ **({"temperature": creativity} if creativity is not None else {})
515
+ )
516
+
517
+ reply = response_obj.content[0].text
518
+
519
+ if chain_of_verification:
520
+ reply = chain_of_verification_anthropic(
521
+ initial_reply=reply,
522
+ step2_prompt=step2_prompt,
523
+ step3_prompt=step3_prompt,
524
+ step4_prompt=step4_prompt,
525
+ client=client,
526
+ user_model=user_model,
527
+ creativity=creativity,
528
+ remove_numbering=remove_numbering
529
+ )
530
+
531
+ link1.append(reply)
532
+ else:
533
+ #if chain of verification is not enabled, just append initial reply
534
+ link1.append(reply)
535
+
536
+ except anthropic.NotFoundError as e:
537
+ # Model doesn't exist - halt immediately
538
+ raise ValueError(f"❌ Model '{user_model}' on {model_source} not found. Please check the model name and try again.") from e
336
539
  except Exception as e:
337
540
  print(f"An error occurred: {e}")
338
541
  link1.append(f"Error processing input: {e}")
339
542
 
340
543
  elif model_source == "google":
341
544
  import requests
545
+
546
+ def make_google_request(url, headers, payload, max_retries=3):
547
+ """Make Google API request with exponential backoff on 429 errors"""
548
+ for attempt in range(max_retries):
549
+ try:
550
+ response = requests.post(url, headers=headers, json=payload)
551
+ response.raise_for_status()
552
+ return response.json()
553
+ except requests.exceptions.HTTPError as e:
554
+ if e.response.status_code == 429 and attempt < max_retries - 1:
555
+ wait_time = 10 * (2 ** attempt)
556
+ print(f"⚠️ Rate limited. Waiting {wait_time}s...")
557
+ time.sleep(wait_time)
558
+ else:
559
+ raise
560
+
342
561
  url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
343
562
  try:
344
563
  headers = {
@@ -353,22 +572,49 @@ def multi_class(
353
572
  **({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
354
573
  }
355
574
 
356
- response = requests.post(url, headers=headers, json=payload)
357
- response.raise_for_status() # Raise exception for HTTP errors
358
- result = response.json()
575
+ result = make_google_request(url, headers, payload)
359
576
 
360
577
  if "candidates" in result and result["candidates"]:
361
578
  reply = result["candidates"][0]["content"]["parts"][0]["text"]
362
579
  else:
363
580
  reply = "No response generated"
364
581
 
365
- link1.append(reply)
582
+ if chain_of_verification:
583
+ reply = chain_of_verification_google(
584
+ initial_reply=reply,
585
+ prompt=prompt,
586
+ step2_prompt=step2_prompt,
587
+ step3_prompt=step3_prompt,
588
+ step4_prompt=step4_prompt,
589
+ url=url,
590
+ headers=headers,
591
+ creativity=creativity,
592
+ remove_numbering=remove_numbering,
593
+ make_google_request=make_google_request
594
+ )
595
+
596
+ link1.append(reply)
597
+
598
+ else:
599
+ # if chain of verification is not enabled, just append initial reply
600
+ link1.append(reply)
601
+
602
+ except requests.exceptions.HTTPError as e:
603
+ if e.response.status_code == 404:
604
+ raise ValueError(f"❌ Model '{user_model}' not found. Please check the model name and try again.") from e
605
+ elif e.response.status_code == 401 or e.response.status_code == 403:
606
+ raise ValueError(f"❌ Authentication failed. Please check your Google API key.") from e
607
+ else:
608
+ print(f"HTTP error occurred: {e}")
609
+ link1.append(f"Error processing input: {e}")
366
610
  except Exception as e:
367
611
  print(f"An error occurred: {e}")
368
612
  link1.append(f"Error processing input: {e}")
369
613
 
370
614
  elif model_source == "mistral":
371
615
  from mistralai import Mistral
616
+ from mistralai.models import SDKError
617
+
372
618
  client = Mistral(api_key=api_key)
373
619
  try:
374
620
  response = client.chat.complete(
@@ -379,12 +625,40 @@ def multi_class(
379
625
  **({"temperature": creativity} if creativity is not None else {})
380
626
  )
381
627
  reply = response.choices[0].message.content
382
- link1.append(reply)
628
+
629
+ if chain_of_verification:
630
+ reply = chain_of_verification_mistral(
631
+ initial_reply=reply,
632
+ step2_prompt=step2_prompt,
633
+ step3_prompt=step3_prompt,
634
+ step4_prompt=step4_prompt,
635
+ client=client,
636
+ user_model=user_model,
637
+ creativity=creativity,
638
+ remove_numbering=remove_numbering
639
+ )
640
+
641
+ link1.append(reply)
642
+
643
+ else:
644
+ #if chain of verification is not enabled, just append initial reply
645
+ link1.append(reply)
646
+
647
+ except SDKError as e:
648
+ error_str = str(e).lower()
649
+ if "invalid_model" in error_str or "invalid model" in error_str:
650
+ raise ValueError(f"❌ Model '{user_model}' not found.") from e
651
+ elif "401" in str(e) or "unauthorized" in str(e).lower():
652
+ raise ValueError(f"❌ Authentication failed. Please check your Mistral API key.") from e
653
+ else:
654
+ print(f"An error occurred: {e}")
655
+ link1.append(f"Error processing input: {e}")
383
656
  except Exception as e:
384
- print(f"An error occurred: {e}")
657
+ print(f"An unexpected error occurred: {e}")
385
658
  link1.append(f"Error processing input: {e}")
659
+
386
660
  else:
387
- raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
661
+ raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, Google, Huggingface, or Mistral")
388
662
  # in situation that no JSON is found
389
663
  if reply is not None:
390
664
  extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
@@ -442,6 +716,25 @@ def multi_class(
442
716
  'json': pd.Series(extracted_jsons).reset_index(drop=True)
443
717
  })
444
718
  categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
719
+ categorized_data = categorized_data.rename(columns=lambda x: f'category_{x}' if str(x).isdigit() else x)
720
+
721
+ #converting to numeric
722
+ cat_cols = [col for col in categorized_data.columns if col.startswith('category_')]
723
+
724
+ categorized_data['processing_status'] = np.where(
725
+ categorized_data[cat_cols].isna().all(axis=1),
726
+ 'error',
727
+ 'success'
728
+ )
729
+
730
+ categorized_data.loc[categorized_data[cat_cols].apply(pd.to_numeric, errors='coerce').isna().any(axis=1), cat_cols] = np.nan
731
+ categorized_data[cat_cols] = categorized_data[cat_cols].astype('Int64')
732
+
733
+ categorized_data['categories_present'] = categorized_data[cat_cols].apply(
734
+ lambda x: ','.join(x.dropna().astype(str)), axis=1
735
+ )
736
+
737
+ categorized_data['categories_counted'] = categorized_data[cat_cols].count(axis=1)
445
738
 
446
739
  if to_csv:
447
740
  if save_directory is None: