halib 0.1.7__py3-none-any.whl → 0.1.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. halib/__init__.py +84 -0
  2. halib/common.py +151 -0
  3. halib/cuda.py +39 -0
  4. halib/dataset.py +209 -0
  5. halib/filetype/csvfile.py +151 -45
  6. halib/filetype/ipynb.py +63 -0
  7. halib/filetype/jsonfile.py +1 -1
  8. halib/filetype/textfile.py +4 -4
  9. halib/filetype/videofile.py +44 -33
  10. halib/filetype/yamlfile.py +95 -0
  11. halib/gdrive.py +1 -1
  12. halib/online/gdrive.py +104 -54
  13. halib/online/gdrive_mkdir.py +29 -17
  14. halib/online/gdrive_test.py +31 -18
  15. halib/online/projectmake.py +58 -43
  16. halib/plot.py +296 -11
  17. halib/projectmake.py +1 -1
  18. halib/research/__init__.py +0 -0
  19. halib/research/base_config.py +100 -0
  20. halib/research/base_exp.py +100 -0
  21. halib/research/benchquery.py +131 -0
  22. halib/research/dataset.py +208 -0
  23. halib/research/flop_csv.py +34 -0
  24. halib/research/flops.py +156 -0
  25. halib/research/metrics.py +133 -0
  26. halib/research/mics.py +68 -0
  27. halib/research/params_gen.py +108 -0
  28. halib/research/perfcalc.py +336 -0
  29. halib/research/perftb.py +780 -0
  30. halib/research/plot.py +758 -0
  31. halib/research/profiler.py +300 -0
  32. halib/research/torchloader.py +162 -0
  33. halib/research/wandb_op.py +116 -0
  34. halib/rich_color.py +285 -0
  35. halib/sys/filesys.py +17 -10
  36. halib/system/__init__.py +0 -0
  37. halib/system/cmd.py +8 -0
  38. halib/system/filesys.py +124 -0
  39. halib/tele_noti.py +166 -0
  40. halib/torchloader.py +162 -0
  41. halib/utils/__init__.py +0 -0
  42. halib/utils/dataclass_util.py +40 -0
  43. halib/utils/dict_op.py +9 -0
  44. halib/utils/gpu_mon.py +58 -0
  45. halib/utils/listop.py +13 -0
  46. halib/utils/tele_noti.py +166 -0
  47. halib/utils/video.py +82 -0
  48. halib/videofile.py +1 -1
  49. halib-0.1.99.dist-info/METADATA +209 -0
  50. halib-0.1.99.dist-info/RECORD +64 -0
  51. {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/WHEEL +1 -1
  52. halib-0.1.7.dist-info/METADATA +0 -59
  53. halib-0.1.7.dist-info/RECORD +0 -30
  54. {halib-0.1.7.dist-info → halib-0.1.99.dist-info/licenses}/LICENSE.txt +0 -0
  55. {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,300 @@
1
+ import os
2
+ import time
3
+ import json
4
+
5
+ from pathlib import Path
6
+ from pprint import pprint
7
+ from threading import Lock
8
+
9
+ from plotly.subplots import make_subplots
10
+ import plotly.graph_objects as go
11
+ import plotly.express as px # for dynamic color scales
12
+ from ..common import ConsoleLog
13
+
14
+ from loguru import logger
15
+
16
+ class zProfiler:
17
+ """A singleton profiler to measure execution time of contexts and steps.
18
+
19
+ Args:
20
+ interval_report (int): Frequency of periodic reports (0 to disable).
21
+ stop_to_view (bool): Pause execution to view reports if True (only in debug mode).
22
+ output_file (str): Path to save the profiling report.
23
+ report_format (str): Output format for reports ("json" or "csv").
24
+
25
+ Example:
26
+ prof = zProfiler()
27
+ prof.ctx_start("my_context")
28
+ prof.step_start("my_context", "step1")
29
+ time.sleep(0.1)
30
+ prof.step_end("my_context", "step1")
31
+ prof.ctx_end("my_context")
32
+ """
33
+
34
+ _instance = None
35
+ _lock = Lock()
36
+
37
+ def __new__(cls, *args, **kwargs):
38
+ with cls._lock:
39
+ if cls._instance is None:
40
+ cls._instance = super().__new__(cls)
41
+ return cls._instance
42
+
43
+ def __init__(
44
+ self,
45
+ ):
46
+ if not hasattr(self, "_initialized"):
47
+ self.time_dict = {}
48
+ self._initialized = True
49
+
50
+ def ctx_start(self, ctx_name="ctx_default"):
51
+ if not isinstance(ctx_name, str) or not ctx_name:
52
+ raise ValueError("ctx_name must be a non-empty string")
53
+ if ctx_name not in self.time_dict:
54
+ self.time_dict[ctx_name] = {
55
+ "start": time.perf_counter(),
56
+ "step_dict": {},
57
+ "report_count": 0,
58
+ }
59
+ self.time_dict[ctx_name]["report_count"] += 1
60
+
61
+ def ctx_end(self, ctx_name="ctx_default", report_func=None):
62
+ if ctx_name not in self.time_dict:
63
+ return
64
+ self.time_dict[ctx_name]["end"] = time.perf_counter()
65
+ self.time_dict[ctx_name]["duration"] = (
66
+ self.time_dict[ctx_name]["end"] - self.time_dict[ctx_name]["start"]
67
+ )
68
+
69
+ def step_start(self, ctx_name, step_name):
70
+ if not isinstance(step_name, str) or not step_name:
71
+ raise ValueError("step_name must be a non-empty string")
72
+ if ctx_name not in self.time_dict:
73
+ return
74
+ if step_name not in self.time_dict[ctx_name]["step_dict"]:
75
+ self.time_dict[ctx_name]["step_dict"][step_name] = []
76
+ self.time_dict[ctx_name]["step_dict"][step_name].append([time.perf_counter()])
77
+
78
+ def step_end(self, ctx_name, step_name):
79
+ if (
80
+ ctx_name not in self.time_dict
81
+ or step_name not in self.time_dict[ctx_name]["step_dict"]
82
+ ):
83
+ return
84
+ self.time_dict[ctx_name]["step_dict"][step_name][-1].append(time.perf_counter())
85
+
86
+ def _step_dict_to_detail(self, ctx_step_dict):
87
+ """
88
+ 'ctx_step_dict': {
89
+ │ │ 'preprocess': [
90
+ │ │ │ [278090.947465806, 278090.960484853],
91
+ │ │ │ [278091.178424035, 278091.230944486],
92
+ │ │ 'infer': [
93
+ │ │ │ [278090.960490534, 278091.178424035],
94
+ │ │ │ [278091.230944486, 278091.251378469],
95
+ │ }
96
+ """
97
+ assert (
98
+ len(ctx_step_dict.keys()) > 0
99
+ ), "step_dict must have only one key (step_name) for detail."
100
+ normed_ctx_step_dict = {}
101
+ for step_name, time_list in ctx_step_dict.items():
102
+ if not isinstance(ctx_step_dict[step_name], list):
103
+ raise ValueError(f"Step data for {step_name} must be a list")
104
+ # step_name = list(ctx_step_dict.keys())[0] # ! debug
105
+ normed_time_ls = []
106
+ for idx, time_data in enumerate(time_list):
107
+ elapsed_time = -1
108
+ if len(time_data) == 2:
109
+ start, end = time_data[0], time_data[1]
110
+ elapsed_time = end - start
111
+ normed_time_ls.append((idx, elapsed_time)) # including step
112
+ normed_ctx_step_dict[step_name] = normed_time_ls
113
+ return normed_ctx_step_dict
114
+
115
+ def get_report_dict(self, with_detail=False):
116
+ report_dict = {}
117
+ for ctx_name, ctx_dict in self.time_dict.items():
118
+ report_dict[ctx_name] = {
119
+ "duration": ctx_dict.get("duration", 0.0),
120
+ "step_dict": {
121
+ "summary": {"avg_time": {}, "percent_time": {}},
122
+ "detail": {},
123
+ },
124
+ }
125
+
126
+ if with_detail:
127
+ report_dict[ctx_name]["step_dict"]["detail"] = (
128
+ self._step_dict_to_detail(ctx_dict["step_dict"])
129
+ )
130
+ avg_time_list = []
131
+ epsilon = 1e-5
132
+ for step_name, step_list in ctx_dict["step_dict"].items():
133
+ durations = []
134
+ try:
135
+ for time_data in step_list:
136
+ if len(time_data) != 2:
137
+ continue
138
+ start, end = time_data
139
+ durations.append(end - start)
140
+ except Exception as e:
141
+ logger.error(
142
+ f"Error processing step {step_name} in context {ctx_name}: {e}"
143
+ )
144
+ continue
145
+ if not durations:
146
+ continue
147
+ avg_time = sum(durations) / len(durations)
148
+ if avg_time < epsilon:
149
+ continue
150
+ avg_time_list.append((step_name, avg_time))
151
+ total_avg_time = (
152
+ sum(time for _, time in avg_time_list) or 1e-10
153
+ ) # Avoid division by zero
154
+ for step_name, avg_time in avg_time_list:
155
+ report_dict[ctx_name]["step_dict"]["summary"]["percent_time"][
156
+ f"per_{step_name}"
157
+ ] = (avg_time / total_avg_time) * 100.0
158
+ report_dict[ctx_name]["step_dict"]["summary"]["avg_time"][
159
+ f"avg_{step_name}"
160
+ ] = avg_time
161
+ report_dict[ctx_name]["step_dict"]["summary"][
162
+ "total_avg_time"
163
+ ] = total_avg_time
164
+ report_dict[ctx_name]["step_dict"]["summary"] = dict(
165
+ sorted(report_dict[ctx_name]["step_dict"]["summary"].items())
166
+ )
167
+ return report_dict
168
+
169
+ @classmethod
170
+ @classmethod
171
+ def plot_formatted_data(
172
+ cls, profiler_data, outdir=None, file_format="png", do_show=False, tag=""
173
+ ):
174
+ """
175
+ Plot each context in a separate figure with bar + pie charts.
176
+ Save each figure in the specified format (png or svg).
177
+ """
178
+
179
+ if outdir is not None:
180
+ os.makedirs(outdir, exist_ok=True)
181
+
182
+ if file_format.lower() not in ["png", "svg"]:
183
+ raise ValueError("file_format must be 'png' or 'svg'")
184
+
185
+ results = {} # {context: fig}
186
+
187
+ for ctx, ctx_data in profiler_data.items():
188
+ summary = ctx_data["step_dict"]["summary"]
189
+ avg_times = summary["avg_time"]
190
+ percent_times = summary["percent_time"]
191
+
192
+ step_names = [s.replace("avg_", "") for s in avg_times.keys()]
193
+ # pprint(f'{step_names=}')
194
+ n_steps = len(step_names)
195
+
196
+ assert n_steps > 0, "No steps found for context: {}".format(ctx)
197
+ # Generate dynamic colors
198
+ colors = px.colors.sample_colorscale(
199
+ "Viridis", [i / (n_steps - 1) for i in range(n_steps)]
200
+ ) if n_steps > 1 else [px.colors.sample_colorscale("Viridis", [0])[0]]
201
+ # pprint(f'{len(colors)} colors generated for {n_steps} steps')
202
+ color_map = dict(zip(step_names, colors))
203
+
204
+ # Create figure
205
+ fig = make_subplots(
206
+ rows=1,
207
+ cols=2,
208
+ subplot_titles=[f"Avg Time", f"% Time"],
209
+ specs=[[{"type": "bar"}, {"type": "pie"}]],
210
+ )
211
+
212
+ # Bar chart
213
+ fig.add_trace(
214
+ go.Bar(
215
+ x=step_names,
216
+ y=list(avg_times.values()),
217
+ text=[f"{v*1000:.2f} ms" for v in avg_times.values()],
218
+ textposition="outside",
219
+ marker=dict(color=[color_map[s] for s in step_names]),
220
+ name="", # unified legend
221
+ showlegend=False,
222
+ ),
223
+ row=1,
224
+ col=1,
225
+ )
226
+
227
+ # Pie chart (colors match bar)
228
+ fig.add_trace(
229
+ go.Pie(
230
+ labels=step_names,
231
+ values=list(percent_times.values()),
232
+ marker=dict(colors=[color_map[s] for s in step_names]),
233
+ hole=0.4,
234
+ name="",
235
+ showlegend=True,
236
+ ),
237
+ row=1,
238
+ col=2,
239
+ )
240
+ tag_str = tag if tag and len(tag) > 0 else ""
241
+ # Layout
242
+ fig.update_layout(
243
+ title_text=f"[{tag_str}] Context Profiler: {ctx}",
244
+ width=1000,
245
+ height=400,
246
+ showlegend=True,
247
+ legend=dict(title="Steps", x=1.05, y=0.5, traceorder="normal"),
248
+ hovermode="x unified",
249
+ )
250
+
251
+ fig.update_xaxes(title_text="Steps", row=1, col=1)
252
+ fig.update_yaxes(title_text="Avg Time (ms)", row=1, col=1)
253
+
254
+ # Show figure
255
+ if do_show:
256
+ fig.show()
257
+
258
+ # Save figure
259
+ if outdir is not None:
260
+ file_prefix = ctx if len(tag_str) == 0 else f"{tag_str}_{ctx}"
261
+ file_path = os.path.join(outdir, f"{file_prefix}_summary.{file_format.lower()}")
262
+ fig.write_image(file_path)
263
+ print(f"Saved figure: {file_path}")
264
+
265
+ results[ctx] = fig
266
+
267
+ return results
268
+
269
+ def report_and_plot(self, outdir=None, file_format="png", do_show=False, tag=""):
270
+ """
271
+ Generate the profiling report and plot the formatted data.
272
+
273
+ Args:
274
+ outdir (str): Directory to save figures. If None, figures are only shown.
275
+ file_format (str): Target file format, "png" or "svg". Default is "png".
276
+ do_show (bool): Whether to display the plots. Default is False.
277
+ """
278
+ report = self.get_report_dict()
279
+ self.get_report_dict(with_detail=False)
280
+ return self.plot_formatted_data(
281
+ report, outdir=outdir, file_format=file_format, do_show=do_show, tag=tag
282
+ )
283
+ def meta_info(self):
284
+ """
285
+ Print the structure of the profiler's time dictionary.
286
+ Useful for debugging and understanding the profiler's internal state.
287
+ """
288
+ for ctx_name, ctx_dict in self.time_dict.items():
289
+ with ConsoleLog(f"Context: {ctx_name}"):
290
+ step_names = list(ctx_dict['step_dict'].keys())
291
+ for step_name in step_names:
292
+ pprint(f"Step: {step_name}")
293
+
294
+ def save_report_dict(self, output_file, with_detail=False):
295
+ try:
296
+ report = self.get_report_dict(with_detail=with_detail)
297
+ with open(output_file, "w") as f:
298
+ json.dump(report, f, indent=4)
299
+ except Exception as e:
300
+ logger.error(f"Failed to save report to {output_file}: {e}")
@@ -0,0 +1,162 @@
1
+ """
2
+ * @author Hoang Van-Ha
3
+ * @email hoangvanhauit@gmail.com
4
+ * @create date 2024-03-27 15:40:22
5
+ * @modify date 2024-03-27 15:40:22
6
+ * @desc this module works as a utility tools for finding the best configuration for dataloader (num_workers, batch_size, pin_menory, etc.) that fits your hardware.
7
+ """
8
+ from argparse import ArgumentParser
9
+ from ..common import *
10
+ from ..filetype import csvfile
11
+ from ..filetype.yamlfile import load_yaml
12
+ from rich import inspect
13
+ from torch.utils.data import DataLoader
14
+ from torchvision import datasets, transforms
15
+ from tqdm import tqdm
16
+ from typing import Union
17
+ import itertools as it # for cartesian product
18
+ import os
19
+ import time
20
+ import traceback
21
+
22
+
23
+ def parse_args():
24
+ parser = ArgumentParser(description="desc text")
25
+ parser.add_argument("-cfg", "--cfg", type=str, help="cfg file for searching")
26
+ return parser.parse_args()
27
+
28
+
29
+ def get_test_range(cfg: dict, search_item="num_workers"):
30
+ item_search_cfg = cfg["search_space"].get(search_item, None)
31
+ if item_search_cfg is None:
32
+ raise ValueError(f"search_item: {search_item} not found in cfg")
33
+ if isinstance(item_search_cfg, list):
34
+ return item_search_cfg
35
+ elif isinstance(item_search_cfg, dict):
36
+ if "mode" in item_search_cfg:
37
+ mode = item_search_cfg["mode"]
38
+ assert mode in ["range", "list"], f"mode: {mode} not supported"
39
+ value_in_mode = item_search_cfg.get(mode, None)
40
+ if value_in_mode is None:
41
+ raise ValueError(f"mode<{mode}>: data not found in <{search_item}>")
42
+ if mode == "range":
43
+ assert len(value_in_mode) == 3, f"range must have 3 values: start, stop, step"
44
+ start = value_in_mode[0]
45
+ stop = value_in_mode[1]
46
+ step = value_in_mode[2]
47
+ return list(range(start, stop, step))
48
+ elif mode == "list":
49
+ return item_search_cfg["list"]
50
+ else:
51
+ return [item_search_cfg] # for int, float, str, bool, etc.
52
+
53
+
54
+ def load_an_batch(loader_iter):
55
+ start = time.time()
56
+ next(loader_iter)
57
+ end = time.time()
58
+ return end - start
59
+
60
+
61
+ def test_dataloader_with_cfg(origin_dataloader: DataLoader, cfg: Union[dict, str]):
62
+ try:
63
+ if isinstance(cfg, str):
64
+ cfg = load_yaml(cfg, to_dict=True)
65
+ dfmk = csvfile.DFCreator()
66
+ search_items = ["batch_size", "num_workers", "persistent_workers", "pin_memory"]
67
+ batch_limit = cfg["general"]["batch_limit"]
68
+ csv_cfg = cfg["general"]["to_csv"]
69
+ log_batch_info = cfg["general"]["log_batch_info"]
70
+
71
+ save_to_csv = csv_cfg["enabled"]
72
+ log_dir = csv_cfg["log_dir"]
73
+ filename = csv_cfg["filename"]
74
+ filename = f"{now_str()}_{filename}.csv"
75
+ outfile = os.path.join(log_dir, filename)
76
+
77
+ dfmk.create_table(
78
+ "cfg_search",
79
+ (search_items + ["avg_time_taken"]),
80
+ )
81
+ ls_range_test = []
82
+ for item in search_items:
83
+ range_test = get_test_range(cfg, search_item=item)
84
+ range_test = [(item, i) for i in range_test]
85
+ ls_range_test.append(range_test)
86
+
87
+ all_combinations = list(it.product(*ls_range_test))
88
+
89
+ rows = []
90
+ for cfg_idx, combine in enumerate(all_combinations):
91
+ console.rule(f"Testing cfg {cfg_idx+1}/{len(all_combinations)}")
92
+ inspect(combine)
93
+ batch_size = combine[search_items.index("batch_size")][1]
94
+ num_workers = combine[search_items.index("num_workers")][1]
95
+ persistent_workers = combine[search_items.index("persistent_workers")][1]
96
+ pin_memory = combine[search_items.index("pin_memory")][1]
97
+
98
+ test_dataloader = DataLoader(origin_dataloader.dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory, shuffle=True)
99
+ row = [
100
+ batch_size,
101
+ num_workers,
102
+ persistent_workers,
103
+ pin_memory,
104
+ 0.0,
105
+ ]
106
+
107
+ # calculate the avg time taken to load the data for <batch_limit> batches
108
+ trainiter = iter(test_dataloader)
109
+ time_elapsed = 0
110
+ pprint('Start testing...')
111
+ for i in tqdm(range(batch_limit)):
112
+ single_batch_time = load_an_batch(trainiter)
113
+ if log_batch_info:
114
+ pprint(f"Batch {i+1} took {single_batch_time:.4f} seconds to load")
115
+ time_elapsed += single_batch_time
116
+ row[-1] = time_elapsed / batch_limit
117
+ rows.append(row)
118
+ dfmk.insert_rows('cfg_search', rows)
119
+ dfmk.fill_table_from_row_pool('cfg_search')
120
+ with ConsoleLog("results"):
121
+ csvfile.fn_display_df(dfmk['cfg_search'])
122
+ if save_to_csv:
123
+ dfmk["cfg_search"].to_csv(outfile, index=False)
124
+ console.print(f"[red] Data saved to <{outfile}> [/red]")
125
+
126
+ except Exception as e:
127
+ traceback.print_exc()
128
+ print(e)
129
+ # get current directory of this python file
130
+ current_dir = os.path.dirname(os.path.realpath(__file__))
131
+ standar_cfg_path = os.path.join(current_dir, "torchloader_search.yaml")
132
+ pprint(
133
+ f"Make sure you get the right <cfg.yaml> file. An example of <cfg.yaml> file can be found at this path: {standar_cfg_path}"
134
+ )
135
+ return
136
+
137
+ def main():
138
+ args = parse_args()
139
+ cfg_yaml = args.cfg
140
+ cfg_dict = load_yaml(cfg_yaml, to_dict=True)
141
+
142
+ # Define transforms for data augmentation and normalization
143
+ transform = transforms.Compose(
144
+ [
145
+ transforms.RandomHorizontalFlip(), # Randomly flip images horizontally
146
+ transforms.RandomRotation(10), # Randomly rotate images by 10 degrees
147
+ transforms.ToTensor(), # Convert images to PyTorch tensors
148
+ transforms.Normalize(
149
+ (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
150
+ ), # Normalize pixel values to [-1, 1]
151
+ ]
152
+ )
153
+ test_dataset = datasets.CIFAR10(
154
+ root="./data", train=False, download=True, transform=transform
155
+ )
156
+ batch_size = 64
157
+ train_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
158
+ test_dataloader_with_cfg(train_loader, cfg_dict)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
@@ -0,0 +1,116 @@
1
+ import glob
2
+ from rich.pretty import pprint
3
+ import os
4
+ import subprocess
5
+ import argparse
6
+ import wandb
7
+ from tqdm import tqdm
8
+ from rich.console import Console
9
+ console = Console()
10
+
11
+ def sync_runs(outdir):
12
+ outdir = os.path.abspath(outdir)
13
+ assert os.path.exists(outdir), f"Output directory {outdir} does not exist."
14
+ sub_dirs = [name for name in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, name))]
15
+ assert len(sub_dirs) > 0, f"No subdirectories found in {outdir}."
16
+ console.rule("Parent Directory")
17
+ console.print(f"[yellow]{outdir}[/yellow]")
18
+
19
+ exp_dirs = [os.path.join(outdir, sub_dir) for sub_dir in sub_dirs]
20
+ wandb_dirs = []
21
+ for exp_dir in exp_dirs:
22
+ wandb_dirs.extend(glob.glob(f"{exp_dir}/wandb/*run-*"))
23
+ if len(wandb_dirs) == 0:
24
+ console.print(f"No wandb runs found in {outdir}.")
25
+ return
26
+ else:
27
+ console.print(f"Found [bold]{len(wandb_dirs)}[/bold] wandb runs in {outdir}.")
28
+ for i, wandb_dir in enumerate(wandb_dirs):
29
+ console.rule(f"Syncing wandb run {i + 1}/{len(wandb_dirs)}")
30
+ console.print(f"Syncing: {wandb_dir}")
31
+ process = subprocess.Popen(
32
+ ["wandb", "sync", wandb_dir],
33
+ stdout=subprocess.PIPE,
34
+ stderr=subprocess.STDOUT,
35
+ text=True,
36
+ )
37
+
38
+ for line in process.stdout:
39
+ console.print(line.strip())
40
+ if " ERROR Error while calling W&B API" in line:
41
+ break
42
+ process.stdout.close()
43
+ process.wait()
44
+ if process.returncode != 0:
45
+ console.print(f"[red]Error syncing {wandb_dir}. Return code: {process.returncode}[/red]")
46
+ else:
47
+ console.print(f"Successfully synced {wandb_dir}.")
48
+
49
+ def delete_runs(project, pattern=None):
50
+ console.rule("Delete W&B Runs")
51
+ confirm_msg = f"Are you sure you want to delete all runs in"
52
+ confirm_msg += f" \n\tproject: [red]{project}[/red]"
53
+ if pattern:
54
+ confirm_msg += f"\n\tpattern: [blue]{pattern}[/blue]"
55
+
56
+ console.print(confirm_msg)
57
+ confirmation = input(f"This action cannot be undone. [y/N]: ").strip().lower()
58
+ if confirmation != "y":
59
+ print("Cancelled.")
60
+ return
61
+
62
+ print("Confirmed. Proceeding...")
63
+ api = wandb.Api()
64
+ runs = api.runs(project)
65
+
66
+ deleted = 0
67
+ console.rule("Deleting W&B Runs")
68
+ if len(runs) == 0:
69
+ print("No runs found in the project.")
70
+ return
71
+ for run in tqdm(runs):
72
+ if pattern is None or pattern in run.name:
73
+ run.delete()
74
+ console.print(f"Deleted run: [red]{run.name}[/red]")
75
+ deleted += 1
76
+
77
+ console.print(f"Total runs deleted: {deleted}")
78
+
79
+
80
+ def valid_argument(args):
81
+ if args.op == "sync":
82
+ assert os.path.exists(args.outdir), f"Output directory {args.outdir} does not exist."
83
+ elif args.op == "delete":
84
+ assert isinstance(args.project, str) and len(args.project.strip()) > 0, "Project name must be a non-empty string."
85
+ else:
86
+ raise ValueError(f"Unknown operation: {args.op}")
87
+
88
+ def parse_args():
89
+ parser = argparse.ArgumentParser(description="Operations on W&B runs")
90
+ parser.add_argument("-op", "--op", type=str, help="Operation to perform", default="sync", choices=["delete", "sync"])
91
+ parser.add_argument("-prj", "--project", type=str, default="fire-paper2-2025", help="W&B project name")
92
+ parser.add_argument("-outdir", "--outdir", type=str, help="arg1 description", default="./zout/train")
93
+ parser.add_argument("-pt", "--pattern",
94
+ type=str,
95
+ default=None,
96
+ help="Run name pattern to match for deletion",
97
+ )
98
+
99
+ return parser.parse_args()
100
+
101
+
102
+ def main():
103
+ args = parse_args()
104
+ # Validate arguments, stop if invalid
105
+ valid_argument(args)
106
+
107
+ op = args.op
108
+ if op == "sync":
109
+ sync_runs(args.outdir)
110
+ elif op == "delete":
111
+ delete_runs(args.project, args.pattern)
112
+ else:
113
+ raise ValueError(f"Unknown operation: {op}")
114
+
115
+ if __name__ == "__main__":
116
+ main()