halib 0.1.84__py3-none-any.whl → 0.1.86__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
halib/filetype/csvfile.py CHANGED
@@ -9,6 +9,7 @@ from loguru import logger
9
9
  from itables import init_notebook_mode, show
10
10
  import pygwalker as pyg
11
11
  import textwrap
12
+ import csv
12
13
 
13
14
  console = Console()
14
15
 
@@ -18,6 +19,34 @@ def read(file, separator=","):
18
19
  return df
19
20
 
20
21
 
22
+ def read_auto_sep(filepath, sample_size=2048, **kwargs):
23
+ """
24
+ Read a CSV file with automatic delimiter detection.
25
+
26
+ Parameters
27
+ ----------
28
+ filepath : str
29
+ Path to the CSV file.
30
+ sample_size : int, optional
31
+ Number of bytes to read for delimiter sniffing.
32
+ **kwargs : dict
33
+ Extra keyword args passed to pandas.read_csv.
34
+
35
+ Returns
36
+ -------
37
+ df : pandas.DataFrame
38
+ """
39
+ with open(filepath, "r", newline="", encoding=kwargs.get("encoding", "utf-8")) as f:
40
+ sample = f.read(sample_size)
41
+ f.seek(0)
42
+ try:
43
+ dialect = csv.Sniffer().sniff(sample, delimiters=[",", ";", "\t", "|", ":"])
44
+ sep = dialect.delimiter
45
+ except csv.Error:
46
+ sep = "," # fallback if detection fails
47
+
48
+ return pd.read_csv(filepath, sep=sep, **kwargs)
49
+
21
50
  # for append, mode = 'a'
22
51
  def fn_write(df, outfile, mode="w", header=True, index_label=None):
23
52
  if not outfile.endswith(".csv"):
halib/research/mics.py ADDED
@@ -0,0 +1,16 @@
1
+ import platform
2
+
3
+ PC_NAME_TO_ABBR = {
4
+ "DESKTOP-JQD9K01": "MainPC",
5
+ "DESKTOP-5IRHU87": "MSI_Laptop",
6
+ "DESKTOP-96HQCNO": "4090_SV",
7
+ "DESKTOP-Q2IKLC0": "4GPU_SV",
8
+ "DESKTOP-QNS3DNF": "1GPU_SV"
9
+ }
10
+
11
+ def get_PC_name():
12
+ return platform.node()
13
+
14
+ def get_PC_abbr_name():
15
+ pc_name = get_PC_name()
16
+ return PC_NAME_TO_ABBR.get(pc_name, "Unknown")
halib/research/plot.py CHANGED
@@ -1,300 +1,495 @@
1
+ import os
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ from rich.console import Console
1
5
  from ..common import now_str, norm_str, ConsoleLog
2
6
  from ..filetype import csvfile
3
7
  from ..system import filesys as fs
4
- from functools import partial
5
- from rich.console import Console
6
- from rich.pretty import pprint
7
8
  import click
8
- import csv
9
- import matplotlib
10
- import matplotlib.pyplot as plt
11
- import numpy as np
12
- import os
9
+ import time
10
+
13
11
  import pandas as pd
14
- import seaborn as sns
12
+ import plotly.graph_objects as go
13
+ from PIL import Image
14
+ import base64
15
+ from io import BytesIO
16
+ from typing import Callable, Optional, Tuple, List, Union
15
17
 
16
18
 
17
19
  console = Console()
18
20
  desktop_path = os.path.expanduser("~/Desktop")
19
- REQUIRED_COLUMNS = ["epoch", "train_loss", "val_loss", "train_acc", "val_acc"]
20
-
21
- import csv
22
21
 
22
+ class PlotHelper:
23
+ def _verify_csv(self, csv_file):
24
+ """Read a CSV and normalize column names (lowercase)."""
25
+ try:
26
+ df = csvfile.read_auto_sep(csv_file)
27
+ df.columns = [col.lower() for col in df.columns]
28
+ return df
29
+ except FileNotFoundError:
30
+ raise FileNotFoundError(f"CSV file '{csv_file}' not found")
31
+ except Exception as e:
32
+ raise ValueError(f"Error reading CSV file '{csv_file}': {str(e)}")
33
+
34
+ @staticmethod
35
+ def _norm_str(s):
36
+ """Normalize string by converting to lowercase and replacing spaces/underscores."""
37
+ return s.lower().replace(" ", "_").replace("-", "_")
38
+
39
+ @staticmethod
40
+ def _get_file_name(file_path):
41
+ """Extract file name without extension."""
42
+ return os.path.splitext(os.path.basename(file_path))[0]
43
+
44
+ def _get_valid_tags(self, csv_files, tags):
45
+ """Generate tags from file names if not provided."""
46
+ if tags:
47
+ return list(tags)
48
+ return [self._norm_str(self._get_file_name(f)) for f in csv_files]
49
+
50
+ def _prepare_long_df(self, csv_files, tags, x_col, y_cols, log=False):
51
+ """Convert multiple CSVs into a single long-form dataframe for Plotly."""
52
+ dfs = []
53
+ for csv_file, tag in zip(csv_files, tags):
54
+ df = self._verify_csv(csv_file)
55
+ # Check columns
56
+ if x_col not in df.columns:
57
+ raise ValueError(f"{csv_file} is missing x_col '{x_col}'")
58
+ missing = [c for c in y_cols if c not in df.columns]
59
+ if missing:
60
+ raise ValueError(f"{csv_file} is missing y_cols {missing}")
61
+
62
+ if log:
63
+ console.log(f"Plotting {csv_file}")
64
+ console.print(df)
65
+
66
+ # Wide to long
67
+ df_long = df.melt(
68
+ id_vars=x_col,
69
+ value_vars=y_cols,
70
+ var_name="metric_type",
71
+ value_name="value",
72
+ )
73
+ df_long["tag"] = tag
74
+ dfs.append(df_long)
75
+
76
+ return pd.concat(dfs, ignore_index=True)
77
+
78
+ def _plot_with_plotly(
79
+ self,
80
+ df_long,
81
+ tags,
82
+ outdir,
83
+ save_fig,
84
+ out_fmt="svg",
85
+ font_size=16,
86
+ x_col="epoch",
87
+ y_cols=None,
88
+ ):
89
+ """Generate Plotly plots for given metrics."""
90
+ assert out_fmt in ["svg", "pdf", "png"], "Unsupported format"
91
+ if y_cols is None:
92
+ raise ValueError("y_cols must be provided")
93
+
94
+ # Group by suffix (e.g., "loss", "acc") if names like train_loss exist
95
+ metric_groups = sorted(set(col.split("_")[-1] for col in y_cols))
96
+
97
+ for metric in metric_groups:
98
+ subset = df_long[df_long["metric_type"].str.contains(metric)]
99
+
100
+ if out_fmt == "svg": # LaTeX-style
101
+ title = f"${'+'.join(tags)}\\_{metric}\\text{{-by-{x_col}}}$"
102
+ xaxis_title = f"$\\text{{{x_col.capitalize()}}}$"
103
+ yaxis_title = f"${metric.capitalize()}$"
104
+ else:
105
+ title = f"{'+'.join(tags)}_{metric}-by-{x_col}"
106
+ xaxis_title = x_col.capitalize()
107
+ yaxis_title = metric.capitalize()
108
+
109
+ fig = px.line(
110
+ subset,
111
+ x=x_col,
112
+ y="value",
113
+ color="tag",
114
+ line_dash="metric_type",
115
+ title=title,
116
+ )
117
+ fig.update_layout(
118
+ font=dict(family="Computer Modern", size=font_size),
119
+ xaxis_title=xaxis_title,
120
+ yaxis_title=yaxis_title,
121
+ )
122
+ fig.show()
123
+
124
+ if save_fig:
125
+ os.makedirs(outdir, exist_ok=True)
126
+ timestamp = now_str()
127
+ filename = f"{timestamp}_{'+'.join(tags)}_{metric}"
128
+ try:
129
+ fig.write_image(os.path.join(outdir, f"{filename}.{out_fmt}"))
130
+ except Exception as e:
131
+ console.log(f"Error saving figure '{filename}.{out_fmt}': {str(e)}")
132
+
133
+ @classmethod
134
+ def plot_csv_timeseries(
135
+ cls,
136
+ csv_files,
137
+ outdir="./out/plot",
138
+ tags=None,
139
+ log=False,
140
+ save_fig=False,
141
+ update_in_min=0,
142
+ out_fmt="svg",
143
+ font_size=16,
144
+ x_col="epoch",
145
+ y_cols=["train_loss", "train_acc"],
146
+ ):
147
+ """Plot CSV files with Plotly, supporting live updates, as a class method."""
148
+ if isinstance(csv_files, str):
149
+ csv_files = [csv_files]
150
+ if isinstance(tags, str):
151
+ tags = [tags]
152
+
153
+ if not y_cols:
154
+ raise ValueError("You must specify y_cols explicitly")
155
+
156
+ # Instantiate PlotHelper to call instance methods
157
+ plot_helper = cls()
158
+ valid_tags = plot_helper._get_valid_tags(csv_files, tags)
159
+ assert len(valid_tags) == len(
160
+ csv_files
161
+ ), "Number of tags must match number of CSV files"
162
+
163
+ def run_once():
164
+ df_long = plot_helper._prepare_long_df(
165
+ csv_files, valid_tags, x_col, y_cols, log
166
+ )
167
+ plot_helper._plot_with_plotly(
168
+ df_long, valid_tags, outdir, save_fig, out_fmt, font_size, x_col, y_cols
169
+ )
23
170
 
24
- def get_delimiter(file_path, bytes=4096):
25
- sniffer = csv.Sniffer()
26
- data = open(file_path, "r").read(bytes)
27
- delimiter = sniffer.sniff(data).delimiter
28
- return delimiter
171
+ if update_in_min > 0:
172
+ interval = int(update_in_min * 60)
173
+ console.log(f"Live update every {interval}s. Press Ctrl+C to stop.")
174
+ try:
175
+ while True:
176
+ run_once()
177
+ time.sleep(interval)
178
+ except KeyboardInterrupt:
179
+ console.log("Stopped live updates.")
180
+ else:
181
+ run_once()
182
+ @staticmethod
183
+ def plot_image_grid(csv_path, sep=";", max_width=300, max_height=300):
184
+ """
185
+ Plot a grid of images using Plotly from a CSV file.
186
+
187
+ Args:
188
+ csv_path (str): Path to CSV file.
189
+ max_width (int): Maximum width of each image in pixels.
190
+ max_height (int): Maximum height of each image in pixels.
191
+ """
192
+ # Load CSV
193
+ df = csvfile.read_auto_sep(csv_path, sep=sep)
194
+
195
+ # Column names for headers
196
+ col_names = df.columns.tolist()
197
+
198
+ # Function to convert image to base64
199
+ def pil_to_base64(img_path):
200
+ with Image.open(img_path) as im:
201
+ im.thumbnail((max_width, max_height))
202
+ buffer = BytesIO()
203
+ im.save(buffer, format="PNG")
204
+ encoded = base64.b64encode(buffer.getvalue()).decode()
205
+ return "data:image/png;base64," + encoded
206
+
207
+ # Initialize figure
208
+ fig = go.Figure()
209
+
210
+ n_rows = len(df)
211
+ n_cols = len(df.columns) - 1 # skip label column
212
+
213
+ # Add images
214
+ for i, row in df.iterrows():
215
+ for j, col in enumerate(df.columns[1:]):
216
+ img_path = row[col]
217
+ img_src = pil_to_base64(img_path)
218
+ fig.add_layout_image(
219
+ dict(
220
+ source=img_src,
221
+ x=j,
222
+ y=-i, # negative to have row 0 on top
223
+ xref="x",
224
+ yref="y",
225
+ sizex=1,
226
+ sizey=1,
227
+ xanchor="left",
228
+ yanchor="top",
229
+ layer="above"
230
+ )
231
+ )
232
+
233
+ # Set axes for grid layout
234
+ fig.update_xaxes(
235
+ tickvals=list(range(n_cols)),
236
+ ticktext=list(df.columns[1:]),
237
+ range=[-0.5, n_cols-0.5],
238
+ showgrid=False,
239
+ zeroline=False
240
+ )
241
+ fig.update_yaxes(
242
+ tickvals=[-i for i in range(n_rows)],
243
+ ticktext=df[df.columns[0]],
244
+ range=[-n_rows + 0.5, 0.5],
245
+ showgrid=False,
246
+ zeroline=False
247
+ )
29
248
 
249
+ fig.update_layout(
250
+ width=max_width*n_cols,
251
+ height=max_height*n_rows,
252
+ margin=dict(l=100, r=20, t=50, b=50)
253
+ )
30
254
 
31
- # Function to verify that the DataFrame has the required columns, and only the required columns
32
- def verify_csv(csv_file, required_columns=REQUIRED_COLUMNS):
33
- delimiter = get_delimiter(csv_file)
34
- df = pd.read_csv(csv_file, sep=delimiter)
35
- # change the column names to lower case
36
- df.columns = [col.lower() for col in df.columns]
37
- for col in required_columns:
38
- if col not in df.columns:
39
- raise ValueError(
40
- f"Required columns are: {REQUIRED_COLUMNS}, but found {df.columns}"
41
- )
42
- df = df[required_columns].copy()
43
- return df
44
-
45
-
46
- def get_valid_tags(csv_files, tags):
47
- if tags is not None and len(tags) > 0:
48
- assert all(
49
- isinstance(tag, str) for tag in tags
50
- ), "tags must be a list of strings"
51
- assert all(
52
- len(tag) > 0 for tag in tags
53
- ), "tags must be a list of non-empty strings"
54
- valid_tags = tags
55
- else:
56
- valid_tags = []
57
- for csv_file in csv_files:
58
- file_name = fs.get_file_name(csv_file, split_file_ext=True)[0]
59
- tag = norm_str(file_name)
60
- valid_tags.append(tag)
61
- return valid_tags
62
-
63
-
64
- def plot_ax(df, ax, metric="loss", tag=""):
65
- pprint(locals())
66
- # reset plt
67
- assert metric in ["loss", "acc"], "metric must be either 'loss' or 'acc'"
68
- part = ["train", "val"]
69
- for p in part:
70
- label = f"{tag}_{p}_{metric}"
71
- ax.plot(df["epoch"], df[f"{p}_{metric}"], label=label)
72
- return ax
73
-
74
-
75
- def actual_plot_seaborn(frame, csv_files, axes, tags, log):
76
- # clear the axes
77
- for ax in axes:
78
- ax.clear()
79
- ls_df = []
80
- valid_tags = get_valid_tags(csv_files, tags)
81
- for csv_file in csv_files:
82
- df = verify_csv(csv_file)
255
+ fig.show()
256
+
257
+ @staticmethod
258
+ # this plot_df contains the data to be plotted (row, column)
259
+ def img_grid_df(input_dir, log=False):
260
+ rows = fs.list_dirs(input_dir)
261
+ rows = [r for r in rows if r.startswith("row")]
262
+ meta_dict = {}
263
+ cols_of_row = None
264
+ for row in rows:
265
+ row_path = os.path.join(input_dir, row)
266
+ cols = sorted(fs.list_dirs(row_path))
267
+ if cols_of_row is None:
268
+ cols_of_row = cols
269
+ else:
270
+ if cols_of_row != cols:
271
+ raise ValueError(
272
+ f"Row {row} has different columns than previous rows: {cols_of_row} vs {cols}"
273
+ )
274
+ meta_dict[row] = cols
275
+
276
+ meta_dict_with_paths = {}
277
+ for row, cols in meta_dict.items():
278
+ meta_dict_with_paths[row] = {
279
+ col: fs.filter_files_by_extension(
280
+ os.path.join(input_dir, row, col), ["png", "jpg", "jpeg"]
281
+ )
282
+ for col in cols
283
+ }
284
+ first_row = list(meta_dict_with_paths.keys())[0]
285
+ first_col = list(meta_dict_with_paths[first_row].keys())[0]
286
+ len_first_col = len(meta_dict_with_paths[first_row][first_col])
287
+ for row, cols in meta_dict_with_paths.items():
288
+ for col, paths in cols.items():
289
+ if len(paths) != len_first_col:
290
+ raise ValueError(
291
+ f"Row {row}, Column {col} has different number of files: {len(paths)} vs {len_first_col}"
292
+ )
293
+ cols = sorted(meta_dict_with_paths[first_row].keys())
294
+ rows_set = sorted(meta_dict_with_paths.keys())
295
+ row_per_col = len(meta_dict_with_paths[first_row][first_col])
296
+ rows = [item for item in rows_set for _ in range(row_per_col)]
297
+ data_dict = {}
298
+ data_dict["row"] = rows
299
+ col_data = {col: [] for col in cols}
300
+ for row_base in rows_set:
301
+ for col in cols:
302
+ for i in range(row_per_col):
303
+ col_data[col].append(meta_dict_with_paths[row_base][col][i])
304
+ data_dict.update(col_data)
305
+ df = pd.DataFrame(data_dict)
83
306
  if log:
84
- with ConsoleLog(f"plotting {csv_file}"):
85
- csvfile.fn_display_df(df)
86
- ls_df.append(df)
87
-
88
- ls_metrics = ["loss", "acc"]
89
- for df_item, tag in zip(ls_df, valid_tags):
90
- # add tag to columns,excpet epoch
91
- df_item.columns = [
92
- f"{tag}_{col}" if col != "epoch" else col for col in df_item.columns
307
+ csvfile.fn_display_df(df)
308
+ return df
309
+
310
+ @staticmethod
311
+ def plot_image_grid(
312
+ csv_file_or_df: Union[str, pd.DataFrame],
313
+ max_width: int = 300,
314
+ max_height: int = 300,
315
+ img_stack_direction: str = "horizontal",
316
+ img_stack_padding_px: int = 10,
317
+ format_row_label_func: Optional[Callable[[str], str]] = None,
318
+ format_col_label_func: Optional[Callable[[str, str], str]] = None,
319
+ title: str = "",
320
+ ):
321
+ """
322
+ Plot a grid of images using Plotly from a DataFrame.
323
+
324
+ Args:
325
+ df (pd.DataFrame): DataFrame with first column as row labels, remaining columns as image paths.
326
+ max_width (int): Maximum width of stacked images per cell in pixels.
327
+ max_height (int): Maximum height of stacked images per cell in pixels.
328
+ img_stack_direction (str): "horizontal" or "vertical" stacking.
329
+ img_stack_padding_px (int): Padding between stacked images in pixels.
330
+ format_row_label_func (Callable): Function to format row labels.
331
+ format_col_label_func (Callable): Function to format column labels.
332
+ title (str): Figure title.
333
+ """
334
+
335
+ def stack_images_base64(
336
+ image_paths: List[str], direction: str, target_size: Tuple[int, int]
337
+ ) -> str:
338
+ """Stack images and return base64-encoded PNG."""
339
+ if not image_paths:
340
+ return ""
341
+
342
+ processed_images = []
343
+ for path in image_paths:
344
+ try:
345
+ img = Image.open(path).convert("RGB")
346
+ img.thumbnail(target_size, Image.Resampling.LANCZOS)
347
+ processed_images.append(img)
348
+ except:
349
+ # blank image if error
350
+ processed_images.append(Image.new("RGB", target_size, (255, 255, 255)))
351
+
352
+ # Stack
353
+ widths, heights = zip(*(img.size for img in processed_images))
354
+ if direction == "horizontal":
355
+ total_width = sum(widths) + img_stack_padding_px * (
356
+ len(processed_images) - 1
357
+ )
358
+ total_height = max(heights)
359
+ stacked = Image.new("RGB", (total_width, total_height), (255, 255, 255))
360
+ x_offset = 0
361
+ for im in processed_images:
362
+ stacked.paste(im, (x_offset, 0))
363
+ x_offset += im.width + img_stack_padding_px
364
+ elif direction == "vertical":
365
+ total_width = max(widths)
366
+ total_height = sum(heights) + img_stack_padding_px * (
367
+ len(processed_images) - 1
368
+ )
369
+ stacked = Image.new("RGB", (total_width, total_height), (255, 255, 255))
370
+ y_offset = 0
371
+ for im in processed_images:
372
+ stacked.paste(im, (0, y_offset))
373
+ y_offset += im.height + img_stack_padding_px
374
+ else:
375
+ raise ValueError("img_stack_direction must be 'horizontal' or 'vertical'")
376
+
377
+ # Encode as base64
378
+ buffer = BytesIO()
379
+ stacked.save(buffer, format="PNG")
380
+ encoded = base64.b64encode(buffer.getvalue()).decode()
381
+ return "data:image/png;base64," + encoded
382
+
383
+ # Load DataFrame if a file path is provided
384
+ if isinstance(csv_file_or_df, str):
385
+ df = csvfile.read_auto_sep(csv_file_or_df)
386
+ else:
387
+ df = csv_file_or_df
388
+ assert isinstance(df, pd.DataFrame), "Input must be a DataFrame or valid CSV file path"
389
+
390
+ rows = df[df.columns[0]].tolist()
391
+ columns = df.columns[1:].tolist()
392
+ n_rows, n_cols = len(rows), len(columns)
393
+
394
+ fig = go.Figure()
395
+
396
+ for i, row_label in enumerate(rows):
397
+ for j, col_label in enumerate(columns):
398
+ image_paths = df.loc[i, col_label]
399
+ if isinstance(image_paths, str):
400
+ image_paths = [image_paths]
401
+ img_src = stack_images_base64(
402
+ image_paths, img_stack_direction, (max_width, max_height)
403
+ )
404
+
405
+ fig.add_layout_image(
406
+ dict(
407
+ source=img_src,
408
+ x=j,
409
+ y=-i, # negative so row 0 on top
410
+ xref="x",
411
+ yref="y",
412
+ sizex=1,
413
+ sizey=1,
414
+ xanchor="left",
415
+ yanchor="top",
416
+ layer="above",
417
+ )
418
+ )
419
+
420
+ # Format axis labels
421
+ col_labels = [
422
+ format_col_label_func(c, pattern="___") if format_col_label_func else c
423
+ for c in columns
93
424
  ]
94
- # merge the dataframes on the epoch column
95
- df_combined = ls_df[0]
96
- for df_item in ls_df[1:]:
97
- df_combined = pd.merge(df_combined, df_item, on="epoch", how="outer")
98
- # csvfile.fn_display_df(df_combined)
99
-
100
- for i, metric in enumerate(ls_metrics):
101
- tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
102
- title = f"{tags_str}_{metric}-by-epoch"
103
- cols = [col for col in df_combined.columns if col != "epoch" and metric in col]
104
- cols = sorted(cols)
105
- # pprint(cols)
106
- plot_data = df_combined[cols]
107
-
108
- # line from same csv file (same tag) should have the same marker
109
- all_markers = [
110
- marker for marker in plt.Line2D.markers if marker and marker != " "
425
+ row_labels = [
426
+ format_row_label_func(r) if format_row_label_func else r for r in rows
111
427
  ]
112
- tag2marker = {tag: marker for tag, marker in zip(valid_tags, all_markers)}
113
- plot_markers = []
114
- for col in cols:
115
- # find the tag:
116
- tag = None
117
- for valid_tag in valid_tags:
118
- if valid_tag in col:
119
- tag = valid_tag
120
- break
121
- plot_markers.append(tag2marker[tag])
122
- # pprint(list(zip(cols, plot_markers)))
123
-
124
- # create color
125
- sequential_palettes = [
126
- "Reds",
127
- "Greens",
128
- "Blues",
129
- "Oranges",
130
- "Purples",
131
- "Greys",
132
- "BuGn",
133
- "BuPu",
134
- "GnBu",
135
- "OrRd",
136
- "PuBu",
137
- "PuRd",
138
- "RdPu",
139
- "YlGn",
140
- "PuBuGn",
141
- "YlGnBu",
142
- "YlOrBr",
143
- "YlOrRd",
144
- ]
145
- # each csvfile (tag) should have a unique color
146
- tag2palette = {
147
- tag: palette for tag, palette in zip(valid_tags, sequential_palettes)
148
- }
149
- plot_colors = []
150
- for tag in valid_tags:
151
- palette = tag2palette[tag]
152
- total_colors = 10
153
- ls_colors = sns.color_palette(palette, total_colors).as_hex()
154
- num_part = len(ls_metrics)
155
- subarr = np.array_split(np.arange(total_colors), num_part)
156
- for idx, col in enumerate(cols):
157
- if tag in col:
158
- chosen_color = ls_colors[
159
- subarr[int(idx % num_part)].mean().astype(int)
160
- ]
161
- plot_colors.append(chosen_color)
162
-
163
- # pprint(list(zip(cols, plot_colors)))
164
- sns.lineplot(
165
- data=plot_data,
166
- markers=plot_markers,
167
- palette=plot_colors,
168
- ax=axes[i],
169
- dashes=False,
170
- )
171
- axes[i].set(xlabel="epoch", ylabel=metric, title=title)
172
- axes[i].legend()
173
- axes[i].grid()
174
428
 
175
-
176
- def actual_plot(frame, csv_files, axes, tags, log):
177
- ls_df = []
178
- valid_tags = get_valid_tags(csv_files, tags)
179
- for csv_file in csv_files:
180
- df = verify_csv(csv_file)
181
- if log:
182
- with ConsoleLog(f"plotting {csv_file}"):
183
- csvfile.fn_display_df(df)
184
- ls_df.append(df)
185
-
186
- metric_values = ["loss", "acc"]
187
- for i, metric in enumerate(metric_values):
188
- for df_item, tag in zip(ls_df, valid_tags):
189
- metric_ax = plot_ax(df_item, axes[i], metric, tag)
190
-
191
- # set the title, xlabel, ylabel, legend, and grid
192
- tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
193
- metric_ax.set(
194
- xlabel="epoch", ylabel=metric, title=f"{tags_str}_{metric}-by-epoch"
429
+ fig.update_xaxes(
430
+ tickvals=list(range(n_cols)),
431
+ ticktext=col_labels,
432
+ range=[-0.5, n_cols - 0.5],
433
+ showgrid=False,
434
+ zeroline=False,
195
435
  )
196
- metric_ax.legend()
197
- metric_ax.grid()
198
-
199
-
200
- def plot_csv_files(
201
- csv_files,
202
- outdir="./out/plot",
203
- tags=None,
204
- log=False,
205
- save_fig=False,
206
- update_in_min=1,
207
- ):
208
- # if csv_files is a string, convert it to a list
209
- if isinstance(csv_files, str):
210
- csv_files = [csv_files]
211
- # if tags is a string, convert it to a list
212
- if isinstance(tags, str):
213
- tags = [tags]
214
- valid_tags = get_valid_tags(csv_files, tags)
215
- assert len(valid_tags) == len(
216
- csv_files
217
- ), "Unable to determine tags for each csv file"
218
- live_update_in_ms = int(update_in_min * 60 * 1000)
219
- fig, axes = plt.subplots(2, 1, figsize=(10, 17))
220
- if live_update_in_ms: # live update in min should be > 0
221
- from matplotlib.animation import FuncAnimation
222
-
223
- anim = FuncAnimation(
224
- fig,
225
- partial(
226
- actual_plot_seaborn, csv_files=csv_files, axes=axes, tags=tags, log=log
227
- ),
228
- interval=live_update_in_ms,
229
- blit=False,
230
- cache_frame_data=False,
436
+ fig.update_yaxes(
437
+ tickvals=[-i for i in range(n_rows)],
438
+ ticktext=row_labels,
439
+ range=[-n_rows + 0.5, 0.5],
440
+ showgrid=False,
441
+ zeroline=False,
231
442
  )
232
- plt.show()
233
- else:
234
- actual_plot_seaborn(None, csv_files, axes, tags, log)
235
- plt.show()
236
-
237
- if save_fig:
238
- os.makedirs(outdir, exist_ok=True)
239
- tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
240
- tag = f"{now_str()}_{tags_str}"
241
- fig.savefig(f"{outdir}/{tag}_plot.png")
242
- enable_plot_pgf()
243
- fig.savefig(f"{outdir}/{tag}_plot.pdf")
244
- if live_update_in_ms:
245
- return anim
246
-
247
-
248
- def enable_plot_pgf():
249
- matplotlib.use("pdf")
250
- matplotlib.rcParams.update(
251
- {
252
- "pgf.texsystem": "pdflatex",
253
- "font.family": "serif",
254
- "text.usetex": True,
255
- "pgf.rcfonts": False,
256
- }
257
- )
258
443
 
444
+ fig.update_layout(
445
+ width=max_width * n_cols + 200, # extra for labels
446
+ height=max_height * n_rows + 100,
447
+ title=title,
448
+ margin=dict(l=100, r=20, t=50, b=50),
449
+ )
259
450
 
260
- def save_fig_latex_pgf(filename, directory="."):
261
- enable_plot_pgf()
262
- if ".pgf" not in filename:
263
- filename = f"{directory}/{filename}.pgf"
264
- plt.savefig(filename)
451
+ fig.show()
265
452
 
266
453
 
267
- # https: // click.palletsprojects.com/en/8.1.x/api/
268
454
  @click.command()
269
455
  @click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
270
456
  @click.option(
271
- "--outdir",
272
- "-o",
273
- type=str,
274
- help="output directory for the plot",
275
- default=str(desktop_path),
457
+ "--outdir", "-o", type=str, default=str(desktop_path), help="output directory"
276
458
  )
277
459
  @click.option(
278
- "--tags", "-t", multiple=True, type=str, help="tags for the csv files", default=[]
460
+ "--tags", "-t", multiple=True, type=str, default=[], help="tags for the csv files"
279
461
  )
280
462
  @click.option("--log", "-l", is_flag=True, help="log the csv files")
281
- @click.option("--save_fig", "-s", is_flag=True, help="save the plot as a file")
463
+ @click.option("--save_fig", "-s", is_flag=True, help="save the plot as file")
282
464
  @click.option(
283
465
  "--update_in_min",
284
466
  "-u",
285
467
  type=float,
286
- help="update the plot every x minutes",
287
468
  default=0.0,
469
+ help="update the plot every x minutes",
470
+ )
471
+ @click.option(
472
+ "--x_col", "-x", type=str, default="epoch", help="column to use as x-axis"
288
473
  )
289
- def main(
290
- csvfiles,
291
- outdir,
292
- tags,
293
- log,
294
- save_fig,
295
- update_in_min,
296
- ):
297
- plot_csv_files(list(csvfiles), outdir, list(tags), log, save_fig, update_in_min)
474
+ @click.option(
475
+ "--y_cols",
476
+ "-y",
477
+ multiple=True,
478
+ type=str,
479
+ required=True,
480
+ help="columns to plot as y (can repeat)",
481
+ )
482
+ def main(csvfiles, outdir, tags, log, save_fig, update_in_min, x_col, y_cols):
483
+ PlotHelper.plot_csv_timeseries(
484
+ list(csvfiles),
485
+ outdir,
486
+ list(tags),
487
+ log,
488
+ save_fig,
489
+ update_in_min,
490
+ x_col=x_col,
491
+ y_cols=list(y_cols),
492
+ )
298
493
 
299
494
 
300
495
  if __name__ == "__main__":
@@ -238,10 +238,10 @@ class zProfiler:
238
238
  row=1,
239
239
  col=2,
240
240
  )
241
- tag_str = f"[{tag}]" if tag else ""
241
+ tag_str = tag if tag and len(tag) > 0 else ""
242
242
  # Layout
243
243
  fig.update_layout(
244
- title_text=f"[{tag_str} Context Profiler: {ctx}",
244
+ title_text=f"[{tag_str}] Context Profiler: {ctx}",
245
245
  width=1000,
246
246
  height=400,
247
247
  showlegend=True,
@@ -258,7 +258,7 @@ class zProfiler:
258
258
 
259
259
  # Save figure
260
260
  if outdir is not None:
261
- file_prefix = {ctx} if len(tag_str) == 0 else f"{tag_str}_{ctx}"
261
+ file_prefix = ctx if len(tag_str) == 0 else f"{tag_str}_{ctx}"
262
262
  file_path = os.path.join(outdir, f"{file_prefix}_summary.{file_format.lower()}")
263
263
  fig.write_image(file_path)
264
264
  print(f"Saved figure: {file_path}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: halib
3
- Version: 0.1.84
3
+ Version: 0.1.86
4
4
  Summary: Small library for common tasks
5
5
  Author: Hoang Van Ha
6
6
  Author-email: hoangvanhauit@gmail.com
@@ -52,7 +52,7 @@ Dynamic: summary
52
52
 
53
53
  Helper package for coding and automation
54
54
 
55
- **Version 0.1.84**
55
+ **Version 0.1.86**
56
56
 
57
57
  + `research/profiler`: add `zProfiler` class to measure execution time of contexts and steps, with support for dynamic color scales in plots.
58
58
 
@@ -17,7 +17,7 @@ halib/textfile.py,sha256=EhVFrit-nRBJx18e6rtIqcE1cSbgsLnMXe_kdhi1EPI,399
17
17
  halib/torchloader.py,sha256=-q9YE-AoHZE1xQX2dgNxdqtucEXYs4sQ22WXdl6EGfI,6500
18
18
  halib/videofile.py,sha256=NTLTZ-j6YD47duw2LN2p-lDQDglYFP1LpEU_0gzHLdI,4737
19
19
  halib/filetype/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- halib/filetype/csvfile.py,sha256=YtJHYft72I4VmKo9QpMv6TPV_62chcwdAIyRRumJKOI,5727
20
+ halib/filetype/csvfile.py,sha256=4Klf8YNzY1MaCD3o5Wp5GG3KMfQIBOEVzHV_7DO5XBo,6604
21
21
  halib/filetype/jsonfile.py,sha256=9LBdM7LV9QgJA1bzJRkq69qpWOP22HDXPGirqXTgSCw,480
22
22
  halib/filetype/textfile.py,sha256=QtuI5PdLxu4hAqSeafr3S8vCXwtvgipWV4Nkl7AzDYM,399
23
23
  halib/filetype/videofile.py,sha256=4nfVAYYtoT76y8P4WYyxNna4Iv1o2iV6xaMcUzNPC4s,4736
@@ -33,10 +33,11 @@ halib/research/base_exp.py,sha256=tz03FF2XMI9b6Ram4ZJBjKo053Vb7T8yTTCH0jeLvp4,32
33
33
  halib/research/benchquery.py,sha256=FuKnbWQtCEoRRtJAfN-zaN-jPiO_EzsakmTOMiqi7GQ,4626
34
34
  halib/research/dataset.py,sha256=QU0Hr5QFb8_XlvnOMgC9QJGIpwXAZ9lDd0RdQi_QRec,6743
35
35
  halib/research/metrics.py,sha256=Xgv0GUGo-o-RJaBOmkRCRpQJaYijF_1xeKkyYU_Bv4U,5249
36
+ halib/research/mics.py,sha256=uX17AGrBGER-OFMqUULE_A9YPPbn1RpQ4o5-omrmqZ8,377
36
37
  halib/research/perfcalc.py,sha256=qDa0sqfpWrwGZVJtjuUVFK7JX6j8xyXP9OnnfYmdamg,15898
37
38
  halib/research/perftb.py,sha256=FWg0b8wSgy4UwuvHSXwEqvTq1Rhi-z-HtAKuQg1lWc4,30989
38
- halib/research/plot.py,sha256=-pDUk4z3C_GnyJ5zWmf-mGMdT4gaipVJWzIgcpIPiRk,9448
39
- halib/research/profiler.py,sha256=IW__mJbOypes_Vl5VNVB9YQS-xIpLJN_l7iuzZyK_q8,11761
39
+ halib/research/plot.py,sha256=A3di1HZhIHIKf7d9b-I68yu_cm4u2LpHoPKlirCaNOI,17956
40
+ halib/research/profiler.py,sha256=a5ndHzVCatmHIBm4Z2e7F41hJLCu-u3g97pq490mPrg,11770
40
41
  halib/research/torchloader.py,sha256=yqUjcSiME6H5W210363HyRUrOi3ISpUFAFkTr1w4DCw,6503
41
42
  halib/research/wandb_op.py,sha256=YzLEqME5kIRxi3VvjFkW83wnFrsn92oYeqYuNwtYRkY,4188
42
43
  halib/sys/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -52,8 +53,8 @@ halib/utils/gpu_mon.py,sha256=vD41_ZnmPLKguuq9X44SB_vwd9JrblO4BDzHLXZhhFY,2233
52
53
  halib/utils/listop.py,sha256=Vpa8_2fI0wySpB2-8sfTBkyi_A4FhoFVVvFiuvW8N64,339
53
54
  halib/utils/tele_noti.py,sha256=-4WXZelCA4W9BroapkRyIdUu9cUVrcJJhegnMs_WpGU,5928
54
55
  halib/utils/video.py,sha256=ZqzNVPgc1RZr_T0OlHvZ6SzyBpL7O27LtB86JMbBuR0,3059
55
- halib-0.1.84.dist-info/licenses/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
56
- halib-0.1.84.dist-info/METADATA,sha256=7rx1qsP50k17bZ_K9ffKdpBR3fLUtKpmYu-Q-xZTf8Q,5864
57
- halib-0.1.84.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
58
- halib-0.1.84.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
59
- halib-0.1.84.dist-info/RECORD,,
56
+ halib-0.1.86.dist-info/licenses/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
57
+ halib-0.1.86.dist-info/METADATA,sha256=u2LiEtPvElf7eADlv05vYAIFbec0W3k6W6jDOxqRt6Q,5864
58
+ halib-0.1.86.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
59
+ halib-0.1.86.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
60
+ halib-0.1.86.dist-info/RECORD,,
File without changes