mlsort 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlsort/__init__.py +30 -0
- mlsort/algorithms.py +160 -0
- mlsort/api.py +159 -0
- mlsort/baseline.py +33 -0
- mlsort/benchmark.py +118 -0
- mlsort/cli_bench_compare.py +44 -0
- mlsort/cli_bench_install.py +25 -0
- mlsort/cli_init.py +38 -0
- mlsort/cli_optimize_cutoffs.py +34 -0
- mlsort/config.py +35 -0
- mlsort/data.py +109 -0
- mlsort/decision.py +48 -0
- mlsort/features.py +178 -0
- mlsort/installer.py +139 -0
- mlsort/model.py +84 -0
- mlsort/optimize.py +80 -0
- mlsort-0.1.0.dist-info/METADATA +135 -0
- mlsort-0.1.0.dist-info/RECORD +22 -0
- mlsort-0.1.0.dist-info/WHEEL +5 -0
- mlsort-0.1.0.dist-info/entry_points.txt +5 -0
- mlsort-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlsort-0.1.0.dist-info/top_level.txt +1 -0
mlsort/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
|
5
|
+
from .api import sort, select_algorithm
|
6
|
+
from .config import get_artifacts_dir, get_env_bool, get_seed
|
7
|
+
from .installer import train_thresholds, save_thresholds, load_thresholds
|
8
|
+
from .optimize import gen_cases, optimize_cutoffs
|
9
|
+
|
10
|
+
__all__ = ["sort", "select_algorithm", "features", "algorithms", "baseline", "model"]
|
11
|
+
|
12
|
+
|
13
|
+
def _maybe_init_on_import() -> None:
|
14
|
+
if not get_env_bool("MLSORT_INIT_ON_IMPORT", False):
|
15
|
+
return
|
16
|
+
thr_path = os.path.join(get_artifacts_dir(), "thresholds.json")
|
17
|
+
if os.path.exists(thr_path):
|
18
|
+
return
|
19
|
+
os.makedirs(os.path.dirname(thr_path) or ".", exist_ok=True)
|
20
|
+
seed = get_seed()
|
21
|
+
th = train_thresholds(num_samples=600, max_n=120_000, seed=seed, max_depth=3)
|
22
|
+
save_thresholds(thr_path, th)
|
23
|
+
arrays = gen_cases(num_samples=250, max_n=120_000, seed=seed + 7)
|
24
|
+
res = optimize_cutoffs(th, arrays)
|
25
|
+
th.cutoff_n = int(res["best"]["cutoff_n"]) # type: ignore[attr-defined]
|
26
|
+
th.activation_n = int(res["best"]["activation_n"]) # type: ignore[attr-defined]
|
27
|
+
save_thresholds(thr_path, th)
|
28
|
+
|
29
|
+
|
30
|
+
_maybe_init_on_import()
|
mlsort/algorithms.py
ADDED
@@ -0,0 +1,160 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import math
|
4
|
+
import time
|
5
|
+
from typing import Any, Dict, List, Sequence, Tuple
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
|
10
|
+
ALG_TIMSORT = "timsort"
|
11
|
+
ALG_NP_QUICK = "np_quick"
|
12
|
+
ALG_NP_MERGE = "np_merge"
|
13
|
+
ALG_COUNTING = "counting"
|
14
|
+
ALG_RADIX = "radix"
|
15
|
+
|
16
|
+
ALL_ALGOS = [ALG_TIMSORT, ALG_NP_QUICK, ALG_NP_MERGE, ALG_COUNTING, ALG_RADIX]
|
17
|
+
|
18
|
+
|
19
|
+
def _as_numpy(arr: Sequence[Any]) -> np.ndarray:
|
20
|
+
if isinstance(arr, np.ndarray):
|
21
|
+
return arr
|
22
|
+
return np.asarray(arr)
|
23
|
+
|
24
|
+
|
25
|
+
def _as_list(arr: Sequence[Any]) -> List[Any]:
|
26
|
+
if isinstance(arr, list):
|
27
|
+
return list(arr)
|
28
|
+
return list(arr)
|
29
|
+
|
30
|
+
|
31
|
+
def sort_timsort(arr: Sequence[Any]) -> List[Any]:
|
32
|
+
a = _as_list(arr)
|
33
|
+
a.sort()
|
34
|
+
return a
|
35
|
+
|
36
|
+
|
37
|
+
def sort_np(arr: Sequence[Any], kind: str) -> np.ndarray:
|
38
|
+
a = _as_numpy(arr)
|
39
|
+
return np.sort(a, kind=kind)
|
40
|
+
|
41
|
+
|
42
|
+
def sort_counting(arr: Sequence[int]) -> List[int]:
|
43
|
+
a = _as_numpy(arr)
|
44
|
+
if not np.issubdtype(a.dtype, np.integer):
|
45
|
+
raise TypeError("counting sort requires integer dtype")
|
46
|
+
if a.size == 0:
|
47
|
+
return []
|
48
|
+
amin = int(a.min())
|
49
|
+
amax = int(a.max())
|
50
|
+
rng = amax - amin + 1
|
51
|
+
# Safety cap: avoid huge memory
|
52
|
+
if rng > 1_000_000:
|
53
|
+
raise ValueError("range too large for counting sort")
|
54
|
+
counts = np.zeros(rng, dtype=np.int64)
|
55
|
+
# Shift values to zero-based
|
56
|
+
shifted = (a - amin).astype(np.int64)
|
57
|
+
for v in shifted:
|
58
|
+
counts[v] += 1
|
59
|
+
# Build output
|
60
|
+
out = np.empty_like(shifted)
|
61
|
+
total = 0
|
62
|
+
for i in range(rng):
|
63
|
+
c = int(counts[i])
|
64
|
+
if c:
|
65
|
+
out[total: total + c] = i
|
66
|
+
total += c
|
67
|
+
# Shift back
|
68
|
+
out = (out + amin).astype(a.dtype, copy=False)
|
69
|
+
return out.tolist()
|
70
|
+
|
71
|
+
|
72
|
+
def sort_radix_lsd(arr: Sequence[int], base: int = 256) -> List[int]:
|
73
|
+
a = _as_numpy(arr)
|
74
|
+
if not np.issubdtype(a.dtype, np.integer):
|
75
|
+
raise TypeError("radix sort requires integer dtype")
|
76
|
+
if a.size == 0:
|
77
|
+
return []
|
78
|
+
# Use 32-bit buckets for speed; bias signed to unsigned
|
79
|
+
dtype = a.dtype
|
80
|
+
bits = np.iinfo(dtype).bits
|
81
|
+
bias = 1 << (bits - 1)
|
82
|
+
u = (a.astype(np.int64) + bias).astype(np.uint64)
|
83
|
+
out = u.copy()
|
84
|
+
mask = base - 1
|
85
|
+
shift = 0
|
86
|
+
tmp = np.empty_like(out)
|
87
|
+
while shift < bits:
|
88
|
+
counts = np.zeros(base, dtype=np.int64)
|
89
|
+
# Count
|
90
|
+
for v in out:
|
91
|
+
counts[(v >> shift) & mask] += 1
|
92
|
+
# Prefix sums
|
93
|
+
total = 0
|
94
|
+
for i in range(base):
|
95
|
+
c = counts[i]
|
96
|
+
counts[i] = total
|
97
|
+
total += c
|
98
|
+
# Reorder
|
99
|
+
for v in out:
|
100
|
+
b = (v >> shift) & mask
|
101
|
+
tmp[counts[b]] = v
|
102
|
+
counts[b] += 1
|
103
|
+
out, tmp = tmp, out
|
104
|
+
shift += int(math.log2(base))
|
105
|
+
# Un-bias
|
106
|
+
res = (out.astype(np.int64) - bias).astype(dtype, copy=False)
|
107
|
+
return res.tolist()
|
108
|
+
|
109
|
+
|
110
|
+
def available_algorithms_for(arr: Sequence[Any]) -> List[str]:
|
111
|
+
a = _as_numpy(arr)
|
112
|
+
algos = [ALG_TIMSORT, ALG_NP_QUICK, ALG_NP_MERGE]
|
113
|
+
if np.issubdtype(a.dtype, np.integer):
|
114
|
+
# counting only if range manageable
|
115
|
+
if a.size > 0:
|
116
|
+
amin = int(a.min())
|
117
|
+
amax = int(a.max())
|
118
|
+
rng = amax - amin + 1
|
119
|
+
if rng <= 100_000 and rng <= 8 * a.size:
|
120
|
+
algos.append(ALG_COUNTING)
|
121
|
+
algos.append(ALG_RADIX)
|
122
|
+
return algos
|
123
|
+
|
124
|
+
|
125
|
+
def time_algorithm(arr: Sequence[Any], algo: str, repeats: int = 1) -> float:
|
126
|
+
start = time.perf_counter
|
127
|
+
best = float("inf")
|
128
|
+
for _ in range(repeats):
|
129
|
+
t0 = start()
|
130
|
+
if algo == ALG_TIMSORT:
|
131
|
+
sort_timsort(arr)
|
132
|
+
elif algo == ALG_NP_QUICK:
|
133
|
+
sort_np(arr, kind="quicksort")
|
134
|
+
elif algo == ALG_NP_MERGE:
|
135
|
+
sort_np(arr, kind="mergesort")
|
136
|
+
elif algo == ALG_COUNTING:
|
137
|
+
sort_counting(arr)
|
138
|
+
elif algo == ALG_RADIX:
|
139
|
+
sort_radix_lsd(arr)
|
140
|
+
else:
|
141
|
+
raise ValueError(f"unknown algo {algo}")
|
142
|
+
best = min(best, start() - t0)
|
143
|
+
return best
|
144
|
+
|
145
|
+
|
146
|
+
def measure_best_algorithm(arr: Sequence[Any], repeats: int = 1):
|
147
|
+
algos = available_algorithms_for(arr)
|
148
|
+
times: Dict[str, float] = {}
|
149
|
+
for algo in algos:
|
150
|
+
try:
|
151
|
+
t = time_algorithm(arr, algo, repeats=repeats)
|
152
|
+
times[algo] = t
|
153
|
+
except Exception:
|
154
|
+
# skip invalid
|
155
|
+
continue
|
156
|
+
if not times:
|
157
|
+
# fallback
|
158
|
+
return ALG_TIMSORT, {ALG_TIMSORT: float("inf")}
|
159
|
+
best_algo = min(times.items(), key=lambda kv: kv[1])[0]
|
160
|
+
return best_algo, times
|
mlsort/api.py
ADDED
@@ -0,0 +1,159 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from typing import Any, Dict, Iterable, List, Sequence, Tuple
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from .config import get_artifacts_dir, get_env_bool, get_seed
|
10
|
+
from .decision import decide
|
11
|
+
from .installer import load_thresholds, train_thresholds, save_thresholds, Thresholds
|
12
|
+
from .optimize import gen_cases, optimize_cutoffs
|
13
|
+
from .algorithms import (
|
14
|
+
ALG_TIMSORT, ALG_NP_QUICK, ALG_NP_MERGE, ALG_COUNTING, ALG_RADIX,
|
15
|
+
sort_timsort, sort_np, sort_counting, sort_radix_lsd, available_algorithms_for
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
log = logging.getLogger("mlsort")
|
20
|
+
|
21
|
+
|
22
|
+
def _ensure_thresholds(path: str) -> Thresholds:
|
23
|
+
# Lazy init: controlled by env flags
|
24
|
+
if os.path.exists(path):
|
25
|
+
return load_thresholds(path)
|
26
|
+
if not get_env_bool("MLSORT_ENABLE_INSTALL_BENCH", False):
|
27
|
+
# Safe default if benchmarks disabled
|
28
|
+
th = Thresholds(cutoff_n=1024, activation_n=98304, tree={"leaf": True, "label": ALG_NP_QUICK}, feature_names=[
|
29
|
+
"n","dtype_code","est_sortedness","est_dup_ratio","est_range","est_entropy","est_run_len"
|
30
|
+
])
|
31
|
+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
32
|
+
save_thresholds(path, th)
|
33
|
+
return th
|
34
|
+
# Run small-budget train+optimize
|
35
|
+
seed = get_seed()
|
36
|
+
th = train_thresholds(num_samples=600, max_n=120_000, seed=seed, max_depth=3)
|
37
|
+
save_thresholds(path, th)
|
38
|
+
arrays = gen_cases(num_samples=250, max_n=120_000, seed=seed + 7)
|
39
|
+
res = optimize_cutoffs(th, arrays)
|
40
|
+
th.cutoff_n = int(res["best"]["cutoff_n"]) # type: ignore[attr-defined]
|
41
|
+
th.activation_n = int(res["best"]["activation_n"]) # type: ignore[attr-defined]
|
42
|
+
save_thresholds(path, th)
|
43
|
+
return th
|
44
|
+
|
45
|
+
|
46
|
+
def select_algorithm(arr: Sequence[Any], thresholds_path: str | None = None, *, key: Any = None, reverse: bool = False) -> str:
|
47
|
+
# Input validation
|
48
|
+
try:
|
49
|
+
n = len(arr) # type: ignore[arg-type]
|
50
|
+
except Exception:
|
51
|
+
raise TypeError("arr must be a sequence with __len__ and indexable by int")
|
52
|
+
if n == 0:
|
53
|
+
return ALG_TIMSORT
|
54
|
+
# If a key function is provided, prefer builtin Timsort for correctness and stability
|
55
|
+
if key is not None:
|
56
|
+
return ALG_TIMSORT
|
57
|
+
# If data are strings or mixed/object types, default to Python's Timsort
|
58
|
+
try:
|
59
|
+
if isinstance(arr, np.ndarray):
|
60
|
+
if arr.dtype.kind in {"O", "U", "S"}:
|
61
|
+
return ALG_TIMSORT
|
62
|
+
else:
|
63
|
+
# Sample a subset to determine type categories
|
64
|
+
sample_count = min(n, 256)
|
65
|
+
idxs = range(sample_count)
|
66
|
+
cats = set()
|
67
|
+
for i in idxs:
|
68
|
+
v = arr[i]
|
69
|
+
if isinstance(v, str) or isinstance(v, bytes):
|
70
|
+
cats.add("string")
|
71
|
+
elif isinstance(v, (int, float, np.integer, np.floating)):
|
72
|
+
cats.add("number")
|
73
|
+
elif v is None:
|
74
|
+
cats.add("other")
|
75
|
+
else:
|
76
|
+
# Unknown/object type
|
77
|
+
cats.add("other")
|
78
|
+
if len(cats) > 1:
|
79
|
+
break
|
80
|
+
if "string" in cats:
|
81
|
+
return ALG_TIMSORT
|
82
|
+
if len(cats) > 1 or (cats and next(iter(cats)) == "other"):
|
83
|
+
return ALG_TIMSORT
|
84
|
+
except Exception:
|
85
|
+
# On any detection error, prefer safe fallback
|
86
|
+
return ALG_TIMSORT
|
87
|
+
# Ensure thresholds
|
88
|
+
thr_path = thresholds_path or os.path.join(get_artifacts_dir(), "thresholds.json")
|
89
|
+
os.makedirs(os.path.dirname(thr_path) or ".", exist_ok=True)
|
90
|
+
th = _ensure_thresholds(thr_path)
|
91
|
+
algo = decide(arr, th)
|
92
|
+
if get_env_bool("MLSORT_DEBUG", False):
|
93
|
+
log.debug("mlsort.select algo=%s n=%d path=%s", algo, n, thr_path)
|
94
|
+
return algo
|
95
|
+
|
96
|
+
|
97
|
+
def sort(
|
98
|
+
arr: Sequence[Any],
|
99
|
+
thresholds_path: str | None = None,
|
100
|
+
*,
|
101
|
+
key: Any = None,
|
102
|
+
reverse: bool = False,
|
103
|
+
) -> List[Any]:
|
104
|
+
# Always safe fallback path
|
105
|
+
try:
|
106
|
+
algo = select_algorithm(arr, thresholds_path, key=key, reverse=reverse)
|
107
|
+
except Exception as e: # strict safety: fallback
|
108
|
+
if get_env_bool("MLSORT_DEBUG", False):
|
109
|
+
log.debug("mlsort.select failed: %s; falling back to timsort", e)
|
110
|
+
algo = ALG_TIMSORT
|
111
|
+
|
112
|
+
# Execute with correct key/reverse handling
|
113
|
+
if algo == ALG_TIMSORT:
|
114
|
+
a = list(arr)
|
115
|
+
a.sort(key=key, reverse=reverse)
|
116
|
+
return a
|
117
|
+
|
118
|
+
# For non-Timsort backends, key is unsupported (would have forced Timsort above)
|
119
|
+
if algo == ALG_NP_QUICK:
|
120
|
+
res = sort_np(arr, kind="quicksort").tolist()
|
121
|
+
return res[::-1] if reverse else res
|
122
|
+
if algo == ALG_NP_MERGE:
|
123
|
+
res = sort_np(arr, kind="mergesort").tolist()
|
124
|
+
return res[::-1] if reverse else res
|
125
|
+
if algo == ALG_COUNTING:
|
126
|
+
try:
|
127
|
+
res = sort_counting(arr)
|
128
|
+
return res[::-1] if reverse else res
|
129
|
+
except Exception:
|
130
|
+
res = sort_np(arr, kind="quicksort").tolist()
|
131
|
+
return res[::-1] if reverse else res
|
132
|
+
if algo == ALG_RADIX:
|
133
|
+
try:
|
134
|
+
res = sort_radix_lsd(arr)
|
135
|
+
return res[::-1] if reverse else res
|
136
|
+
except Exception:
|
137
|
+
res = sort_np(arr, kind="quicksort").tolist()
|
138
|
+
return res[::-1] if reverse else res
|
139
|
+
|
140
|
+
# Last resort: builtin
|
141
|
+
a = list(arr)
|
142
|
+
a.sort(key=key, reverse=reverse)
|
143
|
+
return a
|
144
|
+
|
145
|
+
|
146
|
+
def profile_decisions(samples: int = 100, max_n: int = 200_000, thresholds_path: str | None = None) -> Dict[str, Any]:
|
147
|
+
import time
|
148
|
+
from .algorithms import time_algorithm
|
149
|
+
thr_path = thresholds_path or os.path.join(get_artifacts_dir(), "thresholds.json")
|
150
|
+
th = _ensure_thresholds(thr_path)
|
151
|
+
arrays = gen_cases(samples, max_n, seed=get_seed()+99)
|
152
|
+
rows = []
|
153
|
+
for arr in arrays:
|
154
|
+
t0 = time.perf_counter()
|
155
|
+
algo = decide(arr, th)
|
156
|
+
t1 = time.perf_counter()
|
157
|
+
t_sort = time_algorithm(arr, algo, repeats=1)
|
158
|
+
rows.append({"n": len(arr), "algo": algo, "decision_ms": (t1-t0)*1000.0, "sort_s": t_sort})
|
159
|
+
return {"count": len(rows), "rows": rows[:50]}
|
mlsort/baseline.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Dict
|
4
|
+
|
5
|
+
from .algorithms import ALG_COUNTING, ALG_NP_MERGE, ALG_NP_QUICK, ALG_RADIX, ALG_TIMSORT
|
6
|
+
|
7
|
+
|
8
|
+
def heuristic_baseline(props: Dict[str, float]) -> str:
|
9
|
+
n = props["n"]
|
10
|
+
dtype = int(props["dtype_code"]) # 0 float, 1 int
|
11
|
+
sortedness = props["est_sortedness"]
|
12
|
+
dup_ratio = props["est_dup_ratio"]
|
13
|
+
rng = props["est_range"]
|
14
|
+
entropy = props["est_entropy"]
|
15
|
+
run_len = props["est_run_len"]
|
16
|
+
|
17
|
+
# If almost sorted or long runs, Timsort shines
|
18
|
+
if sortedness >= 0.9 or run_len >= 32:
|
19
|
+
return ALG_TIMSORT
|
20
|
+
|
21
|
+
if dtype == 1:
|
22
|
+
# Counting sort when range relatively small and many duplicates
|
23
|
+
if rng > 0 and rng <= max(1024.0, 8.0 * n) and dup_ratio >= 0.3 and entropy <= 0.7:
|
24
|
+
return ALG_COUNTING
|
25
|
+
# Radix for wide range ints with moderate entropy
|
26
|
+
if n >= 512 and entropy <= 0.9:
|
27
|
+
return ALG_RADIX
|
28
|
+
|
29
|
+
# For general cases prefer NumPy quicksort for speed, merge for stability/some patterns
|
30
|
+
if n >= 2000:
|
31
|
+
return ALG_NP_QUICK
|
32
|
+
else:
|
33
|
+
return ALG_NP_MERGE
|
mlsort/benchmark.py
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
import statistics
|
6
|
+
import time
|
7
|
+
from typing import Dict, List
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
from .algorithms import (
|
12
|
+
ALG_NP_MERGE,
|
13
|
+
ALG_NP_QUICK,
|
14
|
+
ALG_TIMSORT,
|
15
|
+
time_algorithm,
|
16
|
+
)
|
17
|
+
from .decision import decide
|
18
|
+
from .installer import load_thresholds, train_thresholds, save_thresholds
|
19
|
+
from .data import (
|
20
|
+
gen_sorted, gen_reverse, gen_nearly_sorted, gen_uniform, gen_small_range, gen_zipf, gen_normal,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
Naive = [ALG_TIMSORT, ALG_NP_QUICK, ALG_NP_MERGE]
|
25
|
+
|
26
|
+
|
27
|
+
def ensure_thresholds(path: str, samples: int, max_n: int, seed: int) -> str:
|
28
|
+
if os.path.exists(path):
|
29
|
+
return path
|
30
|
+
th = train_thresholds(num_samples=samples, max_n=max_n, seed=seed, max_depth=3)
|
31
|
+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
32
|
+
save_thresholds(path, th)
|
33
|
+
return path
|
34
|
+
|
35
|
+
|
36
|
+
def _random_case(rng: np.random.Generator, max_n: int):
|
37
|
+
n = int(rng.integers(128, max_n + 1))
|
38
|
+
gens = [
|
39
|
+
lambda n: gen_sorted(n, "int"),
|
40
|
+
lambda n: gen_reverse(n, "int"),
|
41
|
+
lambda n: gen_nearly_sorted(n, dtype="int"),
|
42
|
+
lambda n: gen_uniform(n, "int", 0, 10_000),
|
43
|
+
lambda n: gen_uniform(n, "float"),
|
44
|
+
lambda n: gen_small_range(n, 128),
|
45
|
+
lambda n: gen_zipf(n, a=2.0, dtype="int", max_val=50_000),
|
46
|
+
lambda n: gen_normal(n, dtype="float"),
|
47
|
+
]
|
48
|
+
g = gens[int(rng.integers(0, len(gens)))]
|
49
|
+
arr = g(n)
|
50
|
+
return arr
|
51
|
+
|
52
|
+
|
53
|
+
def run_benchmark(num_samples: int, max_n: int, seed: int, thresholds_path: str) -> Dict:
|
54
|
+
ensure_thresholds(thresholds_path, samples=max(600, num_samples // 2), max_n=max_n, seed=seed)
|
55
|
+
th = load_thresholds(thresholds_path)
|
56
|
+
|
57
|
+
rng = np.random.default_rng(seed + 101)
|
58
|
+
|
59
|
+
times_decision: List[float] = []
|
60
|
+
times_naive: Dict[str, List[float]] = {k: [] for k in Naive}
|
61
|
+
per_case: List[Dict] = []
|
62
|
+
|
63
|
+
for _ in range(num_samples):
|
64
|
+
arr = _random_case(rng, max_n)
|
65
|
+
|
66
|
+
t0 = time.perf_counter()
|
67
|
+
algo = decide(arr, th)
|
68
|
+
t1 = time.perf_counter()
|
69
|
+
t_sort = time_algorithm(arr, algo, repeats=1)
|
70
|
+
t_decision_total = (t1 - t0) + t_sort
|
71
|
+
times_decision.append(t_decision_total)
|
72
|
+
|
73
|
+
for k in Naive:
|
74
|
+
t = time_algorithm(arr, k, repeats=1)
|
75
|
+
times_naive[k].append(t)
|
76
|
+
|
77
|
+
per_case.append({
|
78
|
+
"n": int(len(arr)),
|
79
|
+
"decision_algo": algo,
|
80
|
+
"decision_total_time": t_decision_total,
|
81
|
+
**{f"time_{k}": times_naive[k][-1] for k in Naive},
|
82
|
+
})
|
83
|
+
|
84
|
+
def stats(vals: List[float]):
|
85
|
+
vals_sorted = sorted(vals)
|
86
|
+
p50 = vals_sorted[len(vals_sorted)//2]
|
87
|
+
p90 = vals_sorted[int(len(vals_sorted)*0.9)-1]
|
88
|
+
return {
|
89
|
+
"mean": float(statistics.fmean(vals)),
|
90
|
+
"median": float(p50),
|
91
|
+
"p90": float(p90),
|
92
|
+
}
|
93
|
+
|
94
|
+
agg = {"decision": stats(times_decision)}
|
95
|
+
for k in Naive:
|
96
|
+
agg[k] = stats(times_naive[k])
|
97
|
+
|
98
|
+
mean_naive = {k: agg[k]["mean"] for k in Naive}
|
99
|
+
best_naive_key = min(mean_naive.items(), key=lambda kv: kv[1])[0]
|
100
|
+
best_naive_mean = mean_naive[best_naive_key]
|
101
|
+
speedup_vs_best = best_naive_mean / agg["decision"]["mean"] if agg["decision"]["mean"] > 0 else 1.0
|
102
|
+
|
103
|
+
win_rates = {}
|
104
|
+
for k in Naive:
|
105
|
+
wins = sum(1 for i in range(num_samples) if times_decision[i] <= times_naive[k][i])
|
106
|
+
win_rates[k] = wins / num_samples
|
107
|
+
|
108
|
+
return {
|
109
|
+
"samples": num_samples,
|
110
|
+
"max_n": max_n,
|
111
|
+
"seed": seed,
|
112
|
+
"thresholds": thresholds_path,
|
113
|
+
"aggregate": agg,
|
114
|
+
"best_naive": {"key": best_naive_key, "mean_time": best_naive_mean},
|
115
|
+
"speedup_vs_best_naive": speedup_vs_best,
|
116
|
+
"win_rates": win_rates,
|
117
|
+
"cases_head": per_case[:50],
|
118
|
+
}
|
@@ -0,0 +1,44 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import json
|
5
|
+
import os
|
6
|
+
|
7
|
+
from mlsort.benchmark import run_benchmark
|
8
|
+
|
9
|
+
|
10
|
+
def main():
|
11
|
+
parser = argparse.ArgumentParser(description="Benchmark decision policy vs naive baselines")
|
12
|
+
parser.add_argument("--samples", type=int, default=600)
|
13
|
+
parser.add_argument("--max-n", type=int, default=20000)
|
14
|
+
parser.add_argument("--seed", type=int, default=123)
|
15
|
+
parser.add_argument("--thresholds", type=str, default=os.path.join(os.path.expanduser("~"), ".cache", "mlsort", "thresholds.json"))
|
16
|
+
parser.add_argument("--out-json", type=str, default="bench_compare.json")
|
17
|
+
parser.add_argument("--out-md", type=str, default="report.md")
|
18
|
+
args = parser.parse_args()
|
19
|
+
|
20
|
+
results = run_benchmark(args.samples, args.max_n, args.seed, args.thresholds)
|
21
|
+
|
22
|
+
with open(args.out_json, "w") as f:
|
23
|
+
json.dump(results, f, indent=2)
|
24
|
+
|
25
|
+
md = []
|
26
|
+
md.append("# mlsort: Decision Policy vs. Naive Single-Choice\n\n")
|
27
|
+
md.append(f"- Samples: {results['samples']} \n- Max n: {results['max_n']} \n- Seed: {results['seed']} \n- Thresholds: {results['thresholds']}\n\n")
|
28
|
+
md.append("## Mean/Median/P90 (seconds)\n")
|
29
|
+
md.append(f"- Decision mean: {results['aggregate']['decision']['mean']:.6f}\n")
|
30
|
+
for k, v in results['aggregate'].items():
|
31
|
+
if k == 'decision':
|
32
|
+
continue
|
33
|
+
md.append(f"- {k} mean: {v['mean']:.6f}\n")
|
34
|
+
md.append(f"\nBest naive: {results['best_naive']['key']} at {results['best_naive']['mean_time']:.6f} s\n")
|
35
|
+
md.append(f"Speedup vs best naive: {results['speedup_vs_best_naive']:.3f}x\n")
|
36
|
+
|
37
|
+
with open(args.out_md, "w") as f:
|
38
|
+
f.write("".join(md))
|
39
|
+
|
40
|
+
print(f"Wrote {args.out_json} and {args.out_md}") # noqa: T201
|
41
|
+
|
42
|
+
|
43
|
+
if __name__ == "__main__":
|
44
|
+
main()
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import os
|
5
|
+
|
6
|
+
from mlsort.installer import train_thresholds, save_thresholds
|
7
|
+
|
8
|
+
|
9
|
+
def main():
|
10
|
+
parser = argparse.ArgumentParser(description="Benchmark and derive thresholds for mlsort")
|
11
|
+
parser.add_argument("--samples", type=int, default=1200)
|
12
|
+
parser.add_argument("--max-n", type=int, default=20000)
|
13
|
+
parser.add_argument("--seed", type=int, default=42)
|
14
|
+
parser.add_argument("--max-depth", type=int, default=3)
|
15
|
+
parser.add_argument("--out", type=str, default=os.path.join(os.path.expanduser("~"), ".cache", "mlsort", "thresholds.json"))
|
16
|
+
args = parser.parse_args()
|
17
|
+
|
18
|
+
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
19
|
+
th = train_thresholds(num_samples=args.samples, max_n=args.max_n, seed=args.seed, max_depth=args.max_depth)
|
20
|
+
save_thresholds(args.out, th)
|
21
|
+
print(f"Saved thresholds to {args.out}") # noqa: T201
|
22
|
+
|
23
|
+
|
24
|
+
if __name__ == "__main__":
|
25
|
+
main()
|
mlsort/cli_init.py
ADDED
@@ -0,0 +1,38 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import os
|
5
|
+
|
6
|
+
from mlsort.config import get_artifacts_dir, get_seed
|
7
|
+
from mlsort.installer import train_thresholds, save_thresholds, load_thresholds
|
8
|
+
from mlsort.optimize import gen_cases, optimize_cutoffs
|
9
|
+
|
10
|
+
|
11
|
+
def main():
|
12
|
+
parser = argparse.ArgumentParser(description="Initialize mlsort artifacts; optional params have defaults")
|
13
|
+
parser.add_argument("--samples", type=int, default=1200, help="training samples")
|
14
|
+
parser.add_argument("--max-n", type=int, default=200000, help="max array size in benchmarking")
|
15
|
+
parser.add_argument("--seed", type=int, default=None, help="random seed; default from MLSORT_SEED or 42")
|
16
|
+
parser.add_argument("--artifacts", type=str, default=None, help="artifacts dir; default MLSORT_ARTIFACTS_DIR or OS cache")
|
17
|
+
args = parser.parse_args()
|
18
|
+
|
19
|
+
seed = args.seed if args.seed is not None else get_seed()
|
20
|
+
artifacts = args.artifacts or get_artifacts_dir()
|
21
|
+
thr_path = os.path.join(artifacts, "thresholds.json")
|
22
|
+
os.makedirs(artifacts, exist_ok=True)
|
23
|
+
|
24
|
+
th = train_thresholds(num_samples=args.samples, max_n=args.max_n, seed=seed, max_depth=3)
|
25
|
+
save_thresholds(thr_path, th)
|
26
|
+
|
27
|
+
arrays = gen_cases(num_samples=min(300, args.samples), max_n=args.max_n, seed=seed + 17)
|
28
|
+
res = optimize_cutoffs(load_thresholds(thr_path), arrays)
|
29
|
+
th_best = load_thresholds(thr_path)
|
30
|
+
th_best.cutoff_n = int(res["best"]["cutoff_n"]) # type: ignore[attr-defined]
|
31
|
+
th_best.activation_n = int(res["best"]["activation_n"]) # type: ignore[attr-defined]
|
32
|
+
save_thresholds(thr_path, th_best)
|
33
|
+
|
34
|
+
print({"artifacts": artifacts, "thresholds": thr_path, "best": res["best"]}) # noqa: T201
|
35
|
+
|
36
|
+
|
37
|
+
if __name__ == "__main__":
|
38
|
+
main()
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import os
|
5
|
+
|
6
|
+
from mlsort.config import get_artifacts_dir, get_seed
|
7
|
+
from mlsort.installer import load_thresholds, save_thresholds
|
8
|
+
from mlsort.optimize import gen_cases, optimize_cutoffs
|
9
|
+
|
10
|
+
|
11
|
+
def main():
|
12
|
+
parser = argparse.ArgumentParser(description="Optimize cutoff and activation thresholds")
|
13
|
+
parser.add_argument("--samples", type=int, default=350)
|
14
|
+
parser.add_argument("--max-n", type=int, default=100000)
|
15
|
+
parser.add_argument("--seed", type=int, default=None)
|
16
|
+
parser.add_argument("--thresholds", type=str, default=None)
|
17
|
+
args = parser.parse_args()
|
18
|
+
|
19
|
+
seed = args.seed if args.seed is not None else get_seed()
|
20
|
+
thr_path = args.thresholds or os.path.join(get_artifacts_dir(), "thresholds.json")
|
21
|
+
th = load_thresholds(thr_path)
|
22
|
+
|
23
|
+
arrays = gen_cases(args.samples, args.max_n, seed=seed + 10)
|
24
|
+
res = optimize_cutoffs(th, arrays)
|
25
|
+
|
26
|
+
th.cutoff_n = int(res["best"]["cutoff_n"]) # type: ignore[attr-defined]
|
27
|
+
th.activation_n = int(res["best"]["activation_n"]) # type: ignore[attr-defined]
|
28
|
+
save_thresholds(thr_path, th)
|
29
|
+
|
30
|
+
print({"best": res["best"], "thresholds": thr_path}) # noqa: T201
|
31
|
+
|
32
|
+
|
33
|
+
if __name__ == "__main__":
|
34
|
+
main()
|
mlsort/config.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
import sys
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
|
8
|
+
def get_cache_root() -> str:
|
9
|
+
# macOS: ~/Library/Caches
|
10
|
+
# Linux: XDG_CACHE_HOME or ~/.cache
|
11
|
+
# Windows: LOCALAPPDATA
|
12
|
+
if sys.platform == "darwin":
|
13
|
+
return os.path.expanduser("~/Library/Caches")
|
14
|
+
if os.name == "nt":
|
15
|
+
return os.environ.get("LOCALAPPDATA", os.path.expanduser("~\\AppData\\Local"))
|
16
|
+
return os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
17
|
+
|
18
|
+
|
19
|
+
def get_artifacts_dir() -> str:
|
20
|
+
return os.environ.get("MLSORT_ARTIFACTS_DIR", os.path.join(get_cache_root(), "mlsort"))
|
21
|
+
|
22
|
+
|
23
|
+
def get_env_bool(name: str, default: bool = False) -> bool:
|
24
|
+
v = os.environ.get(name)
|
25
|
+
if v is None:
|
26
|
+
return default
|
27
|
+
return v.strip().lower() in {"1", "true", "yes", "on"}
|
28
|
+
|
29
|
+
|
30
|
+
def get_seed(default: int = 42) -> int:
|
31
|
+
try:
|
32
|
+
return int(os.environ.get("MLSORT_SEED", str(default)))
|
33
|
+
except Exception:
|
34
|
+
return default
|
35
|
+
|