cat-llm 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
catllm/cat_llm.py CHANGED
@@ -393,1011 +393,3 @@ Provide your work in JSON format where the number belonging to each category is
393
393
 
394
394
  return categorized_data
395
395
 
396
- # image multi-class (binary) function
397
- def extract_image_multi_class(
398
- image_description,
399
- image_input,
400
- categories,
401
- api_key,
402
- columns="numbered",
403
- user_model="gpt-4o-2024-11-20",
404
- creativity=0,
405
- to_csv=False,
406
- safety=False,
407
- filename="categorized_data.csv",
408
- save_directory=None,
409
- model_source="OpenAI"
410
- ):
411
- import os
412
- import json
413
- import pandas as pd
414
- import regex
415
- from tqdm import tqdm
416
- import glob
417
- import base64
418
- from pathlib import Path
419
-
420
- if save_directory is not None and not os.path.isdir(save_directory):
421
- # Directory doesn't exist - raise an exception to halt execution
422
- raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
423
-
424
- image_extensions = [
425
- '*.png', '*.jpg', '*.jpeg',
426
- '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
427
- '*.tif', '*.tiff', '*.bmp',
428
- '*.heif', '*.heic', '*.ico',
429
- '*.psd'
430
- ]
431
-
432
- if not isinstance(image_input, list):
433
- # If image_input is a filepath (string)
434
- image_files = []
435
- for ext in image_extensions:
436
- image_files.extend(glob.glob(os.path.join(image_input, ext)))
437
-
438
- print(f"Found {len(image_files)} images.")
439
- else:
440
- # If image_files is already a list
441
- image_files = image_input
442
- print(f"Provided a list of {len(image_input)} images.")
443
-
444
- categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
445
- cat_num = len(categories)
446
- category_dict = {str(i+1): "0" for i in range(cat_num)}
447
- example_JSON = json.dumps(category_dict, indent=4)
448
-
449
- # ensure number of categories is what user wants
450
- print("Categories to classify:")
451
- for i, cat in enumerate(categories, 1):
452
- print(f"{i}. {cat}")
453
-
454
- link1 = []
455
- extracted_jsons = []
456
-
457
- for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
458
- # Check validity first
459
- if img_path is None or not os.path.exists(img_path):
460
- link1.append("Skipped NaN input or invalid path")
461
- extracted_jsons.append("""{"no_valid_image": 1}""")
462
- continue # Skip the rest of the loop iteration
463
-
464
- # Only open the file if path is valid
465
- with open(img_path, "rb") as f:
466
- encoded = base64.b64encode(f.read()).decode("utf-8")
467
-
468
- # Handle extension safely
469
- ext = Path(img_path).suffix.lstrip(".").lower()
470
- encoded_image = f"data:image/{ext};base64,{encoded}"
471
-
472
- prompt = [
473
- {
474
- "type": "text",
475
- "text": (
476
- f"You are an image-tagging assistant.\n"
477
- f"Task ► Examine the attached image and decide, **for each category below**, "
478
- f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
479
- f"Image is expected to show: {image_description}\n\n"
480
- f"Categories:\n{categories_str}\n\n"
481
- f"Output format ► Respond with **only** a JSON object whose keys are the "
482
- f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
483
- f"No additional keys, comments, or text.\n\n"
484
- f"Example (three categories):\n"
485
- f"{example_JSON}"
486
- ),
487
- },
488
- {
489
- "type": "image_url",
490
- "image_url": {"url": encoded_image, "detail": "high"},
491
- },
492
- ]
493
- if model_source == "OpenAI":
494
- from openai import OpenAI
495
- client = OpenAI(api_key=api_key)
496
- try:
497
- response_obj = client.chat.completions.create(
498
- model=user_model,
499
- messages=[{'role': 'user', 'content': prompt}],
500
- temperature=creativity
501
- )
502
- reply = response_obj.choices[0].message.content
503
- link1.append(reply)
504
- except Exception as e:
505
- print(f"An error occurred: {e}")
506
- link1.append(f"Error processing input: {e}")
507
-
508
- elif model_source == "Perplexity":
509
- from openai import OpenAI
510
- client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
511
- try:
512
- response_obj = client.chat.completions.create(
513
- model=user_model,
514
- messages=[{'role': 'user', 'content': prompt}],
515
- temperature=creativity
516
- )
517
- reply = response_obj.choices[0].message.content
518
- link1.append(reply)
519
- except Exception as e:
520
- print(f"An error occurred: {e}")
521
- link1.append(f"Error processing input: {e}")
522
- elif model_source == "Anthropic":
523
- import anthropic
524
- client = anthropic.Anthropic(api_key=api_key)
525
- try:
526
- message = client.messages.create(
527
- model=user_model,
528
- max_tokens=1024,
529
- temperature=creativity,
530
- messages=[{"role": "user", "content": prompt}]
531
- )
532
- reply = message.content[0].text # Anthropic returns content as list
533
- link1.append(reply)
534
- except Exception as e:
535
- print(f"An error occurred: {e}")
536
- link1.append(f"Error processing input: {e}")
537
- elif model_source == "Mistral":
538
- from mistralai import Mistral
539
- client = Mistral(api_key=api_key)
540
- try:
541
- response = client.chat.complete(
542
- model=user_model,
543
- messages=[
544
- {'role': 'user', 'content': prompt}
545
- ],
546
- temperature=creativity
547
- )
548
- reply = response.choices[0].message.content
549
- link1.append(reply)
550
- except Exception as e:
551
- print(f"An error occurred: {e}")
552
- link1.append(f"Error processing input: {e}")
553
- else:
554
- raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
555
- # in situation that no JSON is found
556
- if reply is not None:
557
- extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
558
- if extracted_json:
559
- cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
560
- extracted_jsons.append(cleaned_json)
561
- #print(cleaned_json)
562
- else:
563
- error_message = """{"1":"e"}"""
564
- extracted_jsons.append(error_message)
565
- print(error_message)
566
- else:
567
- error_message = """{"1":"e"}"""
568
- extracted_jsons.append(error_message)
569
- #print(error_message)
570
-
571
- # --- Safety Save ---
572
- if safety:
573
- #print(f"Saving CSV to: {save_directory}")
574
- # Save progress so far
575
- temp_df = pd.DataFrame({
576
- 'image_input': image_files[:i+1],
577
- 'link1': link1,
578
- 'json': extracted_jsons
579
- })
580
- # Normalize processed jsons so far
581
- normalized_data_list = []
582
- for json_str in extracted_jsons:
583
- try:
584
- parsed_obj = json.loads(json_str)
585
- normalized_data_list.append(pd.json_normalize(parsed_obj))
586
- except json.JSONDecodeError:
587
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
588
- normalized_data = pd.concat(normalized_data_list, ignore_index=True)
589
- temp_df = pd.concat([temp_df, normalized_data], axis=1)
590
- # Save to CSV
591
- if save_directory is None:
592
- save_directory = os.getcwd()
593
- temp_df.to_csv(os.path.join(save_directory, filename), index=False)
594
-
595
- # --- Final DataFrame ---
596
- normalized_data_list = []
597
- for json_str in extracted_jsons:
598
- try:
599
- parsed_obj = json.loads(json_str)
600
- normalized_data_list.append(pd.json_normalize(parsed_obj))
601
- except json.JSONDecodeError:
602
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
603
- normalized_data = pd.concat(normalized_data_list, ignore_index=True)
604
-
605
- categorized_data = pd.DataFrame({
606
- 'image_input': image_files,
607
- 'link1': pd.Series(link1).reset_index(drop=True),
608
- 'json': pd.Series(extracted_jsons).reset_index(drop=True)
609
- })
610
- categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
611
-
612
- if columns != "numbered": #if user wants text columns
613
- categorized_data.columns = list(categorized_data.columns[:3]) + categories[:len(categorized_data.columns) - 3]
614
-
615
- if to_csv:
616
- if save_directory is None:
617
- save_directory = os.getcwd()
618
- categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
619
-
620
- return categorized_data
621
-
622
- #image score function
623
- def extract_image_score(
624
- reference_image_description,
625
- image_input,
626
- reference_image,
627
- api_key,
628
- columns="numbered",
629
- user_model="gpt-4o-2024-11-20",
630
- creativity=0,
631
- to_csv=False,
632
- safety=False,
633
- filename="categorized_data.csv",
634
- save_directory=None,
635
- model_source="OpenAI"
636
- ):
637
- import os
638
- import json
639
- import pandas as pd
640
- import regex
641
- from tqdm import tqdm
642
- import glob
643
- import base64
644
- from pathlib import Path
645
-
646
- if save_directory is not None and not os.path.isdir(save_directory):
647
- # Directory doesn't exist - raise an exception to halt execution
648
- raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
649
-
650
- image_extensions = [
651
- '*.png', '*.jpg', '*.jpeg',
652
- '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
653
- '*.tif', '*.tiff', '*.bmp',
654
- '*.heif', '*.heic', '*.ico',
655
- '*.psd'
656
- ]
657
-
658
- if not isinstance(image_input, list):
659
- # If image_input is a filepath (string)
660
- image_files = []
661
- for ext in image_extensions:
662
- image_files.extend(glob.glob(os.path.join(image_input, ext)))
663
-
664
- print(f"Found {len(image_files)} images.")
665
- else:
666
- # If image_files is already a list
667
- image_files = image_input
668
- print(f"Provided a list of {len(image_input)} images.")
669
-
670
- with open(reference_image, 'rb') as f:
671
- reference_image = f"data:image/{reference_image.split('.')[-1]};base64,{base64.b64encode(f.read()).decode('utf-8')}"
672
-
673
- link1 = []
674
- extracted_jsons = []
675
-
676
- for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
677
- # Check validity first
678
- if img_path is None or not os.path.exists(img_path):
679
- link1.append("Skipped NaN input or invalid path")
680
- extracted_jsons.append("""{"no_valid_image": 1}""")
681
- continue # Skip the rest of the loop iteration
682
-
683
- # Only open the file if path is valid
684
- with open(img_path, "rb") as f:
685
- encoded = base64.b64encode(f.read()).decode("utf-8")
686
-
687
- # Handle extension safely
688
- ext = Path(img_path).suffix.lstrip(".").lower()
689
- encoded_image = f"data:image/{ext};base64,{encoded}"
690
-
691
- prompt = [
692
- {
693
- "type": "text",
694
- "text": (
695
- f"You are a visual similarity assessment system.\n"
696
- f"Task ► Compare these two images:\n"
697
- f"1. REFERENCE (left): {reference_image_description}\n"
698
- f"2. INPUT (right): User-provided drawing\n\n"
699
- f"Rating criteria:\n"
700
- f"1: No meaningful similarity (fundamentally different)\n"
701
- f"2: Barely recognizable similarity (25% match)\n"
702
- f"3: Partial match (50% key features)\n"
703
- f"4: Strong alignment (75% features)\n"
704
- f"5: Near-perfect match (90%+ similarity)\n\n"
705
- f"Output format ► Return ONLY:\n"
706
- "{\n"
707
- ' "score": [1-5],\n'
708
- ' "summary": "reason you scored"\n'
709
- "}\n\n"
710
- f"Critical rules:\n"
711
- f"- Score must reflect shape, proportions, and key details\n"
712
- f"- List only concrete matching elements from reference\n"
713
- f"- No markdown or additional text"
714
- ),
715
- },
716
- {"type": "image_url",
717
- "image_url": {"url": reference_image, "detail": "high"}
718
- },
719
- {
720
- "type": "image_url",
721
-
722
- "image_url": {"url": encoded_image, "detail": "high"},
723
- },
724
- ]
725
- if model_source == "OpenAI":
726
- from openai import OpenAI
727
- client = OpenAI(api_key=api_key)
728
- try:
729
- response_obj = client.chat.completions.create(
730
- model=user_model,
731
- messages=[{'role': 'user', 'content': prompt}],
732
- temperature=creativity
733
- )
734
- reply = response_obj.choices[0].message.content
735
- link1.append(reply)
736
- except Exception as e:
737
- print(f"An error occurred: {e}")
738
- link1.append(f"Error processing input: {e}")
739
-
740
- elif model_source == "Perplexity":
741
- from openai import OpenAI
742
- client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
743
- try:
744
- response_obj = client.chat.completions.create(
745
- model=user_model,
746
- messages=[{'role': 'user', 'content': prompt}],
747
- temperature=creativity
748
- )
749
- reply = response_obj.choices[0].message.content
750
- link1.append(reply)
751
- except Exception as e:
752
- print(f"An error occurred: {e}")
753
- link1.append(f"Error processing input: {e}")
754
- elif model_source == "Anthropic":
755
- import anthropic
756
- client = anthropic.Anthropic(api_key=api_key)
757
- try:
758
- message = client.messages.create(
759
- model=user_model,
760
- max_tokens=1024,
761
- temperature=creativity,
762
- messages=[{"role": "user", "content": prompt}]
763
- )
764
- reply = message.content[0].text # Anthropic returns content as list
765
- link1.append(reply)
766
- except Exception as e:
767
- print(f"An error occurred: {e}")
768
- link1.append(f"Error processing input: {e}")
769
- elif model_source == "Mistral":
770
- from mistralai import Mistral
771
- client = Mistral(api_key=api_key)
772
- try:
773
- response = client.chat.complete(
774
- model=user_model,
775
- messages=[
776
- {'role': 'user', 'content': prompt}
777
- ],
778
- temperature=creativity
779
- )
780
- reply = response.choices[0].message.content
781
- link1.append(reply)
782
- except Exception as e:
783
- print(f"An error occurred: {e}")
784
- link1.append(f"Error processing input: {e}")
785
- else:
786
- raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
787
- # in situation that no JSON is found
788
- if reply is not None:
789
- extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
790
- if extracted_json:
791
- cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
792
- extracted_jsons.append(cleaned_json)
793
- #print(cleaned_json)
794
- else:
795
- error_message = """{"1":"e"}"""
796
- extracted_jsons.append(error_message)
797
- print(error_message)
798
- else:
799
- error_message = """{"1":"e"}"""
800
- extracted_jsons.append(error_message)
801
- #print(error_message)
802
-
803
- # --- Safety Save ---
804
- if safety:
805
- # Save progress so far
806
- temp_df = pd.DataFrame({
807
- 'image_input': image_files[:i+1],
808
- 'link1': link1,
809
- 'json': extracted_jsons
810
- })
811
- # Normalize processed jsons so far
812
- normalized_data_list = []
813
- for json_str in extracted_jsons:
814
- try:
815
- parsed_obj = json.loads(json_str)
816
- normalized_data_list.append(pd.json_normalize(parsed_obj))
817
- except json.JSONDecodeError:
818
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
819
- normalized_data = pd.concat(normalized_data_list, ignore_index=True)
820
- temp_df = pd.concat([temp_df, normalized_data], axis=1)
821
- # Save to CSV
822
- if save_directory is None:
823
- save_directory = os.getcwd()
824
- temp_df.to_csv(os.path.join(save_directory, filename), index=False)
825
-
826
- # --- Final DataFrame ---
827
- normalized_data_list = []
828
- for json_str in extracted_jsons:
829
- try:
830
- parsed_obj = json.loads(json_str)
831
- normalized_data_list.append(pd.json_normalize(parsed_obj))
832
- except json.JSONDecodeError:
833
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
834
- normalized_data = pd.concat(normalized_data_list, ignore_index=True)
835
-
836
- categorized_data = pd.DataFrame({
837
- 'image_input': image_files,
838
- 'link1': pd.Series(link1).reset_index(drop=True),
839
- 'json': pd.Series(extracted_jsons).reset_index(drop=True)
840
- })
841
- categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
842
-
843
- if to_csv:
844
- if save_directory is None:
845
- save_directory = os.getcwd()
846
- categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
847
-
848
- return categorized_data
849
-
850
- # image features function
851
- def extract_image_features(
852
- image_description,
853
- image_input,
854
- features_to_extract,
855
- api_key,
856
- columns="numbered",
857
- user_model="gpt-4o-2024-11-20",
858
- creativity=0,
859
- to_csv=False,
860
- safety=False,
861
- filename="categorized_data.csv",
862
- save_directory=None,
863
- model_source="OpenAI"
864
- ):
865
- import os
866
- import json
867
- import pandas as pd
868
- import regex
869
- from tqdm import tqdm
870
- import glob
871
- import base64
872
- from pathlib import Path
873
-
874
- if save_directory is not None and not os.path.isdir(save_directory):
875
- # Directory doesn't exist - raise an exception to halt execution
876
- raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
877
-
878
- image_extensions = [
879
- '*.png', '*.jpg', '*.jpeg',
880
- '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
881
- '*.tif', '*.tiff', '*.bmp',
882
- '*.heif', '*.heic', '*.ico',
883
- '*.psd'
884
- ]
885
-
886
- if not isinstance(image_input, list):
887
- # If image_input is a filepath (string)
888
- image_files = []
889
- for ext in image_extensions:
890
- image_files.extend(glob.glob(os.path.join(image_input, ext)))
891
-
892
- print(f"Found {len(image_files)} images.")
893
- else:
894
- # If image_files is already a list
895
- image_files = image_input
896
- print(f"Provided a list of {len(image_input)} images.")
897
-
898
- categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(features_to_extract))
899
- cat_num = len(features_to_extract)
900
- category_dict = {str(i+1): "0" for i in range(cat_num)}
901
- example_JSON = json.dumps(category_dict, indent=4)
902
-
903
- # ensure number of categories is what user wants
904
- print("\nThe image features to be extracted are:")
905
- for i, cat in enumerate(features_to_extract, 1):
906
- print(f"{i}. {cat}")
907
-
908
- link1 = []
909
- extracted_jsons = []
910
-
911
- for i, img_path in enumerate(
912
- tqdm(image_files, desc="Categorising images"), start=0):
913
- # encode this specific image once
914
- with open(img_path, "rb") as f:
915
- encoded = base64.b64encode(f.read()).decode("utf-8")
916
- ext = Path(img_path).suffix.lstrip(".").lower()
917
- encoded_image = f"data:image/{ext};base64,{encoded}"
918
-
919
- prompt = [
920
- {
921
- "type": "text",
922
- "text": (
923
- f"You are a visual question answering assistant.\n"
924
- f"Task ► Analyze the attached image and answer these specific questions:\n\n"
925
- f"Image context: {image_description}\n\n"
926
- f"Questions to answer:\n{categories_str}\n\n"
927
- f"Output format ► Return **only** a JSON object where:\n"
928
- f"- Keys are question numbers ('1', '2', ...)\n"
929
- f"- Values are concise answers (numbers, short phrases)\n\n"
930
- f"Example for 3 questions:\n"
931
- "{\n"
932
- ' "1": "4",\n'
933
- ' "2": "blue",\n'
934
- ' "3": "yes"\n'
935
- "}\n\n"
936
- f"Important rules:\n"
937
- f"1. Answer directly - no explanations\n"
938
- f"2. Use exact numerical values when possible\n"
939
- f"3. For yes/no questions, use 'yes' or 'no'\n"
940
- f"4. Never add extra keys or formatting"
941
- ),
942
- },
943
- {
944
- "type": "image_url",
945
- "image_url": {"url": encoded_image, "detail": "high"},
946
- },
947
- ]
948
- if model_source == "OpenAI":
949
- from openai import OpenAI
950
- client = OpenAI(api_key=api_key)
951
- try:
952
- response_obj = client.chat.completions.create(
953
- model=user_model,
954
- messages=[{'role': 'user', 'content': prompt}],
955
- temperature=creativity
956
- )
957
- reply = response_obj.choices[0].message.content
958
- link1.append(reply)
959
- except Exception as e:
960
- print(f"An error occurred: {e}")
961
- link1.append(f"Error processing input: {e}")
962
-
963
- elif model_source == "Perplexity":
964
- from openai import OpenAI
965
- client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
966
- try:
967
- response_obj = client.chat.completions.create(
968
- model=user_model,
969
- messages=[{'role': 'user', 'content': prompt}],
970
- temperature=creativity
971
- )
972
- reply = response_obj.choices[0].message.content
973
- link1.append(reply)
974
- except Exception as e:
975
- print(f"An error occurred: {e}")
976
- link1.append(f"Error processing input: {e}")
977
- elif model_source == "Anthropic":
978
- import anthropic
979
- client = anthropic.Anthropic(api_key=api_key)
980
- try:
981
- message = client.messages.create(
982
- model=user_model,
983
- max_tokens=1024,
984
- temperature=creativity,
985
- messages=[{"role": "user", "content": prompt}]
986
- )
987
- reply = message.content[0].text # Anthropic returns content as list
988
- link1.append(reply)
989
- except Exception as e:
990
- print(f"An error occurred: {e}")
991
- link1.append(f"Error processing input: {e}")
992
- elif model_source == "Mistral":
993
- from mistralai import Mistral
994
- client = Mistral(api_key=api_key)
995
- try:
996
- response = client.chat.complete(
997
- model=user_model,
998
- messages=[
999
- {'role': 'user', 'content': prompt}
1000
- ],
1001
- temperature=creativity
1002
- )
1003
- reply = response.choices[0].message.content
1004
- link1.append(reply)
1005
- except Exception as e:
1006
- print(f"An error occurred: {e}")
1007
- link1.append(f"Error processing input: {e}")
1008
- else:
1009
- raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
1010
- # in situation that no JSON is found
1011
- if reply is not None:
1012
- extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
1013
- if extracted_json:
1014
- cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
1015
- extracted_jsons.append(cleaned_json)
1016
- #print(cleaned_json)
1017
- else:
1018
- error_message = """{"1":"e"}"""
1019
- extracted_jsons.append(error_message)
1020
- print(error_message)
1021
- else:
1022
- error_message = """{"1":"e"}"""
1023
- extracted_jsons.append(error_message)
1024
- #print(error_message)
1025
-
1026
- # --- Safety Save ---
1027
- if safety:
1028
- #print(f"Saving CSV to: {save_directory}")
1029
- # Save progress so far
1030
- temp_df = pd.DataFrame({
1031
- 'image_input': image_files[:i+1],
1032
- 'link1': link1,
1033
- 'json': extracted_jsons
1034
- })
1035
- # Normalize processed jsons so far
1036
- normalized_data_list = []
1037
- for json_str in extracted_jsons:
1038
- try:
1039
- parsed_obj = json.loads(json_str)
1040
- normalized_data_list.append(pd.json_normalize(parsed_obj))
1041
- except json.JSONDecodeError:
1042
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
1043
- normalized_data = pd.concat(normalized_data_list, ignore_index=True)
1044
- temp_df = pd.concat([temp_df, normalized_data], axis=1)
1045
- # Save to CSV
1046
- if save_directory is None:
1047
- save_directory = os.getcwd()
1048
- temp_df.to_csv(os.path.join(save_directory, filename), index=False)
1049
-
1050
- # --- Final DataFrame ---
1051
- normalized_data_list = []
1052
- for json_str in extracted_jsons:
1053
- try:
1054
- parsed_obj = json.loads(json_str)
1055
- normalized_data_list.append(pd.json_normalize(parsed_obj))
1056
- except json.JSONDecodeError:
1057
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
1058
- normalized_data = pd.concat(normalized_data_list, ignore_index=True)
1059
-
1060
- categorized_data = pd.DataFrame({
1061
- 'image_input': image_files,
1062
- 'link1': pd.Series(link1).reset_index(drop=True),
1063
- 'json': pd.Series(extracted_jsons).reset_index(drop=True)
1064
- })
1065
- categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
1066
-
1067
- if to_csv:
1068
- if save_directory is None:
1069
- save_directory = os.getcwd()
1070
- categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
1071
-
1072
- return categorized_data
1073
-
1074
- # image multi-class (binary) function
1075
- def cerad_score(
1076
- shape,
1077
- image_input,
1078
- api_key,
1079
- user_model="gpt-4o-2024-11-20",
1080
- creativity=0,
1081
- safety=False,
1082
- filename="categorized_data.csv",
1083
- model_source="OpenAI"
1084
- ):
1085
- import os
1086
- import json
1087
- import pandas as pd
1088
- import regex
1089
- from tqdm import tqdm
1090
- import glob
1091
- import base64
1092
- from pathlib import Path
1093
-
1094
- shape = shape.lower()
1095
-
1096
- if shape == "circle":
1097
- categories = ["The image contains a drawing that clearly represents a circle",
1098
- "The drawing does not resemble a circle",
1099
- "The drawing resembles a circle",
1100
- "The circle is closed",
1101
- "The circle is almost closed",
1102
- "The circle is circular",
1103
- "The circle is almost circular",
1104
- "None of the above descriptions apply"]
1105
- elif shape == "diamond":
1106
- categories = ["The image contains a drawing that clearly represents a diamond shape",
1107
- "It has a drawing of a square",
1108
- "A drawn shape DOES NOT resemble a diamond",
1109
- "A drawn shape resembles a diamond",
1110
- "The drawn shape has 4 sides",
1111
- "The drawn shape sides are about equal",
1112
- "If a diamond is drawn it's more elaborate than a simple diamond (such as overlapping diamonds or a diamond with an extras lines inside)",
1113
- "None of the above descriptions apply"]
1114
- elif shape == "rectangles" or shape == "overlapping rectangles":
1115
- categories = ["The image contains a drawing that clearly represents overlapping rectangles",
1116
- "A drawn shape DOES NOT resemble a overlapping rectangles",
1117
- "A drawn shape resembles a overlapping rectangles",
1118
- "Rectangle 1 has 4 sides",
1119
- "Rectangle 2 has 4 sides",
1120
- "The rectangles are overlapping",
1121
- "The rectangles overlap contains a longer vertical rectangle with top and bottom portruding",
1122
- "None of the above descriptions apply"]
1123
- elif shape == "cube":
1124
- categories = ["The image contains a drawing that clearly represents a cube (3D box shape)",
1125
- "The image does NOT contain any drawing that resembles a cube or 3D box",
1126
- "The image contains a WELL-DRAWN recognizable cube with proper 3D perspective",
1127
- "If a cube is present: the front face appears as a square or diamond shape",
1128
- "If a cube is present: internal/hidden edges are visible (showing 3D depth, not just an outline)",
1129
- "If a cube is present: the front and back faces appear parallel to each other",
1130
- "The image contains only a 2D square (flat shape, no 3D appearance)",
1131
- "None of the above descriptions apply"]
1132
- else:
1133
- raise ValueError("Invalid shape! Choose from 'circle', 'diamond', 'rectangles', or 'cube'.")
1134
-
1135
- image_extensions = [
1136
- '*.png', '*.jpg', '*.jpeg',
1137
- '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
1138
- '*.tif', '*.tiff', '*.bmp',
1139
- '*.heif', '*.heic', '*.ico',
1140
- '*.psd'
1141
- ]
1142
-
1143
- if not isinstance(image_input, list):
1144
- # If image_input is a filepath (string)
1145
- image_files = []
1146
- for ext in image_extensions:
1147
- image_files.extend(glob.glob(os.path.join(image_input, ext)))
1148
-
1149
- print(f"Found {len(image_files)} images.")
1150
- else:
1151
- # If image_files is already a list
1152
- image_files = image_input
1153
- print(f"Provided a list of {len(image_input)} images.")
1154
-
1155
- categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
1156
- cat_num = len(categories)
1157
- category_dict = {str(i+1): "0" for i in range(cat_num)}
1158
- example_JSON = json.dumps(category_dict, indent=4)
1159
-
1160
- link1 = []
1161
- extracted_jsons = []
1162
-
1163
- for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
1164
- # Check validity first
1165
- if img_path is None or not os.path.exists(img_path):
1166
- link1.append("Skipped NaN input or invalid path")
1167
- extracted_jsons.append("""{"no_valid_image": 1}""")
1168
- continue # Skip the rest of the loop iteration
1169
-
1170
- # Only open the file if path is valid
1171
- with open(img_path, "rb") as f:
1172
- encoded = base64.b64encode(f.read()).decode("utf-8")
1173
-
1174
- # Handle extension safely
1175
- ext = Path(img_path).suffix.lstrip(".").lower()
1176
- encoded_image = f"data:image/{ext};base64,{encoded}"
1177
-
1178
- prompt = [
1179
- {
1180
- "type": "text",
1181
- "text": (
1182
- f"You are an image-tagging assistant trained in the CERAD Constructional Praxis test.\n"
1183
- f"Task ► Examine the attached image and decide, **for each category below**, "
1184
- f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
1185
- f"Image is expected to show within it a drawing of a {shape}.\n\n"
1186
- f"Categories:\n{categories_str}\n\n"
1187
- f"Output format ► Respond with **only** a JSON object whose keys are the "
1188
- f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
1189
- f"No additional keys, comments, or text.\n\n"
1190
- f"Example:\n"
1191
- f"{example_JSON}"
1192
- ),
1193
- },
1194
- {
1195
- "type": "image_url",
1196
- "image_url": {"url": encoded_image, "detail": "high"},
1197
- },
1198
- ]
1199
- if model_source == "OpenAI":
1200
- from openai import OpenAI
1201
- client = OpenAI(api_key=api_key)
1202
- try:
1203
- response_obj = client.chat.completions.create(
1204
- model=user_model,
1205
- messages=[{'role': 'user', 'content': prompt}],
1206
- temperature=creativity
1207
- )
1208
- reply = response_obj.choices[0].message.content
1209
- link1.append(reply)
1210
- except Exception as e:
1211
- print(f"An error occurred: {e}")
1212
- link1.append(f"Error processing input: {e}")
1213
-
1214
- elif model_source == "Perplexity":
1215
- from openai import OpenAI
1216
- client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
1217
- try:
1218
- response_obj = client.chat.completions.create(
1219
- model=user_model,
1220
- messages=[{'role': 'user', 'content': prompt}],
1221
- temperature=creativity
1222
- )
1223
- reply = response_obj.choices[0].message.content
1224
- link1.append(reply)
1225
- except Exception as e:
1226
- print(f"An error occurred: {e}")
1227
- link1.append(f"Error processing input: {e}")
1228
- elif model_source == "Anthropic":
1229
- import anthropic
1230
- client = anthropic.Anthropic(api_key=api_key)
1231
- try:
1232
- message = client.messages.create(
1233
- model=user_model,
1234
- max_tokens=1024,
1235
- temperature=creativity,
1236
- messages=[{"role": "user", "content": prompt}]
1237
- )
1238
- reply = message.content[0].text # Anthropic returns content as list
1239
- link1.append(reply)
1240
- except Exception as e:
1241
- print(f"An error occurred: {e}")
1242
- link1.append(f"Error processing input: {e}")
1243
- elif model_source == "Mistral":
1244
- from mistralai import Mistral
1245
- client = Mistral(api_key=api_key)
1246
- try:
1247
- response = client.chat.complete(
1248
- model=user_model,
1249
- messages=[
1250
- {'role': 'user', 'content': prompt}
1251
- ],
1252
- temperature=creativity
1253
- )
1254
- reply = response.choices[0].message.content
1255
- link1.append(reply)
1256
- except Exception as e:
1257
- print(f"An error occurred: {e}")
1258
- link1.append(f"Error processing input: {e}")
1259
- else:
1260
- raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
1261
- # in situation that no JSON is found
1262
- if reply is not None:
1263
- extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
1264
- if extracted_json:
1265
- cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
1266
- extracted_jsons.append(cleaned_json)
1267
- #print(cleaned_json)
1268
- else:
1269
- error_message = """{"1":"e"}"""
1270
- extracted_jsons.append(error_message)
1271
- print(error_message)
1272
- else:
1273
- error_message = """{"1":"e"}"""
1274
- extracted_jsons.append(error_message)
1275
- #print(error_message)
1276
-
1277
- # --- Safety Save ---
1278
- if safety:
1279
- #print(f"Saving CSV to: {save_directory}")
1280
- # Save progress so far
1281
- temp_df = pd.DataFrame({
1282
- 'image_input': image_files[:i+1],
1283
- 'link1': link1,
1284
- 'json': extracted_jsons
1285
- })
1286
- # Normalize processed jsons so far
1287
- normalized_data_list = []
1288
- for json_str in extracted_jsons:
1289
- try:
1290
- parsed_obj = json.loads(json_str)
1291
- normalized_data_list.append(pd.json_normalize(parsed_obj))
1292
- except json.JSONDecodeError:
1293
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
1294
- normalized_data = pd.concat(normalized_data_list, ignore_index=True)
1295
- temp_df = pd.concat([temp_df, normalized_data], axis=1)
1296
- # Save to CSV
1297
- if filename is None:
1298
- filepath = os.path.join(os.getcwd(), 'catllm_data.csv')
1299
- else:
1300
- filepath = filename
1301
- temp_df.to_csv(filepath, index=False)
1302
-
1303
- # --- Final DataFrame ---
1304
- normalized_data_list = []
1305
- for json_str in extracted_jsons:
1306
- try:
1307
- parsed_obj = json.loads(json_str)
1308
- normalized_data_list.append(pd.json_normalize(parsed_obj))
1309
- except json.JSONDecodeError:
1310
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
1311
- normalized_data = pd.concat(normalized_data_list, ignore_index=True)
1312
-
1313
- categorized_data = pd.DataFrame({
1314
- 'image_input': image_files,
1315
- 'link1': pd.Series(link1).reset_index(drop=True),
1316
- 'json': pd.Series(extracted_jsons).reset_index(drop=True)
1317
- })
1318
- categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
1319
- columns_to_convert = ["1", "2", "3", "4", "5", "6", "7"]
1320
- categorized_data[columns_to_convert] = categorized_data[columns_to_convert].apply(pd.to_numeric, errors='coerce').fillna(0).astype(int)
1321
-
1322
- if shape == "circle":
1323
-
1324
- categorized_data = categorized_data.rename(columns={
1325
- "1": "drawing_present",
1326
- "2": "not_similar",
1327
- "3": "similar",
1328
- "4": "cir_closed",
1329
- "5": "cir_almost_closed",
1330
- "6": "cir_round",
1331
- "7": "cir_almost_round",
1332
- "8": "none"
1333
- })
1334
-
1335
- categorized_data['score'] = categorized_data['cir_almost_closed'] + categorized_data['cir_closed'] + categorized_data['cir_round'] + categorized_data['cir_almost_round']
1336
- categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
1337
- categorized_data.loc[(categorized_data['drawing_present'] == 0) & (categorized_data['score'] == 0), 'score'] = 0
1338
-
1339
- elif shape == "diamond":
1340
-
1341
- categorized_data = categorized_data.rename(columns={
1342
- "1": "drawing_present",
1343
- "2": "diamond_square",
1344
- "3": "not_similar",
1345
- "4": "similar",
1346
- "5": "diamond_4_sides",
1347
- "6": "diamond_equal_sides",
1348
- "7": "complex_diamond",
1349
- "8": "none"
1350
- })
1351
-
1352
- categorized_data['score'] = categorized_data['diamond_4_sides'] + categorized_data['diamond_equal_sides'] + categorized_data['similar']
1353
-
1354
- categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
1355
- categorized_data.loc[(categorized_data['diamond_square'] == 1) & (categorized_data['score'] == 0), 'score'] = 2
1356
-
1357
- elif shape == "rectangles" or shape == "overlapping rectangles":
1358
-
1359
- categorized_data = categorized_data.rename(columns={
1360
- "1":"drawing_present",
1361
- "2": "not_similar",
1362
- "3": "similar",
1363
- "4": "r1_4_sides",
1364
- "5": "r2_4_sides",
1365
- "6": "rectangles_overlap",
1366
- "7": "rectangles_cross",
1367
- "8": "none"
1368
- })
1369
-
1370
- categorized_data['score'] = 0
1371
- categorized_data.loc[(categorized_data['r1_4_sides'] == 1) & (categorized_data['r2_4_sides'] == 1), 'score'] = 1
1372
- categorized_data.loc[(categorized_data['rectangles_overlap'] == 1) & (categorized_data['rectangles_cross'] == 1), 'score'] += 1
1373
- categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
1374
-
1375
- elif shape == "cube":
1376
-
1377
- categorized_data = categorized_data.rename(columns={
1378
- "1": "drawing_present",
1379
- "2": "not_similar",
1380
- "3": "similar",
1381
- "4": "cube_front_face",
1382
- "5": "cube_internal_lines",
1383
- "6": "cube_opposite_sides",
1384
- "7": "square_only",
1385
- "8": "none"
1386
- })
1387
-
1388
- categorized_data['score'] = categorized_data['cube_front_face'] + categorized_data['cube_internal_lines'] + categorized_data['cube_opposite_sides'] + categorized_data['similar']
1389
- categorized_data.loc[categorized_data['similar'] == 1, 'score'] = categorized_data['score'] + 1
1390
- categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
1391
- categorized_data.loc[(categorized_data['drawing_present'] == 0) & (categorized_data['score'] == 0), 'score'] = 0
1392
- categorized_data.loc[(categorized_data['not_similar'] == 1) & (categorized_data['score'] == 0), 'score'] = 0
1393
- categorized_data.loc[categorized_data['score'] > 4, 'score'] = 4
1394
-
1395
- else:
1396
- raise ValueError("Invalid shape! Choose from 'circle', 'diamond', 'rectangles', or 'cube'.")
1397
-
1398
- categorized_data.loc[categorized_data['no_valid_image'] == 1, 'score'] = None
1399
-
1400
- if filename is not None:
1401
- categorized_data.to_csv(filename, index=False)
1402
-
1403
- return categorized_data