halib 0.1.7__py3-none-any.whl → 0.1.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. halib/__init__.py +84 -0
  2. halib/common.py +151 -0
  3. halib/cuda.py +39 -0
  4. halib/dataset.py +209 -0
  5. halib/filetype/csvfile.py +151 -45
  6. halib/filetype/ipynb.py +63 -0
  7. halib/filetype/jsonfile.py +1 -1
  8. halib/filetype/textfile.py +4 -4
  9. halib/filetype/videofile.py +44 -33
  10. halib/filetype/yamlfile.py +95 -0
  11. halib/gdrive.py +1 -1
  12. halib/online/gdrive.py +104 -54
  13. halib/online/gdrive_mkdir.py +29 -17
  14. halib/online/gdrive_test.py +31 -18
  15. halib/online/projectmake.py +58 -43
  16. halib/plot.py +296 -11
  17. halib/projectmake.py +1 -1
  18. halib/research/__init__.py +0 -0
  19. halib/research/base_config.py +100 -0
  20. halib/research/base_exp.py +100 -0
  21. halib/research/benchquery.py +131 -0
  22. halib/research/dataset.py +208 -0
  23. halib/research/flop_csv.py +34 -0
  24. halib/research/flops.py +156 -0
  25. halib/research/metrics.py +133 -0
  26. halib/research/mics.py +68 -0
  27. halib/research/params_gen.py +108 -0
  28. halib/research/perfcalc.py +336 -0
  29. halib/research/perftb.py +780 -0
  30. halib/research/plot.py +758 -0
  31. halib/research/profiler.py +300 -0
  32. halib/research/torchloader.py +162 -0
  33. halib/research/wandb_op.py +116 -0
  34. halib/rich_color.py +285 -0
  35. halib/sys/filesys.py +17 -10
  36. halib/system/__init__.py +0 -0
  37. halib/system/cmd.py +8 -0
  38. halib/system/filesys.py +124 -0
  39. halib/tele_noti.py +166 -0
  40. halib/torchloader.py +162 -0
  41. halib/utils/__init__.py +0 -0
  42. halib/utils/dataclass_util.py +40 -0
  43. halib/utils/dict_op.py +9 -0
  44. halib/utils/gpu_mon.py +58 -0
  45. halib/utils/listop.py +13 -0
  46. halib/utils/tele_noti.py +166 -0
  47. halib/utils/video.py +82 -0
  48. halib/videofile.py +1 -1
  49. halib-0.1.99.dist-info/METADATA +209 -0
  50. halib-0.1.99.dist-info/RECORD +64 -0
  51. {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/WHEEL +1 -1
  52. halib-0.1.7.dist-info/METADATA +0 -59
  53. halib-0.1.7.dist-info/RECORD +0 -30
  54. {halib-0.1.7.dist-info → halib-0.1.99.dist-info/licenses}/LICENSE.txt +0 -0
  55. {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/top_level.txt +0 -0
halib/research/plot.py ADDED
@@ -0,0 +1,758 @@
1
+ import ast
2
+ import os
3
+ import json
4
+ import time
5
+ import click
6
+ import base64
7
+ import pandas as pd
8
+
9
+ from PIL import Image
10
+ from io import BytesIO
11
+
12
+ import plotly.express as px
13
+ from ..common import now_str
14
+ from ..filetype import csvfile
15
+ import plotly.graph_objects as go
16
+ from ..system import filesys as fs
17
+
18
+ from rich.console import Console
19
+ from typing import Callable, Optional, Tuple, List, Union
20
+
21
+
22
+ console = Console()
23
+ desktop_path = os.path.expanduser("~/Desktop")
24
+
25
+
26
+ class PlotHelper:
27
+ def _verify_csv(self, csv_file):
28
+ """Read a CSV and normalize column names (lowercase)."""
29
+ try:
30
+ df = csvfile.read_auto_sep(csv_file)
31
+ df.columns = [col.lower() for col in df.columns]
32
+ return df
33
+ except FileNotFoundError:
34
+ raise FileNotFoundError(f"CSV file '{csv_file}' not found")
35
+ except Exception as e:
36
+ raise ValueError(f"Error reading CSV file '{csv_file}': {str(e)}")
37
+
38
+ @staticmethod
39
+ def _norm_str(s):
40
+ """Normalize string by converting to lowercase and replacing spaces/underscores."""
41
+ return s.lower().replace(" ", "_").replace("-", "_")
42
+
43
+ @staticmethod
44
+ def _get_file_name(file_path):
45
+ """Extract file name without extension."""
46
+ return os.path.splitext(os.path.basename(file_path))[0]
47
+
48
+ def _get_valid_tags(self, csv_files, tags):
49
+ """Generate tags from file names if not provided."""
50
+ if tags:
51
+ return list(tags)
52
+ return [self._norm_str(self._get_file_name(f)) for f in csv_files]
53
+
54
+ def _prepare_long_df(self, csv_files, tags, x_col, y_cols, log=False):
55
+ """Convert multiple CSVs into a single long-form dataframe for Plotly."""
56
+ dfs = []
57
+ for csv_file, tag in zip(csv_files, tags):
58
+ df = self._verify_csv(csv_file)
59
+ # Check columns
60
+ if x_col not in df.columns:
61
+ raise ValueError(f"{csv_file} is missing x_col '{x_col}'")
62
+ missing = [c for c in y_cols if c not in df.columns]
63
+ if missing:
64
+ raise ValueError(f"{csv_file} is missing y_cols {missing}")
65
+
66
+ if log:
67
+ console.log(f"Plotting {csv_file}")
68
+ console.print(df)
69
+
70
+ # Wide to long
71
+ df_long = df.melt(
72
+ id_vars=x_col,
73
+ value_vars=y_cols,
74
+ var_name="metric_type",
75
+ value_name="value",
76
+ )
77
+ df_long["tag"] = tag
78
+ dfs.append(df_long)
79
+
80
+ return pd.concat(dfs, ignore_index=True)
81
+
82
+ def _plot_with_plotly(
83
+ self,
84
+ df_long,
85
+ tags,
86
+ outdir,
87
+ save_fig,
88
+ out_fmt="svg",
89
+ font_size=16,
90
+ x_col="epoch",
91
+ y_cols=None,
92
+ ):
93
+ """Generate Plotly plots for given metrics."""
94
+ assert out_fmt in ["svg", "pdf", "png"], "Unsupported format"
95
+ if y_cols is None:
96
+ raise ValueError("y_cols must be provided")
97
+
98
+ # Group by suffix (e.g., "loss", "acc") if names like train_loss exist
99
+ metric_groups = sorted(set(col.split("_")[-1] for col in y_cols))
100
+
101
+ for metric in metric_groups:
102
+ subset = df_long[df_long["metric_type"].str.contains(metric)]
103
+
104
+ if out_fmt == "svg": # LaTeX-style
105
+ title = f"${'+'.join(tags)}\\_{metric}\\text{{-by-{x_col}}}$"
106
+ xaxis_title = f"$\\text{{{x_col.capitalize()}}}$"
107
+ yaxis_title = f"${metric.capitalize()}$"
108
+ else:
109
+ title = f"{'+'.join(tags)}_{metric}-by-{x_col}"
110
+ xaxis_title = x_col.capitalize()
111
+ yaxis_title = metric.capitalize()
112
+
113
+ fig = px.line(
114
+ subset,
115
+ x=x_col,
116
+ y="value",
117
+ color="tag",
118
+ line_dash="metric_type",
119
+ title=title,
120
+ )
121
+ fig.update_layout(
122
+ font=dict(family="Computer Modern", size=font_size),
123
+ xaxis_title=xaxis_title,
124
+ yaxis_title=yaxis_title,
125
+ )
126
+ fig.show()
127
+
128
+ if save_fig:
129
+ os.makedirs(outdir, exist_ok=True)
130
+ timestamp = now_str()
131
+ filename = f"{timestamp}_{'+'.join(tags)}_{metric}"
132
+ try:
133
+ fig.write_image(os.path.join(outdir, f"{filename}.{out_fmt}"))
134
+ except Exception as e:
135
+ console.log(f"Error saving figure '{filename}.{out_fmt}': {str(e)}")
136
+
137
+ @classmethod
138
+ def plot_csv_timeseries(
139
+ cls,
140
+ csv_files,
141
+ outdir="./out/plot",
142
+ tags=None,
143
+ log=False,
144
+ save_fig=False,
145
+ update_in_min=0,
146
+ out_fmt="svg",
147
+ font_size=16,
148
+ x_col="epoch",
149
+ y_cols=["train_loss", "train_acc"],
150
+ ):
151
+ """Plot CSV files with Plotly, supporting live updates, as a class method."""
152
+ if isinstance(csv_files, str):
153
+ csv_files = [csv_files]
154
+ if isinstance(tags, str):
155
+ tags = [tags]
156
+
157
+ if not y_cols:
158
+ raise ValueError("You must specify y_cols explicitly")
159
+
160
+ # Instantiate PlotHelper to call instance methods
161
+ plot_helper = cls()
162
+ valid_tags = plot_helper._get_valid_tags(csv_files, tags)
163
+ assert len(valid_tags) == len(
164
+ csv_files
165
+ ), "Number of tags must match number of CSV files"
166
+
167
+ def run_once():
168
+ df_long = plot_helper._prepare_long_df(
169
+ csv_files, valid_tags, x_col, y_cols, log
170
+ )
171
+ plot_helper._plot_with_plotly(
172
+ df_long, valid_tags, outdir, save_fig, out_fmt, font_size, x_col, y_cols
173
+ )
174
+
175
+ if update_in_min > 0:
176
+ interval = int(update_in_min * 60)
177
+ console.log(f"Live update every {interval}s. Press Ctrl+C to stop.")
178
+ try:
179
+ while True:
180
+ run_once()
181
+ time.sleep(interval)
182
+ except KeyboardInterrupt:
183
+ console.log("Stopped live updates.")
184
+ else:
185
+ run_once()
186
+
187
+ @staticmethod
188
+ def get_img_grid_df(input_dir, log=False):
189
+ """
190
+ Use images in input_dir to create a dataframe for plot_image_grid.
191
+
192
+ Directory structures supported:
193
+
194
+ A. Row/Col structure:
195
+ input_dir/
196
+ ├── row0/
197
+ │ ├── col0/
198
+ │ │ ├── 0.png
199
+ │ │ ├── 1.png
200
+ │ └── col1/
201
+ │ ├── 0.png
202
+ │ ├── 1.png
203
+ ├── row1/
204
+ │ ├── col0/
205
+ │ │ ├── 0.png
206
+ │ │ ├── 1.png
207
+ │ └── col1/
208
+ │ ├── 0.png
209
+ │ ├── 1.png
210
+
211
+ B. Row-only structure (no cols):
212
+ input_dir/
213
+ ├── row0/
214
+ │ ├── 0.png
215
+ │ ├── 1.png
216
+ ├── row1/
217
+ │ ├── 0.png
218
+ │ ├── 1.png
219
+
220
+ Returns:
221
+ pd.DataFrame: DataFrame suitable for plot_image_grid.
222
+ Each cell contains a list of image paths.
223
+ """
224
+ # --- Collect row dirs ---
225
+ rows = sorted([r for r in fs.list_dirs(input_dir) if r.startswith("row")])
226
+ if not rows:
227
+ raise ValueError(f"No 'row*' directories found in {input_dir}")
228
+
229
+ first_row_path = os.path.join(input_dir, rows[0])
230
+ subdirs = fs.list_dirs(first_row_path)
231
+
232
+ if subdirs: # --- Case A: row/col structure ---
233
+ cols_ref = sorted(subdirs)
234
+
235
+ # Ensure column consistency
236
+ meta_dict = {row: sorted(fs.list_dirs(os.path.join(input_dir, row))) for row in rows}
237
+ for row, cols in meta_dict.items():
238
+ if cols != cols_ref:
239
+ raise ValueError(f"Row {row} has mismatched columns: {cols} vs {cols_ref}")
240
+
241
+ # Collect image paths
242
+ meta_with_paths = {
243
+ row: {
244
+ col: fs.filter_files_by_extension(os.path.join(input_dir, row, col), ["png", "jpg", "jpeg"])
245
+ for col in cols_ref
246
+ }
247
+ for row in rows
248
+ }
249
+
250
+ # Validate equal number of images per (row, col)
251
+ n_imgs = len(meta_with_paths[rows[0]][cols_ref[0]])
252
+ for row, cols in meta_with_paths.items():
253
+ for col, paths in cols.items():
254
+ if len(paths) != n_imgs:
255
+ raise ValueError(
256
+ f"Inconsistent file counts in {row}/{col}: {len(paths)} vs expected {n_imgs}"
257
+ )
258
+
259
+ # Flatten long format
260
+ data = {"row": [row for row in rows for _ in range(n_imgs)]}
261
+ for col in cols_ref:
262
+ data[col] = [meta_with_paths[row][col][i] for row in rows for i in range(n_imgs)]
263
+
264
+ else: # --- Case B: row-only structure ---
265
+ meta_with_paths = {
266
+ row: fs.filter_files_by_extension(os.path.join(input_dir, row), ["png", "jpg", "jpeg"])
267
+ for row in rows
268
+ }
269
+
270
+ # Validate equal number of images per row
271
+ n_imgs = len(next(iter(meta_with_paths.values())))
272
+ for row, paths in meta_with_paths.items():
273
+ if len(paths) != n_imgs:
274
+ raise ValueError(f"Inconsistent file counts in {row}: {len(paths)} vs expected {n_imgs}")
275
+
276
+ # Flatten long format (images indexed as img0,img1,...)
277
+ data = {"row": rows}
278
+ for i in range(n_imgs):
279
+ data[f"img{i}"] = [meta_with_paths[row][i] for row in rows]
280
+
281
+ # --- Convert to wide "multi-list" format ---
282
+ df = pd.DataFrame(data)
283
+ row_col = df.columns[0] # first col = row labels
284
+ # col_cols = df.columns[1:] # the rest = groupable cols
285
+
286
+ df = (
287
+ df.melt(id_vars=[row_col], var_name="col", value_name="path")
288
+ .groupby([row_col, "col"])["path"]
289
+ .apply(list)
290
+ .unstack("col")
291
+ .reset_index()
292
+ )
293
+
294
+ if log:
295
+ csvfile.fn_display_df(df)
296
+
297
+ return df
298
+
299
+ @staticmethod
300
+ def _parse_cell_to_list(cell) -> List[str]:
301
+ """Parse a DataFrame cell that may already be a list, a Python-list string, JSON list string,
302
+ or a single path. Returns list[str]."""
303
+ if cell is None:
304
+ return []
305
+ # pandas NA
306
+ try:
307
+ if pd.isna(cell):
308
+ return []
309
+ except Exception:
310
+ pass
311
+
312
+ if isinstance(cell, list):
313
+ return [str(x) for x in cell]
314
+
315
+ if isinstance(cell, (tuple, set)):
316
+ return [str(x) for x in cell]
317
+
318
+ if isinstance(cell, str):
319
+ s = cell.strip()
320
+ if not s:
321
+ return []
322
+
323
+ # Try Python literal (e.g. "['a','b']")
324
+ try:
325
+ val = ast.literal_eval(s)
326
+ if isinstance(val, (list, tuple)):
327
+ return [str(x) for x in val]
328
+ if isinstance(val, str):
329
+ return [val]
330
+ except Exception:
331
+ pass
332
+
333
+ # Try JSON
334
+ try:
335
+ val = json.loads(s)
336
+ if isinstance(val, list):
337
+ return [str(x) for x in val]
338
+ if isinstance(val, str):
339
+ return [val]
340
+ except Exception:
341
+ pass
342
+
343
+ # Fallback: split on common separators
344
+ for sep in [";;", ";", "|", ", "]:
345
+ if sep in s:
346
+ parts = [p.strip() for p in s.split(sep) if p.strip()]
347
+ if parts:
348
+ return parts
349
+
350
+ # Single path string
351
+ return [s]
352
+
353
+ # anything else -> coerce to string
354
+ return [str(cell)]
355
+
356
+ @staticmethod
357
+ def plot_image_grid(
358
+ indir_or_csvf_or_df: Union[str, pd.DataFrame],
359
+ save_path: str = None,
360
+ dpi: int = 300, # DPI for saving raster images or PDF
361
+ show: bool = True, # whether to show the plot in an interactive window
362
+ img_width: int = 300,
363
+ img_height: int = 300,
364
+ img_stack_direction: str = "horizontal", # "horizontal" or "vertical"
365
+ img_stack_padding_px: int = 5,
366
+ img_scale_mode: str = "fit", # "fit" or "fill"
367
+ format_row_label_func: Optional[Callable[[str], str]] = None,
368
+ format_col_label_func: Optional[Callable[[str], str]] = None,
369
+ title: str = "",
370
+ tickfont=dict(size=16, family="Arial", color="black"), # <-- bigger labels
371
+ fig_margin: dict = dict(l=50, r=50, t=50, b=50),
372
+ outline_color: str = "",
373
+ outline_size: int = 1,
374
+ cell_margin_px: int = 10, # padding (top, left, right, bottom) inside each cell
375
+ row_line_size: int = 0, # if >0, draw horizontal dotted lines
376
+ col_line_size: int = 0, # if >0, draw vertical dotted lines
377
+ ) -> go.Figure:
378
+ """
379
+ Plot a grid of images using Plotly.
380
+
381
+ - Accepts DataFrame where each cell is either:
382
+ * a Python list object,
383
+ * a string representation of a Python list (e.g. "['a','b']"),
384
+ * a JSON list string, or
385
+ * a single path string.
386
+ - For each cell, stack the images into a single composite that exactly fits
387
+ (img_width, img_height) is the target size for each individual image in the stack.
388
+ The final cell size will depend on the number of images and stacking direction.
389
+ """
390
+
391
+ def process_image_for_slot(
392
+ path: str,
393
+ target_size: Tuple[int, int],
394
+ scale_mode: str,
395
+ outline: str,
396
+ outline_size: int,
397
+ ) -> Image.Image:
398
+ try:
399
+ img = Image.open(path).convert("RGB")
400
+ except Exception:
401
+ return Image.new("RGB", target_size, (255, 255, 255))
402
+
403
+ if scale_mode == "fit":
404
+ img_ratio = img.width / img.height
405
+ target_ratio = target_size[0] / target_size[1]
406
+
407
+ if img_ratio > target_ratio:
408
+ new_height = target_size[1]
409
+ new_width = max(1, int(new_height * img_ratio))
410
+ else:
411
+ new_width = target_size[0]
412
+ new_height = max(1, int(new_width / img_ratio))
413
+
414
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
415
+ left = (new_width - target_size[0]) // 2
416
+ top = (new_height - target_size[1]) // 2
417
+ right = left + target_size[0]
418
+ bottom = top + target_size[1]
419
+
420
+ if len(outline) == 7 and outline.startswith("#"):
421
+ border_px = outline_size
422
+ bordered = Image.new(
423
+ "RGB",
424
+ (target_size[0] + 2 * border_px, target_size[1] + 2 * border_px),
425
+ outline,
426
+ )
427
+ bordered.paste(
428
+ img.crop((left, top, right, bottom)), (border_px, border_px)
429
+ )
430
+ return bordered
431
+ return img.crop((left, top, right, bottom))
432
+
433
+ elif scale_mode == "fill":
434
+ if len(outline) == 7 and outline.startswith("#"):
435
+ border_px = outline_size
436
+ bordered = Image.new(
437
+ "RGB",
438
+ (target_size[0] + 2 * border_px, target_size[1] + 2 * border_px),
439
+ outline,
440
+ )
441
+ img = img.resize(target_size, Image.Resampling.LANCZOS)
442
+ bordered.paste(img, (border_px, border_px))
443
+ return bordered
444
+ return img.resize(target_size, Image.Resampling.LANCZOS)
445
+ else:
446
+ raise ValueError("img_scale_mode must be 'fit' or 'fill'.")
447
+
448
+ def stack_images_base64(
449
+ image_paths: List[str],
450
+ direction: str,
451
+ single_img_size: Tuple[int, int],
452
+ outline: str,
453
+ outline_size: int,
454
+ padding: int,
455
+ ) -> Tuple[str, Tuple[int, int]]:
456
+ image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
457
+ n = len(image_paths)
458
+ if n == 0:
459
+ blank = Image.new("RGB", single_img_size, (255, 255, 255))
460
+ buf = BytesIO()
461
+ blank.save(buf, format="PNG")
462
+ return (
463
+ "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode(),
464
+ single_img_size,
465
+ )
466
+
467
+ processed = [
468
+ process_image_for_slot(
469
+ p, single_img_size, img_scale_mode, outline, outline_size
470
+ )
471
+ for p in image_paths
472
+ ]
473
+ pad_total = padding * (n - 1)
474
+
475
+ if direction == "horizontal":
476
+ total_w = sum(im.width for im in processed) + pad_total
477
+ total_h = max(im.height for im in processed)
478
+ stacked = Image.new("RGB", (total_w, total_h), (255, 255, 255))
479
+ x = 0
480
+ for im in processed:
481
+ stacked.paste(im, (x, 0))
482
+ x += im.width + padding
483
+ elif direction == "vertical":
484
+ total_w = max(im.width for im in processed)
485
+ total_h = sum(im.height for im in processed) + pad_total
486
+ stacked = Image.new("RGB", (total_w, total_h), (255, 255, 255))
487
+ y = 0
488
+ for im in processed:
489
+ stacked.paste(im, (0, y))
490
+ y += im.height + padding
491
+ else:
492
+ raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
493
+
494
+ buf = BytesIO()
495
+ stacked.save(buf, format="PNG")
496
+ encoded = base64.b64encode(buf.getvalue()).decode()
497
+ return f"data:image/png;base64,{encoded}", (total_w, total_h)
498
+
499
+ def compute_stacked_size(
500
+ image_paths: List[str],
501
+ direction: str,
502
+ single_w: int,
503
+ single_h: int,
504
+ padding: int,
505
+ outline: str,
506
+ outline_size: int,
507
+ ) -> Tuple[int, int]:
508
+ image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
509
+ n = len(image_paths)
510
+ if n == 0:
511
+ return single_w, single_h
512
+ has_outline = len(outline) == 7 and outline.startswith("#")
513
+ border = 2 * outline_size if has_outline else 0
514
+ unit_w = single_w + border
515
+ unit_h = single_h + border
516
+ if direction == "horizontal":
517
+ total_w = n * unit_w + (n - 1) * padding
518
+ total_h = unit_h
519
+ elif direction == "vertical":
520
+ total_w = unit_w
521
+ total_h = n * unit_h + (n - 1) * padding
522
+ else:
523
+ raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
524
+ return total_w, total_h
525
+
526
+ # --- Load DataFrame ---
527
+ if isinstance(indir_or_csvf_or_df, str):
528
+ fname, ext = os.path.splitext(indir_or_csvf_or_df)
529
+ if ext.lower() == ".csv":
530
+ df = pd.read_csv(indir_or_csvf_or_df)
531
+ elif os.path.isdir(indir_or_csvf_or_df):
532
+ df = PlotHelper.img_grid_indir_1(indir_or_csvf_or_df, log=False)
533
+ else:
534
+ raise ValueError("Input string must be a valid CSV file or directory path")
535
+ elif isinstance(indir_or_csvf_or_df, pd.DataFrame):
536
+ df = indir_or_csvf_or_df.copy()
537
+ else:
538
+ raise ValueError("Input must be CSV file path, DataFrame, or directory path")
539
+
540
+ rows = df.iloc[:, 0].astype(str).tolist()
541
+ columns = list(df.columns[1:])
542
+ n_rows, n_cols = len(rows), len(columns)
543
+
544
+ fig = go.Figure()
545
+
546
+ # First pass: compute content sizes
547
+ content_col_max = [0] * n_cols
548
+ content_row_max = [0] * n_rows
549
+ cell_paths = [[None] * n_cols for _ in range(n_rows)]
550
+ for i in range(n_rows):
551
+ for j in range(n_cols):
552
+ raw_cell = df.iloc[i, j + 1]
553
+ paths = PlotHelper._parse_cell_to_list(raw_cell)
554
+ image_paths = [str(p).strip() for p in paths if str(p).strip() != ""]
555
+ cell_paths[i][j] = image_paths
556
+ cw, ch = compute_stacked_size(
557
+ image_paths,
558
+ img_stack_direction,
559
+ img_width,
560
+ img_height,
561
+ img_stack_padding_px,
562
+ outline_color,
563
+ outline_size,
564
+ )
565
+ content_col_max[j] = max(content_col_max[j], cw)
566
+ content_row_max[i] = max(content_row_max[i], ch)
567
+
568
+ # Compute display sizes (content max + padding)
569
+ display_col_w = [content_col_max[j] + 2 * cell_margin_px for j in range(n_cols)]
570
+ display_row_h = [content_row_max[i] + 2 * cell_margin_px for i in range(n_rows)]
571
+
572
+ # Compute positions (cells adjacent)
573
+ x_positions = []
574
+ cum_w = 0
575
+ for dw in display_col_w:
576
+ x_positions.append(cum_w)
577
+ cum_w += dw
578
+
579
+ y_positions = []
580
+ cum_h = 0
581
+ for dh in display_row_h:
582
+ y_positions.append(-cum_h)
583
+ cum_h += dh
584
+
585
+ # Second pass: create padded images (centered content)
586
+ cell_imgs = [[None] * n_cols for _ in range(n_rows)]
587
+ p = cell_margin_px
588
+ for i in range(n_rows):
589
+ for j in range(n_cols):
590
+ image_paths = cell_paths[i][j]
591
+ content_src, (cw, ch) = stack_images_base64(
592
+ image_paths,
593
+ img_stack_direction,
594
+ (img_width, img_height),
595
+ outline_color,
596
+ outline_size,
597
+ img_stack_padding_px,
598
+ )
599
+ if cw == 0 or ch == 0:
600
+ # Skip empty, but create white padded
601
+ pad_w = display_col_w[j]
602
+ pad_h = display_row_h[i]
603
+ padded = Image.new("RGB", (pad_w, pad_h), (255, 255, 255))
604
+ else:
605
+ content_img = Image.open(
606
+ BytesIO(base64.b64decode(content_src.split(",")[1]))
607
+ )
608
+ ca_w = content_col_max[j]
609
+ ca_h = content_row_max[i]
610
+ left_offset = (ca_w - cw) // 2
611
+ top_offset = (ca_h - ch) // 2
612
+ pad_w = display_col_w[j]
613
+ pad_h = display_row_h[i]
614
+ padded = Image.new("RGB", (pad_w, pad_h), (255, 255, 255))
615
+ paste_x = p + left_offset
616
+ paste_y = p + top_offset
617
+ padded.paste(content_img, (paste_x, paste_y))
618
+ buf = BytesIO()
619
+ padded.save(buf, format="PNG")
620
+ encoded = base64.b64encode(buf.getvalue()).decode()
621
+ cell_imgs[i][j] = f"data:image/png;base64,{encoded}"
622
+
623
+ # Add images to figure
624
+ for i in range(n_rows):
625
+ for j in range(n_cols):
626
+ fig.add_layout_image(
627
+ dict(
628
+ source=cell_imgs[i][j],
629
+ x=x_positions[j],
630
+ y=y_positions[i],
631
+ xref="x",
632
+ yref="y",
633
+ sizex=display_col_w[j],
634
+ sizey=display_row_h[i],
635
+ xanchor="left",
636
+ yanchor="top",
637
+ layer="above",
638
+ )
639
+ )
640
+
641
+ # Optional grid lines (at cell boundaries, adjusted for inter-content spaces)
642
+ if row_line_size > 0:
643
+ for i in range(1, n_rows):
644
+ y = (y_positions[i - 1] - display_row_h[i - 1] + y_positions[i]) / 2
645
+ fig.add_shape(
646
+ type="line",
647
+ x0=-p,
648
+ x1=cum_w,
649
+ y0=y,
650
+ y1=y,
651
+ line=dict(width=row_line_size, color="black", dash="dot"),
652
+ )
653
+
654
+ if col_line_size > 0:
655
+ for j in range(1, n_cols):
656
+ x = x_positions[j]
657
+ fig.add_shape(
658
+ type="line",
659
+ x0=x,
660
+ x1=x,
661
+ y0=p,
662
+ y1=-cum_h,
663
+ line=dict(width=col_line_size, color="black", dash="dot"),
664
+ )
665
+
666
+ # Axis labels
667
+ col_labels = [
668
+ format_col_label_func(c) if format_col_label_func else c for c in columns
669
+ ]
670
+ row_labels = [
671
+ format_row_label_func(r) if format_row_label_func else r for r in rows
672
+ ]
673
+
674
+ fig.update_xaxes(
675
+ tickvals=[x_positions[j] + display_col_w[j] / 2 for j in range(n_cols)],
676
+ ticktext=col_labels,
677
+ range=[-p, cum_w],
678
+ showgrid=False,
679
+ zeroline=False,
680
+ tickfont=tickfont, # <-- apply bigger font here
681
+ )
682
+ fig.update_yaxes(
683
+ tickvals=[y_positions[i] - display_row_h[i] / 2 for i in range(n_rows)],
684
+ ticktext=row_labels,
685
+ range=[-cum_h, p],
686
+ showgrid=False,
687
+ zeroline=False,
688
+ tickfont=tickfont, # <-- apply bigger font here
689
+ )
690
+
691
+ fig.update_layout(
692
+ width=cum_w + 100,
693
+ height=cum_h + 100,
694
+ title=title,
695
+ title_x=0.5,
696
+ margin=fig_margin,
697
+ )
698
+
699
+ # === EXPORT IF save_path IS GIVEN ===
700
+ if save_path:
701
+ import kaleido # lazy import – only needed when saving
702
+ import os
703
+
704
+ ext = os.path.splitext(save_path)[1].lower()
705
+ if ext in [".png", ".jpg", ".jpeg"]:
706
+ fig.write_image(save_path, scale=dpi / 96) # scale = dpi / base 96
707
+ elif ext in [".pdf", ".svg"]:
708
+ fig.write_image(save_path) # PDF/SVG are vector → dpi ignored
709
+ else:
710
+ raise ValueError("save_path must end with .png, .jpg, .pdf, or .svg")
711
+ if show:
712
+ fig.show()
713
+ return fig
714
+
715
+
716
+ @click.command()
717
+ @click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
718
+ @click.option(
719
+ "--outdir", "-o", type=str, default=str(desktop_path), help="output directory"
720
+ )
721
+ @click.option(
722
+ "--tags", "-t", multiple=True, type=str, default=[], help="tags for the csv files"
723
+ )
724
+ @click.option("--log", "-l", is_flag=True, help="log the csv files")
725
+ @click.option("--save_fig", "-s", is_flag=True, help="save the plot as file")
726
+ @click.option(
727
+ "--update_in_min",
728
+ "-u",
729
+ type=float,
730
+ default=0.0,
731
+ help="update the plot every x minutes",
732
+ )
733
+ @click.option(
734
+ "--x_col", "-x", type=str, default="epoch", help="column to use as x-axis"
735
+ )
736
+ @click.option(
737
+ "--y_cols",
738
+ "-y",
739
+ multiple=True,
740
+ type=str,
741
+ required=True,
742
+ help="columns to plot as y (can repeat)",
743
+ )
744
+ def main(csvfiles, outdir, tags, log, save_fig, update_in_min, x_col, y_cols):
745
+ PlotHelper.plot_csv_timeseries(
746
+ list(csvfiles),
747
+ outdir,
748
+ list(tags),
749
+ log,
750
+ save_fig,
751
+ update_in_min,
752
+ x_col=x_col,
753
+ y_cols=list(y_cols),
754
+ )
755
+
756
+
757
+ if __name__ == "__main__":
758
+ main()