py2ls 0.1.10.27__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/ips.py CHANGED
@@ -793,6 +793,99 @@ def str2html(text_list, strict=False):
793
793
  return html_content
794
794
 
795
795
 
796
+ def cm2px(*cm, dpi=300) -> list:
797
+ # Case 1: When the user passes a single argument that is a list or tuple, such as cm2px([8, 5]) or inch2cm((8, 5))
798
+ if len(cm) == 1 and isinstance(cm[0], (list, tuple)):
799
+ # If the input is a single list or tuple, we unpack its elements and convert each to cm
800
+ return [i / 2.54 * dpi for i in cm[0]]
801
+ # Case 2: When the user passes multiple arguments directly, such as cm2px(8, 5)
802
+ else:
803
+ return [i / 2.54 * dpi for i in cm]
804
+
805
+
806
+ def px2cm(*px, dpi=300) -> list:
807
+ # Case 1: When the user passes a single argument that is a list or tuple, such as px2cm([8, 5]) or inch2cm((8, 5))
808
+ if len(px) == 1 and isinstance(px[0], (list, tuple)):
809
+ # If the input is a single list or tuple, we unpack its elements and convert each to cm
810
+ return [i * 2.54 / dpi for i in px[0]]
811
+ # Case 2: When the user passes multiple arguments directly, such as px2cm(8, 5)
812
+ else:
813
+ # Here, we convert each individual argument directly to cm
814
+ return [i * 2.54 / dpi for i in px]
815
+
816
+
817
+ def px2inch(*px, dpi=300) -> list:
818
+ """
819
+ px2inch: converts pixel measurements to inches based on the given dpi.
820
+ Usage:
821
+ px2inch(300, 600, dpi=300); px2inch([300, 600], dpi=300)
822
+ Returns:
823
+ list: in inches
824
+ """
825
+ # Case 1: When the user passes a single argument that is a list or tuple, such as px2inch([300, 600]) or px2inch((300, 600))
826
+ if len(px) == 1 and isinstance(px[0], (list, tuple)):
827
+ # If the input is a single list or tuple, we unpack its elements and convert each to inches
828
+ return [i / dpi for i in px[0]]
829
+ # Case 2: When the user passes multiple arguments directly, such as px2inch(300, 600)
830
+ else:
831
+ # Here, we convert each individual argument directly to inches
832
+ return [i / dpi for i in px]
833
+
834
+
835
+ def cm2inch(*cm) -> list:
836
+ """
837
+ cm2inch: converts centimeter measurements to inches.
838
+ Usage:
839
+ cm2inch(10, 12.7); cm2inch((10, 12.7)); cm2inch([10, 12.7])
840
+ Returns:
841
+ list: in inches
842
+ """
843
+ # Case 1: When the user passes a single argument that is a list or tuple, such as cm2inch([10, 12.7]) or cm2inch((10, 12.7))
844
+ if len(cm) == 1 and isinstance(cm[0], (list, tuple)):
845
+ # If the input is a single list or tuple, we unpack its elements and convert each to inches
846
+ return [i * 2.54 for i in cm[0]]
847
+ # Case 2: When the user passes multiple arguments directly, such as cm2inch(10, 12.7)
848
+ else:
849
+ # Here, we convert each individual argument directly to inches
850
+ return [i * 2.54 for i in cm]
851
+
852
+
853
+ def inch2px(*inch, dpi=300) -> list:
854
+ """
855
+ inch2px: converts inch measurements to pixels based on the given dpi.
856
+ Usage:
857
+ inch2px(1, 2, dpi=300); inch2px([1, 2], dpi=300)
858
+ Returns:
859
+ list: in pixels
860
+ """
861
+ # Case 1: When the user passes a single argument that is a list or tuple, such as inch2px([1, 2]) or inch2px((1, 2))
862
+ if len(inch) == 1 and isinstance(inch[0], (list, tuple)):
863
+ # If the input is a single list or tuple, we unpack its elements and convert each to pixels
864
+ return [i * dpi for i in inch[0]]
865
+ # Case 2: When the user passes multiple arguments directly, such as inch2px(1, 2)
866
+ else:
867
+ # Here, we convert each individual argument directly to pixels
868
+ return [i * dpi for i in inch]
869
+
870
+
871
+ def inch2cm(*inch) -> list:
872
+ """
873
+ inch2cm: converts inch measurements to centimeters.
874
+ Usage:
875
+ inch2cm(8,5); inch2cm((8,5)); inch2cm([8,5])
876
+ Returns:
877
+ list: in centimeters
878
+ """
879
+ # Case 1: When the user passes a single argument that is a list or tuple, such as inch2cm([8, 5]) or inch2cm((8, 5))
880
+ if len(inch) == 1 and isinstance(inch[0], (list, tuple)):
881
+ # If the input is a single list or tuple, we unpack its elements and convert each to cm
882
+ return [i / 2.54 for i in inch[0]]
883
+ # Case 2: When the user passes multiple arguments directly, such as inch2cm(8, 5)
884
+ else:
885
+ # Here, we convert each individual argument directly to cm
886
+ return [i / 2.54 for i in inch]
887
+
888
+
796
889
  def sreplace(*args, **kwargs):
797
890
  """
798
891
  sreplace(text, by=None, robust=True)
@@ -1276,7 +1369,7 @@ def unzip(dir_path, output_dir=None):
1276
1369
  os.remove(output_dir) # remove file
1277
1370
 
1278
1371
  # Handle .tar.gz files
1279
- if dir_path.endswith(".tar.gz"):
1372
+ if dir_path.endswith(".tar.gz") or dir_path.endswith(".tgz"):
1280
1373
  import tarfile
1281
1374
 
1282
1375
  with tarfile.open(dir_path, "r:gz") as tar_ref:
@@ -1363,6 +1456,80 @@ def unzip(dir_path, output_dir=None):
1363
1456
  # output_dir_7z = unzip('archive.7z')
1364
1457
 
1365
1458
 
1459
+ def is_df_abnormal(df: pd.DataFrame, verbose=False) -> bool:
1460
+ """
1461
+ Usage
1462
+ is_abnormal = is_df_abnormal(df, verbose=1)
1463
+
1464
+ """
1465
+ # Initialize a list to hold messages about abnormalities
1466
+ messages = []
1467
+ is_abnormal = False
1468
+ # Check the shape of the DataFrame
1469
+ actual_shape = df.shape
1470
+ messages.append(f"Shape of DataFrame: {actual_shape}")
1471
+
1472
+ # Check column names
1473
+ column_names = df.columns.tolist()
1474
+
1475
+ # Count of delimiters and their occurrences
1476
+ delimiter_counts = {"\t": 0, ",": 0, "\n": 0, "": 0} # Count of empty strings
1477
+
1478
+ for name in column_names:
1479
+ # Count occurrences of each delimiter
1480
+ delimiter_counts["\t"] += name.count("\t")
1481
+ delimiter_counts[","] += name.count(",")
1482
+ delimiter_counts["\n"] += name.count("\n")
1483
+ if name.strip() == "":
1484
+ delimiter_counts[""] += 1
1485
+
1486
+ # Check for abnormalities based on delimiter counts
1487
+ if len(column_names) == 1 and delimiter_counts["\t"] > 1:
1488
+ messages.append("Abnormal: Column names are not split correctly.")
1489
+ is_abnormal = True
1490
+
1491
+ if any(delimiter_counts[d] > 3 for d in delimiter_counts if d != ""):
1492
+ messages.append("Abnormal: Too many delimiters in column names.")
1493
+ is_abnormal = True
1494
+
1495
+ if delimiter_counts[""] > 3:
1496
+ messages.append("Abnormal: There are empty column names.")
1497
+ is_abnormal = True
1498
+
1499
+ if any(delimiter_counts[d] > 3 for d in ["\t", ",", "\n"]):
1500
+ messages.append("Abnormal: Some column names contain unexpected characters.")
1501
+ is_abnormal = True
1502
+
1503
+ # Check for missing values
1504
+ missing_values = df.isnull().sum()
1505
+ if missing_values.any():
1506
+ messages.append("Missing values in columns:")
1507
+ messages.append(missing_values[missing_values > 0].to_string())
1508
+ is_abnormal = True
1509
+
1510
+ # Check data types
1511
+ data_types = df.dtypes
1512
+ # messages.append(f"Data types of columns:\n{data_types}")
1513
+
1514
+ # Check for constant values across any column
1515
+ constant_columns = df.columns[df.nunique() == 1].tolist()
1516
+ if constant_columns:
1517
+ messages.append(f"Abnormal: Columns with constant values: {constant_columns}")
1518
+ is_abnormal = True
1519
+
1520
+ # Check for an unreasonable number of rows or columns
1521
+ if actual_shape[0] < 2 or actual_shape[1] < 2:
1522
+ messages.append(
1523
+ "Abnormal: DataFrame is too small (less than 2 rows or columns)."
1524
+ )
1525
+ is_abnormal = True
1526
+
1527
+ # Compile results
1528
+ if verbose:
1529
+ print("\n".join(messages))
1530
+ return is_abnormal # Data is abnormal
1531
+
1532
+
1366
1533
  def fload(fpath, kind=None, **kwargs):
1367
1534
  """
1368
1535
  Load content from a file with specified file type.
@@ -1399,13 +1566,185 @@ def fload(fpath, kind=None, **kwargs):
1399
1566
  root = tree.getroot()
1400
1567
  return etree.tostring(root, pretty_print=True).decode()
1401
1568
 
1402
- def load_csv(fpath, engine="pyarrow", **kwargs):
1403
- print(f"engine={engine}")
1404
- df = pd.read_csv(fpath, engine=engine, **kwargs)
1569
+ def get_comment(fpath, comment=None, encoding="utf-8", lines_to_check=5):
1570
+ """
1571
+ Detect comment characters in a file.
1572
+
1573
+ Parameters:
1574
+ - fpath: str, the file path of the CSV file.
1575
+ - encoding: str, the encoding of the file (default is 'utf-8').
1576
+ - lines_to_check: int, number of lines to check for comment characters (default is 5).
1577
+
1578
+ Returns:
1579
+ - str or None: the detected comment character, or None if no comment character is found.
1580
+ """
1581
+ comment_chars = [
1582
+ "#",
1583
+ "!",
1584
+ "//",
1585
+ ";",
1586
+ ] # can use any character or string as a comment
1587
+ try:
1588
+ with open(fpath, "r", encoding=encoding) as f:
1589
+ lines = [next(f) for _ in range(lines_to_check)]
1590
+ except (UnicodeDecodeError, ValueError):
1591
+ with open(fpath, "r", encoding=get_encoding(fpath)) as f:
1592
+ lines = [next(f) for _ in range(lines_to_check)]
1593
+ for line in lines:
1594
+ for char in comment_chars:
1595
+ if line.startswith(char):
1596
+ return char
1597
+ return None
1598
+
1599
+ def load_csv(fpath, **kwargs):
1600
+ from pandas.errors import EmptyDataError
1601
+
1602
+ engine = kwargs.get("engine", "pyarrow")
1603
+ kwargs.pop("engine", None)
1604
+ sep = kwargs.get("sep", "\t")
1605
+ kwargs.pop("sep", None)
1606
+ index_col = kwargs.get("index_col", None)
1607
+ kwargs.pop("index_col", None)
1608
+ memory_map = kwargs.get("memory_map", True)
1609
+ kwargs.pop("memory_map", None)
1610
+ skipinitialspace = kwargs.get("skipinitialspace", True)
1611
+ kwargs.pop("skipinitialspace", None)
1612
+ encoding = kwargs.get("encoding", "utf-8")
1613
+ kwargs.pop("encoding", None)
1614
+ on_bad_lines = kwargs.get("on_bad_lines", "skip")
1615
+ kwargs.pop("on_bad_lines", None)
1616
+ comment = kwargs.get("comment", None)
1617
+ kwargs.pop("comment", None)
1618
+
1619
+ if comment is None:
1620
+ comment = get_comment(
1621
+ fpath, comment=None, encoding="utf-8", lines_to_check=5
1622
+ )
1623
+
1624
+ try:
1625
+ df = pd.read_csv(
1626
+ fpath,
1627
+ engine=engine,
1628
+ index_col=index_col,
1629
+ memory_map=memory_map,
1630
+ encoding=encoding,
1631
+ comment=comment,
1632
+ skipinitialspace=skipinitialspace,
1633
+ sep=sep,
1634
+ on_bad_lines=on_bad_lines,
1635
+ **kwargs,
1636
+ )
1637
+ except:
1638
+ try:
1639
+ try:
1640
+ if engine == "pyarrow":
1641
+ df = pd.read_csv(
1642
+ fpath,
1643
+ engine=engine,
1644
+ index_col=index_col,
1645
+ encoding=encoding,
1646
+ sep=sep,
1647
+ on_bad_lines=on_bad_lines,
1648
+ comment=comment,
1649
+ **kwargs,
1650
+ )
1651
+ else:
1652
+ df = pd.read_csv(
1653
+ fpath,
1654
+ engine=engine,
1655
+ index_col=index_col,
1656
+ memory_map=memory_map,
1657
+ encoding=encoding,
1658
+ sep=sep,
1659
+ skipinitialspace=skipinitialspace,
1660
+ on_bad_lines=on_bad_lines,
1661
+ comment=comment,
1662
+ **kwargs,
1663
+ )
1664
+
1665
+ if is_df_abnormal(df, verbose=0):
1666
+ raise ValueError("the df is abnormal")
1667
+ except (UnicodeDecodeError, ValueError):
1668
+ encoding = get_encoding(fpath)
1669
+ # print(f"utf-8 failed. Retrying with detected encoding: {encoding}")
1670
+ if engine == "pyarrow":
1671
+ df = pd.read_csv(
1672
+ fpath,
1673
+ engine=engine,
1674
+ index_col=index_col,
1675
+ encoding=encoding,
1676
+ sep=sep,
1677
+ on_bad_lines=on_bad_lines,
1678
+ comment=comment,
1679
+ **kwargs,
1680
+ )
1681
+ else:
1682
+ df = pd.read_csv(
1683
+ fpath,
1684
+ engine=engine,
1685
+ index_col=index_col,
1686
+ memory_map=memory_map,
1687
+ encoding=encoding,
1688
+ sep=sep,
1689
+ skipinitialspace=skipinitialspace,
1690
+ on_bad_lines=on_bad_lines,
1691
+ comment=comment,
1692
+ **kwargs,
1693
+ )
1694
+ if is_df_abnormal(df, verbose=0):
1695
+ raise ValueError("the df is abnormal")
1696
+ except Exception as e:
1697
+ separators = [",", "\t", ";", "|", " "]
1698
+ for sep in separators:
1699
+ sep2show = sep if sep != "\t" else "\\t"
1700
+ # print(f'trying with: engine=pyarrow, sep="{sep2show}"')
1701
+ try:
1702
+ df = pd.read_csv(
1703
+ fpath,
1704
+ engine=engine,
1705
+ skipinitialspace=skipinitialspace,
1706
+ sep=sep,
1707
+ on_bad_lines=on_bad_lines,
1708
+ comment=comment,
1709
+ **kwargs,
1710
+ )
1711
+ if not is_df_abnormal(df, verbose=0): # normal
1712
+ break
1713
+ else:
1714
+ if is_df_abnormal(df, verbose=0):
1715
+ pass
1716
+ except:
1717
+ pass
1718
+ else:
1719
+ engines = ["c", "python"]
1720
+ for engine in engines:
1721
+ # separators = [",", "\t", ";", "|", " "]
1722
+ for sep in separators:
1723
+ try:
1724
+ sep2show = sep if sep != "\t" else "\\t"
1725
+ # print(f"trying with: engine={engine}, sep='{sep2show}'")
1726
+ df = pd.read_csv(
1727
+ fpath,
1728
+ engine=engine,
1729
+ sep=sep,
1730
+ on_bad_lines=on_bad_lines,
1731
+ comment=comment,
1732
+ **kwargs,
1733
+ )
1734
+ if not is_df_abnormal(df, verbose=0):
1735
+ break
1736
+ except EmptyDataError as e:
1737
+ continue
1738
+ else:
1739
+ pass
1740
+ display(df.head(2))
1741
+ print(f"shape: {df.shape}")
1405
1742
  return df
1406
1743
 
1407
1744
  def load_xlsx(fpath, **kwargs):
1408
- df = pd.read_excel(fpath, **kwargs)
1745
+ engine = kwargs.get("engine", "openpyxl")
1746
+ kwargs.pop("engine", None)
1747
+ df = pd.read_excel(fpath, engine=engine, **kwargs)
1409
1748
  return df
1410
1749
 
1411
1750
  def load_ipynb(fpath, **kwargs):
@@ -1511,10 +1850,24 @@ def fload(fpath, kind=None, **kwargs):
1511
1850
  "pdf",
1512
1851
  "ipynb",
1513
1852
  ]
1514
- zip_types = ["gz", "zip", "7z", "tar", "tar.gz", "tar.bz2", "bz2", "xz", "rar"]
1515
- supported_types = [*doc_types, *img_types, *zip_types]
1853
+ zip_types = [
1854
+ "gz",
1855
+ "zip",
1856
+ "7z",
1857
+ "tar",
1858
+ "tar.gz",
1859
+ "tar.bz2",
1860
+ "bz2",
1861
+ "xz",
1862
+ "rar",
1863
+ "tgz",
1864
+ ]
1865
+ other_types = ["fcs"]
1866
+ supported_types = [*doc_types, *img_types, *zip_types, *other_types]
1516
1867
  if kind not in supported_types:
1517
- print(f'Error:\n"{kind}" is not in the supported list {supported_types}')
1868
+ print(
1869
+ f'Warning:\n"{kind}" is not in the supported list '
1870
+ ) # {supported_types}')
1518
1871
  # if os.path.splitext(fpath)[1][1:].lower() in zip_types:
1519
1872
  # keep=kwargs.get("keep", False)
1520
1873
  # ifile=kwargs.get("ifile",(0,0))
@@ -1544,9 +1897,22 @@ def fload(fpath, kind=None, **kwargs):
1544
1897
  elif kind == "xml":
1545
1898
  return load_xml(fpath)
1546
1899
  elif kind == "csv":
1547
- return load_csv(fpath, **kwargs)
1900
+ content = load_csv(fpath, **kwargs)
1901
+ return content
1902
+ elif kind in ["ods", "ods", "odt"]:
1903
+ engine = kwargs.get("engine", "odf")
1904
+ kwargs.pop("engine", None)
1905
+ return load_xlsx(fpath, engine=engine, **kwargs)
1906
+ elif kind == "xls":
1907
+ engine = kwargs.get("engine", "xlrd")
1908
+ kwargs.pop("engine", None)
1909
+ content = load_xlsx(fpath, engine=engine, **kwargs)
1910
+ display(content.head(2))
1911
+ return content
1548
1912
  elif kind == "xlsx":
1549
- return load_xlsx(fpath, **kwargs)
1913
+ content = load_xlsx(fpath, **kwargs)
1914
+ display(content.head(2))
1915
+ return content
1550
1916
  elif kind == "ipynb":
1551
1917
  return load_ipynb(fpath, **kwargs)
1552
1918
  elif kind == "pdf":
@@ -1558,17 +1924,62 @@ def fload(fpath, kind=None, **kwargs):
1558
1924
  elif kind.lower() in zip_types:
1559
1925
  keep = kwargs.get("keep", False)
1560
1926
  fpath_unzip = unzip(fpath)
1561
- content_unzip = fload(fpath_unzip, **kwargs)
1562
- if not keep:
1563
- os.remove(fpath_unzip)
1564
- return content_unzip
1927
+ if os.path.isdir(fpath_unzip):
1928
+ print(f"{fpath_unzip} is a folder. fload stoped.")
1929
+ fpath_list = os.listdir("./datasets/GSE10927_family.xml")
1930
+ print(f"{len(fpath_list)} files within the folder")
1931
+ if len(fpath_list) > 5:
1932
+ pp(fpath_list[:5])
1933
+ print("there are more ...")
1934
+ else:
1935
+ pp(fpath_list)
1936
+ return fpath_list
1937
+ elif os.path.isfile(fpath_unzip):
1938
+ print(f"{fpath_unzip} is a file.")
1939
+ content_unzip = fload(fpath_unzip, **kwargs)
1940
+ if not keep:
1941
+ os.remove(fpath_unzip)
1942
+ return content_unzip
1943
+ else:
1944
+ print(f"{fpath_unzip} does not exist or is a different type.")
1945
+
1946
+ elif kind.lower() == "gmt":
1947
+ import gseapy as gp
1948
+
1949
+ gene_sets = gp.read_gmt(fpath)
1950
+ return gene_sets
1951
+
1952
+ elif kind.lower() == "fcs":
1953
+ import fcsparser
1954
+
1955
+ # https://github.com/eyurtsev/fcsparser
1956
+ meta, data = fcsparser.parse(fpath, reformat_meta=True)
1957
+ return meta, data
1958
+
1565
1959
  else:
1960
+ # try:
1961
+ # content = load_csv(fpath, **kwargs)
1962
+ # except:
1566
1963
  try:
1567
- with open(fpath, "r") as f:
1568
- content = f.readlines()
1964
+ try:
1965
+ with open(fpath, "r", encoding="utf-8") as f:
1966
+ content = f.readlines()
1967
+ except UnicodeDecodeError:
1968
+ print("Failed to read as utf-8, trying different encoding...")
1969
+ with open(
1970
+ fpath, "r", encoding=get_encoding(fpath)
1971
+ ) as f: # Trying with a different encoding
1972
+ content = f.readlines()
1569
1973
  except:
1570
- with open(fpath, "r") as f:
1571
- content = f.read()
1974
+ try:
1975
+ with open(fpath, "r", encoding="utf-8") as f:
1976
+ content = f.read()
1977
+ except UnicodeDecodeError:
1978
+ print("Failed to read as utf-8, trying different encoding...")
1979
+ with open(
1980
+ fpath, "r", encoding=get_encoding(fpath)
1981
+ ) as f: # Trying with a different encoding
1982
+ content = f.read()
1572
1983
  return content
1573
1984
 
1574
1985
 
@@ -2021,13 +2432,23 @@ def get_os():
2021
2432
 
2022
2433
  def listdir(
2023
2434
  rootdir,
2024
- kind="folder",
2435
+ kind=None,
2025
2436
  sort_by="name",
2026
2437
  ascending=True,
2027
2438
  contains=None,
2028
2439
  orient="list",
2029
2440
  output="df", # 'list','dict','records','index','series'
2441
+ verbose=True,
2030
2442
  ):
2443
+ if kind is None:
2444
+ df_all = pd.DataFrame(
2445
+ {
2446
+ "fname": os.listdir(rootdir),
2447
+ "fpath": [os.path.join(rootdir, i) for i in os.listdir(rootdir)],
2448
+ }
2449
+ )
2450
+ display(df_all)
2451
+ return df_all
2031
2452
  if isinstance(kind, list):
2032
2453
  f_ = []
2033
2454
  for kind_ in kind:
@@ -2099,7 +2520,7 @@ def listdir(
2099
2520
 
2100
2521
  f["num"] = i
2101
2522
  f["rootdir"] = rootdir
2102
- f["os"] = os.uname().machine
2523
+ f["os"] = get_os() # os.uname().machine
2103
2524
  else:
2104
2525
  raise FileNotFoundError(
2105
2526
  'The directory "{}" does NOT exist. Please check the directory "rootdir".'.format(
@@ -2122,6 +2543,8 @@ def listdir(
2122
2543
  f = sort_kind(f, by="size", ascending=ascending)
2123
2544
 
2124
2545
  if "df" in output:
2546
+ if verbose:
2547
+ display(f.head())
2125
2548
  return f
2126
2549
  else:
2127
2550
  if "l" in orient.lower(): # list # default
@@ -2154,32 +2577,54 @@ def func_list(lib_name, opt="call"):
2154
2577
  return list_func(lib_name, opt=opt)
2155
2578
 
2156
2579
 
2157
- def mkdir(*args, **kwargs):
2580
+ def mkdir_nest(fpath: str) -> str:
2158
2581
  """
2159
- newfolder(pardir, chdir)
2160
- Args:
2161
- pardir (dir): parent dir
2162
- chdir (str): children dir
2163
- overwrite (bool): overwrite?
2582
+ Create nested directories based on the provided file path.
2583
+
2584
+ Parameters:
2585
+ - fpath (str): The full file path for which the directories should be created.
2586
+
2164
2587
  Returns:
2165
- mkdir, giving a option if exists_ok or not
2588
+ - str: The path of the created directory.
2166
2589
  """
2167
- overwrite = kwargs.get("overwrite", False)
2168
- for arg in args:
2169
- if isinstance(arg, (str, list)):
2170
- if "/" in arg or "\\" in arg:
2171
- pardir = arg
2172
- print(f'pardir: "{pardir}"')
2173
- else:
2174
- chdir = arg
2175
- print(f'chdir:"{chdir}"')
2176
- elif isinstance(arg, bool):
2177
- overwrite = arg
2178
- print(overwrite)
2179
- else:
2180
- print(f"{arg}Error: not support a {type(arg)} type")
2590
+ # Check if the directory already exists
2591
+ if os.path.isdir(fpath):
2592
+ return fpath
2593
+
2594
+ # Split the full path into directories
2595
+ f_slash = "/" if "mac" in get_os().lower() else "\\"
2596
+ dir_parts = fpath.split(f_slash) # Split the path by the OS-specific separator
2597
+
2598
+ # Start creating directories from the root to the desired path
2599
+ current_path = ""
2600
+ for part in dir_parts:
2601
+ if part:
2602
+ current_path = os.path.join(current_path, part)
2603
+ if not os.path.isdir(current_path):
2604
+ os.makedirs(current_path)
2605
+ if not current_path.endswith(f_slash):
2606
+ current_path += f_slash
2607
+ return current_path
2608
+
2609
+
2610
+ def mkdir(pardir: str = None, chdir: str | list = None, overwrite=False):
2611
+ """
2612
+ Create a directory.
2613
+
2614
+ Parameters:
2615
+ - pardir (str): Parent directory where the new directory will be created. If None, uses the current working directory.
2616
+ - chdir (str | list): Name of the new directory or a list of directories to create.
2617
+ If None, a default name 'new_directory' will be used.
2618
+ - overwrite (bool): If True, overwrite the directory if it already exists. Defaults to False.
2619
+
2620
+ Returns:
2621
+ - str: The path of the created directory or an error message.
2622
+ """
2623
+
2181
2624
  rootdir = []
2182
2625
  # Convert string to list
2626
+ if chdir is None:
2627
+ return mkdir_nest(pardir)
2183
2628
  if isinstance(chdir, str):
2184
2629
  chdir = [chdir]
2185
2630
  # Subfoldername should be unique
@@ -2226,54 +2671,111 @@ def mkdir(*args, **kwargs):
2226
2671
  return rootdir
2227
2672
 
2228
2673
 
2674
+ def split_path(fpath):
2675
+ f_slash = "/" if "mac" in get_os().lower() else "\\"
2676
+ dir_par = f_slash.join(fpath.split(f_slash)[:-1])
2677
+ dir_ch = "".join(fpath.split(f_slash)[-1:])
2678
+ return dir_par, dir_ch
2679
+
2680
+
2229
2681
  def figsave(*args, dpi=300):
2230
2682
  dir_save = None
2231
2683
  fname = None
2684
+ img = None
2232
2685
  for arg in args:
2233
2686
  if isinstance(arg, str):
2234
2687
  if "/" in arg or "\\" in arg:
2235
2688
  dir_save = arg
2236
2689
  elif "/" not in arg and "\\" not in arg:
2237
2690
  fname = arg
2238
- # Backup original values
2239
- if "/" in dir_save:
2240
- if dir_save[-1] != "/":
2241
- dir_save = dir_save + "/"
2242
- elif "\\" in dir_save:
2243
- if dir_save[-1] != "\\":
2244
- dir_save = dir_save + "\\"
2245
- else:
2246
- raise ValueError("Check the Path of dir_save Directory")
2691
+ elif isinstance(arg, (Image.Image, np.ndarray)):
2692
+ img = arg # Store the PIL image if provided
2693
+
2694
+ f_slash = "/" if "mac" in get_os().lower() else "\\"
2695
+ dir_par = f_slash.join(dir_save.split(f_slash)[:-1])
2696
+ dir_ch = "".join(dir_save.split(f_slash)[-1:])
2697
+ if not dir_par.endswith(f_slash):
2698
+ dir_par += f_slash
2699
+ if fname is None:
2700
+ fname = dir_ch
2701
+ mkdir(dir_par)
2247
2702
  ftype = fname.split(".")[-1]
2248
2703
  if len(fname.split(".")) == 1:
2249
2704
  ftype = "nofmt"
2250
- fname = dir_save + fname + "." + ftype
2705
+ fname = dir_par + fname + "." + ftype
2251
2706
  else:
2252
- fname = dir_save + fname
2707
+ fname = dir_par + fname
2708
+
2253
2709
  # Save figure based on file type
2254
2710
  if ftype.lower() == "eps":
2255
2711
  plt.savefig(fname, format="eps", bbox_inches="tight")
2256
2712
  plt.savefig(
2257
- fname.replace(".eps", ".pdf"), format="pdf", bbox_inches="tight", dpi=dpi
2713
+ fname.replace(".eps", ".pdf"),
2714
+ format="pdf",
2715
+ bbox_inches="tight",
2716
+ dpi=dpi,
2717
+ pad_inches=0,
2258
2718
  )
2259
2719
  elif ftype.lower() == "nofmt": # default: both "tif" and "pdf"
2260
2720
  fname_corr = fname.replace("nofmt", "pdf")
2261
- plt.savefig(fname_corr, format="pdf", bbox_inches="tight", dpi=dpi)
2721
+ plt.savefig(
2722
+ fname_corr, format="pdf", bbox_inches="tight", dpi=dpi, pad_inches=0
2723
+ )
2262
2724
  fname = fname.replace("nofmt", "tif")
2263
- plt.savefig(fname, format="tiff", dpi=dpi, bbox_inches="tight")
2725
+ plt.savefig(fname, format="tiff", dpi=dpi, bbox_inches="tight", pad_inches=0)
2264
2726
  print(f"default saving filetype: both 'tif' and 'pdf")
2265
2727
  elif ftype.lower() == "pdf":
2266
- plt.savefig(fname, format="pdf", bbox_inches="tight", dpi=dpi)
2267
- elif ftype.lower() in ["jpg", "jpeg"]:
2268
- plt.savefig(fname, format="jpeg", dpi=dpi, bbox_inches="tight")
2269
- elif ftype.lower() == "png":
2270
- plt.savefig(fname, format="png", dpi=dpi, bbox_inches="tight", transparent=True)
2271
- elif ftype.lower() in ["tiff", "tif"]:
2272
- plt.savefig(fname, format="tiff", dpi=dpi, bbox_inches="tight")
2728
+ plt.savefig(fname, format="pdf", bbox_inches="tight", dpi=dpi, pad_inches=0)
2729
+ elif ftype.lower() in ["jpg", "jpeg", "png", "tiff", "tif"]:
2730
+ if img is not None: # If a PIL image is provided
2731
+ if isinstance(img, Image.Image):
2732
+ if img.mode == "RGBA":
2733
+ img = img.convert("RGB")
2734
+ img.save(fname, format=ftype.upper(), dpi=(dpi, dpi))
2735
+ elif isinstance(img, np.ndarray):
2736
+ import cv2
2737
+
2738
+ # Check the shape of the image to determine color mode
2739
+ if img.ndim == 2:
2740
+ # Grayscale image
2741
+ Image.fromarray(img).save(
2742
+ fname, format=ftype.upper(), dpi=(dpi, dpi)
2743
+ )
2744
+ elif img.ndim == 3:
2745
+ if img.shape[2] == 3:
2746
+ # RGB image
2747
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
2748
+ Image.fromarray(img).save(
2749
+ fname, format=ftype.upper(), dpi=(dpi, dpi)
2750
+ )
2751
+ elif img.shape[2] == 4:
2752
+ # RGBA image
2753
+ img = cv2.cvtColor(
2754
+ img, cv2.COLOR_BGRA2RGBA
2755
+ ) # Convert BGRA to RGBA
2756
+ Image.fromarray(img).save(
2757
+ fname, format=ftype.upper(), dpi=(dpi, dpi)
2758
+ )
2759
+ else:
2760
+ raise ValueError(
2761
+ "Unexpected number of channels in the image array."
2762
+ )
2763
+ else:
2764
+ raise ValueError(
2765
+ "Image array has an unexpected number of dimensions."
2766
+ )
2767
+ else:
2768
+ plt.savefig(
2769
+ fname, format=ftype.lower(), dpi=dpi, bbox_inches="tight", pad_inches=0
2770
+ )
2771
+ # elif ftype.lower() == "png":
2772
+ # plt.savefig(fname, format="png", dpi=dpi, bbox_inches="tight", transparent=True,pad_inches=0)
2773
+ # elif ftype.lower() in ["tiff", "tif"]:
2774
+ # plt.savefig(fname, format="tiff", dpi=dpi, bbox_inches="tight",pad_inches=0)
2273
2775
  elif ftype.lower() == "emf":
2274
- plt.savefig(fname, format="emf", dpi=dpi, bbox_inches="tight")
2776
+ plt.savefig(fname, format="emf", dpi=dpi, bbox_inches="tight", pad_inches=0)
2275
2777
  elif ftype.lower() == "fig":
2276
- plt.savefig(fname, format="pdf", bbox_inches="tight", dpi=dpi)
2778
+ plt.savefig(fname, format="pdf", bbox_inches="tight", dpi=dpi, pad_inches=0)
2277
2779
  print(f"\nSaved @: dpi={dpi}\n{fname}")
2278
2780
 
2279
2781
 
@@ -2405,6 +2907,8 @@ def load_img(fpath):
2405
2907
  FileNotFoundError: If the specified file is not found.
2406
2908
  OSError: If the specified file cannot be opened or is not a valid image file.
2407
2909
  """
2910
+ from PIL import Image
2911
+
2408
2912
  try:
2409
2913
  img = Image.open(fpath)
2410
2914
  return img
@@ -2545,393 +3049,6 @@ def apply_filter(img, *args):
2545
3049
  return img.filter(supported_filters[filter_name])
2546
3050
 
2547
3051
 
2548
- def imgsetss(
2549
- img,
2550
- sets=None,
2551
- show=True,
2552
- show_axis=False,
2553
- size=None,
2554
- dpi=100,
2555
- figsize=None,
2556
- auto=False,
2557
- filter_kws=None,
2558
- ):
2559
- """
2560
- Apply various enhancements and filters to an image using PIL's ImageEnhance and ImageFilter modules.
2561
-
2562
- Args:
2563
- img (PIL.Image): The input image.
2564
- sets (dict): A dictionary specifying the enhancements, filters, and their parameters.
2565
- show (bool): Whether to display the enhanced image.
2566
- show_axis (bool): Whether to display axes on the image plot.
2567
- size (tuple): The size of the thumbnail, cover, contain, or fit operation.
2568
- dpi (int): Dots per inch for the displayed image.
2569
- figsize (tuple): The size of the figure for displaying the image.
2570
- auto (bool): Whether to automatically enhance the image based on its characteristics.
2571
-
2572
- Returns:
2573
- PIL.Image: The enhanced image.
2574
-
2575
- Supported enhancements and filters:
2576
- - "sharpness": Adjusts the sharpness of the image. Values > 1 increase sharpness, while values < 1 decrease sharpness.
2577
- - "contrast": Adjusts the contrast of the image. Values > 1 increase contrast, while values < 1 decrease contrast.
2578
- - "brightness": Adjusts the brightness of the image. Values > 1 increase brightness, while values < 1 decrease brightness.
2579
- - "color": Adjusts the color saturation of the image. Values > 1 increase saturation, while values < 1 decrease saturation.
2580
- - "rotate": Rotates the image by the specified angle.
2581
- - "crop" or "cut": Crops the image. The value should be a tuple specifying the crop box as (left, upper, right, lower).
2582
- - "size": Resizes the image to the specified dimensions.
2583
- - "thumbnail": Resizes the image to fit within the given size while preserving aspect ratio.
2584
- - "cover": Resizes and crops the image to fill the specified size.
2585
- - "contain": Resizes the image to fit within the specified size, adding borders if necessary.
2586
- - "fit": Resizes and pads the image to fit within the specified size.
2587
- - "filter": Applies various filters to the image (e.g., BLUR, CONTOUR, EDGE_ENHANCE).
2588
-
2589
- Note:
2590
- The "color" and "enhance" enhancements are not implemented in this function.
2591
- """
2592
- supported_filters = [
2593
- "BLUR",
2594
- "CONTOUR",
2595
- "DETAIL",
2596
- "EDGE_ENHANCE",
2597
- "EDGE_ENHANCE_MORE",
2598
- "EMBOSS",
2599
- "FIND_EDGES",
2600
- "SHARPEN",
2601
- "SMOOTH",
2602
- "SMOOTH_MORE",
2603
- "MIN_FILTER",
2604
- "MAX_FILTER",
2605
- "MODE_FILTER",
2606
- "MULTIBAND_FILTER",
2607
- "GAUSSIAN_BLUR",
2608
- "BOX_BLUR",
2609
- "MEDIAN_FILTER",
2610
- ]
2611
- print(
2612
- "sets: a dict,'sharp:1.2','color','contrast:'auto' or 1.2','bright', 'crop: x_upperleft,y_upperleft, x_lowerright, y_lowerright','rotation','resize','rem or background'"
2613
- )
2614
- print(f"usage: filter_kws 'dict' below:")
2615
- pp([str(i).lower() for i in supported_filters])
2616
- print("\nlog:\n")
2617
-
2618
- def confirm_rembg_models(model_name):
2619
- models_support = [
2620
- "u2net",
2621
- "u2netp",
2622
- "u2net_human_seg",
2623
- "u2net_cloth_seg",
2624
- "silueta",
2625
- "isnet-general-use",
2626
- "isnet-anime",
2627
- "sam",
2628
- ]
2629
- if model_name in models_support:
2630
- print(f"model_name: {model_name}")
2631
- return model_name
2632
- else:
2633
- print(
2634
- f"{model_name} cannot be found, check the name:{models_support}, default('isnet-general-use') has been used"
2635
- )
2636
- return "isnet-general-use"
2637
-
2638
- def auto_enhance(img):
2639
- """
2640
- Automatically enhances the image based on its characteristics.
2641
- Args:
2642
- img (PIL.Image): The input image.
2643
- Returns:
2644
- dict: A dictionary containing the optimal enhancement values.
2645
- """
2646
- # Determine the bit depth based on the image mode
2647
- if img.mode in ["1", "L", "P", "RGB", "YCbCr", "LAB", "HSV"]:
2648
- # 8-bit depth per channel
2649
- bit_depth = 8
2650
- elif img.mode in ["RGBA", "CMYK"]:
2651
- # 8-bit depth per channel + alpha (RGBA) or additional channels (CMYK)
2652
- bit_depth = 8
2653
- elif img.mode in ["I", "F"]:
2654
- # 16-bit depth per channel (integer or floating-point)
2655
- bit_depth = 16
2656
- else:
2657
- raise ValueError("Unsupported image mode")
2658
- # Calculate the brightness and contrast for each channel
2659
- num_channels = len(img.getbands())
2660
- brightness_factors = []
2661
- contrast_factors = []
2662
- for channel in range(num_channels):
2663
- channel_histogram = img.split()[channel].histogram()
2664
- brightness = sum(i * w for i, w in enumerate(channel_histogram)) / sum(
2665
- channel_histogram
2666
- )
2667
- channel_min, channel_max = img.split()[channel].getextrema()
2668
- contrast = channel_max - channel_min
2669
- # Adjust calculations based on bit depth
2670
- normalization_factor = 2**bit_depth - 1 # Max value for the given bit depth
2671
- brightness_factor = (
2672
- 1.0 + (brightness - normalization_factor / 2) / normalization_factor
2673
- )
2674
- contrast_factor = (
2675
- 1.0 + (contrast - normalization_factor / 2) / normalization_factor
2676
- )
2677
- brightness_factors.append(brightness_factor)
2678
- contrast_factors.append(contrast_factor)
2679
- # Calculate the average brightness and contrast factors across channels
2680
- avg_brightness_factor = sum(brightness_factors) / num_channels
2681
- avg_contrast_factor = sum(contrast_factors) / num_channels
2682
- return {"brightness": avg_brightness_factor, "contrast": avg_contrast_factor}
2683
-
2684
- # Load image if input is a file path
2685
- if isinstance(img, str):
2686
- img = load_img(img)
2687
- img_update = img.copy()
2688
- # Auto-enhance image if requested
2689
- if auto:
2690
- auto_params = auto_enhance(img_update)
2691
- sets.update(auto_params)
2692
- if sets is None:
2693
- sets = {}
2694
- for k, value in sets.items():
2695
- if "shar" in k.lower():
2696
- enhancer = ImageEnhance.Sharpness(img_update)
2697
- img_update = enhancer.enhance(value)
2698
- elif "col" in k.lower() and "bg" not in k.lower():
2699
- enhancer = ImageEnhance.Color(img_update)
2700
- img_update = enhancer.enhance(value)
2701
- elif "contr" in k.lower():
2702
- if value and isinstance(value, (float, int)):
2703
- enhancer = ImageEnhance.Contrast(img_update)
2704
- img_update = enhancer.enhance(value)
2705
- else:
2706
- print("autocontrasted")
2707
- img_update = ImageOps.autocontrast(img_update)
2708
- elif "bri" in k.lower():
2709
- enhancer = ImageEnhance.Brightness(img_update)
2710
- img_update = enhancer.enhance(value)
2711
- elif "cro" in k.lower() or "cut" in k.lower():
2712
- img_update = img_update.crop(value)
2713
- elif "rota" in k.lower():
2714
- img_update = img_update.rotate(value)
2715
- elif "si" in k.lower():
2716
- img_update = img_update.resize(value)
2717
- elif "thum" in k.lower():
2718
- img_update.thumbnail(value)
2719
- elif "cover" in k.lower():
2720
- img_update = ImageOps.cover(img_update, size=value)
2721
- elif "contain" in k.lower():
2722
- img_update = ImageOps.contain(img_update, size=value)
2723
- elif "fit" in k.lower():
2724
- img_update = ImageOps.fit(img_update, size=value)
2725
- elif "pad" in k.lower():
2726
- img_update = ImageOps.pad(img_update, size=value)
2727
- elif "rem" in k.lower() or "rm" in k.lower() or "back" in k.lower():
2728
- if value and isinstance(value, (int, float, list)):
2729
- print(
2730
- 'example usage: {"rm":[alpha_matting_background_threshold(20),alpha_matting_foreground_threshold(270),alpha_matting_erode_sive(11)]}'
2731
- )
2732
- print("https://github.com/danielgatis/rembg/blob/main/USAGE.md")
2733
- # ### Parameters:
2734
- # data (Union[bytes, PILImage, np.ndarray]): The input image data.
2735
- # alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False.
2736
- # alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240.
2737
- # alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10.
2738
- # alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10.
2739
- # session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None.
2740
- # only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False.
2741
- # post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False.
2742
- # bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None.
2743
- # ###
2744
- if isinstance(value, int):
2745
- value = [value]
2746
- if len(value) < 2:
2747
- img_update = remove(
2748
- img_update,
2749
- alpha_matting=True,
2750
- alpha_matting_background_threshold=value,
2751
- )
2752
- elif 2 <= len(value) < 3:
2753
- img_update = remove(
2754
- img_update,
2755
- alpha_matting=True,
2756
- alpha_matting_background_threshold=value[0],
2757
- alpha_matting_foreground_threshold=value[1],
2758
- )
2759
- elif 3 <= len(value) < 4:
2760
- img_update = remove(
2761
- img_update,
2762
- alpha_matting=True,
2763
- alpha_matting_background_threshold=value[0],
2764
- alpha_matting_foreground_threshold=value[1],
2765
- alpha_matting_erode_size=value[2],
2766
- )
2767
- if isinstance(value, tuple): # replace the background color
2768
- if len(value) == 3:
2769
- value += (255,)
2770
- img_update = remove(img_update, bgcolor=value)
2771
- if isinstance(value, str):
2772
- if confirm_rembg_models(value):
2773
- img_update = remove(img_update, session=new_session(value))
2774
- else:
2775
- img_update = remove(img_update)
2776
- elif "bgcolor" in k.lower():
2777
- if isinstance(value, list):
2778
- value = tuple(value)
2779
- if isinstance(value, tuple): # replace the background color
2780
- if len(value) == 3:
2781
- value += (255,)
2782
- img_update = remove(img_update, bgcolor=value)
2783
- if filter_kws:
2784
- for filter_name, filter_value in filter_kws.items():
2785
- img_update = apply_filter(img_update, filter_name, filter_value)
2786
- # Display the image if requested
2787
- if show:
2788
- if figsize is None:
2789
- plt.figure(dpi=dpi)
2790
- else:
2791
- plt.figure(figsize=figsize, dpi=dpi)
2792
- plt.imshow(img_update)
2793
- plt.axis("on") if show_axis else plt.axis("off")
2794
- return img_update
2795
-
2796
-
2797
- from sklearn.decomposition import PCA
2798
- from skimage import transform, feature, filters, measure
2799
- from skimage.color import rgb2gray
2800
- from scipy.fftpack import fftshift, fft2
2801
- import numpy as np
2802
- import cv2 # Used for template matching
2803
-
2804
-
2805
- def crop_black_borders(image):
2806
- """Crop the black borders from a rotated image."""
2807
- # Convert the image to grayscale if it's not already
2808
- if image.ndim == 3:
2809
- gray_image = color.rgb2gray(image)
2810
- else:
2811
- gray_image = image
2812
-
2813
- # Find all the non-black (non-zero) pixels
2814
- mask = gray_image > 0 # Mask for non-black pixels (assuming black is zero)
2815
- coords = np.column_stack(np.where(mask))
2816
-
2817
- # Get the bounding box of non-black pixels
2818
- if coords.any(): # Check if there are any non-black pixels
2819
- y_min, x_min = coords.min(axis=0)
2820
- y_max, x_max = coords.max(axis=0)
2821
-
2822
- # Crop the image to the bounding box
2823
- cropped_image = image[y_min : y_max + 1, x_min : x_max + 1]
2824
- else:
2825
- # If the image is completely black (which shouldn't happen), return the original image
2826
- cropped_image = image
2827
-
2828
- return cropped_image
2829
-
2830
-
2831
- def detect_angle(image, by="median", template=None):
2832
- """Detect the angle of rotation using various methods."""
2833
- # Convert to grayscale
2834
- gray_image = rgb2gray(image)
2835
-
2836
- # Detect edges using Canny edge detector
2837
- edges = feature.canny(gray_image, sigma=2)
2838
-
2839
- # Use Hough transform to detect lines
2840
- lines = transform.probabilistic_hough_line(edges)
2841
-
2842
- if not lines and any(["me" in by, "pca" in by]):
2843
- print("No lines detected. Adjust the edge detection parameters.")
2844
- return 0
2845
-
2846
- # Hough Transform-based angle detection (Median/Mean)
2847
- if "me" in by:
2848
- angles = []
2849
- for line in lines:
2850
- (x0, y0), (x1, y1) = line
2851
- angle = np.arctan2(y1 - y0, x1 - x0) * 180 / np.pi
2852
- if 80 < abs(angle) < 100:
2853
- angles.append(angle)
2854
- if not angles:
2855
- return 0
2856
- if "di" in by:
2857
- median_angle = np.median(angles)
2858
- rotation_angle = (
2859
- 90 - median_angle if median_angle > 0 else -90 - median_angle
2860
- )
2861
-
2862
- return rotation_angle
2863
- else:
2864
- mean_angle = np.mean(angles)
2865
- rotation_angle = 90 - mean_angle if mean_angle > 0 else -90 - mean_angle
2866
-
2867
- return rotation_angle
2868
-
2869
- # PCA-based angle detection
2870
- elif "pca" in by:
2871
- y, x = np.nonzero(edges)
2872
- if len(x) == 0:
2873
- return 0
2874
- pca = PCA(n_components=2)
2875
- pca.fit(np.vstack((x, y)).T)
2876
- angle = np.arctan2(pca.components_[0, 1], pca.components_[0, 0]) * 180 / np.pi
2877
- return angle
2878
-
2879
- # Gradient Orientation-based angle detection
2880
- elif "gra" in by:
2881
- gx, gy = np.gradient(gray_image)
2882
- angles = np.arctan2(gy, gx) * 180 / np.pi
2883
- hist, bin_edges = np.histogram(angles, bins=360, range=(-180, 180))
2884
- return bin_edges[np.argmax(hist)]
2885
-
2886
- # Template Matching-based angle detection
2887
- elif "temp" in by:
2888
- if template is None:
2889
- # Automatically extract a template from the center of the image
2890
- height, width = gray_image.shape
2891
- center_x, center_y = width // 2, height // 2
2892
- size = (
2893
- min(height, width) // 4
2894
- ) # Size of the template as a fraction of image size
2895
- template = gray_image[
2896
- center_y - size : center_y + size, center_x - size : center_x + size
2897
- ]
2898
- best_angle = None
2899
- best_corr = -1
2900
- for angle in range(0, 180, 1): # Checking every degree
2901
- rotated_template = transform.rotate(template, angle)
2902
- res = cv2.matchTemplate(gray_image, rotated_template, cv2.TM_CCOEFF)
2903
- _, max_val, _, _ = cv2.minMaxLoc(res)
2904
- if max_val > best_corr:
2905
- best_corr = max_val
2906
- best_angle = angle
2907
- return best_angle
2908
-
2909
- # Image Moments-based angle detection
2910
- elif "mo" in by:
2911
- moments = measure.moments_central(gray_image)
2912
- angle = (
2913
- 0.5
2914
- * np.arctan2(2 * moments[1, 1], moments[0, 2] - moments[2, 0])
2915
- * 180
2916
- / np.pi
2917
- )
2918
- return angle
2919
-
2920
- # Fourier Transform-based angle detection
2921
- elif "fft" in by:
2922
- f = fft2(gray_image)
2923
- fshift = fftshift(f)
2924
- magnitude_spectrum = np.log(np.abs(fshift) + 1)
2925
- rows, cols = magnitude_spectrum.shape
2926
- r, c = np.unravel_index(np.argmax(magnitude_spectrum), (rows, cols))
2927
- angle = np.arctan2(r - rows // 2, c - cols // 2) * 180 / np.pi
2928
- return angle
2929
-
2930
- else:
2931
- print(f"Unknown method {by}")
2932
- return 0
2933
-
2934
-
2935
3052
  def imgsets(img, **kwargs):
2936
3053
  """
2937
3054
  Apply various enhancements and filters to an image using PIL's ImageEnhance and ImageFilter modules.
@@ -3074,7 +3191,9 @@ def imgsets(img, **kwargs):
3074
3191
  if "shar" in k.lower():
3075
3192
  enhancer = ImageEnhance.Sharpness(img_update)
3076
3193
  img_update = enhancer.enhance(value)
3077
- elif "col" in k.lower() and "bg" not in k.lower():
3194
+ elif all(
3195
+ ["col" in k.lower(), "bg" not in k.lower(), "background" not in k.lower()]
3196
+ ):
3078
3197
  enhancer = ImageEnhance.Color(img_update)
3079
3198
  img_update = enhancer.enhance(value)
3080
3199
  elif "contr" in k.lower():
@@ -3096,6 +3215,9 @@ def imgsets(img, **kwargs):
3096
3215
  img_update = img_update.rotate(value)
3097
3216
 
3098
3217
  elif "si" in k.lower():
3218
+ if isinstance(value, tuple):
3219
+ value = list(value)
3220
+ value = [int(i) for i in value]
3099
3221
  img_update = img_update.resize(value)
3100
3222
  elif "thum" in k.lower():
3101
3223
  img_update.thumbnail(value)
@@ -3116,21 +3238,7 @@ def imgsets(img, **kwargs):
3116
3238
  session = new_session("isnet-general-use")
3117
3239
  img_update = remove(img_update, session=session)
3118
3240
  elif value and isinstance(value, (int, float, list)):
3119
- print(
3120
- 'example usage: {"rm":[alpha_matting_background_threshold(20),alpha_matting_foreground_threshold(270),alpha_matting_erode_sive(11)]}'
3121
- )
3122
3241
  print("https://github.com/danielgatis/rembg/blob/main/USAGE.md")
3123
- # ### Parameters:
3124
- # data (Union[bytes, PILImage, np.ndarray]): The input image data.
3125
- # alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False.
3126
- # alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240.
3127
- # alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10.
3128
- # alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10.
3129
- # session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None.
3130
- # only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False.
3131
- # post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False.
3132
- # bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None.
3133
- # ###
3134
3242
  if isinstance(value, int):
3135
3243
  value = [value]
3136
3244
  if len(value) < 2:
@@ -3189,16 +3297,6 @@ def imgsets(img, **kwargs):
3189
3297
  return img_update
3190
3298
 
3191
3299
 
3192
- # # usage:
3193
- # img = imgsets(
3194
- # fpath,
3195
- # sets={"rota": -5},
3196
- # dpi=200,
3197
- # filter_kws={"EMBOSS": 5, "sharpen": 5, "EDGE_ENHANCE_MORE": 10},
3198
- # show_axis=True,
3199
- # )
3200
-
3201
-
3202
3300
  def thumbnail(dir_img_list, figsize=(10, 10), dpi=100, dir_save=None, kind=".png"):
3203
3301
  """
3204
3302
  Display a thumbnail figure of all images in the specified directory.
@@ -4143,7 +4241,7 @@ format_excel(
4143
4241
  print(f"Formatted Excel file saved as:\n{filename}")
4144
4242
 
4145
4243
 
4146
- from IPython.display import display, HTML, Markdown, Image
4244
+ from IPython.display import display, HTML, Markdown
4147
4245
 
4148
4246
 
4149
4247
  def preview(var):
@@ -4200,7 +4298,7 @@ def df_as_type(
4200
4298
  columns: Optional[Union[str, List[str]]] = None,
4201
4299
  astype: str = "datetime",
4202
4300
  format: Optional[str] = None,
4203
- inplace: bool = False,
4301
+ inplace: bool = True,
4204
4302
  errors: str = "coerce", # Can be "ignore", "raise", or "coerce"
4205
4303
  **kwargs,
4206
4304
  ) -> Optional[pd.DataFrame]:
@@ -4353,8 +4451,7 @@ def df_as_type(
4353
4451
  print(f"Error converting '{column}' to {astype}: {e}")
4354
4452
 
4355
4453
  # Return the modified DataFrame if inplace is False
4356
- if not inplace:
4357
- return df
4454
+ return df
4358
4455
 
4359
4456
 
4360
4457
  # ! DataFrame
@@ -4424,3 +4521,709 @@ def df_sort_values(df, column, by=None, ascending=True, inplace=False, **kwargs)
4424
4521
  # display(sorted_df_month)
4425
4522
  # df_sort_values(df_month, "month", month_order, ascending=True, inplace=True)
4426
4523
  # display(df_month)
4524
+
4525
+
4526
+ def df_cluster(
4527
+ data: pd.DataFrame,
4528
+ columns: Optional[list] = None,
4529
+ n_clusters: Optional[int] = None,
4530
+ range_n_clusters: Union[range, np.ndarray] = range(2, 11),
4531
+ scale: bool = True,
4532
+ plot: Union[str, list] = "all",
4533
+ inplace: bool = True,
4534
+ ax: Optional[plt.Axes] = None,
4535
+ ) -> tuple[pd.DataFrame, int, Optional[plt.Axes]]:
4536
+ from sklearn.preprocessing import StandardScaler
4537
+ from sklearn.cluster import KMeans
4538
+ from sklearn.metrics import silhouette_score, silhouette_samples
4539
+ import seaborn as sns
4540
+ import numpy as np
4541
+ import pandas as pd
4542
+ import matplotlib.pyplot as plt
4543
+ import seaborn as sns
4544
+
4545
+ """
4546
+ Performs clustering analysis on the provided feature matrix using K-Means.
4547
+
4548
+ Parameters:
4549
+ X (np.ndarray):
4550
+ A 2D numpy array or DataFrame containing numerical feature data,
4551
+ where each row corresponds to an observation and each column to a feature.
4552
+
4553
+ range_n_clusters (range):
4554
+ A range object specifying the number of clusters to evaluate for K-Means clustering.
4555
+ Default is range(2, 11), meaning it will evaluate from 2 to 10 clusters.
4556
+
4557
+ scale (bool):
4558
+ A flag indicating whether to standardize the features before clustering.
4559
+ Default is True, which scales the data to have a mean of 0 and variance of 1.
4560
+
4561
+ plot (bool):
4562
+ A flag indicating whether to generate visualizations of the clustering analysis.
4563
+ Default is True, which will plot silhouette scores, inertia, and other relevant plots.
4564
+ Returns:
4565
+ tuple:
4566
+ A tuple containing the modified DataFrame with cluster labels,
4567
+ the optimal number of clusters, and the Axes object (if any).
4568
+ """
4569
+ X = data[columns].values if columns is not None else data.values
4570
+
4571
+ silhouette_avg_scores = []
4572
+ inertia_scores = []
4573
+
4574
+ # Standardize the features
4575
+ if scale:
4576
+ scaler = StandardScaler()
4577
+ X = scaler.fit_transform(X)
4578
+
4579
+ for n_cluster in range_n_clusters:
4580
+ kmeans = KMeans(n_clusters=n_cluster, random_state=42)
4581
+ cluster_labels = kmeans.fit_predict(X)
4582
+
4583
+ silhouette_avg = silhouette_score(X, cluster_labels)
4584
+ silhouette_avg_scores.append(silhouette_avg)
4585
+ inertia_scores.append(kmeans.inertia_)
4586
+ print(
4587
+ f"For n_clusters = {n_cluster}, the average silhouette_score is : {silhouette_avg:.4f}"
4588
+ )
4589
+
4590
+ # Determine the optimal number of clusters based on the maximum silhouette score
4591
+ if n_clusters is None:
4592
+ n_clusters = range_n_clusters[np.argmax(silhouette_avg_scores)]
4593
+ print(f"n_clusters = {n_clusters}")
4594
+
4595
+ # Apply K-Means Clustering with Optimal Number of Clusters
4596
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42)
4597
+ cluster_labels = kmeans.fit_predict(X)
4598
+
4599
+ if plot:
4600
+ # ! Interpreting the plots from your K-Means clustering analysis
4601
+ # ! 1. Silhouette Score and Inertia vs Number of Clusters
4602
+ # Description:
4603
+ # This plot has two y-axes: the left y-axis represents the Silhouette Score, and the right y-axis
4604
+ # represents Inertia.
4605
+ # The x-axis represents the number of clusters (k).
4606
+
4607
+ # Interpretation:
4608
+
4609
+ # Silhouette Score:
4610
+ # Ranges from -1 to 1, where a score close to 1 indicates that points are well-clustered, while a
4611
+ # score close to -1 indicates that points might be incorrectly clustered.
4612
+ # A higher silhouette score generally suggests that the data points are appropriately clustered.
4613
+ # Look for the highest value to determine the optimal number of clusters.
4614
+
4615
+ # Inertia:
4616
+ # Represents the sum of squared distances from each point to its assigned cluster center.
4617
+ # Lower inertia values indicate tighter clusters.
4618
+ # As the number of clusters increases, inertia typically decreases, but the rate of decrease
4619
+ # may slow down, indicating diminishing returns for additional clusters.
4620
+
4621
+ # Optimal Number of Clusters:
4622
+ # You can identify an optimal number of clusters where the silhouette score is maximized and
4623
+ # inertia starts to plateau (the "elbow" point).
4624
+ # This typically suggests that increasing the number of clusters further yields less meaningful
4625
+ # separations.
4626
+ if ax is None:
4627
+ _, ax = plt.subplots(figsize=inch2cm(10, 6))
4628
+ color = "tab:blue"
4629
+ ax.plot(
4630
+ range_n_clusters,
4631
+ silhouette_avg_scores,
4632
+ marker="o",
4633
+ color=color,
4634
+ label="Silhouette Score",
4635
+ )
4636
+ ax.set_xlabel("Number of Clusters")
4637
+ ax.set_ylabel("Silhouette Score", color=color)
4638
+ ax.tick_params(axis="y", labelcolor=color)
4639
+ # add right axis: inertia
4640
+ ax2 = ax.twinx()
4641
+ color = "tab:red"
4642
+ ax2.set_ylabel("Inertia", color=color)
4643
+ ax2.plot(
4644
+ range_n_clusters,
4645
+ inertia_scores,
4646
+ marker="x",
4647
+ color=color,
4648
+ label="Inertia",
4649
+ )
4650
+ ax2.tick_params(axis="y", labelcolor=color)
4651
+
4652
+ plt.title("Silhouette Score and Inertia vs Number of Clusters")
4653
+ plt.xticks(range_n_clusters)
4654
+ plt.grid()
4655
+ plt.axvline(x=n_clusters, linestyle="--", color="r", label="Optimal n_clusters")
4656
+ # ! 2. Elbow Method Plot
4657
+ # Description:
4658
+ # This plot shows the Inertia against the number of clusters.
4659
+
4660
+ # Interpretation:
4661
+ # The elbow point is where the inertia begins to decrease at a slower rate. This point suggests that
4662
+ # adding more clusters beyond this point does not significantly improve the clustering performance.
4663
+ # Look for a noticeable bend in the curve to identify the optimal number of clusters, indicated by the
4664
+ # vertical dashed line.
4665
+ # Inertia plot
4666
+ plt.figure(figsize=inch2cm(10, 6))
4667
+ plt.plot(range_n_clusters, inertia_scores, marker="o")
4668
+ plt.title("Elbow Method for Optimal k")
4669
+ plt.xlabel("Number of clusters")
4670
+ plt.ylabel("Inertia")
4671
+ plt.grid()
4672
+ plt.axvline(
4673
+ x=np.argmax(silhouette_avg_scores) + 2,
4674
+ linestyle="--",
4675
+ color="r",
4676
+ label="Optimal n_clusters",
4677
+ )
4678
+ plt.legend()
4679
+ # ! Silhouette Plots
4680
+ # 3. Silhouette Plot for Various Clusters
4681
+ # Description:
4682
+ # This horizontal bar plot shows the silhouette coefficient values for each sample, organized by cluster.
4683
+
4684
+ # Interpretation:
4685
+ # Each bar represents the silhouette score of a sample within a specific cluster. Longer bars indicate
4686
+ # that the samples are well-clustered.
4687
+ # The height of the bars shows how similar points within the same cluster are to one another compared to
4688
+ # points in other clusters.
4689
+ # The vertical red dashed line indicates the average silhouette score for all samples.
4690
+ # You want the majority of silhouette values to be above the average line, indicating that most points
4691
+ # are well-clustered.
4692
+
4693
+ # 以下代码不用再跑一次了
4694
+ # n_clusters = (
4695
+ # np.argmax(silhouette_avg_scores) + 2
4696
+ # ) # Optimal clusters based on max silhouette score
4697
+ # kmeans = KMeans(n_clusters=n_clusters, random_state=42)
4698
+ # cluster_labels = kmeans.fit_predict(X)
4699
+ silhouette_vals = silhouette_samples(X, cluster_labels)
4700
+
4701
+ plt.figure(figsize=inch2cm(10, 6))
4702
+ y_lower = 10
4703
+ for i in range(n_clusters):
4704
+ # Aggregate the silhouette scores for samples belonging to cluster i
4705
+ ith_cluster_silhouette_values = silhouette_vals[cluster_labels == i]
4706
+
4707
+ # Sort the values
4708
+ ith_cluster_silhouette_values.sort()
4709
+
4710
+ size_cluster_i = ith_cluster_silhouette_values.shape[0]
4711
+ y_upper = y_lower + size_cluster_i
4712
+
4713
+ # Create a horizontal bar plot for the silhouette scores
4714
+ plt.barh(range(y_lower, y_upper), ith_cluster_silhouette_values, height=0.5)
4715
+
4716
+ # Label the silhouette scores
4717
+ plt.text(-0.05, (y_lower + y_upper) / 2, str(i + 2))
4718
+ y_lower = y_upper + 10 # 10 for the 0 samples
4719
+
4720
+ plt.title("Silhouette Plot for the Various Clusters")
4721
+ plt.xlabel("Silhouette Coefficient Values")
4722
+ plt.ylabel("Cluster Label")
4723
+ plt.axvline(x=np.mean(silhouette_vals), color="red", linestyle="--")
4724
+
4725
+ df_clusters = pd.DataFrame(
4726
+ X, columns=[f"Feature {i+1}" for i in range(X.shape[1])]
4727
+ )
4728
+ df_clusters["Cluster"] = cluster_labels
4729
+ # ! pairplot of the clusters
4730
+ # Overview of the Pairplot
4731
+ # Axes and Grid:
4732
+ # The pairplot creates a grid of scatter plots for each pair of features in your dataset.
4733
+ # Each point in the scatter plots represents a sample from your dataset, colored according to its cluster assignment.
4734
+
4735
+ # Diagonal Elements:
4736
+ # The diagonal plots usually show the distribution of each feature. In this case, since X.shape[1] <= 4,
4737
+ # there will be a maximum of four features plotted against each other. The diagonal could display histograms or
4738
+ # kernel density estimates (KDE) for each feature.
4739
+
4740
+ # Interpretation of the Pairplot
4741
+
4742
+ # Feature Relationships:
4743
+ # Look at each scatter plot in the off-diagonal plots. Each plot shows the relationship between two features. Points that
4744
+ # are close together in the scatter plot suggest similar values for those features.
4745
+ # Cluster Separation: You want to see clusters of different colors (representing different clusters) that are visually distinct.
4746
+ # Good separation indicates that the clustering algorithm effectively identified different groups within your data.
4747
+ # Overlapping Points: If points from different clusters overlap significantly in any scatter plot, it indicates that those clusters
4748
+ # might not be distinct in terms of the two features being compared.
4749
+ # Cluster Characteristics:
4750
+ # Shape and Distribution: Observe the shape of the clusters. Are they spherical, elongated, or irregular? This can give insights
4751
+ # into how well the K-Means (or other clustering methods) has performed:
4752
+ # Spherical Clusters: Indicates that clusters are well defined and separated.
4753
+ # Elongated Clusters: May suggest that the algorithm is capturing variations along specific axes but could benefit from adjustments
4754
+ # in clustering parameters or methods.
4755
+ # Feature Influence: Identify which features contribute most to cluster separation. For instance, if you see that one feature
4756
+ # consistently separates two clusters, it may be a key factor for clustering.
4757
+ # Diagonal Histograms/KDE:
4758
+ # The diagonal plots show the distribution of individual features across all samples. Look for:
4759
+ # Distribution Shape: Is the distribution unimodal, bimodal, skewed, or uniform?
4760
+ # Concentration: Areas with a high density of points may indicate that certain values are more common among samples.
4761
+ # Differences Among Clusters: If you see distinct peaks in the histograms for different clusters, it suggests that those clusters are
4762
+ # characterized by specific ranges of feature values.
4763
+ # Example Observations
4764
+ # Feature 1 vs. Feature 2: If there are clear, well-separated clusters in this scatter plot, it suggests that these two features
4765
+ # effectively distinguish between the clusters.
4766
+ # Feature 3 vs. Feature 4: If you observe significant overlap between clusters in this plot, it may indicate that these features do not
4767
+ # provide a strong basis for clustering.
4768
+ # Diagonal Plots: If you notice that one cluster has a higher density of points at lower values for a specific feature, while another
4769
+ # cluster is concentrated at higher values, this suggests that this feature is critical for differentiating those clusters.
4770
+
4771
+ # Pairplot of the clusters
4772
+ # * 为什么要限制到4个features?
4773
+ # 2 features=1 scatter plot # 3 features=3 scatter plots
4774
+ # 4 features=6 scatter plots # 5 features=10 scatter plots
4775
+ # 6 features=15 scatter plots # 10 features=45 scatter plots
4776
+ # Pairplot works well with low-dimensional data, 如果维度比较高的话, 子图也很多,失去了它的意义
4777
+ if X.shape[1] <= 6:
4778
+ plt.figure(figsize=(8, 4))
4779
+ sns.pairplot(df_clusters, hue="Cluster", palette="tab10")
4780
+ plt.suptitle("Pairplot of Clusters", y=1.02)
4781
+
4782
+ # Add cluster labels to the DataFrame or modify in-place
4783
+ if inplace: # replace the oringinal data
4784
+ data["Cluster"] = cluster_labels
4785
+ return None, n_clusters, kmeans, ax # Return None when inplace is True
4786
+ else:
4787
+ data_copy = data.copy()
4788
+ data_copy["Cluster"] = cluster_labels
4789
+ return data_copy, n_clusters, kmeans, ax
4790
+
4791
+
4792
+ # example:
4793
+ # clustering_features = [marker + "_log" for marker in markers]
4794
+ # df_cluster(data, columns=clustering_features, n_clusters=3,range_n_clusters=np.arange(3, 7))
4795
+
4796
+ """
4797
+ # You're on the right track, but let's clarify how PCA and clustering (like KMeans) work, especially
4798
+ # in the context of your dataset with 7 columns and 23,121 rows.
4799
+
4800
+ # Principal Component Analysis (PCA)
4801
+ # Purpose of PCA:
4802
+ # PCA is a dimensionality reduction technique. It transforms your dataset from a high-dimensional space
4803
+ # (in your case, 7 dimensions corresponding to your 7 columns) to a lower-dimensional space while
4804
+ # retaining as much variance (information) as possible.
4805
+ # How PCA Works:
4806
+ # PCA computes new features called "principal components" that are linear combinations of the original
4807
+ # features.
4808
+ # The first principal component captures the most variance, the second captures the next most variance
4809
+ # (orthogonal to the first), and so on.
4810
+ # If you set n_components=2, for example, PCA will reduce your dataset from 7 columns to 2 columns.
4811
+ # This helps in visualizing and analyzing the data with fewer dimensions.
4812
+ # Result of PCA:
4813
+ # After applying PCA, your original dataset with 7 columns will be transformed into a new dataset with
4814
+ # the specified number of components (e.g., 2 or 3).
4815
+ # The transformed dataset will have fewer columns but should capture most of the important information
4816
+ # from the original dataset.
4817
+
4818
+ # Clustering (KMeans)
4819
+ # Purpose of Clustering:
4820
+ # Clustering is used to group data points based on their similarities. KMeans, specifically, partitions
4821
+ # your data into a specified number of clusters (groups).
4822
+ # How KMeans Works:
4823
+ # KMeans assigns each data point to one of the k clusters based on the feature space (original or
4824
+ # PCA-transformed).
4825
+ # It aims to minimize the variance within each cluster while maximizing the variance between clusters.
4826
+ # It does not classify the data in each column independently; instead, it considers the overall similarity
4827
+ # between data points based on their features.
4828
+ # Result of KMeans:
4829
+ # The output will be cluster labels for each data point (e.g., which cluster a particular observation
4830
+ # belongs to).
4831
+ # You can visualize how many groups were formed and analyze the characteristics of each cluster.
4832
+
4833
+ # Summary
4834
+ # PCA reduces the number of features (columns) in your dataset, transforming it into a lower-dimensional
4835
+ # space.
4836
+ # KMeans then classifies data points based on the features of the transformed dataset (or the original
4837
+ # if you choose) into different subgroups (clusters).
4838
+ # By combining these techniques, you can simplify the complexity of your data and uncover patterns that
4839
+ # might not be visible in the original high-dimensional space. Let me know if you have further questions!
4840
+ """
4841
+
4842
+
4843
+ def df_reducer(
4844
+ data: pd.DataFrame,
4845
+ columns: Optional[List[str]] = None,
4846
+ method: str = "umap", # 'pca', 'umap'
4847
+ n_components: int = 2, # Default for umap, but 50 for PCA
4848
+ umap_neighbors: int = 15, # Default
4849
+ umap_min_dist: float = 0.1, # Default
4850
+ scale: bool = True,
4851
+ fill_missing: bool = True,
4852
+ debug: bool = False,
4853
+ inplace: bool = True, # replace the oringinal data
4854
+ ) -> pd.DataFrame:
4855
+ """
4856
+ Reduces the dimensionality of the selected DataFrame using PCA or UMAP.
4857
+
4858
+ Parameters:
4859
+ -----------
4860
+ data : pd.DataFrame
4861
+ The input DataFrame (samples x features).
4862
+
4863
+ columns : List[str], optional
4864
+ List of column names to reduce. If None, all columns are used.
4865
+
4866
+ method : str, optional, default="umap"
4867
+ Dimensionality reduction method, either "pca" or "umap".
4868
+
4869
+ n_components : int, optional, default=50
4870
+ Number of components for PCA or UMAP.
4871
+
4872
+ umap_neighbors : int, optional, default=15
4873
+ Number of neighbors considered for UMAP embedding.
4874
+
4875
+ umap_min_dist : float, optional, default=0.1
4876
+ Minimum distance between points in UMAP embedding.
4877
+
4878
+ scale : bool, optional, default=True
4879
+ Whether to scale the data using StandardScaler.
4880
+
4881
+ fill_missing : bool, optional, default=True
4882
+ Whether to fill missing values using the mean before applying PCA/UMAP.
4883
+
4884
+ Returns:
4885
+ --------
4886
+ reduced_df : pd.DataFrame
4887
+ DataFrame with the reduced dimensions.
4888
+ """
4889
+ from sklearn.decomposition import PCA
4890
+ from sklearn.preprocessing import StandardScaler
4891
+ import umap
4892
+ from sklearn.impute import SimpleImputer
4893
+
4894
+ # Select columns if specified, else use all columns
4895
+ X = data[columns].values if columns else data.values
4896
+
4897
+ # Handle missing values
4898
+ if fill_missing:
4899
+ imputer = SimpleImputer(strategy="mean")
4900
+ X = imputer.fit_transform(X)
4901
+
4902
+ # Optionally scale the data
4903
+ if scale:
4904
+ scaler = StandardScaler()
4905
+ X = scaler.fit_transform(X)
4906
+
4907
+ # Check valid method input
4908
+ if method not in ["pca", "umap"]:
4909
+ raise ValueError(f"Invalid method '{method}'. Choose 'pca' or 'umap'.")
4910
+
4911
+ # Apply PCA if selected
4912
+ if method == "pca":
4913
+ if n_components is None:
4914
+ # to get the n_components with threshold method:
4915
+ pca = PCA()
4916
+ pca_result = pca.fit_transform(X)
4917
+
4918
+ # Calculate explained variance
4919
+ explained_variance = pca.explained_variance_ratio_
4920
+ # Cumulative explained variance
4921
+ cumulative_variance = np.cumsum(explained_variance)
4922
+ # Set a threshold for cumulative variance
4923
+ threshold = 0.95 # Example threshold
4924
+ n_components = (
4925
+ np.argmax(cumulative_variance >= threshold) + 1
4926
+ ) # Number of components to retain
4927
+ if debug:
4928
+ # debug:
4929
+ # Plot the cumulative explained variance
4930
+ plt.figure(figsize=(8, 5))
4931
+ plt.plot(
4932
+ range(1, len(cumulative_variance) + 1),
4933
+ cumulative_variance,
4934
+ marker="o",
4935
+ linestyle="-",
4936
+ )
4937
+ plt.title("Cumulative Explained Variance by Principal Components")
4938
+ plt.xlabel("Number of Principal Components")
4939
+ plt.ylabel("Cumulative Explained Variance")
4940
+ plt.xticks(range(1, len(cumulative_variance) + 1))
4941
+ # Add horizontal line for the threshold
4942
+ plt.axhline(
4943
+ y=threshold, color="r", linestyle="--", label="Threshold (95%)"
4944
+ )
4945
+ # Add vertical line for n_components
4946
+ plt.axvline(
4947
+ x=n_components,
4948
+ color="g",
4949
+ linestyle="--",
4950
+ label=f"n_components = {n_components}",
4951
+ )
4952
+ plt.legend()
4953
+ plt.grid()
4954
+ pca = PCA(n_components=n_components)
4955
+ X_reduced = pca.fit_transform(X)
4956
+ print(f"PCA completed: Reduced to {n_components} components.")
4957
+
4958
+ # Apply UMAP if selected
4959
+ elif method == "umap":
4960
+ umap_reducer = umap.UMAP(
4961
+ n_neighbors=umap_neighbors,
4962
+ min_dist=umap_min_dist,
4963
+ n_components=n_components,
4964
+ )
4965
+ X_reduced = umap_reducer.fit_transform(X)
4966
+ print(f"UMAP completed: Reduced to {n_components} components.")
4967
+
4968
+ # Return reduced data as a new DataFrame with the same index
4969
+ reduced_df = pd.DataFrame(X_reduced, index=data.index)
4970
+
4971
+ if inplace:
4972
+ # Replace or add new columns based on n_components
4973
+ for col_idx in range(n_components):
4974
+ data[f"Component_{col_idx+1}"] = reduced_df.iloc[:, col_idx]
4975
+ return None # No return when inplace=True
4976
+
4977
+ return reduced_df
4978
+
4979
+
4980
+ # example:
4981
+ # df_reducer(data=data_log, columns=markers, n_components=2)
4982
+
4983
+
4984
+ def plot_cluster(
4985
+ data: pd.DataFrame,
4986
+ labels: np.ndarray,
4987
+ metrics: dict = None,
4988
+ cmap="tab20",
4989
+ true_labels: Optional[np.ndarray] = None,
4990
+ ) -> None:
4991
+ """
4992
+ Visualize clustering results with various plots.
4993
+
4994
+ Parameters:
4995
+ -----------
4996
+ data : pd.DataFrame
4997
+ The input data used for clustering.
4998
+ labels : np.ndarray
4999
+ Cluster labels assigned to each point.
5000
+ metrics : dict
5001
+ Dictionary containing evaluation metrics from evaluate_cluster function.
5002
+ true_labels : Optional[np.ndarray], default=None
5003
+ Ground truth labels, if available.
5004
+ """
5005
+ import seaborn as sns
5006
+ from sklearn.metrics import silhouette_samples
5007
+
5008
+ if metrics is None:
5009
+ metrics = evaluate_cluster(data=data, labels=labels, true_labels=true_labels)
5010
+
5011
+ # 1. Scatter Plot of Clusters
5012
+ plt.figure(figsize=(15, 6))
5013
+ plt.subplot(1, 3, 1)
5014
+ plt.scatter(data.iloc[:, 0], data.iloc[:, 1], c=labels, cmap=cmap, s=20)
5015
+ plt.title("Cluster Scatter Plot")
5016
+ plt.xlabel("Component 1")
5017
+ plt.ylabel("Component 2")
5018
+ plt.colorbar(label="Cluster Label")
5019
+ plt.grid()
5020
+
5021
+ # 2. Silhouette Plot
5022
+ if "Silhouette Score" in metrics:
5023
+ silhouette_vals = silhouette_samples(data, labels)
5024
+ plt.subplot(1, 3, 2)
5025
+ y_lower = 10
5026
+ for i in range(len(set(labels))):
5027
+ # Aggregate the silhouette scores for samples belonging to the current cluster
5028
+ cluster_silhouette_vals = silhouette_vals[labels == i]
5029
+ cluster_silhouette_vals.sort()
5030
+ size_cluster_i = cluster_silhouette_vals.shape[0]
5031
+ y_upper = y_lower + size_cluster_i
5032
+
5033
+ plt.fill_betweenx(np.arange(y_lower, y_upper), 0, cluster_silhouette_vals)
5034
+ plt.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
5035
+ y_lower = y_upper + 10 # 10 for the 0 samples
5036
+
5037
+ plt.title("Silhouette Plot")
5038
+ plt.xlabel("Silhouette Coefficient Values")
5039
+ plt.ylabel("Cluster Label")
5040
+ plt.axvline(x=metrics["Silhouette Score"], color="red", linestyle="--")
5041
+ plt.grid()
5042
+
5043
+ # 3. Metrics Plot
5044
+ plt.subplot(1, 3, 3)
5045
+ metric_names = ["Davies-Bouldin Index", "Calinski-Harabasz Index"]
5046
+ metric_values = [
5047
+ metrics["Davies-Bouldin Index"],
5048
+ metrics["Calinski-Harabasz Index"],
5049
+ ]
5050
+
5051
+ if true_labels is not None:
5052
+ metric_names += ["Homogeneity Score", "Completeness Score", "V-Measure"]
5053
+ metric_values += [
5054
+ metrics["Homogeneity Score"],
5055
+ metrics["Completeness Score"],
5056
+ metrics["V-Measure"],
5057
+ ]
5058
+
5059
+ plt.barh(metric_names, metric_values, color="lightblue")
5060
+ plt.title("Clustering Metrics")
5061
+ plt.xlabel("Score")
5062
+ plt.axvline(x=0, color="gray", linestyle="--")
5063
+ plt.grid()
5064
+ plt.tight_layout()
5065
+
5066
+
5067
+ def evaluate_cluster(
5068
+ data: pd.DataFrame, labels: np.ndarray, true_labels: Optional[np.ndarray] = None
5069
+ ) -> dict:
5070
+ """
5071
+ Evaluate clustering performance using various metrics.
5072
+
5073
+ Parameters:
5074
+ -----------
5075
+ data : pd.DataFrame
5076
+ The input data used for clustering.
5077
+ labels : np.ndarray
5078
+ Cluster labels assigned to each point.
5079
+ true_labels : Optional[np.ndarray], default=None
5080
+ Ground truth labels, if available.
5081
+
5082
+ Returns:
5083
+ --------
5084
+ metrics : dict
5085
+ Dictionary containing evaluation metrics.
5086
+
5087
+ 1. Silhouette Score:
5088
+ The silhouette score measures how similar an object is to its own cluster (cohesion) compared to
5089
+ how similar it is to other clusters (separation). The score ranges from -1 to +1:
5090
+ +1: Indicates that the data point is very far from the neighboring clusters and well clustered.
5091
+ 0: Indicates that the data point is on or very close to the decision boundary between two neighboring
5092
+ clusters.
5093
+ -1: Indicates that the data point might have been assigned to the wrong cluster.
5094
+
5095
+ Interpretation:
5096
+ A higher average silhouette score indicates better-defined clusters.
5097
+ If the score is consistently high (above 0.5), it suggests that the clusters are well separated.
5098
+ A score near 0 may indicate overlapping clusters, while negative scores suggest points may have
5099
+ been misclassified.
5100
+
5101
+ 2. Davies-Bouldin Index:
5102
+ The Davies-Bouldin Index (DBI) measures the average similarity ratio of each cluster with its
5103
+ most similar cluster. The index values range from 0 to ∞, with lower values indicating better clustering.
5104
+ It is defined as the ratio of within-cluster distances to between-cluster distances.
5105
+
5106
+ Interpretation:
5107
+ A lower DBI value indicates that the clusters are compact and well-separated.
5108
+ Ideally, you want to minimize the Davies-Bouldin Index. If your DBI value is above 1, this indicates
5109
+ that your clusters might not be well-separated.
5110
+
5111
+ 3. Adjusted Rand Index (ARI):
5112
+ The Adjusted Rand Index (ARI) is a measure of the similarity between two data clusterings. The ARI
5113
+ score ranges from -1 to +1:
5114
+ 1: Indicates perfect agreement between the two clusterings.
5115
+ 0: Indicates that the clusterings are no better than random.
5116
+ Negative values: Indicate less agreement than expected by chance.
5117
+
5118
+ Interpretation:
5119
+ A higher ARI score indicates better clustering, particularly if it's close to 1.
5120
+ An ARI score of 0 or lower suggests that the clustering results do not represent the true labels
5121
+ well, indicating a poor clustering performance.
5122
+
5123
+ 4. Calinski-Harabasz Index:
5124
+ The Calinski-Harabasz Index (also known as the Variance Ratio Criterion) evaluates the ratio of
5125
+ the sum of between-cluster dispersion to within-cluster dispersion. Higher values indicate better clustering.
5126
+
5127
+ Interpretation:
5128
+ A higher Calinski-Harabasz Index suggests that clusters are dense and well-separated. It is typically
5129
+ used to validate the number of clusters, with higher values favoring more distinct clusters.
5130
+
5131
+ 5. Homogeneity Score:
5132
+ The homogeneity score measures how much a cluster contains only members of a single class (if true labels are provided).
5133
+ A score of 1 indicates perfect homogeneity, where all points in a cluster belong to the same class.
5134
+
5135
+ Interpretation:
5136
+ A higher homogeneity score indicates that the clustering result is pure, meaning the clusters are composed
5137
+ of similar members. Lower values indicate mixed clusters, suggesting poor clustering performance.
5138
+
5139
+ 6. Completeness Score:
5140
+ The completeness score evaluates how well all members of a given class are assigned to the same cluster.
5141
+ A score of 1 indicates perfect completeness, meaning all points in a true class are assigned to a single cluster.
5142
+
5143
+ Interpretation:
5144
+ A higher completeness score indicates that the clustering effectively groups all instances of a class together.
5145
+ Lower values suggest that some instances of a class are dispersed among multiple clusters.
5146
+
5147
+ 7. V-Measure:
5148
+ The V-measure is the harmonic mean of homogeneity and completeness, giving a balanced measure of clustering performance.
5149
+
5150
+ Interpretation:
5151
+ A higher V-measure score indicates that the clusters are both homogenous (pure) and complete (cover all members of a class).
5152
+ Scores closer to 1 indicate better clustering quality.
5153
+ """
5154
+ from sklearn.metrics import (
5155
+ silhouette_score,
5156
+ davies_bouldin_score,
5157
+ adjusted_rand_score,
5158
+ calinski_harabasz_score,
5159
+ homogeneity_score,
5160
+ completeness_score,
5161
+ v_measure_score,
5162
+ )
5163
+
5164
+ metrics = {}
5165
+ unique_labels = set(labels)
5166
+ if len(unique_labels) > 1 and len(unique_labels) < len(data):
5167
+ # Calculate Silhouette Score
5168
+ try:
5169
+ metrics["Silhouette Score"] = silhouette_score(data, labels)
5170
+ except Exception as e:
5171
+ metrics["Silhouette Score"] = np.nan
5172
+ print(f"Silhouette Score calculation failed: {e}")
5173
+
5174
+ # Calculate Davies-Bouldin Index
5175
+ try:
5176
+ metrics["Davies-Bouldin Index"] = davies_bouldin_score(data, labels)
5177
+ except Exception as e:
5178
+ metrics["Davies-Bouldin Index"] = np.nan
5179
+ print(f"Davies-Bouldin Index calculation failed: {e}")
5180
+
5181
+ # Calculate Calinski-Harabasz Index
5182
+ try:
5183
+ metrics["Calinski-Harabasz Index"] = calinski_harabasz_score(data, labels)
5184
+ except Exception as e:
5185
+ metrics["Calinski-Harabasz Index"] = np.nan
5186
+ print(f"Calinski-Harabasz Index calculation failed: {e}")
5187
+
5188
+ # Calculate Adjusted Rand Index if true labels are provided
5189
+ if true_labels is not None:
5190
+ try:
5191
+ metrics["Adjusted Rand Index"] = adjusted_rand_score(
5192
+ true_labels, labels
5193
+ )
5194
+ except Exception as e:
5195
+ metrics["Adjusted Rand Index"] = np.nan
5196
+ print(f"Adjusted Rand Index calculation failed: {e}")
5197
+
5198
+ # Calculate Homogeneity Score
5199
+ try:
5200
+ metrics["Homogeneity Score"] = homogeneity_score(true_labels, labels)
5201
+ except Exception as e:
5202
+ metrics["Homogeneity Score"] = np.nan
5203
+ print(f"Homogeneity Score calculation failed: {e}")
5204
+
5205
+ # Calculate Completeness Score
5206
+ try:
5207
+ metrics["Completeness Score"] = completeness_score(true_labels, labels)
5208
+ except Exception as e:
5209
+ metrics["Completeness Score"] = np.nan
5210
+ print(f"Completeness Score calculation failed: {e}")
5211
+
5212
+ # Calculate V-Measure
5213
+ try:
5214
+ metrics["V-Measure"] = v_measure_score(true_labels, labels)
5215
+ except Exception as e:
5216
+ metrics["V-Measure"] = np.nan
5217
+ print(f"V-Measure calculation failed: {e}")
5218
+ else:
5219
+ # Metrics cannot be computed with 1 cluster or all points as noise
5220
+ metrics["Silhouette Score"] = np.nan
5221
+ metrics["Davies-Bouldin Index"] = np.nan
5222
+ metrics["Calinski-Harabasz Index"] = np.nan
5223
+ if true_labels is not None:
5224
+ metrics["Adjusted Rand Index"] = np.nan
5225
+ metrics["Homogeneity Score"] = np.nan
5226
+ metrics["Completeness Score"] = np.nan
5227
+ metrics["V-Measure"] = np.nan
5228
+
5229
+ return metrics