halib 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/__init__.py +94 -0
- halib/common/__init__.py +0 -0
- halib/common/common.py +326 -0
- halib/common/rich_color.py +285 -0
- halib/common.py +151 -0
- halib/csvfile.py +48 -0
- halib/cuda.py +39 -0
- halib/dataset.py +209 -0
- halib/exp/__init__.py +0 -0
- halib/exp/core/__init__.py +0 -0
- halib/exp/core/base_config.py +167 -0
- halib/exp/core/base_exp.py +147 -0
- halib/exp/core/param_gen.py +170 -0
- halib/exp/core/wandb_op.py +117 -0
- halib/exp/data/__init__.py +0 -0
- halib/exp/data/dataclass_util.py +41 -0
- halib/exp/data/dataset.py +208 -0
- halib/exp/data/torchloader.py +165 -0
- halib/exp/perf/__init__.py +0 -0
- halib/exp/perf/flop_calc.py +190 -0
- halib/exp/perf/gpu_mon.py +58 -0
- halib/exp/perf/perfcalc.py +470 -0
- halib/exp/perf/perfmetrics.py +137 -0
- halib/exp/perf/perftb.py +778 -0
- halib/exp/perf/profiler.py +507 -0
- halib/exp/viz/__init__.py +0 -0
- halib/exp/viz/plot.py +754 -0
- halib/filesys.py +117 -0
- halib/filetype/__init__.py +0 -0
- halib/filetype/csvfile.py +192 -0
- halib/filetype/ipynb.py +61 -0
- halib/filetype/jsonfile.py +19 -0
- halib/filetype/textfile.py +12 -0
- halib/filetype/videofile.py +266 -0
- halib/filetype/yamlfile.py +87 -0
- halib/gdrive.py +179 -0
- halib/gdrive_mkdir.py +41 -0
- halib/gdrive_test.py +37 -0
- halib/jsonfile.py +22 -0
- halib/listop.py +13 -0
- halib/online/__init__.py +0 -0
- halib/online/gdrive.py +229 -0
- halib/online/gdrive_mkdir.py +53 -0
- halib/online/gdrive_test.py +50 -0
- halib/online/projectmake.py +131 -0
- halib/online/tele_noti.py +165 -0
- halib/plot.py +301 -0
- halib/projectmake.py +115 -0
- halib/research/__init__.py +0 -0
- halib/research/base_config.py +100 -0
- halib/research/base_exp.py +157 -0
- halib/research/benchquery.py +131 -0
- halib/research/core/__init__.py +0 -0
- halib/research/core/base_config.py +144 -0
- halib/research/core/base_exp.py +157 -0
- halib/research/core/param_gen.py +108 -0
- halib/research/core/wandb_op.py +117 -0
- halib/research/data/__init__.py +0 -0
- halib/research/data/dataclass_util.py +41 -0
- halib/research/data/dataset.py +208 -0
- halib/research/data/torchloader.py +165 -0
- halib/research/dataset.py +208 -0
- halib/research/flop_csv.py +34 -0
- halib/research/flops.py +156 -0
- halib/research/metrics.py +137 -0
- halib/research/mics.py +74 -0
- halib/research/params_gen.py +108 -0
- halib/research/perf/__init__.py +0 -0
- halib/research/perf/flop_calc.py +190 -0
- halib/research/perf/gpu_mon.py +58 -0
- halib/research/perf/perfcalc.py +363 -0
- halib/research/perf/perfmetrics.py +137 -0
- halib/research/perf/perftb.py +778 -0
- halib/research/perf/profiler.py +301 -0
- halib/research/perfcalc.py +361 -0
- halib/research/perftb.py +780 -0
- halib/research/plot.py +758 -0
- halib/research/profiler.py +300 -0
- halib/research/torchloader.py +162 -0
- halib/research/viz/__init__.py +0 -0
- halib/research/viz/plot.py +754 -0
- halib/research/wandb_op.py +116 -0
- halib/rich_color.py +285 -0
- halib/sys/__init__.py +0 -0
- halib/sys/cmd.py +8 -0
- halib/sys/filesys.py +124 -0
- halib/system/__init__.py +0 -0
- halib/system/_list_pc.csv +6 -0
- halib/system/cmd.py +8 -0
- halib/system/filesys.py +164 -0
- halib/system/path.py +106 -0
- halib/tele_noti.py +166 -0
- halib/textfile.py +13 -0
- halib/torchloader.py +162 -0
- halib/utils/__init__.py +0 -0
- halib/utils/dataclass_util.py +40 -0
- halib/utils/dict.py +317 -0
- halib/utils/dict_op.py +9 -0
- halib/utils/gpu_mon.py +58 -0
- halib/utils/list.py +17 -0
- halib/utils/listop.py +13 -0
- halib/utils/slack.py +86 -0
- halib/utils/tele_noti.py +166 -0
- halib/utils/video.py +82 -0
- halib/videofile.py +139 -0
- halib-0.2.30.dist-info/METADATA +237 -0
- halib-0.2.30.dist-info/RECORD +110 -0
- halib-0.2.30.dist-info/WHEEL +5 -0
- halib-0.2.30.dist-info/licenses/LICENSE.txt +17 -0
- halib-0.2.30.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,754 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
import click
|
|
6
|
+
import base64
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from io import BytesIO
|
|
10
|
+
import plotly.express as px
|
|
11
|
+
import plotly.graph_objects as go
|
|
12
|
+
from rich.console import Console
|
|
13
|
+
from typing import Callable, Optional, Tuple, List, Union
|
|
14
|
+
|
|
15
|
+
from ...common.common import now_str
|
|
16
|
+
from ...filetype import csvfile
|
|
17
|
+
from ...system import filesys as fs
|
|
18
|
+
|
|
19
|
+
console = Console()
|
|
20
|
+
desktop_path = os.path.expanduser("~/Desktop")
|
|
21
|
+
|
|
22
|
+
class PlotHelper:
|
|
23
|
+
def _verify_csv(self, csv_file):
|
|
24
|
+
"""Read a CSV and normalize column names (lowercase)."""
|
|
25
|
+
try:
|
|
26
|
+
df = csvfile.read_auto_sep(csv_file)
|
|
27
|
+
df.columns = [col.lower() for col in df.columns]
|
|
28
|
+
return df
|
|
29
|
+
except FileNotFoundError:
|
|
30
|
+
raise FileNotFoundError(f"CSV file '{csv_file}' not found")
|
|
31
|
+
except Exception as e:
|
|
32
|
+
raise ValueError(f"Error reading CSV file '{csv_file}': {str(e)}")
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def _norm_str(s):
|
|
36
|
+
"""Normalize string by converting to lowercase and replacing spaces/underscores."""
|
|
37
|
+
return s.lower().replace(" ", "_").replace("-", "_")
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def _get_file_name(file_path):
|
|
41
|
+
"""Extract file name without extension."""
|
|
42
|
+
return os.path.splitext(os.path.basename(file_path))[0]
|
|
43
|
+
|
|
44
|
+
def _get_valid_tags(self, csv_files, tags):
|
|
45
|
+
"""Generate tags from file names if not provided."""
|
|
46
|
+
if tags:
|
|
47
|
+
return list(tags)
|
|
48
|
+
return [self._norm_str(self._get_file_name(f)) for f in csv_files]
|
|
49
|
+
|
|
50
|
+
def _prepare_long_df(self, csv_files, tags, x_col, y_cols, log=False):
|
|
51
|
+
"""Convert multiple CSVs into a single long-form dataframe for Plotly."""
|
|
52
|
+
dfs = []
|
|
53
|
+
for csv_file, tag in zip(csv_files, tags):
|
|
54
|
+
df = self._verify_csv(csv_file)
|
|
55
|
+
# Check columns
|
|
56
|
+
if x_col not in df.columns:
|
|
57
|
+
raise ValueError(f"{csv_file} is missing x_col '{x_col}'")
|
|
58
|
+
missing = [c for c in y_cols if c not in df.columns]
|
|
59
|
+
if missing:
|
|
60
|
+
raise ValueError(f"{csv_file} is missing y_cols {missing}")
|
|
61
|
+
|
|
62
|
+
if log:
|
|
63
|
+
console.log(f"Plotting {csv_file}")
|
|
64
|
+
console.print(df)
|
|
65
|
+
|
|
66
|
+
# Wide to long
|
|
67
|
+
df_long = df.melt(
|
|
68
|
+
id_vars=x_col,
|
|
69
|
+
value_vars=y_cols,
|
|
70
|
+
var_name="metric_type",
|
|
71
|
+
value_name="value",
|
|
72
|
+
)
|
|
73
|
+
df_long["tag"] = tag
|
|
74
|
+
dfs.append(df_long)
|
|
75
|
+
|
|
76
|
+
return pd.concat(dfs, ignore_index=True)
|
|
77
|
+
|
|
78
|
+
def _plot_with_plotly(
|
|
79
|
+
self,
|
|
80
|
+
df_long,
|
|
81
|
+
tags,
|
|
82
|
+
outdir,
|
|
83
|
+
save_fig,
|
|
84
|
+
out_fmt="svg",
|
|
85
|
+
font_size=16,
|
|
86
|
+
x_col="epoch",
|
|
87
|
+
y_cols=None,
|
|
88
|
+
):
|
|
89
|
+
"""Generate Plotly plots for given metrics."""
|
|
90
|
+
assert out_fmt in ["svg", "pdf", "png"], "Unsupported format"
|
|
91
|
+
if y_cols is None:
|
|
92
|
+
raise ValueError("y_cols must be provided")
|
|
93
|
+
|
|
94
|
+
# Group by suffix (e.g., "loss", "acc") if names like train_loss exist
|
|
95
|
+
metric_groups = sorted(set(col.split("_")[-1] for col in y_cols))
|
|
96
|
+
|
|
97
|
+
for metric in metric_groups:
|
|
98
|
+
subset = df_long[df_long["metric_type"].str.contains(metric)]
|
|
99
|
+
|
|
100
|
+
if out_fmt == "svg": # LaTeX-style
|
|
101
|
+
title = f"${'+'.join(tags)}\\_{metric}\\text{{-by-{x_col}}}$"
|
|
102
|
+
xaxis_title = f"$\\text{{{x_col.capitalize()}}}$"
|
|
103
|
+
yaxis_title = f"${metric.capitalize()}$"
|
|
104
|
+
else:
|
|
105
|
+
title = f"{'+'.join(tags)}_{metric}-by-{x_col}"
|
|
106
|
+
xaxis_title = x_col.capitalize()
|
|
107
|
+
yaxis_title = metric.capitalize()
|
|
108
|
+
|
|
109
|
+
fig = px.line(
|
|
110
|
+
subset,
|
|
111
|
+
x=x_col,
|
|
112
|
+
y="value",
|
|
113
|
+
color="tag",
|
|
114
|
+
line_dash="metric_type",
|
|
115
|
+
title=title,
|
|
116
|
+
)
|
|
117
|
+
fig.update_layout(
|
|
118
|
+
font=dict(family="Computer Modern", size=font_size),
|
|
119
|
+
xaxis_title=xaxis_title,
|
|
120
|
+
yaxis_title=yaxis_title,
|
|
121
|
+
)
|
|
122
|
+
fig.show()
|
|
123
|
+
|
|
124
|
+
if save_fig:
|
|
125
|
+
os.makedirs(outdir, exist_ok=True)
|
|
126
|
+
timestamp = now_str()
|
|
127
|
+
filename = f"{timestamp}_{'+'.join(tags)}_{metric}"
|
|
128
|
+
try:
|
|
129
|
+
fig.write_image(os.path.join(outdir, f"{filename}.{out_fmt}"))
|
|
130
|
+
except Exception as e:
|
|
131
|
+
console.log(f"Error saving figure '{filename}.{out_fmt}': {str(e)}")
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def plot_csv_timeseries(
|
|
135
|
+
cls,
|
|
136
|
+
csv_files,
|
|
137
|
+
outdir="./out/plot",
|
|
138
|
+
tags=None,
|
|
139
|
+
log=False,
|
|
140
|
+
save_fig=False,
|
|
141
|
+
update_in_min=0,
|
|
142
|
+
out_fmt="svg",
|
|
143
|
+
font_size=16,
|
|
144
|
+
x_col="epoch",
|
|
145
|
+
y_cols=["train_loss", "train_acc"],
|
|
146
|
+
):
|
|
147
|
+
"""Plot CSV files with Plotly, supporting live updates, as a class method."""
|
|
148
|
+
if isinstance(csv_files, str):
|
|
149
|
+
csv_files = [csv_files]
|
|
150
|
+
if isinstance(tags, str):
|
|
151
|
+
tags = [tags]
|
|
152
|
+
|
|
153
|
+
if not y_cols:
|
|
154
|
+
raise ValueError("You must specify y_cols explicitly")
|
|
155
|
+
|
|
156
|
+
# Instantiate PlotHelper to call instance methods
|
|
157
|
+
plot_helper = cls()
|
|
158
|
+
valid_tags = plot_helper._get_valid_tags(csv_files, tags)
|
|
159
|
+
assert len(valid_tags) == len(
|
|
160
|
+
csv_files
|
|
161
|
+
), "Number of tags must match number of CSV files"
|
|
162
|
+
|
|
163
|
+
def run_once():
|
|
164
|
+
df_long = plot_helper._prepare_long_df(
|
|
165
|
+
csv_files, valid_tags, x_col, y_cols, log
|
|
166
|
+
)
|
|
167
|
+
plot_helper._plot_with_plotly(
|
|
168
|
+
df_long, valid_tags, outdir, save_fig, out_fmt, font_size, x_col, y_cols
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if update_in_min > 0:
|
|
172
|
+
interval = int(update_in_min * 60)
|
|
173
|
+
console.log(f"Live update every {interval}s. Press Ctrl+C to stop.")
|
|
174
|
+
try:
|
|
175
|
+
while True:
|
|
176
|
+
run_once()
|
|
177
|
+
time.sleep(interval)
|
|
178
|
+
except KeyboardInterrupt:
|
|
179
|
+
console.log("Stopped live updates.")
|
|
180
|
+
else:
|
|
181
|
+
run_once()
|
|
182
|
+
|
|
183
|
+
@staticmethod
|
|
184
|
+
def get_img_grid_df(input_dir, log=False):
|
|
185
|
+
"""
|
|
186
|
+
Use images in input_dir to create a dataframe for plot_image_grid.
|
|
187
|
+
|
|
188
|
+
Directory structures supported:
|
|
189
|
+
|
|
190
|
+
A. Row/Col structure:
|
|
191
|
+
input_dir/
|
|
192
|
+
├── row0/
|
|
193
|
+
│ ├── col0/
|
|
194
|
+
│ │ ├── 0.png
|
|
195
|
+
│ │ ├── 1.png
|
|
196
|
+
│ └── col1/
|
|
197
|
+
│ ├── 0.png
|
|
198
|
+
│ ├── 1.png
|
|
199
|
+
├── row1/
|
|
200
|
+
│ ├── col0/
|
|
201
|
+
│ │ ├── 0.png
|
|
202
|
+
│ │ ├── 1.png
|
|
203
|
+
│ └── col1/
|
|
204
|
+
│ ├── 0.png
|
|
205
|
+
│ ├── 1.png
|
|
206
|
+
|
|
207
|
+
B. Row-only structure (no cols):
|
|
208
|
+
input_dir/
|
|
209
|
+
├── row0/
|
|
210
|
+
│ ├── 0.png
|
|
211
|
+
│ ├── 1.png
|
|
212
|
+
├── row1/
|
|
213
|
+
│ ├── 0.png
|
|
214
|
+
│ ├── 1.png
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
pd.DataFrame: DataFrame suitable for plot_image_grid.
|
|
218
|
+
Each cell contains a list of image paths.
|
|
219
|
+
"""
|
|
220
|
+
# --- Collect row dirs ---
|
|
221
|
+
rows = sorted([r for r in fs.list_dirs(input_dir) if r.startswith("row")])
|
|
222
|
+
if not rows:
|
|
223
|
+
raise ValueError(f"No 'row*' directories found in {input_dir}")
|
|
224
|
+
|
|
225
|
+
first_row_path = os.path.join(input_dir, rows[0])
|
|
226
|
+
subdirs = fs.list_dirs(first_row_path)
|
|
227
|
+
|
|
228
|
+
if subdirs: # --- Case A: row/col structure ---
|
|
229
|
+
cols_ref = sorted(subdirs)
|
|
230
|
+
|
|
231
|
+
# Ensure column consistency
|
|
232
|
+
meta_dict = {row: sorted(fs.list_dirs(os.path.join(input_dir, row))) for row in rows}
|
|
233
|
+
for row, cols in meta_dict.items():
|
|
234
|
+
if cols != cols_ref:
|
|
235
|
+
raise ValueError(f"Row {row} has mismatched columns: {cols} vs {cols_ref}")
|
|
236
|
+
|
|
237
|
+
# Collect image paths
|
|
238
|
+
meta_with_paths = {
|
|
239
|
+
row: {
|
|
240
|
+
col: fs.filter_files_by_extension(os.path.join(input_dir, row, col), ["png", "jpg", "jpeg"])
|
|
241
|
+
for col in cols_ref
|
|
242
|
+
}
|
|
243
|
+
for row in rows
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
# Validate equal number of images per (row, col)
|
|
247
|
+
n_imgs = len(meta_with_paths[rows[0]][cols_ref[0]])
|
|
248
|
+
for row, cols in meta_with_paths.items():
|
|
249
|
+
for col, paths in cols.items():
|
|
250
|
+
if len(paths) != n_imgs:
|
|
251
|
+
raise ValueError(
|
|
252
|
+
f"Inconsistent file counts in {row}/{col}: {len(paths)} vs expected {n_imgs}"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Flatten long format
|
|
256
|
+
data = {"row": [row for row in rows for _ in range(n_imgs)]}
|
|
257
|
+
for col in cols_ref:
|
|
258
|
+
data[col] = [meta_with_paths[row][col][i] for row in rows for i in range(n_imgs)]
|
|
259
|
+
|
|
260
|
+
else: # --- Case B: row-only structure ---
|
|
261
|
+
meta_with_paths = {
|
|
262
|
+
row: fs.filter_files_by_extension(os.path.join(input_dir, row), ["png", "jpg", "jpeg"])
|
|
263
|
+
for row in rows
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
# Validate equal number of images per row
|
|
267
|
+
n_imgs = len(next(iter(meta_with_paths.values())))
|
|
268
|
+
for row, paths in meta_with_paths.items():
|
|
269
|
+
if len(paths) != n_imgs:
|
|
270
|
+
raise ValueError(f"Inconsistent file counts in {row}: {len(paths)} vs expected {n_imgs}")
|
|
271
|
+
|
|
272
|
+
# Flatten long format (images indexed as img0,img1,...)
|
|
273
|
+
data = {"row": rows}
|
|
274
|
+
for i in range(n_imgs):
|
|
275
|
+
data[f"img{i}"] = [meta_with_paths[row][i] for row in rows]
|
|
276
|
+
|
|
277
|
+
# --- Convert to wide "multi-list" format ---
|
|
278
|
+
df = pd.DataFrame(data)
|
|
279
|
+
row_col = df.columns[0] # first col = row labels
|
|
280
|
+
# col_cols = df.columns[1:] # the rest = groupable cols
|
|
281
|
+
|
|
282
|
+
df = (
|
|
283
|
+
df.melt(id_vars=[row_col], var_name="col", value_name="path")
|
|
284
|
+
.groupby([row_col, "col"])["path"]
|
|
285
|
+
.apply(list)
|
|
286
|
+
.unstack("col")
|
|
287
|
+
.reset_index()
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
if log:
|
|
291
|
+
csvfile.fn_display_df(df)
|
|
292
|
+
|
|
293
|
+
return df
|
|
294
|
+
|
|
295
|
+
@staticmethod
|
|
296
|
+
def _parse_cell_to_list(cell) -> List[str]:
|
|
297
|
+
"""Parse a DataFrame cell that may already be a list, a Python-list string, JSON list string,
|
|
298
|
+
or a single path. Returns list[str]."""
|
|
299
|
+
if cell is None:
|
|
300
|
+
return []
|
|
301
|
+
# pandas NA
|
|
302
|
+
try:
|
|
303
|
+
if pd.isna(cell):
|
|
304
|
+
return []
|
|
305
|
+
except Exception:
|
|
306
|
+
pass
|
|
307
|
+
|
|
308
|
+
if isinstance(cell, list):
|
|
309
|
+
return [str(x) for x in cell]
|
|
310
|
+
|
|
311
|
+
if isinstance(cell, (tuple, set)):
|
|
312
|
+
return [str(x) for x in cell]
|
|
313
|
+
|
|
314
|
+
if isinstance(cell, str):
|
|
315
|
+
s = cell.strip()
|
|
316
|
+
if not s:
|
|
317
|
+
return []
|
|
318
|
+
|
|
319
|
+
# Try Python literal (e.g. "['a','b']")
|
|
320
|
+
try:
|
|
321
|
+
val = ast.literal_eval(s)
|
|
322
|
+
if isinstance(val, (list, tuple)):
|
|
323
|
+
return [str(x) for x in val]
|
|
324
|
+
if isinstance(val, str):
|
|
325
|
+
return [val]
|
|
326
|
+
except Exception:
|
|
327
|
+
pass
|
|
328
|
+
|
|
329
|
+
# Try JSON
|
|
330
|
+
try:
|
|
331
|
+
val = json.loads(s)
|
|
332
|
+
if isinstance(val, list):
|
|
333
|
+
return [str(x) for x in val]
|
|
334
|
+
if isinstance(val, str):
|
|
335
|
+
return [val]
|
|
336
|
+
except Exception:
|
|
337
|
+
pass
|
|
338
|
+
|
|
339
|
+
# Fallback: split on common separators
|
|
340
|
+
for sep in [";;", ";", "|", ", "]:
|
|
341
|
+
if sep in s:
|
|
342
|
+
parts = [p.strip() for p in s.split(sep) if p.strip()]
|
|
343
|
+
if parts:
|
|
344
|
+
return parts
|
|
345
|
+
|
|
346
|
+
# Single path string
|
|
347
|
+
return [s]
|
|
348
|
+
|
|
349
|
+
# anything else -> coerce to string
|
|
350
|
+
return [str(cell)]
|
|
351
|
+
|
|
352
|
+
@staticmethod
|
|
353
|
+
def plot_image_grid(
|
|
354
|
+
indir_or_csvf_or_df: Union[str, pd.DataFrame],
|
|
355
|
+
save_path: str = None,
|
|
356
|
+
dpi: int = 300, # DPI for saving raster images or PDF
|
|
357
|
+
show: bool = True, # whether to show the plot in an interactive window
|
|
358
|
+
img_width: int = 300,
|
|
359
|
+
img_height: int = 300,
|
|
360
|
+
img_stack_direction: str = "horizontal", # "horizontal" or "vertical"
|
|
361
|
+
img_stack_padding_px: int = 5,
|
|
362
|
+
img_scale_mode: str = "fit", # "fit" or "fill"
|
|
363
|
+
format_row_label_func: Optional[Callable[[str], str]] = None,
|
|
364
|
+
format_col_label_func: Optional[Callable[[str], str]] = None,
|
|
365
|
+
title: str = "",
|
|
366
|
+
tickfont=dict(size=16, family="Arial", color="black"), # <-- bigger labels
|
|
367
|
+
fig_margin: dict = dict(l=50, r=50, t=50, b=50),
|
|
368
|
+
outline_color: str = "",
|
|
369
|
+
outline_size: int = 1,
|
|
370
|
+
cell_margin_px: int = 10, # padding (top, left, right, bottom) inside each cell
|
|
371
|
+
row_line_size: int = 0, # if >0, draw horizontal dotted lines
|
|
372
|
+
col_line_size: int = 0, # if >0, draw vertical dotted lines
|
|
373
|
+
) -> go.Figure:
|
|
374
|
+
"""
|
|
375
|
+
Plot a grid of images using Plotly.
|
|
376
|
+
|
|
377
|
+
- Accepts DataFrame where each cell is either:
|
|
378
|
+
* a Python list object,
|
|
379
|
+
* a string representation of a Python list (e.g. "['a','b']"),
|
|
380
|
+
* a JSON list string, or
|
|
381
|
+
* a single path string.
|
|
382
|
+
- For each cell, stack the images into a single composite that exactly fits
|
|
383
|
+
(img_width, img_height) is the target size for each individual image in the stack.
|
|
384
|
+
The final cell size will depend on the number of images and stacking direction.
|
|
385
|
+
"""
|
|
386
|
+
|
|
387
|
+
def process_image_for_slot(
|
|
388
|
+
path: str,
|
|
389
|
+
target_size: Tuple[int, int],
|
|
390
|
+
scale_mode: str,
|
|
391
|
+
outline: str,
|
|
392
|
+
outline_size: int,
|
|
393
|
+
) -> Image.Image:
|
|
394
|
+
try:
|
|
395
|
+
img = Image.open(path).convert("RGB")
|
|
396
|
+
except Exception:
|
|
397
|
+
return Image.new("RGB", target_size, (255, 255, 255))
|
|
398
|
+
|
|
399
|
+
if scale_mode == "fit":
|
|
400
|
+
img_ratio = img.width / img.height
|
|
401
|
+
target_ratio = target_size[0] / target_size[1]
|
|
402
|
+
|
|
403
|
+
if img_ratio > target_ratio:
|
|
404
|
+
new_height = target_size[1]
|
|
405
|
+
new_width = max(1, int(new_height * img_ratio))
|
|
406
|
+
else:
|
|
407
|
+
new_width = target_size[0]
|
|
408
|
+
new_height = max(1, int(new_width / img_ratio))
|
|
409
|
+
|
|
410
|
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
411
|
+
left = (new_width - target_size[0]) // 2
|
|
412
|
+
top = (new_height - target_size[1]) // 2
|
|
413
|
+
right = left + target_size[0]
|
|
414
|
+
bottom = top + target_size[1]
|
|
415
|
+
|
|
416
|
+
if len(outline) == 7 and outline.startswith("#"):
|
|
417
|
+
border_px = outline_size
|
|
418
|
+
bordered = Image.new(
|
|
419
|
+
"RGB",
|
|
420
|
+
(target_size[0] + 2 * border_px, target_size[1] + 2 * border_px),
|
|
421
|
+
outline,
|
|
422
|
+
)
|
|
423
|
+
bordered.paste(
|
|
424
|
+
img.crop((left, top, right, bottom)), (border_px, border_px)
|
|
425
|
+
)
|
|
426
|
+
return bordered
|
|
427
|
+
return img.crop((left, top, right, bottom))
|
|
428
|
+
|
|
429
|
+
elif scale_mode == "fill":
|
|
430
|
+
if len(outline) == 7 and outline.startswith("#"):
|
|
431
|
+
border_px = outline_size
|
|
432
|
+
bordered = Image.new(
|
|
433
|
+
"RGB",
|
|
434
|
+
(target_size[0] + 2 * border_px, target_size[1] + 2 * border_px),
|
|
435
|
+
outline,
|
|
436
|
+
)
|
|
437
|
+
img = img.resize(target_size, Image.Resampling.LANCZOS)
|
|
438
|
+
bordered.paste(img, (border_px, border_px))
|
|
439
|
+
return bordered
|
|
440
|
+
return img.resize(target_size, Image.Resampling.LANCZOS)
|
|
441
|
+
else:
|
|
442
|
+
raise ValueError("img_scale_mode must be 'fit' or 'fill'.")
|
|
443
|
+
|
|
444
|
+
def stack_images_base64(
|
|
445
|
+
image_paths: List[str],
|
|
446
|
+
direction: str,
|
|
447
|
+
single_img_size: Tuple[int, int],
|
|
448
|
+
outline: str,
|
|
449
|
+
outline_size: int,
|
|
450
|
+
padding: int,
|
|
451
|
+
) -> Tuple[str, Tuple[int, int]]:
|
|
452
|
+
image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
|
|
453
|
+
n = len(image_paths)
|
|
454
|
+
if n == 0:
|
|
455
|
+
blank = Image.new("RGB", single_img_size, (255, 255, 255))
|
|
456
|
+
buf = BytesIO()
|
|
457
|
+
blank.save(buf, format="PNG")
|
|
458
|
+
return (
|
|
459
|
+
"data:image/png;base64," + base64.b64encode(buf.getvalue()).decode(),
|
|
460
|
+
single_img_size,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
processed = [
|
|
464
|
+
process_image_for_slot(
|
|
465
|
+
p, single_img_size, img_scale_mode, outline, outline_size
|
|
466
|
+
)
|
|
467
|
+
for p in image_paths
|
|
468
|
+
]
|
|
469
|
+
pad_total = padding * (n - 1)
|
|
470
|
+
|
|
471
|
+
if direction == "horizontal":
|
|
472
|
+
total_w = sum(im.width for im in processed) + pad_total
|
|
473
|
+
total_h = max(im.height for im in processed)
|
|
474
|
+
stacked = Image.new("RGB", (total_w, total_h), (255, 255, 255))
|
|
475
|
+
x = 0
|
|
476
|
+
for im in processed:
|
|
477
|
+
stacked.paste(im, (x, 0))
|
|
478
|
+
x += im.width + padding
|
|
479
|
+
elif direction == "vertical":
|
|
480
|
+
total_w = max(im.width for im in processed)
|
|
481
|
+
total_h = sum(im.height for im in processed) + pad_total
|
|
482
|
+
stacked = Image.new("RGB", (total_w, total_h), (255, 255, 255))
|
|
483
|
+
y = 0
|
|
484
|
+
for im in processed:
|
|
485
|
+
stacked.paste(im, (0, y))
|
|
486
|
+
y += im.height + padding
|
|
487
|
+
else:
|
|
488
|
+
raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
|
|
489
|
+
|
|
490
|
+
buf = BytesIO()
|
|
491
|
+
stacked.save(buf, format="PNG")
|
|
492
|
+
encoded = base64.b64encode(buf.getvalue()).decode()
|
|
493
|
+
return f"data:image/png;base64,{encoded}", (total_w, total_h)
|
|
494
|
+
|
|
495
|
+
def compute_stacked_size(
|
|
496
|
+
image_paths: List[str],
|
|
497
|
+
direction: str,
|
|
498
|
+
single_w: int,
|
|
499
|
+
single_h: int,
|
|
500
|
+
padding: int,
|
|
501
|
+
outline: str,
|
|
502
|
+
outline_size: int,
|
|
503
|
+
) -> Tuple[int, int]:
|
|
504
|
+
image_paths = [p for p in image_paths if p is not None and str(p).strip() != ""]
|
|
505
|
+
n = len(image_paths)
|
|
506
|
+
if n == 0:
|
|
507
|
+
return single_w, single_h
|
|
508
|
+
has_outline = len(outline) == 7 and outline.startswith("#")
|
|
509
|
+
border = 2 * outline_size if has_outline else 0
|
|
510
|
+
unit_w = single_w + border
|
|
511
|
+
unit_h = single_h + border
|
|
512
|
+
if direction == "horizontal":
|
|
513
|
+
total_w = n * unit_w + (n - 1) * padding
|
|
514
|
+
total_h = unit_h
|
|
515
|
+
elif direction == "vertical":
|
|
516
|
+
total_w = unit_w
|
|
517
|
+
total_h = n * unit_h + (n - 1) * padding
|
|
518
|
+
else:
|
|
519
|
+
raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'.")
|
|
520
|
+
return total_w, total_h
|
|
521
|
+
|
|
522
|
+
# --- Load DataFrame ---
|
|
523
|
+
if isinstance(indir_or_csvf_or_df, str):
|
|
524
|
+
fname, ext = os.path.splitext(indir_or_csvf_or_df)
|
|
525
|
+
if ext.lower() == ".csv":
|
|
526
|
+
df = pd.read_csv(indir_or_csvf_or_df)
|
|
527
|
+
elif os.path.isdir(indir_or_csvf_or_df):
|
|
528
|
+
df = PlotHelper.img_grid_indir_1(indir_or_csvf_or_df, log=False)
|
|
529
|
+
else:
|
|
530
|
+
raise ValueError("Input string must be a valid CSV file or directory path")
|
|
531
|
+
elif isinstance(indir_or_csvf_or_df, pd.DataFrame):
|
|
532
|
+
df = indir_or_csvf_or_df.copy()
|
|
533
|
+
else:
|
|
534
|
+
raise ValueError("Input must be CSV file path, DataFrame, or directory path")
|
|
535
|
+
|
|
536
|
+
rows = df.iloc[:, 0].astype(str).tolist()
|
|
537
|
+
columns = list(df.columns[1:])
|
|
538
|
+
n_rows, n_cols = len(rows), len(columns)
|
|
539
|
+
|
|
540
|
+
fig = go.Figure()
|
|
541
|
+
|
|
542
|
+
# First pass: compute content sizes
|
|
543
|
+
content_col_max = [0] * n_cols
|
|
544
|
+
content_row_max = [0] * n_rows
|
|
545
|
+
cell_paths = [[None] * n_cols for _ in range(n_rows)]
|
|
546
|
+
for i in range(n_rows):
|
|
547
|
+
for j in range(n_cols):
|
|
548
|
+
raw_cell = df.iloc[i, j + 1]
|
|
549
|
+
paths = PlotHelper._parse_cell_to_list(raw_cell)
|
|
550
|
+
image_paths = [str(p).strip() for p in paths if str(p).strip() != ""]
|
|
551
|
+
cell_paths[i][j] = image_paths
|
|
552
|
+
cw, ch = compute_stacked_size(
|
|
553
|
+
image_paths,
|
|
554
|
+
img_stack_direction,
|
|
555
|
+
img_width,
|
|
556
|
+
img_height,
|
|
557
|
+
img_stack_padding_px,
|
|
558
|
+
outline_color,
|
|
559
|
+
outline_size,
|
|
560
|
+
)
|
|
561
|
+
content_col_max[j] = max(content_col_max[j], cw)
|
|
562
|
+
content_row_max[i] = max(content_row_max[i], ch)
|
|
563
|
+
|
|
564
|
+
# Compute display sizes (content max + padding)
|
|
565
|
+
display_col_w = [content_col_max[j] + 2 * cell_margin_px for j in range(n_cols)]
|
|
566
|
+
display_row_h = [content_row_max[i] + 2 * cell_margin_px for i in range(n_rows)]
|
|
567
|
+
|
|
568
|
+
# Compute positions (cells adjacent)
|
|
569
|
+
x_positions = []
|
|
570
|
+
cum_w = 0
|
|
571
|
+
for dw in display_col_w:
|
|
572
|
+
x_positions.append(cum_w)
|
|
573
|
+
cum_w += dw
|
|
574
|
+
|
|
575
|
+
y_positions = []
|
|
576
|
+
cum_h = 0
|
|
577
|
+
for dh in display_row_h:
|
|
578
|
+
y_positions.append(-cum_h)
|
|
579
|
+
cum_h += dh
|
|
580
|
+
|
|
581
|
+
# Second pass: create padded images (centered content)
|
|
582
|
+
cell_imgs = [[None] * n_cols for _ in range(n_rows)]
|
|
583
|
+
p = cell_margin_px
|
|
584
|
+
for i in range(n_rows):
|
|
585
|
+
for j in range(n_cols):
|
|
586
|
+
image_paths = cell_paths[i][j]
|
|
587
|
+
content_src, (cw, ch) = stack_images_base64(
|
|
588
|
+
image_paths,
|
|
589
|
+
img_stack_direction,
|
|
590
|
+
(img_width, img_height),
|
|
591
|
+
outline_color,
|
|
592
|
+
outline_size,
|
|
593
|
+
img_stack_padding_px,
|
|
594
|
+
)
|
|
595
|
+
if cw == 0 or ch == 0:
|
|
596
|
+
# Skip empty, but create white padded
|
|
597
|
+
pad_w = display_col_w[j]
|
|
598
|
+
pad_h = display_row_h[i]
|
|
599
|
+
padded = Image.new("RGB", (pad_w, pad_h), (255, 255, 255))
|
|
600
|
+
else:
|
|
601
|
+
content_img = Image.open(
|
|
602
|
+
BytesIO(base64.b64decode(content_src.split(",")[1]))
|
|
603
|
+
)
|
|
604
|
+
ca_w = content_col_max[j]
|
|
605
|
+
ca_h = content_row_max[i]
|
|
606
|
+
left_offset = (ca_w - cw) // 2
|
|
607
|
+
top_offset = (ca_h - ch) // 2
|
|
608
|
+
pad_w = display_col_w[j]
|
|
609
|
+
pad_h = display_row_h[i]
|
|
610
|
+
padded = Image.new("RGB", (pad_w, pad_h), (255, 255, 255))
|
|
611
|
+
paste_x = p + left_offset
|
|
612
|
+
paste_y = p + top_offset
|
|
613
|
+
padded.paste(content_img, (paste_x, paste_y))
|
|
614
|
+
buf = BytesIO()
|
|
615
|
+
padded.save(buf, format="PNG")
|
|
616
|
+
encoded = base64.b64encode(buf.getvalue()).decode()
|
|
617
|
+
cell_imgs[i][j] = f"data:image/png;base64,{encoded}"
|
|
618
|
+
|
|
619
|
+
# Add images to figure
|
|
620
|
+
for i in range(n_rows):
|
|
621
|
+
for j in range(n_cols):
|
|
622
|
+
fig.add_layout_image(
|
|
623
|
+
dict(
|
|
624
|
+
source=cell_imgs[i][j],
|
|
625
|
+
x=x_positions[j],
|
|
626
|
+
y=y_positions[i],
|
|
627
|
+
xref="x",
|
|
628
|
+
yref="y",
|
|
629
|
+
sizex=display_col_w[j],
|
|
630
|
+
sizey=display_row_h[i],
|
|
631
|
+
xanchor="left",
|
|
632
|
+
yanchor="top",
|
|
633
|
+
layer="above",
|
|
634
|
+
)
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# Optional grid lines (at cell boundaries, adjusted for inter-content spaces)
|
|
638
|
+
if row_line_size > 0:
|
|
639
|
+
for i in range(1, n_rows):
|
|
640
|
+
y = (y_positions[i - 1] - display_row_h[i - 1] + y_positions[i]) / 2
|
|
641
|
+
fig.add_shape(
|
|
642
|
+
type="line",
|
|
643
|
+
x0=-p,
|
|
644
|
+
x1=cum_w,
|
|
645
|
+
y0=y,
|
|
646
|
+
y1=y,
|
|
647
|
+
line=dict(width=row_line_size, color="black", dash="dot"),
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
if col_line_size > 0:
|
|
651
|
+
for j in range(1, n_cols):
|
|
652
|
+
x = x_positions[j]
|
|
653
|
+
fig.add_shape(
|
|
654
|
+
type="line",
|
|
655
|
+
x0=x,
|
|
656
|
+
x1=x,
|
|
657
|
+
y0=p,
|
|
658
|
+
y1=-cum_h,
|
|
659
|
+
line=dict(width=col_line_size, color="black", dash="dot"),
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
# Axis labels
|
|
663
|
+
col_labels = [
|
|
664
|
+
format_col_label_func(c) if format_col_label_func else c for c in columns
|
|
665
|
+
]
|
|
666
|
+
row_labels = [
|
|
667
|
+
format_row_label_func(r) if format_row_label_func else r for r in rows
|
|
668
|
+
]
|
|
669
|
+
|
|
670
|
+
fig.update_xaxes(
|
|
671
|
+
tickvals=[x_positions[j] + display_col_w[j] / 2 for j in range(n_cols)],
|
|
672
|
+
ticktext=col_labels,
|
|
673
|
+
range=[-p, cum_w],
|
|
674
|
+
showgrid=False,
|
|
675
|
+
zeroline=False,
|
|
676
|
+
tickfont=tickfont, # <-- apply bigger font here
|
|
677
|
+
)
|
|
678
|
+
fig.update_yaxes(
|
|
679
|
+
tickvals=[y_positions[i] - display_row_h[i] / 2 for i in range(n_rows)],
|
|
680
|
+
ticktext=row_labels,
|
|
681
|
+
range=[-cum_h, p],
|
|
682
|
+
showgrid=False,
|
|
683
|
+
zeroline=False,
|
|
684
|
+
tickfont=tickfont, # <-- apply bigger font here
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
fig.update_layout(
|
|
688
|
+
width=cum_w + 100,
|
|
689
|
+
height=cum_h + 100,
|
|
690
|
+
title=title,
|
|
691
|
+
title_x=0.5,
|
|
692
|
+
margin=fig_margin,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
# === EXPORT IF save_path IS GIVEN ===
|
|
696
|
+
if save_path:
|
|
697
|
+
import kaleido # lazy import – only needed when saving
|
|
698
|
+
import os
|
|
699
|
+
|
|
700
|
+
ext = os.path.splitext(save_path)[1].lower()
|
|
701
|
+
if ext in [".png", ".jpg", ".jpeg"]:
|
|
702
|
+
fig.write_image(save_path, scale=dpi / 96) # scale = dpi / base 96
|
|
703
|
+
elif ext in [".pdf", ".svg"]:
|
|
704
|
+
fig.write_image(save_path) # PDF/SVG are vector → dpi ignored
|
|
705
|
+
else:
|
|
706
|
+
raise ValueError("save_path must end with .png, .jpg, .pdf, or .svg")
|
|
707
|
+
if show:
|
|
708
|
+
fig.show()
|
|
709
|
+
return fig
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
@click.command()
|
|
713
|
+
@click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
|
|
714
|
+
@click.option(
|
|
715
|
+
"--outdir", "-o", type=str, default=str(desktop_path), help="output directory"
|
|
716
|
+
)
|
|
717
|
+
@click.option(
|
|
718
|
+
"--tags", "-t", multiple=True, type=str, default=[], help="tags for the csv files"
|
|
719
|
+
)
|
|
720
|
+
@click.option("--log", "-l", is_flag=True, help="log the csv files")
|
|
721
|
+
@click.option("--save_fig", "-s", is_flag=True, help="save the plot as file")
|
|
722
|
+
@click.option(
|
|
723
|
+
"--update_in_min",
|
|
724
|
+
"-u",
|
|
725
|
+
type=float,
|
|
726
|
+
default=0.0,
|
|
727
|
+
help="update the plot every x minutes",
|
|
728
|
+
)
|
|
729
|
+
@click.option(
|
|
730
|
+
"--x_col", "-x", type=str, default="epoch", help="column to use as x-axis"
|
|
731
|
+
)
|
|
732
|
+
@click.option(
|
|
733
|
+
"--y_cols",
|
|
734
|
+
"-y",
|
|
735
|
+
multiple=True,
|
|
736
|
+
type=str,
|
|
737
|
+
required=True,
|
|
738
|
+
help="columns to plot as y (can repeat)",
|
|
739
|
+
)
|
|
740
|
+
def main(csvfiles, outdir, tags, log, save_fig, update_in_min, x_col, y_cols):
|
|
741
|
+
PlotHelper.plot_csv_timeseries(
|
|
742
|
+
list(csvfiles),
|
|
743
|
+
outdir,
|
|
744
|
+
list(tags),
|
|
745
|
+
log,
|
|
746
|
+
save_fig,
|
|
747
|
+
update_in_min,
|
|
748
|
+
x_col=x_col,
|
|
749
|
+
y_cols=list(y_cols),
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
if __name__ == "__main__":
|
|
754
|
+
main()
|