halib 0.1.85__py3-none-any.whl → 0.1.87__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/filetype/csvfile.py +29 -0
- halib/research/base_exp.py +1 -0
- halib/research/mics.py +16 -0
- halib/research/plot.py +456 -261
- {halib-0.1.85.dist-info → halib-0.1.87.dist-info}/METADATA +2 -2
- {halib-0.1.85.dist-info → halib-0.1.87.dist-info}/RECORD +9 -8
- {halib-0.1.85.dist-info → halib-0.1.87.dist-info}/WHEEL +0 -0
- {halib-0.1.85.dist-info → halib-0.1.87.dist-info}/licenses/LICENSE.txt +0 -0
- {halib-0.1.85.dist-info → halib-0.1.87.dist-info}/top_level.txt +0 -0
halib/filetype/csvfile.py
CHANGED
@@ -9,6 +9,7 @@ from loguru import logger
|
|
9
9
|
from itables import init_notebook_mode, show
|
10
10
|
import pygwalker as pyg
|
11
11
|
import textwrap
|
12
|
+
import csv
|
12
13
|
|
13
14
|
console = Console()
|
14
15
|
|
@@ -18,6 +19,34 @@ def read(file, separator=","):
|
|
18
19
|
return df
|
19
20
|
|
20
21
|
|
22
|
+
def read_auto_sep(filepath, sample_size=2048, **kwargs):
|
23
|
+
"""
|
24
|
+
Read a CSV file with automatic delimiter detection.
|
25
|
+
|
26
|
+
Parameters
|
27
|
+
----------
|
28
|
+
filepath : str
|
29
|
+
Path to the CSV file.
|
30
|
+
sample_size : int, optional
|
31
|
+
Number of bytes to read for delimiter sniffing.
|
32
|
+
**kwargs : dict
|
33
|
+
Extra keyword args passed to pandas.read_csv.
|
34
|
+
|
35
|
+
Returns
|
36
|
+
-------
|
37
|
+
df : pandas.DataFrame
|
38
|
+
"""
|
39
|
+
with open(filepath, "r", newline="", encoding=kwargs.get("encoding", "utf-8")) as f:
|
40
|
+
sample = f.read(sample_size)
|
41
|
+
f.seek(0)
|
42
|
+
try:
|
43
|
+
dialect = csv.Sniffer().sniff(sample, delimiters=[",", ";", "\t", "|", ":"])
|
44
|
+
sep = dialect.delimiter
|
45
|
+
except csv.Error:
|
46
|
+
sep = "," # fallback if detection fails
|
47
|
+
|
48
|
+
return pd.read_csv(filepath, sep=sep, **kwargs)
|
49
|
+
|
21
50
|
# for append, mode = 'a'
|
22
51
|
def fn_write(df, outfile, mode="w", header=True, index_label=None):
|
23
52
|
if not outfile.endswith(".csv"):
|
halib/research/base_exp.py
CHANGED
@@ -82,6 +82,7 @@ class BaseExperiment(PerfCalc, ABC):
|
|
82
82
|
"""
|
83
83
|
self.init_general(self.config.get_general_cfg())
|
84
84
|
self.prepare_dataset(self.config.get_dataset_cfg())
|
85
|
+
self.prepare_metrics(self.config.get_metric_cfg())
|
85
86
|
|
86
87
|
# Save config before running
|
87
88
|
self.config.save_to_outdir()
|
halib/research/mics.py
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
import platform
|
2
|
+
|
3
|
+
PC_NAME_TO_ABBR = {
|
4
|
+
"DESKTOP-JQD9K01": "MainPC",
|
5
|
+
"DESKTOP-5IRHU87": "MSI_Laptop",
|
6
|
+
"DESKTOP-96HQCNO": "4090_SV",
|
7
|
+
"DESKTOP-Q2IKLC0": "4GPU_SV",
|
8
|
+
"DESKTOP-QNS3DNF": "1GPU_SV"
|
9
|
+
}
|
10
|
+
|
11
|
+
def get_PC_name():
|
12
|
+
return platform.node()
|
13
|
+
|
14
|
+
def get_PC_abbr_name():
|
15
|
+
pc_name = get_PC_name()
|
16
|
+
return PC_NAME_TO_ABBR.get(pc_name, "Unknown")
|
halib/research/plot.py
CHANGED
@@ -1,300 +1,495 @@
|
|
1
|
+
import os
|
2
|
+
import pandas as pd
|
3
|
+
import plotly.express as px
|
4
|
+
from rich.console import Console
|
1
5
|
from ..common import now_str, norm_str, ConsoleLog
|
2
6
|
from ..filetype import csvfile
|
3
7
|
from ..system import filesys as fs
|
4
|
-
from functools import partial
|
5
|
-
from rich.console import Console
|
6
|
-
from rich.pretty import pprint
|
7
8
|
import click
|
8
|
-
import
|
9
|
-
|
10
|
-
import matplotlib.pyplot as plt
|
11
|
-
import numpy as np
|
12
|
-
import os
|
9
|
+
import time
|
10
|
+
|
13
11
|
import pandas as pd
|
14
|
-
import
|
12
|
+
import plotly.graph_objects as go
|
13
|
+
from PIL import Image
|
14
|
+
import base64
|
15
|
+
from io import BytesIO
|
16
|
+
from typing import Callable, Optional, Tuple, List, Union
|
15
17
|
|
16
18
|
|
17
19
|
console = Console()
|
18
20
|
desktop_path = os.path.expanduser("~/Desktop")
|
19
|
-
REQUIRED_COLUMNS = ["epoch", "train_loss", "val_loss", "train_acc", "val_acc"]
|
20
|
-
|
21
|
-
import csv
|
22
21
|
|
22
|
+
class PlotHelper:
|
23
|
+
def _verify_csv(self, csv_file):
|
24
|
+
"""Read a CSV and normalize column names (lowercase)."""
|
25
|
+
try:
|
26
|
+
df = csvfile.read_auto_sep(csv_file)
|
27
|
+
df.columns = [col.lower() for col in df.columns]
|
28
|
+
return df
|
29
|
+
except FileNotFoundError:
|
30
|
+
raise FileNotFoundError(f"CSV file '{csv_file}' not found")
|
31
|
+
except Exception as e:
|
32
|
+
raise ValueError(f"Error reading CSV file '{csv_file}': {str(e)}")
|
33
|
+
|
34
|
+
@staticmethod
|
35
|
+
def _norm_str(s):
|
36
|
+
"""Normalize string by converting to lowercase and replacing spaces/underscores."""
|
37
|
+
return s.lower().replace(" ", "_").replace("-", "_")
|
38
|
+
|
39
|
+
@staticmethod
|
40
|
+
def _get_file_name(file_path):
|
41
|
+
"""Extract file name without extension."""
|
42
|
+
return os.path.splitext(os.path.basename(file_path))[0]
|
43
|
+
|
44
|
+
def _get_valid_tags(self, csv_files, tags):
|
45
|
+
"""Generate tags from file names if not provided."""
|
46
|
+
if tags:
|
47
|
+
return list(tags)
|
48
|
+
return [self._norm_str(self._get_file_name(f)) for f in csv_files]
|
49
|
+
|
50
|
+
def _prepare_long_df(self, csv_files, tags, x_col, y_cols, log=False):
|
51
|
+
"""Convert multiple CSVs into a single long-form dataframe for Plotly."""
|
52
|
+
dfs = []
|
53
|
+
for csv_file, tag in zip(csv_files, tags):
|
54
|
+
df = self._verify_csv(csv_file)
|
55
|
+
# Check columns
|
56
|
+
if x_col not in df.columns:
|
57
|
+
raise ValueError(f"{csv_file} is missing x_col '{x_col}'")
|
58
|
+
missing = [c for c in y_cols if c not in df.columns]
|
59
|
+
if missing:
|
60
|
+
raise ValueError(f"{csv_file} is missing y_cols {missing}")
|
61
|
+
|
62
|
+
if log:
|
63
|
+
console.log(f"Plotting {csv_file}")
|
64
|
+
console.print(df)
|
65
|
+
|
66
|
+
# Wide to long
|
67
|
+
df_long = df.melt(
|
68
|
+
id_vars=x_col,
|
69
|
+
value_vars=y_cols,
|
70
|
+
var_name="metric_type",
|
71
|
+
value_name="value",
|
72
|
+
)
|
73
|
+
df_long["tag"] = tag
|
74
|
+
dfs.append(df_long)
|
75
|
+
|
76
|
+
return pd.concat(dfs, ignore_index=True)
|
77
|
+
|
78
|
+
def _plot_with_plotly(
|
79
|
+
self,
|
80
|
+
df_long,
|
81
|
+
tags,
|
82
|
+
outdir,
|
83
|
+
save_fig,
|
84
|
+
out_fmt="svg",
|
85
|
+
font_size=16,
|
86
|
+
x_col="epoch",
|
87
|
+
y_cols=None,
|
88
|
+
):
|
89
|
+
"""Generate Plotly plots for given metrics."""
|
90
|
+
assert out_fmt in ["svg", "pdf", "png"], "Unsupported format"
|
91
|
+
if y_cols is None:
|
92
|
+
raise ValueError("y_cols must be provided")
|
93
|
+
|
94
|
+
# Group by suffix (e.g., "loss", "acc") if names like train_loss exist
|
95
|
+
metric_groups = sorted(set(col.split("_")[-1] for col in y_cols))
|
96
|
+
|
97
|
+
for metric in metric_groups:
|
98
|
+
subset = df_long[df_long["metric_type"].str.contains(metric)]
|
99
|
+
|
100
|
+
if out_fmt == "svg": # LaTeX-style
|
101
|
+
title = f"${'+'.join(tags)}\\_{metric}\\text{{-by-{x_col}}}$"
|
102
|
+
xaxis_title = f"$\\text{{{x_col.capitalize()}}}$"
|
103
|
+
yaxis_title = f"${metric.capitalize()}$"
|
104
|
+
else:
|
105
|
+
title = f"{'+'.join(tags)}_{metric}-by-{x_col}"
|
106
|
+
xaxis_title = x_col.capitalize()
|
107
|
+
yaxis_title = metric.capitalize()
|
108
|
+
|
109
|
+
fig = px.line(
|
110
|
+
subset,
|
111
|
+
x=x_col,
|
112
|
+
y="value",
|
113
|
+
color="tag",
|
114
|
+
line_dash="metric_type",
|
115
|
+
title=title,
|
116
|
+
)
|
117
|
+
fig.update_layout(
|
118
|
+
font=dict(family="Computer Modern", size=font_size),
|
119
|
+
xaxis_title=xaxis_title,
|
120
|
+
yaxis_title=yaxis_title,
|
121
|
+
)
|
122
|
+
fig.show()
|
123
|
+
|
124
|
+
if save_fig:
|
125
|
+
os.makedirs(outdir, exist_ok=True)
|
126
|
+
timestamp = now_str()
|
127
|
+
filename = f"{timestamp}_{'+'.join(tags)}_{metric}"
|
128
|
+
try:
|
129
|
+
fig.write_image(os.path.join(outdir, f"{filename}.{out_fmt}"))
|
130
|
+
except Exception as e:
|
131
|
+
console.log(f"Error saving figure '{filename}.{out_fmt}': {str(e)}")
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
def plot_csv_timeseries(
|
135
|
+
cls,
|
136
|
+
csv_files,
|
137
|
+
outdir="./out/plot",
|
138
|
+
tags=None,
|
139
|
+
log=False,
|
140
|
+
save_fig=False,
|
141
|
+
update_in_min=0,
|
142
|
+
out_fmt="svg",
|
143
|
+
font_size=16,
|
144
|
+
x_col="epoch",
|
145
|
+
y_cols=["train_loss", "train_acc"],
|
146
|
+
):
|
147
|
+
"""Plot CSV files with Plotly, supporting live updates, as a class method."""
|
148
|
+
if isinstance(csv_files, str):
|
149
|
+
csv_files = [csv_files]
|
150
|
+
if isinstance(tags, str):
|
151
|
+
tags = [tags]
|
152
|
+
|
153
|
+
if not y_cols:
|
154
|
+
raise ValueError("You must specify y_cols explicitly")
|
155
|
+
|
156
|
+
# Instantiate PlotHelper to call instance methods
|
157
|
+
plot_helper = cls()
|
158
|
+
valid_tags = plot_helper._get_valid_tags(csv_files, tags)
|
159
|
+
assert len(valid_tags) == len(
|
160
|
+
csv_files
|
161
|
+
), "Number of tags must match number of CSV files"
|
162
|
+
|
163
|
+
def run_once():
|
164
|
+
df_long = plot_helper._prepare_long_df(
|
165
|
+
csv_files, valid_tags, x_col, y_cols, log
|
166
|
+
)
|
167
|
+
plot_helper._plot_with_plotly(
|
168
|
+
df_long, valid_tags, outdir, save_fig, out_fmt, font_size, x_col, y_cols
|
169
|
+
)
|
23
170
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
171
|
+
if update_in_min > 0:
|
172
|
+
interval = int(update_in_min * 60)
|
173
|
+
console.log(f"Live update every {interval}s. Press Ctrl+C to stop.")
|
174
|
+
try:
|
175
|
+
while True:
|
176
|
+
run_once()
|
177
|
+
time.sleep(interval)
|
178
|
+
except KeyboardInterrupt:
|
179
|
+
console.log("Stopped live updates.")
|
180
|
+
else:
|
181
|
+
run_once()
|
182
|
+
@staticmethod
|
183
|
+
def plot_image_grid(csv_path, sep=";", max_width=300, max_height=300):
|
184
|
+
"""
|
185
|
+
Plot a grid of images using Plotly from a CSV file.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
csv_path (str): Path to CSV file.
|
189
|
+
max_width (int): Maximum width of each image in pixels.
|
190
|
+
max_height (int): Maximum height of each image in pixels.
|
191
|
+
"""
|
192
|
+
# Load CSV
|
193
|
+
df = csvfile.read_auto_sep(csv_path, sep=sep)
|
194
|
+
|
195
|
+
# Column names for headers
|
196
|
+
col_names = df.columns.tolist()
|
197
|
+
|
198
|
+
# Function to convert image to base64
|
199
|
+
def pil_to_base64(img_path):
|
200
|
+
with Image.open(img_path) as im:
|
201
|
+
im.thumbnail((max_width, max_height))
|
202
|
+
buffer = BytesIO()
|
203
|
+
im.save(buffer, format="PNG")
|
204
|
+
encoded = base64.b64encode(buffer.getvalue()).decode()
|
205
|
+
return "data:image/png;base64," + encoded
|
206
|
+
|
207
|
+
# Initialize figure
|
208
|
+
fig = go.Figure()
|
209
|
+
|
210
|
+
n_rows = len(df)
|
211
|
+
n_cols = len(df.columns) - 1 # skip label column
|
212
|
+
|
213
|
+
# Add images
|
214
|
+
for i, row in df.iterrows():
|
215
|
+
for j, col in enumerate(df.columns[1:]):
|
216
|
+
img_path = row[col]
|
217
|
+
img_src = pil_to_base64(img_path)
|
218
|
+
fig.add_layout_image(
|
219
|
+
dict(
|
220
|
+
source=img_src,
|
221
|
+
x=j,
|
222
|
+
y=-i, # negative to have row 0 on top
|
223
|
+
xref="x",
|
224
|
+
yref="y",
|
225
|
+
sizex=1,
|
226
|
+
sizey=1,
|
227
|
+
xanchor="left",
|
228
|
+
yanchor="top",
|
229
|
+
layer="above"
|
230
|
+
)
|
231
|
+
)
|
232
|
+
|
233
|
+
# Set axes for grid layout
|
234
|
+
fig.update_xaxes(
|
235
|
+
tickvals=list(range(n_cols)),
|
236
|
+
ticktext=list(df.columns[1:]),
|
237
|
+
range=[-0.5, n_cols-0.5],
|
238
|
+
showgrid=False,
|
239
|
+
zeroline=False
|
240
|
+
)
|
241
|
+
fig.update_yaxes(
|
242
|
+
tickvals=[-i for i in range(n_rows)],
|
243
|
+
ticktext=df[df.columns[0]],
|
244
|
+
range=[-n_rows + 0.5, 0.5],
|
245
|
+
showgrid=False,
|
246
|
+
zeroline=False
|
247
|
+
)
|
29
248
|
|
249
|
+
fig.update_layout(
|
250
|
+
width=max_width*n_cols,
|
251
|
+
height=max_height*n_rows,
|
252
|
+
margin=dict(l=100, r=20, t=50, b=50)
|
253
|
+
)
|
30
254
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
)
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
df = verify_csv(csv_file)
|
255
|
+
fig.show()
|
256
|
+
|
257
|
+
@staticmethod
|
258
|
+
# this plot_df contains the data to be plotted (row, column)
|
259
|
+
def img_grid_df(input_dir, log=False):
|
260
|
+
rows = fs.list_dirs(input_dir)
|
261
|
+
rows = [r for r in rows if r.startswith("row")]
|
262
|
+
meta_dict = {}
|
263
|
+
cols_of_row = None
|
264
|
+
for row in rows:
|
265
|
+
row_path = os.path.join(input_dir, row)
|
266
|
+
cols = sorted(fs.list_dirs(row_path))
|
267
|
+
if cols_of_row is None:
|
268
|
+
cols_of_row = cols
|
269
|
+
else:
|
270
|
+
if cols_of_row != cols:
|
271
|
+
raise ValueError(
|
272
|
+
f"Row {row} has different columns than previous rows: {cols_of_row} vs {cols}"
|
273
|
+
)
|
274
|
+
meta_dict[row] = cols
|
275
|
+
|
276
|
+
meta_dict_with_paths = {}
|
277
|
+
for row, cols in meta_dict.items():
|
278
|
+
meta_dict_with_paths[row] = {
|
279
|
+
col: fs.filter_files_by_extension(
|
280
|
+
os.path.join(input_dir, row, col), ["png", "jpg", "jpeg"]
|
281
|
+
)
|
282
|
+
for col in cols
|
283
|
+
}
|
284
|
+
first_row = list(meta_dict_with_paths.keys())[0]
|
285
|
+
first_col = list(meta_dict_with_paths[first_row].keys())[0]
|
286
|
+
len_first_col = len(meta_dict_with_paths[first_row][first_col])
|
287
|
+
for row, cols in meta_dict_with_paths.items():
|
288
|
+
for col, paths in cols.items():
|
289
|
+
if len(paths) != len_first_col:
|
290
|
+
raise ValueError(
|
291
|
+
f"Row {row}, Column {col} has different number of files: {len(paths)} vs {len_first_col}"
|
292
|
+
)
|
293
|
+
cols = sorted(meta_dict_with_paths[first_row].keys())
|
294
|
+
rows_set = sorted(meta_dict_with_paths.keys())
|
295
|
+
row_per_col = len(meta_dict_with_paths[first_row][first_col])
|
296
|
+
rows = [item for item in rows_set for _ in range(row_per_col)]
|
297
|
+
data_dict = {}
|
298
|
+
data_dict["row"] = rows
|
299
|
+
col_data = {col: [] for col in cols}
|
300
|
+
for row_base in rows_set:
|
301
|
+
for col in cols:
|
302
|
+
for i in range(row_per_col):
|
303
|
+
col_data[col].append(meta_dict_with_paths[row_base][col][i])
|
304
|
+
data_dict.update(col_data)
|
305
|
+
df = pd.DataFrame(data_dict)
|
83
306
|
if log:
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
307
|
+
csvfile.fn_display_df(df)
|
308
|
+
return df
|
309
|
+
|
310
|
+
@staticmethod
|
311
|
+
def plot_image_grid(
|
312
|
+
csv_file_or_df: Union[str, pd.DataFrame],
|
313
|
+
max_width: int = 300,
|
314
|
+
max_height: int = 300,
|
315
|
+
img_stack_direction: str = "horizontal",
|
316
|
+
img_stack_padding_px: int = 10,
|
317
|
+
format_row_label_func: Optional[Callable[[str], str]] = None,
|
318
|
+
format_col_label_func: Optional[Callable[[str, str], str]] = None,
|
319
|
+
title: str = "",
|
320
|
+
):
|
321
|
+
"""
|
322
|
+
Plot a grid of images using Plotly from a DataFrame.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
df (pd.DataFrame): DataFrame with first column as row labels, remaining columns as image paths.
|
326
|
+
max_width (int): Maximum width of stacked images per cell in pixels.
|
327
|
+
max_height (int): Maximum height of stacked images per cell in pixels.
|
328
|
+
img_stack_direction (str): "horizontal" or "vertical" stacking.
|
329
|
+
img_stack_padding_px (int): Padding between stacked images in pixels.
|
330
|
+
format_row_label_func (Callable): Function to format row labels.
|
331
|
+
format_col_label_func (Callable): Function to format column labels.
|
332
|
+
title (str): Figure title.
|
333
|
+
"""
|
334
|
+
|
335
|
+
def stack_images_base64(
|
336
|
+
image_paths: List[str], direction: str, target_size: Tuple[int, int]
|
337
|
+
) -> str:
|
338
|
+
"""Stack images and return base64-encoded PNG."""
|
339
|
+
if not image_paths:
|
340
|
+
return ""
|
341
|
+
|
342
|
+
processed_images = []
|
343
|
+
for path in image_paths:
|
344
|
+
try:
|
345
|
+
img = Image.open(path).convert("RGB")
|
346
|
+
img.thumbnail(target_size, Image.Resampling.LANCZOS)
|
347
|
+
processed_images.append(img)
|
348
|
+
except:
|
349
|
+
# blank image if error
|
350
|
+
processed_images.append(Image.new("RGB", target_size, (255, 255, 255)))
|
351
|
+
|
352
|
+
# Stack
|
353
|
+
widths, heights = zip(*(img.size for img in processed_images))
|
354
|
+
if direction == "horizontal":
|
355
|
+
total_width = sum(widths) + img_stack_padding_px * (
|
356
|
+
len(processed_images) - 1
|
357
|
+
)
|
358
|
+
total_height = max(heights)
|
359
|
+
stacked = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
360
|
+
x_offset = 0
|
361
|
+
for im in processed_images:
|
362
|
+
stacked.paste(im, (x_offset, 0))
|
363
|
+
x_offset += im.width + img_stack_padding_px
|
364
|
+
elif direction == "vertical":
|
365
|
+
total_width = max(widths)
|
366
|
+
total_height = sum(heights) + img_stack_padding_px * (
|
367
|
+
len(processed_images) - 1
|
368
|
+
)
|
369
|
+
stacked = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
370
|
+
y_offset = 0
|
371
|
+
for im in processed_images:
|
372
|
+
stacked.paste(im, (0, y_offset))
|
373
|
+
y_offset += im.height + img_stack_padding_px
|
374
|
+
else:
|
375
|
+
raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'")
|
376
|
+
|
377
|
+
# Encode as base64
|
378
|
+
buffer = BytesIO()
|
379
|
+
stacked.save(buffer, format="PNG")
|
380
|
+
encoded = base64.b64encode(buffer.getvalue()).decode()
|
381
|
+
return "data:image/png;base64," + encoded
|
382
|
+
|
383
|
+
# Load DataFrame if a file path is provided
|
384
|
+
if isinstance(csv_file_or_df, str):
|
385
|
+
df = csvfile.read_auto_sep(csv_file_or_df)
|
386
|
+
else:
|
387
|
+
df = csv_file_or_df
|
388
|
+
assert isinstance(df, pd.DataFrame), "Input must be a DataFrame or valid CSV file path"
|
389
|
+
|
390
|
+
rows = df[df.columns[0]].tolist()
|
391
|
+
columns = df.columns[1:].tolist()
|
392
|
+
n_rows, n_cols = len(rows), len(columns)
|
393
|
+
|
394
|
+
fig = go.Figure()
|
395
|
+
|
396
|
+
for i, row_label in enumerate(rows):
|
397
|
+
for j, col_label in enumerate(columns):
|
398
|
+
image_paths = df.loc[i, col_label]
|
399
|
+
if isinstance(image_paths, str):
|
400
|
+
image_paths = [image_paths]
|
401
|
+
img_src = stack_images_base64(
|
402
|
+
image_paths, img_stack_direction, (max_width, max_height)
|
403
|
+
)
|
404
|
+
|
405
|
+
fig.add_layout_image(
|
406
|
+
dict(
|
407
|
+
source=img_src,
|
408
|
+
x=j,
|
409
|
+
y=-i, # negative so row 0 on top
|
410
|
+
xref="x",
|
411
|
+
yref="y",
|
412
|
+
sizex=1,
|
413
|
+
sizey=1,
|
414
|
+
xanchor="left",
|
415
|
+
yanchor="top",
|
416
|
+
layer="above",
|
417
|
+
)
|
418
|
+
)
|
419
|
+
|
420
|
+
# Format axis labels
|
421
|
+
col_labels = [
|
422
|
+
format_col_label_func(c, pattern="___") if format_col_label_func else c
|
423
|
+
for c in columns
|
93
424
|
]
|
94
|
-
|
95
|
-
|
96
|
-
for df_item in ls_df[1:]:
|
97
|
-
df_combined = pd.merge(df_combined, df_item, on="epoch", how="outer")
|
98
|
-
# csvfile.fn_display_df(df_combined)
|
99
|
-
|
100
|
-
for i, metric in enumerate(ls_metrics):
|
101
|
-
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
102
|
-
title = f"{tags_str}_{metric}-by-epoch"
|
103
|
-
cols = [col for col in df_combined.columns if col != "epoch" and metric in col]
|
104
|
-
cols = sorted(cols)
|
105
|
-
# pprint(cols)
|
106
|
-
plot_data = df_combined[cols]
|
107
|
-
|
108
|
-
# line from same csv file (same tag) should have the same marker
|
109
|
-
all_markers = [
|
110
|
-
marker for marker in plt.Line2D.markers if marker and marker != " "
|
425
|
+
row_labels = [
|
426
|
+
format_row_label_func(r) if format_row_label_func else r for r in rows
|
111
427
|
]
|
112
|
-
tag2marker = {tag: marker for tag, marker in zip(valid_tags, all_markers)}
|
113
|
-
plot_markers = []
|
114
|
-
for col in cols:
|
115
|
-
# find the tag:
|
116
|
-
tag = None
|
117
|
-
for valid_tag in valid_tags:
|
118
|
-
if valid_tag in col:
|
119
|
-
tag = valid_tag
|
120
|
-
break
|
121
|
-
plot_markers.append(tag2marker[tag])
|
122
|
-
# pprint(list(zip(cols, plot_markers)))
|
123
|
-
|
124
|
-
# create color
|
125
|
-
sequential_palettes = [
|
126
|
-
"Reds",
|
127
|
-
"Greens",
|
128
|
-
"Blues",
|
129
|
-
"Oranges",
|
130
|
-
"Purples",
|
131
|
-
"Greys",
|
132
|
-
"BuGn",
|
133
|
-
"BuPu",
|
134
|
-
"GnBu",
|
135
|
-
"OrRd",
|
136
|
-
"PuBu",
|
137
|
-
"PuRd",
|
138
|
-
"RdPu",
|
139
|
-
"YlGn",
|
140
|
-
"PuBuGn",
|
141
|
-
"YlGnBu",
|
142
|
-
"YlOrBr",
|
143
|
-
"YlOrRd",
|
144
|
-
]
|
145
|
-
# each csvfile (tag) should have a unique color
|
146
|
-
tag2palette = {
|
147
|
-
tag: palette for tag, palette in zip(valid_tags, sequential_palettes)
|
148
|
-
}
|
149
|
-
plot_colors = []
|
150
|
-
for tag in valid_tags:
|
151
|
-
palette = tag2palette[tag]
|
152
|
-
total_colors = 10
|
153
|
-
ls_colors = sns.color_palette(palette, total_colors).as_hex()
|
154
|
-
num_part = len(ls_metrics)
|
155
|
-
subarr = np.array_split(np.arange(total_colors), num_part)
|
156
|
-
for idx, col in enumerate(cols):
|
157
|
-
if tag in col:
|
158
|
-
chosen_color = ls_colors[
|
159
|
-
subarr[int(idx % num_part)].mean().astype(int)
|
160
|
-
]
|
161
|
-
plot_colors.append(chosen_color)
|
162
|
-
|
163
|
-
# pprint(list(zip(cols, plot_colors)))
|
164
|
-
sns.lineplot(
|
165
|
-
data=plot_data,
|
166
|
-
markers=plot_markers,
|
167
|
-
palette=plot_colors,
|
168
|
-
ax=axes[i],
|
169
|
-
dashes=False,
|
170
|
-
)
|
171
|
-
axes[i].set(xlabel="epoch", ylabel=metric, title=title)
|
172
|
-
axes[i].legend()
|
173
|
-
axes[i].grid()
|
174
428
|
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
if log:
|
182
|
-
with ConsoleLog(f"plotting {csv_file}"):
|
183
|
-
csvfile.fn_display_df(df)
|
184
|
-
ls_df.append(df)
|
185
|
-
|
186
|
-
metric_values = ["loss", "acc"]
|
187
|
-
for i, metric in enumerate(metric_values):
|
188
|
-
for df_item, tag in zip(ls_df, valid_tags):
|
189
|
-
metric_ax = plot_ax(df_item, axes[i], metric, tag)
|
190
|
-
|
191
|
-
# set the title, xlabel, ylabel, legend, and grid
|
192
|
-
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
193
|
-
metric_ax.set(
|
194
|
-
xlabel="epoch", ylabel=metric, title=f"{tags_str}_{metric}-by-epoch"
|
429
|
+
fig.update_xaxes(
|
430
|
+
tickvals=list(range(n_cols)),
|
431
|
+
ticktext=col_labels,
|
432
|
+
range=[-0.5, n_cols - 0.5],
|
433
|
+
showgrid=False,
|
434
|
+
zeroline=False,
|
195
435
|
)
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
outdir="./out/plot",
|
203
|
-
tags=None,
|
204
|
-
log=False,
|
205
|
-
save_fig=False,
|
206
|
-
update_in_min=1,
|
207
|
-
):
|
208
|
-
# if csv_files is a string, convert it to a list
|
209
|
-
if isinstance(csv_files, str):
|
210
|
-
csv_files = [csv_files]
|
211
|
-
# if tags is a string, convert it to a list
|
212
|
-
if isinstance(tags, str):
|
213
|
-
tags = [tags]
|
214
|
-
valid_tags = get_valid_tags(csv_files, tags)
|
215
|
-
assert len(valid_tags) == len(
|
216
|
-
csv_files
|
217
|
-
), "Unable to determine tags for each csv file"
|
218
|
-
live_update_in_ms = int(update_in_min * 60 * 1000)
|
219
|
-
fig, axes = plt.subplots(2, 1, figsize=(10, 17))
|
220
|
-
if live_update_in_ms: # live update in min should be > 0
|
221
|
-
from matplotlib.animation import FuncAnimation
|
222
|
-
|
223
|
-
anim = FuncAnimation(
|
224
|
-
fig,
|
225
|
-
partial(
|
226
|
-
actual_plot_seaborn, csv_files=csv_files, axes=axes, tags=tags, log=log
|
227
|
-
),
|
228
|
-
interval=live_update_in_ms,
|
229
|
-
blit=False,
|
230
|
-
cache_frame_data=False,
|
436
|
+
fig.update_yaxes(
|
437
|
+
tickvals=[-i for i in range(n_rows)],
|
438
|
+
ticktext=row_labels,
|
439
|
+
range=[-n_rows + 0.5, 0.5],
|
440
|
+
showgrid=False,
|
441
|
+
zeroline=False,
|
231
442
|
)
|
232
|
-
plt.show()
|
233
|
-
else:
|
234
|
-
actual_plot_seaborn(None, csv_files, axes, tags, log)
|
235
|
-
plt.show()
|
236
|
-
|
237
|
-
if save_fig:
|
238
|
-
os.makedirs(outdir, exist_ok=True)
|
239
|
-
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
240
|
-
tag = f"{now_str()}_{tags_str}"
|
241
|
-
fig.savefig(f"{outdir}/{tag}_plot.png")
|
242
|
-
enable_plot_pgf()
|
243
|
-
fig.savefig(f"{outdir}/{tag}_plot.pdf")
|
244
|
-
if live_update_in_ms:
|
245
|
-
return anim
|
246
|
-
|
247
|
-
|
248
|
-
def enable_plot_pgf():
|
249
|
-
matplotlib.use("pdf")
|
250
|
-
matplotlib.rcParams.update(
|
251
|
-
{
|
252
|
-
"pgf.texsystem": "pdflatex",
|
253
|
-
"font.family": "serif",
|
254
|
-
"text.usetex": True,
|
255
|
-
"pgf.rcfonts": False,
|
256
|
-
}
|
257
|
-
)
|
258
443
|
|
444
|
+
fig.update_layout(
|
445
|
+
width=max_width * n_cols + 200, # extra for labels
|
446
|
+
height=max_height * n_rows + 100,
|
447
|
+
title=title,
|
448
|
+
margin=dict(l=100, r=20, t=50, b=50),
|
449
|
+
)
|
259
450
|
|
260
|
-
|
261
|
-
enable_plot_pgf()
|
262
|
-
if ".pgf" not in filename:
|
263
|
-
filename = f"{directory}/{filename}.pgf"
|
264
|
-
plt.savefig(filename)
|
451
|
+
fig.show()
|
265
452
|
|
266
453
|
|
267
|
-
# https: // click.palletsprojects.com/en/8.1.x/api/
|
268
454
|
@click.command()
|
269
455
|
@click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
|
270
456
|
@click.option(
|
271
|
-
"--outdir",
|
272
|
-
"-o",
|
273
|
-
type=str,
|
274
|
-
help="output directory for the plot",
|
275
|
-
default=str(desktop_path),
|
457
|
+
"--outdir", "-o", type=str, default=str(desktop_path), help="output directory"
|
276
458
|
)
|
277
459
|
@click.option(
|
278
|
-
"--tags", "-t", multiple=True, type=str, help="tags for the csv files"
|
460
|
+
"--tags", "-t", multiple=True, type=str, default=[], help="tags for the csv files"
|
279
461
|
)
|
280
462
|
@click.option("--log", "-l", is_flag=True, help="log the csv files")
|
281
|
-
@click.option("--save_fig", "-s", is_flag=True, help="save the plot as
|
463
|
+
@click.option("--save_fig", "-s", is_flag=True, help="save the plot as file")
|
282
464
|
@click.option(
|
283
465
|
"--update_in_min",
|
284
466
|
"-u",
|
285
467
|
type=float,
|
286
|
-
help="update the plot every x minutes",
|
287
468
|
default=0.0,
|
469
|
+
help="update the plot every x minutes",
|
470
|
+
)
|
471
|
+
@click.option(
|
472
|
+
"--x_col", "-x", type=str, default="epoch", help="column to use as x-axis"
|
288
473
|
)
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
)
|
297
|
-
|
474
|
+
@click.option(
|
475
|
+
"--y_cols",
|
476
|
+
"-y",
|
477
|
+
multiple=True,
|
478
|
+
type=str,
|
479
|
+
required=True,
|
480
|
+
help="columns to plot as y (can repeat)",
|
481
|
+
)
|
482
|
+
def main(csvfiles, outdir, tags, log, save_fig, update_in_min, x_col, y_cols):
|
483
|
+
PlotHelper.plot_csv_timeseries(
|
484
|
+
list(csvfiles),
|
485
|
+
outdir,
|
486
|
+
list(tags),
|
487
|
+
log,
|
488
|
+
save_fig,
|
489
|
+
update_in_min,
|
490
|
+
x_col=x_col,
|
491
|
+
y_cols=list(y_cols),
|
492
|
+
)
|
298
493
|
|
299
494
|
|
300
495
|
if __name__ == "__main__":
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: halib
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.87
|
4
4
|
Summary: Small library for common tasks
|
5
5
|
Author: Hoang Van Ha
|
6
6
|
Author-email: hoangvanhauit@gmail.com
|
@@ -52,7 +52,7 @@ Dynamic: summary
|
|
52
52
|
|
53
53
|
Helper package for coding and automation
|
54
54
|
|
55
|
-
**Version 0.1.
|
55
|
+
**Version 0.1.87**
|
56
56
|
|
57
57
|
+ `research/profiler`: add `zProfiler` class to measure execution time of contexts and steps, with support for dynamic color scales in plots.
|
58
58
|
|
@@ -17,7 +17,7 @@ halib/textfile.py,sha256=EhVFrit-nRBJx18e6rtIqcE1cSbgsLnMXe_kdhi1EPI,399
|
|
17
17
|
halib/torchloader.py,sha256=-q9YE-AoHZE1xQX2dgNxdqtucEXYs4sQ22WXdl6EGfI,6500
|
18
18
|
halib/videofile.py,sha256=NTLTZ-j6YD47duw2LN2p-lDQDglYFP1LpEU_0gzHLdI,4737
|
19
19
|
halib/filetype/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
-
halib/filetype/csvfile.py,sha256=
|
20
|
+
halib/filetype/csvfile.py,sha256=4Klf8YNzY1MaCD3o5Wp5GG3KMfQIBOEVzHV_7DO5XBo,6604
|
21
21
|
halib/filetype/jsonfile.py,sha256=9LBdM7LV9QgJA1bzJRkq69qpWOP22HDXPGirqXTgSCw,480
|
22
22
|
halib/filetype/textfile.py,sha256=QtuI5PdLxu4hAqSeafr3S8vCXwtvgipWV4Nkl7AzDYM,399
|
23
23
|
halib/filetype/videofile.py,sha256=4nfVAYYtoT76y8P4WYyxNna4Iv1o2iV6xaMcUzNPC4s,4736
|
@@ -29,13 +29,14 @@ halib/online/gdrive_test.py,sha256=hMWzz4RqZwETHp4GG4WwVNFfYvFQhp2Boz5t-DqwMo0,1
|
|
29
29
|
halib/online/projectmake.py,sha256=Zrs96WgXvO4nIrwxnCOletL4aTBge-EoF0r7hpKO1w8,4034
|
30
30
|
halib/research/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
halib/research/base_config.py,sha256=AIjVzl2ZJ9b8yIGb2X5EZwLmyGJ_9wNWqrib1nU3Wj0,2831
|
32
|
-
halib/research/base_exp.py,sha256=
|
32
|
+
halib/research/base_exp.py,sha256=hiO2flt_I0iJJ4bWcQwyh2ISezoC8t2k3PtxHeVr0eI,3278
|
33
33
|
halib/research/benchquery.py,sha256=FuKnbWQtCEoRRtJAfN-zaN-jPiO_EzsakmTOMiqi7GQ,4626
|
34
34
|
halib/research/dataset.py,sha256=QU0Hr5QFb8_XlvnOMgC9QJGIpwXAZ9lDd0RdQi_QRec,6743
|
35
35
|
halib/research/metrics.py,sha256=Xgv0GUGo-o-RJaBOmkRCRpQJaYijF_1xeKkyYU_Bv4U,5249
|
36
|
+
halib/research/mics.py,sha256=uX17AGrBGER-OFMqUULE_A9YPPbn1RpQ4o5-omrmqZ8,377
|
36
37
|
halib/research/perfcalc.py,sha256=qDa0sqfpWrwGZVJtjuUVFK7JX6j8xyXP9OnnfYmdamg,15898
|
37
38
|
halib/research/perftb.py,sha256=FWg0b8wSgy4UwuvHSXwEqvTq1Rhi-z-HtAKuQg1lWc4,30989
|
38
|
-
halib/research/plot.py,sha256
|
39
|
+
halib/research/plot.py,sha256=A3di1HZhIHIKf7d9b-I68yu_cm4u2LpHoPKlirCaNOI,17956
|
39
40
|
halib/research/profiler.py,sha256=a5ndHzVCatmHIBm4Z2e7F41hJLCu-u3g97pq490mPrg,11770
|
40
41
|
halib/research/torchloader.py,sha256=yqUjcSiME6H5W210363HyRUrOi3ISpUFAFkTr1w4DCw,6503
|
41
42
|
halib/research/wandb_op.py,sha256=YzLEqME5kIRxi3VvjFkW83wnFrsn92oYeqYuNwtYRkY,4188
|
@@ -52,8 +53,8 @@ halib/utils/gpu_mon.py,sha256=vD41_ZnmPLKguuq9X44SB_vwd9JrblO4BDzHLXZhhFY,2233
|
|
52
53
|
halib/utils/listop.py,sha256=Vpa8_2fI0wySpB2-8sfTBkyi_A4FhoFVVvFiuvW8N64,339
|
53
54
|
halib/utils/tele_noti.py,sha256=-4WXZelCA4W9BroapkRyIdUu9cUVrcJJhegnMs_WpGU,5928
|
54
55
|
halib/utils/video.py,sha256=ZqzNVPgc1RZr_T0OlHvZ6SzyBpL7O27LtB86JMbBuR0,3059
|
55
|
-
halib-0.1.
|
56
|
-
halib-0.1.
|
57
|
-
halib-0.1.
|
58
|
-
halib-0.1.
|
59
|
-
halib-0.1.
|
56
|
+
halib-0.1.87.dist-info/licenses/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
|
57
|
+
halib-0.1.87.dist-info/METADATA,sha256=EEx-vMyPJvloMv-e87rnnlWS1mBwdAz9oBP6TZQIzOY,5864
|
58
|
+
halib-0.1.87.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
59
|
+
halib-0.1.87.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
|
60
|
+
halib-0.1.87.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|