halib 0.1.91__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. halib/__init__.py +12 -6
  2. halib/common/__init__.py +0 -0
  3. halib/common/common.py +207 -0
  4. halib/common/rich_color.py +285 -0
  5. halib/common.py +53 -10
  6. halib/exp/__init__.py +0 -0
  7. halib/exp/core/__init__.py +0 -0
  8. halib/exp/core/base_config.py +167 -0
  9. halib/exp/core/base_exp.py +147 -0
  10. halib/exp/core/param_gen.py +189 -0
  11. halib/exp/core/wandb_op.py +117 -0
  12. halib/exp/data/__init__.py +0 -0
  13. halib/exp/data/dataclass_util.py +41 -0
  14. halib/exp/data/dataset.py +208 -0
  15. halib/exp/data/torchloader.py +165 -0
  16. halib/exp/perf/__init__.py +0 -0
  17. halib/exp/perf/flop_calc.py +190 -0
  18. halib/exp/perf/gpu_mon.py +58 -0
  19. halib/exp/perf/perfcalc.py +440 -0
  20. halib/exp/perf/perfmetrics.py +137 -0
  21. halib/exp/perf/perftb.py +778 -0
  22. halib/exp/perf/profiler.py +507 -0
  23. halib/exp/viz/__init__.py +0 -0
  24. halib/exp/viz/plot.py +754 -0
  25. halib/filetype/csvfile.py +3 -9
  26. halib/filetype/ipynb.py +61 -0
  27. halib/filetype/jsonfile.py +0 -3
  28. halib/filetype/textfile.py +0 -1
  29. halib/filetype/videofile.py +119 -3
  30. halib/filetype/yamlfile.py +16 -1
  31. halib/online/projectmake.py +7 -6
  32. halib/online/tele_noti.py +165 -0
  33. halib/research/base_exp.py +75 -18
  34. halib/research/core/__init__.py +0 -0
  35. halib/research/core/base_config.py +144 -0
  36. halib/research/core/base_exp.py +157 -0
  37. halib/research/core/param_gen.py +108 -0
  38. halib/research/core/wandb_op.py +117 -0
  39. halib/research/data/__init__.py +0 -0
  40. halib/research/data/dataclass_util.py +41 -0
  41. halib/research/data/dataset.py +208 -0
  42. halib/research/data/torchloader.py +165 -0
  43. halib/research/dataset.py +6 -7
  44. halib/research/flop_csv.py +34 -0
  45. halib/research/flops.py +156 -0
  46. halib/research/metrics.py +4 -0
  47. halib/research/mics.py +59 -1
  48. halib/research/perf/__init__.py +0 -0
  49. halib/research/perf/flop_calc.py +190 -0
  50. halib/research/perf/gpu_mon.py +58 -0
  51. halib/research/perf/perfcalc.py +363 -0
  52. halib/research/perf/perfmetrics.py +137 -0
  53. halib/research/perf/perftb.py +778 -0
  54. halib/research/perf/profiler.py +301 -0
  55. halib/research/perfcalc.py +60 -35
  56. halib/research/perftb.py +2 -1
  57. halib/research/plot.py +480 -218
  58. halib/research/viz/__init__.py +0 -0
  59. halib/research/viz/plot.py +754 -0
  60. halib/system/_list_pc.csv +6 -0
  61. halib/system/filesys.py +60 -20
  62. halib/system/path.py +106 -0
  63. halib/utils/dict.py +9 -0
  64. halib/utils/list.py +12 -0
  65. halib/utils/video.py +6 -0
  66. halib-0.2.21.dist-info/METADATA +192 -0
  67. halib-0.2.21.dist-info/RECORD +109 -0
  68. halib-0.1.91.dist-info/METADATA +0 -201
  69. halib-0.1.91.dist-info/RECORD +0 -61
  70. {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/WHEEL +0 -0
  71. {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/licenses/LICENSE.txt +0 -0
  72. {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/top_level.txt +0 -0
halib/research/plot.py CHANGED
@@ -1,24 +1,28 @@
1
+ import ast
1
2
  import os
3
+ import json
4
+ import time
5
+ import click
6
+ import base64
2
7
  import pandas as pd
8
+
9
+ from PIL import Image
10
+ from io import BytesIO
11
+
3
12
  import plotly.express as px
4
- from rich.console import Console
5
- from ..common import now_str, norm_str, ConsoleLog
13
+ from ..common import now_str
6
14
  from ..filetype import csvfile
15
+ import plotly.graph_objects as go
7
16
  from ..system import filesys as fs
8
- import click
9
- import time
10
17
 
11
- import pandas as pd
12
- import plotly.graph_objects as go
13
- from PIL import Image
14
- import base64
15
- from io import BytesIO
18
+ from rich.console import Console
16
19
  from typing import Callable, Optional, Tuple, List, Union
17
20
 
18
21
 
19
22
  console = Console()
20
23
  desktop_path = os.path.expanduser("~/Desktop")
21
24
 
25
+
22
26
  class PlotHelper:
23
27
  def _verify_csv(self, csv_file):
24
28
  """Read a CSV and normalize column names (lowercase)."""
@@ -179,276 +183,534 @@ class PlotHelper:
179
183
  console.log("Stopped live updates.")
180
184
  else:
181
185
  run_once()
186
+
182
187
  @staticmethod
183
- def plot_image_grid(csv_path, sep=";", max_width=300, max_height=300):
188
+ def get_img_grid_df(input_dir, log=False):
184
189
  """
185
- Plot a grid of images using Plotly from a CSV file.
186
-
187
- Args:
188
- csv_path (str): Path to CSV file.
189
- max_width (int): Maximum width of each image in pixels.
190
- max_height (int): Maximum height of each image in pixels.
190
+ Use images in input_dir to create a dataframe for plot_image_grid.
191
+
192
+ Directory structures supported:
193
+
194
+ A. Row/Col structure:
195
+ input_dir/
196
+ ├── row0/
197
+ │ ├── col0/
198
+ │ │ ├── 0.png
199
+ │ │ ├── 1.png
200
+ │ └── col1/
201
+ │ ├── 0.png
202
+ │ ├── 1.png
203
+ ├── row1/
204
+ │ ├── col0/
205
+ │ │ ├── 0.png
206
+ │ │ ├── 1.png
207
+ │ └── col1/
208
+ │ ├── 0.png
209
+ │ ├── 1.png
210
+
211
+ B. Row-only structure (no cols):
212
+ input_dir/
213
+ ├── row0/
214
+ │ ├── 0.png
215
+ │ ├── 1.png
216
+ ├── row1/
217
+ │ ├── 0.png
218
+ │ ├── 1.png
219
+
220
+ Returns:
221
+ pd.DataFrame: DataFrame suitable for plot_image_grid.
222
+ Each cell contains a list of image paths.
191
223
  """
192
- # Load CSV
193
- df = csvfile.read_auto_sep(csv_path, sep=sep)
194
-
195
- # Column names for headers
196
- col_names = df.columns.tolist()
197
-
198
- # Function to convert image to base64
199
- def pil_to_base64(img_path):
200
- with Image.open(img_path) as im:
201
- im.thumbnail((max_width, max_height))
202
- buffer = BytesIO()
203
- im.save(buffer, format="PNG")
204
- encoded = base64.b64encode(buffer.getvalue()).decode()
205
- return "data:image/png;base64," + encoded
206
-
207
- # Initialize figure
208
- fig = go.Figure()
209
-
210
- n_rows = len(df)
211
- n_cols = len(df.columns) - 1 # skip label column
224
+ # --- Collect row dirs ---
225
+ rows = sorted([r for r in fs.list_dirs(input_dir) if r.startswith("row")])
226
+ if not rows:
227
+ raise ValueError(f"No 'row*' directories found in {input_dir}")
228
+
229
+ first_row_path = os.path.join(input_dir, rows[0])
230
+ subdirs = fs.list_dirs(first_row_path)
231
+
232
+ if subdirs: # --- Case A: row/col structure ---
233
+ cols_ref = sorted(subdirs)
234
+
235
+ # Ensure column consistency
236
+ meta_dict = {row: sorted(fs.list_dirs(os.path.join(input_dir, row))) for row in rows}
237
+ for row, cols in meta_dict.items():
238
+ if cols != cols_ref:
239
+ raise ValueError(f"Row {row} has mismatched columns: {cols} vs {cols_ref}")
240
+
241
+ # Collect image paths
242
+ meta_with_paths = {
243
+ row: {
244
+ col: fs.filter_files_by_extension(os.path.join(input_dir, row, col), ["png", "jpg", "jpeg"])
245
+ for col in cols_ref
246
+ }
247
+ for row in rows
248
+ }
212
249
 
213
- # Add images
214
- for i, row in df.iterrows():
215
- for j, col in enumerate(df.columns[1:]):
216
- img_path = row[col]
217
- img_src = pil_to_base64(img_path)
218
- fig.add_layout_image(
219
- dict(
220
- source=img_src,
221
- x=j,
222
- y=-i, # negative to have row 0 on top
223
- xref="x",
224
- yref="y",
225
- sizex=1,
226
- sizey=1,
227
- xanchor="left",
228
- yanchor="top",
229
- layer="above"
230
- )
231
- )
250
+ # Validate equal number of images per (row, col)
251
+ n_imgs = len(meta_with_paths[rows[0]][cols_ref[0]])
252
+ for row, cols in meta_with_paths.items():
253
+ for col, paths in cols.items():
254
+ if len(paths) != n_imgs:
255
+ raise ValueError(
256
+ f"Inconsistent file counts in {row}/{col}: {len(paths)} vs expected {n_imgs}"
257
+ )
258
+
259
+ # Flatten long format
260
+ data = {"row": [row for row in rows for _ in range(n_imgs)]}
261
+ for col in cols_ref:
262
+ data[col] = [meta_with_paths[row][col][i] for row in rows for i in range(n_imgs)]
263
+
264
+ else: # --- Case B: row-only structure ---
265
+ meta_with_paths = {
266
+ row: fs.filter_files_by_extension(os.path.join(input_dir, row), ["png", "jpg", "jpeg"])
267
+ for row in rows
268
+ }
232
269
 
233
- # Set axes for grid layout
234
- fig.update_xaxes(
235
- tickvals=list(range(n_cols)),
236
- ticktext=list(df.columns[1:]),
237
- range=[-0.5, n_cols-0.5],
238
- showgrid=False,
239
- zeroline=False
240
- )
241
- fig.update_yaxes(
242
- tickvals=[-i for i in range(n_rows)],
243
- ticktext=df[df.columns[0]],
244
- range=[-n_rows + 0.5, 0.5],
245
- showgrid=False,
246
- zeroline=False
270
+ # Validate equal number of images per row
271
+ n_imgs = len(next(iter(meta_with_paths.values())))
272
+ for row, paths in meta_with_paths.items():
273
+ if len(paths) != n_imgs:
274
+ raise ValueError(f"Inconsistent file counts in {row}: {len(paths)} vs expected {n_imgs}")
275
+
276
+ # Flatten long format (images indexed as img0,img1,...)
277
+ data = {"row": rows}
278
+ for i in range(n_imgs):
279
+ data[f"img{i}"] = [meta_with_paths[row][i] for row in rows]
280
+
281
+ # --- Convert to wide "multi-list" format ---
282
+ df = pd.DataFrame(data)
283
+ row_col = df.columns[0] # first col = row labels
284
+ # col_cols = df.columns[1:] # the rest = groupable cols
285
+
286
+ df = (
287
+ df.melt(id_vars=[row_col], var_name="col", value_name="path")
288
+ .groupby([row_col, "col"])["path"]
289
+ .apply(list)
290
+ .unstack("col")
291
+ .reset_index()
247
292
  )
248
293
 
249
- fig.update_layout(
250
- width=max_width*n_cols,
251
- height=max_height*n_rows,
252
- margin=dict(l=100, r=20, t=50, b=50)
253
- )
294
+ if log:
295
+ csvfile.fn_display_df(df)
254
296
 
255
- fig.show()
297
+ return df
256
298
 
257
299
  @staticmethod
258
- # this plot_df contains the data to be plotted (row, column)
259
- def img_grid_df(input_dir, log=False):
260
- rows = fs.list_dirs(input_dir)
261
- rows = [r for r in rows if r.startswith("row")]
262
- meta_dict = {}
263
- cols_of_row = None
264
- for row in rows:
265
- row_path = os.path.join(input_dir, row)
266
- cols = sorted(fs.list_dirs(row_path))
267
- if cols_of_row is None:
268
- cols_of_row = cols
269
- else:
270
- if cols_of_row != cols:
271
- raise ValueError(
272
- f"Row {row} has different columns than previous rows: {cols_of_row} vs {cols}"
273
- )
274
- meta_dict[row] = cols
300
+ def _parse_cell_to_list(cell) -> List[str]:
301
+ """Parse a DataFrame cell that may already be a list, a Python-list string, JSON list string,
302
+ or a single path. Returns list[str]."""
303
+ if cell is None:
304
+ return []
305
+ # pandas NA
306
+ try:
307
+ if pd.isna(cell):
308
+ return []
309
+ except Exception:
310
+ pass
275
311
 
276
- meta_dict_with_paths = {}
277
- for row, cols in meta_dict.items():
278
- meta_dict_with_paths[row] = {
279
- col: fs.filter_files_by_extension(
280
- os.path.join(input_dir, row, col), ["png", "jpg", "jpeg"]
281
- )
282
- for col in cols
283
- }
284
- first_row = list(meta_dict_with_paths.keys())[0]
285
- first_col = list(meta_dict_with_paths[first_row].keys())[0]
286
- len_first_col = len(meta_dict_with_paths[first_row][first_col])
287
- for row, cols in meta_dict_with_paths.items():
288
- for col, paths in cols.items():
289
- if len(paths) != len_first_col:
290
- raise ValueError(
291
- f"Row {row}, Column {col} has different number of files: {len(paths)} vs {len_first_col}"
292
- )
293
- cols = sorted(meta_dict_with_paths[first_row].keys())
294
- rows_set = sorted(meta_dict_with_paths.keys())
295
- row_per_col = len(meta_dict_with_paths[first_row][first_col])
296
- rows = [item for item in rows_set for _ in range(row_per_col)]
297
- data_dict = {}
298
- data_dict["row"] = rows
299
- col_data = {col: [] for col in cols}
300
- for row_base in rows_set:
301
- for col in cols:
302
- for i in range(row_per_col):
303
- col_data[col].append(meta_dict_with_paths[row_base][col][i])
304
- data_dict.update(col_data)
305
- df = pd.DataFrame(data_dict)
306
- if log:
307
- csvfile.fn_display_df(df)
308
- return df
312
+ if isinstance(cell, list):
313
+ return [str(x) for x in cell]
314
+
315
+ if isinstance(cell, (tuple, set)):
316
+ return [str(x) for x in cell]
317
+
318
+ if isinstance(cell, str):
319
+ s = cell.strip()
320
+ if not s:
321
+ return []
322
+
323
+ # Try Python literal (e.g. "['a','b']")
324
+ try:
325
+ val = ast.literal_eval(s)
326
+ if isinstance(val, (list, tuple)):
327
+ return [str(x) for x in val]
328
+ if isinstance(val, str):
329
+ return [val]
330
+ except Exception:
331
+ pass
332
+
333
+ # Try JSON
334
+ try:
335
+ val = json.loads(s)
336
+ if isinstance(val, list):
337
+ return [str(x) for x in val]
338
+ if isinstance(val, str):
339
+ return [val]
340
+ except Exception:
341
+ pass
342
+
343
+ # Fallback: split on common separators
344
+ for sep in [";;", ";", "|", ", "]:
345
+ if sep in s:
346
+ parts = [p.strip() for p in s.split(sep) if p.strip()]
347
+ if parts:
348
+ return parts
349
+
350
+ # Single path string
351
+ return [s]
352
+
353
+ # anything else -> coerce to string
354
+ return [str(cell)]
309
355
 
310
356
  @staticmethod
311
357
  def plot_image_grid(
312
- csv_file_or_df: Union[str, pd.DataFrame],
313
- max_width: int = 300,
314
- max_height: int = 300,
315
- img_stack_direction: str = "horizontal",
316
- img_stack_padding_px: int = 10,
358
+ indir_or_csvf_or_df: Union[str, pd.DataFrame],
359
+ save_path: str = None,
360
+ dpi: int = 300, # DPI for saving raster images or PDF
361
+ show: bool = True, # whether to show the plot in an interactive window
362
+ img_width: int = 300,
363
+ img_height: int = 300,
364
+ img_stack_direction: str = "horizontal", # "horizontal" or "vertical"
365
+ img_stack_padding_px: int = 5,
366
+ img_scale_mode: str = "fit", # "fit" or "fill"
317
367
  format_row_label_func: Optional[Callable[[str], str]] = None,
318
- format_col_label_func: Optional[Callable[[str, str], str]] = None,
368
+ format_col_label_func: Optional[Callable[[str], str]] = None,
319
369
  title: str = "",
320
- ):
370
+ tickfont=dict(size=16, family="Arial", color="black"), # <-- bigger labels
371
+ fig_margin: dict = dict(l=50, r=50, t=50, b=50),
372
+ outline_color: str = "",
373
+ outline_size: int = 1,
374
+ cell_margin_px: int = 10, # padding (top, left, right, bottom) inside each cell
375
+ row_line_size: int = 0, # if >0, draw horizontal dotted lines
376
+ col_line_size: int = 0, # if >0, draw vertical dotted lines
377
+ ) -> go.Figure:
321
378
  """
322
- Plot a grid of images using Plotly from a DataFrame.
323
-
324
- Args:
325
- df (pd.DataFrame): DataFrame with first column as row labels, remaining columns as image paths.
326
- max_width (int): Maximum width of stacked images per cell in pixels.
327
- max_height (int): Maximum height of stacked images per cell in pixels.
328
- img_stack_direction (str): "horizontal" or "vertical" stacking.
329
- img_stack_padding_px (int): Padding between stacked images in pixels.
330
- format_row_label_func (Callable): Function to format row labels.
331
- format_col_label_func (Callable): Function to format column labels.
332
- title (str): Figure title.
379
+ Plot a grid of images using Plotly.
380
+
381
+ - Accepts DataFrame where each cell is either:
382
+ * a Python list object,
383
+ * a string representation of a Python list (e.g. "['a','b']"),
384
+ * a JSON list string, or
385
+ * a single path string.
386
+ - For each cell, stack the images into a single composite that exactly fits
387
+ (img_width, img_height) is the target size for each individual image in the stack.
388
+ The final cell size will depend on the number of images and stacking direction.
333
389
  """
334
390
 
391
+ def process_image_for_slot(
392
+ path: str,
393
+ target_size: Tuple[int, int],
394
+ scale_mode: str,
395
+ outline: str,
396
+ outline_size: int,
397
+ ) -> Image.Image:
398
+ try:
399
+ img = Image.open(path).convert("RGB")
400
+ except Exception:
401
+ return Image.new("RGB", target_size, (255, 255, 255))
402
+
403
+ if scale_mode == "fit":
404
+ img_ratio = img.width / img.height
405
+ target_ratio = target_size[0] / target_size[1]
406
+
407
+ if img_ratio > target_ratio:
408
+ new_height = target_size[1]
409
+ new_width = max(1, int(new_height * img_ratio))
410
+ else:
411
+ new_width = target_size[0]
412
+ new_height = max(1, int(new_width / img_ratio))
413
+
414
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
415
+ left = (new_width - target_size[0]) // 2
416
+ top = (new_height - target_size[1]) // 2
417
+ right = left + target_size[0]
418
+ bottom = top + target_size[1]
419
+
420
+ if len(outline) == 7 and outline.startswith("#"):
421
+ border_px = outline_size
422
+ bordered = Image.new(
423
+ "RGB",
424
+ (target_size[0] + 2 * border_px, target_size[1] + 2 * border_px),
425
+ outline,
426
+ )
427
+ bordered.paste(
428
+ img.crop((left, top, right, bottom)), (border_px, border_px)
429
+ )
430
+ return bordered
431
+ return img.crop((left, top, right, bottom))
432
+
433
+ elif scale_mode == "fill":
434
+ if len(outline) == 7 and outline.startswith("#"):
435
+ border_px = outline_size
436
+ bordered = Image.new(
437
+ "RGB",
438
+ (target_size[0] + 2 * border_px, target_size[1] + 2 * border_px),
439
+ outline,
440
+ )
441
+ img = img.resize(target_size, Image.Resampling.LANCZOS)
442
+ bordered.paste(img, (border_px, border_px))
443
+ return bordered
444
+ return img.resize(target_size, Image.Resampling.LANCZOS)
445
+ else:
446
+ raise ValueError("img_scale_mode must be 'fit' or 'fill'.")
447
+
335
448
  def stack_images_base64(
336
- image_paths: List[str], direction: str, target_size: Tuple[int, int]
337
- ) -> str:
338
- """Stack images and return base64-encoded PNG."""
339
- if not image_paths:
340
- return ""
341
-
342
- processed_images = []
343
- for path in image_paths:
344
- try:
345
- img = Image.open(path).convert("RGB")
346
- img.thumbnail(target_size, Image.Resampling.LANCZOS)
347
- processed_images.append(img)
348
- except:
349
- # blank image if error
350
- processed_images.append(Image.new("RGB", target_size, (255, 255, 255)))
351
-
352
- # Stack
353
- widths, heights = zip(*(img.size for img in processed_images))
354
- if direction == "horizontal":
355
- total_width = sum(widths) + img_stack_padding_px * (
356
- len(processed_images) - 1
449
+ image_paths: List[str],
450
+ direction: str,
451
+ single_img_size: Tuple[int, int],
452
+ outline: str,
453
+ outline_size: int,
454
+ padding: int,
455
+ ) -> Tuple[str, Tuple[int, int]]:
456
+ image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
457
+ n = len(image_paths)
458
+ if n == 0:
459
+ blank = Image.new("RGB", single_img_size, (255, 255, 255))
460
+ buf = BytesIO()
461
+ blank.save(buf, format="PNG")
462
+ return (
463
+ "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode(),
464
+ single_img_size,
357
465
  )
358
- total_height = max(heights)
359
- stacked = Image.new("RGB", (total_width, total_height), (255, 255, 255))
360
- x_offset = 0
361
- for im in processed_images:
362
- stacked.paste(im, (x_offset, 0))
363
- x_offset += im.width + img_stack_padding_px
364
- elif direction == "vertical":
365
- total_width = max(widths)
366
- total_height = sum(heights) + img_stack_padding_px * (
367
- len(processed_images) - 1
368
- )
369
- stacked = Image.new("RGB", (total_width, total_height), (255, 255, 255))
370
- y_offset = 0
371
- for im in processed_images:
372
- stacked.paste(im, (0, y_offset))
373
- y_offset += im.height + img_stack_padding_px
374
- else:
375
- raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'")
376
466
 
377
- # Encode as base64
378
- buffer = BytesIO()
379
- stacked.save(buffer, format="PNG")
380
- encoded = base64.b64encode(buffer.getvalue()).decode()
381
- return "data:image/png;base64," + encoded
467
+ processed = [
468
+ process_image_for_slot(
469
+ p, single_img_size, img_scale_mode, outline, outline_size
470
+ )
471
+ for p in image_paths
472
+ ]
473
+ pad_total = padding * (n - 1)
382
474
 
383
- # Load DataFrame if a file path is provided
384
- if isinstance(csv_file_or_df, str):
385
- df = csvfile.read_auto_sep(csv_file_or_df)
475
+ if direction == "horizontal":
476
+ total_w = sum(im.width for im in processed) + pad_total
477
+ total_h = max(im.height for im in processed)
478
+ stacked = Image.new("RGB", (total_w, total_h), (255, 255, 255))
479
+ x = 0
480
+ for im in processed:
481
+ stacked.paste(im, (x, 0))
482
+ x += im.width + padding
483
+ elif direction == "vertical":
484
+ total_w = max(im.width for im in processed)
485
+ total_h = sum(im.height for im in processed) + pad_total
486
+ stacked = Image.new("RGB", (total_w, total_h), (255, 255, 255))
487
+ y = 0
488
+ for im in processed:
489
+ stacked.paste(im, (0, y))
490
+ y += im.height + padding
491
+ else:
492
+ raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
493
+
494
+ buf = BytesIO()
495
+ stacked.save(buf, format="PNG")
496
+ encoded = base64.b64encode(buf.getvalue()).decode()
497
+ return f"data:image/png;base64,{encoded}", (total_w, total_h)
498
+
499
+ def compute_stacked_size(
500
+ image_paths: List[str],
501
+ direction: str,
502
+ single_w: int,
503
+ single_h: int,
504
+ padding: int,
505
+ outline: str,
506
+ outline_size: int,
507
+ ) -> Tuple[int, int]:
508
+ image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
509
+ n = len(image_paths)
510
+ if n == 0:
511
+ return single_w, single_h
512
+ has_outline = len(outline) == 7 and outline.startswith("#")
513
+ border = 2 * outline_size if has_outline else 0
514
+ unit_w = single_w + border
515
+ unit_h = single_h + border
516
+ if direction == "horizontal":
517
+ total_w = n * unit_w + (n - 1) * padding
518
+ total_h = unit_h
519
+ elif direction == "vertical":
520
+ total_w = unit_w
521
+ total_h = n * unit_h + (n - 1) * padding
522
+ else:
523
+ raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
524
+ return total_w, total_h
525
+
526
+ # --- Load DataFrame ---
527
+ if isinstance(indir_or_csvf_or_df, str):
528
+ fname, ext = os.path.splitext(indir_or_csvf_or_df)
529
+ if ext.lower() == ".csv":
530
+ df = pd.read_csv(indir_or_csvf_or_df)
531
+ elif os.path.isdir(indir_or_csvf_or_df):
532
+ df = PlotHelper.img_grid_indir_1(indir_or_csvf_or_df, log=False)
533
+ else:
534
+ raise ValueError("Input string must be a valid CSV file or directory path")
535
+ elif isinstance(indir_or_csvf_or_df, pd.DataFrame):
536
+ df = indir_or_csvf_or_df.copy()
386
537
  else:
387
- df = csv_file_or_df
388
- assert isinstance(df, pd.DataFrame), "Input must be a DataFrame or valid CSV file path"
538
+ raise ValueError("Input must be CSV file path, DataFrame, or directory path")
389
539
 
390
- rows = df[df.columns[0]].tolist()
391
- columns = df.columns[1:].tolist()
540
+ rows = df.iloc[:, 0].astype(str).tolist()
541
+ columns = list(df.columns[1:])
392
542
  n_rows, n_cols = len(rows), len(columns)
393
543
 
394
544
  fig = go.Figure()
395
545
 
396
- for i, row_label in enumerate(rows):
397
- for j, col_label in enumerate(columns):
398
- image_paths = df.loc[i, col_label]
399
- if isinstance(image_paths, str):
400
- image_paths = [image_paths]
401
- img_src = stack_images_base64(
402
- image_paths, img_stack_direction, (max_width, max_height)
546
+ # First pass: compute content sizes
547
+ content_col_max = [0] * n_cols
548
+ content_row_max = [0] * n_rows
549
+ cell_paths = [[None] * n_cols for _ in range(n_rows)]
550
+ for i in range(n_rows):
551
+ for j in range(n_cols):
552
+ raw_cell = df.iloc[i, j + 1]
553
+ paths = PlotHelper._parse_cell_to_list(raw_cell)
554
+ image_paths = [str(p).strip() for p in paths if str(p).strip() != ""]
555
+ cell_paths[i][j] = image_paths
556
+ cw, ch = compute_stacked_size(
557
+ image_paths,
558
+ img_stack_direction,
559
+ img_width,
560
+ img_height,
561
+ img_stack_padding_px,
562
+ outline_color,
563
+ outline_size,
403
564
  )
404
-
565
+ content_col_max[j] = max(content_col_max[j], cw)
566
+ content_row_max[i] = max(content_row_max[i], ch)
567
+
568
+ # Compute display sizes (content max + padding)
569
+ display_col_w = [content_col_max[j] + 2 * cell_margin_px for j in range(n_cols)]
570
+ display_row_h = [content_row_max[i] + 2 * cell_margin_px for i in range(n_rows)]
571
+
572
+ # Compute positions (cells adjacent)
573
+ x_positions = []
574
+ cum_w = 0
575
+ for dw in display_col_w:
576
+ x_positions.append(cum_w)
577
+ cum_w += dw
578
+
579
+ y_positions = []
580
+ cum_h = 0
581
+ for dh in display_row_h:
582
+ y_positions.append(-cum_h)
583
+ cum_h += dh
584
+
585
+ # Second pass: create padded images (centered content)
586
+ cell_imgs = [[None] * n_cols for _ in range(n_rows)]
587
+ p = cell_margin_px
588
+ for i in range(n_rows):
589
+ for j in range(n_cols):
590
+ image_paths = cell_paths[i][j]
591
+ content_src, (cw, ch) = stack_images_base64(
592
+ image_paths,
593
+ img_stack_direction,
594
+ (img_width, img_height),
595
+ outline_color,
596
+ outline_size,
597
+ img_stack_padding_px,
598
+ )
599
+ if cw == 0 or ch == 0:
600
+ # Skip empty, but create white padded
601
+ pad_w = display_col_w[j]
602
+ pad_h = display_row_h[i]
603
+ padded = Image.new("RGB", (pad_w, pad_h), (255, 255, 255))
604
+ else:
605
+ content_img = Image.open(
606
+ BytesIO(base64.b64decode(content_src.split(",")[1]))
607
+ )
608
+ ca_w = content_col_max[j]
609
+ ca_h = content_row_max[i]
610
+ left_offset = (ca_w - cw) // 2
611
+ top_offset = (ca_h - ch) // 2
612
+ pad_w = display_col_w[j]
613
+ pad_h = display_row_h[i]
614
+ padded = Image.new("RGB", (pad_w, pad_h), (255, 255, 255))
615
+ paste_x = p + left_offset
616
+ paste_y = p + top_offset
617
+ padded.paste(content_img, (paste_x, paste_y))
618
+ buf = BytesIO()
619
+ padded.save(buf, format="PNG")
620
+ encoded = base64.b64encode(buf.getvalue()).decode()
621
+ cell_imgs[i][j] = f"data:image/png;base64,{encoded}"
622
+
623
+ # Add images to figure
624
+ for i in range(n_rows):
625
+ for j in range(n_cols):
405
626
  fig.add_layout_image(
406
627
  dict(
407
- source=img_src,
408
- x=j,
409
- y=-i, # negative so row 0 on top
628
+ source=cell_imgs[i][j],
629
+ x=x_positions[j],
630
+ y=y_positions[i],
410
631
  xref="x",
411
632
  yref="y",
412
- sizex=1,
413
- sizey=1,
633
+ sizex=display_col_w[j],
634
+ sizey=display_row_h[i],
414
635
  xanchor="left",
415
636
  yanchor="top",
416
637
  layer="above",
417
638
  )
418
639
  )
419
640
 
420
- # Format axis labels
641
+ # Optional grid lines (at cell boundaries, adjusted for inter-content spaces)
642
+ if row_line_size > 0:
643
+ for i in range(1, n_rows):
644
+ y = (y_positions[i - 1] - display_row_h[i - 1] + y_positions[i]) / 2
645
+ fig.add_shape(
646
+ type="line",
647
+ x0=-p,
648
+ x1=cum_w,
649
+ y0=y,
650
+ y1=y,
651
+ line=dict(width=row_line_size, color="black", dash="dot"),
652
+ )
653
+
654
+ if col_line_size > 0:
655
+ for j in range(1, n_cols):
656
+ x = x_positions[j]
657
+ fig.add_shape(
658
+ type="line",
659
+ x0=x,
660
+ x1=x,
661
+ y0=p,
662
+ y1=-cum_h,
663
+ line=dict(width=col_line_size, color="black", dash="dot"),
664
+ )
665
+
666
+ # Axis labels
421
667
  col_labels = [
422
- format_col_label_func(c, pattern="___") if format_col_label_func else c
423
- for c in columns
668
+ format_col_label_func(c) if format_col_label_func else c for c in columns
424
669
  ]
425
670
  row_labels = [
426
671
  format_row_label_func(r) if format_row_label_func else r for r in rows
427
672
  ]
428
673
 
429
674
  fig.update_xaxes(
430
- tickvals=list(range(n_cols)),
675
+ tickvals=[x_positions[j] + display_col_w[j] / 2 for j in range(n_cols)],
431
676
  ticktext=col_labels,
432
- range=[-0.5, n_cols - 0.5],
677
+ range=[-p, cum_w],
433
678
  showgrid=False,
434
679
  zeroline=False,
680
+ tickfont=tickfont, # <-- apply bigger font here
435
681
  )
436
682
  fig.update_yaxes(
437
- tickvals=[-i for i in range(n_rows)],
683
+ tickvals=[y_positions[i] - display_row_h[i] / 2 for i in range(n_rows)],
438
684
  ticktext=row_labels,
439
- range=[-n_rows + 0.5, 0.5],
685
+ range=[-cum_h, p],
440
686
  showgrid=False,
441
687
  zeroline=False,
688
+ tickfont=tickfont, # <-- apply bigger font here
442
689
  )
443
690
 
444
691
  fig.update_layout(
445
- width=max_width * n_cols + 200, # extra for labels
446
- height=max_height * n_rows + 100,
692
+ width=cum_w + 100,
693
+ height=cum_h + 100,
447
694
  title=title,
448
- margin=dict(l=100, r=20, t=50, b=50),
695
+ title_x=0.5,
696
+ margin=fig_margin,
449
697
  )
450
698
 
451
- fig.show()
699
+ # === EXPORT IF save_path IS GIVEN ===
700
+ if save_path:
701
+ import kaleido # lazy import – only needed when saving
702
+ import os
703
+
704
+ ext = os.path.splitext(save_path)[1].lower()
705
+ if ext in [".png", ".jpg", ".jpeg"]:
706
+ fig.write_image(save_path, scale=dpi / 96) # scale = dpi / base 96
707
+ elif ext in [".pdf", ".svg"]:
708
+ fig.write_image(save_path) # PDF/SVG are vector → dpi ignored
709
+ else:
710
+ raise ValueError("save_path must end with .png, .jpg, .pdf, or .svg")
711
+ if show:
712
+ fig.show()
713
+ return fig
452
714
 
453
715
 
454
716
  @click.command()