py2ls 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py2ls/.DS_Store +0 -0
- py2ls/data/sns_info.json +74 -0
- py2ls/data/usages_pd.json +56 -0
- py2ls/data/usages_sns.json +25 -0
- py2ls/ips.py +1253 -517
- py2ls/plot.py +746 -30
- py2ls/stats.py +18 -9
- py2ls/update2usage.py +126 -0
- {py2ls-0.2.1.dist-info → py2ls-0.2.3.dist-info}/METADATA +1 -1
- {py2ls-0.2.1.dist-info → py2ls-0.2.3.dist-info}/RECORD +11 -7
- {py2ls-0.2.1.dist-info → py2ls-0.2.3.dist-info}/WHEEL +0 -0
py2ls/ips.py
CHANGED
@@ -793,6 +793,99 @@ def str2html(text_list, strict=False):
|
|
793
793
|
return html_content
|
794
794
|
|
795
795
|
|
796
|
+
def cm2px(*cm, dpi=300) -> list:
|
797
|
+
# Case 1: When the user passes a single argument that is a list or tuple, such as cm2px([8, 5]) or inch2cm((8, 5))
|
798
|
+
if len(cm) == 1 and isinstance(cm[0], (list, tuple)):
|
799
|
+
# If the input is a single list or tuple, we unpack its elements and convert each to cm
|
800
|
+
return [i / 2.54 * dpi for i in cm[0]]
|
801
|
+
# Case 2: When the user passes multiple arguments directly, such as cm2px(8, 5)
|
802
|
+
else:
|
803
|
+
return [i / 2.54 * dpi for i in cm]
|
804
|
+
|
805
|
+
|
806
|
+
def px2cm(*px, dpi=300) -> list:
|
807
|
+
# Case 1: When the user passes a single argument that is a list or tuple, such as px2cm([8, 5]) or inch2cm((8, 5))
|
808
|
+
if len(px) == 1 and isinstance(px[0], (list, tuple)):
|
809
|
+
# If the input is a single list or tuple, we unpack its elements and convert each to cm
|
810
|
+
return [i * 2.54 / dpi for i in px[0]]
|
811
|
+
# Case 2: When the user passes multiple arguments directly, such as px2cm(8, 5)
|
812
|
+
else:
|
813
|
+
# Here, we convert each individual argument directly to cm
|
814
|
+
return [i * 2.54 / dpi for i in px]
|
815
|
+
|
816
|
+
|
817
|
+
def px2inch(*px, dpi=300) -> list:
|
818
|
+
"""
|
819
|
+
px2inch: converts pixel measurements to inches based on the given dpi.
|
820
|
+
Usage:
|
821
|
+
px2inch(300, 600, dpi=300); px2inch([300, 600], dpi=300)
|
822
|
+
Returns:
|
823
|
+
list: in inches
|
824
|
+
"""
|
825
|
+
# Case 1: When the user passes a single argument that is a list or tuple, such as px2inch([300, 600]) or px2inch((300, 600))
|
826
|
+
if len(px) == 1 and isinstance(px[0], (list, tuple)):
|
827
|
+
# If the input is a single list or tuple, we unpack its elements and convert each to inches
|
828
|
+
return [i / dpi for i in px[0]]
|
829
|
+
# Case 2: When the user passes multiple arguments directly, such as px2inch(300, 600)
|
830
|
+
else:
|
831
|
+
# Here, we convert each individual argument directly to inches
|
832
|
+
return [i / dpi for i in px]
|
833
|
+
|
834
|
+
|
835
|
+
def cm2inch(*cm) -> list:
|
836
|
+
"""
|
837
|
+
cm2inch: converts centimeter measurements to inches.
|
838
|
+
Usage:
|
839
|
+
cm2inch(10, 12.7); cm2inch((10, 12.7)); cm2inch([10, 12.7])
|
840
|
+
Returns:
|
841
|
+
list: in inches
|
842
|
+
"""
|
843
|
+
# Case 1: When the user passes a single argument that is a list or tuple, such as cm2inch([10, 12.7]) or cm2inch((10, 12.7))
|
844
|
+
if len(cm) == 1 and isinstance(cm[0], (list, tuple)):
|
845
|
+
# If the input is a single list or tuple, we unpack its elements and convert each to inches
|
846
|
+
return [i * 2.54 for i in cm[0]]
|
847
|
+
# Case 2: When the user passes multiple arguments directly, such as cm2inch(10, 12.7)
|
848
|
+
else:
|
849
|
+
# Here, we convert each individual argument directly to inches
|
850
|
+
return [i * 2.54 for i in cm]
|
851
|
+
|
852
|
+
|
853
|
+
def inch2px(*inch, dpi=300) -> list:
|
854
|
+
"""
|
855
|
+
inch2px: converts inch measurements to pixels based on the given dpi.
|
856
|
+
Usage:
|
857
|
+
inch2px(1, 2, dpi=300); inch2px([1, 2], dpi=300)
|
858
|
+
Returns:
|
859
|
+
list: in pixels
|
860
|
+
"""
|
861
|
+
# Case 1: When the user passes a single argument that is a list or tuple, such as inch2px([1, 2]) or inch2px((1, 2))
|
862
|
+
if len(inch) == 1 and isinstance(inch[0], (list, tuple)):
|
863
|
+
# If the input is a single list or tuple, we unpack its elements and convert each to pixels
|
864
|
+
return [i * dpi for i in inch[0]]
|
865
|
+
# Case 2: When the user passes multiple arguments directly, such as inch2px(1, 2)
|
866
|
+
else:
|
867
|
+
# Here, we convert each individual argument directly to pixels
|
868
|
+
return [i * dpi for i in inch]
|
869
|
+
|
870
|
+
|
871
|
+
def inch2cm(*inch) -> list:
|
872
|
+
"""
|
873
|
+
inch2cm: converts inch measurements to centimeters.
|
874
|
+
Usage:
|
875
|
+
inch2cm(8,5); inch2cm((8,5)); inch2cm([8,5])
|
876
|
+
Returns:
|
877
|
+
list: in centimeters
|
878
|
+
"""
|
879
|
+
# Case 1: When the user passes a single argument that is a list or tuple, such as inch2cm([8, 5]) or inch2cm((8, 5))
|
880
|
+
if len(inch) == 1 and isinstance(inch[0], (list, tuple)):
|
881
|
+
# If the input is a single list or tuple, we unpack its elements and convert each to cm
|
882
|
+
return [i / 2.54 for i in inch[0]]
|
883
|
+
# Case 2: When the user passes multiple arguments directly, such as inch2cm(8, 5)
|
884
|
+
else:
|
885
|
+
# Here, we convert each individual argument directly to cm
|
886
|
+
return [i / 2.54 for i in inch]
|
887
|
+
|
888
|
+
|
796
889
|
def sreplace(*args, **kwargs):
|
797
890
|
"""
|
798
891
|
sreplace(text, by=None, robust=True)
|
@@ -1363,6 +1456,80 @@ def unzip(dir_path, output_dir=None):
|
|
1363
1456
|
# output_dir_7z = unzip('archive.7z')
|
1364
1457
|
|
1365
1458
|
|
1459
|
+
def is_df_abnormal(df: pd.DataFrame, verbose=False) -> bool:
|
1460
|
+
"""
|
1461
|
+
Usage
|
1462
|
+
is_abnormal = is_df_abnormal(df, verbose=1)
|
1463
|
+
|
1464
|
+
"""
|
1465
|
+
# Initialize a list to hold messages about abnormalities
|
1466
|
+
messages = []
|
1467
|
+
is_abnormal = False
|
1468
|
+
# Check the shape of the DataFrame
|
1469
|
+
actual_shape = df.shape
|
1470
|
+
messages.append(f"Shape of DataFrame: {actual_shape}")
|
1471
|
+
|
1472
|
+
# Check column names
|
1473
|
+
column_names = df.columns.tolist()
|
1474
|
+
|
1475
|
+
# Count of delimiters and their occurrences
|
1476
|
+
delimiter_counts = {"\t": 0, ",": 0, "\n": 0, "": 0} # Count of empty strings
|
1477
|
+
|
1478
|
+
for name in column_names:
|
1479
|
+
# Count occurrences of each delimiter
|
1480
|
+
delimiter_counts["\t"] += name.count("\t")
|
1481
|
+
delimiter_counts[","] += name.count(",")
|
1482
|
+
delimiter_counts["\n"] += name.count("\n")
|
1483
|
+
if name.strip() == "":
|
1484
|
+
delimiter_counts[""] += 1
|
1485
|
+
|
1486
|
+
# Check for abnormalities based on delimiter counts
|
1487
|
+
if len(column_names) == 1 and delimiter_counts["\t"] > 1:
|
1488
|
+
messages.append("Abnormal: Column names are not split correctly.")
|
1489
|
+
is_abnormal = True
|
1490
|
+
|
1491
|
+
if any(delimiter_counts[d] > 3 for d in delimiter_counts if d != ""):
|
1492
|
+
messages.append("Abnormal: Too many delimiters in column names.")
|
1493
|
+
is_abnormal = True
|
1494
|
+
|
1495
|
+
if delimiter_counts[""] > 3:
|
1496
|
+
messages.append("Abnormal: There are empty column names.")
|
1497
|
+
is_abnormal = True
|
1498
|
+
|
1499
|
+
if any(delimiter_counts[d] > 3 for d in ["\t", ",", "\n"]):
|
1500
|
+
messages.append("Abnormal: Some column names contain unexpected characters.")
|
1501
|
+
is_abnormal = True
|
1502
|
+
|
1503
|
+
# Check for missing values
|
1504
|
+
missing_values = df.isnull().sum()
|
1505
|
+
if missing_values.any():
|
1506
|
+
messages.append("Missing values in columns:")
|
1507
|
+
messages.append(missing_values[missing_values > 0].to_string())
|
1508
|
+
is_abnormal = True
|
1509
|
+
|
1510
|
+
# Check data types
|
1511
|
+
data_types = df.dtypes
|
1512
|
+
# messages.append(f"Data types of columns:\n{data_types}")
|
1513
|
+
|
1514
|
+
# Check for constant values across any column
|
1515
|
+
constant_columns = df.columns[df.nunique() == 1].tolist()
|
1516
|
+
if constant_columns:
|
1517
|
+
messages.append(f"Abnormal: Columns with constant values: {constant_columns}")
|
1518
|
+
is_abnormal = True
|
1519
|
+
|
1520
|
+
# Check for an unreasonable number of rows or columns
|
1521
|
+
if actual_shape[0] < 2 or actual_shape[1] < 2:
|
1522
|
+
messages.append(
|
1523
|
+
"Abnormal: DataFrame is too small (less than 2 rows or columns)."
|
1524
|
+
)
|
1525
|
+
is_abnormal = True
|
1526
|
+
|
1527
|
+
# Compile results
|
1528
|
+
if verbose:
|
1529
|
+
print("\n".join(messages))
|
1530
|
+
return is_abnormal # Data is abnormal
|
1531
|
+
|
1532
|
+
|
1366
1533
|
def fload(fpath, kind=None, **kwargs):
|
1367
1534
|
"""
|
1368
1535
|
Load content from a file with specified file type.
|
@@ -1384,10 +1551,14 @@ def fload(fpath, kind=None, **kwargs):
|
|
1384
1551
|
content = file.read()
|
1385
1552
|
return content
|
1386
1553
|
|
1387
|
-
def load_json(fpath):
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1554
|
+
def load_json(fpath, **kwargs):
|
1555
|
+
output=kwargs.pop("output","json")
|
1556
|
+
if output=='json':
|
1557
|
+
with open(fpath, "r") as file:
|
1558
|
+
content = json.load(file)
|
1559
|
+
return content
|
1560
|
+
else:
|
1561
|
+
return pd.read_json(fpath,**kwargs)
|
1391
1562
|
|
1392
1563
|
def load_yaml(fpath):
|
1393
1564
|
with open(fpath, "r") as file:
|
@@ -1399,9 +1570,43 @@ def fload(fpath, kind=None, **kwargs):
|
|
1399
1570
|
root = tree.getroot()
|
1400
1571
|
return etree.tostring(root, pretty_print=True).decode()
|
1401
1572
|
|
1573
|
+
def get_comment(fpath, comment=None, encoding="utf-8", lines_to_check=5):
|
1574
|
+
"""
|
1575
|
+
Detect comment characters in a file.
|
1576
|
+
|
1577
|
+
Parameters:
|
1578
|
+
- fpath: str, the file path of the CSV file.
|
1579
|
+
- encoding: str, the encoding of the file (default is 'utf-8').
|
1580
|
+
- lines_to_check: int, number of lines to check for comment characters (default is 5).
|
1581
|
+
|
1582
|
+
Returns:
|
1583
|
+
- str or None: the detected comment character, or None if no comment character is found.
|
1584
|
+
"""
|
1585
|
+
comment_chars = [
|
1586
|
+
"#",
|
1587
|
+
"!",
|
1588
|
+
"//",
|
1589
|
+
";",
|
1590
|
+
] # can use any character or string as a comment
|
1591
|
+
try:
|
1592
|
+
with open(fpath, "r", encoding=encoding) as f:
|
1593
|
+
lines = [next(f) for _ in range(lines_to_check)]
|
1594
|
+
except (UnicodeDecodeError, ValueError):
|
1595
|
+
with open(fpath, "r", encoding=get_encoding(fpath)) as f:
|
1596
|
+
lines = [next(f) for _ in range(lines_to_check)]
|
1597
|
+
for line in lines:
|
1598
|
+
for char in comment_chars:
|
1599
|
+
if line.startswith(char):
|
1600
|
+
return char
|
1601
|
+
return None
|
1602
|
+
|
1402
1603
|
def load_csv(fpath, **kwargs):
|
1604
|
+
from pandas.errors import EmptyDataError
|
1605
|
+
|
1403
1606
|
engine = kwargs.get("engine", "pyarrow")
|
1404
1607
|
kwargs.pop("engine", None)
|
1608
|
+
sep = kwargs.get("sep", "\t")
|
1609
|
+
kwargs.pop("sep", None)
|
1405
1610
|
index_col = kwargs.get("index_col", None)
|
1406
1611
|
kwargs.pop("index_col", None)
|
1407
1612
|
memory_map = kwargs.get("memory_map", True)
|
@@ -1410,53 +1615,146 @@ def fload(fpath, kind=None, **kwargs):
|
|
1410
1615
|
kwargs.pop("skipinitialspace", None)
|
1411
1616
|
encoding = kwargs.get("encoding", "utf-8")
|
1412
1617
|
kwargs.pop("encoding", None)
|
1618
|
+
on_bad_lines = kwargs.get("on_bad_lines", "skip")
|
1619
|
+
kwargs.pop("on_bad_lines", None)
|
1620
|
+
comment = kwargs.get("comment", None)
|
1621
|
+
kwargs.pop("comment", None)
|
1622
|
+
|
1623
|
+
fmt=kwargs.pop("fmt",False)
|
1624
|
+
if verbose:
|
1625
|
+
print_pd_usage("read_csv", verbose=verbose)
|
1626
|
+
return
|
1627
|
+
|
1628
|
+
if comment is None:
|
1629
|
+
comment = get_comment(
|
1630
|
+
fpath, comment=None, encoding="utf-8", lines_to_check=5
|
1631
|
+
)
|
1632
|
+
|
1413
1633
|
try:
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
|
1451
|
-
|
1452
|
-
|
1453
|
-
|
1454
|
-
|
1634
|
+
df = pd.read_csv(
|
1635
|
+
fpath,
|
1636
|
+
engine=engine,
|
1637
|
+
index_col=index_col,
|
1638
|
+
memory_map=memory_map,
|
1639
|
+
encoding=encoding,
|
1640
|
+
comment=comment,
|
1641
|
+
skipinitialspace=skipinitialspace,
|
1642
|
+
sep=sep,
|
1643
|
+
on_bad_lines=on_bad_lines,
|
1644
|
+
**kwargs,
|
1645
|
+
)
|
1646
|
+
except:
|
1647
|
+
try:
|
1648
|
+
try:
|
1649
|
+
if engine == "pyarrow":
|
1650
|
+
df = pd.read_csv(
|
1651
|
+
fpath,
|
1652
|
+
engine=engine,
|
1653
|
+
index_col=index_col,
|
1654
|
+
encoding=encoding,
|
1655
|
+
sep=sep,
|
1656
|
+
on_bad_lines=on_bad_lines,
|
1657
|
+
comment=comment,
|
1658
|
+
**kwargs,
|
1659
|
+
)
|
1660
|
+
else:
|
1661
|
+
df = pd.read_csv(
|
1662
|
+
fpath,
|
1663
|
+
engine=engine,
|
1664
|
+
index_col=index_col,
|
1665
|
+
memory_map=memory_map,
|
1666
|
+
encoding=encoding,
|
1667
|
+
sep=sep,
|
1668
|
+
skipinitialspace=skipinitialspace,
|
1669
|
+
on_bad_lines=on_bad_lines,
|
1670
|
+
comment=comment,
|
1671
|
+
**kwargs,
|
1672
|
+
)
|
1673
|
+
|
1674
|
+
if is_df_abnormal(df, verbose=0):
|
1675
|
+
raise ValueError("the df is abnormal")
|
1676
|
+
except (UnicodeDecodeError, ValueError):
|
1677
|
+
encoding = get_encoding(fpath)
|
1678
|
+
# print(f"utf-8 failed. Retrying with detected encoding: {encoding}")
|
1679
|
+
if engine == "pyarrow":
|
1680
|
+
df = pd.read_csv(
|
1681
|
+
fpath,
|
1682
|
+
engine=engine,
|
1683
|
+
index_col=index_col,
|
1684
|
+
encoding=encoding,
|
1685
|
+
sep=sep,
|
1686
|
+
on_bad_lines=on_bad_lines,
|
1687
|
+
comment=comment,
|
1688
|
+
**kwargs,
|
1689
|
+
)
|
1690
|
+
else:
|
1691
|
+
df = pd.read_csv(
|
1692
|
+
fpath,
|
1693
|
+
engine=engine,
|
1694
|
+
index_col=index_col,
|
1695
|
+
memory_map=memory_map,
|
1696
|
+
encoding=encoding,
|
1697
|
+
sep=sep,
|
1698
|
+
skipinitialspace=skipinitialspace,
|
1699
|
+
on_bad_lines=on_bad_lines,
|
1700
|
+
comment=comment,
|
1701
|
+
**kwargs,
|
1702
|
+
)
|
1703
|
+
if is_df_abnormal(df, verbose=0):
|
1704
|
+
raise ValueError("the df is abnormal")
|
1705
|
+
except Exception as e:
|
1706
|
+
separators = [",", "\t", ";", "|", " "]
|
1707
|
+
for sep in separators:
|
1708
|
+
sep2show = sep if sep != "\t" else "\\t"
|
1709
|
+
# print(f'trying with: engine=pyarrow, sep="{sep2show}"')
|
1710
|
+
try:
|
1711
|
+
df = pd.read_csv(
|
1712
|
+
fpath,
|
1713
|
+
engine=engine,
|
1714
|
+
skipinitialspace=skipinitialspace,
|
1715
|
+
sep=sep,
|
1716
|
+
on_bad_lines=on_bad_lines,
|
1717
|
+
comment=comment,
|
1718
|
+
**kwargs,
|
1719
|
+
)
|
1720
|
+
if not is_df_abnormal(df, verbose=0): # normal
|
1721
|
+
break
|
1722
|
+
else:
|
1723
|
+
if is_df_abnormal(df, verbose=0):
|
1724
|
+
pass
|
1725
|
+
except:
|
1726
|
+
pass
|
1727
|
+
else:
|
1728
|
+
engines = ["c", "python"]
|
1729
|
+
for engine in engines:
|
1730
|
+
# separators = [",", "\t", ";", "|", " "]
|
1731
|
+
for sep in separators:
|
1732
|
+
try:
|
1733
|
+
sep2show = sep if sep != "\t" else "\\t"
|
1734
|
+
# print(f"trying with: engine={engine}, sep='{sep2show}'")
|
1735
|
+
df = pd.read_csv(
|
1736
|
+
fpath,
|
1737
|
+
engine=engine,
|
1738
|
+
sep=sep,
|
1739
|
+
on_bad_lines=on_bad_lines,
|
1740
|
+
comment=comment,
|
1741
|
+
**kwargs,
|
1742
|
+
)
|
1743
|
+
if not is_df_abnormal(df, verbose=0):
|
1744
|
+
break
|
1745
|
+
except EmptyDataError as e:
|
1746
|
+
continue
|
1747
|
+
else:
|
1748
|
+
pass
|
1749
|
+
display(df.head(2))
|
1750
|
+
print(f"shape: {df.shape}")
|
1455
1751
|
return df
|
1456
1752
|
|
1457
1753
|
def load_xlsx(fpath, **kwargs):
|
1458
1754
|
engine = kwargs.get("engine", "openpyxl")
|
1459
|
-
kwargs.pop("
|
1755
|
+
verbose=kwargs.pop("verbose",False)
|
1756
|
+
if verbose:
|
1757
|
+
print_pd_usage("read_excel", verbose=verbose)
|
1460
1758
|
df = pd.read_excel(fpath, engine=engine, **kwargs)
|
1461
1759
|
return df
|
1462
1760
|
|
@@ -1526,7 +1824,6 @@ def fload(fpath, kind=None, **kwargs):
|
|
1526
1824
|
if kind is None:
|
1527
1825
|
_, kind = os.path.splitext(fpath)
|
1528
1826
|
kind = kind.lower()
|
1529
|
-
|
1530
1827
|
kind = kind.lstrip(".").lower()
|
1531
1828
|
img_types = [
|
1532
1829
|
"bmp",
|
@@ -1575,9 +1872,12 @@ def fload(fpath, kind=None, **kwargs):
|
|
1575
1872
|
"rar",
|
1576
1873
|
"tgz",
|
1577
1874
|
]
|
1578
|
-
|
1875
|
+
other_types = ["fcs"]
|
1876
|
+
supported_types = [*doc_types, *img_types, *zip_types, *other_types]
|
1579
1877
|
if kind not in supported_types:
|
1580
|
-
print(
|
1878
|
+
print(
|
1879
|
+
f'Warning:\n"{kind}" is not in the supported list '
|
1880
|
+
) # {supported_types}')
|
1581
1881
|
# if os.path.splitext(fpath)[1][1:].lower() in zip_types:
|
1582
1882
|
# keep=kwargs.get("keep", False)
|
1583
1883
|
# ifile=kwargs.get("ifile",(0,0))
|
@@ -1607,7 +1907,8 @@ def fload(fpath, kind=None, **kwargs):
|
|
1607
1907
|
elif kind == "xml":
|
1608
1908
|
return load_xml(fpath)
|
1609
1909
|
elif kind == "csv":
|
1610
|
-
|
1910
|
+
content = load_csv(fpath, **kwargs)
|
1911
|
+
return content
|
1611
1912
|
elif kind in ["ods", "ods", "odt"]:
|
1612
1913
|
engine = kwargs.get("engine", "odf")
|
1613
1914
|
kwargs.pop("engine", None)
|
@@ -1615,9 +1916,13 @@ def fload(fpath, kind=None, **kwargs):
|
|
1615
1916
|
elif kind == "xls":
|
1616
1917
|
engine = kwargs.get("engine", "xlrd")
|
1617
1918
|
kwargs.pop("engine", None)
|
1618
|
-
|
1919
|
+
content = load_xlsx(fpath, engine=engine, **kwargs)
|
1920
|
+
display(content.head(2))
|
1921
|
+
return content
|
1619
1922
|
elif kind == "xlsx":
|
1620
|
-
|
1923
|
+
content = load_xlsx(fpath, **kwargs)
|
1924
|
+
display(content.head(2))
|
1925
|
+
return content
|
1621
1926
|
elif kind == "ipynb":
|
1622
1927
|
return load_ipynb(fpath, **kwargs)
|
1623
1928
|
elif kind == "pdf":
|
@@ -1653,7 +1958,18 @@ def fload(fpath, kind=None, **kwargs):
|
|
1653
1958
|
|
1654
1959
|
gene_sets = gp.read_gmt(fpath)
|
1655
1960
|
return gene_sets
|
1961
|
+
|
1962
|
+
elif kind.lower() == "fcs":
|
1963
|
+
import fcsparser
|
1964
|
+
|
1965
|
+
# https://github.com/eyurtsev/fcsparser
|
1966
|
+
meta, data = fcsparser.parse(fpath, reformat_meta=True)
|
1967
|
+
return meta, data
|
1968
|
+
|
1656
1969
|
else:
|
1970
|
+
# try:
|
1971
|
+
# content = load_csv(fpath, **kwargs)
|
1972
|
+
# except:
|
1657
1973
|
try:
|
1658
1974
|
try:
|
1659
1975
|
with open(fpath, "r", encoding="utf-8") as f:
|
@@ -2126,13 +2442,23 @@ def get_os():
|
|
2126
2442
|
|
2127
2443
|
def listdir(
|
2128
2444
|
rootdir,
|
2129
|
-
kind=
|
2445
|
+
kind=None,
|
2130
2446
|
sort_by="name",
|
2131
2447
|
ascending=True,
|
2132
2448
|
contains=None,
|
2133
2449
|
orient="list",
|
2134
2450
|
output="df", # 'list','dict','records','index','series'
|
2451
|
+
verbose=True,
|
2135
2452
|
):
|
2453
|
+
if kind is None:
|
2454
|
+
df_all = pd.DataFrame(
|
2455
|
+
{
|
2456
|
+
"fname": os.listdir(rootdir),
|
2457
|
+
"fpath": [os.path.join(rootdir, i) for i in os.listdir(rootdir)],
|
2458
|
+
}
|
2459
|
+
)
|
2460
|
+
display(df_all)
|
2461
|
+
return df_all
|
2136
2462
|
if isinstance(kind, list):
|
2137
2463
|
f_ = []
|
2138
2464
|
for kind_ in kind:
|
@@ -2204,7 +2530,7 @@ def listdir(
|
|
2204
2530
|
|
2205
2531
|
f["num"] = i
|
2206
2532
|
f["rootdir"] = rootdir
|
2207
|
-
f["os"] = os.uname().machine
|
2533
|
+
f["os"] = get_os() # os.uname().machine
|
2208
2534
|
else:
|
2209
2535
|
raise FileNotFoundError(
|
2210
2536
|
'The directory "{}" does NOT exist. Please check the directory "rootdir".'.format(
|
@@ -2227,6 +2553,8 @@ def listdir(
|
|
2227
2553
|
f = sort_kind(f, by="size", ascending=ascending)
|
2228
2554
|
|
2229
2555
|
if "df" in output:
|
2556
|
+
if verbose:
|
2557
|
+
display(f.head())
|
2230
2558
|
return f
|
2231
2559
|
else:
|
2232
2560
|
if "l" in orient.lower(): # list # default
|
@@ -2259,32 +2587,54 @@ def func_list(lib_name, opt="call"):
|
|
2259
2587
|
return list_func(lib_name, opt=opt)
|
2260
2588
|
|
2261
2589
|
|
2262
|
-
def
|
2590
|
+
def mkdir_nest(fpath: str) -> str:
|
2263
2591
|
"""
|
2264
|
-
|
2265
|
-
|
2266
|
-
|
2267
|
-
|
2268
|
-
|
2592
|
+
Create nested directories based on the provided file path.
|
2593
|
+
|
2594
|
+
Parameters:
|
2595
|
+
- fpath (str): The full file path for which the directories should be created.
|
2596
|
+
|
2269
2597
|
Returns:
|
2270
|
-
|
2598
|
+
- str: The path of the created directory.
|
2271
2599
|
"""
|
2272
|
-
|
2273
|
-
|
2274
|
-
|
2275
|
-
|
2276
|
-
|
2277
|
-
|
2278
|
-
|
2279
|
-
|
2280
|
-
|
2281
|
-
|
2282
|
-
|
2283
|
-
|
2284
|
-
|
2285
|
-
|
2600
|
+
# Check if the directory already exists
|
2601
|
+
if os.path.isdir(fpath):
|
2602
|
+
return fpath
|
2603
|
+
|
2604
|
+
# Split the full path into directories
|
2605
|
+
f_slash = "/" if "mac" in get_os().lower() else "\\"
|
2606
|
+
dir_parts = fpath.split(f_slash) # Split the path by the OS-specific separator
|
2607
|
+
|
2608
|
+
# Start creating directories from the root to the desired path
|
2609
|
+
current_path = ""
|
2610
|
+
for part in dir_parts:
|
2611
|
+
if part:
|
2612
|
+
current_path = os.path.join(current_path, part)
|
2613
|
+
if not os.path.isdir(current_path):
|
2614
|
+
os.makedirs(current_path)
|
2615
|
+
if not current_path.endswith(f_slash):
|
2616
|
+
current_path += f_slash
|
2617
|
+
return current_path
|
2618
|
+
|
2619
|
+
|
2620
|
+
def mkdir(pardir: str = None, chdir: str | list = None, overwrite=False):
|
2621
|
+
"""
|
2622
|
+
Create a directory.
|
2623
|
+
|
2624
|
+
Parameters:
|
2625
|
+
- pardir (str): Parent directory where the new directory will be created. If None, uses the current working directory.
|
2626
|
+
- chdir (str | list): Name of the new directory or a list of directories to create.
|
2627
|
+
If None, a default name 'new_directory' will be used.
|
2628
|
+
- overwrite (bool): If True, overwrite the directory if it already exists. Defaults to False.
|
2629
|
+
|
2630
|
+
Returns:
|
2631
|
+
- str: The path of the created directory or an error message.
|
2632
|
+
"""
|
2633
|
+
|
2286
2634
|
rootdir = []
|
2287
2635
|
# Convert string to list
|
2636
|
+
if chdir is None:
|
2637
|
+
return mkdir_nest(pardir)
|
2288
2638
|
if isinstance(chdir, str):
|
2289
2639
|
chdir = [chdir]
|
2290
2640
|
# Subfoldername should be unique
|
@@ -2331,54 +2681,111 @@ def mkdir(*args, **kwargs):
|
|
2331
2681
|
return rootdir
|
2332
2682
|
|
2333
2683
|
|
2684
|
+
def split_path(fpath):
|
2685
|
+
f_slash = "/" if "mac" in get_os().lower() else "\\"
|
2686
|
+
dir_par = f_slash.join(fpath.split(f_slash)[:-1])
|
2687
|
+
dir_ch = "".join(fpath.split(f_slash)[-1:])
|
2688
|
+
return dir_par, dir_ch
|
2689
|
+
|
2690
|
+
|
2334
2691
|
def figsave(*args, dpi=300):
|
2335
2692
|
dir_save = None
|
2336
2693
|
fname = None
|
2694
|
+
img = None
|
2337
2695
|
for arg in args:
|
2338
2696
|
if isinstance(arg, str):
|
2339
2697
|
if "/" in arg or "\\" in arg:
|
2340
2698
|
dir_save = arg
|
2341
2699
|
elif "/" not in arg and "\\" not in arg:
|
2342
2700
|
fname = arg
|
2343
|
-
|
2344
|
-
|
2345
|
-
|
2346
|
-
|
2347
|
-
|
2348
|
-
|
2349
|
-
|
2350
|
-
|
2351
|
-
|
2701
|
+
elif isinstance(arg, (Image.Image, np.ndarray)):
|
2702
|
+
img = arg # Store the PIL image if provided
|
2703
|
+
|
2704
|
+
f_slash = "/" if "mac" in get_os().lower() else "\\"
|
2705
|
+
dir_par = f_slash.join(dir_save.split(f_slash)[:-1])
|
2706
|
+
dir_ch = "".join(dir_save.split(f_slash)[-1:])
|
2707
|
+
if not dir_par.endswith(f_slash):
|
2708
|
+
dir_par += f_slash
|
2709
|
+
if fname is None:
|
2710
|
+
fname = dir_ch
|
2711
|
+
mkdir(dir_par)
|
2352
2712
|
ftype = fname.split(".")[-1]
|
2353
2713
|
if len(fname.split(".")) == 1:
|
2354
2714
|
ftype = "nofmt"
|
2355
|
-
fname =
|
2715
|
+
fname = dir_par + fname + "." + ftype
|
2356
2716
|
else:
|
2357
|
-
fname =
|
2717
|
+
fname = dir_par + fname
|
2718
|
+
|
2358
2719
|
# Save figure based on file type
|
2359
2720
|
if ftype.lower() == "eps":
|
2360
2721
|
plt.savefig(fname, format="eps", bbox_inches="tight")
|
2361
2722
|
plt.savefig(
|
2362
|
-
fname.replace(".eps", ".pdf"),
|
2723
|
+
fname.replace(".eps", ".pdf"),
|
2724
|
+
format="pdf",
|
2725
|
+
bbox_inches="tight",
|
2726
|
+
dpi=dpi,
|
2727
|
+
pad_inches=0,
|
2363
2728
|
)
|
2364
2729
|
elif ftype.lower() == "nofmt": # default: both "tif" and "pdf"
|
2365
2730
|
fname_corr = fname.replace("nofmt", "pdf")
|
2366
|
-
plt.savefig(
|
2731
|
+
plt.savefig(
|
2732
|
+
fname_corr, format="pdf", bbox_inches="tight", dpi=dpi, pad_inches=0
|
2733
|
+
)
|
2367
2734
|
fname = fname.replace("nofmt", "tif")
|
2368
|
-
plt.savefig(fname, format="tiff", dpi=dpi, bbox_inches="tight")
|
2735
|
+
plt.savefig(fname, format="tiff", dpi=dpi, bbox_inches="tight", pad_inches=0)
|
2369
2736
|
print(f"default saving filetype: both 'tif' and 'pdf")
|
2370
2737
|
elif ftype.lower() == "pdf":
|
2371
|
-
plt.savefig(fname, format="pdf", bbox_inches="tight", dpi=dpi)
|
2372
|
-
elif ftype.lower() in ["jpg", "jpeg"]:
|
2373
|
-
|
2374
|
-
|
2375
|
-
|
2376
|
-
|
2377
|
-
|
2738
|
+
plt.savefig(fname, format="pdf", bbox_inches="tight", dpi=dpi, pad_inches=0)
|
2739
|
+
elif ftype.lower() in ["jpg", "jpeg", "png", "tiff", "tif"]:
|
2740
|
+
if img is not None: # If a PIL image is provided
|
2741
|
+
if isinstance(img, Image.Image):
|
2742
|
+
if img.mode == "RGBA":
|
2743
|
+
img = img.convert("RGB")
|
2744
|
+
img.save(fname, format=ftype.upper(), dpi=(dpi, dpi))
|
2745
|
+
elif isinstance(img, np.ndarray):
|
2746
|
+
import cv2
|
2747
|
+
|
2748
|
+
# Check the shape of the image to determine color mode
|
2749
|
+
if img.ndim == 2:
|
2750
|
+
# Grayscale image
|
2751
|
+
Image.fromarray(img).save(
|
2752
|
+
fname, format=ftype.upper(), dpi=(dpi, dpi)
|
2753
|
+
)
|
2754
|
+
elif img.ndim == 3:
|
2755
|
+
if img.shape[2] == 3:
|
2756
|
+
# RGB image
|
2757
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
2758
|
+
Image.fromarray(img).save(
|
2759
|
+
fname, format=ftype.upper(), dpi=(dpi, dpi)
|
2760
|
+
)
|
2761
|
+
elif img.shape[2] == 4:
|
2762
|
+
# RGBA image
|
2763
|
+
img = cv2.cvtColor(
|
2764
|
+
img, cv2.COLOR_BGRA2RGBA
|
2765
|
+
) # Convert BGRA to RGBA
|
2766
|
+
Image.fromarray(img).save(
|
2767
|
+
fname, format=ftype.upper(), dpi=(dpi, dpi)
|
2768
|
+
)
|
2769
|
+
else:
|
2770
|
+
raise ValueError(
|
2771
|
+
"Unexpected number of channels in the image array."
|
2772
|
+
)
|
2773
|
+
else:
|
2774
|
+
raise ValueError(
|
2775
|
+
"Image array has an unexpected number of dimensions."
|
2776
|
+
)
|
2777
|
+
else:
|
2778
|
+
plt.savefig(
|
2779
|
+
fname, format=ftype.lower(), dpi=dpi, bbox_inches="tight", pad_inches=0
|
2780
|
+
)
|
2781
|
+
# elif ftype.lower() == "png":
|
2782
|
+
# plt.savefig(fname, format="png", dpi=dpi, bbox_inches="tight", transparent=True,pad_inches=0)
|
2783
|
+
# elif ftype.lower() in ["tiff", "tif"]:
|
2784
|
+
# plt.savefig(fname, format="tiff", dpi=dpi, bbox_inches="tight",pad_inches=0)
|
2378
2785
|
elif ftype.lower() == "emf":
|
2379
|
-
plt.savefig(fname, format="emf", dpi=dpi, bbox_inches="tight")
|
2786
|
+
plt.savefig(fname, format="emf", dpi=dpi, bbox_inches="tight", pad_inches=0)
|
2380
2787
|
elif ftype.lower() == "fig":
|
2381
|
-
plt.savefig(fname, format="pdf", bbox_inches="tight", dpi=dpi)
|
2788
|
+
plt.savefig(fname, format="pdf", bbox_inches="tight", dpi=dpi, pad_inches=0)
|
2382
2789
|
print(f"\nSaved @: dpi={dpi}\n{fname}")
|
2383
2790
|
|
2384
2791
|
|
@@ -2510,6 +2917,8 @@ def load_img(fpath):
|
|
2510
2917
|
FileNotFoundError: If the specified file is not found.
|
2511
2918
|
OSError: If the specified file cannot be opened or is not a valid image file.
|
2512
2919
|
"""
|
2920
|
+
from PIL import Image
|
2921
|
+
|
2513
2922
|
try:
|
2514
2923
|
img = Image.open(fpath)
|
2515
2924
|
return img
|
@@ -2650,393 +3059,6 @@ def apply_filter(img, *args):
|
|
2650
3059
|
return img.filter(supported_filters[filter_name])
|
2651
3060
|
|
2652
3061
|
|
2653
|
-
def imgsetss(
|
2654
|
-
img,
|
2655
|
-
sets=None,
|
2656
|
-
show=True,
|
2657
|
-
show_axis=False,
|
2658
|
-
size=None,
|
2659
|
-
dpi=100,
|
2660
|
-
figsize=None,
|
2661
|
-
auto=False,
|
2662
|
-
filter_kws=None,
|
2663
|
-
):
|
2664
|
-
"""
|
2665
|
-
Apply various enhancements and filters to an image using PIL's ImageEnhance and ImageFilter modules.
|
2666
|
-
|
2667
|
-
Args:
|
2668
|
-
img (PIL.Image): The input image.
|
2669
|
-
sets (dict): A dictionary specifying the enhancements, filters, and their parameters.
|
2670
|
-
show (bool): Whether to display the enhanced image.
|
2671
|
-
show_axis (bool): Whether to display axes on the image plot.
|
2672
|
-
size (tuple): The size of the thumbnail, cover, contain, or fit operation.
|
2673
|
-
dpi (int): Dots per inch for the displayed image.
|
2674
|
-
figsize (tuple): The size of the figure for displaying the image.
|
2675
|
-
auto (bool): Whether to automatically enhance the image based on its characteristics.
|
2676
|
-
|
2677
|
-
Returns:
|
2678
|
-
PIL.Image: The enhanced image.
|
2679
|
-
|
2680
|
-
Supported enhancements and filters:
|
2681
|
-
- "sharpness": Adjusts the sharpness of the image. Values > 1 increase sharpness, while values < 1 decrease sharpness.
|
2682
|
-
- "contrast": Adjusts the contrast of the image. Values > 1 increase contrast, while values < 1 decrease contrast.
|
2683
|
-
- "brightness": Adjusts the brightness of the image. Values > 1 increase brightness, while values < 1 decrease brightness.
|
2684
|
-
- "color": Adjusts the color saturation of the image. Values > 1 increase saturation, while values < 1 decrease saturation.
|
2685
|
-
- "rotate": Rotates the image by the specified angle.
|
2686
|
-
- "crop" or "cut": Crops the image. The value should be a tuple specifying the crop box as (left, upper, right, lower).
|
2687
|
-
- "size": Resizes the image to the specified dimensions.
|
2688
|
-
- "thumbnail": Resizes the image to fit within the given size while preserving aspect ratio.
|
2689
|
-
- "cover": Resizes and crops the image to fill the specified size.
|
2690
|
-
- "contain": Resizes the image to fit within the specified size, adding borders if necessary.
|
2691
|
-
- "fit": Resizes and pads the image to fit within the specified size.
|
2692
|
-
- "filter": Applies various filters to the image (e.g., BLUR, CONTOUR, EDGE_ENHANCE).
|
2693
|
-
|
2694
|
-
Note:
|
2695
|
-
The "color" and "enhance" enhancements are not implemented in this function.
|
2696
|
-
"""
|
2697
|
-
supported_filters = [
|
2698
|
-
"BLUR",
|
2699
|
-
"CONTOUR",
|
2700
|
-
"DETAIL",
|
2701
|
-
"EDGE_ENHANCE",
|
2702
|
-
"EDGE_ENHANCE_MORE",
|
2703
|
-
"EMBOSS",
|
2704
|
-
"FIND_EDGES",
|
2705
|
-
"SHARPEN",
|
2706
|
-
"SMOOTH",
|
2707
|
-
"SMOOTH_MORE",
|
2708
|
-
"MIN_FILTER",
|
2709
|
-
"MAX_FILTER",
|
2710
|
-
"MODE_FILTER",
|
2711
|
-
"MULTIBAND_FILTER",
|
2712
|
-
"GAUSSIAN_BLUR",
|
2713
|
-
"BOX_BLUR",
|
2714
|
-
"MEDIAN_FILTER",
|
2715
|
-
]
|
2716
|
-
print(
|
2717
|
-
"sets: a dict,'sharp:1.2','color','contrast:'auto' or 1.2','bright', 'crop: x_upperleft,y_upperleft, x_lowerright, y_lowerright','rotation','resize','rem or background'"
|
2718
|
-
)
|
2719
|
-
print(f"usage: filter_kws 'dict' below:")
|
2720
|
-
pp([str(i).lower() for i in supported_filters])
|
2721
|
-
print("\nlog:\n")
|
2722
|
-
|
2723
|
-
def confirm_rembg_models(model_name):
|
2724
|
-
models_support = [
|
2725
|
-
"u2net",
|
2726
|
-
"u2netp",
|
2727
|
-
"u2net_human_seg",
|
2728
|
-
"u2net_cloth_seg",
|
2729
|
-
"silueta",
|
2730
|
-
"isnet-general-use",
|
2731
|
-
"isnet-anime",
|
2732
|
-
"sam",
|
2733
|
-
]
|
2734
|
-
if model_name in models_support:
|
2735
|
-
print(f"model_name: {model_name}")
|
2736
|
-
return model_name
|
2737
|
-
else:
|
2738
|
-
print(
|
2739
|
-
f"{model_name} cannot be found, check the name:{models_support}, default('isnet-general-use') has been used"
|
2740
|
-
)
|
2741
|
-
return "isnet-general-use"
|
2742
|
-
|
2743
|
-
def auto_enhance(img):
|
2744
|
-
"""
|
2745
|
-
Automatically enhances the image based on its characteristics.
|
2746
|
-
Args:
|
2747
|
-
img (PIL.Image): The input image.
|
2748
|
-
Returns:
|
2749
|
-
dict: A dictionary containing the optimal enhancement values.
|
2750
|
-
"""
|
2751
|
-
# Determine the bit depth based on the image mode
|
2752
|
-
if img.mode in ["1", "L", "P", "RGB", "YCbCr", "LAB", "HSV"]:
|
2753
|
-
# 8-bit depth per channel
|
2754
|
-
bit_depth = 8
|
2755
|
-
elif img.mode in ["RGBA", "CMYK"]:
|
2756
|
-
# 8-bit depth per channel + alpha (RGBA) or additional channels (CMYK)
|
2757
|
-
bit_depth = 8
|
2758
|
-
elif img.mode in ["I", "F"]:
|
2759
|
-
# 16-bit depth per channel (integer or floating-point)
|
2760
|
-
bit_depth = 16
|
2761
|
-
else:
|
2762
|
-
raise ValueError("Unsupported image mode")
|
2763
|
-
# Calculate the brightness and contrast for each channel
|
2764
|
-
num_channels = len(img.getbands())
|
2765
|
-
brightness_factors = []
|
2766
|
-
contrast_factors = []
|
2767
|
-
for channel in range(num_channels):
|
2768
|
-
channel_histogram = img.split()[channel].histogram()
|
2769
|
-
brightness = sum(i * w for i, w in enumerate(channel_histogram)) / sum(
|
2770
|
-
channel_histogram
|
2771
|
-
)
|
2772
|
-
channel_min, channel_max = img.split()[channel].getextrema()
|
2773
|
-
contrast = channel_max - channel_min
|
2774
|
-
# Adjust calculations based on bit depth
|
2775
|
-
normalization_factor = 2**bit_depth - 1 # Max value for the given bit depth
|
2776
|
-
brightness_factor = (
|
2777
|
-
1.0 + (brightness - normalization_factor / 2) / normalization_factor
|
2778
|
-
)
|
2779
|
-
contrast_factor = (
|
2780
|
-
1.0 + (contrast - normalization_factor / 2) / normalization_factor
|
2781
|
-
)
|
2782
|
-
brightness_factors.append(brightness_factor)
|
2783
|
-
contrast_factors.append(contrast_factor)
|
2784
|
-
# Calculate the average brightness and contrast factors across channels
|
2785
|
-
avg_brightness_factor = sum(brightness_factors) / num_channels
|
2786
|
-
avg_contrast_factor = sum(contrast_factors) / num_channels
|
2787
|
-
return {"brightness": avg_brightness_factor, "contrast": avg_contrast_factor}
|
2788
|
-
|
2789
|
-
# Load image if input is a file path
|
2790
|
-
if isinstance(img, str):
|
2791
|
-
img = load_img(img)
|
2792
|
-
img_update = img.copy()
|
2793
|
-
# Auto-enhance image if requested
|
2794
|
-
if auto:
|
2795
|
-
auto_params = auto_enhance(img_update)
|
2796
|
-
sets.update(auto_params)
|
2797
|
-
if sets is None:
|
2798
|
-
sets = {}
|
2799
|
-
for k, value in sets.items():
|
2800
|
-
if "shar" in k.lower():
|
2801
|
-
enhancer = ImageEnhance.Sharpness(img_update)
|
2802
|
-
img_update = enhancer.enhance(value)
|
2803
|
-
elif "col" in k.lower() and "bg" not in k.lower():
|
2804
|
-
enhancer = ImageEnhance.Color(img_update)
|
2805
|
-
img_update = enhancer.enhance(value)
|
2806
|
-
elif "contr" in k.lower():
|
2807
|
-
if value and isinstance(value, (float, int)):
|
2808
|
-
enhancer = ImageEnhance.Contrast(img_update)
|
2809
|
-
img_update = enhancer.enhance(value)
|
2810
|
-
else:
|
2811
|
-
print("autocontrasted")
|
2812
|
-
img_update = ImageOps.autocontrast(img_update)
|
2813
|
-
elif "bri" in k.lower():
|
2814
|
-
enhancer = ImageEnhance.Brightness(img_update)
|
2815
|
-
img_update = enhancer.enhance(value)
|
2816
|
-
elif "cro" in k.lower() or "cut" in k.lower():
|
2817
|
-
img_update = img_update.crop(value)
|
2818
|
-
elif "rota" in k.lower():
|
2819
|
-
img_update = img_update.rotate(value)
|
2820
|
-
elif "si" in k.lower():
|
2821
|
-
img_update = img_update.resize(value)
|
2822
|
-
elif "thum" in k.lower():
|
2823
|
-
img_update.thumbnail(value)
|
2824
|
-
elif "cover" in k.lower():
|
2825
|
-
img_update = ImageOps.cover(img_update, size=value)
|
2826
|
-
elif "contain" in k.lower():
|
2827
|
-
img_update = ImageOps.contain(img_update, size=value)
|
2828
|
-
elif "fit" in k.lower():
|
2829
|
-
img_update = ImageOps.fit(img_update, size=value)
|
2830
|
-
elif "pad" in k.lower():
|
2831
|
-
img_update = ImageOps.pad(img_update, size=value)
|
2832
|
-
elif "rem" in k.lower() or "rm" in k.lower() or "back" in k.lower():
|
2833
|
-
if value and isinstance(value, (int, float, list)):
|
2834
|
-
print(
|
2835
|
-
'example usage: {"rm":[alpha_matting_background_threshold(20),alpha_matting_foreground_threshold(270),alpha_matting_erode_sive(11)]}'
|
2836
|
-
)
|
2837
|
-
print("https://github.com/danielgatis/rembg/blob/main/USAGE.md")
|
2838
|
-
# ### Parameters:
|
2839
|
-
# data (Union[bytes, PILImage, np.ndarray]): The input image data.
|
2840
|
-
# alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False.
|
2841
|
-
# alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240.
|
2842
|
-
# alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10.
|
2843
|
-
# alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10.
|
2844
|
-
# session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None.
|
2845
|
-
# only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False.
|
2846
|
-
# post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False.
|
2847
|
-
# bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None.
|
2848
|
-
# ###
|
2849
|
-
if isinstance(value, int):
|
2850
|
-
value = [value]
|
2851
|
-
if len(value) < 2:
|
2852
|
-
img_update = remove(
|
2853
|
-
img_update,
|
2854
|
-
alpha_matting=True,
|
2855
|
-
alpha_matting_background_threshold=value,
|
2856
|
-
)
|
2857
|
-
elif 2 <= len(value) < 3:
|
2858
|
-
img_update = remove(
|
2859
|
-
img_update,
|
2860
|
-
alpha_matting=True,
|
2861
|
-
alpha_matting_background_threshold=value[0],
|
2862
|
-
alpha_matting_foreground_threshold=value[1],
|
2863
|
-
)
|
2864
|
-
elif 3 <= len(value) < 4:
|
2865
|
-
img_update = remove(
|
2866
|
-
img_update,
|
2867
|
-
alpha_matting=True,
|
2868
|
-
alpha_matting_background_threshold=value[0],
|
2869
|
-
alpha_matting_foreground_threshold=value[1],
|
2870
|
-
alpha_matting_erode_size=value[2],
|
2871
|
-
)
|
2872
|
-
if isinstance(value, tuple): # replace the background color
|
2873
|
-
if len(value) == 3:
|
2874
|
-
value += (255,)
|
2875
|
-
img_update = remove(img_update, bgcolor=value)
|
2876
|
-
if isinstance(value, str):
|
2877
|
-
if confirm_rembg_models(value):
|
2878
|
-
img_update = remove(img_update, session=new_session(value))
|
2879
|
-
else:
|
2880
|
-
img_update = remove(img_update)
|
2881
|
-
elif "bgcolor" in k.lower():
|
2882
|
-
if isinstance(value, list):
|
2883
|
-
value = tuple(value)
|
2884
|
-
if isinstance(value, tuple): # replace the background color
|
2885
|
-
if len(value) == 3:
|
2886
|
-
value += (255,)
|
2887
|
-
img_update = remove(img_update, bgcolor=value)
|
2888
|
-
if filter_kws:
|
2889
|
-
for filter_name, filter_value in filter_kws.items():
|
2890
|
-
img_update = apply_filter(img_update, filter_name, filter_value)
|
2891
|
-
# Display the image if requested
|
2892
|
-
if show:
|
2893
|
-
if figsize is None:
|
2894
|
-
plt.figure(dpi=dpi)
|
2895
|
-
else:
|
2896
|
-
plt.figure(figsize=figsize, dpi=dpi)
|
2897
|
-
plt.imshow(img_update)
|
2898
|
-
plt.axis("on") if show_axis else plt.axis("off")
|
2899
|
-
return img_update
|
2900
|
-
|
2901
|
-
|
2902
|
-
from sklearn.decomposition import PCA
|
2903
|
-
from skimage import transform, feature, filters, measure
|
2904
|
-
from skimage.color import rgb2gray
|
2905
|
-
from scipy.fftpack import fftshift, fft2
|
2906
|
-
import numpy as np
|
2907
|
-
import cv2 # Used for template matching
|
2908
|
-
|
2909
|
-
|
2910
|
-
def crop_black_borders(image):
|
2911
|
-
"""Crop the black borders from a rotated image."""
|
2912
|
-
# Convert the image to grayscale if it's not already
|
2913
|
-
if image.ndim == 3:
|
2914
|
-
gray_image = color.rgb2gray(image)
|
2915
|
-
else:
|
2916
|
-
gray_image = image
|
2917
|
-
|
2918
|
-
# Find all the non-black (non-zero) pixels
|
2919
|
-
mask = gray_image > 0 # Mask for non-black pixels (assuming black is zero)
|
2920
|
-
coords = np.column_stack(np.where(mask))
|
2921
|
-
|
2922
|
-
# Get the bounding box of non-black pixels
|
2923
|
-
if coords.any(): # Check if there are any non-black pixels
|
2924
|
-
y_min, x_min = coords.min(axis=0)
|
2925
|
-
y_max, x_max = coords.max(axis=0)
|
2926
|
-
|
2927
|
-
# Crop the image to the bounding box
|
2928
|
-
cropped_image = image[y_min : y_max + 1, x_min : x_max + 1]
|
2929
|
-
else:
|
2930
|
-
# If the image is completely black (which shouldn't happen), return the original image
|
2931
|
-
cropped_image = image
|
2932
|
-
|
2933
|
-
return cropped_image
|
2934
|
-
|
2935
|
-
|
2936
|
-
def detect_angle(image, by="median", template=None):
|
2937
|
-
"""Detect the angle of rotation using various methods."""
|
2938
|
-
# Convert to grayscale
|
2939
|
-
gray_image = rgb2gray(image)
|
2940
|
-
|
2941
|
-
# Detect edges using Canny edge detector
|
2942
|
-
edges = feature.canny(gray_image, sigma=2)
|
2943
|
-
|
2944
|
-
# Use Hough transform to detect lines
|
2945
|
-
lines = transform.probabilistic_hough_line(edges)
|
2946
|
-
|
2947
|
-
if not lines and any(["me" in by, "pca" in by]):
|
2948
|
-
print("No lines detected. Adjust the edge detection parameters.")
|
2949
|
-
return 0
|
2950
|
-
|
2951
|
-
# Hough Transform-based angle detection (Median/Mean)
|
2952
|
-
if "me" in by:
|
2953
|
-
angles = []
|
2954
|
-
for line in lines:
|
2955
|
-
(x0, y0), (x1, y1) = line
|
2956
|
-
angle = np.arctan2(y1 - y0, x1 - x0) * 180 / np.pi
|
2957
|
-
if 80 < abs(angle) < 100:
|
2958
|
-
angles.append(angle)
|
2959
|
-
if not angles:
|
2960
|
-
return 0
|
2961
|
-
if "di" in by:
|
2962
|
-
median_angle = np.median(angles)
|
2963
|
-
rotation_angle = (
|
2964
|
-
90 - median_angle if median_angle > 0 else -90 - median_angle
|
2965
|
-
)
|
2966
|
-
|
2967
|
-
return rotation_angle
|
2968
|
-
else:
|
2969
|
-
mean_angle = np.mean(angles)
|
2970
|
-
rotation_angle = 90 - mean_angle if mean_angle > 0 else -90 - mean_angle
|
2971
|
-
|
2972
|
-
return rotation_angle
|
2973
|
-
|
2974
|
-
# PCA-based angle detection
|
2975
|
-
elif "pca" in by:
|
2976
|
-
y, x = np.nonzero(edges)
|
2977
|
-
if len(x) == 0:
|
2978
|
-
return 0
|
2979
|
-
pca = PCA(n_components=2)
|
2980
|
-
pca.fit(np.vstack((x, y)).T)
|
2981
|
-
angle = np.arctan2(pca.components_[0, 1], pca.components_[0, 0]) * 180 / np.pi
|
2982
|
-
return angle
|
2983
|
-
|
2984
|
-
# Gradient Orientation-based angle detection
|
2985
|
-
elif "gra" in by:
|
2986
|
-
gx, gy = np.gradient(gray_image)
|
2987
|
-
angles = np.arctan2(gy, gx) * 180 / np.pi
|
2988
|
-
hist, bin_edges = np.histogram(angles, bins=360, range=(-180, 180))
|
2989
|
-
return bin_edges[np.argmax(hist)]
|
2990
|
-
|
2991
|
-
# Template Matching-based angle detection
|
2992
|
-
elif "temp" in by:
|
2993
|
-
if template is None:
|
2994
|
-
# Automatically extract a template from the center of the image
|
2995
|
-
height, width = gray_image.shape
|
2996
|
-
center_x, center_y = width // 2, height // 2
|
2997
|
-
size = (
|
2998
|
-
min(height, width) // 4
|
2999
|
-
) # Size of the template as a fraction of image size
|
3000
|
-
template = gray_image[
|
3001
|
-
center_y - size : center_y + size, center_x - size : center_x + size
|
3002
|
-
]
|
3003
|
-
best_angle = None
|
3004
|
-
best_corr = -1
|
3005
|
-
for angle in range(0, 180, 1): # Checking every degree
|
3006
|
-
rotated_template = transform.rotate(template, angle)
|
3007
|
-
res = cv2.matchTemplate(gray_image, rotated_template, cv2.TM_CCOEFF)
|
3008
|
-
_, max_val, _, _ = cv2.minMaxLoc(res)
|
3009
|
-
if max_val > best_corr:
|
3010
|
-
best_corr = max_val
|
3011
|
-
best_angle = angle
|
3012
|
-
return best_angle
|
3013
|
-
|
3014
|
-
# Image Moments-based angle detection
|
3015
|
-
elif "mo" in by:
|
3016
|
-
moments = measure.moments_central(gray_image)
|
3017
|
-
angle = (
|
3018
|
-
0.5
|
3019
|
-
* np.arctan2(2 * moments[1, 1], moments[0, 2] - moments[2, 0])
|
3020
|
-
* 180
|
3021
|
-
/ np.pi
|
3022
|
-
)
|
3023
|
-
return angle
|
3024
|
-
|
3025
|
-
# Fourier Transform-based angle detection
|
3026
|
-
elif "fft" in by:
|
3027
|
-
f = fft2(gray_image)
|
3028
|
-
fshift = fftshift(f)
|
3029
|
-
magnitude_spectrum = np.log(np.abs(fshift) + 1)
|
3030
|
-
rows, cols = magnitude_spectrum.shape
|
3031
|
-
r, c = np.unravel_index(np.argmax(magnitude_spectrum), (rows, cols))
|
3032
|
-
angle = np.arctan2(r - rows // 2, c - cols // 2) * 180 / np.pi
|
3033
|
-
return angle
|
3034
|
-
|
3035
|
-
else:
|
3036
|
-
print(f"Unknown method {by}")
|
3037
|
-
return 0
|
3038
|
-
|
3039
|
-
|
3040
3062
|
def imgsets(img, **kwargs):
|
3041
3063
|
"""
|
3042
3064
|
Apply various enhancements and filters to an image using PIL's ImageEnhance and ImageFilter modules.
|
@@ -3179,7 +3201,9 @@ def imgsets(img, **kwargs):
|
|
3179
3201
|
if "shar" in k.lower():
|
3180
3202
|
enhancer = ImageEnhance.Sharpness(img_update)
|
3181
3203
|
img_update = enhancer.enhance(value)
|
3182
|
-
elif
|
3204
|
+
elif all(
|
3205
|
+
["col" in k.lower(), "bg" not in k.lower(), "background" not in k.lower()]
|
3206
|
+
):
|
3183
3207
|
enhancer = ImageEnhance.Color(img_update)
|
3184
3208
|
img_update = enhancer.enhance(value)
|
3185
3209
|
elif "contr" in k.lower():
|
@@ -3201,6 +3225,9 @@ def imgsets(img, **kwargs):
|
|
3201
3225
|
img_update = img_update.rotate(value)
|
3202
3226
|
|
3203
3227
|
elif "si" in k.lower():
|
3228
|
+
if isinstance(value, tuple):
|
3229
|
+
value = list(value)
|
3230
|
+
value = [int(i) for i in value]
|
3204
3231
|
img_update = img_update.resize(value)
|
3205
3232
|
elif "thum" in k.lower():
|
3206
3233
|
img_update.thumbnail(value)
|
@@ -3221,21 +3248,7 @@ def imgsets(img, **kwargs):
|
|
3221
3248
|
session = new_session("isnet-general-use")
|
3222
3249
|
img_update = remove(img_update, session=session)
|
3223
3250
|
elif value and isinstance(value, (int, float, list)):
|
3224
|
-
print(
|
3225
|
-
'example usage: {"rm":[alpha_matting_background_threshold(20),alpha_matting_foreground_threshold(270),alpha_matting_erode_sive(11)]}'
|
3226
|
-
)
|
3227
3251
|
print("https://github.com/danielgatis/rembg/blob/main/USAGE.md")
|
3228
|
-
# ### Parameters:
|
3229
|
-
# data (Union[bytes, PILImage, np.ndarray]): The input image data.
|
3230
|
-
# alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False.
|
3231
|
-
# alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240.
|
3232
|
-
# alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10.
|
3233
|
-
# alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10.
|
3234
|
-
# session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None.
|
3235
|
-
# only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False.
|
3236
|
-
# post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False.
|
3237
|
-
# bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None.
|
3238
|
-
# ###
|
3239
3252
|
if isinstance(value, int):
|
3240
3253
|
value = [value]
|
3241
3254
|
if len(value) < 2:
|
@@ -3294,16 +3307,6 @@ def imgsets(img, **kwargs):
|
|
3294
3307
|
return img_update
|
3295
3308
|
|
3296
3309
|
|
3297
|
-
# # usage:
|
3298
|
-
# img = imgsets(
|
3299
|
-
# fpath,
|
3300
|
-
# sets={"rota": -5},
|
3301
|
-
# dpi=200,
|
3302
|
-
# filter_kws={"EMBOSS": 5, "sharpen": 5, "EDGE_ENHANCE_MORE": 10},
|
3303
|
-
# show_axis=True,
|
3304
|
-
# )
|
3305
|
-
|
3306
|
-
|
3307
3310
|
def thumbnail(dir_img_list, figsize=(10, 10), dpi=100, dir_save=None, kind=".png"):
|
3308
3311
|
"""
|
3309
3312
|
Display a thumbnail figure of all images in the specified directory.
|
@@ -4248,7 +4251,7 @@ format_excel(
|
|
4248
4251
|
print(f"Formatted Excel file saved as:\n{filename}")
|
4249
4252
|
|
4250
4253
|
|
4251
|
-
from IPython.display import display, HTML, Markdown
|
4254
|
+
from IPython.display import display, HTML, Markdown
|
4252
4255
|
|
4253
4256
|
|
4254
4257
|
def preview(var):
|
@@ -4298,8 +4301,6 @@ def preview(var):
|
|
4298
4301
|
# preview(pd.DataFrame({"Name": ["Alice", "Bob"], "Age": [25, 30]}))
|
4299
4302
|
# preview({"key": "value", "numbers": [1, 2, 3]})
|
4300
4303
|
|
4301
|
-
|
4302
|
-
# ! DataFrame
|
4303
4304
|
def df_as_type(
|
4304
4305
|
df: pd.DataFrame,
|
4305
4306
|
columns: Optional[Union[str, List[str]]] = None,
|
@@ -4308,6 +4309,18 @@ def df_as_type(
|
|
4308
4309
|
inplace: bool = True,
|
4309
4310
|
errors: str = "coerce", # Can be "ignore", "raise", or "coerce"
|
4310
4311
|
**kwargs,
|
4312
|
+
):
|
4313
|
+
return df_astype(df=df,columns=columns,astype=astype,format=format,inplace=inplace,errors=errors,**kwargs)
|
4314
|
+
|
4315
|
+
# ! DataFrame
|
4316
|
+
def df_astype(
|
4317
|
+
df: pd.DataFrame,
|
4318
|
+
columns: Optional[Union[str, List[str]]] = None,
|
4319
|
+
astype: str = "datetime",
|
4320
|
+
fmt: Optional[str] = None,
|
4321
|
+
inplace: bool = True,
|
4322
|
+
errors: str = "coerce", # Can be "ignore", "raise", or "coerce"
|
4323
|
+
**kwargs,
|
4311
4324
|
) -> Optional[pd.DataFrame]:
|
4312
4325
|
"""
|
4313
4326
|
Convert specified columns of a DataFrame to a specified type (e.g., datetime, float, int, numeric, timedelta).
|
@@ -4317,7 +4330,7 @@ def df_as_type(
|
|
4317
4330
|
- df: DataFrame containing the columns to convert.
|
4318
4331
|
- columns: Either a single column name, a list of column names, or None to convert all columns.
|
4319
4332
|
- astype: The target type to convert the columns to ('datetime', 'float', 'int', 'numeric', 'timedelta', etc.).
|
4320
|
-
-
|
4333
|
+
- fmt: Optional; format to specify the datetime format (only relevant for 'datetime' conversion).
|
4321
4334
|
- inplace: Whether to modify the DataFrame in place or return a new one. Defaults to False.
|
4322
4335
|
- errors: Can be "ignore", "raise", or "coerce"
|
4323
4336
|
- **kwargs: Additional keyword arguments to pass to the conversion function (e.g., errors='ignore' for pd.to_datetime or pd.to_numeric).
|
@@ -4397,7 +4410,7 @@ def df_as_type(
|
|
4397
4410
|
# convert it as type: datetime
|
4398
4411
|
if isinstance(column, int):
|
4399
4412
|
df.iloc[:, column] = pd.to_datetime(
|
4400
|
-
df.iloc[:, column], format=
|
4413
|
+
df.iloc[:, column], format=fmt, errors=errors, **kwargs
|
4401
4414
|
)
|
4402
4415
|
# further convert:
|
4403
4416
|
if astype == "time":
|
@@ -4419,9 +4432,9 @@ def df_as_type(
|
|
4419
4432
|
else:
|
4420
4433
|
df[column] = (
|
4421
4434
|
pd.to_datetime(
|
4422
|
-
df[column], format=
|
4435
|
+
df[column], format=fmt, errors=errors, **kwargs
|
4423
4436
|
)
|
4424
|
-
if
|
4437
|
+
if fmt
|
4425
4438
|
else pd.to_datetime(df[column], errors=errors, **kwargs)
|
4426
4439
|
)
|
4427
4440
|
# further convert:
|
@@ -4528,3 +4541,726 @@ def df_sort_values(df, column, by=None, ascending=True, inplace=False, **kwargs)
|
|
4528
4541
|
# display(sorted_df_month)
|
4529
4542
|
# df_sort_values(df_month, "month", month_order, ascending=True, inplace=True)
|
4530
4543
|
# display(df_month)
|
4544
|
+
|
4545
|
+
|
4546
|
+
def df_cluster(
|
4547
|
+
data: pd.DataFrame,
|
4548
|
+
columns: Optional[list] = None,
|
4549
|
+
n_clusters: Optional[int] = None,
|
4550
|
+
range_n_clusters: Union[range, np.ndarray] = range(2, 11),
|
4551
|
+
scale: bool = True,
|
4552
|
+
plot: Union[str, list] = "all",
|
4553
|
+
inplace: bool = True,
|
4554
|
+
ax: Optional[plt.Axes] = None,
|
4555
|
+
) -> tuple[pd.DataFrame, int, Optional[plt.Axes]]:
|
4556
|
+
from sklearn.preprocessing import StandardScaler
|
4557
|
+
from sklearn.cluster import KMeans
|
4558
|
+
from sklearn.metrics import silhouette_score, silhouette_samples
|
4559
|
+
import seaborn as sns
|
4560
|
+
import numpy as np
|
4561
|
+
import pandas as pd
|
4562
|
+
import matplotlib.pyplot as plt
|
4563
|
+
import seaborn as sns
|
4564
|
+
|
4565
|
+
"""
|
4566
|
+
Performs clustering analysis on the provided feature matrix using K-Means.
|
4567
|
+
|
4568
|
+
Parameters:
|
4569
|
+
X (np.ndarray):
|
4570
|
+
A 2D numpy array or DataFrame containing numerical feature data,
|
4571
|
+
where each row corresponds to an observation and each column to a feature.
|
4572
|
+
|
4573
|
+
range_n_clusters (range):
|
4574
|
+
A range object specifying the number of clusters to evaluate for K-Means clustering.
|
4575
|
+
Default is range(2, 11), meaning it will evaluate from 2 to 10 clusters.
|
4576
|
+
|
4577
|
+
scale (bool):
|
4578
|
+
A flag indicating whether to standardize the features before clustering.
|
4579
|
+
Default is True, which scales the data to have a mean of 0 and variance of 1.
|
4580
|
+
|
4581
|
+
plot (bool):
|
4582
|
+
A flag indicating whether to generate visualizations of the clustering analysis.
|
4583
|
+
Default is True, which will plot silhouette scores, inertia, and other relevant plots.
|
4584
|
+
Returns:
|
4585
|
+
tuple:
|
4586
|
+
A tuple containing the modified DataFrame with cluster labels,
|
4587
|
+
the optimal number of clusters, and the Axes object (if any).
|
4588
|
+
"""
|
4589
|
+
X = data[columns].values if columns is not None else data.values
|
4590
|
+
|
4591
|
+
silhouette_avg_scores = []
|
4592
|
+
inertia_scores = []
|
4593
|
+
|
4594
|
+
# Standardize the features
|
4595
|
+
if scale:
|
4596
|
+
scaler = StandardScaler()
|
4597
|
+
X = scaler.fit_transform(X)
|
4598
|
+
|
4599
|
+
for n_cluster in range_n_clusters:
|
4600
|
+
kmeans = KMeans(n_clusters=n_cluster, random_state=42)
|
4601
|
+
cluster_labels = kmeans.fit_predict(X)
|
4602
|
+
|
4603
|
+
silhouette_avg = silhouette_score(X, cluster_labels)
|
4604
|
+
silhouette_avg_scores.append(silhouette_avg)
|
4605
|
+
inertia_scores.append(kmeans.inertia_)
|
4606
|
+
print(
|
4607
|
+
f"For n_clusters = {n_cluster}, the average silhouette_score is : {silhouette_avg:.4f}"
|
4608
|
+
)
|
4609
|
+
|
4610
|
+
# Determine the optimal number of clusters based on the maximum silhouette score
|
4611
|
+
if n_clusters is None:
|
4612
|
+
n_clusters = range_n_clusters[np.argmax(silhouette_avg_scores)]
|
4613
|
+
print(f"n_clusters = {n_clusters}")
|
4614
|
+
|
4615
|
+
# Apply K-Means Clustering with Optimal Number of Clusters
|
4616
|
+
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
|
4617
|
+
cluster_labels = kmeans.fit_predict(X)
|
4618
|
+
|
4619
|
+
if plot:
|
4620
|
+
# ! Interpreting the plots from your K-Means clustering analysis
|
4621
|
+
# ! 1. Silhouette Score and Inertia vs Number of Clusters
|
4622
|
+
# Description:
|
4623
|
+
# This plot has two y-axes: the left y-axis represents the Silhouette Score, and the right y-axis
|
4624
|
+
# represents Inertia.
|
4625
|
+
# The x-axis represents the number of clusters (k).
|
4626
|
+
|
4627
|
+
# Interpretation:
|
4628
|
+
|
4629
|
+
# Silhouette Score:
|
4630
|
+
# Ranges from -1 to 1, where a score close to 1 indicates that points are well-clustered, while a
|
4631
|
+
# score close to -1 indicates that points might be incorrectly clustered.
|
4632
|
+
# A higher silhouette score generally suggests that the data points are appropriately clustered.
|
4633
|
+
# Look for the highest value to determine the optimal number of clusters.
|
4634
|
+
|
4635
|
+
# Inertia:
|
4636
|
+
# Represents the sum of squared distances from each point to its assigned cluster center.
|
4637
|
+
# Lower inertia values indicate tighter clusters.
|
4638
|
+
# As the number of clusters increases, inertia typically decreases, but the rate of decrease
|
4639
|
+
# may slow down, indicating diminishing returns for additional clusters.
|
4640
|
+
|
4641
|
+
# Optimal Number of Clusters:
|
4642
|
+
# You can identify an optimal number of clusters where the silhouette score is maximized and
|
4643
|
+
# inertia starts to plateau (the "elbow" point).
|
4644
|
+
# This typically suggests that increasing the number of clusters further yields less meaningful
|
4645
|
+
# separations.
|
4646
|
+
if ax is None:
|
4647
|
+
_, ax = plt.subplots(figsize=inch2cm(10, 6))
|
4648
|
+
color = "tab:blue"
|
4649
|
+
ax.plot(
|
4650
|
+
range_n_clusters,
|
4651
|
+
silhouette_avg_scores,
|
4652
|
+
marker="o",
|
4653
|
+
color=color,
|
4654
|
+
label="Silhouette Score",
|
4655
|
+
)
|
4656
|
+
ax.set_xlabel("Number of Clusters")
|
4657
|
+
ax.set_ylabel("Silhouette Score", color=color)
|
4658
|
+
ax.tick_params(axis="y", labelcolor=color)
|
4659
|
+
# add right axis: inertia
|
4660
|
+
ax2 = ax.twinx()
|
4661
|
+
color = "tab:red"
|
4662
|
+
ax2.set_ylabel("Inertia", color=color)
|
4663
|
+
ax2.plot(
|
4664
|
+
range_n_clusters,
|
4665
|
+
inertia_scores,
|
4666
|
+
marker="x",
|
4667
|
+
color=color,
|
4668
|
+
label="Inertia",
|
4669
|
+
)
|
4670
|
+
ax2.tick_params(axis="y", labelcolor=color)
|
4671
|
+
|
4672
|
+
plt.title("Silhouette Score and Inertia vs Number of Clusters")
|
4673
|
+
plt.xticks(range_n_clusters)
|
4674
|
+
plt.grid()
|
4675
|
+
plt.axvline(x=n_clusters, linestyle="--", color="r", label="Optimal n_clusters")
|
4676
|
+
# ! 2. Elbow Method Plot
|
4677
|
+
# Description:
|
4678
|
+
# This plot shows the Inertia against the number of clusters.
|
4679
|
+
|
4680
|
+
# Interpretation:
|
4681
|
+
# The elbow point is where the inertia begins to decrease at a slower rate. This point suggests that
|
4682
|
+
# adding more clusters beyond this point does not significantly improve the clustering performance.
|
4683
|
+
# Look for a noticeable bend in the curve to identify the optimal number of clusters, indicated by the
|
4684
|
+
# vertical dashed line.
|
4685
|
+
# Inertia plot
|
4686
|
+
plt.figure(figsize=inch2cm(10, 6))
|
4687
|
+
plt.plot(range_n_clusters, inertia_scores, marker="o")
|
4688
|
+
plt.title("Elbow Method for Optimal k")
|
4689
|
+
plt.xlabel("Number of clusters")
|
4690
|
+
plt.ylabel("Inertia")
|
4691
|
+
plt.grid()
|
4692
|
+
plt.axvline(
|
4693
|
+
x=np.argmax(silhouette_avg_scores) + 2,
|
4694
|
+
linestyle="--",
|
4695
|
+
color="r",
|
4696
|
+
label="Optimal n_clusters",
|
4697
|
+
)
|
4698
|
+
plt.legend()
|
4699
|
+
# ! Silhouette Plots
|
4700
|
+
# 3. Silhouette Plot for Various Clusters
|
4701
|
+
# Description:
|
4702
|
+
# This horizontal bar plot shows the silhouette coefficient values for each sample, organized by cluster.
|
4703
|
+
|
4704
|
+
# Interpretation:
|
4705
|
+
# Each bar represents the silhouette score of a sample within a specific cluster. Longer bars indicate
|
4706
|
+
# that the samples are well-clustered.
|
4707
|
+
# The height of the bars shows how similar points within the same cluster are to one another compared to
|
4708
|
+
# points in other clusters.
|
4709
|
+
# The vertical red dashed line indicates the average silhouette score for all samples.
|
4710
|
+
# You want the majority of silhouette values to be above the average line, indicating that most points
|
4711
|
+
# are well-clustered.
|
4712
|
+
|
4713
|
+
# 以下代码不用再跑一次了
|
4714
|
+
# n_clusters = (
|
4715
|
+
# np.argmax(silhouette_avg_scores) + 2
|
4716
|
+
# ) # Optimal clusters based on max silhouette score
|
4717
|
+
# kmeans = KMeans(n_clusters=n_clusters, random_state=42)
|
4718
|
+
# cluster_labels = kmeans.fit_predict(X)
|
4719
|
+
silhouette_vals = silhouette_samples(X, cluster_labels)
|
4720
|
+
|
4721
|
+
plt.figure(figsize=inch2cm(10, 6))
|
4722
|
+
y_lower = 10
|
4723
|
+
for i in range(n_clusters):
|
4724
|
+
# Aggregate the silhouette scores for samples belonging to cluster i
|
4725
|
+
ith_cluster_silhouette_values = silhouette_vals[cluster_labels == i]
|
4726
|
+
|
4727
|
+
# Sort the values
|
4728
|
+
ith_cluster_silhouette_values.sort()
|
4729
|
+
|
4730
|
+
size_cluster_i = ith_cluster_silhouette_values.shape[0]
|
4731
|
+
y_upper = y_lower + size_cluster_i
|
4732
|
+
|
4733
|
+
# Create a horizontal bar plot for the silhouette scores
|
4734
|
+
plt.barh(range(y_lower, y_upper), ith_cluster_silhouette_values, height=0.5)
|
4735
|
+
|
4736
|
+
# Label the silhouette scores
|
4737
|
+
plt.text(-0.05, (y_lower + y_upper) / 2, str(i + 2))
|
4738
|
+
y_lower = y_upper + 10 # 10 for the 0 samples
|
4739
|
+
|
4740
|
+
plt.title("Silhouette Plot for the Various Clusters")
|
4741
|
+
plt.xlabel("Silhouette Coefficient Values")
|
4742
|
+
plt.ylabel("Cluster Label")
|
4743
|
+
plt.axvline(x=np.mean(silhouette_vals), color="red", linestyle="--")
|
4744
|
+
|
4745
|
+
df_clusters = pd.DataFrame(
|
4746
|
+
X, columns=[f"Feature {i+1}" for i in range(X.shape[1])]
|
4747
|
+
)
|
4748
|
+
df_clusters["Cluster"] = cluster_labels
|
4749
|
+
# ! pairplot of the clusters
|
4750
|
+
# Overview of the Pairplot
|
4751
|
+
# Axes and Grid:
|
4752
|
+
# The pairplot creates a grid of scatter plots for each pair of features in your dataset.
|
4753
|
+
# Each point in the scatter plots represents a sample from your dataset, colored according to its cluster assignment.
|
4754
|
+
|
4755
|
+
# Diagonal Elements:
|
4756
|
+
# The diagonal plots usually show the distribution of each feature. In this case, since X.shape[1] <= 4,
|
4757
|
+
# there will be a maximum of four features plotted against each other. The diagonal could display histograms or
|
4758
|
+
# kernel density estimates (KDE) for each feature.
|
4759
|
+
|
4760
|
+
# Interpretation of the Pairplot
|
4761
|
+
|
4762
|
+
# Feature Relationships:
|
4763
|
+
# Look at each scatter plot in the off-diagonal plots. Each plot shows the relationship between two features. Points that
|
4764
|
+
# are close together in the scatter plot suggest similar values for those features.
|
4765
|
+
# Cluster Separation: You want to see clusters of different colors (representing different clusters) that are visually distinct.
|
4766
|
+
# Good separation indicates that the clustering algorithm effectively identified different groups within your data.
|
4767
|
+
# Overlapping Points: If points from different clusters overlap significantly in any scatter plot, it indicates that those clusters
|
4768
|
+
# might not be distinct in terms of the two features being compared.
|
4769
|
+
# Cluster Characteristics:
|
4770
|
+
# Shape and Distribution: Observe the shape of the clusters. Are they spherical, elongated, or irregular? This can give insights
|
4771
|
+
# into how well the K-Means (or other clustering methods) has performed:
|
4772
|
+
# Spherical Clusters: Indicates that clusters are well defined and separated.
|
4773
|
+
# Elongated Clusters: May suggest that the algorithm is capturing variations along specific axes but could benefit from adjustments
|
4774
|
+
# in clustering parameters or methods.
|
4775
|
+
# Feature Influence: Identify which features contribute most to cluster separation. For instance, if you see that one feature
|
4776
|
+
# consistently separates two clusters, it may be a key factor for clustering.
|
4777
|
+
# Diagonal Histograms/KDE:
|
4778
|
+
# The diagonal plots show the distribution of individual features across all samples. Look for:
|
4779
|
+
# Distribution Shape: Is the distribution unimodal, bimodal, skewed, or uniform?
|
4780
|
+
# Concentration: Areas with a high density of points may indicate that certain values are more common among samples.
|
4781
|
+
# Differences Among Clusters: If you see distinct peaks in the histograms for different clusters, it suggests that those clusters are
|
4782
|
+
# characterized by specific ranges of feature values.
|
4783
|
+
# Example Observations
|
4784
|
+
# Feature 1 vs. Feature 2: If there are clear, well-separated clusters in this scatter plot, it suggests that these two features
|
4785
|
+
# effectively distinguish between the clusters.
|
4786
|
+
# Feature 3 vs. Feature 4: If you observe significant overlap between clusters in this plot, it may indicate that these features do not
|
4787
|
+
# provide a strong basis for clustering.
|
4788
|
+
# Diagonal Plots: If you notice that one cluster has a higher density of points at lower values for a specific feature, while another
|
4789
|
+
# cluster is concentrated at higher values, this suggests that this feature is critical for differentiating those clusters.
|
4790
|
+
|
4791
|
+
# Pairplot of the clusters
|
4792
|
+
# * 为什么要限制到4个features?
|
4793
|
+
# 2 features=1 scatter plot # 3 features=3 scatter plots
|
4794
|
+
# 4 features=6 scatter plots # 5 features=10 scatter plots
|
4795
|
+
# 6 features=15 scatter plots # 10 features=45 scatter plots
|
4796
|
+
# Pairplot works well with low-dimensional data, 如果维度比较高的话, 子图也很多,失去了它的意义
|
4797
|
+
if X.shape[1] <= 6:
|
4798
|
+
plt.figure(figsize=(8, 4))
|
4799
|
+
sns.pairplot(df_clusters, hue="Cluster", palette="tab10")
|
4800
|
+
plt.suptitle("Pairplot of Clusters", y=1.02)
|
4801
|
+
|
4802
|
+
# Add cluster labels to the DataFrame or modify in-place
|
4803
|
+
if inplace: # replace the oringinal data
|
4804
|
+
data["Cluster"] = cluster_labels
|
4805
|
+
return None, n_clusters, kmeans, ax # Return None when inplace is True
|
4806
|
+
else:
|
4807
|
+
data_copy = data.copy()
|
4808
|
+
data_copy["Cluster"] = cluster_labels
|
4809
|
+
return data_copy, n_clusters, kmeans, ax
|
4810
|
+
|
4811
|
+
|
4812
|
+
# example:
|
4813
|
+
# clustering_features = [marker + "_log" for marker in markers]
|
4814
|
+
# df_cluster(data, columns=clustering_features, n_clusters=3,range_n_clusters=np.arange(3, 7))
|
4815
|
+
|
4816
|
+
"""
|
4817
|
+
# You're on the right track, but let's clarify how PCA and clustering (like KMeans) work, especially
|
4818
|
+
# in the context of your dataset with 7 columns and 23,121 rows.
|
4819
|
+
|
4820
|
+
# Principal Component Analysis (PCA)
|
4821
|
+
# Purpose of PCA:
|
4822
|
+
# PCA is a dimensionality reduction technique. It transforms your dataset from a high-dimensional space
|
4823
|
+
# (in your case, 7 dimensions corresponding to your 7 columns) to a lower-dimensional space while
|
4824
|
+
# retaining as much variance (information) as possible.
|
4825
|
+
# How PCA Works:
|
4826
|
+
# PCA computes new features called "principal components" that are linear combinations of the original
|
4827
|
+
# features.
|
4828
|
+
# The first principal component captures the most variance, the second captures the next most variance
|
4829
|
+
# (orthogonal to the first), and so on.
|
4830
|
+
# If you set n_components=2, for example, PCA will reduce your dataset from 7 columns to 2 columns.
|
4831
|
+
# This helps in visualizing and analyzing the data with fewer dimensions.
|
4832
|
+
# Result of PCA:
|
4833
|
+
# After applying PCA, your original dataset with 7 columns will be transformed into a new dataset with
|
4834
|
+
# the specified number of components (e.g., 2 or 3).
|
4835
|
+
# The transformed dataset will have fewer columns but should capture most of the important information
|
4836
|
+
# from the original dataset.
|
4837
|
+
|
4838
|
+
# Clustering (KMeans)
|
4839
|
+
# Purpose of Clustering:
|
4840
|
+
# Clustering is used to group data points based on their similarities. KMeans, specifically, partitions
|
4841
|
+
# your data into a specified number of clusters (groups).
|
4842
|
+
# How KMeans Works:
|
4843
|
+
# KMeans assigns each data point to one of the k clusters based on the feature space (original or
|
4844
|
+
# PCA-transformed).
|
4845
|
+
# It aims to minimize the variance within each cluster while maximizing the variance between clusters.
|
4846
|
+
# It does not classify the data in each column independently; instead, it considers the overall similarity
|
4847
|
+
# between data points based on their features.
|
4848
|
+
# Result of KMeans:
|
4849
|
+
# The output will be cluster labels for each data point (e.g., which cluster a particular observation
|
4850
|
+
# belongs to).
|
4851
|
+
# You can visualize how many groups were formed and analyze the characteristics of each cluster.
|
4852
|
+
|
4853
|
+
# Summary
|
4854
|
+
# PCA reduces the number of features (columns) in your dataset, transforming it into a lower-dimensional
|
4855
|
+
# space.
|
4856
|
+
# KMeans then classifies data points based on the features of the transformed dataset (or the original
|
4857
|
+
# if you choose) into different subgroups (clusters).
|
4858
|
+
# By combining these techniques, you can simplify the complexity of your data and uncover patterns that
|
4859
|
+
# might not be visible in the original high-dimensional space. Let me know if you have further questions!
|
4860
|
+
"""
|
4861
|
+
|
4862
|
+
|
4863
|
+
def df_reducer(
|
4864
|
+
data: pd.DataFrame,
|
4865
|
+
columns: Optional[List[str]] = None,
|
4866
|
+
method: str = "umap", # 'pca', 'umap'
|
4867
|
+
n_components: int = 2, # Default for umap, but 50 for PCA
|
4868
|
+
umap_neighbors: int = 15, # Default
|
4869
|
+
umap_min_dist: float = 0.1, # Default
|
4870
|
+
scale: bool = True,
|
4871
|
+
fill_missing: bool = True,
|
4872
|
+
debug: bool = False,
|
4873
|
+
inplace: bool = True, # replace the oringinal data
|
4874
|
+
) -> pd.DataFrame:
|
4875
|
+
"""
|
4876
|
+
Reduces the dimensionality of the selected DataFrame using PCA or UMAP.
|
4877
|
+
|
4878
|
+
Parameters:
|
4879
|
+
-----------
|
4880
|
+
data : pd.DataFrame
|
4881
|
+
The input DataFrame (samples x features).
|
4882
|
+
|
4883
|
+
columns : List[str], optional
|
4884
|
+
List of column names to reduce. If None, all columns are used.
|
4885
|
+
|
4886
|
+
method : str, optional, default="umap"
|
4887
|
+
Dimensionality reduction method, either "pca" or "umap".
|
4888
|
+
|
4889
|
+
n_components : int, optional, default=50
|
4890
|
+
Number of components for PCA or UMAP.
|
4891
|
+
|
4892
|
+
umap_neighbors : int, optional, default=15
|
4893
|
+
Number of neighbors considered for UMAP embedding.
|
4894
|
+
|
4895
|
+
umap_min_dist : float, optional, default=0.1
|
4896
|
+
Minimum distance between points in UMAP embedding.
|
4897
|
+
|
4898
|
+
scale : bool, optional, default=True
|
4899
|
+
Whether to scale the data using StandardScaler.
|
4900
|
+
|
4901
|
+
fill_missing : bool, optional, default=True
|
4902
|
+
Whether to fill missing values using the mean before applying PCA/UMAP.
|
4903
|
+
|
4904
|
+
Returns:
|
4905
|
+
--------
|
4906
|
+
reduced_df : pd.DataFrame
|
4907
|
+
DataFrame with the reduced dimensions.
|
4908
|
+
"""
|
4909
|
+
from sklearn.decomposition import PCA
|
4910
|
+
from sklearn.preprocessing import StandardScaler
|
4911
|
+
import umap
|
4912
|
+
from sklearn.impute import SimpleImputer
|
4913
|
+
|
4914
|
+
# Select columns if specified, else use all columns
|
4915
|
+
X = data[columns].values if columns else data.values
|
4916
|
+
|
4917
|
+
# Handle missing values
|
4918
|
+
if fill_missing:
|
4919
|
+
imputer = SimpleImputer(strategy="mean")
|
4920
|
+
X = imputer.fit_transform(X)
|
4921
|
+
|
4922
|
+
# Optionally scale the data
|
4923
|
+
if scale:
|
4924
|
+
scaler = StandardScaler()
|
4925
|
+
X = scaler.fit_transform(X)
|
4926
|
+
|
4927
|
+
# Check valid method input
|
4928
|
+
if method not in ["pca", "umap"]:
|
4929
|
+
raise ValueError(f"Invalid method '{method}'. Choose 'pca' or 'umap'.")
|
4930
|
+
|
4931
|
+
# Apply PCA if selected
|
4932
|
+
if method == "pca":
|
4933
|
+
if n_components is None:
|
4934
|
+
# to get the n_components with threshold method:
|
4935
|
+
pca = PCA()
|
4936
|
+
pca_result = pca.fit_transform(X)
|
4937
|
+
|
4938
|
+
# Calculate explained variance
|
4939
|
+
explained_variance = pca.explained_variance_ratio_
|
4940
|
+
# Cumulative explained variance
|
4941
|
+
cumulative_variance = np.cumsum(explained_variance)
|
4942
|
+
# Set a threshold for cumulative variance
|
4943
|
+
threshold = 0.95 # Example threshold
|
4944
|
+
n_components = (
|
4945
|
+
np.argmax(cumulative_variance >= threshold) + 1
|
4946
|
+
) # Number of components to retain
|
4947
|
+
if debug:
|
4948
|
+
# debug:
|
4949
|
+
# Plot the cumulative explained variance
|
4950
|
+
plt.figure(figsize=(8, 5))
|
4951
|
+
plt.plot(
|
4952
|
+
range(1, len(cumulative_variance) + 1),
|
4953
|
+
cumulative_variance,
|
4954
|
+
marker="o",
|
4955
|
+
linestyle="-",
|
4956
|
+
)
|
4957
|
+
plt.title("Cumulative Explained Variance by Principal Components")
|
4958
|
+
plt.xlabel("Number of Principal Components")
|
4959
|
+
plt.ylabel("Cumulative Explained Variance")
|
4960
|
+
plt.xticks(range(1, len(cumulative_variance) + 1))
|
4961
|
+
# Add horizontal line for the threshold
|
4962
|
+
plt.axhline(
|
4963
|
+
y=threshold, color="r", linestyle="--", label="Threshold (95%)"
|
4964
|
+
)
|
4965
|
+
# Add vertical line for n_components
|
4966
|
+
plt.axvline(
|
4967
|
+
x=n_components,
|
4968
|
+
color="g",
|
4969
|
+
linestyle="--",
|
4970
|
+
label=f"n_components = {n_components}",
|
4971
|
+
)
|
4972
|
+
plt.legend()
|
4973
|
+
plt.grid()
|
4974
|
+
pca = PCA(n_components=n_components)
|
4975
|
+
X_reduced = pca.fit_transform(X)
|
4976
|
+
print(f"PCA completed: Reduced to {n_components} components.")
|
4977
|
+
|
4978
|
+
# Apply UMAP if selected
|
4979
|
+
elif method == "umap":
|
4980
|
+
umap_reducer = umap.UMAP(
|
4981
|
+
n_neighbors=umap_neighbors,
|
4982
|
+
min_dist=umap_min_dist,
|
4983
|
+
n_components=n_components,
|
4984
|
+
)
|
4985
|
+
X_reduced = umap_reducer.fit_transform(X)
|
4986
|
+
print(f"UMAP completed: Reduced to {n_components} components.")
|
4987
|
+
|
4988
|
+
# Return reduced data as a new DataFrame with the same index
|
4989
|
+
reduced_df = pd.DataFrame(X_reduced, index=data.index)
|
4990
|
+
|
4991
|
+
if inplace:
|
4992
|
+
# Replace or add new columns based on n_components
|
4993
|
+
for col_idx in range(n_components):
|
4994
|
+
data[f"Component_{col_idx+1}"] = reduced_df.iloc[:, col_idx]
|
4995
|
+
return None # No return when inplace=True
|
4996
|
+
|
4997
|
+
return reduced_df
|
4998
|
+
|
4999
|
+
|
5000
|
+
# example:
|
5001
|
+
# df_reducer(data=data_log, columns=markers, n_components=2)
|
5002
|
+
|
5003
|
+
|
5004
|
+
def plot_cluster(
|
5005
|
+
data: pd.DataFrame,
|
5006
|
+
labels: np.ndarray,
|
5007
|
+
metrics: dict = None,
|
5008
|
+
cmap="tab20",
|
5009
|
+
true_labels: Optional[np.ndarray] = None,
|
5010
|
+
) -> None:
|
5011
|
+
"""
|
5012
|
+
Visualize clustering results with various plots.
|
5013
|
+
|
5014
|
+
Parameters:
|
5015
|
+
-----------
|
5016
|
+
data : pd.DataFrame
|
5017
|
+
The input data used for clustering.
|
5018
|
+
labels : np.ndarray
|
5019
|
+
Cluster labels assigned to each point.
|
5020
|
+
metrics : dict
|
5021
|
+
Dictionary containing evaluation metrics from evaluate_cluster function.
|
5022
|
+
true_labels : Optional[np.ndarray], default=None
|
5023
|
+
Ground truth labels, if available.
|
5024
|
+
"""
|
5025
|
+
import seaborn as sns
|
5026
|
+
from sklearn.metrics import silhouette_samples
|
5027
|
+
|
5028
|
+
if metrics is None:
|
5029
|
+
metrics = evaluate_cluster(data=data, labels=labels, true_labels=true_labels)
|
5030
|
+
|
5031
|
+
# 1. Scatter Plot of Clusters
|
5032
|
+
plt.figure(figsize=(15, 6))
|
5033
|
+
plt.subplot(1, 3, 1)
|
5034
|
+
plt.scatter(data.iloc[:, 0], data.iloc[:, 1], c=labels, cmap=cmap, s=20)
|
5035
|
+
plt.title("Cluster Scatter Plot")
|
5036
|
+
plt.xlabel("Component 1")
|
5037
|
+
plt.ylabel("Component 2")
|
5038
|
+
plt.colorbar(label="Cluster Label")
|
5039
|
+
plt.grid()
|
5040
|
+
|
5041
|
+
# 2. Silhouette Plot
|
5042
|
+
if "Silhouette Score" in metrics:
|
5043
|
+
silhouette_vals = silhouette_samples(data, labels)
|
5044
|
+
plt.subplot(1, 3, 2)
|
5045
|
+
y_lower = 10
|
5046
|
+
for i in range(len(set(labels))):
|
5047
|
+
# Aggregate the silhouette scores for samples belonging to the current cluster
|
5048
|
+
cluster_silhouette_vals = silhouette_vals[labels == i]
|
5049
|
+
cluster_silhouette_vals.sort()
|
5050
|
+
size_cluster_i = cluster_silhouette_vals.shape[0]
|
5051
|
+
y_upper = y_lower + size_cluster_i
|
5052
|
+
|
5053
|
+
plt.fill_betweenx(np.arange(y_lower, y_upper), 0, cluster_silhouette_vals)
|
5054
|
+
plt.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
|
5055
|
+
y_lower = y_upper + 10 # 10 for the 0 samples
|
5056
|
+
|
5057
|
+
plt.title("Silhouette Plot")
|
5058
|
+
plt.xlabel("Silhouette Coefficient Values")
|
5059
|
+
plt.ylabel("Cluster Label")
|
5060
|
+
plt.axvline(x=metrics["Silhouette Score"], color="red", linestyle="--")
|
5061
|
+
plt.grid()
|
5062
|
+
|
5063
|
+
# 3. Metrics Plot
|
5064
|
+
plt.subplot(1, 3, 3)
|
5065
|
+
metric_names = ["Davies-Bouldin Index", "Calinski-Harabasz Index"]
|
5066
|
+
metric_values = [
|
5067
|
+
metrics["Davies-Bouldin Index"],
|
5068
|
+
metrics["Calinski-Harabasz Index"],
|
5069
|
+
]
|
5070
|
+
|
5071
|
+
if true_labels is not None:
|
5072
|
+
metric_names += ["Homogeneity Score", "Completeness Score", "V-Measure"]
|
5073
|
+
metric_values += [
|
5074
|
+
metrics["Homogeneity Score"],
|
5075
|
+
metrics["Completeness Score"],
|
5076
|
+
metrics["V-Measure"],
|
5077
|
+
]
|
5078
|
+
|
5079
|
+
plt.barh(metric_names, metric_values, color="lightblue")
|
5080
|
+
plt.title("Clustering Metrics")
|
5081
|
+
plt.xlabel("Score")
|
5082
|
+
plt.axvline(x=0, color="gray", linestyle="--")
|
5083
|
+
plt.grid()
|
5084
|
+
plt.tight_layout()
|
5085
|
+
|
5086
|
+
|
5087
|
+
def evaluate_cluster(
|
5088
|
+
data: pd.DataFrame, labels: np.ndarray, true_labels: Optional[np.ndarray] = None
|
5089
|
+
) -> dict:
|
5090
|
+
"""
|
5091
|
+
Evaluate clustering performance using various metrics.
|
5092
|
+
|
5093
|
+
Parameters:
|
5094
|
+
-----------
|
5095
|
+
data : pd.DataFrame
|
5096
|
+
The input data used for clustering.
|
5097
|
+
labels : np.ndarray
|
5098
|
+
Cluster labels assigned to each point.
|
5099
|
+
true_labels : Optional[np.ndarray], default=None
|
5100
|
+
Ground truth labels, if available.
|
5101
|
+
|
5102
|
+
Returns:
|
5103
|
+
--------
|
5104
|
+
metrics : dict
|
5105
|
+
Dictionary containing evaluation metrics.
|
5106
|
+
|
5107
|
+
1. Silhouette Score:
|
5108
|
+
The silhouette score measures how similar an object is to its own cluster (cohesion) compared to
|
5109
|
+
how similar it is to other clusters (separation). The score ranges from -1 to +1:
|
5110
|
+
+1: Indicates that the data point is very far from the neighboring clusters and well clustered.
|
5111
|
+
0: Indicates that the data point is on or very close to the decision boundary between two neighboring
|
5112
|
+
clusters.
|
5113
|
+
-1: Indicates that the data point might have been assigned to the wrong cluster.
|
5114
|
+
|
5115
|
+
Interpretation:
|
5116
|
+
A higher average silhouette score indicates better-defined clusters.
|
5117
|
+
If the score is consistently high (above 0.5), it suggests that the clusters are well separated.
|
5118
|
+
A score near 0 may indicate overlapping clusters, while negative scores suggest points may have
|
5119
|
+
been misclassified.
|
5120
|
+
|
5121
|
+
2. Davies-Bouldin Index:
|
5122
|
+
The Davies-Bouldin Index (DBI) measures the average similarity ratio of each cluster with its
|
5123
|
+
most similar cluster. The index values range from 0 to ∞, with lower values indicating better clustering.
|
5124
|
+
It is defined as the ratio of within-cluster distances to between-cluster distances.
|
5125
|
+
|
5126
|
+
Interpretation:
|
5127
|
+
A lower DBI value indicates that the clusters are compact and well-separated.
|
5128
|
+
Ideally, you want to minimize the Davies-Bouldin Index. If your DBI value is above 1, this indicates
|
5129
|
+
that your clusters might not be well-separated.
|
5130
|
+
|
5131
|
+
3. Adjusted Rand Index (ARI):
|
5132
|
+
The Adjusted Rand Index (ARI) is a measure of the similarity between two data clusterings. The ARI
|
5133
|
+
score ranges from -1 to +1:
|
5134
|
+
1: Indicates perfect agreement between the two clusterings.
|
5135
|
+
0: Indicates that the clusterings are no better than random.
|
5136
|
+
Negative values: Indicate less agreement than expected by chance.
|
5137
|
+
|
5138
|
+
Interpretation:
|
5139
|
+
A higher ARI score indicates better clustering, particularly if it's close to 1.
|
5140
|
+
An ARI score of 0 or lower suggests that the clustering results do not represent the true labels
|
5141
|
+
well, indicating a poor clustering performance.
|
5142
|
+
|
5143
|
+
4. Calinski-Harabasz Index:
|
5144
|
+
The Calinski-Harabasz Index (also known as the Variance Ratio Criterion) evaluates the ratio of
|
5145
|
+
the sum of between-cluster dispersion to within-cluster dispersion. Higher values indicate better clustering.
|
5146
|
+
|
5147
|
+
Interpretation:
|
5148
|
+
A higher Calinski-Harabasz Index suggests that clusters are dense and well-separated. It is typically
|
5149
|
+
used to validate the number of clusters, with higher values favoring more distinct clusters.
|
5150
|
+
|
5151
|
+
5. Homogeneity Score:
|
5152
|
+
The homogeneity score measures how much a cluster contains only members of a single class (if true labels are provided).
|
5153
|
+
A score of 1 indicates perfect homogeneity, where all points in a cluster belong to the same class.
|
5154
|
+
|
5155
|
+
Interpretation:
|
5156
|
+
A higher homogeneity score indicates that the clustering result is pure, meaning the clusters are composed
|
5157
|
+
of similar members. Lower values indicate mixed clusters, suggesting poor clustering performance.
|
5158
|
+
|
5159
|
+
6. Completeness Score:
|
5160
|
+
The completeness score evaluates how well all members of a given class are assigned to the same cluster.
|
5161
|
+
A score of 1 indicates perfect completeness, meaning all points in a true class are assigned to a single cluster.
|
5162
|
+
|
5163
|
+
Interpretation:
|
5164
|
+
A higher completeness score indicates that the clustering effectively groups all instances of a class together.
|
5165
|
+
Lower values suggest that some instances of a class are dispersed among multiple clusters.
|
5166
|
+
|
5167
|
+
7. V-Measure:
|
5168
|
+
The V-measure is the harmonic mean of homogeneity and completeness, giving a balanced measure of clustering performance.
|
5169
|
+
|
5170
|
+
Interpretation:
|
5171
|
+
A higher V-measure score indicates that the clusters are both homogenous (pure) and complete (cover all members of a class).
|
5172
|
+
Scores closer to 1 indicate better clustering quality.
|
5173
|
+
"""
|
5174
|
+
from sklearn.metrics import (
|
5175
|
+
silhouette_score,
|
5176
|
+
davies_bouldin_score,
|
5177
|
+
adjusted_rand_score,
|
5178
|
+
calinski_harabasz_score,
|
5179
|
+
homogeneity_score,
|
5180
|
+
completeness_score,
|
5181
|
+
v_measure_score,
|
5182
|
+
)
|
5183
|
+
|
5184
|
+
metrics = {}
|
5185
|
+
unique_labels = set(labels)
|
5186
|
+
if len(unique_labels) > 1 and len(unique_labels) < len(data):
|
5187
|
+
# Calculate Silhouette Score
|
5188
|
+
try:
|
5189
|
+
metrics["Silhouette Score"] = silhouette_score(data, labels)
|
5190
|
+
except Exception as e:
|
5191
|
+
metrics["Silhouette Score"] = np.nan
|
5192
|
+
print(f"Silhouette Score calculation failed: {e}")
|
5193
|
+
|
5194
|
+
# Calculate Davies-Bouldin Index
|
5195
|
+
try:
|
5196
|
+
metrics["Davies-Bouldin Index"] = davies_bouldin_score(data, labels)
|
5197
|
+
except Exception as e:
|
5198
|
+
metrics["Davies-Bouldin Index"] = np.nan
|
5199
|
+
print(f"Davies-Bouldin Index calculation failed: {e}")
|
5200
|
+
|
5201
|
+
# Calculate Calinski-Harabasz Index
|
5202
|
+
try:
|
5203
|
+
metrics["Calinski-Harabasz Index"] = calinski_harabasz_score(data, labels)
|
5204
|
+
except Exception as e:
|
5205
|
+
metrics["Calinski-Harabasz Index"] = np.nan
|
5206
|
+
print(f"Calinski-Harabasz Index calculation failed: {e}")
|
5207
|
+
|
5208
|
+
# Calculate Adjusted Rand Index if true labels are provided
|
5209
|
+
if true_labels is not None:
|
5210
|
+
try:
|
5211
|
+
metrics["Adjusted Rand Index"] = adjusted_rand_score(
|
5212
|
+
true_labels, labels
|
5213
|
+
)
|
5214
|
+
except Exception as e:
|
5215
|
+
metrics["Adjusted Rand Index"] = np.nan
|
5216
|
+
print(f"Adjusted Rand Index calculation failed: {e}")
|
5217
|
+
|
5218
|
+
# Calculate Homogeneity Score
|
5219
|
+
try:
|
5220
|
+
metrics["Homogeneity Score"] = homogeneity_score(true_labels, labels)
|
5221
|
+
except Exception as e:
|
5222
|
+
metrics["Homogeneity Score"] = np.nan
|
5223
|
+
print(f"Homogeneity Score calculation failed: {e}")
|
5224
|
+
|
5225
|
+
# Calculate Completeness Score
|
5226
|
+
try:
|
5227
|
+
metrics["Completeness Score"] = completeness_score(true_labels, labels)
|
5228
|
+
except Exception as e:
|
5229
|
+
metrics["Completeness Score"] = np.nan
|
5230
|
+
print(f"Completeness Score calculation failed: {e}")
|
5231
|
+
|
5232
|
+
# Calculate V-Measure
|
5233
|
+
try:
|
5234
|
+
metrics["V-Measure"] = v_measure_score(true_labels, labels)
|
5235
|
+
except Exception as e:
|
5236
|
+
metrics["V-Measure"] = np.nan
|
5237
|
+
print(f"V-Measure calculation failed: {e}")
|
5238
|
+
else:
|
5239
|
+
# Metrics cannot be computed with 1 cluster or all points as noise
|
5240
|
+
metrics["Silhouette Score"] = np.nan
|
5241
|
+
metrics["Davies-Bouldin Index"] = np.nan
|
5242
|
+
metrics["Calinski-Harabasz Index"] = np.nan
|
5243
|
+
if true_labels is not None:
|
5244
|
+
metrics["Adjusted Rand Index"] = np.nan
|
5245
|
+
metrics["Homogeneity Score"] = np.nan
|
5246
|
+
metrics["Completeness Score"] = np.nan
|
5247
|
+
metrics["V-Measure"] = np.nan
|
5248
|
+
|
5249
|
+
return metrics
|
5250
|
+
|
5251
|
+
|
5252
|
+
def print_pd_usage(
|
5253
|
+
func_name="excel",
|
5254
|
+
verbose=True,
|
5255
|
+
dir_json="/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/usages_pd.json",
|
5256
|
+
):
|
5257
|
+
default_settings = fload(dir_json, output='json')
|
5258
|
+
valid_kinds = list(default_settings.keys())
|
5259
|
+
kind = strcmp(func_name, valid_kinds)[0]
|
5260
|
+
usage=default_settings[kind]
|
5261
|
+
if verbose:
|
5262
|
+
for i, i_ in enumerate(ssplit(usage, by=",")):
|
5263
|
+
i_ = i_.replace("=", "\t= ") + ","
|
5264
|
+
print(i_) if i == 0 else print("\t", i_)
|
5265
|
+
else:
|
5266
|
+
print(usage)
|