halib 0.1.7__py3-none-any.whl → 0.1.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. halib/__init__.py +84 -0
  2. halib/common.py +151 -0
  3. halib/cuda.py +39 -0
  4. halib/dataset.py +209 -0
  5. halib/filetype/csvfile.py +151 -45
  6. halib/filetype/ipynb.py +63 -0
  7. halib/filetype/jsonfile.py +1 -1
  8. halib/filetype/textfile.py +4 -4
  9. halib/filetype/videofile.py +44 -33
  10. halib/filetype/yamlfile.py +95 -0
  11. halib/gdrive.py +1 -1
  12. halib/online/gdrive.py +104 -54
  13. halib/online/gdrive_mkdir.py +29 -17
  14. halib/online/gdrive_test.py +31 -18
  15. halib/online/projectmake.py +58 -43
  16. halib/plot.py +296 -11
  17. halib/projectmake.py +1 -1
  18. halib/research/__init__.py +0 -0
  19. halib/research/base_config.py +100 -0
  20. halib/research/base_exp.py +100 -0
  21. halib/research/benchquery.py +131 -0
  22. halib/research/dataset.py +208 -0
  23. halib/research/flop_csv.py +34 -0
  24. halib/research/flops.py +156 -0
  25. halib/research/metrics.py +133 -0
  26. halib/research/mics.py +68 -0
  27. halib/research/params_gen.py +108 -0
  28. halib/research/perfcalc.py +336 -0
  29. halib/research/perftb.py +780 -0
  30. halib/research/plot.py +758 -0
  31. halib/research/profiler.py +300 -0
  32. halib/research/torchloader.py +162 -0
  33. halib/research/wandb_op.py +116 -0
  34. halib/rich_color.py +285 -0
  35. halib/sys/filesys.py +17 -10
  36. halib/system/__init__.py +0 -0
  37. halib/system/cmd.py +8 -0
  38. halib/system/filesys.py +124 -0
  39. halib/tele_noti.py +166 -0
  40. halib/torchloader.py +162 -0
  41. halib/utils/__init__.py +0 -0
  42. halib/utils/dataclass_util.py +40 -0
  43. halib/utils/dict_op.py +9 -0
  44. halib/utils/gpu_mon.py +58 -0
  45. halib/utils/listop.py +13 -0
  46. halib/utils/tele_noti.py +166 -0
  47. halib/utils/video.py +82 -0
  48. halib/videofile.py +1 -1
  49. halib-0.1.99.dist-info/METADATA +209 -0
  50. halib-0.1.99.dist-info/RECORD +64 -0
  51. {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/WHEEL +1 -1
  52. halib-0.1.7.dist-info/METADATA +0 -59
  53. halib-0.1.7.dist-info/RECORD +0 -30
  54. {halib-0.1.7.dist-info → halib-0.1.99.dist-info/licenses}/LICENSE.txt +0 -0
  55. {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ # This script create a test version
2
+ # of the watcam (wc) dataset
3
+ # for testing the tflite model
4
+
5
+ from argparse import ArgumentParser
6
+
7
+ import os
8
+ import click
9
+ import shutil
10
+ from tqdm import tqdm
11
+ from rich import inspect
12
+ from rich.pretty import pprint
13
+ from torchvision.datasets import ImageFolder
14
+ from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
15
+
16
+ from ..system import filesys as fs
17
+ from ..common import console, seed_everything, ConsoleLog
18
+
19
+ def parse_args():
20
+ parser = ArgumentParser(description="desc text")
21
+ parser.add_argument(
22
+ "-indir",
23
+ "--indir",
24
+ type=str,
25
+ help="orignal dataset path",
26
+ )
27
+ parser.add_argument(
28
+ "-outdir",
29
+ "--outdir",
30
+ type=str,
31
+ help="dataset out path",
32
+ default=".", # default to current dir
33
+ )
34
+ parser.add_argument(
35
+ "-val_size",
36
+ "--val_size",
37
+ type=float,
38
+ help="validation size", # no default value to force user to input
39
+ default=0.2,
40
+ )
41
+ # add using StratifiedShuffleSplit or ShuffleSplit
42
+ parser.add_argument(
43
+ "-seed",
44
+ "--seed",
45
+ type=int,
46
+ help="random seed",
47
+ default=42,
48
+ )
49
+ parser.add_argument(
50
+ "-inplace",
51
+ "--inplace",
52
+ action="store_true",
53
+ help="inplace operation, will overwrite the outdir if exists",
54
+ )
55
+
56
+ parser.add_argument(
57
+ "-stratified",
58
+ "--stratified",
59
+ action="store_true",
60
+ help="use StratifiedShuffleSplit instead of ShuffleSplit",
61
+ )
62
+ parser.add_argument(
63
+ "-no_train",
64
+ "--no_train",
65
+ action="store_true",
66
+ help="only create test set, no train set",
67
+ )
68
+ parser.add_argument(
69
+ "-reverse",
70
+ "--reverse",
71
+ action="store_true",
72
+ help="combine train and val set back to original dataset",
73
+ )
74
+ return parser.parse_args()
75
+
76
+
77
+ def move_images(image_paths, target_set_dir):
78
+ for img_path in tqdm(image_paths):
79
+ # get folder name of the image
80
+ img_dir = os.path.dirname(img_path)
81
+ out_cls_dir = os.path.join(target_set_dir, os.path.basename(img_dir))
82
+ if not os.path.exists(out_cls_dir):
83
+ os.makedirs(out_cls_dir)
84
+ # move the image to the class folder
85
+ shutil.move(img_path, out_cls_dir)
86
+
87
+
88
+ def split_dataset_cls(
89
+ indir, outdir, val_size, seed, inplace, stratified_split, no_train
90
+ ):
91
+ seed_everything(seed)
92
+ console.rule("Config confirm?")
93
+ pprint(locals())
94
+ click.confirm("Continue?", abort=True)
95
+ assert os.path.exists(indir), f"{indir} does not exist"
96
+
97
+ if not inplace:
98
+ assert (not inplace) and (
99
+ not os.path.exists(outdir)
100
+ ), f"{outdir} already exists; SKIP ...."
101
+
102
+ if inplace:
103
+ outdir = indir
104
+ if not os.path.exists(outdir):
105
+ os.makedirs(outdir)
106
+
107
+ console.rule(f"Creating train/val dataset")
108
+
109
+ sss = (
110
+ ShuffleSplit(n_splits=1, test_size=val_size)
111
+ if not stratified_split
112
+ else StratifiedShuffleSplit(n_splits=1, test_size=val_size)
113
+ )
114
+
115
+ pprint({"split strategy": sss, "indir": indir, "outdir": outdir})
116
+ dataset = ImageFolder(
117
+ root=indir,
118
+ transform=None,
119
+ )
120
+ train_dataset_indices = None
121
+ val_dataset_indices = None # val here means test
122
+ for train_indices, val_indices in sss.split(dataset.samples, dataset.targets):
123
+ train_dataset_indices = train_indices
124
+ val_dataset_indices = val_indices
125
+
126
+ # get image paths for train/val split dataset
127
+ train_image_paths = [dataset.imgs[i][0] for i in train_dataset_indices]
128
+ val_image_paths = [dataset.imgs[i][0] for i in val_dataset_indices]
129
+
130
+ # start creating train/val folders then move images
131
+ out_train_dir = os.path.join(outdir, "train")
132
+ out_val_dir = os.path.join(outdir, "val")
133
+ if inplace:
134
+ assert os.path.exists(out_train_dir) == False, f"{out_train_dir} already exists"
135
+ assert os.path.exists(out_val_dir) == False, f"{out_val_dir} already exists"
136
+
137
+ os.makedirs(out_train_dir)
138
+ os.makedirs(out_val_dir)
139
+
140
+ if not no_train:
141
+ with ConsoleLog(f"Moving train images to {out_train_dir} "):
142
+ move_images(train_image_paths, out_train_dir)
143
+ else:
144
+ pprint("test only, skip moving train images")
145
+ # remove out_train_dir
146
+ shutil.rmtree(out_train_dir)
147
+
148
+ with ConsoleLog(f"Moving val images to {out_val_dir} "):
149
+ move_images(val_image_paths, out_val_dir)
150
+
151
+ if inplace:
152
+ pprint(f"remove all folders, except train and val")
153
+ for cls_dir in os.listdir(outdir):
154
+ if cls_dir not in ["train", "val"]:
155
+ shutil.rmtree(os.path.join(indir, cls_dir))
156
+
157
+
158
+ def reverse_split_ds(indir):
159
+ console.rule(f"Reversing split dataset <{indir}>...")
160
+ ls_dirs = os.listdir(indir)
161
+ # make sure there are only two dirs 'train' and 'val'
162
+ assert len(ls_dirs) == 2, f"Found more than 2 dirs: {len(ls_dirs) } dirs"
163
+ assert "train" in ls_dirs, f"train dir not found in {indir}"
164
+ assert "val" in ls_dirs, f"val dir not found in {indir}"
165
+ train_dir = os.path.join(indir, "train")
166
+ val_dir = os.path.join(indir, "val")
167
+ all_train_files = fs.filter_files_by_extension(
168
+ train_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
169
+ )
170
+ all_val_files = fs.filter_files_by_extension(
171
+ val_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
172
+ )
173
+ # move all files from train to indir
174
+ with ConsoleLog(f"Moving train images to {indir} "):
175
+ move_images(all_train_files, indir)
176
+ with ConsoleLog(f"Moving val images to {indir} "):
177
+ move_images(all_val_files, indir)
178
+ with ConsoleLog(f"Removing train and val dirs"):
179
+ # remove train and val dirs
180
+ shutil.rmtree(train_dir)
181
+ shutil.rmtree(val_dir)
182
+
183
+
184
+ def main():
185
+ args = parse_args()
186
+ indir = args.indir
187
+ outdir = args.outdir
188
+ if outdir == ".":
189
+ # get current folder of the indir
190
+ indir_parent_dir = os.path.dirname(os.path.normpath(indir))
191
+ indir_name = os.path.basename(indir)
192
+ outdir = os.path.join(indir_parent_dir, f"{indir_name}_split")
193
+ val_size = args.val_size
194
+ seed = args.seed
195
+ inplace = args.inplace
196
+ stratified_split = args.stratified
197
+ no_train = args.no_train
198
+ reverse = args.reverse
199
+ if not reverse:
200
+ split_dataset_cls(
201
+ indir, outdir, val_size, seed, inplace, stratified_split, no_train
202
+ )
203
+ else:
204
+ reverse_split_ds(indir)
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()
@@ -0,0 +1,34 @@
1
+ from halib import *
2
+ from flops import _calculate_flops_for_model
3
+
4
+ from halib import *
5
+ from argparse import ArgumentParser
6
+
7
+
8
+ def main():
9
+ csv_file = "./results-imagenet.csv"
10
+ df = pd.read_csv(csv_file)
11
+ # make param_count column as float
12
+ # df['param_count'] = df['param_count'].astype(float)
13
+ df['param_count'] = pd.to_numeric(df['param_count'], errors='coerce').fillna(99999).astype(float)
14
+ df = df[df['param_count'] < 5.0] # filter models with param_count < 20M
15
+
16
+ dict_ls = []
17
+
18
+ for index, row in tqdm(df.iterrows()):
19
+ console.rule(f"Row {index+1}/{len(df)}")
20
+ model = row['model']
21
+ num_class = 2
22
+ _, _, mflops = _calculate_flops_for_model(model, num_class)
23
+ dict_ls.append({'model': model, 'param_count': row['param_count'], 'mflops': mflops})
24
+
25
+ # Create a DataFrame from the list of dictionaries
26
+ result_df = pd.DataFrame(dict_ls)
27
+
28
+ final_df = pd.merge(df, result_df, on=['model', 'param_count'])
29
+ final_df.sort_values(by='mflops', inplace=True, ascending=True)
30
+ csvfile.fn_display_df(final_df)
31
+
32
+
33
+ if __name__ == "__main__":
34
+ main()
@@ -0,0 +1,156 @@
1
+ import os
2
+ import sys
3
+ import torch
4
+ import timm
5
+ from argparse import ArgumentParser
6
+ from fvcore.nn import FlopCountAnalysis
7
+ from halib import *
8
+ from halib.filetype import csvfile
9
+ from curriculum.utils.config import *
10
+ from curriculum.utils.model_helper import *
11
+
12
+
13
+ # ---------------------------------------------------------------------
14
+ # Argument Parser
15
+ # ---------------------------------------------------------------------
16
+ def parse_args():
17
+ parser = ArgumentParser(description="Calculate FLOPs for TIMM or trained models")
18
+
19
+ # Option 1: Direct TIMM model
20
+ parser.add_argument(
21
+ "--model_name", type=str, help="TIMM model name (e.g., efficientnet_b0)"
22
+ )
23
+ parser.add_argument(
24
+ "--num_classes", type=int, default=1000, help="Number of output classes"
25
+ )
26
+
27
+ # Option 2: Experiment directory
28
+ parser.add_argument(
29
+ "--indir",
30
+ type=str,
31
+ default=None,
32
+ help="Directory containing trained experiment (with .yaml and .pth)",
33
+ )
34
+ parser.add_argument(
35
+ "-o", "--o", action="store_true", help="Open output CSV after saving"
36
+ )
37
+ return parser.parse_args()
38
+
39
+
40
+ # ---------------------------------------------------------------------
41
+ # Helper Functions
42
+ # ---------------------------------------------------------------------
43
+ def _get_list_of_proc_dirs(indir):
44
+ assert os.path.exists(indir), f"Input directory {indir} does not exist."
45
+ pth_files = [f for f in os.listdir(indir) if f.endswith(".pth")]
46
+ if len(pth_files) > 0:
47
+ return [indir]
48
+ return [
49
+ os.path.join(indir, f)
50
+ for f in os.listdir(indir)
51
+ if os.path.isdir(os.path.join(indir, f))
52
+ ]
53
+
54
+
55
+ def _calculate_flops_for_model(model_name, num_classes):
56
+ """Calculate FLOPs for a plain TIMM model."""
57
+ try:
58
+ model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
59
+ input_size = timm.data.resolve_data_config(model.default_cfg)["input_size"]
60
+ dummy_input = torch.randn(1, *input_size)
61
+ model.eval() # ! set to eval mode to avoid some warnings or errors
62
+ flops = FlopCountAnalysis(model, dummy_input)
63
+ gflops = flops.total() / 1e9
64
+ mflops = flops.total() / 1e6
65
+ print(f"\nModel: **{model_name}**, Classes: {num_classes}")
66
+ print(f"Input size: {input_size}, FLOPs: **{gflops:.3f} GFLOPs**, **{mflops:.3f} MFLOPs**\n")
67
+ return model_name, gflops, mflops
68
+ except Exception as e:
69
+ print(f"[Error] Could not calculate FLOPs for {model_name}: {e}")
70
+ return model_name, -1, -1
71
+
72
+
73
+ def _calculate_flops_for_experiment(exp_dir):
74
+ """Calculate FLOPs for a trained experiment directory."""
75
+ yaml_files = [f for f in os.listdir(exp_dir) if f.endswith(".yaml")]
76
+ pth_files = [f for f in os.listdir(exp_dir) if f.endswith(".pth")]
77
+
78
+ assert (
79
+ len(yaml_files) == 1
80
+ ), f"Expected 1 YAML file in {exp_dir}, found {len(yaml_files)}"
81
+ assert (
82
+ len(pth_files) == 1
83
+ ), f"Expected 1 PTH file in {exp_dir}, found {len(pth_files)}"
84
+
85
+ exp_cfg_yaml = os.path.join(exp_dir, yaml_files[0])
86
+ cfg = ExpConfig.from_yaml(exp_cfg_yaml)
87
+ ds_label_list = cfg.dataset.get_label_list()
88
+
89
+ try:
90
+ model = build_model(
91
+ cfg.model.name, num_classes=len(ds_label_list), pretrained=True
92
+ )
93
+ model_weights_path = os.path.join(exp_dir, pth_files[0])
94
+ model.load_state_dict(torch.load(model_weights_path, map_location="cpu"))
95
+ model.eval()
96
+
97
+ input_size = timm.data.resolve_data_config(model.default_cfg)["input_size"]
98
+ dummy_input = torch.randn(1, *input_size)
99
+ flops = FlopCountAnalysis(model, dummy_input)
100
+ gflops = flops.total() / 1e9
101
+ mflops = flops.total() / 1e6
102
+
103
+ return str(cfg), cfg.model.name, gflops, mflops
104
+ except Exception as e:
105
+ console.print(f"[red] Error processing {exp_dir}: {e}[/red]")
106
+ return str(cfg), cfg.model.name, -1, -1
107
+
108
+
109
+ # ---------------------------------------------------------------------
110
+ # Main Entry
111
+ # ---------------------------------------------------------------------
112
+ def main():
113
+ args = parse_args()
114
+
115
+ # Case 1: Direct TIMM model input
116
+ if args.model_name:
117
+ _calculate_flops_for_model(args.model_name, args.num_classes)
118
+ return
119
+
120
+ # Case 2: Experiment directory input
121
+ if args.indir is None:
122
+ print("[Error] Either --model_name or --indir must be specified.")
123
+ return
124
+
125
+ proc_dirs = _get_list_of_proc_dirs(args.indir)
126
+ pprint(proc_dirs)
127
+
128
+ dfmk = csvfile.DFCreator()
129
+ TABLE_NAME = "model_flops_results"
130
+ dfmk.create_table(TABLE_NAME, ["exp_name", "model_name", "gflops", "mflops"])
131
+
132
+ console.rule(f"Calculating FLOPs for models in {len(proc_dirs)} dir(s)...")
133
+ rows = []
134
+ for exp_dir in tqdm(proc_dirs):
135
+ dir_name = os.path.basename(exp_dir)
136
+ console.rule(f"{dir_name}")
137
+ exp_name, model_name, gflops, mflops = _calculate_flops_for_experiment(exp_dir)
138
+ rows.append([exp_name, model_name, gflops, mflops])
139
+
140
+ dfmk.insert_rows(TABLE_NAME, rows)
141
+ dfmk.fill_table_from_row_pool(TABLE_NAME)
142
+
143
+ outfile = f"zout/zreport/{now_str()}_model_flops_results.csv"
144
+ dfmk[TABLE_NAME].to_csv(outfile, sep=";", index=False)
145
+ csvfile.fn_display_df(dfmk[TABLE_NAME])
146
+
147
+ if args.o:
148
+ os.system(f"start {outfile}")
149
+
150
+
151
+ # ---------------------------------------------------------------------
152
+ # Script Entry
153
+ # ---------------------------------------------------------------------
154
+ if __name__ == "__main__":
155
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
156
+ main()
@@ -0,0 +1,133 @@
1
+ # -------------------------------
2
+ # Metrics Backend Interface
3
+ # -------------------------------
4
+ import inspect
5
+ from typing import Dict, Union, List, Any
6
+ from abc import ABC, abstractmethod
7
+
8
+ class MetricsBackend(ABC):
9
+ """Interface for pluggable metrics computation backends."""
10
+
11
+ def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
12
+ """
13
+ Initialize the backend with optional metrics_info.
14
+ """
15
+ self.metric_info = metrics_info
16
+ self.validate_metrics_info(self.metric_info)
17
+
18
+ @property
19
+ def metric_names(self) -> List[str]:
20
+ """
21
+ Return a list of metric names.
22
+ If metric_info is a dict, return its keys; if it's a list, return it directly.
23
+ """
24
+ if isinstance(self.metric_info, dict):
25
+ return list(self.metric_info.keys())
26
+ elif isinstance(self.metric_info, list):
27
+ return self.metric_info
28
+ else:
29
+ raise TypeError("metric_info must be a list or a dict")
30
+
31
+ def validate_metrics_info(self, metrics_info):
32
+ if isinstance(metrics_info, list):
33
+ return metrics_info
34
+ elif isinstance(metrics_info, dict):
35
+ return {k: v for k, v in metrics_info.items() if isinstance(k, str)}
36
+ else:
37
+ raise TypeError(
38
+ "metrics_info must be a list of strings or a dict with string keys"
39
+ )
40
+
41
+ @abstractmethod
42
+ def compute_metrics(
43
+ self, metrics_info: Union[List[str], Dict[str, Any]], metrics_data_dict: Dict[str, Any], *args, **kwargs
44
+ ) -> Dict[str, Any]:
45
+ pass
46
+
47
+ def prepare_metrics_backend_data(
48
+ self, raw_metric_data, *args, **kwargs
49
+ ):
50
+ """
51
+ Prepare the data for the metrics backend.
52
+ This method can be overridden by subclasses to customize data preparation.
53
+ """
54
+ return raw_metric_data
55
+
56
+ def calc_metrics(
57
+ self, metrics_data_dict: Dict[str, Any], *args, **kwargs
58
+ ) -> Dict[str, Any]:
59
+ """
60
+ Calculate metrics based on the provided metrics_info and data.
61
+ This method should be overridden by subclasses to implement specific metric calculations.
62
+ """
63
+ # prevalidate the metrics_data_dict
64
+ for metric in self.metric_names:
65
+ if metric not in metrics_data_dict:
66
+ raise ValueError(f"Metric '{metric}' not found in provided data.")
67
+ # Prepare the data for the backend
68
+ metrics_data_dict = self.prepare_metrics_backend_data(
69
+ metrics_data_dict, *args, **kwargs
70
+ )
71
+ # Call the abstract method to compute metrics
72
+ return self.compute_metrics(self.metric_info, metrics_data_dict, *args, **kwargs)
73
+
74
+ class TorchMetricsBackend(MetricsBackend):
75
+ """TorchMetrics-based backend implementation."""
76
+
77
+ def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
78
+ try:
79
+ import torch
80
+ from torchmetrics import Metric
81
+ except ImportError:
82
+ raise ImportError(
83
+ "TorchMetricsBackend requires torch and torchmetrics to be installed."
84
+ )
85
+ self.metric_info = metrics_info
86
+ self.torch = torch
87
+ self.Metric = Metric
88
+ self.validate_metrics_info(metrics_info)
89
+
90
+ def validate_metrics_info(self, metrics_info):
91
+ if not isinstance(metrics_info, dict):
92
+ raise TypeError(
93
+ "TorchMetricsBackend requires metrics_info as a dict {name: MetricInstance}"
94
+ )
95
+ for k, v in metrics_info.items():
96
+ if not isinstance(k, str):
97
+ raise TypeError(f"Key '{k}' is not a string")
98
+ if not isinstance(v, self.Metric):
99
+ raise TypeError(f"Value for key '{k}' must be a torchmetrics.Metric")
100
+ return metrics_info
101
+
102
+ def compute_metrics(self, metrics_info, metrics_data_dict, *args, **kwargs):
103
+ out_dict = {}
104
+ for metric, metric_instance in metrics_info.items():
105
+ if metric not in metrics_data_dict:
106
+ raise ValueError(f"Metric '{metric}' not found in provided data.")
107
+
108
+ metric_data = metrics_data_dict[metric]
109
+ sig = inspect.signature(metric_instance.update)
110
+ expected_args = list(sig.parameters.values())
111
+
112
+ if isinstance(metric_data, dict):
113
+ args = [metric_data[param.name] for param in expected_args]
114
+ elif isinstance(metric_data, (list, tuple)):
115
+ args = metric_data
116
+ else:
117
+ args = metric_data
118
+ if len(expected_args) == 1:
119
+ metric_instance.update(args)
120
+ else:
121
+ metric_instance.update(*args)
122
+
123
+ computed_value = metric_instance.compute()
124
+ if isinstance(computed_value, self.torch.Tensor):
125
+ computed_value = (
126
+ computed_value.item()
127
+ if computed_value.numel() == 1
128
+ else computed_value.tolist()
129
+ )
130
+
131
+
132
+ out_dict[metric] = computed_value
133
+ return out_dict
halib/research/mics.py ADDED
@@ -0,0 +1,68 @@
1
+ from ..common import *
2
+ from ..filetype import csvfile
3
+ import pandas as pd
4
+ import platform
5
+
6
+
7
+ PC_NAME_TO_ABBR = {
8
+ "DESKTOP-JQD9K01": "MainPC",
9
+ "DESKTOP-5IRHU87": "MSI_Laptop",
10
+ "DESKTOP-96HQCNO": "4090_SV",
11
+ "DESKTOP-Q2IKLC0": "4GPU_SV",
12
+ "DESKTOP-QNS3DNF": "1GPU_SV"
13
+ }
14
+
15
+ DEFAULT_ABBR_WORKING_DISK = {
16
+ "MainPC": "E:",
17
+ "MSI_Laptop": "D:",
18
+ "4090_SV": "E:",
19
+ "4GPU_SV": "D:",
20
+ }
21
+
22
+ def list_PCs(show=True):
23
+ df = pd.DataFrame(list(PC_NAME_TO_ABBR.items()), columns=["PC Name", "Abbreviation"])
24
+ if show:
25
+ csvfile.fn_display_df(df)
26
+ return df
27
+
28
+ def get_PC_name():
29
+ return platform.node()
30
+
31
+ def get_PC_abbr_name():
32
+ pc_name = get_PC_name()
33
+ return PC_NAME_TO_ABBR.get(pc_name, "Unknown")
34
+
35
+ # ! This funcction search for full paths in the obj and normalize them according to the current platform and working disk
36
+ # ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "windows" => "D:/zdataset/DFire"
37
+ # ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "linux" => "/mnt/d/zdataset/DFire"
38
+ def normalize_paths(obj, working_disk, current_platform):
39
+ if isinstance(obj, dict):
40
+ for key, value in obj.items():
41
+ obj[key] = normalize_paths(value, working_disk, current_platform)
42
+ return obj
43
+ elif isinstance(obj, list):
44
+ for i, item in enumerate(obj):
45
+ obj[i] = normalize_paths(item, working_disk, current_platform)
46
+ return obj
47
+ elif isinstance(obj, str):
48
+ # Normalize backslashes to forward slashes for consistency
49
+ obj = obj.replace("\\", "/")
50
+ # Regex for Windows-style path: e.g., "E:/zdataset/DFire"
51
+ win_match = re.match(r"^([A-Z]):/(.*)$", obj)
52
+ # Regex for Linux-style path: e.g., "/mnt/e/zdataset/DFire"
53
+ lin_match = re.match(r"^/mnt/([a-z])/(.*)$", obj)
54
+ if win_match or lin_match:
55
+ rest = win_match.group(2) if win_match else lin_match.group(2)
56
+ if current_platform == "windows":
57
+ # working_disk is like "D:", so "D:/" + rest
58
+ new_path = working_disk + "/" + rest
59
+ elif current_platform == "linux":
60
+ # Extract drive letter from working_disk (e.g., "D:" -> "d")
61
+ drive_letter = working_disk[0].lower()
62
+ new_path = "/mnt/" + drive_letter + "/" + rest
63
+ else:
64
+ # Unknown platform, return original
65
+ return obj
66
+ return new_path
67
+ # For non-strings or non-path strings, return as is
68
+ return obj