halib 0.1.7__py3-none-any.whl → 0.1.99__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/__init__.py +84 -0
- halib/common.py +151 -0
- halib/cuda.py +39 -0
- halib/dataset.py +209 -0
- halib/filetype/csvfile.py +151 -45
- halib/filetype/ipynb.py +63 -0
- halib/filetype/jsonfile.py +1 -1
- halib/filetype/textfile.py +4 -4
- halib/filetype/videofile.py +44 -33
- halib/filetype/yamlfile.py +95 -0
- halib/gdrive.py +1 -1
- halib/online/gdrive.py +104 -54
- halib/online/gdrive_mkdir.py +29 -17
- halib/online/gdrive_test.py +31 -18
- halib/online/projectmake.py +58 -43
- halib/plot.py +296 -11
- halib/projectmake.py +1 -1
- halib/research/__init__.py +0 -0
- halib/research/base_config.py +100 -0
- halib/research/base_exp.py +100 -0
- halib/research/benchquery.py +131 -0
- halib/research/dataset.py +208 -0
- halib/research/flop_csv.py +34 -0
- halib/research/flops.py +156 -0
- halib/research/metrics.py +133 -0
- halib/research/mics.py +68 -0
- halib/research/params_gen.py +108 -0
- halib/research/perfcalc.py +336 -0
- halib/research/perftb.py +780 -0
- halib/research/plot.py +758 -0
- halib/research/profiler.py +300 -0
- halib/research/torchloader.py +162 -0
- halib/research/wandb_op.py +116 -0
- halib/rich_color.py +285 -0
- halib/sys/filesys.py +17 -10
- halib/system/__init__.py +0 -0
- halib/system/cmd.py +8 -0
- halib/system/filesys.py +124 -0
- halib/tele_noti.py +166 -0
- halib/torchloader.py +162 -0
- halib/utils/__init__.py +0 -0
- halib/utils/dataclass_util.py +40 -0
- halib/utils/dict_op.py +9 -0
- halib/utils/gpu_mon.py +58 -0
- halib/utils/listop.py +13 -0
- halib/utils/tele_noti.py +166 -0
- halib/utils/video.py +82 -0
- halib/videofile.py +1 -1
- halib-0.1.99.dist-info/METADATA +209 -0
- halib-0.1.99.dist-info/RECORD +64 -0
- {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/WHEEL +1 -1
- halib-0.1.7.dist-info/METADATA +0 -59
- halib-0.1.7.dist-info/RECORD +0 -30
- {halib-0.1.7.dist-info → halib-0.1.99.dist-info/licenses}/LICENSE.txt +0 -0
- {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/top_level.txt +0 -0
halib/__init__.py
CHANGED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
__all__ = [
|
|
2
|
+
"arrow",
|
|
3
|
+
"cmd",
|
|
4
|
+
"console_log",
|
|
5
|
+
"console",
|
|
6
|
+
"ConsoleLog",
|
|
7
|
+
"csvfile",
|
|
8
|
+
"DictConfig",
|
|
9
|
+
"filetype",
|
|
10
|
+
"fs",
|
|
11
|
+
"inspect",
|
|
12
|
+
"load_yaml",
|
|
13
|
+
"logger",
|
|
14
|
+
"norm_str",
|
|
15
|
+
"now_str",
|
|
16
|
+
"np",
|
|
17
|
+
"omegaconf",
|
|
18
|
+
"OmegaConf",
|
|
19
|
+
"os",
|
|
20
|
+
"pd",
|
|
21
|
+
"plt",
|
|
22
|
+
"pprint",
|
|
23
|
+
"pprint_box",
|
|
24
|
+
"pprint_local_path",
|
|
25
|
+
"px",
|
|
26
|
+
"pprint_local_path",
|
|
27
|
+
"rcolor_all_str",
|
|
28
|
+
"rcolor_palette_all",
|
|
29
|
+
"rcolor_palette",
|
|
30
|
+
"rcolor_str",
|
|
31
|
+
"re",
|
|
32
|
+
"rprint",
|
|
33
|
+
"sns",
|
|
34
|
+
"tcuda",
|
|
35
|
+
"timebudget",
|
|
36
|
+
"tqdm",
|
|
37
|
+
"warnings",
|
|
38
|
+
"time",
|
|
39
|
+
]
|
|
40
|
+
import warnings
|
|
41
|
+
|
|
42
|
+
warnings.filterwarnings("ignore", message="Unable to import Axes3D")
|
|
43
|
+
|
|
44
|
+
# common libraries
|
|
45
|
+
import re
|
|
46
|
+
from tqdm import tqdm
|
|
47
|
+
import arrow
|
|
48
|
+
import numpy as np
|
|
49
|
+
import pandas as pd
|
|
50
|
+
import os
|
|
51
|
+
import time
|
|
52
|
+
|
|
53
|
+
# my own modules
|
|
54
|
+
from .filetype import *
|
|
55
|
+
from .filetype.yamlfile import load_yaml
|
|
56
|
+
from .system import cmd
|
|
57
|
+
from .system import filesys as fs
|
|
58
|
+
from .filetype import csvfile
|
|
59
|
+
from .cuda import tcuda
|
|
60
|
+
from .common import (
|
|
61
|
+
console,
|
|
62
|
+
console_log,
|
|
63
|
+
ConsoleLog,
|
|
64
|
+
now_str,
|
|
65
|
+
norm_str,
|
|
66
|
+
pprint_box,
|
|
67
|
+
pprint_local_path,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# for log
|
|
71
|
+
from loguru import logger
|
|
72
|
+
from rich import inspect
|
|
73
|
+
from rich import print as rprint
|
|
74
|
+
from rich.pretty import pprint
|
|
75
|
+
from timebudget import timebudget
|
|
76
|
+
import omegaconf
|
|
77
|
+
from omegaconf import OmegaConf
|
|
78
|
+
from omegaconf.dictconfig import DictConfig
|
|
79
|
+
from .rich_color import rcolor_str, rcolor_palette, rcolor_palette_all, rcolor_all_str
|
|
80
|
+
|
|
81
|
+
# for visualization
|
|
82
|
+
import seaborn as sns
|
|
83
|
+
import matplotlib.pyplot as plt
|
|
84
|
+
import plotly.express as px
|
halib/common.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import rich
|
|
4
|
+
import arrow
|
|
5
|
+
import pathlib
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
import urllib.parse
|
|
8
|
+
|
|
9
|
+
from rich import print
|
|
10
|
+
from rich.panel import Panel
|
|
11
|
+
from rich.console import Console
|
|
12
|
+
from rich.pretty import pprint, Pretty
|
|
13
|
+
from pathlib import PureWindowsPath
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
console = Console()
|
|
17
|
+
|
|
18
|
+
def seed_everything(seed=42):
|
|
19
|
+
import random
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
random.seed(seed)
|
|
23
|
+
np.random.seed(seed)
|
|
24
|
+
# import torch if it is available
|
|
25
|
+
try:
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
torch.manual_seed(seed)
|
|
29
|
+
torch.cuda.manual_seed(seed)
|
|
30
|
+
torch.cuda.manual_seed_all(seed)
|
|
31
|
+
torch.backends.cudnn.deterministic = True
|
|
32
|
+
torch.backends.cudnn.benchmark = False
|
|
33
|
+
except ImportError:
|
|
34
|
+
pprint("torch not imported, skipping torch seed_everything")
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def now_str(sep_date_time="."):
|
|
39
|
+
assert sep_date_time in [
|
|
40
|
+
".",
|
|
41
|
+
"_",
|
|
42
|
+
"-",
|
|
43
|
+
], "sep_date_time must be one of '.', '_', or '-'"
|
|
44
|
+
now_string = arrow.now().format(f"YYYYMMDD{sep_date_time}HHmmss")
|
|
45
|
+
return now_string
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def norm_str(in_str):
|
|
49
|
+
# Replace one or more whitespace characters with a single underscore
|
|
50
|
+
norm_string = re.sub(r"\s+", "_", in_str)
|
|
51
|
+
# Remove leading and trailing spaces
|
|
52
|
+
norm_string = norm_string.strip()
|
|
53
|
+
return norm_string
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def pprint_box(obj, title="", border_style="green"):
|
|
57
|
+
"""
|
|
58
|
+
Pretty print an object in a box.
|
|
59
|
+
"""
|
|
60
|
+
rich.print(
|
|
61
|
+
Panel(Pretty(obj, expand_all=True), title=title, border_style=border_style)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def console_rule(msg, do_norm_msg=True, is_end_tag=False):
|
|
65
|
+
msg = norm_str(msg) if do_norm_msg else msg
|
|
66
|
+
if is_end_tag:
|
|
67
|
+
console.rule(f"</{msg}>")
|
|
68
|
+
else:
|
|
69
|
+
console.rule(f"<{msg}>")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def console_log(func):
|
|
73
|
+
def wrapper(*args, **kwargs):
|
|
74
|
+
console_rule(func.__name__)
|
|
75
|
+
result = func(*args, **kwargs)
|
|
76
|
+
console_rule(func.__name__, is_end_tag=True)
|
|
77
|
+
return result
|
|
78
|
+
|
|
79
|
+
return wrapper
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ConsoleLog:
|
|
83
|
+
def __init__(self, message):
|
|
84
|
+
self.message = message
|
|
85
|
+
|
|
86
|
+
def __enter__(self):
|
|
87
|
+
console_rule(self.message)
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
91
|
+
console_rule(self.message, is_end_tag=True)
|
|
92
|
+
if exc_type is not None:
|
|
93
|
+
print(f"An exception of type {exc_type} occurred.")
|
|
94
|
+
print(f"Exception message: {exc_value}")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def linux_to_wins_path(path: str) -> str:
|
|
98
|
+
"""
|
|
99
|
+
Convert a Linux-style WSL path (/mnt/c/... or /mnt/d/...) to a Windows-style path (C:\...).
|
|
100
|
+
"""
|
|
101
|
+
# Handle only /mnt/<drive>/... style
|
|
102
|
+
if (
|
|
103
|
+
path.startswith("/mnt/")
|
|
104
|
+
and len(path) > 6
|
|
105
|
+
and path[5].isalpha()
|
|
106
|
+
and path[6] == "/"
|
|
107
|
+
):
|
|
108
|
+
drive = path[5].upper() # Extract drive letter
|
|
109
|
+
win_path = f"{drive}:{path[6:]}" # Replace "/mnt/c/" with "C:/"
|
|
110
|
+
else:
|
|
111
|
+
win_path = path # Return unchanged if not a WSL-style path
|
|
112
|
+
# Normalize to Windows-style backslashes
|
|
113
|
+
return str(PureWindowsPath(win_path))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def pprint_local_path(
|
|
117
|
+
local_path: str, get_wins_path: bool = False, tag: str = ""
|
|
118
|
+
) -> str:
|
|
119
|
+
"""
|
|
120
|
+
Pretty-print a local path with emoji and clickable file:// URI.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
local_path: Path to file or directory (Linux or Windows style).
|
|
124
|
+
get_wins_path: If True on Linux, convert WSL-style path to Windows style before printing.
|
|
125
|
+
tag: Optional console log tag.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
The file URI string.
|
|
129
|
+
"""
|
|
130
|
+
p = Path(local_path).resolve()
|
|
131
|
+
type_str = "📄" if p.is_file() else "📁" if p.is_dir() else "❓"
|
|
132
|
+
|
|
133
|
+
if get_wins_path and os.name == "posix":
|
|
134
|
+
# Try WSL → Windows conversion
|
|
135
|
+
converted = linux_to_wins_path(str(p))
|
|
136
|
+
if converted != str(p): # Conversion happened
|
|
137
|
+
file_uri = str(PureWindowsPath(converted).as_uri())
|
|
138
|
+
else:
|
|
139
|
+
file_uri = p.as_uri()
|
|
140
|
+
else:
|
|
141
|
+
file_uri = p.as_uri()
|
|
142
|
+
|
|
143
|
+
content_str = f"{type_str} [link={file_uri}]{file_uri}[/link]"
|
|
144
|
+
|
|
145
|
+
if tag:
|
|
146
|
+
with ConsoleLog(tag):
|
|
147
|
+
console.print(content_str)
|
|
148
|
+
else:
|
|
149
|
+
console.print(content_str)
|
|
150
|
+
|
|
151
|
+
return file_uri
|
halib/cuda.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
from rich.pretty import pprint
|
|
3
|
+
from rich.console import Console
|
|
4
|
+
|
|
5
|
+
console = Console()
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def tcuda():
|
|
9
|
+
NOT_INSTALLED = "Not Installed"
|
|
10
|
+
GPU_AVAILABLE = "GPU(s) Available"
|
|
11
|
+
ls_lib = ["torch", "tensorflow"]
|
|
12
|
+
lib_stats = {lib: NOT_INSTALLED for lib in ls_lib}
|
|
13
|
+
for lib in ls_lib:
|
|
14
|
+
spec = importlib.util.find_spec(lib)
|
|
15
|
+
if spec:
|
|
16
|
+
if lib == "torch":
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
lib_stats[lib] = str(torch.cuda.device_count()) + " " + GPU_AVAILABLE
|
|
20
|
+
elif lib == "tensorflow":
|
|
21
|
+
import tensorflow as tf
|
|
22
|
+
|
|
23
|
+
lib_stats[lib] = (
|
|
24
|
+
str(len(tf.config.list_physical_devices("GPU")))
|
|
25
|
+
+ " "
|
|
26
|
+
+ GPU_AVAILABLE
|
|
27
|
+
)
|
|
28
|
+
console.rule("<CUDA Library Stats>")
|
|
29
|
+
pprint(lib_stats)
|
|
30
|
+
console.rule("</CUDA Library Stats>")
|
|
31
|
+
return lib_stats
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def main():
|
|
35
|
+
tcuda()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
if __name__ == "__main__":
|
|
39
|
+
main()
|
halib/dataset.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# This script create a test version
|
|
2
|
+
# of the watcam (wc) dataset
|
|
3
|
+
# for testing the tflite model
|
|
4
|
+
|
|
5
|
+
from argparse import ArgumentParser
|
|
6
|
+
|
|
7
|
+
from rich import inspect
|
|
8
|
+
from common import console, seed_everything, ConsoleLog
|
|
9
|
+
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
import os
|
|
12
|
+
import click
|
|
13
|
+
from torchvision.datasets import ImageFolder
|
|
14
|
+
import shutil
|
|
15
|
+
from rich.pretty import pprint
|
|
16
|
+
from system import filesys as fs
|
|
17
|
+
import glob
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def parse_args():
|
|
21
|
+
parser = ArgumentParser(description="desc text")
|
|
22
|
+
parser.add_argument(
|
|
23
|
+
"-indir",
|
|
24
|
+
"--indir",
|
|
25
|
+
type=str,
|
|
26
|
+
help="orignal dataset path",
|
|
27
|
+
)
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"-outdir",
|
|
30
|
+
"--outdir",
|
|
31
|
+
type=str,
|
|
32
|
+
help="dataset out path",
|
|
33
|
+
default=".", # default to current dir
|
|
34
|
+
)
|
|
35
|
+
parser.add_argument(
|
|
36
|
+
"-val_size",
|
|
37
|
+
"--val_size",
|
|
38
|
+
type=float,
|
|
39
|
+
help="validation size", # no default value to force user to input
|
|
40
|
+
default=0.2,
|
|
41
|
+
)
|
|
42
|
+
# add using StratifiedShuffleSplit or ShuffleSplit
|
|
43
|
+
parser.add_argument(
|
|
44
|
+
"-seed",
|
|
45
|
+
"--seed",
|
|
46
|
+
type=int,
|
|
47
|
+
help="random seed",
|
|
48
|
+
default=42,
|
|
49
|
+
)
|
|
50
|
+
parser.add_argument(
|
|
51
|
+
"-inplace",
|
|
52
|
+
"--inplace",
|
|
53
|
+
action="store_true",
|
|
54
|
+
help="inplace operation, will overwrite the outdir if exists",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"-stratified",
|
|
59
|
+
"--stratified",
|
|
60
|
+
action="store_true",
|
|
61
|
+
help="use StratifiedShuffleSplit instead of ShuffleSplit",
|
|
62
|
+
)
|
|
63
|
+
parser.add_argument(
|
|
64
|
+
"-no_train",
|
|
65
|
+
"--no_train",
|
|
66
|
+
action="store_true",
|
|
67
|
+
help="only create test set, no train set",
|
|
68
|
+
)
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"-reverse",
|
|
71
|
+
"--reverse",
|
|
72
|
+
action="store_true",
|
|
73
|
+
help="combine train and val set back to original dataset",
|
|
74
|
+
)
|
|
75
|
+
return parser.parse_args()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def move_images(image_paths, target_set_dir):
|
|
79
|
+
for img_path in tqdm(image_paths):
|
|
80
|
+
# get folder name of the image
|
|
81
|
+
img_dir = os.path.dirname(img_path)
|
|
82
|
+
out_cls_dir = os.path.join(target_set_dir, os.path.basename(img_dir))
|
|
83
|
+
if not os.path.exists(out_cls_dir):
|
|
84
|
+
os.makedirs(out_cls_dir)
|
|
85
|
+
# move the image to the class folder
|
|
86
|
+
shutil.move(img_path, out_cls_dir)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def split_dataset_cls(
|
|
90
|
+
indir, outdir, val_size, seed, inplace, stratified_split, no_train
|
|
91
|
+
):
|
|
92
|
+
seed_everything(seed)
|
|
93
|
+
console.rule("Config confirm?")
|
|
94
|
+
pprint(locals())
|
|
95
|
+
click.confirm("Continue?", abort=True)
|
|
96
|
+
assert os.path.exists(indir), f"{indir} does not exist"
|
|
97
|
+
|
|
98
|
+
if not inplace:
|
|
99
|
+
assert (not inplace) and (
|
|
100
|
+
not os.path.exists(outdir)
|
|
101
|
+
), f"{outdir} already exists; SKIP ...."
|
|
102
|
+
|
|
103
|
+
if inplace:
|
|
104
|
+
outdir = indir
|
|
105
|
+
if not os.path.exists(outdir):
|
|
106
|
+
os.makedirs(outdir)
|
|
107
|
+
|
|
108
|
+
console.rule(f"Creating train/val dataset")
|
|
109
|
+
|
|
110
|
+
sss = (
|
|
111
|
+
ShuffleSplit(n_splits=1, test_size=val_size)
|
|
112
|
+
if not stratified_split
|
|
113
|
+
else StratifiedShuffleSplit(n_splits=1, test_size=val_size)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
pprint({"split strategy": sss, "indir": indir, "outdir": outdir})
|
|
117
|
+
dataset = ImageFolder(
|
|
118
|
+
root=indir,
|
|
119
|
+
transform=None,
|
|
120
|
+
)
|
|
121
|
+
train_dataset_indices = None
|
|
122
|
+
val_dataset_indices = None # val here means test
|
|
123
|
+
for train_indices, val_indices in sss.split(dataset.samples, dataset.targets):
|
|
124
|
+
train_dataset_indices = train_indices
|
|
125
|
+
val_dataset_indices = val_indices
|
|
126
|
+
|
|
127
|
+
# get image paths for train/val split dataset
|
|
128
|
+
train_image_paths = [dataset.imgs[i][0] for i in train_dataset_indices]
|
|
129
|
+
val_image_paths = [dataset.imgs[i][0] for i in val_dataset_indices]
|
|
130
|
+
|
|
131
|
+
# start creating train/val folders then move images
|
|
132
|
+
out_train_dir = os.path.join(outdir, "train")
|
|
133
|
+
out_val_dir = os.path.join(outdir, "val")
|
|
134
|
+
if inplace:
|
|
135
|
+
assert os.path.exists(out_train_dir) == False, f"{out_train_dir} already exists"
|
|
136
|
+
assert os.path.exists(out_val_dir) == False, f"{out_val_dir} already exists"
|
|
137
|
+
|
|
138
|
+
os.makedirs(out_train_dir)
|
|
139
|
+
os.makedirs(out_val_dir)
|
|
140
|
+
|
|
141
|
+
if not no_train:
|
|
142
|
+
with ConsoleLog(f"Moving train images to {out_train_dir} "):
|
|
143
|
+
move_images(train_image_paths, out_train_dir)
|
|
144
|
+
else:
|
|
145
|
+
pprint("test only, skip moving train images")
|
|
146
|
+
# remove out_train_dir
|
|
147
|
+
shutil.rmtree(out_train_dir)
|
|
148
|
+
|
|
149
|
+
with ConsoleLog(f"Moving val images to {out_val_dir} "):
|
|
150
|
+
move_images(val_image_paths, out_val_dir)
|
|
151
|
+
|
|
152
|
+
if inplace:
|
|
153
|
+
pprint(f"remove all folders, except train and val")
|
|
154
|
+
for cls_dir in os.listdir(outdir):
|
|
155
|
+
if cls_dir not in ["train", "val"]:
|
|
156
|
+
shutil.rmtree(os.path.join(indir, cls_dir))
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def reverse_split_ds(indir):
|
|
160
|
+
console.rule(f"Reversing split dataset <{indir}>...")
|
|
161
|
+
ls_dirs = os.listdir(indir)
|
|
162
|
+
# make sure there are only two dirs 'train' and 'val'
|
|
163
|
+
assert len(ls_dirs) == 2, f"Found more than 2 dirs: {len(ls_dirs) } dirs"
|
|
164
|
+
assert "train" in ls_dirs, f"train dir not found in {indir}"
|
|
165
|
+
assert "val" in ls_dirs, f"val dir not found in {indir}"
|
|
166
|
+
train_dir = os.path.join(indir, "train")
|
|
167
|
+
val_dir = os.path.join(indir, "val")
|
|
168
|
+
all_train_files = fs.filter_files_by_extension(
|
|
169
|
+
train_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
|
|
170
|
+
)
|
|
171
|
+
all_val_files = fs.filter_files_by_extension(
|
|
172
|
+
val_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
|
|
173
|
+
)
|
|
174
|
+
# move all files from train to indir
|
|
175
|
+
with ConsoleLog(f"Moving train images to {indir} "):
|
|
176
|
+
move_images(all_train_files, indir)
|
|
177
|
+
with ConsoleLog(f"Moving val images to {indir} "):
|
|
178
|
+
move_images(all_val_files, indir)
|
|
179
|
+
with ConsoleLog(f"Removing train and val dirs"):
|
|
180
|
+
# remove train and val dirs
|
|
181
|
+
shutil.rmtree(train_dir)
|
|
182
|
+
shutil.rmtree(val_dir)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def main():
|
|
186
|
+
args = parse_args()
|
|
187
|
+
indir = args.indir
|
|
188
|
+
outdir = args.outdir
|
|
189
|
+
if outdir == ".":
|
|
190
|
+
# get current folder of the indir
|
|
191
|
+
indir_parent_dir = os.path.dirname(os.path.normpath(indir))
|
|
192
|
+
indir_name = os.path.basename(indir)
|
|
193
|
+
outdir = os.path.join(indir_parent_dir, f"{indir_name}_split")
|
|
194
|
+
val_size = args.val_size
|
|
195
|
+
seed = args.seed
|
|
196
|
+
inplace = args.inplace
|
|
197
|
+
stratified_split = args.stratified
|
|
198
|
+
no_train = args.no_train
|
|
199
|
+
reverse = args.reverse
|
|
200
|
+
if not reverse:
|
|
201
|
+
split_dataset_cls(
|
|
202
|
+
indir, outdir, val_size, seed, inplace, stratified_split, no_train
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
reverse_split_ds(indir)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
if __name__ == "__main__":
|
|
209
|
+
main()
|