halib 0.1.85__tar.gz → 0.1.86__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {halib-0.1.85 → halib-0.1.86}/.gitignore +2 -0
- {halib-0.1.85 → halib-0.1.86}/PKG-INFO +2 -2
- {halib-0.1.85 → halib-0.1.86}/README.md +1 -1
- {halib-0.1.85 → halib-0.1.86}/halib/filetype/csvfile.py +29 -0
- halib-0.1.86/halib/research/mics.py +16 -0
- halib-0.1.86/halib/research/plot.py +496 -0
- {halib-0.1.85 → halib-0.1.86}/halib.egg-info/PKG-INFO +2 -2
- {halib-0.1.85 → halib-0.1.86}/halib.egg-info/SOURCES.txt +1 -0
- {halib-0.1.85 → halib-0.1.86}/setup.py +1 -1
- halib-0.1.85/halib/research/plot.py +0 -301
- {halib-0.1.85 → halib-0.1.86}/GDriveFolder.txt +0 -0
- {halib-0.1.85 → halib-0.1.86}/LICENSE.txt +0 -0
- {halib-0.1.85 → halib-0.1.86}/MANIFEST.in +0 -0
- {halib-0.1.85 → halib-0.1.86}/guide_publish_pip.pdf +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/__init__.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/common.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/cuda.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/filetype/__init__.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/filetype/jsonfile.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/filetype/textfile.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/filetype/videofile.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/filetype/yamlfile.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/online/__init__.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/online/gdrive.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/online/gdrive_mkdir.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/online/gdrive_test.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/online/projectmake.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/research/__init__.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/research/base_config.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/research/base_exp.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/research/dataset.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/research/metrics.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/research/perfcalc.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/research/perftb.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/research/profiler.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/research/torchloader.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/research/wandb_op.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/rich_color.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/system/__init__.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/system/cmd.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/system/filesys.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/utils/__init__.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/utils/dataclass_util.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/utils/dict_op.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/utils/gpu_mon.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/utils/listop.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/utils/tele_noti.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib/utils/video.py +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib.egg-info/dependency_links.txt +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib.egg-info/requires.txt +0 -0
- {halib-0.1.85 → halib-0.1.86}/halib.egg-info/top_level.txt +0 -0
- {halib-0.1.85 → halib-0.1.86}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: halib
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.86
|
4
4
|
Summary: Small library for common tasks
|
5
5
|
Author: Hoang Van Ha
|
6
6
|
Author-email: hoangvanhauit@gmail.com
|
@@ -52,7 +52,7 @@ Dynamic: summary
|
|
52
52
|
|
53
53
|
Helper package for coding and automation
|
54
54
|
|
55
|
-
**Version 0.1.
|
55
|
+
**Version 0.1.86**
|
56
56
|
|
57
57
|
+ `research/profiler`: add `zProfiler` class to measure execution time of contexts and steps, with support for dynamic color scales in plots.
|
58
58
|
|
@@ -9,6 +9,7 @@ from loguru import logger
|
|
9
9
|
from itables import init_notebook_mode, show
|
10
10
|
import pygwalker as pyg
|
11
11
|
import textwrap
|
12
|
+
import csv
|
12
13
|
|
13
14
|
console = Console()
|
14
15
|
|
@@ -18,6 +19,34 @@ def read(file, separator=","):
|
|
18
19
|
return df
|
19
20
|
|
20
21
|
|
22
|
+
def read_auto_sep(filepath, sample_size=2048, **kwargs):
|
23
|
+
"""
|
24
|
+
Read a CSV file with automatic delimiter detection.
|
25
|
+
|
26
|
+
Parameters
|
27
|
+
----------
|
28
|
+
filepath : str
|
29
|
+
Path to the CSV file.
|
30
|
+
sample_size : int, optional
|
31
|
+
Number of bytes to read for delimiter sniffing.
|
32
|
+
**kwargs : dict
|
33
|
+
Extra keyword args passed to pandas.read_csv.
|
34
|
+
|
35
|
+
Returns
|
36
|
+
-------
|
37
|
+
df : pandas.DataFrame
|
38
|
+
"""
|
39
|
+
with open(filepath, "r", newline="", encoding=kwargs.get("encoding", "utf-8")) as f:
|
40
|
+
sample = f.read(sample_size)
|
41
|
+
f.seek(0)
|
42
|
+
try:
|
43
|
+
dialect = csv.Sniffer().sniff(sample, delimiters=[",", ";", "\t", "|", ":"])
|
44
|
+
sep = dialect.delimiter
|
45
|
+
except csv.Error:
|
46
|
+
sep = "," # fallback if detection fails
|
47
|
+
|
48
|
+
return pd.read_csv(filepath, sep=sep, **kwargs)
|
49
|
+
|
21
50
|
# for append, mode = 'a'
|
22
51
|
def fn_write(df, outfile, mode="w", header=True, index_label=None):
|
23
52
|
if not outfile.endswith(".csv"):
|
@@ -0,0 +1,16 @@
|
|
1
|
+
import platform
|
2
|
+
|
3
|
+
PC_NAME_TO_ABBR = {
|
4
|
+
"DESKTOP-JQD9K01": "MainPC",
|
5
|
+
"DESKTOP-5IRHU87": "MSI_Laptop",
|
6
|
+
"DESKTOP-96HQCNO": "4090_SV",
|
7
|
+
"DESKTOP-Q2IKLC0": "4GPU_SV",
|
8
|
+
"DESKTOP-QNS3DNF": "1GPU_SV"
|
9
|
+
}
|
10
|
+
|
11
|
+
def get_PC_name():
|
12
|
+
return platform.node()
|
13
|
+
|
14
|
+
def get_PC_abbr_name():
|
15
|
+
pc_name = get_PC_name()
|
16
|
+
return PC_NAME_TO_ABBR.get(pc_name, "Unknown")
|
@@ -0,0 +1,496 @@
|
|
1
|
+
import os
|
2
|
+
import pandas as pd
|
3
|
+
import plotly.express as px
|
4
|
+
from rich.console import Console
|
5
|
+
from ..common import now_str, norm_str, ConsoleLog
|
6
|
+
from ..filetype import csvfile
|
7
|
+
from ..system import filesys as fs
|
8
|
+
import click
|
9
|
+
import time
|
10
|
+
|
11
|
+
import pandas as pd
|
12
|
+
import plotly.graph_objects as go
|
13
|
+
from PIL import Image
|
14
|
+
import base64
|
15
|
+
from io import BytesIO
|
16
|
+
from typing import Callable, Optional, Tuple, List, Union
|
17
|
+
|
18
|
+
|
19
|
+
console = Console()
|
20
|
+
desktop_path = os.path.expanduser("~/Desktop")
|
21
|
+
|
22
|
+
class PlotHelper:
|
23
|
+
def _verify_csv(self, csv_file):
|
24
|
+
"""Read a CSV and normalize column names (lowercase)."""
|
25
|
+
try:
|
26
|
+
df = csvfile.read_auto_sep(csv_file)
|
27
|
+
df.columns = [col.lower() for col in df.columns]
|
28
|
+
return df
|
29
|
+
except FileNotFoundError:
|
30
|
+
raise FileNotFoundError(f"CSV file '{csv_file}' not found")
|
31
|
+
except Exception as e:
|
32
|
+
raise ValueError(f"Error reading CSV file '{csv_file}': {str(e)}")
|
33
|
+
|
34
|
+
@staticmethod
|
35
|
+
def _norm_str(s):
|
36
|
+
"""Normalize string by converting to lowercase and replacing spaces/underscores."""
|
37
|
+
return s.lower().replace(" ", "_").replace("-", "_")
|
38
|
+
|
39
|
+
@staticmethod
|
40
|
+
def _get_file_name(file_path):
|
41
|
+
"""Extract file name without extension."""
|
42
|
+
return os.path.splitext(os.path.basename(file_path))[0]
|
43
|
+
|
44
|
+
def _get_valid_tags(self, csv_files, tags):
|
45
|
+
"""Generate tags from file names if not provided."""
|
46
|
+
if tags:
|
47
|
+
return list(tags)
|
48
|
+
return [self._norm_str(self._get_file_name(f)) for f in csv_files]
|
49
|
+
|
50
|
+
def _prepare_long_df(self, csv_files, tags, x_col, y_cols, log=False):
|
51
|
+
"""Convert multiple CSVs into a single long-form dataframe for Plotly."""
|
52
|
+
dfs = []
|
53
|
+
for csv_file, tag in zip(csv_files, tags):
|
54
|
+
df = self._verify_csv(csv_file)
|
55
|
+
# Check columns
|
56
|
+
if x_col not in df.columns:
|
57
|
+
raise ValueError(f"{csv_file} is missing x_col '{x_col}'")
|
58
|
+
missing = [c for c in y_cols if c not in df.columns]
|
59
|
+
if missing:
|
60
|
+
raise ValueError(f"{csv_file} is missing y_cols {missing}")
|
61
|
+
|
62
|
+
if log:
|
63
|
+
console.log(f"Plotting {csv_file}")
|
64
|
+
console.print(df)
|
65
|
+
|
66
|
+
# Wide to long
|
67
|
+
df_long = df.melt(
|
68
|
+
id_vars=x_col,
|
69
|
+
value_vars=y_cols,
|
70
|
+
var_name="metric_type",
|
71
|
+
value_name="value",
|
72
|
+
)
|
73
|
+
df_long["tag"] = tag
|
74
|
+
dfs.append(df_long)
|
75
|
+
|
76
|
+
return pd.concat(dfs, ignore_index=True)
|
77
|
+
|
78
|
+
def _plot_with_plotly(
|
79
|
+
self,
|
80
|
+
df_long,
|
81
|
+
tags,
|
82
|
+
outdir,
|
83
|
+
save_fig,
|
84
|
+
out_fmt="svg",
|
85
|
+
font_size=16,
|
86
|
+
x_col="epoch",
|
87
|
+
y_cols=None,
|
88
|
+
):
|
89
|
+
"""Generate Plotly plots for given metrics."""
|
90
|
+
assert out_fmt in ["svg", "pdf", "png"], "Unsupported format"
|
91
|
+
if y_cols is None:
|
92
|
+
raise ValueError("y_cols must be provided")
|
93
|
+
|
94
|
+
# Group by suffix (e.g., "loss", "acc") if names like train_loss exist
|
95
|
+
metric_groups = sorted(set(col.split("_")[-1] for col in y_cols))
|
96
|
+
|
97
|
+
for metric in metric_groups:
|
98
|
+
subset = df_long[df_long["metric_type"].str.contains(metric)]
|
99
|
+
|
100
|
+
if out_fmt == "svg": # LaTeX-style
|
101
|
+
title = f"${'+'.join(tags)}\\_{metric}\\text{{-by-{x_col}}}$"
|
102
|
+
xaxis_title = f"$\\text{{{x_col.capitalize()}}}$"
|
103
|
+
yaxis_title = f"${metric.capitalize()}$"
|
104
|
+
else:
|
105
|
+
title = f"{'+'.join(tags)}_{metric}-by-{x_col}"
|
106
|
+
xaxis_title = x_col.capitalize()
|
107
|
+
yaxis_title = metric.capitalize()
|
108
|
+
|
109
|
+
fig = px.line(
|
110
|
+
subset,
|
111
|
+
x=x_col,
|
112
|
+
y="value",
|
113
|
+
color="tag",
|
114
|
+
line_dash="metric_type",
|
115
|
+
title=title,
|
116
|
+
)
|
117
|
+
fig.update_layout(
|
118
|
+
font=dict(family="Computer Modern", size=font_size),
|
119
|
+
xaxis_title=xaxis_title,
|
120
|
+
yaxis_title=yaxis_title,
|
121
|
+
)
|
122
|
+
fig.show()
|
123
|
+
|
124
|
+
if save_fig:
|
125
|
+
os.makedirs(outdir, exist_ok=True)
|
126
|
+
timestamp = now_str()
|
127
|
+
filename = f"{timestamp}_{'+'.join(tags)}_{metric}"
|
128
|
+
try:
|
129
|
+
fig.write_image(os.path.join(outdir, f"{filename}.{out_fmt}"))
|
130
|
+
except Exception as e:
|
131
|
+
console.log(f"Error saving figure '{filename}.{out_fmt}': {str(e)}")
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
def plot_csv_timeseries(
|
135
|
+
cls,
|
136
|
+
csv_files,
|
137
|
+
outdir="./out/plot",
|
138
|
+
tags=None,
|
139
|
+
log=False,
|
140
|
+
save_fig=False,
|
141
|
+
update_in_min=0,
|
142
|
+
out_fmt="svg",
|
143
|
+
font_size=16,
|
144
|
+
x_col="epoch",
|
145
|
+
y_cols=["train_loss", "train_acc"],
|
146
|
+
):
|
147
|
+
"""Plot CSV files with Plotly, supporting live updates, as a class method."""
|
148
|
+
if isinstance(csv_files, str):
|
149
|
+
csv_files = [csv_files]
|
150
|
+
if isinstance(tags, str):
|
151
|
+
tags = [tags]
|
152
|
+
|
153
|
+
if not y_cols:
|
154
|
+
raise ValueError("You must specify y_cols explicitly")
|
155
|
+
|
156
|
+
# Instantiate PlotHelper to call instance methods
|
157
|
+
plot_helper = cls()
|
158
|
+
valid_tags = plot_helper._get_valid_tags(csv_files, tags)
|
159
|
+
assert len(valid_tags) == len(
|
160
|
+
csv_files
|
161
|
+
), "Number of tags must match number of CSV files"
|
162
|
+
|
163
|
+
def run_once():
|
164
|
+
df_long = plot_helper._prepare_long_df(
|
165
|
+
csv_files, valid_tags, x_col, y_cols, log
|
166
|
+
)
|
167
|
+
plot_helper._plot_with_plotly(
|
168
|
+
df_long, valid_tags, outdir, save_fig, out_fmt, font_size, x_col, y_cols
|
169
|
+
)
|
170
|
+
|
171
|
+
if update_in_min > 0:
|
172
|
+
interval = int(update_in_min * 60)
|
173
|
+
console.log(f"Live update every {interval}s. Press Ctrl+C to stop.")
|
174
|
+
try:
|
175
|
+
while True:
|
176
|
+
run_once()
|
177
|
+
time.sleep(interval)
|
178
|
+
except KeyboardInterrupt:
|
179
|
+
console.log("Stopped live updates.")
|
180
|
+
else:
|
181
|
+
run_once()
|
182
|
+
@staticmethod
|
183
|
+
def plot_image_grid(csv_path, sep=";", max_width=300, max_height=300):
|
184
|
+
"""
|
185
|
+
Plot a grid of images using Plotly from a CSV file.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
csv_path (str): Path to CSV file.
|
189
|
+
max_width (int): Maximum width of each image in pixels.
|
190
|
+
max_height (int): Maximum height of each image in pixels.
|
191
|
+
"""
|
192
|
+
# Load CSV
|
193
|
+
df = csvfile.read_auto_sep(csv_path, sep=sep)
|
194
|
+
|
195
|
+
# Column names for headers
|
196
|
+
col_names = df.columns.tolist()
|
197
|
+
|
198
|
+
# Function to convert image to base64
|
199
|
+
def pil_to_base64(img_path):
|
200
|
+
with Image.open(img_path) as im:
|
201
|
+
im.thumbnail((max_width, max_height))
|
202
|
+
buffer = BytesIO()
|
203
|
+
im.save(buffer, format="PNG")
|
204
|
+
encoded = base64.b64encode(buffer.getvalue()).decode()
|
205
|
+
return "data:image/png;base64," + encoded
|
206
|
+
|
207
|
+
# Initialize figure
|
208
|
+
fig = go.Figure()
|
209
|
+
|
210
|
+
n_rows = len(df)
|
211
|
+
n_cols = len(df.columns) - 1 # skip label column
|
212
|
+
|
213
|
+
# Add images
|
214
|
+
for i, row in df.iterrows():
|
215
|
+
for j, col in enumerate(df.columns[1:]):
|
216
|
+
img_path = row[col]
|
217
|
+
img_src = pil_to_base64(img_path)
|
218
|
+
fig.add_layout_image(
|
219
|
+
dict(
|
220
|
+
source=img_src,
|
221
|
+
x=j,
|
222
|
+
y=-i, # negative to have row 0 on top
|
223
|
+
xref="x",
|
224
|
+
yref="y",
|
225
|
+
sizex=1,
|
226
|
+
sizey=1,
|
227
|
+
xanchor="left",
|
228
|
+
yanchor="top",
|
229
|
+
layer="above"
|
230
|
+
)
|
231
|
+
)
|
232
|
+
|
233
|
+
# Set axes for grid layout
|
234
|
+
fig.update_xaxes(
|
235
|
+
tickvals=list(range(n_cols)),
|
236
|
+
ticktext=list(df.columns[1:]),
|
237
|
+
range=[-0.5, n_cols-0.5],
|
238
|
+
showgrid=False,
|
239
|
+
zeroline=False
|
240
|
+
)
|
241
|
+
fig.update_yaxes(
|
242
|
+
tickvals=[-i for i in range(n_rows)],
|
243
|
+
ticktext=df[df.columns[0]],
|
244
|
+
range=[-n_rows + 0.5, 0.5],
|
245
|
+
showgrid=False,
|
246
|
+
zeroline=False
|
247
|
+
)
|
248
|
+
|
249
|
+
fig.update_layout(
|
250
|
+
width=max_width*n_cols,
|
251
|
+
height=max_height*n_rows,
|
252
|
+
margin=dict(l=100, r=20, t=50, b=50)
|
253
|
+
)
|
254
|
+
|
255
|
+
fig.show()
|
256
|
+
|
257
|
+
@staticmethod
|
258
|
+
# this plot_df contains the data to be plotted (row, column)
|
259
|
+
def img_grid_df(input_dir, log=False):
|
260
|
+
rows = fs.list_dirs(input_dir)
|
261
|
+
rows = [r for r in rows if r.startswith("row")]
|
262
|
+
meta_dict = {}
|
263
|
+
cols_of_row = None
|
264
|
+
for row in rows:
|
265
|
+
row_path = os.path.join(input_dir, row)
|
266
|
+
cols = sorted(fs.list_dirs(row_path))
|
267
|
+
if cols_of_row is None:
|
268
|
+
cols_of_row = cols
|
269
|
+
else:
|
270
|
+
if cols_of_row != cols:
|
271
|
+
raise ValueError(
|
272
|
+
f"Row {row} has different columns than previous rows: {cols_of_row} vs {cols}"
|
273
|
+
)
|
274
|
+
meta_dict[row] = cols
|
275
|
+
|
276
|
+
meta_dict_with_paths = {}
|
277
|
+
for row, cols in meta_dict.items():
|
278
|
+
meta_dict_with_paths[row] = {
|
279
|
+
col: fs.filter_files_by_extension(
|
280
|
+
os.path.join(input_dir, row, col), ["png", "jpg", "jpeg"]
|
281
|
+
)
|
282
|
+
for col in cols
|
283
|
+
}
|
284
|
+
first_row = list(meta_dict_with_paths.keys())[0]
|
285
|
+
first_col = list(meta_dict_with_paths[first_row].keys())[0]
|
286
|
+
len_first_col = len(meta_dict_with_paths[first_row][first_col])
|
287
|
+
for row, cols in meta_dict_with_paths.items():
|
288
|
+
for col, paths in cols.items():
|
289
|
+
if len(paths) != len_first_col:
|
290
|
+
raise ValueError(
|
291
|
+
f"Row {row}, Column {col} has different number of files: {len(paths)} vs {len_first_col}"
|
292
|
+
)
|
293
|
+
cols = sorted(meta_dict_with_paths[first_row].keys())
|
294
|
+
rows_set = sorted(meta_dict_with_paths.keys())
|
295
|
+
row_per_col = len(meta_dict_with_paths[first_row][first_col])
|
296
|
+
rows = [item for item in rows_set for _ in range(row_per_col)]
|
297
|
+
data_dict = {}
|
298
|
+
data_dict["row"] = rows
|
299
|
+
col_data = {col: [] for col in cols}
|
300
|
+
for row_base in rows_set:
|
301
|
+
for col in cols:
|
302
|
+
for i in range(row_per_col):
|
303
|
+
col_data[col].append(meta_dict_with_paths[row_base][col][i])
|
304
|
+
data_dict.update(col_data)
|
305
|
+
df = pd.DataFrame(data_dict)
|
306
|
+
if log:
|
307
|
+
csvfile.fn_display_df(df)
|
308
|
+
return df
|
309
|
+
|
310
|
+
@staticmethod
|
311
|
+
def plot_image_grid(
|
312
|
+
csv_file_or_df: Union[str, pd.DataFrame],
|
313
|
+
max_width: int = 300,
|
314
|
+
max_height: int = 300,
|
315
|
+
img_stack_direction: str = "horizontal",
|
316
|
+
img_stack_padding_px: int = 10,
|
317
|
+
format_row_label_func: Optional[Callable[[str], str]] = None,
|
318
|
+
format_col_label_func: Optional[Callable[[str, str], str]] = None,
|
319
|
+
title: str = "",
|
320
|
+
):
|
321
|
+
"""
|
322
|
+
Plot a grid of images using Plotly from a DataFrame.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
df (pd.DataFrame): DataFrame with first column as row labels, remaining columns as image paths.
|
326
|
+
max_width (int): Maximum width of stacked images per cell in pixels.
|
327
|
+
max_height (int): Maximum height of stacked images per cell in pixels.
|
328
|
+
img_stack_direction (str): "horizontal" or "vertical" stacking.
|
329
|
+
img_stack_padding_px (int): Padding between stacked images in pixels.
|
330
|
+
format_row_label_func (Callable): Function to format row labels.
|
331
|
+
format_col_label_func (Callable): Function to format column labels.
|
332
|
+
title (str): Figure title.
|
333
|
+
"""
|
334
|
+
|
335
|
+
def stack_images_base64(
|
336
|
+
image_paths: List[str], direction: str, target_size: Tuple[int, int]
|
337
|
+
) -> str:
|
338
|
+
"""Stack images and return base64-encoded PNG."""
|
339
|
+
if not image_paths:
|
340
|
+
return ""
|
341
|
+
|
342
|
+
processed_images = []
|
343
|
+
for path in image_paths:
|
344
|
+
try:
|
345
|
+
img = Image.open(path).convert("RGB")
|
346
|
+
img.thumbnail(target_size, Image.Resampling.LANCZOS)
|
347
|
+
processed_images.append(img)
|
348
|
+
except:
|
349
|
+
# blank image if error
|
350
|
+
processed_images.append(Image.new("RGB", target_size, (255, 255, 255)))
|
351
|
+
|
352
|
+
# Stack
|
353
|
+
widths, heights = zip(*(img.size for img in processed_images))
|
354
|
+
if direction == "horizontal":
|
355
|
+
total_width = sum(widths) + img_stack_padding_px * (
|
356
|
+
len(processed_images) - 1
|
357
|
+
)
|
358
|
+
total_height = max(heights)
|
359
|
+
stacked = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
360
|
+
x_offset = 0
|
361
|
+
for im in processed_images:
|
362
|
+
stacked.paste(im, (x_offset, 0))
|
363
|
+
x_offset += im.width + img_stack_padding_px
|
364
|
+
elif direction == "vertical":
|
365
|
+
total_width = max(widths)
|
366
|
+
total_height = sum(heights) + img_stack_padding_px * (
|
367
|
+
len(processed_images) - 1
|
368
|
+
)
|
369
|
+
stacked = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
370
|
+
y_offset = 0
|
371
|
+
for im in processed_images:
|
372
|
+
stacked.paste(im, (0, y_offset))
|
373
|
+
y_offset += im.height + img_stack_padding_px
|
374
|
+
else:
|
375
|
+
raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'")
|
376
|
+
|
377
|
+
# Encode as base64
|
378
|
+
buffer = BytesIO()
|
379
|
+
stacked.save(buffer, format="PNG")
|
380
|
+
encoded = base64.b64encode(buffer.getvalue()).decode()
|
381
|
+
return "data:image/png;base64," + encoded
|
382
|
+
|
383
|
+
# Load DataFrame if a file path is provided
|
384
|
+
if isinstance(csv_file_or_df, str):
|
385
|
+
df = csvfile.read_auto_sep(csv_file_or_df)
|
386
|
+
else:
|
387
|
+
df = csv_file_or_df
|
388
|
+
assert isinstance(df, pd.DataFrame), "Input must be a DataFrame or valid CSV file path"
|
389
|
+
|
390
|
+
rows = df[df.columns[0]].tolist()
|
391
|
+
columns = df.columns[1:].tolist()
|
392
|
+
n_rows, n_cols = len(rows), len(columns)
|
393
|
+
|
394
|
+
fig = go.Figure()
|
395
|
+
|
396
|
+
for i, row_label in enumerate(rows):
|
397
|
+
for j, col_label in enumerate(columns):
|
398
|
+
image_paths = df.loc[i, col_label]
|
399
|
+
if isinstance(image_paths, str):
|
400
|
+
image_paths = [image_paths]
|
401
|
+
img_src = stack_images_base64(
|
402
|
+
image_paths, img_stack_direction, (max_width, max_height)
|
403
|
+
)
|
404
|
+
|
405
|
+
fig.add_layout_image(
|
406
|
+
dict(
|
407
|
+
source=img_src,
|
408
|
+
x=j,
|
409
|
+
y=-i, # negative so row 0 on top
|
410
|
+
xref="x",
|
411
|
+
yref="y",
|
412
|
+
sizex=1,
|
413
|
+
sizey=1,
|
414
|
+
xanchor="left",
|
415
|
+
yanchor="top",
|
416
|
+
layer="above",
|
417
|
+
)
|
418
|
+
)
|
419
|
+
|
420
|
+
# Format axis labels
|
421
|
+
col_labels = [
|
422
|
+
format_col_label_func(c, pattern="___") if format_col_label_func else c
|
423
|
+
for c in columns
|
424
|
+
]
|
425
|
+
row_labels = [
|
426
|
+
format_row_label_func(r) if format_row_label_func else r for r in rows
|
427
|
+
]
|
428
|
+
|
429
|
+
fig.update_xaxes(
|
430
|
+
tickvals=list(range(n_cols)),
|
431
|
+
ticktext=col_labels,
|
432
|
+
range=[-0.5, n_cols - 0.5],
|
433
|
+
showgrid=False,
|
434
|
+
zeroline=False,
|
435
|
+
)
|
436
|
+
fig.update_yaxes(
|
437
|
+
tickvals=[-i for i in range(n_rows)],
|
438
|
+
ticktext=row_labels,
|
439
|
+
range=[-n_rows + 0.5, 0.5],
|
440
|
+
showgrid=False,
|
441
|
+
zeroline=False,
|
442
|
+
)
|
443
|
+
|
444
|
+
fig.update_layout(
|
445
|
+
width=max_width * n_cols + 200, # extra for labels
|
446
|
+
height=max_height * n_rows + 100,
|
447
|
+
title=title,
|
448
|
+
margin=dict(l=100, r=20, t=50, b=50),
|
449
|
+
)
|
450
|
+
|
451
|
+
fig.show()
|
452
|
+
|
453
|
+
|
454
|
+
@click.command()
|
455
|
+
@click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
|
456
|
+
@click.option(
|
457
|
+
"--outdir", "-o", type=str, default=str(desktop_path), help="output directory"
|
458
|
+
)
|
459
|
+
@click.option(
|
460
|
+
"--tags", "-t", multiple=True, type=str, default=[], help="tags for the csv files"
|
461
|
+
)
|
462
|
+
@click.option("--log", "-l", is_flag=True, help="log the csv files")
|
463
|
+
@click.option("--save_fig", "-s", is_flag=True, help="save the plot as file")
|
464
|
+
@click.option(
|
465
|
+
"--update_in_min",
|
466
|
+
"-u",
|
467
|
+
type=float,
|
468
|
+
default=0.0,
|
469
|
+
help="update the plot every x minutes",
|
470
|
+
)
|
471
|
+
@click.option(
|
472
|
+
"--x_col", "-x", type=str, default="epoch", help="column to use as x-axis"
|
473
|
+
)
|
474
|
+
@click.option(
|
475
|
+
"--y_cols",
|
476
|
+
"-y",
|
477
|
+
multiple=True,
|
478
|
+
type=str,
|
479
|
+
required=True,
|
480
|
+
help="columns to plot as y (can repeat)",
|
481
|
+
)
|
482
|
+
def main(csvfiles, outdir, tags, log, save_fig, update_in_min, x_col, y_cols):
|
483
|
+
PlotHelper.plot_csv_timeseries(
|
484
|
+
list(csvfiles),
|
485
|
+
outdir,
|
486
|
+
list(tags),
|
487
|
+
log,
|
488
|
+
save_fig,
|
489
|
+
update_in_min,
|
490
|
+
x_col=x_col,
|
491
|
+
y_cols=list(y_cols),
|
492
|
+
)
|
493
|
+
|
494
|
+
|
495
|
+
if __name__ == "__main__":
|
496
|
+
main()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: halib
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.86
|
4
4
|
Summary: Small library for common tasks
|
5
5
|
Author: Hoang Van Ha
|
6
6
|
Author-email: hoangvanhauit@gmail.com
|
@@ -52,7 +52,7 @@ Dynamic: summary
|
|
52
52
|
|
53
53
|
Helper package for coding and automation
|
54
54
|
|
55
|
-
**Version 0.1.
|
55
|
+
**Version 0.1.86**
|
56
56
|
|
57
57
|
+ `research/profiler`: add `zProfiler` class to measure execution time of contexts and steps, with support for dynamic color scales in plots.
|
58
58
|
|
@@ -1,301 +0,0 @@
|
|
1
|
-
from ..common import now_str, norm_str, ConsoleLog
|
2
|
-
from ..filetype import csvfile
|
3
|
-
from ..system import filesys as fs
|
4
|
-
from functools import partial
|
5
|
-
from rich.console import Console
|
6
|
-
from rich.pretty import pprint
|
7
|
-
import click
|
8
|
-
import csv
|
9
|
-
import matplotlib
|
10
|
-
import matplotlib.pyplot as plt
|
11
|
-
import numpy as np
|
12
|
-
import os
|
13
|
-
import pandas as pd
|
14
|
-
import seaborn as sns
|
15
|
-
|
16
|
-
|
17
|
-
console = Console()
|
18
|
-
desktop_path = os.path.expanduser("~/Desktop")
|
19
|
-
REQUIRED_COLUMNS = ["epoch", "train_loss", "val_loss", "train_acc", "val_acc"]
|
20
|
-
|
21
|
-
import csv
|
22
|
-
|
23
|
-
|
24
|
-
def get_delimiter(file_path, bytes=4096):
|
25
|
-
sniffer = csv.Sniffer()
|
26
|
-
data = open(file_path, "r").read(bytes)
|
27
|
-
delimiter = sniffer.sniff(data).delimiter
|
28
|
-
return delimiter
|
29
|
-
|
30
|
-
|
31
|
-
# Function to verify that the DataFrame has the required columns, and only the required columns
|
32
|
-
def verify_csv(csv_file, required_columns=REQUIRED_COLUMNS):
|
33
|
-
delimiter = get_delimiter(csv_file)
|
34
|
-
df = pd.read_csv(csv_file, sep=delimiter)
|
35
|
-
# change the column names to lower case
|
36
|
-
df.columns = [col.lower() for col in df.columns]
|
37
|
-
for col in required_columns:
|
38
|
-
if col not in df.columns:
|
39
|
-
raise ValueError(
|
40
|
-
f"Required columns are: {REQUIRED_COLUMNS}, but found {df.columns}"
|
41
|
-
)
|
42
|
-
df = df[required_columns].copy()
|
43
|
-
return df
|
44
|
-
|
45
|
-
|
46
|
-
def get_valid_tags(csv_files, tags):
|
47
|
-
if tags is not None and len(tags) > 0:
|
48
|
-
assert all(
|
49
|
-
isinstance(tag, str) for tag in tags
|
50
|
-
), "tags must be a list of strings"
|
51
|
-
assert all(
|
52
|
-
len(tag) > 0 for tag in tags
|
53
|
-
), "tags must be a list of non-empty strings"
|
54
|
-
valid_tags = tags
|
55
|
-
else:
|
56
|
-
valid_tags = []
|
57
|
-
for csv_file in csv_files:
|
58
|
-
file_name = fs.get_file_name(csv_file, split_file_ext=True)[0]
|
59
|
-
tag = norm_str(file_name)
|
60
|
-
valid_tags.append(tag)
|
61
|
-
return valid_tags
|
62
|
-
|
63
|
-
|
64
|
-
def plot_ax(df, ax, metric="loss", tag=""):
|
65
|
-
pprint(locals())
|
66
|
-
# reset plt
|
67
|
-
assert metric in ["loss", "acc"], "metric must be either 'loss' or 'acc'"
|
68
|
-
part = ["train", "val"]
|
69
|
-
for p in part:
|
70
|
-
label = f"{tag}_{p}_{metric}"
|
71
|
-
ax.plot(df["epoch"], df[f"{p}_{metric}"], label=label)
|
72
|
-
return ax
|
73
|
-
|
74
|
-
|
75
|
-
def actual_plot_seaborn(frame, csv_files, axes, tags, log):
|
76
|
-
# clear the axes
|
77
|
-
for ax in axes:
|
78
|
-
ax.clear()
|
79
|
-
ls_df = []
|
80
|
-
valid_tags = get_valid_tags(csv_files, tags)
|
81
|
-
for csv_file in csv_files:
|
82
|
-
df = verify_csv(csv_file)
|
83
|
-
if log:
|
84
|
-
with ConsoleLog(f"plotting {csv_file}"):
|
85
|
-
csvfile.fn_display_df(df)
|
86
|
-
ls_df.append(df)
|
87
|
-
|
88
|
-
ls_metrics = ["loss", "acc"]
|
89
|
-
for df_item, tag in zip(ls_df, valid_tags):
|
90
|
-
# add tag to columns,excpet epoch
|
91
|
-
df_item.columns = [
|
92
|
-
f"{tag}_{col}" if col != "epoch" else col for col in df_item.columns
|
93
|
-
]
|
94
|
-
# merge the dataframes on the epoch column
|
95
|
-
df_combined = ls_df[0]
|
96
|
-
for df_item in ls_df[1:]:
|
97
|
-
df_combined = pd.merge(df_combined, df_item, on="epoch", how="outer")
|
98
|
-
# csvfile.fn_display_df(df_combined)
|
99
|
-
|
100
|
-
for i, metric in enumerate(ls_metrics):
|
101
|
-
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
102
|
-
title = f"{tags_str}_{metric}-by-epoch"
|
103
|
-
cols = [col for col in df_combined.columns if col != "epoch" and metric in col]
|
104
|
-
cols = sorted(cols)
|
105
|
-
# pprint(cols)
|
106
|
-
plot_data = df_combined[cols]
|
107
|
-
|
108
|
-
# line from same csv file (same tag) should have the same marker
|
109
|
-
all_markers = [
|
110
|
-
marker for marker in plt.Line2D.markers if marker and marker != " "
|
111
|
-
]
|
112
|
-
tag2marker = {tag: marker for tag, marker in zip(valid_tags, all_markers)}
|
113
|
-
plot_markers = []
|
114
|
-
for col in cols:
|
115
|
-
# find the tag:
|
116
|
-
tag = None
|
117
|
-
for valid_tag in valid_tags:
|
118
|
-
if valid_tag in col:
|
119
|
-
tag = valid_tag
|
120
|
-
break
|
121
|
-
plot_markers.append(tag2marker[tag])
|
122
|
-
# pprint(list(zip(cols, plot_markers)))
|
123
|
-
|
124
|
-
# create color
|
125
|
-
sequential_palettes = [
|
126
|
-
"Reds",
|
127
|
-
"Greens",
|
128
|
-
"Blues",
|
129
|
-
"Oranges",
|
130
|
-
"Purples",
|
131
|
-
"Greys",
|
132
|
-
"BuGn",
|
133
|
-
"BuPu",
|
134
|
-
"GnBu",
|
135
|
-
"OrRd",
|
136
|
-
"PuBu",
|
137
|
-
"PuRd",
|
138
|
-
"RdPu",
|
139
|
-
"YlGn",
|
140
|
-
"PuBuGn",
|
141
|
-
"YlGnBu",
|
142
|
-
"YlOrBr",
|
143
|
-
"YlOrRd",
|
144
|
-
]
|
145
|
-
# each csvfile (tag) should have a unique color
|
146
|
-
tag2palette = {
|
147
|
-
tag: palette for tag, palette in zip(valid_tags, sequential_palettes)
|
148
|
-
}
|
149
|
-
plot_colors = []
|
150
|
-
for tag in valid_tags:
|
151
|
-
palette = tag2palette[tag]
|
152
|
-
total_colors = 10
|
153
|
-
ls_colors = sns.color_palette(palette, total_colors).as_hex()
|
154
|
-
num_part = len(ls_metrics)
|
155
|
-
subarr = np.array_split(np.arange(total_colors), num_part)
|
156
|
-
for idx, col in enumerate(cols):
|
157
|
-
if tag in col:
|
158
|
-
chosen_color = ls_colors[
|
159
|
-
subarr[int(idx % num_part)].mean().astype(int)
|
160
|
-
]
|
161
|
-
plot_colors.append(chosen_color)
|
162
|
-
|
163
|
-
# pprint(list(zip(cols, plot_colors)))
|
164
|
-
sns.lineplot(
|
165
|
-
data=plot_data,
|
166
|
-
markers=plot_markers,
|
167
|
-
palette=plot_colors,
|
168
|
-
ax=axes[i],
|
169
|
-
dashes=False,
|
170
|
-
)
|
171
|
-
axes[i].set(xlabel="epoch", ylabel=metric, title=title)
|
172
|
-
axes[i].legend()
|
173
|
-
axes[i].grid()
|
174
|
-
|
175
|
-
|
176
|
-
def actual_plot(frame, csv_files, axes, tags, log):
|
177
|
-
ls_df = []
|
178
|
-
valid_tags = get_valid_tags(csv_files, tags)
|
179
|
-
for csv_file in csv_files:
|
180
|
-
df = verify_csv(csv_file)
|
181
|
-
if log:
|
182
|
-
with ConsoleLog(f"plotting {csv_file}"):
|
183
|
-
csvfile.fn_display_df(df)
|
184
|
-
ls_df.append(df)
|
185
|
-
|
186
|
-
metric_values = ["loss", "acc"]
|
187
|
-
for i, metric in enumerate(metric_values):
|
188
|
-
for df_item, tag in zip(ls_df, valid_tags):
|
189
|
-
metric_ax = plot_ax(df_item, axes[i], metric, tag)
|
190
|
-
|
191
|
-
# set the title, xlabel, ylabel, legend, and grid
|
192
|
-
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
193
|
-
metric_ax.set(
|
194
|
-
xlabel="epoch", ylabel=metric, title=f"{tags_str}_{metric}-by-epoch"
|
195
|
-
)
|
196
|
-
metric_ax.legend()
|
197
|
-
metric_ax.grid()
|
198
|
-
|
199
|
-
|
200
|
-
def plot_csv_files(
|
201
|
-
csv_files,
|
202
|
-
outdir="./out/plot",
|
203
|
-
tags=None,
|
204
|
-
log=False,
|
205
|
-
save_fig=False,
|
206
|
-
update_in_min=1,
|
207
|
-
):
|
208
|
-
# if csv_files is a string, convert it to a list
|
209
|
-
if isinstance(csv_files, str):
|
210
|
-
csv_files = [csv_files]
|
211
|
-
# if tags is a string, convert it to a list
|
212
|
-
if isinstance(tags, str):
|
213
|
-
tags = [tags]
|
214
|
-
valid_tags = get_valid_tags(csv_files, tags)
|
215
|
-
assert len(valid_tags) == len(
|
216
|
-
csv_files
|
217
|
-
), "Unable to determine tags for each csv file"
|
218
|
-
live_update_in_ms = int(update_in_min * 60 * 1000)
|
219
|
-
fig, axes = plt.subplots(2, 1, figsize=(10, 17))
|
220
|
-
if live_update_in_ms: # live update in min should be > 0
|
221
|
-
from matplotlib.animation import FuncAnimation
|
222
|
-
|
223
|
-
anim = FuncAnimation(
|
224
|
-
fig,
|
225
|
-
partial(
|
226
|
-
actual_plot_seaborn, csv_files=csv_files, axes=axes, tags=tags, log=log
|
227
|
-
),
|
228
|
-
interval=live_update_in_ms,
|
229
|
-
blit=False,
|
230
|
-
cache_frame_data=False,
|
231
|
-
)
|
232
|
-
plt.show()
|
233
|
-
else:
|
234
|
-
actual_plot_seaborn(None, csv_files, axes, tags, log)
|
235
|
-
plt.show()
|
236
|
-
|
237
|
-
if save_fig:
|
238
|
-
os.makedirs(outdir, exist_ok=True)
|
239
|
-
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
240
|
-
tag = f"{now_str()}_{tags_str}"
|
241
|
-
fig.savefig(f"{outdir}/{tag}_plot.png")
|
242
|
-
enable_plot_pgf()
|
243
|
-
fig.savefig(f"{outdir}/{tag}_plot.pdf")
|
244
|
-
if live_update_in_ms:
|
245
|
-
return anim
|
246
|
-
|
247
|
-
|
248
|
-
def enable_plot_pgf():
|
249
|
-
matplotlib.use("pdf")
|
250
|
-
matplotlib.rcParams.update(
|
251
|
-
{
|
252
|
-
"pgf.texsystem": "pdflatex",
|
253
|
-
"font.family": "serif",
|
254
|
-
"text.usetex": True,
|
255
|
-
"pgf.rcfonts": False,
|
256
|
-
}
|
257
|
-
)
|
258
|
-
|
259
|
-
|
260
|
-
def save_fig_latex_pgf(filename, directory="."):
|
261
|
-
enable_plot_pgf()
|
262
|
-
if ".pgf" not in filename:
|
263
|
-
filename = f"{directory}/{filename}.pgf"
|
264
|
-
plt.savefig(filename)
|
265
|
-
|
266
|
-
|
267
|
-
# https: // click.palletsprojects.com/en/8.1.x/api/
|
268
|
-
@click.command()
|
269
|
-
@click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
|
270
|
-
@click.option(
|
271
|
-
"--outdir",
|
272
|
-
"-o",
|
273
|
-
type=str,
|
274
|
-
help="output directory for the plot",
|
275
|
-
default=str(desktop_path),
|
276
|
-
)
|
277
|
-
@click.option(
|
278
|
-
"--tags", "-t", multiple=True, type=str, help="tags for the csv files", default=[]
|
279
|
-
)
|
280
|
-
@click.option("--log", "-l", is_flag=True, help="log the csv files")
|
281
|
-
@click.option("--save_fig", "-s", is_flag=True, help="save the plot as a file")
|
282
|
-
@click.option(
|
283
|
-
"--update_in_min",
|
284
|
-
"-u",
|
285
|
-
type=float,
|
286
|
-
help="update the plot every x minutes",
|
287
|
-
default=0.0,
|
288
|
-
)
|
289
|
-
def main(
|
290
|
-
csvfiles,
|
291
|
-
outdir,
|
292
|
-
tags,
|
293
|
-
log,
|
294
|
-
save_fig,
|
295
|
-
update_in_min,
|
296
|
-
):
|
297
|
-
plot_csv_files(list(csvfiles), outdir, list(tags), log, save_fig, update_in_min)
|
298
|
-
|
299
|
-
|
300
|
-
if __name__ == "__main__":
|
301
|
-
main()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|