halib 0.1.94__py3-none-any.whl → 0.1.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,8 @@ from omegaconf import OmegaConf
6
6
  from rich.console import Console
7
7
  from argparse import ArgumentParser
8
8
 
9
+ from ..research.mics import *
10
+
9
11
  console = Console()
10
12
 
11
13
 
@@ -51,6 +53,27 @@ def load_yaml(yaml_file, to_dict=False, log_info=False):
51
53
  else:
52
54
  return omgconf
53
55
 
56
+ def load_yaml_with_PC_abbr(
57
+ yaml_file, pc_abbr_to_working_disk=DEFAULT_ABBR_WORKING_DISK
58
+ ):
59
+ # current PC abbreviation
60
+ pc_abbr = get_PC_abbr_name()
61
+
62
+ # current plaftform: windows or linux
63
+ current_platform = platform.system().lower()
64
+
65
+ assert pc_abbr in pc_abbr_to_working_disk, f"The is no mapping for {pc_abbr} to <working_disk>"
66
+
67
+ # working disk
68
+ working_disk = pc_abbr_to_working_disk.get(pc_abbr)
69
+
70
+ # load yaml file
71
+ data_dict = load_yaml(yaml_file=yaml_file, to_dict=True)
72
+
73
+ # Normalize paths in the loaded data
74
+ data_dict = normalize_paths(data_dict, working_disk, current_platform)
75
+ return data_dict
76
+
54
77
 
55
78
  def parse_args():
56
79
  parser = ArgumentParser(description="desc text")
@@ -0,0 +1,34 @@
1
+ from halib import *
2
+ from flops import _calculate_flops_for_model
3
+
4
+ from halib import *
5
+ from argparse import ArgumentParser
6
+
7
+
8
+ def main():
9
+ csv_file = "./results-imagenet.csv"
10
+ df = pd.read_csv(csv_file)
11
+ # make param_count column as float
12
+ # df['param_count'] = df['param_count'].astype(float)
13
+ df['param_count'] = pd.to_numeric(df['param_count'], errors='coerce').fillna(99999).astype(float)
14
+ df = df[df['param_count'] < 5.0] # filter models with param_count < 20M
15
+
16
+ dict_ls = []
17
+
18
+ for index, row in tqdm(df.iterrows()):
19
+ console.rule(f"Row {index+1}/{len(df)}")
20
+ model = row['model']
21
+ num_class = 2
22
+ _, _, mflops = _calculate_flops_for_model(model, num_class)
23
+ dict_ls.append({'model': model, 'param_count': row['param_count'], 'mflops': mflops})
24
+
25
+ # Create a DataFrame from the list of dictionaries
26
+ result_df = pd.DataFrame(dict_ls)
27
+
28
+ final_df = pd.merge(df, result_df, on=['model', 'param_count'])
29
+ final_df.sort_values(by='mflops', inplace=True, ascending=True)
30
+ csvfile.fn_display_df(final_df)
31
+
32
+
33
+ if __name__ == "__main__":
34
+ main()
@@ -0,0 +1,156 @@
1
+ import os
2
+ import sys
3
+ import torch
4
+ import timm
5
+ from argparse import ArgumentParser
6
+ from fvcore.nn import FlopCountAnalysis
7
+ from halib import *
8
+ from halib.filetype import csvfile
9
+ from curriculum.utils.config import *
10
+ from curriculum.utils.model_helper import *
11
+
12
+
13
+ # ---------------------------------------------------------------------
14
+ # Argument Parser
15
+ # ---------------------------------------------------------------------
16
+ def parse_args():
17
+ parser = ArgumentParser(description="Calculate FLOPs for TIMM or trained models")
18
+
19
+ # Option 1: Direct TIMM model
20
+ parser.add_argument(
21
+ "--model_name", type=str, help="TIMM model name (e.g., efficientnet_b0)"
22
+ )
23
+ parser.add_argument(
24
+ "--num_classes", type=int, default=1000, help="Number of output classes"
25
+ )
26
+
27
+ # Option 2: Experiment directory
28
+ parser.add_argument(
29
+ "--indir",
30
+ type=str,
31
+ default=None,
32
+ help="Directory containing trained experiment (with .yaml and .pth)",
33
+ )
34
+ parser.add_argument(
35
+ "-o", "--o", action="store_true", help="Open output CSV after saving"
36
+ )
37
+ return parser.parse_args()
38
+
39
+
40
+ # ---------------------------------------------------------------------
41
+ # Helper Functions
42
+ # ---------------------------------------------------------------------
43
+ def _get_list_of_proc_dirs(indir):
44
+ assert os.path.exists(indir), f"Input directory {indir} does not exist."
45
+ pth_files = [f for f in os.listdir(indir) if f.endswith(".pth")]
46
+ if len(pth_files) > 0:
47
+ return [indir]
48
+ return [
49
+ os.path.join(indir, f)
50
+ for f in os.listdir(indir)
51
+ if os.path.isdir(os.path.join(indir, f))
52
+ ]
53
+
54
+
55
+ def _calculate_flops_for_model(model_name, num_classes):
56
+ """Calculate FLOPs for a plain TIMM model."""
57
+ try:
58
+ model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
59
+ input_size = timm.data.resolve_data_config(model.default_cfg)["input_size"]
60
+ dummy_input = torch.randn(1, *input_size)
61
+ model.eval() # ! set to eval mode to avoid some warnings or errors
62
+ flops = FlopCountAnalysis(model, dummy_input)
63
+ gflops = flops.total() / 1e9
64
+ mflops = flops.total() / 1e6
65
+ print(f"\nModel: **{model_name}**, Classes: {num_classes}")
66
+ print(f"Input size: {input_size}, FLOPs: **{gflops:.3f} GFLOPs**, **{mflops:.3f} MFLOPs**\n")
67
+ return model_name, gflops, mflops
68
+ except Exception as e:
69
+ print(f"[Error] Could not calculate FLOPs for {model_name}: {e}")
70
+ return model_name, -1, -1
71
+
72
+
73
+ def _calculate_flops_for_experiment(exp_dir):
74
+ """Calculate FLOPs for a trained experiment directory."""
75
+ yaml_files = [f for f in os.listdir(exp_dir) if f.endswith(".yaml")]
76
+ pth_files = [f for f in os.listdir(exp_dir) if f.endswith(".pth")]
77
+
78
+ assert (
79
+ len(yaml_files) == 1
80
+ ), f"Expected 1 YAML file in {exp_dir}, found {len(yaml_files)}"
81
+ assert (
82
+ len(pth_files) == 1
83
+ ), f"Expected 1 PTH file in {exp_dir}, found {len(pth_files)}"
84
+
85
+ exp_cfg_yaml = os.path.join(exp_dir, yaml_files[0])
86
+ cfg = ExpConfig.from_yaml(exp_cfg_yaml)
87
+ ds_label_list = cfg.dataset.get_label_list()
88
+
89
+ try:
90
+ model = build_model(
91
+ cfg.model.name, num_classes=len(ds_label_list), pretrained=True
92
+ )
93
+ model_weights_path = os.path.join(exp_dir, pth_files[0])
94
+ model.load_state_dict(torch.load(model_weights_path, map_location="cpu"))
95
+ model.eval()
96
+
97
+ input_size = timm.data.resolve_data_config(model.default_cfg)["input_size"]
98
+ dummy_input = torch.randn(1, *input_size)
99
+ flops = FlopCountAnalysis(model, dummy_input)
100
+ gflops = flops.total() / 1e9
101
+ mflops = flops.total() / 1e6
102
+
103
+ return str(cfg), cfg.model.name, gflops, mflops
104
+ except Exception as e:
105
+ console.print(f"[red] Error processing {exp_dir}: {e}[/red]")
106
+ return str(cfg), cfg.model.name, -1, -1
107
+
108
+
109
+ # ---------------------------------------------------------------------
110
+ # Main Entry
111
+ # ---------------------------------------------------------------------
112
+ def main():
113
+ args = parse_args()
114
+
115
+ # Case 1: Direct TIMM model input
116
+ if args.model_name:
117
+ _calculate_flops_for_model(args.model_name, args.num_classes)
118
+ return
119
+
120
+ # Case 2: Experiment directory input
121
+ if args.indir is None:
122
+ print("[Error] Either --model_name or --indir must be specified.")
123
+ return
124
+
125
+ proc_dirs = _get_list_of_proc_dirs(args.indir)
126
+ pprint(proc_dirs)
127
+
128
+ dfmk = csvfile.DFCreator()
129
+ TABLE_NAME = "model_flops_results"
130
+ dfmk.create_table(TABLE_NAME, ["exp_name", "model_name", "gflops", "mflops"])
131
+
132
+ console.rule(f"Calculating FLOPs for models in {len(proc_dirs)} dir(s)...")
133
+ rows = []
134
+ for exp_dir in tqdm(proc_dirs):
135
+ dir_name = os.path.basename(exp_dir)
136
+ console.rule(f"{dir_name}")
137
+ exp_name, model_name, gflops, mflops = _calculate_flops_for_experiment(exp_dir)
138
+ rows.append([exp_name, model_name, gflops, mflops])
139
+
140
+ dfmk.insert_rows(TABLE_NAME, rows)
141
+ dfmk.fill_table_from_row_pool(TABLE_NAME)
142
+
143
+ outfile = f"zout/zreport/{now_str()}_model_flops_results.csv"
144
+ dfmk[TABLE_NAME].to_csv(outfile, sep=";", index=False)
145
+ csvfile.fn_display_df(dfmk[TABLE_NAME])
146
+
147
+ if args.o:
148
+ os.system(f"start {outfile}")
149
+
150
+
151
+ # ---------------------------------------------------------------------
152
+ # Script Entry
153
+ # ---------------------------------------------------------------------
154
+ if __name__ == "__main__":
155
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
156
+ main()
halib/research/mics.py CHANGED
@@ -1,5 +1,9 @@
1
+ from ..common import *
2
+ from ..filetype import csvfile
3
+ import pandas as pd
1
4
  import platform
2
5
 
6
+
3
7
  PC_NAME_TO_ABBR = {
4
8
  "DESKTOP-JQD9K01": "MainPC",
5
9
  "DESKTOP-5IRHU87": "MSI_Laptop",
@@ -8,9 +12,57 @@ PC_NAME_TO_ABBR = {
8
12
  "DESKTOP-QNS3DNF": "1GPU_SV"
9
13
  }
10
14
 
15
+ DEFAULT_ABBR_WORKING_DISK = {
16
+ "MainPC": "E:",
17
+ "MSI_Laptop": "D:",
18
+ "4090_SV": "E:",
19
+ "4GPU_SV": "D:",
20
+ }
21
+
22
+ def list_PCs(show=True):
23
+ df = pd.DataFrame(list(PC_NAME_TO_ABBR.items()), columns=["PC Name", "Abbreviation"])
24
+ if show:
25
+ csvfile.fn_display_df(df)
26
+ return df
27
+
11
28
  def get_PC_name():
12
29
  return platform.node()
13
30
 
14
31
  def get_PC_abbr_name():
15
32
  pc_name = get_PC_name()
16
33
  return PC_NAME_TO_ABBR.get(pc_name, "Unknown")
34
+
35
+ # ! This funcction search for full paths in the obj and normalize them according to the current platform and working disk
36
+ # ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "windows" => "D:/zdataset/DFire"
37
+ # ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "linux" => "/mnt/d/zdataset/DFire"
38
+ def normalize_paths(obj, working_disk, current_platform):
39
+ if isinstance(obj, dict):
40
+ for key, value in obj.items():
41
+ obj[key] = normalize_paths(value, working_disk, current_platform)
42
+ return obj
43
+ elif isinstance(obj, list):
44
+ for i, item in enumerate(obj):
45
+ obj[i] = normalize_paths(item, working_disk, current_platform)
46
+ return obj
47
+ elif isinstance(obj, str):
48
+ # Normalize backslashes to forward slashes for consistency
49
+ obj = obj.replace("\\", "/")
50
+ # Regex for Windows-style path: e.g., "E:/zdataset/DFire"
51
+ win_match = re.match(r"^([A-Z]):/(.*)$", obj)
52
+ # Regex for Linux-style path: e.g., "/mnt/e/zdataset/DFire"
53
+ lin_match = re.match(r"^/mnt/([a-z])/(.*)$", obj)
54
+ if win_match or lin_match:
55
+ rest = win_match.group(2) if win_match else lin_match.group(2)
56
+ if current_platform == "windows":
57
+ # working_disk is like "D:", so "D:/" + rest
58
+ new_path = working_disk + "/" + rest
59
+ elif current_platform == "linux":
60
+ # Extract drive letter from working_disk (e.g., "D:" -> "d")
61
+ drive_letter = working_disk[0].lower()
62
+ new_path = "/mnt/" + drive_letter + "/" + rest
63
+ else:
64
+ # Unknown platform, return original
65
+ return obj
66
+ return new_path
67
+ # For non-strings or non-path strings, return as is
68
+ return obj
@@ -227,9 +227,9 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
227
227
  ), "No metric columns found in the DataFrame. Ensure that the CSV files contain metric columns starting with 'metric_'."
228
228
  final_cols = sticky_cols + metric_cols
229
229
  df = df[final_cols]
230
- # !hahv debug
231
- pprint("------ Final DataFrame Columns ------")
232
- csvfile.fn_display_df(df)
230
+ # # !hahv debug
231
+ # pprint("------ Final DataFrame Columns ------")
232
+ # csvfile.fn_display_df(df)
233
233
  # ! validate all rows in df before returning
234
234
  # make sure all rows will have at least values for REQUIRED_COLS and at least one metric column
235
235
  for index, row in df.iterrows():
halib/research/perftb.py CHANGED
@@ -308,7 +308,8 @@ class PerfTB:
308
308
  if save_path:
309
309
  export_success = False
310
310
  try:
311
- fig.write_image(save_path, engine="kaleido")
311
+ # fig.write_image(save_path, engine="kaleido")
312
+ fig.write_image(save_path, engine="kaleido", width=width, height=height * len(metric_list))
312
313
  export_success = True
313
314
  # pprint(f"Saved: {os.path.abspath(save_path)}")
314
315
  except Exception as e:
halib/research/plot.py CHANGED
@@ -281,7 +281,7 @@ class PlotHelper:
281
281
  # --- Convert to wide "multi-list" format ---
282
282
  df = pd.DataFrame(data)
283
283
  row_col = df.columns[0] # first col = row labels
284
- col_cols = df.columns[1:] # the rest = groupable cols
284
+ # col_cols = df.columns[1:] # the rest = groupable cols
285
285
 
286
286
  df = (
287
287
  df.melt(id_vars=[row_col], var_name="col", value_name="path")
@@ -356,244 +356,9 @@ class PlotHelper:
356
356
  @staticmethod
357
357
  def plot_image_grid(
358
358
  indir_or_csvf_or_df: Union[str, pd.DataFrame],
359
- img_width: int = 300,
360
- img_height: int = 300,
361
- img_stack_direction: str = "horizontal", # "horizontal" or "vertical"
362
- img_stack_padding_px: int = 5,
363
- img_scale_mode: str = "fit", # "fit" or "fill"
364
- format_row_label_func: Optional[Callable[[str], str]] = None,
365
- format_col_label_func: Optional[Callable[[str], str]] = None,
366
- title: str = "",
367
- tickfont=dict(size=16, family="Arial", color="black"), # <-- bigger labels
368
- fig_margin: dict = dict(l=50, r=50, t=50, b=50),
369
- outline_color: str = "",
370
- outline_size: int = 1,
371
- cell_margin_px: int = 10, # spacing between cells
372
- row_line_size: int = 0, # if >0, draw horizontal dotted lines
373
- col_line_size: int = 0 # if >0, draw vertical dotted lines
374
- ):
375
- """
376
- Plot a grid of images using Plotly.
377
-
378
- - Accepts DataFrame where each cell is either:
379
- * a Python list object,
380
- * a string representation of a Python list (e.g. "['a','b']"),
381
- * a JSON list string, or
382
- * a single path string.
383
- - For each cell, stack the images into a single composite that exactly fits
384
- (img_width, img_height) is the target size for each individual image in the stack.
385
- The final cell size will depend on the number of images and stacking direction.
386
- """
387
-
388
- def process_image_for_slot(path: str, target_size: Tuple[int, int], scale_mode: str, outline: str, outline_size: int) -> Image.Image:
389
- try:
390
- img = Image.open(path).convert("RGB")
391
- except Exception:
392
- return Image.new("RGB", target_size, (255, 255, 255))
393
-
394
- if scale_mode == "fit":
395
- img_ratio = img.width / img.height
396
- target_ratio = target_size[0] / target_size[1]
397
-
398
- if img_ratio > target_ratio:
399
- new_height = target_size[1]
400
- new_width = max(1, int(new_height * img_ratio))
401
- else:
402
- new_width = target_size[0]
403
- new_height = max(1, int(new_width / img_ratio))
404
-
405
- img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
406
- left = (new_width - target_size[0]) // 2
407
- top = (new_height - target_size[1]) // 2
408
- right = left + target_size[0]
409
- bottom = top + target_size[1]
410
-
411
- if len(outline) == 7 and outline.startswith("#"):
412
- border_px = outline_size
413
- bordered = Image.new("RGB", (target_size[0] + 2*border_px, target_size[1] + 2*border_px), outline)
414
- bordered.paste(img.crop((left, top, right, bottom)), (border_px, border_px))
415
- return bordered
416
- return img.crop((left, top, right, bottom))
417
-
418
- elif scale_mode == "fill":
419
- if len(outline) == 7 and outline.startswith("#"):
420
- border_px = outline_size
421
- bordered = Image.new("RGB", (target_size[0] + 2*border_px, target_size[1] + 2*border_px), outline)
422
- img = img.resize(target_size, Image.Resampling.LANCZOS)
423
- bordered.paste(img, (border_px, border_px))
424
- return bordered
425
- return img.resize(target_size, Image.Resampling.LANCZOS)
426
- else:
427
- raise ValueError("img_scale_mode must be 'fit' or 'fill'.")
428
-
429
- def stack_images_base64(image_paths: List[str], direction: str, single_img_size: Tuple[int,int], outline: str, outline_size: int, padding: int) -> Tuple[str, Tuple[int,int]]:
430
- image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
431
- n = len(image_paths)
432
- if n == 0:
433
- blank = Image.new("RGB", single_img_size, (255,255,255))
434
- buf = BytesIO()
435
- blank.save(buf, format="PNG")
436
- return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode(), single_img_size
437
-
438
- processed = [process_image_for_slot(p, single_img_size, img_scale_mode, outline, outline_size) for p in image_paths]
439
- pad_total = padding * (n-1)
440
-
441
- if direction == "horizontal":
442
- total_w = sum(im.width for im in processed) + pad_total
443
- total_h = max(im.height for im in processed)
444
- stacked = Image.new("RGB", (total_w, total_h), (255,255,255))
445
- x = 0
446
- for im in processed:
447
- stacked.paste(im, (x,0))
448
- x += im.width + padding
449
- elif direction == "vertical":
450
- total_w = max(im.width for im in processed)
451
- total_h = sum(im.height for im in processed) + pad_total
452
- stacked = Image.new("RGB", (total_w, total_h), (255,255,255))
453
- y = 0
454
- for im in processed:
455
- stacked.paste(im, (0,y))
456
- y += im.height + padding
457
- else:
458
- raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
459
-
460
- buf = BytesIO()
461
- stacked.save(buf, format="PNG")
462
- encoded = base64.b64encode(buf.getvalue()).decode()
463
- return f"data:image/png;base64,{encoded}", (total_w, total_h)
464
-
465
- # --- Load DataFrame ---
466
- if isinstance(indir_or_csvf_or_df, str):
467
- fname, ext = os.path.splitext(indir_or_csvf_or_df)
468
- if ext.lower() == ".csv":
469
- df = pd.read_csv(indir_or_csvf_or_df)
470
- elif os.path.isdir(indir_or_csvf_or_df):
471
- df = PlotHelper.img_grid_indir_1(indir_or_csvf_or_df, log=False)
472
- else:
473
- raise ValueError("Input string must be a valid CSV file or directory path")
474
- elif isinstance(indir_or_csvf_or_df, pd.DataFrame):
475
- df = indir_or_csvf_or_df.copy()
476
- else:
477
- raise ValueError("Input must be CSV file path, DataFrame, or directory path")
478
-
479
- rows = df.iloc[:,0].astype(str).tolist()
480
- columns = list(df.columns[1:])
481
- n_rows, n_cols = len(rows), len(columns)
482
-
483
- fig = go.Figure()
484
- col_widths = [0]*n_cols
485
- row_heights = [0]*n_rows
486
-
487
- cell_imgs = [[None]*n_cols for _ in range(n_rows)]
488
- for i in range(n_rows):
489
- for j, col_label in enumerate(columns):
490
- raw_cell = df.iloc[i, j+1]
491
- image_paths = PlotHelper._parse_cell_to_list(raw_cell)
492
- image_paths = [str(p).strip() for p in image_paths if str(p).strip() != ""]
493
-
494
- img_src, (cell_w_actual, cell_h_actual) = stack_images_base64(
495
- image_paths, img_stack_direction, (img_width, img_height),
496
- outline=outline_color, outline_size=outline_size,
497
- padding=img_stack_padding_px
498
- )
499
-
500
- col_widths[j] = max(col_widths[j], cell_w_actual)
501
- row_heights[i] = max(row_heights[i], cell_h_actual)
502
- cell_imgs[i][j] = img_src
503
-
504
- # Compute x/y positions including cell_margin
505
- x_positions = []
506
- cum_w = 0
507
- for w in col_widths:
508
- x_positions.append(cum_w)
509
- cum_w += w + cell_margin_px
510
-
511
- y_positions = []
512
- cum_h = 0
513
- for h in row_heights:
514
- y_positions.append(-cum_h)
515
- cum_h += h + cell_margin_px
516
-
517
- # Add images to figure
518
- for i in range(n_rows):
519
- for j in range(n_cols):
520
- fig.add_layout_image(
521
- dict(
522
- source=cell_imgs[i][j],
523
- x=x_positions[j],
524
- y=y_positions[i],
525
- xref="x",
526
- yref="y",
527
- sizex=col_widths[j],
528
- sizey=row_heights[i],
529
- xanchor="left",
530
- yanchor="top",
531
- layer="above",
532
- )
533
- )
534
- # ! Optional grid lines
535
- # Add horizontal grid lines if row_line_size > 0
536
- if row_line_size > 0:
537
- for i in range(1, n_rows):
538
- # Place line in the middle of the gap between rows
539
- y = (
540
- y_positions[i - 1] - row_heights[i - 1] - y_positions[i]
541
- ) / 2 + y_positions[i]
542
- fig.add_shape(
543
- type="line",
544
- x0=-cell_margin_px,
545
- x1=cum_w - cell_margin_px,
546
- y0=y,
547
- y1=y,
548
- line=dict(width=row_line_size, color="black", dash="dot"),
549
- )
550
-
551
- # Add vertical grid lines if col_line_size > 0
552
- if col_line_size > 0:
553
- for j in range(1, n_cols):
554
- x = x_positions[j] - cell_margin_px / 2
555
- fig.add_shape(
556
- type="line",
557
- x0=x,
558
- x1=x,
559
- y0=cell_margin_px,
560
- y1=-cum_h + cell_margin_px,
561
- line=dict(width=col_line_size, color="black", dash="dot"),
562
- )
563
- # Axis labels
564
- col_labels = [format_col_label_func(c) if format_col_label_func else c for c in columns]
565
- row_labels = [format_row_label_func(r) if format_row_label_func else r for r in rows]
566
-
567
- fig.update_xaxes(
568
- tickvals=[x_positions[j] + col_widths[j]/2 for j in range(n_cols)],
569
- ticktext=col_labels,
570
- range=[-cell_margin_px, cum_w - cell_margin_px],
571
- showgrid=False,
572
- zeroline=False,
573
- tickfont=tickfont # <-- apply bigger font here
574
- )
575
- fig.update_yaxes(
576
- tickvals=[y_positions[i] - row_heights[i]/2 for i in range(n_rows)],
577
- ticktext=row_labels,
578
- range=[-cum_h + cell_margin_px, cell_margin_px],
579
- showgrid=False,
580
- zeroline=False,
581
- tickfont=tickfont # <-- apply bigger font here
582
- )
583
-
584
- fig.update_layout(
585
- width=cum_w + 100,
586
- height=cum_h + 100,
587
- title=title,
588
- title_x=0.5,
589
- margin=fig_margin,
590
- )
591
-
592
- fig.show()
593
-
594
- @staticmethod
595
- def plot_image_grid1(
596
- indir_or_csvf_or_df: Union[str, pd.DataFrame],
359
+ save_path: str = None,
360
+ dpi: int = 300, # DPI for saving raster images or PDF
361
+ show: bool = True, # whether to show the plot in an interactive window
597
362
  img_width: int = 300,
598
363
  img_height: int = 300,
599
364
  img_stack_direction: str = "horizontal", # "horizontal" or "vertical"
@@ -609,7 +374,7 @@ class PlotHelper:
609
374
  cell_margin_px: int = 10, # padding (top, left, right, bottom) inside each cell
610
375
  row_line_size: int = 0, # if >0, draw horizontal dotted lines
611
376
  col_line_size: int = 0, # if >0, draw vertical dotted lines
612
- ):
377
+ ) -> go.Figure:
613
378
  """
614
379
  Plot a grid of images using Plotly.
615
380
 
@@ -931,7 +696,21 @@ class PlotHelper:
931
696
  margin=fig_margin,
932
697
  )
933
698
 
934
- fig.show()
699
+ # === EXPORT IF save_path IS GIVEN ===
700
+ if save_path:
701
+ import kaleido # lazy import – only needed when saving
702
+ import os
703
+
704
+ ext = os.path.splitext(save_path)[1].lower()
705
+ if ext in [".png", ".jpg", ".jpeg"]:
706
+ fig.write_image(save_path, scale=dpi / 96) # scale = dpi / base 96
707
+ elif ext in [".pdf", ".svg"]:
708
+ fig.write_image(save_path) # PDF/SVG are vector → dpi ignored
709
+ else:
710
+ raise ValueError("save_path must end with .png, .jpg, .pdf, or .svg")
711
+ if show:
712
+ fig.show()
713
+ return fig
935
714
 
936
715
 
937
716
  @click.command()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: halib
3
- Version: 0.1.94
3
+ Version: 0.1.96
4
4
  Summary: Small library for common tasks
5
5
  Author: Hoang Van Ha
6
6
  Author-email: hoangvanhauit@gmail.com
@@ -52,7 +52,7 @@ Dynamic: summary
52
52
 
53
53
  # Helper package for coding and automation
54
54
 
55
- **Version 0.1.94**
55
+ **Version 0.1.96**
56
56
  + `research/plot': add `PlotHelper` class to plot train history + plot grid of images (e.g., image samples from dataset or model outputs)
57
57
 
58
58
 
@@ -21,7 +21,7 @@ halib/filetype/csvfile.py,sha256=4Klf8YNzY1MaCD3o5Wp5GG3KMfQIBOEVzHV_7DO5XBo,660
21
21
  halib/filetype/jsonfile.py,sha256=9LBdM7LV9QgJA1bzJRkq69qpWOP22HDXPGirqXTgSCw,480
22
22
  halib/filetype/textfile.py,sha256=QtuI5PdLxu4hAqSeafr3S8vCXwtvgipWV4Nkl7AzDYM,399
23
23
  halib/filetype/videofile.py,sha256=4nfVAYYtoT76y8P4WYyxNna4Iv1o2iV6xaMcUzNPC4s,4736
24
- halib/filetype/yamlfile.py,sha256=CqYqWtWSysQm_KBmcLjU-FXSXY4TYLPDGP4IqtnPPF4,2037
24
+ halib/filetype/yamlfile.py,sha256=59P9cdqTx655XXeQtkmAJoR_UhhVN4L8Tro-kd8Ri5g,2741
25
25
  halib/online/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  halib/online/gdrive.py,sha256=RmF4y6UPxektkKIctmfT-pKWZsBM9FVUeld6zZmJkp0,7787
27
27
  halib/online/gdrive_mkdir.py,sha256=wSJkQMJCDuS1gxQ2lHQHq_IrJ4xR_SEoPSo9n_2WNFU,1474
@@ -32,12 +32,14 @@ halib/research/base_config.py,sha256=AqZHZ0NNQ3WmUOfRzs36lf3o0FrehSdVLbdmgNpbV7A
32
32
  halib/research/base_exp.py,sha256=hiO2flt_I0iJJ4bWcQwyh2ISezoC8t2k3PtxHeVr0eI,3278
33
33
  halib/research/benchquery.py,sha256=FuKnbWQtCEoRRtJAfN-zaN-jPiO_EzsakmTOMiqi7GQ,4626
34
34
  halib/research/dataset.py,sha256=QU0Hr5QFb8_XlvnOMgC9QJGIpwXAZ9lDd0RdQi_QRec,6743
35
+ halib/research/flop_csv.py,sha256=JeIUWgPFmhkPqvmhe-MLwwvAu9yR5F2k3qaViJCJJD4,1148
36
+ halib/research/flops.py,sha256=Us0VudX8QMOm7YenICGf-Tq57C_l9x9hj-MUGA8_hCg,5773
35
37
  halib/research/metrics.py,sha256=PXPCy8r1_0lpMKfjc5SjIpRHnX80gHmeZ1C4eVj9U_s,5200
36
- halib/research/mics.py,sha256=uX17AGrBGER-OFMqUULE_A9YPPbn1RpQ4o5-omrmqZ8,377
38
+ halib/research/mics.py,sha256=nZyric8d0yKP5HrwwLsN4AjszrdxAhpJCRo1oy-EKJI,2612
37
39
  halib/research/params_gen.py,sha256=GcTMlniL0iE3HalJY-gVRiYa8Qy8u6nX4LkKZeMkct8,4262
38
- halib/research/perfcalc.py,sha256=qDa0sqfpWrwGZVJtjuUVFK7JX6j8xyXP9OnnfYmdamg,15898
39
- halib/research/perftb.py,sha256=FWg0b8wSgy4UwuvHSXwEqvTq1Rhi-z-HtAKuQg1lWc4,30989
40
- halib/research/plot.py,sha256=4xMGJuP1lGN1wF27XFM5eMFb73Gu9qB582VZhTdcCSA,38418
40
+ halib/research/perfcalc.py,sha256=G8WpGB95AY5KQCt0__bPK1yUa2M1onNhXLM7twkElxg,15904
41
+ halib/research/perftb.py,sha256=YlBXMeWn8S0LhsgxONEQZrKomRTju2T8QGGspUOy_6Y,31100
42
+ halib/research/plot.py,sha256=GBCXP1QnzRlNqjAl9UvGvW3I9II61DBStJNQThrLy38,28578
41
43
  halib/research/profiler.py,sha256=GRAewTo0jGkOputjmRwtYVfJYBze_ivsOnrW9exWkPQ,11772
42
44
  halib/research/torchloader.py,sha256=yqUjcSiME6H5W210363HyRUrOi3ISpUFAFkTr1w4DCw,6503
43
45
  halib/research/wandb_op.py,sha256=YzLEqME5kIRxi3VvjFkW83wnFrsn92oYeqYuNwtYRkY,4188
@@ -54,8 +56,8 @@ halib/utils/gpu_mon.py,sha256=vD41_ZnmPLKguuq9X44SB_vwd9JrblO4BDzHLXZhhFY,2233
54
56
  halib/utils/listop.py,sha256=Vpa8_2fI0wySpB2-8sfTBkyi_A4FhoFVVvFiuvW8N64,339
55
57
  halib/utils/tele_noti.py,sha256=-4WXZelCA4W9BroapkRyIdUu9cUVrcJJhegnMs_WpGU,5928
56
58
  halib/utils/video.py,sha256=zLoj5EHk4SmP9OnoHjO8mLbzPdtq6gQPzTQisOEDdO8,3261
57
- halib-0.1.94.dist-info/licenses/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
58
- halib-0.1.94.dist-info/METADATA,sha256=5R20faqnHeXHEoRWO4_LumYsz4_HIG5LnfpfrXoR5Y0,6200
59
- halib-0.1.94.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
60
- halib-0.1.94.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
61
- halib-0.1.94.dist-info/RECORD,,
59
+ halib-0.1.96.dist-info/licenses/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
60
+ halib-0.1.96.dist-info/METADATA,sha256=3y5yIsCp4dZVdVSSJbTPiqfmW_ZomowRQXP-_fdHit4,6200
61
+ halib-0.1.96.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
62
+ halib-0.1.96.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
63
+ halib-0.1.96.dist-info/RECORD,,
File without changes