halib 0.1.7__py3-none-any.whl → 0.1.99__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/__init__.py +84 -0
- halib/common.py +151 -0
- halib/cuda.py +39 -0
- halib/dataset.py +209 -0
- halib/filetype/csvfile.py +151 -45
- halib/filetype/ipynb.py +63 -0
- halib/filetype/jsonfile.py +1 -1
- halib/filetype/textfile.py +4 -4
- halib/filetype/videofile.py +44 -33
- halib/filetype/yamlfile.py +95 -0
- halib/gdrive.py +1 -1
- halib/online/gdrive.py +104 -54
- halib/online/gdrive_mkdir.py +29 -17
- halib/online/gdrive_test.py +31 -18
- halib/online/projectmake.py +58 -43
- halib/plot.py +296 -11
- halib/projectmake.py +1 -1
- halib/research/__init__.py +0 -0
- halib/research/base_config.py +100 -0
- halib/research/base_exp.py +100 -0
- halib/research/benchquery.py +131 -0
- halib/research/dataset.py +208 -0
- halib/research/flop_csv.py +34 -0
- halib/research/flops.py +156 -0
- halib/research/metrics.py +133 -0
- halib/research/mics.py +68 -0
- halib/research/params_gen.py +108 -0
- halib/research/perfcalc.py +336 -0
- halib/research/perftb.py +780 -0
- halib/research/plot.py +758 -0
- halib/research/profiler.py +300 -0
- halib/research/torchloader.py +162 -0
- halib/research/wandb_op.py +116 -0
- halib/rich_color.py +285 -0
- halib/sys/filesys.py +17 -10
- halib/system/__init__.py +0 -0
- halib/system/cmd.py +8 -0
- halib/system/filesys.py +124 -0
- halib/tele_noti.py +166 -0
- halib/torchloader.py +162 -0
- halib/utils/__init__.py +0 -0
- halib/utils/dataclass_util.py +40 -0
- halib/utils/dict_op.py +9 -0
- halib/utils/gpu_mon.py +58 -0
- halib/utils/listop.py +13 -0
- halib/utils/tele_noti.py +166 -0
- halib/utils/video.py +82 -0
- halib/videofile.py +1 -1
- halib-0.1.99.dist-info/METADATA +209 -0
- halib-0.1.99.dist-info/RECORD +64 -0
- {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/WHEEL +1 -1
- halib-0.1.7.dist-info/METADATA +0 -59
- halib-0.1.7.dist-info/RECORD +0 -30
- {halib-0.1.7.dist-info → halib-0.1.99.dist-info/licenses}/LICENSE.txt +0 -0
- {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from pprint import pprint
|
|
7
|
+
from threading import Lock
|
|
8
|
+
|
|
9
|
+
from plotly.subplots import make_subplots
|
|
10
|
+
import plotly.graph_objects as go
|
|
11
|
+
import plotly.express as px # for dynamic color scales
|
|
12
|
+
from ..common import ConsoleLog
|
|
13
|
+
|
|
14
|
+
from loguru import logger
|
|
15
|
+
|
|
16
|
+
class zProfiler:
|
|
17
|
+
"""A singleton profiler to measure execution time of contexts and steps.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
interval_report (int): Frequency of periodic reports (0 to disable).
|
|
21
|
+
stop_to_view (bool): Pause execution to view reports if True (only in debug mode).
|
|
22
|
+
output_file (str): Path to save the profiling report.
|
|
23
|
+
report_format (str): Output format for reports ("json" or "csv").
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
prof = zProfiler()
|
|
27
|
+
prof.ctx_start("my_context")
|
|
28
|
+
prof.step_start("my_context", "step1")
|
|
29
|
+
time.sleep(0.1)
|
|
30
|
+
prof.step_end("my_context", "step1")
|
|
31
|
+
prof.ctx_end("my_context")
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
_instance = None
|
|
35
|
+
_lock = Lock()
|
|
36
|
+
|
|
37
|
+
def __new__(cls, *args, **kwargs):
|
|
38
|
+
with cls._lock:
|
|
39
|
+
if cls._instance is None:
|
|
40
|
+
cls._instance = super().__new__(cls)
|
|
41
|
+
return cls._instance
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
):
|
|
46
|
+
if not hasattr(self, "_initialized"):
|
|
47
|
+
self.time_dict = {}
|
|
48
|
+
self._initialized = True
|
|
49
|
+
|
|
50
|
+
def ctx_start(self, ctx_name="ctx_default"):
|
|
51
|
+
if not isinstance(ctx_name, str) or not ctx_name:
|
|
52
|
+
raise ValueError("ctx_name must be a non-empty string")
|
|
53
|
+
if ctx_name not in self.time_dict:
|
|
54
|
+
self.time_dict[ctx_name] = {
|
|
55
|
+
"start": time.perf_counter(),
|
|
56
|
+
"step_dict": {},
|
|
57
|
+
"report_count": 0,
|
|
58
|
+
}
|
|
59
|
+
self.time_dict[ctx_name]["report_count"] += 1
|
|
60
|
+
|
|
61
|
+
def ctx_end(self, ctx_name="ctx_default", report_func=None):
|
|
62
|
+
if ctx_name not in self.time_dict:
|
|
63
|
+
return
|
|
64
|
+
self.time_dict[ctx_name]["end"] = time.perf_counter()
|
|
65
|
+
self.time_dict[ctx_name]["duration"] = (
|
|
66
|
+
self.time_dict[ctx_name]["end"] - self.time_dict[ctx_name]["start"]
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def step_start(self, ctx_name, step_name):
|
|
70
|
+
if not isinstance(step_name, str) or not step_name:
|
|
71
|
+
raise ValueError("step_name must be a non-empty string")
|
|
72
|
+
if ctx_name not in self.time_dict:
|
|
73
|
+
return
|
|
74
|
+
if step_name not in self.time_dict[ctx_name]["step_dict"]:
|
|
75
|
+
self.time_dict[ctx_name]["step_dict"][step_name] = []
|
|
76
|
+
self.time_dict[ctx_name]["step_dict"][step_name].append([time.perf_counter()])
|
|
77
|
+
|
|
78
|
+
def step_end(self, ctx_name, step_name):
|
|
79
|
+
if (
|
|
80
|
+
ctx_name not in self.time_dict
|
|
81
|
+
or step_name not in self.time_dict[ctx_name]["step_dict"]
|
|
82
|
+
):
|
|
83
|
+
return
|
|
84
|
+
self.time_dict[ctx_name]["step_dict"][step_name][-1].append(time.perf_counter())
|
|
85
|
+
|
|
86
|
+
def _step_dict_to_detail(self, ctx_step_dict):
|
|
87
|
+
"""
|
|
88
|
+
'ctx_step_dict': {
|
|
89
|
+
│ │ 'preprocess': [
|
|
90
|
+
│ │ │ [278090.947465806, 278090.960484853],
|
|
91
|
+
│ │ │ [278091.178424035, 278091.230944486],
|
|
92
|
+
│ │ 'infer': [
|
|
93
|
+
│ │ │ [278090.960490534, 278091.178424035],
|
|
94
|
+
│ │ │ [278091.230944486, 278091.251378469],
|
|
95
|
+
│ }
|
|
96
|
+
"""
|
|
97
|
+
assert (
|
|
98
|
+
len(ctx_step_dict.keys()) > 0
|
|
99
|
+
), "step_dict must have only one key (step_name) for detail."
|
|
100
|
+
normed_ctx_step_dict = {}
|
|
101
|
+
for step_name, time_list in ctx_step_dict.items():
|
|
102
|
+
if not isinstance(ctx_step_dict[step_name], list):
|
|
103
|
+
raise ValueError(f"Step data for {step_name} must be a list")
|
|
104
|
+
# step_name = list(ctx_step_dict.keys())[0] # ! debug
|
|
105
|
+
normed_time_ls = []
|
|
106
|
+
for idx, time_data in enumerate(time_list):
|
|
107
|
+
elapsed_time = -1
|
|
108
|
+
if len(time_data) == 2:
|
|
109
|
+
start, end = time_data[0], time_data[1]
|
|
110
|
+
elapsed_time = end - start
|
|
111
|
+
normed_time_ls.append((idx, elapsed_time)) # including step
|
|
112
|
+
normed_ctx_step_dict[step_name] = normed_time_ls
|
|
113
|
+
return normed_ctx_step_dict
|
|
114
|
+
|
|
115
|
+
def get_report_dict(self, with_detail=False):
|
|
116
|
+
report_dict = {}
|
|
117
|
+
for ctx_name, ctx_dict in self.time_dict.items():
|
|
118
|
+
report_dict[ctx_name] = {
|
|
119
|
+
"duration": ctx_dict.get("duration", 0.0),
|
|
120
|
+
"step_dict": {
|
|
121
|
+
"summary": {"avg_time": {}, "percent_time": {}},
|
|
122
|
+
"detail": {},
|
|
123
|
+
},
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
if with_detail:
|
|
127
|
+
report_dict[ctx_name]["step_dict"]["detail"] = (
|
|
128
|
+
self._step_dict_to_detail(ctx_dict["step_dict"])
|
|
129
|
+
)
|
|
130
|
+
avg_time_list = []
|
|
131
|
+
epsilon = 1e-5
|
|
132
|
+
for step_name, step_list in ctx_dict["step_dict"].items():
|
|
133
|
+
durations = []
|
|
134
|
+
try:
|
|
135
|
+
for time_data in step_list:
|
|
136
|
+
if len(time_data) != 2:
|
|
137
|
+
continue
|
|
138
|
+
start, end = time_data
|
|
139
|
+
durations.append(end - start)
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.error(
|
|
142
|
+
f"Error processing step {step_name} in context {ctx_name}: {e}"
|
|
143
|
+
)
|
|
144
|
+
continue
|
|
145
|
+
if not durations:
|
|
146
|
+
continue
|
|
147
|
+
avg_time = sum(durations) / len(durations)
|
|
148
|
+
if avg_time < epsilon:
|
|
149
|
+
continue
|
|
150
|
+
avg_time_list.append((step_name, avg_time))
|
|
151
|
+
total_avg_time = (
|
|
152
|
+
sum(time for _, time in avg_time_list) or 1e-10
|
|
153
|
+
) # Avoid division by zero
|
|
154
|
+
for step_name, avg_time in avg_time_list:
|
|
155
|
+
report_dict[ctx_name]["step_dict"]["summary"]["percent_time"][
|
|
156
|
+
f"per_{step_name}"
|
|
157
|
+
] = (avg_time / total_avg_time) * 100.0
|
|
158
|
+
report_dict[ctx_name]["step_dict"]["summary"]["avg_time"][
|
|
159
|
+
f"avg_{step_name}"
|
|
160
|
+
] = avg_time
|
|
161
|
+
report_dict[ctx_name]["step_dict"]["summary"][
|
|
162
|
+
"total_avg_time"
|
|
163
|
+
] = total_avg_time
|
|
164
|
+
report_dict[ctx_name]["step_dict"]["summary"] = dict(
|
|
165
|
+
sorted(report_dict[ctx_name]["step_dict"]["summary"].items())
|
|
166
|
+
)
|
|
167
|
+
return report_dict
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
@classmethod
|
|
171
|
+
def plot_formatted_data(
|
|
172
|
+
cls, profiler_data, outdir=None, file_format="png", do_show=False, tag=""
|
|
173
|
+
):
|
|
174
|
+
"""
|
|
175
|
+
Plot each context in a separate figure with bar + pie charts.
|
|
176
|
+
Save each figure in the specified format (png or svg).
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
if outdir is not None:
|
|
180
|
+
os.makedirs(outdir, exist_ok=True)
|
|
181
|
+
|
|
182
|
+
if file_format.lower() not in ["png", "svg"]:
|
|
183
|
+
raise ValueError("file_format must be 'png' or 'svg'")
|
|
184
|
+
|
|
185
|
+
results = {} # {context: fig}
|
|
186
|
+
|
|
187
|
+
for ctx, ctx_data in profiler_data.items():
|
|
188
|
+
summary = ctx_data["step_dict"]["summary"]
|
|
189
|
+
avg_times = summary["avg_time"]
|
|
190
|
+
percent_times = summary["percent_time"]
|
|
191
|
+
|
|
192
|
+
step_names = [s.replace("avg_", "") for s in avg_times.keys()]
|
|
193
|
+
# pprint(f'{step_names=}')
|
|
194
|
+
n_steps = len(step_names)
|
|
195
|
+
|
|
196
|
+
assert n_steps > 0, "No steps found for context: {}".format(ctx)
|
|
197
|
+
# Generate dynamic colors
|
|
198
|
+
colors = px.colors.sample_colorscale(
|
|
199
|
+
"Viridis", [i / (n_steps - 1) for i in range(n_steps)]
|
|
200
|
+
) if n_steps > 1 else [px.colors.sample_colorscale("Viridis", [0])[0]]
|
|
201
|
+
# pprint(f'{len(colors)} colors generated for {n_steps} steps')
|
|
202
|
+
color_map = dict(zip(step_names, colors))
|
|
203
|
+
|
|
204
|
+
# Create figure
|
|
205
|
+
fig = make_subplots(
|
|
206
|
+
rows=1,
|
|
207
|
+
cols=2,
|
|
208
|
+
subplot_titles=[f"Avg Time", f"% Time"],
|
|
209
|
+
specs=[[{"type": "bar"}, {"type": "pie"}]],
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Bar chart
|
|
213
|
+
fig.add_trace(
|
|
214
|
+
go.Bar(
|
|
215
|
+
x=step_names,
|
|
216
|
+
y=list(avg_times.values()),
|
|
217
|
+
text=[f"{v*1000:.2f} ms" for v in avg_times.values()],
|
|
218
|
+
textposition="outside",
|
|
219
|
+
marker=dict(color=[color_map[s] for s in step_names]),
|
|
220
|
+
name="", # unified legend
|
|
221
|
+
showlegend=False,
|
|
222
|
+
),
|
|
223
|
+
row=1,
|
|
224
|
+
col=1,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Pie chart (colors match bar)
|
|
228
|
+
fig.add_trace(
|
|
229
|
+
go.Pie(
|
|
230
|
+
labels=step_names,
|
|
231
|
+
values=list(percent_times.values()),
|
|
232
|
+
marker=dict(colors=[color_map[s] for s in step_names]),
|
|
233
|
+
hole=0.4,
|
|
234
|
+
name="",
|
|
235
|
+
showlegend=True,
|
|
236
|
+
),
|
|
237
|
+
row=1,
|
|
238
|
+
col=2,
|
|
239
|
+
)
|
|
240
|
+
tag_str = tag if tag and len(tag) > 0 else ""
|
|
241
|
+
# Layout
|
|
242
|
+
fig.update_layout(
|
|
243
|
+
title_text=f"[{tag_str}] Context Profiler: {ctx}",
|
|
244
|
+
width=1000,
|
|
245
|
+
height=400,
|
|
246
|
+
showlegend=True,
|
|
247
|
+
legend=dict(title="Steps", x=1.05, y=0.5, traceorder="normal"),
|
|
248
|
+
hovermode="x unified",
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
fig.update_xaxes(title_text="Steps", row=1, col=1)
|
|
252
|
+
fig.update_yaxes(title_text="Avg Time (ms)", row=1, col=1)
|
|
253
|
+
|
|
254
|
+
# Show figure
|
|
255
|
+
if do_show:
|
|
256
|
+
fig.show()
|
|
257
|
+
|
|
258
|
+
# Save figure
|
|
259
|
+
if outdir is not None:
|
|
260
|
+
file_prefix = ctx if len(tag_str) == 0 else f"{tag_str}_{ctx}"
|
|
261
|
+
file_path = os.path.join(outdir, f"{file_prefix}_summary.{file_format.lower()}")
|
|
262
|
+
fig.write_image(file_path)
|
|
263
|
+
print(f"Saved figure: {file_path}")
|
|
264
|
+
|
|
265
|
+
results[ctx] = fig
|
|
266
|
+
|
|
267
|
+
return results
|
|
268
|
+
|
|
269
|
+
def report_and_plot(self, outdir=None, file_format="png", do_show=False, tag=""):
|
|
270
|
+
"""
|
|
271
|
+
Generate the profiling report and plot the formatted data.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
outdir (str): Directory to save figures. If None, figures are only shown.
|
|
275
|
+
file_format (str): Target file format, "png" or "svg". Default is "png".
|
|
276
|
+
do_show (bool): Whether to display the plots. Default is False.
|
|
277
|
+
"""
|
|
278
|
+
report = self.get_report_dict()
|
|
279
|
+
self.get_report_dict(with_detail=False)
|
|
280
|
+
return self.plot_formatted_data(
|
|
281
|
+
report, outdir=outdir, file_format=file_format, do_show=do_show, tag=tag
|
|
282
|
+
)
|
|
283
|
+
def meta_info(self):
|
|
284
|
+
"""
|
|
285
|
+
Print the structure of the profiler's time dictionary.
|
|
286
|
+
Useful for debugging and understanding the profiler's internal state.
|
|
287
|
+
"""
|
|
288
|
+
for ctx_name, ctx_dict in self.time_dict.items():
|
|
289
|
+
with ConsoleLog(f"Context: {ctx_name}"):
|
|
290
|
+
step_names = list(ctx_dict['step_dict'].keys())
|
|
291
|
+
for step_name in step_names:
|
|
292
|
+
pprint(f"Step: {step_name}")
|
|
293
|
+
|
|
294
|
+
def save_report_dict(self, output_file, with_detail=False):
|
|
295
|
+
try:
|
|
296
|
+
report = self.get_report_dict(with_detail=with_detail)
|
|
297
|
+
with open(output_file, "w") as f:
|
|
298
|
+
json.dump(report, f, indent=4)
|
|
299
|
+
except Exception as e:
|
|
300
|
+
logger.error(f"Failed to save report to {output_file}: {e}")
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
* @author Hoang Van-Ha
|
|
3
|
+
* @email hoangvanhauit@gmail.com
|
|
4
|
+
* @create date 2024-03-27 15:40:22
|
|
5
|
+
* @modify date 2024-03-27 15:40:22
|
|
6
|
+
* @desc this module works as a utility tools for finding the best configuration for dataloader (num_workers, batch_size, pin_menory, etc.) that fits your hardware.
|
|
7
|
+
"""
|
|
8
|
+
from argparse import ArgumentParser
|
|
9
|
+
from ..common import *
|
|
10
|
+
from ..filetype import csvfile
|
|
11
|
+
from ..filetype.yamlfile import load_yaml
|
|
12
|
+
from rich import inspect
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
from torchvision import datasets, transforms
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
from typing import Union
|
|
17
|
+
import itertools as it # for cartesian product
|
|
18
|
+
import os
|
|
19
|
+
import time
|
|
20
|
+
import traceback
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def parse_args():
|
|
24
|
+
parser = ArgumentParser(description="desc text")
|
|
25
|
+
parser.add_argument("-cfg", "--cfg", type=str, help="cfg file for searching")
|
|
26
|
+
return parser.parse_args()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_test_range(cfg: dict, search_item="num_workers"):
|
|
30
|
+
item_search_cfg = cfg["search_space"].get(search_item, None)
|
|
31
|
+
if item_search_cfg is None:
|
|
32
|
+
raise ValueError(f"search_item: {search_item} not found in cfg")
|
|
33
|
+
if isinstance(item_search_cfg, list):
|
|
34
|
+
return item_search_cfg
|
|
35
|
+
elif isinstance(item_search_cfg, dict):
|
|
36
|
+
if "mode" in item_search_cfg:
|
|
37
|
+
mode = item_search_cfg["mode"]
|
|
38
|
+
assert mode in ["range", "list"], f"mode: {mode} not supported"
|
|
39
|
+
value_in_mode = item_search_cfg.get(mode, None)
|
|
40
|
+
if value_in_mode is None:
|
|
41
|
+
raise ValueError(f"mode<{mode}>: data not found in <{search_item}>")
|
|
42
|
+
if mode == "range":
|
|
43
|
+
assert len(value_in_mode) == 3, f"range must have 3 values: start, stop, step"
|
|
44
|
+
start = value_in_mode[0]
|
|
45
|
+
stop = value_in_mode[1]
|
|
46
|
+
step = value_in_mode[2]
|
|
47
|
+
return list(range(start, stop, step))
|
|
48
|
+
elif mode == "list":
|
|
49
|
+
return item_search_cfg["list"]
|
|
50
|
+
else:
|
|
51
|
+
return [item_search_cfg] # for int, float, str, bool, etc.
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def load_an_batch(loader_iter):
|
|
55
|
+
start = time.time()
|
|
56
|
+
next(loader_iter)
|
|
57
|
+
end = time.time()
|
|
58
|
+
return end - start
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_dataloader_with_cfg(origin_dataloader: DataLoader, cfg: Union[dict, str]):
|
|
62
|
+
try:
|
|
63
|
+
if isinstance(cfg, str):
|
|
64
|
+
cfg = load_yaml(cfg, to_dict=True)
|
|
65
|
+
dfmk = csvfile.DFCreator()
|
|
66
|
+
search_items = ["batch_size", "num_workers", "persistent_workers", "pin_memory"]
|
|
67
|
+
batch_limit = cfg["general"]["batch_limit"]
|
|
68
|
+
csv_cfg = cfg["general"]["to_csv"]
|
|
69
|
+
log_batch_info = cfg["general"]["log_batch_info"]
|
|
70
|
+
|
|
71
|
+
save_to_csv = csv_cfg["enabled"]
|
|
72
|
+
log_dir = csv_cfg["log_dir"]
|
|
73
|
+
filename = csv_cfg["filename"]
|
|
74
|
+
filename = f"{now_str()}_{filename}.csv"
|
|
75
|
+
outfile = os.path.join(log_dir, filename)
|
|
76
|
+
|
|
77
|
+
dfmk.create_table(
|
|
78
|
+
"cfg_search",
|
|
79
|
+
(search_items + ["avg_time_taken"]),
|
|
80
|
+
)
|
|
81
|
+
ls_range_test = []
|
|
82
|
+
for item in search_items:
|
|
83
|
+
range_test = get_test_range(cfg, search_item=item)
|
|
84
|
+
range_test = [(item, i) for i in range_test]
|
|
85
|
+
ls_range_test.append(range_test)
|
|
86
|
+
|
|
87
|
+
all_combinations = list(it.product(*ls_range_test))
|
|
88
|
+
|
|
89
|
+
rows = []
|
|
90
|
+
for cfg_idx, combine in enumerate(all_combinations):
|
|
91
|
+
console.rule(f"Testing cfg {cfg_idx+1}/{len(all_combinations)}")
|
|
92
|
+
inspect(combine)
|
|
93
|
+
batch_size = combine[search_items.index("batch_size")][1]
|
|
94
|
+
num_workers = combine[search_items.index("num_workers")][1]
|
|
95
|
+
persistent_workers = combine[search_items.index("persistent_workers")][1]
|
|
96
|
+
pin_memory = combine[search_items.index("pin_memory")][1]
|
|
97
|
+
|
|
98
|
+
test_dataloader = DataLoader(origin_dataloader.dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory, shuffle=True)
|
|
99
|
+
row = [
|
|
100
|
+
batch_size,
|
|
101
|
+
num_workers,
|
|
102
|
+
persistent_workers,
|
|
103
|
+
pin_memory,
|
|
104
|
+
0.0,
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
# calculate the avg time taken to load the data for <batch_limit> batches
|
|
108
|
+
trainiter = iter(test_dataloader)
|
|
109
|
+
time_elapsed = 0
|
|
110
|
+
pprint('Start testing...')
|
|
111
|
+
for i in tqdm(range(batch_limit)):
|
|
112
|
+
single_batch_time = load_an_batch(trainiter)
|
|
113
|
+
if log_batch_info:
|
|
114
|
+
pprint(f"Batch {i+1} took {single_batch_time:.4f} seconds to load")
|
|
115
|
+
time_elapsed += single_batch_time
|
|
116
|
+
row[-1] = time_elapsed / batch_limit
|
|
117
|
+
rows.append(row)
|
|
118
|
+
dfmk.insert_rows('cfg_search', rows)
|
|
119
|
+
dfmk.fill_table_from_row_pool('cfg_search')
|
|
120
|
+
with ConsoleLog("results"):
|
|
121
|
+
csvfile.fn_display_df(dfmk['cfg_search'])
|
|
122
|
+
if save_to_csv:
|
|
123
|
+
dfmk["cfg_search"].to_csv(outfile, index=False)
|
|
124
|
+
console.print(f"[red] Data saved to <{outfile}> [/red]")
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
traceback.print_exc()
|
|
128
|
+
print(e)
|
|
129
|
+
# get current directory of this python file
|
|
130
|
+
current_dir = os.path.dirname(os.path.realpath(__file__))
|
|
131
|
+
standar_cfg_path = os.path.join(current_dir, "torchloader_search.yaml")
|
|
132
|
+
pprint(
|
|
133
|
+
f"Make sure you get the right <cfg.yaml> file. An example of <cfg.yaml> file can be found at this path: {standar_cfg_path}"
|
|
134
|
+
)
|
|
135
|
+
return
|
|
136
|
+
|
|
137
|
+
def main():
|
|
138
|
+
args = parse_args()
|
|
139
|
+
cfg_yaml = args.cfg
|
|
140
|
+
cfg_dict = load_yaml(cfg_yaml, to_dict=True)
|
|
141
|
+
|
|
142
|
+
# Define transforms for data augmentation and normalization
|
|
143
|
+
transform = transforms.Compose(
|
|
144
|
+
[
|
|
145
|
+
transforms.RandomHorizontalFlip(), # Randomly flip images horizontally
|
|
146
|
+
transforms.RandomRotation(10), # Randomly rotate images by 10 degrees
|
|
147
|
+
transforms.ToTensor(), # Convert images to PyTorch tensors
|
|
148
|
+
transforms.Normalize(
|
|
149
|
+
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
|
150
|
+
), # Normalize pixel values to [-1, 1]
|
|
151
|
+
]
|
|
152
|
+
)
|
|
153
|
+
test_dataset = datasets.CIFAR10(
|
|
154
|
+
root="./data", train=False, download=True, transform=transform
|
|
155
|
+
)
|
|
156
|
+
batch_size = 64
|
|
157
|
+
train_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
|
|
158
|
+
test_dataloader_with_cfg(train_loader, cfg_dict)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == "__main__":
|
|
162
|
+
main()
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
from rich.pretty import pprint
|
|
3
|
+
import os
|
|
4
|
+
import subprocess
|
|
5
|
+
import argparse
|
|
6
|
+
import wandb
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
console = Console()
|
|
10
|
+
|
|
11
|
+
def sync_runs(outdir):
|
|
12
|
+
outdir = os.path.abspath(outdir)
|
|
13
|
+
assert os.path.exists(outdir), f"Output directory {outdir} does not exist."
|
|
14
|
+
sub_dirs = [name for name in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, name))]
|
|
15
|
+
assert len(sub_dirs) > 0, f"No subdirectories found in {outdir}."
|
|
16
|
+
console.rule("Parent Directory")
|
|
17
|
+
console.print(f"[yellow]{outdir}[/yellow]")
|
|
18
|
+
|
|
19
|
+
exp_dirs = [os.path.join(outdir, sub_dir) for sub_dir in sub_dirs]
|
|
20
|
+
wandb_dirs = []
|
|
21
|
+
for exp_dir in exp_dirs:
|
|
22
|
+
wandb_dirs.extend(glob.glob(f"{exp_dir}/wandb/*run-*"))
|
|
23
|
+
if len(wandb_dirs) == 0:
|
|
24
|
+
console.print(f"No wandb runs found in {outdir}.")
|
|
25
|
+
return
|
|
26
|
+
else:
|
|
27
|
+
console.print(f"Found [bold]{len(wandb_dirs)}[/bold] wandb runs in {outdir}.")
|
|
28
|
+
for i, wandb_dir in enumerate(wandb_dirs):
|
|
29
|
+
console.rule(f"Syncing wandb run {i + 1}/{len(wandb_dirs)}")
|
|
30
|
+
console.print(f"Syncing: {wandb_dir}")
|
|
31
|
+
process = subprocess.Popen(
|
|
32
|
+
["wandb", "sync", wandb_dir],
|
|
33
|
+
stdout=subprocess.PIPE,
|
|
34
|
+
stderr=subprocess.STDOUT,
|
|
35
|
+
text=True,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
for line in process.stdout:
|
|
39
|
+
console.print(line.strip())
|
|
40
|
+
if " ERROR Error while calling W&B API" in line:
|
|
41
|
+
break
|
|
42
|
+
process.stdout.close()
|
|
43
|
+
process.wait()
|
|
44
|
+
if process.returncode != 0:
|
|
45
|
+
console.print(f"[red]Error syncing {wandb_dir}. Return code: {process.returncode}[/red]")
|
|
46
|
+
else:
|
|
47
|
+
console.print(f"Successfully synced {wandb_dir}.")
|
|
48
|
+
|
|
49
|
+
def delete_runs(project, pattern=None):
|
|
50
|
+
console.rule("Delete W&B Runs")
|
|
51
|
+
confirm_msg = f"Are you sure you want to delete all runs in"
|
|
52
|
+
confirm_msg += f" \n\tproject: [red]{project}[/red]"
|
|
53
|
+
if pattern:
|
|
54
|
+
confirm_msg += f"\n\tpattern: [blue]{pattern}[/blue]"
|
|
55
|
+
|
|
56
|
+
console.print(confirm_msg)
|
|
57
|
+
confirmation = input(f"This action cannot be undone. [y/N]: ").strip().lower()
|
|
58
|
+
if confirmation != "y":
|
|
59
|
+
print("Cancelled.")
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
print("Confirmed. Proceeding...")
|
|
63
|
+
api = wandb.Api()
|
|
64
|
+
runs = api.runs(project)
|
|
65
|
+
|
|
66
|
+
deleted = 0
|
|
67
|
+
console.rule("Deleting W&B Runs")
|
|
68
|
+
if len(runs) == 0:
|
|
69
|
+
print("No runs found in the project.")
|
|
70
|
+
return
|
|
71
|
+
for run in tqdm(runs):
|
|
72
|
+
if pattern is None or pattern in run.name:
|
|
73
|
+
run.delete()
|
|
74
|
+
console.print(f"Deleted run: [red]{run.name}[/red]")
|
|
75
|
+
deleted += 1
|
|
76
|
+
|
|
77
|
+
console.print(f"Total runs deleted: {deleted}")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def valid_argument(args):
|
|
81
|
+
if args.op == "sync":
|
|
82
|
+
assert os.path.exists(args.outdir), f"Output directory {args.outdir} does not exist."
|
|
83
|
+
elif args.op == "delete":
|
|
84
|
+
assert isinstance(args.project, str) and len(args.project.strip()) > 0, "Project name must be a non-empty string."
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f"Unknown operation: {args.op}")
|
|
87
|
+
|
|
88
|
+
def parse_args():
|
|
89
|
+
parser = argparse.ArgumentParser(description="Operations on W&B runs")
|
|
90
|
+
parser.add_argument("-op", "--op", type=str, help="Operation to perform", default="sync", choices=["delete", "sync"])
|
|
91
|
+
parser.add_argument("-prj", "--project", type=str, default="fire-paper2-2025", help="W&B project name")
|
|
92
|
+
parser.add_argument("-outdir", "--outdir", type=str, help="arg1 description", default="./zout/train")
|
|
93
|
+
parser.add_argument("-pt", "--pattern",
|
|
94
|
+
type=str,
|
|
95
|
+
default=None,
|
|
96
|
+
help="Run name pattern to match for deletion",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return parser.parse_args()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def main():
|
|
103
|
+
args = parse_args()
|
|
104
|
+
# Validate arguments, stop if invalid
|
|
105
|
+
valid_argument(args)
|
|
106
|
+
|
|
107
|
+
op = args.op
|
|
108
|
+
if op == "sync":
|
|
109
|
+
sync_runs(args.outdir)
|
|
110
|
+
elif op == "delete":
|
|
111
|
+
delete_runs(args.project, args.pattern)
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError(f"Unknown operation: {op}")
|
|
114
|
+
|
|
115
|
+
if __name__ == "__main__":
|
|
116
|
+
main()
|