halib 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. halib/__init__.py +94 -0
  2. halib/common/__init__.py +0 -0
  3. halib/common/common.py +326 -0
  4. halib/common/rich_color.py +285 -0
  5. halib/common.py +151 -0
  6. halib/csvfile.py +48 -0
  7. halib/cuda.py +39 -0
  8. halib/dataset.py +209 -0
  9. halib/exp/__init__.py +0 -0
  10. halib/exp/core/__init__.py +0 -0
  11. halib/exp/core/base_config.py +167 -0
  12. halib/exp/core/base_exp.py +147 -0
  13. halib/exp/core/param_gen.py +170 -0
  14. halib/exp/core/wandb_op.py +117 -0
  15. halib/exp/data/__init__.py +0 -0
  16. halib/exp/data/dataclass_util.py +41 -0
  17. halib/exp/data/dataset.py +208 -0
  18. halib/exp/data/torchloader.py +165 -0
  19. halib/exp/perf/__init__.py +0 -0
  20. halib/exp/perf/flop_calc.py +190 -0
  21. halib/exp/perf/gpu_mon.py +58 -0
  22. halib/exp/perf/perfcalc.py +470 -0
  23. halib/exp/perf/perfmetrics.py +137 -0
  24. halib/exp/perf/perftb.py +778 -0
  25. halib/exp/perf/profiler.py +507 -0
  26. halib/exp/viz/__init__.py +0 -0
  27. halib/exp/viz/plot.py +754 -0
  28. halib/filesys.py +117 -0
  29. halib/filetype/__init__.py +0 -0
  30. halib/filetype/csvfile.py +192 -0
  31. halib/filetype/ipynb.py +61 -0
  32. halib/filetype/jsonfile.py +19 -0
  33. halib/filetype/textfile.py +12 -0
  34. halib/filetype/videofile.py +266 -0
  35. halib/filetype/yamlfile.py +87 -0
  36. halib/gdrive.py +179 -0
  37. halib/gdrive_mkdir.py +41 -0
  38. halib/gdrive_test.py +37 -0
  39. halib/jsonfile.py +22 -0
  40. halib/listop.py +13 -0
  41. halib/online/__init__.py +0 -0
  42. halib/online/gdrive.py +229 -0
  43. halib/online/gdrive_mkdir.py +53 -0
  44. halib/online/gdrive_test.py +50 -0
  45. halib/online/projectmake.py +131 -0
  46. halib/online/tele_noti.py +165 -0
  47. halib/plot.py +301 -0
  48. halib/projectmake.py +115 -0
  49. halib/research/__init__.py +0 -0
  50. halib/research/base_config.py +100 -0
  51. halib/research/base_exp.py +157 -0
  52. halib/research/benchquery.py +131 -0
  53. halib/research/core/__init__.py +0 -0
  54. halib/research/core/base_config.py +144 -0
  55. halib/research/core/base_exp.py +157 -0
  56. halib/research/core/param_gen.py +108 -0
  57. halib/research/core/wandb_op.py +117 -0
  58. halib/research/data/__init__.py +0 -0
  59. halib/research/data/dataclass_util.py +41 -0
  60. halib/research/data/dataset.py +208 -0
  61. halib/research/data/torchloader.py +165 -0
  62. halib/research/dataset.py +208 -0
  63. halib/research/flop_csv.py +34 -0
  64. halib/research/flops.py +156 -0
  65. halib/research/metrics.py +137 -0
  66. halib/research/mics.py +74 -0
  67. halib/research/params_gen.py +108 -0
  68. halib/research/perf/__init__.py +0 -0
  69. halib/research/perf/flop_calc.py +190 -0
  70. halib/research/perf/gpu_mon.py +58 -0
  71. halib/research/perf/perfcalc.py +363 -0
  72. halib/research/perf/perfmetrics.py +137 -0
  73. halib/research/perf/perftb.py +778 -0
  74. halib/research/perf/profiler.py +301 -0
  75. halib/research/perfcalc.py +361 -0
  76. halib/research/perftb.py +780 -0
  77. halib/research/plot.py +758 -0
  78. halib/research/profiler.py +300 -0
  79. halib/research/torchloader.py +162 -0
  80. halib/research/viz/__init__.py +0 -0
  81. halib/research/viz/plot.py +754 -0
  82. halib/research/wandb_op.py +116 -0
  83. halib/rich_color.py +285 -0
  84. halib/sys/__init__.py +0 -0
  85. halib/sys/cmd.py +8 -0
  86. halib/sys/filesys.py +124 -0
  87. halib/system/__init__.py +0 -0
  88. halib/system/_list_pc.csv +6 -0
  89. halib/system/cmd.py +8 -0
  90. halib/system/filesys.py +164 -0
  91. halib/system/path.py +106 -0
  92. halib/tele_noti.py +166 -0
  93. halib/textfile.py +13 -0
  94. halib/torchloader.py +162 -0
  95. halib/utils/__init__.py +0 -0
  96. halib/utils/dataclass_util.py +40 -0
  97. halib/utils/dict.py +317 -0
  98. halib/utils/dict_op.py +9 -0
  99. halib/utils/gpu_mon.py +58 -0
  100. halib/utils/list.py +17 -0
  101. halib/utils/listop.py +13 -0
  102. halib/utils/slack.py +86 -0
  103. halib/utils/tele_noti.py +166 -0
  104. halib/utils/video.py +82 -0
  105. halib/videofile.py +139 -0
  106. halib-0.2.30.dist-info/METADATA +237 -0
  107. halib-0.2.30.dist-info/RECORD +110 -0
  108. halib-0.2.30.dist-info/WHEEL +5 -0
  109. halib-0.2.30.dist-info/licenses/LICENSE.txt +17 -0
  110. halib-0.2.30.dist-info/top_level.txt +1 -0
@@ -0,0 +1,754 @@
1
+ import ast
2
+ import os
3
+ import json
4
+ import time
5
+ import click
6
+ import base64
7
+ import pandas as pd
8
+ from PIL import Image
9
+ from io import BytesIO
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ from rich.console import Console
13
+ from typing import Callable, Optional, Tuple, List, Union
14
+
15
+ from ...common.common import now_str
16
+ from ...filetype import csvfile
17
+ from ...system import filesys as fs
18
+
19
+ console = Console()
20
+ desktop_path = os.path.expanduser("~/Desktop")
21
+
22
+ class PlotHelper:
23
+ def _verify_csv(self, csv_file):
24
+ """Read a CSV and normalize column names (lowercase)."""
25
+ try:
26
+ df = csvfile.read_auto_sep(csv_file)
27
+ df.columns = [col.lower() for col in df.columns]
28
+ return df
29
+ except FileNotFoundError:
30
+ raise FileNotFoundError(f"CSV file '{csv_file}' not found")
31
+ except Exception as e:
32
+ raise ValueError(f"Error reading CSV file '{csv_file}': {str(e)}")
33
+
34
+ @staticmethod
35
+ def _norm_str(s):
36
+ """Normalize string by converting to lowercase and replacing spaces/underscores."""
37
+ return s.lower().replace(" ", "_").replace("-", "_")
38
+
39
+ @staticmethod
40
+ def _get_file_name(file_path):
41
+ """Extract file name without extension."""
42
+ return os.path.splitext(os.path.basename(file_path))[0]
43
+
44
+ def _get_valid_tags(self, csv_files, tags):
45
+ """Generate tags from file names if not provided."""
46
+ if tags:
47
+ return list(tags)
48
+ return [self._norm_str(self._get_file_name(f)) for f in csv_files]
49
+
50
+ def _prepare_long_df(self, csv_files, tags, x_col, y_cols, log=False):
51
+ """Convert multiple CSVs into a single long-form dataframe for Plotly."""
52
+ dfs = []
53
+ for csv_file, tag in zip(csv_files, tags):
54
+ df = self._verify_csv(csv_file)
55
+ # Check columns
56
+ if x_col not in df.columns:
57
+ raise ValueError(f"{csv_file} is missing x_col '{x_col}'")
58
+ missing = [c for c in y_cols if c not in df.columns]
59
+ if missing:
60
+ raise ValueError(f"{csv_file} is missing y_cols {missing}")
61
+
62
+ if log:
63
+ console.log(f"Plotting {csv_file}")
64
+ console.print(df)
65
+
66
+ # Wide to long
67
+ df_long = df.melt(
68
+ id_vars=x_col,
69
+ value_vars=y_cols,
70
+ var_name="metric_type",
71
+ value_name="value",
72
+ )
73
+ df_long["tag"] = tag
74
+ dfs.append(df_long)
75
+
76
+ return pd.concat(dfs, ignore_index=True)
77
+
78
+ def _plot_with_plotly(
79
+ self,
80
+ df_long,
81
+ tags,
82
+ outdir,
83
+ save_fig,
84
+ out_fmt="svg",
85
+ font_size=16,
86
+ x_col="epoch",
87
+ y_cols=None,
88
+ ):
89
+ """Generate Plotly plots for given metrics."""
90
+ assert out_fmt in ["svg", "pdf", "png"], "Unsupported format"
91
+ if y_cols is None:
92
+ raise ValueError("y_cols must be provided")
93
+
94
+ # Group by suffix (e.g., "loss", "acc") if names like train_loss exist
95
+ metric_groups = sorted(set(col.split("_")[-1] for col in y_cols))
96
+
97
+ for metric in metric_groups:
98
+ subset = df_long[df_long["metric_type"].str.contains(metric)]
99
+
100
+ if out_fmt == "svg": # LaTeX-style
101
+ title = f"${'+'.join(tags)}\\_{metric}\\text{{-by-{x_col}}}$"
102
+ xaxis_title = f"$\\text{{{x_col.capitalize()}}}$"
103
+ yaxis_title = f"${metric.capitalize()}$"
104
+ else:
105
+ title = f"{'+'.join(tags)}_{metric}-by-{x_col}"
106
+ xaxis_title = x_col.capitalize()
107
+ yaxis_title = metric.capitalize()
108
+
109
+ fig = px.line(
110
+ subset,
111
+ x=x_col,
112
+ y="value",
113
+ color="tag",
114
+ line_dash="metric_type",
115
+ title=title,
116
+ )
117
+ fig.update_layout(
118
+ font=dict(family="Computer Modern", size=font_size),
119
+ xaxis_title=xaxis_title,
120
+ yaxis_title=yaxis_title,
121
+ )
122
+ fig.show()
123
+
124
+ if save_fig:
125
+ os.makedirs(outdir, exist_ok=True)
126
+ timestamp = now_str()
127
+ filename = f"{timestamp}_{'+'.join(tags)}_{metric}"
128
+ try:
129
+ fig.write_image(os.path.join(outdir, f"{filename}.{out_fmt}"))
130
+ except Exception as e:
131
+ console.log(f"Error saving figure '{filename}.{out_fmt}': {str(e)}")
132
+
133
+ @classmethod
134
+ def plot_csv_timeseries(
135
+ cls,
136
+ csv_files,
137
+ outdir="./out/plot",
138
+ tags=None,
139
+ log=False,
140
+ save_fig=False,
141
+ update_in_min=0,
142
+ out_fmt="svg",
143
+ font_size=16,
144
+ x_col="epoch",
145
+ y_cols=["train_loss", "train_acc"],
146
+ ):
147
+ """Plot CSV files with Plotly, supporting live updates, as a class method."""
148
+ if isinstance(csv_files, str):
149
+ csv_files = [csv_files]
150
+ if isinstance(tags, str):
151
+ tags = [tags]
152
+
153
+ if not y_cols:
154
+ raise ValueError("You must specify y_cols explicitly")
155
+
156
+ # Instantiate PlotHelper to call instance methods
157
+ plot_helper = cls()
158
+ valid_tags = plot_helper._get_valid_tags(csv_files, tags)
159
+ assert len(valid_tags) == len(
160
+ csv_files
161
+ ), "Number of tags must match number of CSV files"
162
+
163
+ def run_once():
164
+ df_long = plot_helper._prepare_long_df(
165
+ csv_files, valid_tags, x_col, y_cols, log
166
+ )
167
+ plot_helper._plot_with_plotly(
168
+ df_long, valid_tags, outdir, save_fig, out_fmt, font_size, x_col, y_cols
169
+ )
170
+
171
+ if update_in_min > 0:
172
+ interval = int(update_in_min * 60)
173
+ console.log(f"Live update every {interval}s. Press Ctrl+C to stop.")
174
+ try:
175
+ while True:
176
+ run_once()
177
+ time.sleep(interval)
178
+ except KeyboardInterrupt:
179
+ console.log("Stopped live updates.")
180
+ else:
181
+ run_once()
182
+
183
+ @staticmethod
184
+ def get_img_grid_df(input_dir, log=False):
185
+ """
186
+ Use images in input_dir to create a dataframe for plot_image_grid.
187
+
188
+ Directory structures supported:
189
+
190
+ A. Row/Col structure:
191
+ input_dir/
192
+ ├── row0/
193
+ │ ├── col0/
194
+ │ │ ├── 0.png
195
+ │ │ ├── 1.png
196
+ │ └── col1/
197
+ │ ├── 0.png
198
+ │ ├── 1.png
199
+ ├── row1/
200
+ │ ├── col0/
201
+ │ │ ├── 0.png
202
+ │ │ ├── 1.png
203
+ │ └── col1/
204
+ │ ├── 0.png
205
+ │ ├── 1.png
206
+
207
+ B. Row-only structure (no cols):
208
+ input_dir/
209
+ ├── row0/
210
+ │ ├── 0.png
211
+ │ ├── 1.png
212
+ ├── row1/
213
+ │ ├── 0.png
214
+ │ ├── 1.png
215
+
216
+ Returns:
217
+ pd.DataFrame: DataFrame suitable for plot_image_grid.
218
+ Each cell contains a list of image paths.
219
+ """
220
+ # --- Collect row dirs ---
221
+ rows = sorted([r for r in fs.list_dirs(input_dir) if r.startswith("row")])
222
+ if not rows:
223
+ raise ValueError(f"No 'row*' directories found in {input_dir}")
224
+
225
+ first_row_path = os.path.join(input_dir, rows[0])
226
+ subdirs = fs.list_dirs(first_row_path)
227
+
228
+ if subdirs: # --- Case A: row/col structure ---
229
+ cols_ref = sorted(subdirs)
230
+
231
+ # Ensure column consistency
232
+ meta_dict = {row: sorted(fs.list_dirs(os.path.join(input_dir, row))) for row in rows}
233
+ for row, cols in meta_dict.items():
234
+ if cols != cols_ref:
235
+ raise ValueError(f"Row {row} has mismatched columns: {cols} vs {cols_ref}")
236
+
237
+ # Collect image paths
238
+ meta_with_paths = {
239
+ row: {
240
+ col: fs.filter_files_by_extension(os.path.join(input_dir, row, col), ["png", "jpg", "jpeg"])
241
+ for col in cols_ref
242
+ }
243
+ for row in rows
244
+ }
245
+
246
+ # Validate equal number of images per (row, col)
247
+ n_imgs = len(meta_with_paths[rows[0]][cols_ref[0]])
248
+ for row, cols in meta_with_paths.items():
249
+ for col, paths in cols.items():
250
+ if len(paths) != n_imgs:
251
+ raise ValueError(
252
+ f"Inconsistent file counts in {row}/{col}: {len(paths)} vs expected {n_imgs}"
253
+ )
254
+
255
+ # Flatten long format
256
+ data = {"row": [row for row in rows for _ in range(n_imgs)]}
257
+ for col in cols_ref:
258
+ data[col] = [meta_with_paths[row][col][i] for row in rows for i in range(n_imgs)]
259
+
260
+ else: # --- Case B: row-only structure ---
261
+ meta_with_paths = {
262
+ row: fs.filter_files_by_extension(os.path.join(input_dir, row), ["png", "jpg", "jpeg"])
263
+ for row in rows
264
+ }
265
+
266
+ # Validate equal number of images per row
267
+ n_imgs = len(next(iter(meta_with_paths.values())))
268
+ for row, paths in meta_with_paths.items():
269
+ if len(paths) != n_imgs:
270
+ raise ValueError(f"Inconsistent file counts in {row}: {len(paths)} vs expected {n_imgs}")
271
+
272
+ # Flatten long format (images indexed as img0,img1,...)
273
+ data = {"row": rows}
274
+ for i in range(n_imgs):
275
+ data[f"img{i}"] = [meta_with_paths[row][i] for row in rows]
276
+
277
+ # --- Convert to wide "multi-list" format ---
278
+ df = pd.DataFrame(data)
279
+ row_col = df.columns[0] # first col = row labels
280
+ # col_cols = df.columns[1:] # the rest = groupable cols
281
+
282
+ df = (
283
+ df.melt(id_vars=[row_col], var_name="col", value_name="path")
284
+ .groupby([row_col, "col"])["path"]
285
+ .apply(list)
286
+ .unstack("col")
287
+ .reset_index()
288
+ )
289
+
290
+ if log:
291
+ csvfile.fn_display_df(df)
292
+
293
+ return df
294
+
295
+ @staticmethod
296
+ def _parse_cell_to_list(cell) -> List[str]:
297
+ """Parse a DataFrame cell that may already be a list, a Python-list string, JSON list string,
298
+ or a single path. Returns list[str]."""
299
+ if cell is None:
300
+ return []
301
+ # pandas NA
302
+ try:
303
+ if pd.isna(cell):
304
+ return []
305
+ except Exception:
306
+ pass
307
+
308
+ if isinstance(cell, list):
309
+ return [str(x) for x in cell]
310
+
311
+ if isinstance(cell, (tuple, set)):
312
+ return [str(x) for x in cell]
313
+
314
+ if isinstance(cell, str):
315
+ s = cell.strip()
316
+ if not s:
317
+ return []
318
+
319
+ # Try Python literal (e.g. "['a','b']")
320
+ try:
321
+ val = ast.literal_eval(s)
322
+ if isinstance(val, (list, tuple)):
323
+ return [str(x) for x in val]
324
+ if isinstance(val, str):
325
+ return [val]
326
+ except Exception:
327
+ pass
328
+
329
+ # Try JSON
330
+ try:
331
+ val = json.loads(s)
332
+ if isinstance(val, list):
333
+ return [str(x) for x in val]
334
+ if isinstance(val, str):
335
+ return [val]
336
+ except Exception:
337
+ pass
338
+
339
+ # Fallback: split on common separators
340
+ for sep in [";;", ";", "|", ", "]:
341
+ if sep in s:
342
+ parts = [p.strip() for p in s.split(sep) if p.strip()]
343
+ if parts:
344
+ return parts
345
+
346
+ # Single path string
347
+ return [s]
348
+
349
+ # anything else -> coerce to string
350
+ return [str(cell)]
351
+
352
+ @staticmethod
353
+ def plot_image_grid(
354
+ indir_or_csvf_or_df: Union[str, pd.DataFrame],
355
+ save_path: str = None,
356
+ dpi: int = 300, # DPI for saving raster images or PDF
357
+ show: bool = True, # whether to show the plot in an interactive window
358
+ img_width: int = 300,
359
+ img_height: int = 300,
360
+ img_stack_direction: str = "horizontal", # "horizontal" or "vertical"
361
+ img_stack_padding_px: int = 5,
362
+ img_scale_mode: str = "fit", # "fit" or "fill"
363
+ format_row_label_func: Optional[Callable[[str], str]] = None,
364
+ format_col_label_func: Optional[Callable[[str], str]] = None,
365
+ title: str = "",
366
+ tickfont=dict(size=16, family="Arial", color="black"), # <-- bigger labels
367
+ fig_margin: dict = dict(l=50, r=50, t=50, b=50),
368
+ outline_color: str = "",
369
+ outline_size: int = 1,
370
+ cell_margin_px: int = 10, # padding (top, left, right, bottom) inside each cell
371
+ row_line_size: int = 0, # if >0, draw horizontal dotted lines
372
+ col_line_size: int = 0, # if >0, draw vertical dotted lines
373
+ ) -> go.Figure:
374
+ """
375
+ Plot a grid of images using Plotly.
376
+
377
+ - Accepts DataFrame where each cell is either:
378
+ * a Python list object,
379
+ * a string representation of a Python list (e.g. "['a','b']"),
380
+ * a JSON list string, or
381
+ * a single path string.
382
+ - For each cell, stack the images into a single composite that exactly fits
383
+ (img_width, img_height) is the target size for each individual image in the stack.
384
+ The final cell size will depend on the number of images and stacking direction.
385
+ """
386
+
387
+ def process_image_for_slot(
388
+ path: str,
389
+ target_size: Tuple[int, int],
390
+ scale_mode: str,
391
+ outline: str,
392
+ outline_size: int,
393
+ ) -> Image.Image:
394
+ try:
395
+ img = Image.open(path).convert("RGB")
396
+ except Exception:
397
+ return Image.new("RGB", target_size, (255, 255, 255))
398
+
399
+ if scale_mode == "fit":
400
+ img_ratio = img.width / img.height
401
+ target_ratio = target_size[0] / target_size[1]
402
+
403
+ if img_ratio > target_ratio:
404
+ new_height = target_size[1]
405
+ new_width = max(1, int(new_height * img_ratio))
406
+ else:
407
+ new_width = target_size[0]
408
+ new_height = max(1, int(new_width / img_ratio))
409
+
410
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
411
+ left = (new_width - target_size[0]) // 2
412
+ top = (new_height - target_size[1]) // 2
413
+ right = left + target_size[0]
414
+ bottom = top + target_size[1]
415
+
416
+ if len(outline) == 7 and outline.startswith("#"):
417
+ border_px = outline_size
418
+ bordered = Image.new(
419
+ "RGB",
420
+ (target_size[0] + 2 * border_px, target_size[1] + 2 * border_px),
421
+ outline,
422
+ )
423
+ bordered.paste(
424
+ img.crop((left, top, right, bottom)), (border_px, border_px)
425
+ )
426
+ return bordered
427
+ return img.crop((left, top, right, bottom))
428
+
429
+ elif scale_mode == "fill":
430
+ if len(outline) == 7 and outline.startswith("#"):
431
+ border_px = outline_size
432
+ bordered = Image.new(
433
+ "RGB",
434
+ (target_size[0] + 2 * border_px, target_size[1] + 2 * border_px),
435
+ outline,
436
+ )
437
+ img = img.resize(target_size, Image.Resampling.LANCZOS)
438
+ bordered.paste(img, (border_px, border_px))
439
+ return bordered
440
+ return img.resize(target_size, Image.Resampling.LANCZOS)
441
+ else:
442
+ raise ValueError("img_scale_mode must be 'fit' or 'fill'.")
443
+
444
+ def stack_images_base64(
445
+ image_paths: List[str],
446
+ direction: str,
447
+ single_img_size: Tuple[int, int],
448
+ outline: str,
449
+ outline_size: int,
450
+ padding: int,
451
+ ) -> Tuple[str, Tuple[int, int]]:
452
+ image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
453
+ n = len(image_paths)
454
+ if n == 0:
455
+ blank = Image.new("RGB", single_img_size, (255, 255, 255))
456
+ buf = BytesIO()
457
+ blank.save(buf, format="PNG")
458
+ return (
459
+ "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode(),
460
+ single_img_size,
461
+ )
462
+
463
+ processed = [
464
+ process_image_for_slot(
465
+ p, single_img_size, img_scale_mode, outline, outline_size
466
+ )
467
+ for p in image_paths
468
+ ]
469
+ pad_total = padding * (n - 1)
470
+
471
+ if direction == "horizontal":
472
+ total_w = sum(im.width for im in processed) + pad_total
473
+ total_h = max(im.height for im in processed)
474
+ stacked = Image.new("RGB", (total_w, total_h), (255, 255, 255))
475
+ x = 0
476
+ for im in processed:
477
+ stacked.paste(im, (x, 0))
478
+ x += im.width + padding
479
+ elif direction == "vertical":
480
+ total_w = max(im.width for im in processed)
481
+ total_h = sum(im.height for im in processed) + pad_total
482
+ stacked = Image.new("RGB", (total_w, total_h), (255, 255, 255))
483
+ y = 0
484
+ for im in processed:
485
+ stacked.paste(im, (0, y))
486
+ y += im.height + padding
487
+ else:
488
+ raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
489
+
490
+ buf = BytesIO()
491
+ stacked.save(buf, format="PNG")
492
+ encoded = base64.b64encode(buf.getvalue()).decode()
493
+ return f"data:image/png;base64,{encoded}", (total_w, total_h)
494
+
495
+ def compute_stacked_size(
496
+ image_paths: List[str],
497
+ direction: str,
498
+ single_w: int,
499
+ single_h: int,
500
+ padding: int,
501
+ outline: str,
502
+ outline_size: int,
503
+ ) -> Tuple[int, int]:
504
+ image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
505
+ n = len(image_paths)
506
+ if n == 0:
507
+ return single_w, single_h
508
+ has_outline = len(outline) == 7 and outline.startswith("#")
509
+ border = 2 * outline_size if has_outline else 0
510
+ unit_w = single_w + border
511
+ unit_h = single_h + border
512
+ if direction == "horizontal":
513
+ total_w = n * unit_w + (n - 1) * padding
514
+ total_h = unit_h
515
+ elif direction == "vertical":
516
+ total_w = unit_w
517
+ total_h = n * unit_h + (n - 1) * padding
518
+ else:
519
+ raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
520
+ return total_w, total_h
521
+
522
+ # --- Load DataFrame ---
523
+ if isinstance(indir_or_csvf_or_df, str):
524
+ fname, ext = os.path.splitext(indir_or_csvf_or_df)
525
+ if ext.lower() == ".csv":
526
+ df = pd.read_csv(indir_or_csvf_or_df)
527
+ elif os.path.isdir(indir_or_csvf_or_df):
528
+ df = PlotHelper.img_grid_indir_1(indir_or_csvf_or_df, log=False)
529
+ else:
530
+ raise ValueError("Input string must be a valid CSV file or directory path")
531
+ elif isinstance(indir_or_csvf_or_df, pd.DataFrame):
532
+ df = indir_or_csvf_or_df.copy()
533
+ else:
534
+ raise ValueError("Input must be CSV file path, DataFrame, or directory path")
535
+
536
+ rows = df.iloc[:, 0].astype(str).tolist()
537
+ columns = list(df.columns[1:])
538
+ n_rows, n_cols = len(rows), len(columns)
539
+
540
+ fig = go.Figure()
541
+
542
+ # First pass: compute content sizes
543
+ content_col_max = [0] * n_cols
544
+ content_row_max = [0] * n_rows
545
+ cell_paths = [[None] * n_cols for _ in range(n_rows)]
546
+ for i in range(n_rows):
547
+ for j in range(n_cols):
548
+ raw_cell = df.iloc[i, j + 1]
549
+ paths = PlotHelper._parse_cell_to_list(raw_cell)
550
+ image_paths = [str(p).strip() for p in paths if str(p).strip() != ""]
551
+ cell_paths[i][j] = image_paths
552
+ cw, ch = compute_stacked_size(
553
+ image_paths,
554
+ img_stack_direction,
555
+ img_width,
556
+ img_height,
557
+ img_stack_padding_px,
558
+ outline_color,
559
+ outline_size,
560
+ )
561
+ content_col_max[j] = max(content_col_max[j], cw)
562
+ content_row_max[i] = max(content_row_max[i], ch)
563
+
564
+ # Compute display sizes (content max + padding)
565
+ display_col_w = [content_col_max[j] + 2 * cell_margin_px for j in range(n_cols)]
566
+ display_row_h = [content_row_max[i] + 2 * cell_margin_px for i in range(n_rows)]
567
+
568
+ # Compute positions (cells adjacent)
569
+ x_positions = []
570
+ cum_w = 0
571
+ for dw in display_col_w:
572
+ x_positions.append(cum_w)
573
+ cum_w += dw
574
+
575
+ y_positions = []
576
+ cum_h = 0
577
+ for dh in display_row_h:
578
+ y_positions.append(-cum_h)
579
+ cum_h += dh
580
+
581
+ # Second pass: create padded images (centered content)
582
+ cell_imgs = [[None] * n_cols for _ in range(n_rows)]
583
+ p = cell_margin_px
584
+ for i in range(n_rows):
585
+ for j in range(n_cols):
586
+ image_paths = cell_paths[i][j]
587
+ content_src, (cw, ch) = stack_images_base64(
588
+ image_paths,
589
+ img_stack_direction,
590
+ (img_width, img_height),
591
+ outline_color,
592
+ outline_size,
593
+ img_stack_padding_px,
594
+ )
595
+ if cw == 0 or ch == 0:
596
+ # Skip empty, but create white padded
597
+ pad_w = display_col_w[j]
598
+ pad_h = display_row_h[i]
599
+ padded = Image.new("RGB", (pad_w, pad_h), (255, 255, 255))
600
+ else:
601
+ content_img = Image.open(
602
+ BytesIO(base64.b64decode(content_src.split(",")[1]))
603
+ )
604
+ ca_w = content_col_max[j]
605
+ ca_h = content_row_max[i]
606
+ left_offset = (ca_w - cw) // 2
607
+ top_offset = (ca_h - ch) // 2
608
+ pad_w = display_col_w[j]
609
+ pad_h = display_row_h[i]
610
+ padded = Image.new("RGB", (pad_w, pad_h), (255, 255, 255))
611
+ paste_x = p + left_offset
612
+ paste_y = p + top_offset
613
+ padded.paste(content_img, (paste_x, paste_y))
614
+ buf = BytesIO()
615
+ padded.save(buf, format="PNG")
616
+ encoded = base64.b64encode(buf.getvalue()).decode()
617
+ cell_imgs[i][j] = f"data:image/png;base64,{encoded}"
618
+
619
+ # Add images to figure
620
+ for i in range(n_rows):
621
+ for j in range(n_cols):
622
+ fig.add_layout_image(
623
+ dict(
624
+ source=cell_imgs[i][j],
625
+ x=x_positions[j],
626
+ y=y_positions[i],
627
+ xref="x",
628
+ yref="y",
629
+ sizex=display_col_w[j],
630
+ sizey=display_row_h[i],
631
+ xanchor="left",
632
+ yanchor="top",
633
+ layer="above",
634
+ )
635
+ )
636
+
637
+ # Optional grid lines (at cell boundaries, adjusted for inter-content spaces)
638
+ if row_line_size > 0:
639
+ for i in range(1, n_rows):
640
+ y = (y_positions[i - 1] - display_row_h[i - 1] + y_positions[i]) / 2
641
+ fig.add_shape(
642
+ type="line",
643
+ x0=-p,
644
+ x1=cum_w,
645
+ y0=y,
646
+ y1=y,
647
+ line=dict(width=row_line_size, color="black", dash="dot"),
648
+ )
649
+
650
+ if col_line_size > 0:
651
+ for j in range(1, n_cols):
652
+ x = x_positions[j]
653
+ fig.add_shape(
654
+ type="line",
655
+ x0=x,
656
+ x1=x,
657
+ y0=p,
658
+ y1=-cum_h,
659
+ line=dict(width=col_line_size, color="black", dash="dot"),
660
+ )
661
+
662
+ # Axis labels
663
+ col_labels = [
664
+ format_col_label_func(c) if format_col_label_func else c for c in columns
665
+ ]
666
+ row_labels = [
667
+ format_row_label_func(r) if format_row_label_func else r for r in rows
668
+ ]
669
+
670
+ fig.update_xaxes(
671
+ tickvals=[x_positions[j] + display_col_w[j] / 2 for j in range(n_cols)],
672
+ ticktext=col_labels,
673
+ range=[-p, cum_w],
674
+ showgrid=False,
675
+ zeroline=False,
676
+ tickfont=tickfont, # <-- apply bigger font here
677
+ )
678
+ fig.update_yaxes(
679
+ tickvals=[y_positions[i] - display_row_h[i] / 2 for i in range(n_rows)],
680
+ ticktext=row_labels,
681
+ range=[-cum_h, p],
682
+ showgrid=False,
683
+ zeroline=False,
684
+ tickfont=tickfont, # <-- apply bigger font here
685
+ )
686
+
687
+ fig.update_layout(
688
+ width=cum_w + 100,
689
+ height=cum_h + 100,
690
+ title=title,
691
+ title_x=0.5,
692
+ margin=fig_margin,
693
+ )
694
+
695
+ # === EXPORT IF save_path IS GIVEN ===
696
+ if save_path:
697
+ import kaleido # lazy import – only needed when saving
698
+ import os
699
+
700
+ ext = os.path.splitext(save_path)[1].lower()
701
+ if ext in [".png", ".jpg", ".jpeg"]:
702
+ fig.write_image(save_path, scale=dpi / 96) # scale = dpi / base 96
703
+ elif ext in [".pdf", ".svg"]:
704
+ fig.write_image(save_path) # PDF/SVG are vector → dpi ignored
705
+ else:
706
+ raise ValueError("save_path must end with .png, .jpg, .pdf, or .svg")
707
+ if show:
708
+ fig.show()
709
+ return fig
710
+
711
+
712
+ @click.command()
713
+ @click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
714
+ @click.option(
715
+ "--outdir", "-o", type=str, default=str(desktop_path), help="output directory"
716
+ )
717
+ @click.option(
718
+ "--tags", "-t", multiple=True, type=str, default=[], help="tags for the csv files"
719
+ )
720
+ @click.option("--log", "-l", is_flag=True, help="log the csv files")
721
+ @click.option("--save_fig", "-s", is_flag=True, help="save the plot as file")
722
+ @click.option(
723
+ "--update_in_min",
724
+ "-u",
725
+ type=float,
726
+ default=0.0,
727
+ help="update the plot every x minutes",
728
+ )
729
+ @click.option(
730
+ "--x_col", "-x", type=str, default="epoch", help="column to use as x-axis"
731
+ )
732
+ @click.option(
733
+ "--y_cols",
734
+ "-y",
735
+ multiple=True,
736
+ type=str,
737
+ required=True,
738
+ help="columns to plot as y (can repeat)",
739
+ )
740
+ def main(csvfiles, outdir, tags, log, save_fig, update_in_min, x_col, y_cols):
741
+ PlotHelper.plot_csv_timeseries(
742
+ list(csvfiles),
743
+ outdir,
744
+ list(tags),
745
+ log,
746
+ save_fig,
747
+ update_in_min,
748
+ x_col=x_col,
749
+ y_cols=list(y_cols),
750
+ )
751
+
752
+
753
+ if __name__ == "__main__":
754
+ main()