halib 0.1.95__tar.gz → 0.1.96__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {halib-0.1.95 → halib-0.1.96}/PKG-INFO +2 -2
- {halib-0.1.95 → halib-0.1.96}/README.md +1 -1
- halib-0.1.96/halib/research/flop_csv.py +34 -0
- halib-0.1.96/halib/research/flops.py +156 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/mics.py +4 -2
- {halib-0.1.95 → halib-0.1.96}/halib/research/plot.py +20 -241
- {halib-0.1.95 → halib-0.1.96}/halib.egg-info/PKG-INFO +2 -2
- {halib-0.1.95 → halib-0.1.96}/halib.egg-info/SOURCES.txt +2 -0
- {halib-0.1.95 → halib-0.1.96}/setup.py +1 -1
- {halib-0.1.95 → halib-0.1.96}/.gitignore +0 -0
- {halib-0.1.95 → halib-0.1.96}/GDriveFolder.txt +0 -0
- {halib-0.1.95 → halib-0.1.96}/LICENSE.txt +0 -0
- {halib-0.1.95 → halib-0.1.96}/MANIFEST.in +0 -0
- {halib-0.1.95 → halib-0.1.96}/guide_publish_pip.pdf +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/__init__.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/common.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/cuda.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/filetype/__init__.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/filetype/csvfile.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/filetype/jsonfile.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/filetype/textfile.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/filetype/videofile.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/filetype/yamlfile.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/online/__init__.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/online/gdrive.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/online/gdrive_mkdir.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/online/gdrive_test.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/online/projectmake.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/__init__.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/base_config.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/base_exp.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/dataset.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/metrics.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/params_gen.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/perfcalc.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/perftb.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/profiler.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/torchloader.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/research/wandb_op.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/rich_color.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/system/__init__.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/system/cmd.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/system/filesys.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/utils/__init__.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/utils/dataclass_util.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/utils/dict_op.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/utils/gpu_mon.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/utils/listop.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/utils/tele_noti.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib/utils/video.py +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib.egg-info/dependency_links.txt +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib.egg-info/requires.txt +0 -0
- {halib-0.1.95 → halib-0.1.96}/halib.egg-info/top_level.txt +0 -0
- {halib-0.1.95 → halib-0.1.96}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: halib
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.96
|
|
4
4
|
Summary: Small library for common tasks
|
|
5
5
|
Author: Hoang Van Ha
|
|
6
6
|
Author-email: hoangvanhauit@gmail.com
|
|
@@ -52,7 +52,7 @@ Dynamic: summary
|
|
|
52
52
|
|
|
53
53
|
# Helper package for coding and automation
|
|
54
54
|
|
|
55
|
-
**Version 0.1.
|
|
55
|
+
**Version 0.1.96**
|
|
56
56
|
+ `research/plot': add `PlotHelper` class to plot train history + plot grid of images (e.g., image samples from dataset or model outputs)
|
|
57
57
|
|
|
58
58
|
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from halib import *
|
|
2
|
+
from flops import _calculate_flops_for_model
|
|
3
|
+
|
|
4
|
+
from halib import *
|
|
5
|
+
from argparse import ArgumentParser
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def main():
|
|
9
|
+
csv_file = "./results-imagenet.csv"
|
|
10
|
+
df = pd.read_csv(csv_file)
|
|
11
|
+
# make param_count column as float
|
|
12
|
+
# df['param_count'] = df['param_count'].astype(float)
|
|
13
|
+
df['param_count'] = pd.to_numeric(df['param_count'], errors='coerce').fillna(99999).astype(float)
|
|
14
|
+
df = df[df['param_count'] < 5.0] # filter models with param_count < 20M
|
|
15
|
+
|
|
16
|
+
dict_ls = []
|
|
17
|
+
|
|
18
|
+
for index, row in tqdm(df.iterrows()):
|
|
19
|
+
console.rule(f"Row {index+1}/{len(df)}")
|
|
20
|
+
model = row['model']
|
|
21
|
+
num_class = 2
|
|
22
|
+
_, _, mflops = _calculate_flops_for_model(model, num_class)
|
|
23
|
+
dict_ls.append({'model': model, 'param_count': row['param_count'], 'mflops': mflops})
|
|
24
|
+
|
|
25
|
+
# Create a DataFrame from the list of dictionaries
|
|
26
|
+
result_df = pd.DataFrame(dict_ls)
|
|
27
|
+
|
|
28
|
+
final_df = pd.merge(df, result_df, on=['model', 'param_count'])
|
|
29
|
+
final_df.sort_values(by='mflops', inplace=True, ascending=True)
|
|
30
|
+
csvfile.fn_display_df(final_df)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if __name__ == "__main__":
|
|
34
|
+
main()
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import torch
|
|
4
|
+
import timm
|
|
5
|
+
from argparse import ArgumentParser
|
|
6
|
+
from fvcore.nn import FlopCountAnalysis
|
|
7
|
+
from halib import *
|
|
8
|
+
from halib.filetype import csvfile
|
|
9
|
+
from curriculum.utils.config import *
|
|
10
|
+
from curriculum.utils.model_helper import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# ---------------------------------------------------------------------
|
|
14
|
+
# Argument Parser
|
|
15
|
+
# ---------------------------------------------------------------------
|
|
16
|
+
def parse_args():
|
|
17
|
+
parser = ArgumentParser(description="Calculate FLOPs for TIMM or trained models")
|
|
18
|
+
|
|
19
|
+
# Option 1: Direct TIMM model
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--model_name", type=str, help="TIMM model name (e.g., efficientnet_b0)"
|
|
22
|
+
)
|
|
23
|
+
parser.add_argument(
|
|
24
|
+
"--num_classes", type=int, default=1000, help="Number of output classes"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Option 2: Experiment directory
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--indir",
|
|
30
|
+
type=str,
|
|
31
|
+
default=None,
|
|
32
|
+
help="Directory containing trained experiment (with .yaml and .pth)",
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"-o", "--o", action="store_true", help="Open output CSV after saving"
|
|
36
|
+
)
|
|
37
|
+
return parser.parse_args()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# ---------------------------------------------------------------------
|
|
41
|
+
# Helper Functions
|
|
42
|
+
# ---------------------------------------------------------------------
|
|
43
|
+
def _get_list_of_proc_dirs(indir):
|
|
44
|
+
assert os.path.exists(indir), f"Input directory {indir} does not exist."
|
|
45
|
+
pth_files = [f for f in os.listdir(indir) if f.endswith(".pth")]
|
|
46
|
+
if len(pth_files) > 0:
|
|
47
|
+
return [indir]
|
|
48
|
+
return [
|
|
49
|
+
os.path.join(indir, f)
|
|
50
|
+
for f in os.listdir(indir)
|
|
51
|
+
if os.path.isdir(os.path.join(indir, f))
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _calculate_flops_for_model(model_name, num_classes):
|
|
56
|
+
"""Calculate FLOPs for a plain TIMM model."""
|
|
57
|
+
try:
|
|
58
|
+
model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
|
|
59
|
+
input_size = timm.data.resolve_data_config(model.default_cfg)["input_size"]
|
|
60
|
+
dummy_input = torch.randn(1, *input_size)
|
|
61
|
+
model.eval() # ! set to eval mode to avoid some warnings or errors
|
|
62
|
+
flops = FlopCountAnalysis(model, dummy_input)
|
|
63
|
+
gflops = flops.total() / 1e9
|
|
64
|
+
mflops = flops.total() / 1e6
|
|
65
|
+
print(f"\nModel: **{model_name}**, Classes: {num_classes}")
|
|
66
|
+
print(f"Input size: {input_size}, FLOPs: **{gflops:.3f} GFLOPs**, **{mflops:.3f} MFLOPs**\n")
|
|
67
|
+
return model_name, gflops, mflops
|
|
68
|
+
except Exception as e:
|
|
69
|
+
print(f"[Error] Could not calculate FLOPs for {model_name}: {e}")
|
|
70
|
+
return model_name, -1, -1
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _calculate_flops_for_experiment(exp_dir):
|
|
74
|
+
"""Calculate FLOPs for a trained experiment directory."""
|
|
75
|
+
yaml_files = [f for f in os.listdir(exp_dir) if f.endswith(".yaml")]
|
|
76
|
+
pth_files = [f for f in os.listdir(exp_dir) if f.endswith(".pth")]
|
|
77
|
+
|
|
78
|
+
assert (
|
|
79
|
+
len(yaml_files) == 1
|
|
80
|
+
), f"Expected 1 YAML file in {exp_dir}, found {len(yaml_files)}"
|
|
81
|
+
assert (
|
|
82
|
+
len(pth_files) == 1
|
|
83
|
+
), f"Expected 1 PTH file in {exp_dir}, found {len(pth_files)}"
|
|
84
|
+
|
|
85
|
+
exp_cfg_yaml = os.path.join(exp_dir, yaml_files[0])
|
|
86
|
+
cfg = ExpConfig.from_yaml(exp_cfg_yaml)
|
|
87
|
+
ds_label_list = cfg.dataset.get_label_list()
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
model = build_model(
|
|
91
|
+
cfg.model.name, num_classes=len(ds_label_list), pretrained=True
|
|
92
|
+
)
|
|
93
|
+
model_weights_path = os.path.join(exp_dir, pth_files[0])
|
|
94
|
+
model.load_state_dict(torch.load(model_weights_path, map_location="cpu"))
|
|
95
|
+
model.eval()
|
|
96
|
+
|
|
97
|
+
input_size = timm.data.resolve_data_config(model.default_cfg)["input_size"]
|
|
98
|
+
dummy_input = torch.randn(1, *input_size)
|
|
99
|
+
flops = FlopCountAnalysis(model, dummy_input)
|
|
100
|
+
gflops = flops.total() / 1e9
|
|
101
|
+
mflops = flops.total() / 1e6
|
|
102
|
+
|
|
103
|
+
return str(cfg), cfg.model.name, gflops, mflops
|
|
104
|
+
except Exception as e:
|
|
105
|
+
console.print(f"[red] Error processing {exp_dir}: {e}[/red]")
|
|
106
|
+
return str(cfg), cfg.model.name, -1, -1
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# ---------------------------------------------------------------------
|
|
110
|
+
# Main Entry
|
|
111
|
+
# ---------------------------------------------------------------------
|
|
112
|
+
def main():
|
|
113
|
+
args = parse_args()
|
|
114
|
+
|
|
115
|
+
# Case 1: Direct TIMM model input
|
|
116
|
+
if args.model_name:
|
|
117
|
+
_calculate_flops_for_model(args.model_name, args.num_classes)
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
# Case 2: Experiment directory input
|
|
121
|
+
if args.indir is None:
|
|
122
|
+
print("[Error] Either --model_name or --indir must be specified.")
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
proc_dirs = _get_list_of_proc_dirs(args.indir)
|
|
126
|
+
pprint(proc_dirs)
|
|
127
|
+
|
|
128
|
+
dfmk = csvfile.DFCreator()
|
|
129
|
+
TABLE_NAME = "model_flops_results"
|
|
130
|
+
dfmk.create_table(TABLE_NAME, ["exp_name", "model_name", "gflops", "mflops"])
|
|
131
|
+
|
|
132
|
+
console.rule(f"Calculating FLOPs for models in {len(proc_dirs)} dir(s)...")
|
|
133
|
+
rows = []
|
|
134
|
+
for exp_dir in tqdm(proc_dirs):
|
|
135
|
+
dir_name = os.path.basename(exp_dir)
|
|
136
|
+
console.rule(f"{dir_name}")
|
|
137
|
+
exp_name, model_name, gflops, mflops = _calculate_flops_for_experiment(exp_dir)
|
|
138
|
+
rows.append([exp_name, model_name, gflops, mflops])
|
|
139
|
+
|
|
140
|
+
dfmk.insert_rows(TABLE_NAME, rows)
|
|
141
|
+
dfmk.fill_table_from_row_pool(TABLE_NAME)
|
|
142
|
+
|
|
143
|
+
outfile = f"zout/zreport/{now_str()}_model_flops_results.csv"
|
|
144
|
+
dfmk[TABLE_NAME].to_csv(outfile, sep=";", index=False)
|
|
145
|
+
csvfile.fn_display_df(dfmk[TABLE_NAME])
|
|
146
|
+
|
|
147
|
+
if args.o:
|
|
148
|
+
os.system(f"start {outfile}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ---------------------------------------------------------------------
|
|
152
|
+
# Script Entry
|
|
153
|
+
# ---------------------------------------------------------------------
|
|
154
|
+
if __name__ == "__main__":
|
|
155
|
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
156
|
+
main()
|
|
@@ -19,9 +19,11 @@ DEFAULT_ABBR_WORKING_DISK = {
|
|
|
19
19
|
"4GPU_SV": "D:",
|
|
20
20
|
}
|
|
21
21
|
|
|
22
|
-
def list_PCs():
|
|
22
|
+
def list_PCs(show=True):
|
|
23
23
|
df = pd.DataFrame(list(PC_NAME_TO_ABBR.items()), columns=["PC Name", "Abbreviation"])
|
|
24
|
-
|
|
24
|
+
if show:
|
|
25
|
+
csvfile.fn_display_df(df)
|
|
26
|
+
return df
|
|
25
27
|
|
|
26
28
|
def get_PC_name():
|
|
27
29
|
return platform.node()
|
|
@@ -281,7 +281,7 @@ class PlotHelper:
|
|
|
281
281
|
# --- Convert to wide "multi-list" format ---
|
|
282
282
|
df = pd.DataFrame(data)
|
|
283
283
|
row_col = df.columns[0] # first col = row labels
|
|
284
|
-
col_cols = df.columns[1:] # the rest = groupable cols
|
|
284
|
+
# col_cols = df.columns[1:] # the rest = groupable cols
|
|
285
285
|
|
|
286
286
|
df = (
|
|
287
287
|
df.melt(id_vars=[row_col], var_name="col", value_name="path")
|
|
@@ -356,244 +356,9 @@ class PlotHelper:
|
|
|
356
356
|
@staticmethod
|
|
357
357
|
def plot_image_grid(
|
|
358
358
|
indir_or_csvf_or_df: Union[str, pd.DataFrame],
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
img_stack_padding_px: int = 5,
|
|
363
|
-
img_scale_mode: str = "fit", # "fit" or "fill"
|
|
364
|
-
format_row_label_func: Optional[Callable[[str], str]] = None,
|
|
365
|
-
format_col_label_func: Optional[Callable[[str], str]] = None,
|
|
366
|
-
title: str = "",
|
|
367
|
-
tickfont=dict(size=16, family="Arial", color="black"), # <-- bigger labels
|
|
368
|
-
fig_margin: dict = dict(l=50, r=50, t=50, b=50),
|
|
369
|
-
outline_color: str = "",
|
|
370
|
-
outline_size: int = 1,
|
|
371
|
-
cell_margin_px: int = 10, # spacing between cells
|
|
372
|
-
row_line_size: int = 0, # if >0, draw horizontal dotted lines
|
|
373
|
-
col_line_size: int = 0 # if >0, draw vertical dotted lines
|
|
374
|
-
):
|
|
375
|
-
"""
|
|
376
|
-
Plot a grid of images using Plotly.
|
|
377
|
-
|
|
378
|
-
- Accepts DataFrame where each cell is either:
|
|
379
|
-
* a Python list object,
|
|
380
|
-
* a string representation of a Python list (e.g. "['a','b']"),
|
|
381
|
-
* a JSON list string, or
|
|
382
|
-
* a single path string.
|
|
383
|
-
- For each cell, stack the images into a single composite that exactly fits
|
|
384
|
-
(img_width, img_height) is the target size for each individual image in the stack.
|
|
385
|
-
The final cell size will depend on the number of images and stacking direction.
|
|
386
|
-
"""
|
|
387
|
-
|
|
388
|
-
def process_image_for_slot(path: str, target_size: Tuple[int, int], scale_mode: str, outline: str, outline_size: int) -> Image.Image:
|
|
389
|
-
try:
|
|
390
|
-
img = Image.open(path).convert("RGB")
|
|
391
|
-
except Exception:
|
|
392
|
-
return Image.new("RGB", target_size, (255, 255, 255))
|
|
393
|
-
|
|
394
|
-
if scale_mode == "fit":
|
|
395
|
-
img_ratio = img.width / img.height
|
|
396
|
-
target_ratio = target_size[0] / target_size[1]
|
|
397
|
-
|
|
398
|
-
if img_ratio > target_ratio:
|
|
399
|
-
new_height = target_size[1]
|
|
400
|
-
new_width = max(1, int(new_height * img_ratio))
|
|
401
|
-
else:
|
|
402
|
-
new_width = target_size[0]
|
|
403
|
-
new_height = max(1, int(new_width / img_ratio))
|
|
404
|
-
|
|
405
|
-
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
406
|
-
left = (new_width - target_size[0]) // 2
|
|
407
|
-
top = (new_height - target_size[1]) // 2
|
|
408
|
-
right = left + target_size[0]
|
|
409
|
-
bottom = top + target_size[1]
|
|
410
|
-
|
|
411
|
-
if len(outline) == 7 and outline.startswith("#"):
|
|
412
|
-
border_px = outline_size
|
|
413
|
-
bordered = Image.new("RGB", (target_size[0] + 2*border_px, target_size[1] + 2*border_px), outline)
|
|
414
|
-
bordered.paste(img.crop((left, top, right, bottom)), (border_px, border_px))
|
|
415
|
-
return bordered
|
|
416
|
-
return img.crop((left, top, right, bottom))
|
|
417
|
-
|
|
418
|
-
elif scale_mode == "fill":
|
|
419
|
-
if len(outline) == 7 and outline.startswith("#"):
|
|
420
|
-
border_px = outline_size
|
|
421
|
-
bordered = Image.new("RGB", (target_size[0] + 2*border_px, target_size[1] + 2*border_px), outline)
|
|
422
|
-
img = img.resize(target_size, Image.Resampling.LANCZOS)
|
|
423
|
-
bordered.paste(img, (border_px, border_px))
|
|
424
|
-
return bordered
|
|
425
|
-
return img.resize(target_size, Image.Resampling.LANCZOS)
|
|
426
|
-
else:
|
|
427
|
-
raise ValueError("img_scale_mode must be 'fit' or 'fill'.")
|
|
428
|
-
|
|
429
|
-
def stack_images_base64(image_paths: List[str], direction: str, single_img_size: Tuple[int,int], outline: str, outline_size: int, padding: int) -> Tuple[str, Tuple[int,int]]:
|
|
430
|
-
image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
|
|
431
|
-
n = len(image_paths)
|
|
432
|
-
if n == 0:
|
|
433
|
-
blank = Image.new("RGB", single_img_size, (255,255,255))
|
|
434
|
-
buf = BytesIO()
|
|
435
|
-
blank.save(buf, format="PNG")
|
|
436
|
-
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode(), single_img_size
|
|
437
|
-
|
|
438
|
-
processed = [process_image_for_slot(p, single_img_size, img_scale_mode, outline, outline_size) for p in image_paths]
|
|
439
|
-
pad_total = padding * (n-1)
|
|
440
|
-
|
|
441
|
-
if direction == "horizontal":
|
|
442
|
-
total_w = sum(im.width for im in processed) + pad_total
|
|
443
|
-
total_h = max(im.height for im in processed)
|
|
444
|
-
stacked = Image.new("RGB", (total_w, total_h), (255,255,255))
|
|
445
|
-
x = 0
|
|
446
|
-
for im in processed:
|
|
447
|
-
stacked.paste(im, (x,0))
|
|
448
|
-
x += im.width + padding
|
|
449
|
-
elif direction == "vertical":
|
|
450
|
-
total_w = max(im.width for im in processed)
|
|
451
|
-
total_h = sum(im.height for im in processed) + pad_total
|
|
452
|
-
stacked = Image.new("RGB", (total_w, total_h), (255,255,255))
|
|
453
|
-
y = 0
|
|
454
|
-
for im in processed:
|
|
455
|
-
stacked.paste(im, (0,y))
|
|
456
|
-
y += im.height + padding
|
|
457
|
-
else:
|
|
458
|
-
raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
|
|
459
|
-
|
|
460
|
-
buf = BytesIO()
|
|
461
|
-
stacked.save(buf, format="PNG")
|
|
462
|
-
encoded = base64.b64encode(buf.getvalue()).decode()
|
|
463
|
-
return f"data:image/png;base64,{encoded}", (total_w, total_h)
|
|
464
|
-
|
|
465
|
-
# --- Load DataFrame ---
|
|
466
|
-
if isinstance(indir_or_csvf_or_df, str):
|
|
467
|
-
fname, ext = os.path.splitext(indir_or_csvf_or_df)
|
|
468
|
-
if ext.lower() == ".csv":
|
|
469
|
-
df = pd.read_csv(indir_or_csvf_or_df)
|
|
470
|
-
elif os.path.isdir(indir_or_csvf_or_df):
|
|
471
|
-
df = PlotHelper.img_grid_indir_1(indir_or_csvf_or_df, log=False)
|
|
472
|
-
else:
|
|
473
|
-
raise ValueError("Input string must be a valid CSV file or directory path")
|
|
474
|
-
elif isinstance(indir_or_csvf_or_df, pd.DataFrame):
|
|
475
|
-
df = indir_or_csvf_or_df.copy()
|
|
476
|
-
else:
|
|
477
|
-
raise ValueError("Input must be CSV file path, DataFrame, or directory path")
|
|
478
|
-
|
|
479
|
-
rows = df.iloc[:,0].astype(str).tolist()
|
|
480
|
-
columns = list(df.columns[1:])
|
|
481
|
-
n_rows, n_cols = len(rows), len(columns)
|
|
482
|
-
|
|
483
|
-
fig = go.Figure()
|
|
484
|
-
col_widths = [0]*n_cols
|
|
485
|
-
row_heights = [0]*n_rows
|
|
486
|
-
|
|
487
|
-
cell_imgs = [[None]*n_cols for _ in range(n_rows)]
|
|
488
|
-
for i in range(n_rows):
|
|
489
|
-
for j, col_label in enumerate(columns):
|
|
490
|
-
raw_cell = df.iloc[i, j+1]
|
|
491
|
-
image_paths = PlotHelper._parse_cell_to_list(raw_cell)
|
|
492
|
-
image_paths = [str(p).strip() for p in image_paths if str(p).strip() != ""]
|
|
493
|
-
|
|
494
|
-
img_src, (cell_w_actual, cell_h_actual) = stack_images_base64(
|
|
495
|
-
image_paths, img_stack_direction, (img_width, img_height),
|
|
496
|
-
outline=outline_color, outline_size=outline_size,
|
|
497
|
-
padding=img_stack_padding_px
|
|
498
|
-
)
|
|
499
|
-
|
|
500
|
-
col_widths[j] = max(col_widths[j], cell_w_actual)
|
|
501
|
-
row_heights[i] = max(row_heights[i], cell_h_actual)
|
|
502
|
-
cell_imgs[i][j] = img_src
|
|
503
|
-
|
|
504
|
-
# Compute x/y positions including cell_margin
|
|
505
|
-
x_positions = []
|
|
506
|
-
cum_w = 0
|
|
507
|
-
for w in col_widths:
|
|
508
|
-
x_positions.append(cum_w)
|
|
509
|
-
cum_w += w + cell_margin_px
|
|
510
|
-
|
|
511
|
-
y_positions = []
|
|
512
|
-
cum_h = 0
|
|
513
|
-
for h in row_heights:
|
|
514
|
-
y_positions.append(-cum_h)
|
|
515
|
-
cum_h += h + cell_margin_px
|
|
516
|
-
|
|
517
|
-
# Add images to figure
|
|
518
|
-
for i in range(n_rows):
|
|
519
|
-
for j in range(n_cols):
|
|
520
|
-
fig.add_layout_image(
|
|
521
|
-
dict(
|
|
522
|
-
source=cell_imgs[i][j],
|
|
523
|
-
x=x_positions[j],
|
|
524
|
-
y=y_positions[i],
|
|
525
|
-
xref="x",
|
|
526
|
-
yref="y",
|
|
527
|
-
sizex=col_widths[j],
|
|
528
|
-
sizey=row_heights[i],
|
|
529
|
-
xanchor="left",
|
|
530
|
-
yanchor="top",
|
|
531
|
-
layer="above",
|
|
532
|
-
)
|
|
533
|
-
)
|
|
534
|
-
# ! Optional grid lines
|
|
535
|
-
# Add horizontal grid lines if row_line_size > 0
|
|
536
|
-
if row_line_size > 0:
|
|
537
|
-
for i in range(1, n_rows):
|
|
538
|
-
# Place line in the middle of the gap between rows
|
|
539
|
-
y = (
|
|
540
|
-
y_positions[i - 1] - row_heights[i - 1] - y_positions[i]
|
|
541
|
-
) / 2 + y_positions[i]
|
|
542
|
-
fig.add_shape(
|
|
543
|
-
type="line",
|
|
544
|
-
x0=-cell_margin_px,
|
|
545
|
-
x1=cum_w - cell_margin_px,
|
|
546
|
-
y0=y,
|
|
547
|
-
y1=y,
|
|
548
|
-
line=dict(width=row_line_size, color="black", dash="dot"),
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
# Add vertical grid lines if col_line_size > 0
|
|
552
|
-
if col_line_size > 0:
|
|
553
|
-
for j in range(1, n_cols):
|
|
554
|
-
x = x_positions[j] - cell_margin_px / 2
|
|
555
|
-
fig.add_shape(
|
|
556
|
-
type="line",
|
|
557
|
-
x0=x,
|
|
558
|
-
x1=x,
|
|
559
|
-
y0=cell_margin_px,
|
|
560
|
-
y1=-cum_h + cell_margin_px,
|
|
561
|
-
line=dict(width=col_line_size, color="black", dash="dot"),
|
|
562
|
-
)
|
|
563
|
-
# Axis labels
|
|
564
|
-
col_labels = [format_col_label_func(c) if format_col_label_func else c for c in columns]
|
|
565
|
-
row_labels = [format_row_label_func(r) if format_row_label_func else r for r in rows]
|
|
566
|
-
|
|
567
|
-
fig.update_xaxes(
|
|
568
|
-
tickvals=[x_positions[j] + col_widths[j]/2 for j in range(n_cols)],
|
|
569
|
-
ticktext=col_labels,
|
|
570
|
-
range=[-cell_margin_px, cum_w - cell_margin_px],
|
|
571
|
-
showgrid=False,
|
|
572
|
-
zeroline=False,
|
|
573
|
-
tickfont=tickfont # <-- apply bigger font here
|
|
574
|
-
)
|
|
575
|
-
fig.update_yaxes(
|
|
576
|
-
tickvals=[y_positions[i] - row_heights[i]/2 for i in range(n_rows)],
|
|
577
|
-
ticktext=row_labels,
|
|
578
|
-
range=[-cum_h + cell_margin_px, cell_margin_px],
|
|
579
|
-
showgrid=False,
|
|
580
|
-
zeroline=False,
|
|
581
|
-
tickfont=tickfont # <-- apply bigger font here
|
|
582
|
-
)
|
|
583
|
-
|
|
584
|
-
fig.update_layout(
|
|
585
|
-
width=cum_w + 100,
|
|
586
|
-
height=cum_h + 100,
|
|
587
|
-
title=title,
|
|
588
|
-
title_x=0.5,
|
|
589
|
-
margin=fig_margin,
|
|
590
|
-
)
|
|
591
|
-
|
|
592
|
-
fig.show()
|
|
593
|
-
|
|
594
|
-
@staticmethod
|
|
595
|
-
def plot_image_grid1(
|
|
596
|
-
indir_or_csvf_or_df: Union[str, pd.DataFrame],
|
|
359
|
+
save_path: str = None,
|
|
360
|
+
dpi: int = 300, # DPI for saving raster images or PDF
|
|
361
|
+
show: bool = True, # whether to show the plot in an interactive window
|
|
597
362
|
img_width: int = 300,
|
|
598
363
|
img_height: int = 300,
|
|
599
364
|
img_stack_direction: str = "horizontal", # "horizontal" or "vertical"
|
|
@@ -609,7 +374,7 @@ class PlotHelper:
|
|
|
609
374
|
cell_margin_px: int = 10, # padding (top, left, right, bottom) inside each cell
|
|
610
375
|
row_line_size: int = 0, # if >0, draw horizontal dotted lines
|
|
611
376
|
col_line_size: int = 0, # if >0, draw vertical dotted lines
|
|
612
|
-
):
|
|
377
|
+
) -> go.Figure:
|
|
613
378
|
"""
|
|
614
379
|
Plot a grid of images using Plotly.
|
|
615
380
|
|
|
@@ -931,7 +696,21 @@ class PlotHelper:
|
|
|
931
696
|
margin=fig_margin,
|
|
932
697
|
)
|
|
933
698
|
|
|
934
|
-
|
|
699
|
+
# === EXPORT IF save_path IS GIVEN ===
|
|
700
|
+
if save_path:
|
|
701
|
+
import kaleido # lazy import – only needed when saving
|
|
702
|
+
import os
|
|
703
|
+
|
|
704
|
+
ext = os.path.splitext(save_path)[1].lower()
|
|
705
|
+
if ext in [".png", ".jpg", ".jpeg"]:
|
|
706
|
+
fig.write_image(save_path, scale=dpi / 96) # scale = dpi / base 96
|
|
707
|
+
elif ext in [".pdf", ".svg"]:
|
|
708
|
+
fig.write_image(save_path) # PDF/SVG are vector → dpi ignored
|
|
709
|
+
else:
|
|
710
|
+
raise ValueError("save_path must end with .png, .jpg, .pdf, or .svg")
|
|
711
|
+
if show:
|
|
712
|
+
fig.show()
|
|
713
|
+
return fig
|
|
935
714
|
|
|
936
715
|
|
|
937
716
|
@click.command()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: halib
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.96
|
|
4
4
|
Summary: Small library for common tasks
|
|
5
5
|
Author: Hoang Van Ha
|
|
6
6
|
Author-email: hoangvanhauit@gmail.com
|
|
@@ -52,7 +52,7 @@ Dynamic: summary
|
|
|
52
52
|
|
|
53
53
|
# Helper package for coding and automation
|
|
54
54
|
|
|
55
|
-
**Version 0.1.
|
|
55
|
+
**Version 0.1.96**
|
|
56
56
|
+ `research/plot': add `PlotHelper` class to plot train history + plot grid of images (e.g., image samples from dataset or model outputs)
|
|
57
57
|
|
|
58
58
|
|
|
@@ -29,6 +29,8 @@ halib/research/__init__.py
|
|
|
29
29
|
halib/research/base_config.py
|
|
30
30
|
halib/research/base_exp.py
|
|
31
31
|
halib/research/dataset.py
|
|
32
|
+
halib/research/flop_csv.py
|
|
33
|
+
halib/research/flops.py
|
|
32
34
|
halib/research/metrics.py
|
|
33
35
|
halib/research/mics.py
|
|
34
36
|
halib/research/params_gen.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|