halib 0.1.7__py3-none-any.whl → 0.1.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. halib/__init__.py +84 -0
  2. halib/common.py +151 -0
  3. halib/cuda.py +39 -0
  4. halib/dataset.py +209 -0
  5. halib/filetype/csvfile.py +151 -45
  6. halib/filetype/ipynb.py +63 -0
  7. halib/filetype/jsonfile.py +1 -1
  8. halib/filetype/textfile.py +4 -4
  9. halib/filetype/videofile.py +44 -33
  10. halib/filetype/yamlfile.py +95 -0
  11. halib/gdrive.py +1 -1
  12. halib/online/gdrive.py +104 -54
  13. halib/online/gdrive_mkdir.py +29 -17
  14. halib/online/gdrive_test.py +31 -18
  15. halib/online/projectmake.py +58 -43
  16. halib/plot.py +296 -11
  17. halib/projectmake.py +1 -1
  18. halib/research/__init__.py +0 -0
  19. halib/research/base_config.py +100 -0
  20. halib/research/base_exp.py +100 -0
  21. halib/research/benchquery.py +131 -0
  22. halib/research/dataset.py +208 -0
  23. halib/research/flop_csv.py +34 -0
  24. halib/research/flops.py +156 -0
  25. halib/research/metrics.py +133 -0
  26. halib/research/mics.py +68 -0
  27. halib/research/params_gen.py +108 -0
  28. halib/research/perfcalc.py +336 -0
  29. halib/research/perftb.py +780 -0
  30. halib/research/plot.py +758 -0
  31. halib/research/profiler.py +300 -0
  32. halib/research/torchloader.py +162 -0
  33. halib/research/wandb_op.py +116 -0
  34. halib/rich_color.py +285 -0
  35. halib/sys/filesys.py +17 -10
  36. halib/system/__init__.py +0 -0
  37. halib/system/cmd.py +8 -0
  38. halib/system/filesys.py +124 -0
  39. halib/tele_noti.py +166 -0
  40. halib/torchloader.py +162 -0
  41. halib/utils/__init__.py +0 -0
  42. halib/utils/dataclass_util.py +40 -0
  43. halib/utils/dict_op.py +9 -0
  44. halib/utils/gpu_mon.py +58 -0
  45. halib/utils/listop.py +13 -0
  46. halib/utils/tele_noti.py +166 -0
  47. halib/utils/video.py +82 -0
  48. halib/videofile.py +1 -1
  49. halib-0.1.99.dist-info/METADATA +209 -0
  50. halib-0.1.99.dist-info/RECORD +64 -0
  51. {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/WHEEL +1 -1
  52. halib-0.1.7.dist-info/METADATA +0 -59
  53. halib-0.1.7.dist-info/RECORD +0 -30
  54. {halib-0.1.7.dist-info → halib-0.1.99.dist-info/licenses}/LICENSE.txt +0 -0
  55. {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/top_level.txt +0 -0
halib/plot.py CHANGED
@@ -1,16 +1,301 @@
1
+ from .common import now_str, norm_str, ConsoleLog
2
+ from .filetype import csvfile
3
+ from .system import filesys as fs
4
+ from functools import partial
5
+ from rich.console import Console
6
+ from rich.pretty import pprint
7
+ import click
8
+ import csv
9
+ import matplotlib
1
10
  import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import os
13
+ import pandas as pd
2
14
  import seaborn as sns
3
- import matplotlib
4
15
 
5
16
 
6
- def save_fig_latex_pgf(filename, directory='.'):
7
- matplotlib.use("pgf")
8
- matplotlib.rcParams.update({
9
- "pgf.texsystem": "pdflatex",
10
- 'font.family': 'serif',
11
- 'text.usetex': True,
12
- 'pgf.rcfonts': False,
13
- })
14
- if '.pgf' not in filename:
15
- filename = f'{directory}/{filename}.pgf'
17
+ console = Console()
18
+ desktop_path = os.path.expanduser("~/Desktop")
19
+ REQUIRED_COLUMNS = ["epoch", "train_loss", "val_loss", "train_acc", "val_acc"]
20
+
21
+ import csv
22
+
23
+
24
+ def get_delimiter(file_path, bytes=4096):
25
+ sniffer = csv.Sniffer()
26
+ data = open(file_path, "r").read(bytes)
27
+ delimiter = sniffer.sniff(data).delimiter
28
+ return delimiter
29
+
30
+
31
+ # Function to verify that the DataFrame has the required columns, and only the required columns
32
+ def verify_csv(csv_file, required_columns=REQUIRED_COLUMNS):
33
+ delimiter = get_delimiter(csv_file)
34
+ df = pd.read_csv(csv_file, sep=delimiter)
35
+ # change the column names to lower case
36
+ df.columns = [col.lower() for col in df.columns]
37
+ for col in required_columns:
38
+ if col not in df.columns:
39
+ raise ValueError(
40
+ f"Required columns are: {REQUIRED_COLUMNS}, but found {df.columns}"
41
+ )
42
+ df = df[required_columns].copy()
43
+ return df
44
+
45
+
46
+ def get_valid_tags(csv_files, tags):
47
+ if tags is not None and len(tags) > 0:
48
+ assert all(
49
+ isinstance(tag, str) for tag in tags
50
+ ), "tags must be a list of strings"
51
+ assert all(
52
+ len(tag) > 0 for tag in tags
53
+ ), "tags must be a list of non-empty strings"
54
+ valid_tags = tags
55
+ else:
56
+ valid_tags = []
57
+ for csv_file in csv_files:
58
+ file_name = fs.get_file_name(csv_file, split_file_ext=True)[0]
59
+ tag = norm_str(file_name)
60
+ valid_tags.append(tag)
61
+ return valid_tags
62
+
63
+
64
+ def plot_ax(df, ax, metric="loss", tag=""):
65
+ pprint(locals())
66
+ # reset plt
67
+ assert metric in ["loss", "acc"], "metric must be either 'loss' or 'acc'"
68
+ part = ["train", "val"]
69
+ for p in part:
70
+ label = f"{tag}_{p}_{metric}"
71
+ ax.plot(df["epoch"], df[f"{p}_{metric}"], label=label)
72
+ return ax
73
+
74
+
75
+ def actual_plot_seaborn(frame, csv_files, axes, tags, log):
76
+ # clear the axes
77
+ for ax in axes:
78
+ ax.clear()
79
+ ls_df = []
80
+ valid_tags = get_valid_tags(csv_files, tags)
81
+ for csv_file in csv_files:
82
+ df = verify_csv(csv_file)
83
+ if log:
84
+ with ConsoleLog(f"plotting {csv_file}"):
85
+ csvfile.fn_display_df(df)
86
+ ls_df.append(df)
87
+
88
+ ls_metrics = ["loss", "acc"]
89
+ for df_item, tag in zip(ls_df, valid_tags):
90
+ # add tag to columns,excpet epoch
91
+ df_item.columns = [
92
+ f"{tag}_{col}" if col != "epoch" else col for col in df_item.columns
93
+ ]
94
+ # merge the dataframes on the epoch column
95
+ df_combined = ls_df[0]
96
+ for df_item in ls_df[1:]:
97
+ df_combined = pd.merge(df_combined, df_item, on="epoch", how="outer")
98
+ # csvfile.fn_display_df(df_combined)
99
+
100
+ for i, metric in enumerate(ls_metrics):
101
+ tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
102
+ title = f"{tags_str}_{metric}-by-epoch"
103
+ cols = [col for col in df_combined.columns if col != "epoch" and metric in col]
104
+ cols = sorted(cols)
105
+ # pprint(cols)
106
+ plot_data = df_combined[cols]
107
+
108
+ # line from same csv file (same tag) should have the same marker
109
+ all_markers = [
110
+ marker for marker in plt.Line2D.markers if marker and marker != " "
111
+ ]
112
+ tag2marker = {tag: marker for tag, marker in zip(valid_tags, all_markers)}
113
+ plot_markers = []
114
+ for col in cols:
115
+ # find the tag:
116
+ tag = None
117
+ for valid_tag in valid_tags:
118
+ if valid_tag in col:
119
+ tag = valid_tag
120
+ break
121
+ plot_markers.append(tag2marker[tag])
122
+ # pprint(list(zip(cols, plot_markers)))
123
+
124
+ # create color
125
+ sequential_palettes = [
126
+ "Reds",
127
+ "Greens",
128
+ "Blues",
129
+ "Oranges",
130
+ "Purples",
131
+ "Greys",
132
+ "BuGn",
133
+ "BuPu",
134
+ "GnBu",
135
+ "OrRd",
136
+ "PuBu",
137
+ "PuRd",
138
+ "RdPu",
139
+ "YlGn",
140
+ "PuBuGn",
141
+ "YlGnBu",
142
+ "YlOrBr",
143
+ "YlOrRd",
144
+ ]
145
+ # each csvfile (tag) should have a unique color
146
+ tag2palette = {
147
+ tag: palette for tag, palette in zip(valid_tags, sequential_palettes)
148
+ }
149
+ plot_colors = []
150
+ for tag in valid_tags:
151
+ palette = tag2palette[tag]
152
+ total_colors = 10
153
+ ls_colors = sns.color_palette(palette, total_colors).as_hex()
154
+ num_part = len(ls_metrics)
155
+ subarr = np.array_split(np.arange(total_colors), num_part)
156
+ for idx, col in enumerate(cols):
157
+ if tag in col:
158
+ chosen_color = ls_colors[
159
+ subarr[int(idx % num_part)].mean().astype(int)
160
+ ]
161
+ plot_colors.append(chosen_color)
162
+
163
+ # pprint(list(zip(cols, plot_colors)))
164
+ sns.lineplot(
165
+ data=plot_data,
166
+ markers=plot_markers,
167
+ palette=plot_colors,
168
+ ax=axes[i],
169
+ dashes=False,
170
+ )
171
+ axes[i].set(xlabel="epoch", ylabel=metric, title=title)
172
+ axes[i].legend()
173
+ axes[i].grid()
174
+
175
+
176
+ def actual_plot(frame, csv_files, axes, tags, log):
177
+ ls_df = []
178
+ valid_tags = get_valid_tags(csv_files, tags)
179
+ for csv_file in csv_files:
180
+ df = verify_csv(csv_file)
181
+ if log:
182
+ with ConsoleLog(f"plotting {csv_file}"):
183
+ csvfile.fn_display_df(df)
184
+ ls_df.append(df)
185
+
186
+ metric_values = ["loss", "acc"]
187
+ for i, metric in enumerate(metric_values):
188
+ for df_item, tag in zip(ls_df, valid_tags):
189
+ metric_ax = plot_ax(df_item, axes[i], metric, tag)
190
+
191
+ # set the title, xlabel, ylabel, legend, and grid
192
+ tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
193
+ metric_ax.set(
194
+ xlabel="epoch", ylabel=metric, title=f"{tags_str}_{metric}-by-epoch"
195
+ )
196
+ metric_ax.legend()
197
+ metric_ax.grid()
198
+
199
+
200
+ def plot_csv_files(
201
+ csv_files,
202
+ outdir="./out/plot",
203
+ tags=None,
204
+ log=False,
205
+ save_fig=False,
206
+ update_in_min=1,
207
+ ):
208
+ # if csv_files is a string, convert it to a list
209
+ if isinstance(csv_files, str):
210
+ csv_files = [csv_files]
211
+ # if tags is a string, convert it to a list
212
+ if isinstance(tags, str):
213
+ tags = [tags]
214
+ valid_tags = get_valid_tags(csv_files, tags)
215
+ assert len(valid_tags) == len(
216
+ csv_files
217
+ ), "Unable to determine tags for each csv file"
218
+ live_update_in_ms = int(update_in_min * 60 * 1000)
219
+ fig, axes = plt.subplots(2, 1, figsize=(10, 17))
220
+ if live_update_in_ms: # live update in min should be > 0
221
+ from matplotlib.animation import FuncAnimation
222
+
223
+ anim = FuncAnimation(
224
+ fig,
225
+ partial(
226
+ actual_plot_seaborn, csv_files=csv_files, axes=axes, tags=tags, log=log
227
+ ),
228
+ interval=live_update_in_ms,
229
+ blit=False,
230
+ cache_frame_data=False,
231
+ )
232
+ plt.show()
233
+ else:
234
+ actual_plot_seaborn(None, csv_files, axes, tags, log)
235
+ plt.show()
236
+
237
+ if save_fig:
238
+ os.makedirs(outdir, exist_ok=True)
239
+ tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
240
+ tag = f"{now_str()}_{tags_str}"
241
+ fig.savefig(f"{outdir}/{tag}_plot.png")
242
+ enable_plot_pgf()
243
+ fig.savefig(f"{outdir}/{tag}_plot.pdf")
244
+ if live_update_in_ms:
245
+ return anim
246
+
247
+
248
+ def enable_plot_pgf():
249
+ matplotlib.use("pdf")
250
+ matplotlib.rcParams.update(
251
+ {
252
+ "pgf.texsystem": "pdflatex",
253
+ "font.family": "serif",
254
+ "text.usetex": True,
255
+ "pgf.rcfonts": False,
256
+ }
257
+ )
258
+
259
+
260
+ def save_fig_latex_pgf(filename, directory="."):
261
+ enable_plot_pgf()
262
+ if ".pgf" not in filename:
263
+ filename = f"{directory}/{filename}.pgf"
16
264
  plt.savefig(filename)
265
+
266
+
267
+ # https: // click.palletsprojects.com/en/8.1.x/api/
268
+ @click.command()
269
+ @click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
270
+ @click.option(
271
+ "--outdir",
272
+ "-o",
273
+ type=str,
274
+ help="output directory for the plot",
275
+ default=str(desktop_path),
276
+ )
277
+ @click.option(
278
+ "--tags", "-t", multiple=True, type=str, help="tags for the csv files", default=[]
279
+ )
280
+ @click.option("--log", "-l", is_flag=True, help="log the csv files")
281
+ @click.option("--save_fig", "-s", is_flag=True, help="save the plot as a file")
282
+ @click.option(
283
+ "--update_in_min",
284
+ "-u",
285
+ type=float,
286
+ help="update the plot every x minutes",
287
+ default=0.0,
288
+ )
289
+ def main(
290
+ csvfiles,
291
+ outdir,
292
+ tags,
293
+ log,
294
+ save_fig,
295
+ update_in_min,
296
+ ):
297
+ plot_csv_files(list(csvfiles), outdir, list(tags), log, save_fig, update_in_min)
298
+
299
+
300
+ if __name__ == "__main__":
301
+ main()
halib/projectmake.py CHANGED
@@ -10,7 +10,7 @@ import certifi
10
10
  import pycurl
11
11
 
12
12
  from halib.filetype import jsonfile
13
- from halib.sys import filesys
13
+ from halib.system import filesys
14
14
 
15
15
 
16
16
  def get_curl(url, user_and_pass, verbose=True):
File without changes
@@ -0,0 +1,100 @@
1
+ import os
2
+ from rich.pretty import pprint
3
+ from abc import ABC, abstractmethod
4
+ from dataclass_wizard import YAMLWizard
5
+
6
+
7
+ class NamedConfig(ABC):
8
+ """
9
+ Base class for named configurations.
10
+ All configurations should have a name.
11
+ """
12
+
13
+ @abstractmethod
14
+ def get_name(self):
15
+ """
16
+ Get the name of the configuration.
17
+ This method should be implemented in subclasses.
18
+ """
19
+ pass
20
+
21
+
22
+ class ExpBaseConfig(ABC, YAMLWizard):
23
+ """
24
+ Base class for configuration objects.
25
+ What a cfg class must have:
26
+ 1 - a dataset cfg
27
+ 2 - a metric cfg
28
+ 3 - a method cfg
29
+ """
30
+
31
+ # Save to yaml fil
32
+ def save_to_outdir(
33
+ self, filename: str = "__config.yaml", outdir=None, override: bool = False
34
+ ) -> None:
35
+ """
36
+ Save the configuration to the output directory.
37
+ """
38
+ if outdir is not None:
39
+ output_dir = outdir
40
+ else:
41
+ output_dir = self.get_outdir()
42
+ os.makedirs(output_dir, exist_ok=True)
43
+ assert (output_dir is not None) and (
44
+ os.path.isdir(output_dir)
45
+ ), f"Output directory '{output_dir}' does not exist or is not a directory."
46
+ file_path = os.path.join(output_dir, filename)
47
+ if os.path.exists(file_path) and not override:
48
+ pprint(
49
+ f"File '{file_path}' already exists. Use 'override=True' to overwrite."
50
+ )
51
+ else:
52
+ # method of YAMLWizard to_yaml_file
53
+ self.to_yaml_file(file_path)
54
+
55
+ @classmethod
56
+ @abstractmethod
57
+ # load from a custom YAML file
58
+ def from_custom_yaml_file(cls, yaml_file: str):
59
+ """Load a configuration from a custom YAML file."""
60
+ pass
61
+
62
+ @abstractmethod
63
+ def get_cfg_name(self):
64
+ """
65
+ Get the name of the configuration.
66
+ This method should be implemented in subclasses.
67
+ """
68
+ pass
69
+
70
+ @abstractmethod
71
+ def get_outdir(self):
72
+ """
73
+ Get the output directory for the configuration.
74
+ This method should be implemented in subclasses.
75
+ """
76
+ return None
77
+
78
+ @abstractmethod
79
+ def get_general_cfg(self):
80
+ """
81
+ Get the general configuration like output directory, log settings, SEED, etc.
82
+ This method should be implemented in subclasses.
83
+ """
84
+ pass
85
+
86
+ @abstractmethod
87
+ def get_dataset_cfg(self) -> NamedConfig:
88
+ """
89
+ Get the dataset configuration.
90
+ This method should be implemented in subclasses.
91
+ """
92
+ pass
93
+
94
+ @abstractmethod
95
+ def get_metric_cfg(self) -> NamedConfig:
96
+ """
97
+ Get the metric configuration.
98
+ This method should be implemented in subclasses.
99
+ """
100
+ pass
@@ -0,0 +1,100 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from ..research.base_config import ExpBaseConfig
4
+ from ..research.perfcalc import PerfCalc
5
+ from ..research.metrics import MetricsBackend
6
+
7
+ # ! SEE https://github.com/hahv/base_exp for sample usage
8
+ class BaseExperiment(PerfCalc, ABC):
9
+ """
10
+ Base class for experiments.
11
+ Orchestrates the experiment pipeline using a pluggable metrics backend.
12
+ """
13
+
14
+ def __init__(self, config: ExpBaseConfig):
15
+ self.config = config
16
+ self.metric_backend = None
17
+
18
+ # -----------------------
19
+ # PerfCalc Required Methods
20
+ # -----------------------
21
+ def get_dataset_name(self):
22
+ return self.config.get_dataset_cfg().get_name()
23
+
24
+ def get_experiment_name(self):
25
+ return self.config.get_cfg_name()
26
+
27
+ def get_metric_backend(self):
28
+ if not self.metric_backend:
29
+ self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
30
+ return self.metric_backend
31
+
32
+ # -----------------------
33
+ # Abstract Experiment Steps
34
+ # -----------------------
35
+ @abstractmethod
36
+ def init_general(self, general_cfg):
37
+ """Setup general settings like SEED, logging, env variables."""
38
+ pass
39
+
40
+ @abstractmethod
41
+ def prepare_dataset(self, dataset_cfg):
42
+ """Load/prepare dataset."""
43
+ pass
44
+
45
+ @abstractmethod
46
+ def prepare_metrics(self, metric_cfg) -> MetricsBackend:
47
+ """
48
+ Prepare the metrics for the experiment.
49
+ This method should be implemented in subclasses.
50
+ """
51
+ pass
52
+
53
+ @abstractmethod
54
+ def exec_exp(self, *args, **kwargs):
55
+ """Run experiment process, e.g.: training/evaluation loop.
56
+ Return: raw_metrics_data, and extra_data as input for calc_and_save_exp_perfs
57
+ """
58
+ pass
59
+
60
+ def eval_exp(self):
61
+ """Optional: re-run evaluation from saved results."""
62
+ pass
63
+
64
+ # -----------------------
65
+ # Main Experiment Runner
66
+ # -----------------------
67
+ def run_exp(self, do_calc_metrics=True, *args, **kwargs):
68
+ """
69
+ Run the whole experiment pipeline.
70
+ Params:
71
+ + 'outfile' to save csv file results,
72
+ + 'outdir' to set output directory for experiment results.
73
+ + 'return_df' to return a DataFrame of results instead of a dictionary.
74
+
75
+ Full pipeline:
76
+ 1. Init
77
+ 2. Dataset
78
+ 3. Metrics Preparation
79
+ 4. Save Config
80
+ 5. Execute
81
+ 6. Calculate & Save Metrics
82
+ """
83
+ self.init_general(self.config.get_general_cfg())
84
+ self.prepare_dataset(self.config.get_dataset_cfg())
85
+ self.prepare_metrics(self.config.get_metric_cfg())
86
+
87
+ # Save config before running
88
+ self.config.save_to_outdir()
89
+
90
+ # Execute experiment
91
+ results = self.exec_exp(*args, **kwargs)
92
+ if do_calc_metrics:
93
+ metrics_data, extra_data = results
94
+ # Calculate & Save metrics
95
+ perf_results = self.calc_and_save_exp_perfs(
96
+ raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
97
+ )
98
+ return perf_results
99
+ else:
100
+ return results
@@ -0,0 +1,131 @@
1
+ import pandas as pd
2
+ from rich.pretty import pprint
3
+ from argparse import ArgumentParser
4
+
5
+ def cols_to_col_groups(df):
6
+ columns = list(df.columns)
7
+ # pprint(columns)
8
+
9
+ col_groups = []
10
+ current_group = []
11
+
12
+ def have_unnamed(col_group):
13
+ return any("unnamed" in col.lower() for col in col_group)
14
+
15
+ for i, col in enumerate(columns):
16
+ # Add the first column to the current group
17
+ if not current_group:
18
+ current_group.append(col)
19
+ continue
20
+
21
+ prev_col = columns[i - 1]
22
+ # Check if current column is "unnamed" or shares base name with previous
23
+ # Assuming "equal" means same base name (before any suffix like '_1')
24
+ base_prev = (
25
+ prev_col.split("_")[0].lower() if "_" in prev_col else prev_col.lower()
26
+ )
27
+ base_col = col.split("_")[0].lower() if "_" in col else col.lower()
28
+ is_unnamed = "unnamed" in col.lower()
29
+ is_equal = base_col == base_prev
30
+
31
+ if is_unnamed or is_equal:
32
+ # Add to current group
33
+ current_group.append(col)
34
+ else:
35
+ # Start a new group
36
+ col_groups.append(current_group)
37
+ current_group = [col]
38
+ # Append the last group
39
+ if current_group:
40
+ col_groups.append(current_group)
41
+ meta_dict = {"common_cols": [], "db_cols": []}
42
+ for group in col_groups:
43
+ if not have_unnamed(group):
44
+ meta_dict["common_cols"].extend(group)
45
+ else:
46
+ # find the first unnamed column
47
+ named_col = next(
48
+ (col for col in group if "unnamed" not in col.lower()), None
49
+ )
50
+ group_cols = [f"{named_col}_{i}" for i in range(len(group))]
51
+ meta_dict["db_cols"].extend(group_cols)
52
+ return meta_dict
53
+
54
+ # def bech_by_db_name(df, db_list="db1, db2", key_metrics="p, r, f1, acc"):
55
+
56
+
57
+ def str_2_list(input_str, sep=","):
58
+ out_ls = []
59
+ if len(input_str.strip()) == 0:
60
+ return out_ls
61
+ if sep not in input_str:
62
+ out_ls.append(input_str.strip())
63
+ return out_ls
64
+ else:
65
+ out_ls = [item.strip() for item in input_str.split(sep) if item.strip()]
66
+ return out_ls
67
+
68
+
69
+ def filter_bech_df_by_db_and_metrics(df, db_list="", key_metrics=""):
70
+ meta_cols_dict = cols_to_col_groups(df)
71
+ op_df = df.copy()
72
+ op_df.columns = (
73
+ meta_cols_dict["common_cols"].copy() + meta_cols_dict["db_cols"].copy()
74
+ )
75
+ filterd_cols = []
76
+ filterd_cols.extend(meta_cols_dict["common_cols"])
77
+
78
+ selected_db_list = str_2_list(db_list)
79
+ db_filted_cols = []
80
+ if len(selected_db_list) > 0:
81
+ for db_name in db_list.split(","):
82
+ db_name = db_name.strip()
83
+ for col_name in meta_cols_dict["db_cols"]:
84
+ if db_name.lower() in col_name.lower():
85
+ db_filted_cols.append(col_name)
86
+ else:
87
+ db_filted_cols = meta_cols_dict["db_cols"]
88
+
89
+ filterd_cols.extend(db_filted_cols)
90
+ df_filtered = op_df[filterd_cols].copy()
91
+ df_filtered
92
+
93
+ selected_metrics_ls = str_2_list(key_metrics)
94
+ if len(selected_metrics_ls) > 0:
95
+ # get the second row as metrics row (header)
96
+ metrics_row = df_filtered.iloc[0].copy()
97
+ # only get the values in columns in (db_filterd_cols)
98
+ metrics_values = metrics_row[db_filted_cols].values
99
+ keep_metrics_cols = []
100
+ # create a zip of db_filted_cols and metrics_values (in that metrics_row)
101
+ metrics_list = list(zip(metrics_values, db_filted_cols))
102
+ selected_metrics_ls = [metric.strip().lower() for metric in selected_metrics_ls]
103
+ for metric, col_name in metrics_list:
104
+ if metric.lower() in selected_metrics_ls:
105
+ keep_metrics_cols.append(col_name)
106
+
107
+ else:
108
+ pprint("No metrics selected, keeping all db columns")
109
+ keep_metrics_cols = db_filted_cols
110
+
111
+ final_filterd_cols = meta_cols_dict["common_cols"].copy() + keep_metrics_cols
112
+ df_final = df_filtered[final_filterd_cols].copy()
113
+ return df_final
114
+
115
+
116
+ def parse_args():
117
+ parser = ArgumentParser(
118
+ description="desc text")
119
+ parser.add_argument('-csv', '--csv', type=str, help='CSV file path', default=r"E:\Dev\__halib\test\bench.csv")
120
+ return parser.parse_args()
121
+
122
+
123
+ def main():
124
+ args = parse_args()
125
+ csv_file = args.csv
126
+ df = pd.read_csv(csv_file, sep=";", encoding="utf-8")
127
+ filtered_df = filter_bech_df_by_db_and_metrics(df, "bowfire", "acc")
128
+ print(filtered_df)
129
+
130
+ if __name__ == "__main__":
131
+ main()