halib 0.1.94__py3-none-any.whl → 0.1.96__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/filetype/yamlfile.py +23 -0
- halib/research/flop_csv.py +34 -0
- halib/research/flops.py +156 -0
- halib/research/mics.py +52 -0
- halib/research/perfcalc.py +3 -3
- halib/research/perftb.py +2 -1
- halib/research/plot.py +20 -241
- {halib-0.1.94.dist-info → halib-0.1.96.dist-info}/METADATA +2 -2
- {halib-0.1.94.dist-info → halib-0.1.96.dist-info}/RECORD +12 -10
- {halib-0.1.94.dist-info → halib-0.1.96.dist-info}/WHEEL +0 -0
- {halib-0.1.94.dist-info → halib-0.1.96.dist-info}/licenses/LICENSE.txt +0 -0
- {halib-0.1.94.dist-info → halib-0.1.96.dist-info}/top_level.txt +0 -0
halib/filetype/yamlfile.py
CHANGED
|
@@ -6,6 +6,8 @@ from omegaconf import OmegaConf
|
|
|
6
6
|
from rich.console import Console
|
|
7
7
|
from argparse import ArgumentParser
|
|
8
8
|
|
|
9
|
+
from ..research.mics import *
|
|
10
|
+
|
|
9
11
|
console = Console()
|
|
10
12
|
|
|
11
13
|
|
|
@@ -51,6 +53,27 @@ def load_yaml(yaml_file, to_dict=False, log_info=False):
|
|
|
51
53
|
else:
|
|
52
54
|
return omgconf
|
|
53
55
|
|
|
56
|
+
def load_yaml_with_PC_abbr(
|
|
57
|
+
yaml_file, pc_abbr_to_working_disk=DEFAULT_ABBR_WORKING_DISK
|
|
58
|
+
):
|
|
59
|
+
# current PC abbreviation
|
|
60
|
+
pc_abbr = get_PC_abbr_name()
|
|
61
|
+
|
|
62
|
+
# current plaftform: windows or linux
|
|
63
|
+
current_platform = platform.system().lower()
|
|
64
|
+
|
|
65
|
+
assert pc_abbr in pc_abbr_to_working_disk, f"The is no mapping for {pc_abbr} to <working_disk>"
|
|
66
|
+
|
|
67
|
+
# working disk
|
|
68
|
+
working_disk = pc_abbr_to_working_disk.get(pc_abbr)
|
|
69
|
+
|
|
70
|
+
# load yaml file
|
|
71
|
+
data_dict = load_yaml(yaml_file=yaml_file, to_dict=True)
|
|
72
|
+
|
|
73
|
+
# Normalize paths in the loaded data
|
|
74
|
+
data_dict = normalize_paths(data_dict, working_disk, current_platform)
|
|
75
|
+
return data_dict
|
|
76
|
+
|
|
54
77
|
|
|
55
78
|
def parse_args():
|
|
56
79
|
parser = ArgumentParser(description="desc text")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from halib import *
|
|
2
|
+
from flops import _calculate_flops_for_model
|
|
3
|
+
|
|
4
|
+
from halib import *
|
|
5
|
+
from argparse import ArgumentParser
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def main():
|
|
9
|
+
csv_file = "./results-imagenet.csv"
|
|
10
|
+
df = pd.read_csv(csv_file)
|
|
11
|
+
# make param_count column as float
|
|
12
|
+
# df['param_count'] = df['param_count'].astype(float)
|
|
13
|
+
df['param_count'] = pd.to_numeric(df['param_count'], errors='coerce').fillna(99999).astype(float)
|
|
14
|
+
df = df[df['param_count'] < 5.0] # filter models with param_count < 20M
|
|
15
|
+
|
|
16
|
+
dict_ls = []
|
|
17
|
+
|
|
18
|
+
for index, row in tqdm(df.iterrows()):
|
|
19
|
+
console.rule(f"Row {index+1}/{len(df)}")
|
|
20
|
+
model = row['model']
|
|
21
|
+
num_class = 2
|
|
22
|
+
_, _, mflops = _calculate_flops_for_model(model, num_class)
|
|
23
|
+
dict_ls.append({'model': model, 'param_count': row['param_count'], 'mflops': mflops})
|
|
24
|
+
|
|
25
|
+
# Create a DataFrame from the list of dictionaries
|
|
26
|
+
result_df = pd.DataFrame(dict_ls)
|
|
27
|
+
|
|
28
|
+
final_df = pd.merge(df, result_df, on=['model', 'param_count'])
|
|
29
|
+
final_df.sort_values(by='mflops', inplace=True, ascending=True)
|
|
30
|
+
csvfile.fn_display_df(final_df)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if __name__ == "__main__":
|
|
34
|
+
main()
|
halib/research/flops.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import torch
|
|
4
|
+
import timm
|
|
5
|
+
from argparse import ArgumentParser
|
|
6
|
+
from fvcore.nn import FlopCountAnalysis
|
|
7
|
+
from halib import *
|
|
8
|
+
from halib.filetype import csvfile
|
|
9
|
+
from curriculum.utils.config import *
|
|
10
|
+
from curriculum.utils.model_helper import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# ---------------------------------------------------------------------
|
|
14
|
+
# Argument Parser
|
|
15
|
+
# ---------------------------------------------------------------------
|
|
16
|
+
def parse_args():
|
|
17
|
+
parser = ArgumentParser(description="Calculate FLOPs for TIMM or trained models")
|
|
18
|
+
|
|
19
|
+
# Option 1: Direct TIMM model
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--model_name", type=str, help="TIMM model name (e.g., efficientnet_b0)"
|
|
22
|
+
)
|
|
23
|
+
parser.add_argument(
|
|
24
|
+
"--num_classes", type=int, default=1000, help="Number of output classes"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Option 2: Experiment directory
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--indir",
|
|
30
|
+
type=str,
|
|
31
|
+
default=None,
|
|
32
|
+
help="Directory containing trained experiment (with .yaml and .pth)",
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"-o", "--o", action="store_true", help="Open output CSV after saving"
|
|
36
|
+
)
|
|
37
|
+
return parser.parse_args()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# ---------------------------------------------------------------------
|
|
41
|
+
# Helper Functions
|
|
42
|
+
# ---------------------------------------------------------------------
|
|
43
|
+
def _get_list_of_proc_dirs(indir):
|
|
44
|
+
assert os.path.exists(indir), f"Input directory {indir} does not exist."
|
|
45
|
+
pth_files = [f for f in os.listdir(indir) if f.endswith(".pth")]
|
|
46
|
+
if len(pth_files) > 0:
|
|
47
|
+
return [indir]
|
|
48
|
+
return [
|
|
49
|
+
os.path.join(indir, f)
|
|
50
|
+
for f in os.listdir(indir)
|
|
51
|
+
if os.path.isdir(os.path.join(indir, f))
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _calculate_flops_for_model(model_name, num_classes):
|
|
56
|
+
"""Calculate FLOPs for a plain TIMM model."""
|
|
57
|
+
try:
|
|
58
|
+
model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
|
|
59
|
+
input_size = timm.data.resolve_data_config(model.default_cfg)["input_size"]
|
|
60
|
+
dummy_input = torch.randn(1, *input_size)
|
|
61
|
+
model.eval() # ! set to eval mode to avoid some warnings or errors
|
|
62
|
+
flops = FlopCountAnalysis(model, dummy_input)
|
|
63
|
+
gflops = flops.total() / 1e9
|
|
64
|
+
mflops = flops.total() / 1e6
|
|
65
|
+
print(f"\nModel: **{model_name}**, Classes: {num_classes}")
|
|
66
|
+
print(f"Input size: {input_size}, FLOPs: **{gflops:.3f} GFLOPs**, **{mflops:.3f} MFLOPs**\n")
|
|
67
|
+
return model_name, gflops, mflops
|
|
68
|
+
except Exception as e:
|
|
69
|
+
print(f"[Error] Could not calculate FLOPs for {model_name}: {e}")
|
|
70
|
+
return model_name, -1, -1
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _calculate_flops_for_experiment(exp_dir):
|
|
74
|
+
"""Calculate FLOPs for a trained experiment directory."""
|
|
75
|
+
yaml_files = [f for f in os.listdir(exp_dir) if f.endswith(".yaml")]
|
|
76
|
+
pth_files = [f for f in os.listdir(exp_dir) if f.endswith(".pth")]
|
|
77
|
+
|
|
78
|
+
assert (
|
|
79
|
+
len(yaml_files) == 1
|
|
80
|
+
), f"Expected 1 YAML file in {exp_dir}, found {len(yaml_files)}"
|
|
81
|
+
assert (
|
|
82
|
+
len(pth_files) == 1
|
|
83
|
+
), f"Expected 1 PTH file in {exp_dir}, found {len(pth_files)}"
|
|
84
|
+
|
|
85
|
+
exp_cfg_yaml = os.path.join(exp_dir, yaml_files[0])
|
|
86
|
+
cfg = ExpConfig.from_yaml(exp_cfg_yaml)
|
|
87
|
+
ds_label_list = cfg.dataset.get_label_list()
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
model = build_model(
|
|
91
|
+
cfg.model.name, num_classes=len(ds_label_list), pretrained=True
|
|
92
|
+
)
|
|
93
|
+
model_weights_path = os.path.join(exp_dir, pth_files[0])
|
|
94
|
+
model.load_state_dict(torch.load(model_weights_path, map_location="cpu"))
|
|
95
|
+
model.eval()
|
|
96
|
+
|
|
97
|
+
input_size = timm.data.resolve_data_config(model.default_cfg)["input_size"]
|
|
98
|
+
dummy_input = torch.randn(1, *input_size)
|
|
99
|
+
flops = FlopCountAnalysis(model, dummy_input)
|
|
100
|
+
gflops = flops.total() / 1e9
|
|
101
|
+
mflops = flops.total() / 1e6
|
|
102
|
+
|
|
103
|
+
return str(cfg), cfg.model.name, gflops, mflops
|
|
104
|
+
except Exception as e:
|
|
105
|
+
console.print(f"[red] Error processing {exp_dir}: {e}[/red]")
|
|
106
|
+
return str(cfg), cfg.model.name, -1, -1
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# ---------------------------------------------------------------------
|
|
110
|
+
# Main Entry
|
|
111
|
+
# ---------------------------------------------------------------------
|
|
112
|
+
def main():
|
|
113
|
+
args = parse_args()
|
|
114
|
+
|
|
115
|
+
# Case 1: Direct TIMM model input
|
|
116
|
+
if args.model_name:
|
|
117
|
+
_calculate_flops_for_model(args.model_name, args.num_classes)
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
# Case 2: Experiment directory input
|
|
121
|
+
if args.indir is None:
|
|
122
|
+
print("[Error] Either --model_name or --indir must be specified.")
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
proc_dirs = _get_list_of_proc_dirs(args.indir)
|
|
126
|
+
pprint(proc_dirs)
|
|
127
|
+
|
|
128
|
+
dfmk = csvfile.DFCreator()
|
|
129
|
+
TABLE_NAME = "model_flops_results"
|
|
130
|
+
dfmk.create_table(TABLE_NAME, ["exp_name", "model_name", "gflops", "mflops"])
|
|
131
|
+
|
|
132
|
+
console.rule(f"Calculating FLOPs for models in {len(proc_dirs)} dir(s)...")
|
|
133
|
+
rows = []
|
|
134
|
+
for exp_dir in tqdm(proc_dirs):
|
|
135
|
+
dir_name = os.path.basename(exp_dir)
|
|
136
|
+
console.rule(f"{dir_name}")
|
|
137
|
+
exp_name, model_name, gflops, mflops = _calculate_flops_for_experiment(exp_dir)
|
|
138
|
+
rows.append([exp_name, model_name, gflops, mflops])
|
|
139
|
+
|
|
140
|
+
dfmk.insert_rows(TABLE_NAME, rows)
|
|
141
|
+
dfmk.fill_table_from_row_pool(TABLE_NAME)
|
|
142
|
+
|
|
143
|
+
outfile = f"zout/zreport/{now_str()}_model_flops_results.csv"
|
|
144
|
+
dfmk[TABLE_NAME].to_csv(outfile, sep=";", index=False)
|
|
145
|
+
csvfile.fn_display_df(dfmk[TABLE_NAME])
|
|
146
|
+
|
|
147
|
+
if args.o:
|
|
148
|
+
os.system(f"start {outfile}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ---------------------------------------------------------------------
|
|
152
|
+
# Script Entry
|
|
153
|
+
# ---------------------------------------------------------------------
|
|
154
|
+
if __name__ == "__main__":
|
|
155
|
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
156
|
+
main()
|
halib/research/mics.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
|
+
from ..common import *
|
|
2
|
+
from ..filetype import csvfile
|
|
3
|
+
import pandas as pd
|
|
1
4
|
import platform
|
|
2
5
|
|
|
6
|
+
|
|
3
7
|
PC_NAME_TO_ABBR = {
|
|
4
8
|
"DESKTOP-JQD9K01": "MainPC",
|
|
5
9
|
"DESKTOP-5IRHU87": "MSI_Laptop",
|
|
@@ -8,9 +12,57 @@ PC_NAME_TO_ABBR = {
|
|
|
8
12
|
"DESKTOP-QNS3DNF": "1GPU_SV"
|
|
9
13
|
}
|
|
10
14
|
|
|
15
|
+
DEFAULT_ABBR_WORKING_DISK = {
|
|
16
|
+
"MainPC": "E:",
|
|
17
|
+
"MSI_Laptop": "D:",
|
|
18
|
+
"4090_SV": "E:",
|
|
19
|
+
"4GPU_SV": "D:",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
def list_PCs(show=True):
|
|
23
|
+
df = pd.DataFrame(list(PC_NAME_TO_ABBR.items()), columns=["PC Name", "Abbreviation"])
|
|
24
|
+
if show:
|
|
25
|
+
csvfile.fn_display_df(df)
|
|
26
|
+
return df
|
|
27
|
+
|
|
11
28
|
def get_PC_name():
|
|
12
29
|
return platform.node()
|
|
13
30
|
|
|
14
31
|
def get_PC_abbr_name():
|
|
15
32
|
pc_name = get_PC_name()
|
|
16
33
|
return PC_NAME_TO_ABBR.get(pc_name, "Unknown")
|
|
34
|
+
|
|
35
|
+
# ! This funcction search for full paths in the obj and normalize them according to the current platform and working disk
|
|
36
|
+
# ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "windows" => "D:/zdataset/DFire"
|
|
37
|
+
# ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "linux" => "/mnt/d/zdataset/DFire"
|
|
38
|
+
def normalize_paths(obj, working_disk, current_platform):
|
|
39
|
+
if isinstance(obj, dict):
|
|
40
|
+
for key, value in obj.items():
|
|
41
|
+
obj[key] = normalize_paths(value, working_disk, current_platform)
|
|
42
|
+
return obj
|
|
43
|
+
elif isinstance(obj, list):
|
|
44
|
+
for i, item in enumerate(obj):
|
|
45
|
+
obj[i] = normalize_paths(item, working_disk, current_platform)
|
|
46
|
+
return obj
|
|
47
|
+
elif isinstance(obj, str):
|
|
48
|
+
# Normalize backslashes to forward slashes for consistency
|
|
49
|
+
obj = obj.replace("\\", "/")
|
|
50
|
+
# Regex for Windows-style path: e.g., "E:/zdataset/DFire"
|
|
51
|
+
win_match = re.match(r"^([A-Z]):/(.*)$", obj)
|
|
52
|
+
# Regex for Linux-style path: e.g., "/mnt/e/zdataset/DFire"
|
|
53
|
+
lin_match = re.match(r"^/mnt/([a-z])/(.*)$", obj)
|
|
54
|
+
if win_match or lin_match:
|
|
55
|
+
rest = win_match.group(2) if win_match else lin_match.group(2)
|
|
56
|
+
if current_platform == "windows":
|
|
57
|
+
# working_disk is like "D:", so "D:/" + rest
|
|
58
|
+
new_path = working_disk + "/" + rest
|
|
59
|
+
elif current_platform == "linux":
|
|
60
|
+
# Extract drive letter from working_disk (e.g., "D:" -> "d")
|
|
61
|
+
drive_letter = working_disk[0].lower()
|
|
62
|
+
new_path = "/mnt/" + drive_letter + "/" + rest
|
|
63
|
+
else:
|
|
64
|
+
# Unknown platform, return original
|
|
65
|
+
return obj
|
|
66
|
+
return new_path
|
|
67
|
+
# For non-strings or non-path strings, return as is
|
|
68
|
+
return obj
|
halib/research/perfcalc.py
CHANGED
|
@@ -227,9 +227,9 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
|
227
227
|
), "No metric columns found in the DataFrame. Ensure that the CSV files contain metric columns starting with 'metric_'."
|
|
228
228
|
final_cols = sticky_cols + metric_cols
|
|
229
229
|
df = df[final_cols]
|
|
230
|
-
# !hahv debug
|
|
231
|
-
pprint("------ Final DataFrame Columns ------")
|
|
232
|
-
csvfile.fn_display_df(df)
|
|
230
|
+
# # !hahv debug
|
|
231
|
+
# pprint("------ Final DataFrame Columns ------")
|
|
232
|
+
# csvfile.fn_display_df(df)
|
|
233
233
|
# ! validate all rows in df before returning
|
|
234
234
|
# make sure all rows will have at least values for REQUIRED_COLS and at least one metric column
|
|
235
235
|
for index, row in df.iterrows():
|
halib/research/perftb.py
CHANGED
|
@@ -308,7 +308,8 @@ class PerfTB:
|
|
|
308
308
|
if save_path:
|
|
309
309
|
export_success = False
|
|
310
310
|
try:
|
|
311
|
-
fig.write_image(save_path, engine="kaleido")
|
|
311
|
+
# fig.write_image(save_path, engine="kaleido")
|
|
312
|
+
fig.write_image(save_path, engine="kaleido", width=width, height=height * len(metric_list))
|
|
312
313
|
export_success = True
|
|
313
314
|
# pprint(f"Saved: {os.path.abspath(save_path)}")
|
|
314
315
|
except Exception as e:
|
halib/research/plot.py
CHANGED
|
@@ -281,7 +281,7 @@ class PlotHelper:
|
|
|
281
281
|
# --- Convert to wide "multi-list" format ---
|
|
282
282
|
df = pd.DataFrame(data)
|
|
283
283
|
row_col = df.columns[0] # first col = row labels
|
|
284
|
-
col_cols = df.columns[1:] # the rest = groupable cols
|
|
284
|
+
# col_cols = df.columns[1:] # the rest = groupable cols
|
|
285
285
|
|
|
286
286
|
df = (
|
|
287
287
|
df.melt(id_vars=[row_col], var_name="col", value_name="path")
|
|
@@ -356,244 +356,9 @@ class PlotHelper:
|
|
|
356
356
|
@staticmethod
|
|
357
357
|
def plot_image_grid(
|
|
358
358
|
indir_or_csvf_or_df: Union[str, pd.DataFrame],
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
img_stack_padding_px: int = 5,
|
|
363
|
-
img_scale_mode: str = "fit", # "fit" or "fill"
|
|
364
|
-
format_row_label_func: Optional[Callable[[str], str]] = None,
|
|
365
|
-
format_col_label_func: Optional[Callable[[str], str]] = None,
|
|
366
|
-
title: str = "",
|
|
367
|
-
tickfont=dict(size=16, family="Arial", color="black"), # <-- bigger labels
|
|
368
|
-
fig_margin: dict = dict(l=50, r=50, t=50, b=50),
|
|
369
|
-
outline_color: str = "",
|
|
370
|
-
outline_size: int = 1,
|
|
371
|
-
cell_margin_px: int = 10, # spacing between cells
|
|
372
|
-
row_line_size: int = 0, # if >0, draw horizontal dotted lines
|
|
373
|
-
col_line_size: int = 0 # if >0, draw vertical dotted lines
|
|
374
|
-
):
|
|
375
|
-
"""
|
|
376
|
-
Plot a grid of images using Plotly.
|
|
377
|
-
|
|
378
|
-
- Accepts DataFrame where each cell is either:
|
|
379
|
-
* a Python list object,
|
|
380
|
-
* a string representation of a Python list (e.g. "['a','b']"),
|
|
381
|
-
* a JSON list string, or
|
|
382
|
-
* a single path string.
|
|
383
|
-
- For each cell, stack the images into a single composite that exactly fits
|
|
384
|
-
(img_width, img_height) is the target size for each individual image in the stack.
|
|
385
|
-
The final cell size will depend on the number of images and stacking direction.
|
|
386
|
-
"""
|
|
387
|
-
|
|
388
|
-
def process_image_for_slot(path: str, target_size: Tuple[int, int], scale_mode: str, outline: str, outline_size: int) -> Image.Image:
|
|
389
|
-
try:
|
|
390
|
-
img = Image.open(path).convert("RGB")
|
|
391
|
-
except Exception:
|
|
392
|
-
return Image.new("RGB", target_size, (255, 255, 255))
|
|
393
|
-
|
|
394
|
-
if scale_mode == "fit":
|
|
395
|
-
img_ratio = img.width / img.height
|
|
396
|
-
target_ratio = target_size[0] / target_size[1]
|
|
397
|
-
|
|
398
|
-
if img_ratio > target_ratio:
|
|
399
|
-
new_height = target_size[1]
|
|
400
|
-
new_width = max(1, int(new_height * img_ratio))
|
|
401
|
-
else:
|
|
402
|
-
new_width = target_size[0]
|
|
403
|
-
new_height = max(1, int(new_width / img_ratio))
|
|
404
|
-
|
|
405
|
-
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
406
|
-
left = (new_width - target_size[0]) // 2
|
|
407
|
-
top = (new_height - target_size[1]) // 2
|
|
408
|
-
right = left + target_size[0]
|
|
409
|
-
bottom = top + target_size[1]
|
|
410
|
-
|
|
411
|
-
if len(outline) == 7 and outline.startswith("#"):
|
|
412
|
-
border_px = outline_size
|
|
413
|
-
bordered = Image.new("RGB", (target_size[0] + 2*border_px, target_size[1] + 2*border_px), outline)
|
|
414
|
-
bordered.paste(img.crop((left, top, right, bottom)), (border_px, border_px))
|
|
415
|
-
return bordered
|
|
416
|
-
return img.crop((left, top, right, bottom))
|
|
417
|
-
|
|
418
|
-
elif scale_mode == "fill":
|
|
419
|
-
if len(outline) == 7 and outline.startswith("#"):
|
|
420
|
-
border_px = outline_size
|
|
421
|
-
bordered = Image.new("RGB", (target_size[0] + 2*border_px, target_size[1] + 2*border_px), outline)
|
|
422
|
-
img = img.resize(target_size, Image.Resampling.LANCZOS)
|
|
423
|
-
bordered.paste(img, (border_px, border_px))
|
|
424
|
-
return bordered
|
|
425
|
-
return img.resize(target_size, Image.Resampling.LANCZOS)
|
|
426
|
-
else:
|
|
427
|
-
raise ValueError("img_scale_mode must be 'fit' or 'fill'.")
|
|
428
|
-
|
|
429
|
-
def stack_images_base64(image_paths: List[str], direction: str, single_img_size: Tuple[int,int], outline: str, outline_size: int, padding: int) -> Tuple[str, Tuple[int,int]]:
|
|
430
|
-
image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
|
|
431
|
-
n = len(image_paths)
|
|
432
|
-
if n == 0:
|
|
433
|
-
blank = Image.new("RGB", single_img_size, (255,255,255))
|
|
434
|
-
buf = BytesIO()
|
|
435
|
-
blank.save(buf, format="PNG")
|
|
436
|
-
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode(), single_img_size
|
|
437
|
-
|
|
438
|
-
processed = [process_image_for_slot(p, single_img_size, img_scale_mode, outline, outline_size) for p in image_paths]
|
|
439
|
-
pad_total = padding * (n-1)
|
|
440
|
-
|
|
441
|
-
if direction == "horizontal":
|
|
442
|
-
total_w = sum(im.width for im in processed) + pad_total
|
|
443
|
-
total_h = max(im.height for im in processed)
|
|
444
|
-
stacked = Image.new("RGB", (total_w, total_h), (255,255,255))
|
|
445
|
-
x = 0
|
|
446
|
-
for im in processed:
|
|
447
|
-
stacked.paste(im, (x,0))
|
|
448
|
-
x += im.width + padding
|
|
449
|
-
elif direction == "vertical":
|
|
450
|
-
total_w = max(im.width for im in processed)
|
|
451
|
-
total_h = sum(im.height for im in processed) + pad_total
|
|
452
|
-
stacked = Image.new("RGB", (total_w, total_h), (255,255,255))
|
|
453
|
-
y = 0
|
|
454
|
-
for im in processed:
|
|
455
|
-
stacked.paste(im, (0,y))
|
|
456
|
-
y += im.height + padding
|
|
457
|
-
else:
|
|
458
|
-
raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
|
|
459
|
-
|
|
460
|
-
buf = BytesIO()
|
|
461
|
-
stacked.save(buf, format="PNG")
|
|
462
|
-
encoded = base64.b64encode(buf.getvalue()).decode()
|
|
463
|
-
return f"data:image/png;base64,{encoded}", (total_w, total_h)
|
|
464
|
-
|
|
465
|
-
# --- Load DataFrame ---
|
|
466
|
-
if isinstance(indir_or_csvf_or_df, str):
|
|
467
|
-
fname, ext = os.path.splitext(indir_or_csvf_or_df)
|
|
468
|
-
if ext.lower() == ".csv":
|
|
469
|
-
df = pd.read_csv(indir_or_csvf_or_df)
|
|
470
|
-
elif os.path.isdir(indir_or_csvf_or_df):
|
|
471
|
-
df = PlotHelper.img_grid_indir_1(indir_or_csvf_or_df, log=False)
|
|
472
|
-
else:
|
|
473
|
-
raise ValueError("Input string must be a valid CSV file or directory path")
|
|
474
|
-
elif isinstance(indir_or_csvf_or_df, pd.DataFrame):
|
|
475
|
-
df = indir_or_csvf_or_df.copy()
|
|
476
|
-
else:
|
|
477
|
-
raise ValueError("Input must be CSV file path, DataFrame, or directory path")
|
|
478
|
-
|
|
479
|
-
rows = df.iloc[:,0].astype(str).tolist()
|
|
480
|
-
columns = list(df.columns[1:])
|
|
481
|
-
n_rows, n_cols = len(rows), len(columns)
|
|
482
|
-
|
|
483
|
-
fig = go.Figure()
|
|
484
|
-
col_widths = [0]*n_cols
|
|
485
|
-
row_heights = [0]*n_rows
|
|
486
|
-
|
|
487
|
-
cell_imgs = [[None]*n_cols for _ in range(n_rows)]
|
|
488
|
-
for i in range(n_rows):
|
|
489
|
-
for j, col_label in enumerate(columns):
|
|
490
|
-
raw_cell = df.iloc[i, j+1]
|
|
491
|
-
image_paths = PlotHelper._parse_cell_to_list(raw_cell)
|
|
492
|
-
image_paths = [str(p).strip() for p in image_paths if str(p).strip() != ""]
|
|
493
|
-
|
|
494
|
-
img_src, (cell_w_actual, cell_h_actual) = stack_images_base64(
|
|
495
|
-
image_paths, img_stack_direction, (img_width, img_height),
|
|
496
|
-
outline=outline_color, outline_size=outline_size,
|
|
497
|
-
padding=img_stack_padding_px
|
|
498
|
-
)
|
|
499
|
-
|
|
500
|
-
col_widths[j] = max(col_widths[j], cell_w_actual)
|
|
501
|
-
row_heights[i] = max(row_heights[i], cell_h_actual)
|
|
502
|
-
cell_imgs[i][j] = img_src
|
|
503
|
-
|
|
504
|
-
# Compute x/y positions including cell_margin
|
|
505
|
-
x_positions = []
|
|
506
|
-
cum_w = 0
|
|
507
|
-
for w in col_widths:
|
|
508
|
-
x_positions.append(cum_w)
|
|
509
|
-
cum_w += w + cell_margin_px
|
|
510
|
-
|
|
511
|
-
y_positions = []
|
|
512
|
-
cum_h = 0
|
|
513
|
-
for h in row_heights:
|
|
514
|
-
y_positions.append(-cum_h)
|
|
515
|
-
cum_h += h + cell_margin_px
|
|
516
|
-
|
|
517
|
-
# Add images to figure
|
|
518
|
-
for i in range(n_rows):
|
|
519
|
-
for j in range(n_cols):
|
|
520
|
-
fig.add_layout_image(
|
|
521
|
-
dict(
|
|
522
|
-
source=cell_imgs[i][j],
|
|
523
|
-
x=x_positions[j],
|
|
524
|
-
y=y_positions[i],
|
|
525
|
-
xref="x",
|
|
526
|
-
yref="y",
|
|
527
|
-
sizex=col_widths[j],
|
|
528
|
-
sizey=row_heights[i],
|
|
529
|
-
xanchor="left",
|
|
530
|
-
yanchor="top",
|
|
531
|
-
layer="above",
|
|
532
|
-
)
|
|
533
|
-
)
|
|
534
|
-
# ! Optional grid lines
|
|
535
|
-
# Add horizontal grid lines if row_line_size > 0
|
|
536
|
-
if row_line_size > 0:
|
|
537
|
-
for i in range(1, n_rows):
|
|
538
|
-
# Place line in the middle of the gap between rows
|
|
539
|
-
y = (
|
|
540
|
-
y_positions[i - 1] - row_heights[i - 1] - y_positions[i]
|
|
541
|
-
) / 2 + y_positions[i]
|
|
542
|
-
fig.add_shape(
|
|
543
|
-
type="line",
|
|
544
|
-
x0=-cell_margin_px,
|
|
545
|
-
x1=cum_w - cell_margin_px,
|
|
546
|
-
y0=y,
|
|
547
|
-
y1=y,
|
|
548
|
-
line=dict(width=row_line_size, color="black", dash="dot"),
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
# Add vertical grid lines if col_line_size > 0
|
|
552
|
-
if col_line_size > 0:
|
|
553
|
-
for j in range(1, n_cols):
|
|
554
|
-
x = x_positions[j] - cell_margin_px / 2
|
|
555
|
-
fig.add_shape(
|
|
556
|
-
type="line",
|
|
557
|
-
x0=x,
|
|
558
|
-
x1=x,
|
|
559
|
-
y0=cell_margin_px,
|
|
560
|
-
y1=-cum_h + cell_margin_px,
|
|
561
|
-
line=dict(width=col_line_size, color="black", dash="dot"),
|
|
562
|
-
)
|
|
563
|
-
# Axis labels
|
|
564
|
-
col_labels = [format_col_label_func(c) if format_col_label_func else c for c in columns]
|
|
565
|
-
row_labels = [format_row_label_func(r) if format_row_label_func else r for r in rows]
|
|
566
|
-
|
|
567
|
-
fig.update_xaxes(
|
|
568
|
-
tickvals=[x_positions[j] + col_widths[j]/2 for j in range(n_cols)],
|
|
569
|
-
ticktext=col_labels,
|
|
570
|
-
range=[-cell_margin_px, cum_w - cell_margin_px],
|
|
571
|
-
showgrid=False,
|
|
572
|
-
zeroline=False,
|
|
573
|
-
tickfont=tickfont # <-- apply bigger font here
|
|
574
|
-
)
|
|
575
|
-
fig.update_yaxes(
|
|
576
|
-
tickvals=[y_positions[i] - row_heights[i]/2 for i in range(n_rows)],
|
|
577
|
-
ticktext=row_labels,
|
|
578
|
-
range=[-cum_h + cell_margin_px, cell_margin_px],
|
|
579
|
-
showgrid=False,
|
|
580
|
-
zeroline=False,
|
|
581
|
-
tickfont=tickfont # <-- apply bigger font here
|
|
582
|
-
)
|
|
583
|
-
|
|
584
|
-
fig.update_layout(
|
|
585
|
-
width=cum_w + 100,
|
|
586
|
-
height=cum_h + 100,
|
|
587
|
-
title=title,
|
|
588
|
-
title_x=0.5,
|
|
589
|
-
margin=fig_margin,
|
|
590
|
-
)
|
|
591
|
-
|
|
592
|
-
fig.show()
|
|
593
|
-
|
|
594
|
-
@staticmethod
|
|
595
|
-
def plot_image_grid1(
|
|
596
|
-
indir_or_csvf_or_df: Union[str, pd.DataFrame],
|
|
359
|
+
save_path: str = None,
|
|
360
|
+
dpi: int = 300, # DPI for saving raster images or PDF
|
|
361
|
+
show: bool = True, # whether to show the plot in an interactive window
|
|
597
362
|
img_width: int = 300,
|
|
598
363
|
img_height: int = 300,
|
|
599
364
|
img_stack_direction: str = "horizontal", # "horizontal" or "vertical"
|
|
@@ -609,7 +374,7 @@ class PlotHelper:
|
|
|
609
374
|
cell_margin_px: int = 10, # padding (top, left, right, bottom) inside each cell
|
|
610
375
|
row_line_size: int = 0, # if >0, draw horizontal dotted lines
|
|
611
376
|
col_line_size: int = 0, # if >0, draw vertical dotted lines
|
|
612
|
-
):
|
|
377
|
+
) -> go.Figure:
|
|
613
378
|
"""
|
|
614
379
|
Plot a grid of images using Plotly.
|
|
615
380
|
|
|
@@ -931,7 +696,21 @@ class PlotHelper:
|
|
|
931
696
|
margin=fig_margin,
|
|
932
697
|
)
|
|
933
698
|
|
|
934
|
-
|
|
699
|
+
# === EXPORT IF save_path IS GIVEN ===
|
|
700
|
+
if save_path:
|
|
701
|
+
import kaleido # lazy import – only needed when saving
|
|
702
|
+
import os
|
|
703
|
+
|
|
704
|
+
ext = os.path.splitext(save_path)[1].lower()
|
|
705
|
+
if ext in [".png", ".jpg", ".jpeg"]:
|
|
706
|
+
fig.write_image(save_path, scale=dpi / 96) # scale = dpi / base 96
|
|
707
|
+
elif ext in [".pdf", ".svg"]:
|
|
708
|
+
fig.write_image(save_path) # PDF/SVG are vector → dpi ignored
|
|
709
|
+
else:
|
|
710
|
+
raise ValueError("save_path must end with .png, .jpg, .pdf, or .svg")
|
|
711
|
+
if show:
|
|
712
|
+
fig.show()
|
|
713
|
+
return fig
|
|
935
714
|
|
|
936
715
|
|
|
937
716
|
@click.command()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: halib
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.96
|
|
4
4
|
Summary: Small library for common tasks
|
|
5
5
|
Author: Hoang Van Ha
|
|
6
6
|
Author-email: hoangvanhauit@gmail.com
|
|
@@ -52,7 +52,7 @@ Dynamic: summary
|
|
|
52
52
|
|
|
53
53
|
# Helper package for coding and automation
|
|
54
54
|
|
|
55
|
-
**Version 0.1.
|
|
55
|
+
**Version 0.1.96**
|
|
56
56
|
+ `research/plot': add `PlotHelper` class to plot train history + plot grid of images (e.g., image samples from dataset or model outputs)
|
|
57
57
|
|
|
58
58
|
|
|
@@ -21,7 +21,7 @@ halib/filetype/csvfile.py,sha256=4Klf8YNzY1MaCD3o5Wp5GG3KMfQIBOEVzHV_7DO5XBo,660
|
|
|
21
21
|
halib/filetype/jsonfile.py,sha256=9LBdM7LV9QgJA1bzJRkq69qpWOP22HDXPGirqXTgSCw,480
|
|
22
22
|
halib/filetype/textfile.py,sha256=QtuI5PdLxu4hAqSeafr3S8vCXwtvgipWV4Nkl7AzDYM,399
|
|
23
23
|
halib/filetype/videofile.py,sha256=4nfVAYYtoT76y8P4WYyxNna4Iv1o2iV6xaMcUzNPC4s,4736
|
|
24
|
-
halib/filetype/yamlfile.py,sha256=
|
|
24
|
+
halib/filetype/yamlfile.py,sha256=59P9cdqTx655XXeQtkmAJoR_UhhVN4L8Tro-kd8Ri5g,2741
|
|
25
25
|
halib/online/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
26
|
halib/online/gdrive.py,sha256=RmF4y6UPxektkKIctmfT-pKWZsBM9FVUeld6zZmJkp0,7787
|
|
27
27
|
halib/online/gdrive_mkdir.py,sha256=wSJkQMJCDuS1gxQ2lHQHq_IrJ4xR_SEoPSo9n_2WNFU,1474
|
|
@@ -32,12 +32,14 @@ halib/research/base_config.py,sha256=AqZHZ0NNQ3WmUOfRzs36lf3o0FrehSdVLbdmgNpbV7A
|
|
|
32
32
|
halib/research/base_exp.py,sha256=hiO2flt_I0iJJ4bWcQwyh2ISezoC8t2k3PtxHeVr0eI,3278
|
|
33
33
|
halib/research/benchquery.py,sha256=FuKnbWQtCEoRRtJAfN-zaN-jPiO_EzsakmTOMiqi7GQ,4626
|
|
34
34
|
halib/research/dataset.py,sha256=QU0Hr5QFb8_XlvnOMgC9QJGIpwXAZ9lDd0RdQi_QRec,6743
|
|
35
|
+
halib/research/flop_csv.py,sha256=JeIUWgPFmhkPqvmhe-MLwwvAu9yR5F2k3qaViJCJJD4,1148
|
|
36
|
+
halib/research/flops.py,sha256=Us0VudX8QMOm7YenICGf-Tq57C_l9x9hj-MUGA8_hCg,5773
|
|
35
37
|
halib/research/metrics.py,sha256=PXPCy8r1_0lpMKfjc5SjIpRHnX80gHmeZ1C4eVj9U_s,5200
|
|
36
|
-
halib/research/mics.py,sha256=
|
|
38
|
+
halib/research/mics.py,sha256=nZyric8d0yKP5HrwwLsN4AjszrdxAhpJCRo1oy-EKJI,2612
|
|
37
39
|
halib/research/params_gen.py,sha256=GcTMlniL0iE3HalJY-gVRiYa8Qy8u6nX4LkKZeMkct8,4262
|
|
38
|
-
halib/research/perfcalc.py,sha256=
|
|
39
|
-
halib/research/perftb.py,sha256=
|
|
40
|
-
halib/research/plot.py,sha256=
|
|
40
|
+
halib/research/perfcalc.py,sha256=G8WpGB95AY5KQCt0__bPK1yUa2M1onNhXLM7twkElxg,15904
|
|
41
|
+
halib/research/perftb.py,sha256=YlBXMeWn8S0LhsgxONEQZrKomRTju2T8QGGspUOy_6Y,31100
|
|
42
|
+
halib/research/plot.py,sha256=GBCXP1QnzRlNqjAl9UvGvW3I9II61DBStJNQThrLy38,28578
|
|
41
43
|
halib/research/profiler.py,sha256=GRAewTo0jGkOputjmRwtYVfJYBze_ivsOnrW9exWkPQ,11772
|
|
42
44
|
halib/research/torchloader.py,sha256=yqUjcSiME6H5W210363HyRUrOi3ISpUFAFkTr1w4DCw,6503
|
|
43
45
|
halib/research/wandb_op.py,sha256=YzLEqME5kIRxi3VvjFkW83wnFrsn92oYeqYuNwtYRkY,4188
|
|
@@ -54,8 +56,8 @@ halib/utils/gpu_mon.py,sha256=vD41_ZnmPLKguuq9X44SB_vwd9JrblO4BDzHLXZhhFY,2233
|
|
|
54
56
|
halib/utils/listop.py,sha256=Vpa8_2fI0wySpB2-8sfTBkyi_A4FhoFVVvFiuvW8N64,339
|
|
55
57
|
halib/utils/tele_noti.py,sha256=-4WXZelCA4W9BroapkRyIdUu9cUVrcJJhegnMs_WpGU,5928
|
|
56
58
|
halib/utils/video.py,sha256=zLoj5EHk4SmP9OnoHjO8mLbzPdtq6gQPzTQisOEDdO8,3261
|
|
57
|
-
halib-0.1.
|
|
58
|
-
halib-0.1.
|
|
59
|
-
halib-0.1.
|
|
60
|
-
halib-0.1.
|
|
61
|
-
halib-0.1.
|
|
59
|
+
halib-0.1.96.dist-info/licenses/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
|
|
60
|
+
halib-0.1.96.dist-info/METADATA,sha256=3y5yIsCp4dZVdVSSJbTPiqfmW_ZomowRQXP-_fdHit4,6200
|
|
61
|
+
halib-0.1.96.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
62
|
+
halib-0.1.96.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
|
|
63
|
+
halib-0.1.96.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|