mlsort 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlsort/__init__.py ADDED
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ from .api import sort, select_algorithm
6
+ from .config import get_artifacts_dir, get_env_bool, get_seed
7
+ from .installer import train_thresholds, save_thresholds, load_thresholds
8
+ from .optimize import gen_cases, optimize_cutoffs
9
+
10
+ __all__ = ["sort", "select_algorithm", "features", "algorithms", "baseline", "model"]
11
+
12
+
13
+ def _maybe_init_on_import() -> None:
14
+ if not get_env_bool("MLSORT_INIT_ON_IMPORT", False):
15
+ return
16
+ thr_path = os.path.join(get_artifacts_dir(), "thresholds.json")
17
+ if os.path.exists(thr_path):
18
+ return
19
+ os.makedirs(os.path.dirname(thr_path) or ".", exist_ok=True)
20
+ seed = get_seed()
21
+ th = train_thresholds(num_samples=600, max_n=120_000, seed=seed, max_depth=3)
22
+ save_thresholds(thr_path, th)
23
+ arrays = gen_cases(num_samples=250, max_n=120_000, seed=seed + 7)
24
+ res = optimize_cutoffs(th, arrays)
25
+ th.cutoff_n = int(res["best"]["cutoff_n"]) # type: ignore[attr-defined]
26
+ th.activation_n = int(res["best"]["activation_n"]) # type: ignore[attr-defined]
27
+ save_thresholds(thr_path, th)
28
+
29
+
30
+ _maybe_init_on_import()
mlsort/algorithms.py ADDED
@@ -0,0 +1,160 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import time
5
+ from typing import Any, Dict, List, Sequence, Tuple
6
+
7
+ import numpy as np
8
+
9
+
10
+ ALG_TIMSORT = "timsort"
11
+ ALG_NP_QUICK = "np_quick"
12
+ ALG_NP_MERGE = "np_merge"
13
+ ALG_COUNTING = "counting"
14
+ ALG_RADIX = "radix"
15
+
16
+ ALL_ALGOS = [ALG_TIMSORT, ALG_NP_QUICK, ALG_NP_MERGE, ALG_COUNTING, ALG_RADIX]
17
+
18
+
19
+ def _as_numpy(arr: Sequence[Any]) -> np.ndarray:
20
+ if isinstance(arr, np.ndarray):
21
+ return arr
22
+ return np.asarray(arr)
23
+
24
+
25
+ def _as_list(arr: Sequence[Any]) -> List[Any]:
26
+ if isinstance(arr, list):
27
+ return list(arr)
28
+ return list(arr)
29
+
30
+
31
+ def sort_timsort(arr: Sequence[Any]) -> List[Any]:
32
+ a = _as_list(arr)
33
+ a.sort()
34
+ return a
35
+
36
+
37
+ def sort_np(arr: Sequence[Any], kind: str) -> np.ndarray:
38
+ a = _as_numpy(arr)
39
+ return np.sort(a, kind=kind)
40
+
41
+
42
+ def sort_counting(arr: Sequence[int]) -> List[int]:
43
+ a = _as_numpy(arr)
44
+ if not np.issubdtype(a.dtype, np.integer):
45
+ raise TypeError("counting sort requires integer dtype")
46
+ if a.size == 0:
47
+ return []
48
+ amin = int(a.min())
49
+ amax = int(a.max())
50
+ rng = amax - amin + 1
51
+ # Safety cap: avoid huge memory
52
+ if rng > 1_000_000:
53
+ raise ValueError("range too large for counting sort")
54
+ counts = np.zeros(rng, dtype=np.int64)
55
+ # Shift values to zero-based
56
+ shifted = (a - amin).astype(np.int64)
57
+ for v in shifted:
58
+ counts[v] += 1
59
+ # Build output
60
+ out = np.empty_like(shifted)
61
+ total = 0
62
+ for i in range(rng):
63
+ c = int(counts[i])
64
+ if c:
65
+ out[total: total + c] = i
66
+ total += c
67
+ # Shift back
68
+ out = (out + amin).astype(a.dtype, copy=False)
69
+ return out.tolist()
70
+
71
+
72
+ def sort_radix_lsd(arr: Sequence[int], base: int = 256) -> List[int]:
73
+ a = _as_numpy(arr)
74
+ if not np.issubdtype(a.dtype, np.integer):
75
+ raise TypeError("radix sort requires integer dtype")
76
+ if a.size == 0:
77
+ return []
78
+ # Use 32-bit buckets for speed; bias signed to unsigned
79
+ dtype = a.dtype
80
+ bits = np.iinfo(dtype).bits
81
+ bias = 1 << (bits - 1)
82
+ u = (a.astype(np.int64) + bias).astype(np.uint64)
83
+ out = u.copy()
84
+ mask = base - 1
85
+ shift = 0
86
+ tmp = np.empty_like(out)
87
+ while shift < bits:
88
+ counts = np.zeros(base, dtype=np.int64)
89
+ # Count
90
+ for v in out:
91
+ counts[(v >> shift) & mask] += 1
92
+ # Prefix sums
93
+ total = 0
94
+ for i in range(base):
95
+ c = counts[i]
96
+ counts[i] = total
97
+ total += c
98
+ # Reorder
99
+ for v in out:
100
+ b = (v >> shift) & mask
101
+ tmp[counts[b]] = v
102
+ counts[b] += 1
103
+ out, tmp = tmp, out
104
+ shift += int(math.log2(base))
105
+ # Un-bias
106
+ res = (out.astype(np.int64) - bias).astype(dtype, copy=False)
107
+ return res.tolist()
108
+
109
+
110
+ def available_algorithms_for(arr: Sequence[Any]) -> List[str]:
111
+ a = _as_numpy(arr)
112
+ algos = [ALG_TIMSORT, ALG_NP_QUICK, ALG_NP_MERGE]
113
+ if np.issubdtype(a.dtype, np.integer):
114
+ # counting only if range manageable
115
+ if a.size > 0:
116
+ amin = int(a.min())
117
+ amax = int(a.max())
118
+ rng = amax - amin + 1
119
+ if rng <= 100_000 and rng <= 8 * a.size:
120
+ algos.append(ALG_COUNTING)
121
+ algos.append(ALG_RADIX)
122
+ return algos
123
+
124
+
125
+ def time_algorithm(arr: Sequence[Any], algo: str, repeats: int = 1) -> float:
126
+ start = time.perf_counter
127
+ best = float("inf")
128
+ for _ in range(repeats):
129
+ t0 = start()
130
+ if algo == ALG_TIMSORT:
131
+ sort_timsort(arr)
132
+ elif algo == ALG_NP_QUICK:
133
+ sort_np(arr, kind="quicksort")
134
+ elif algo == ALG_NP_MERGE:
135
+ sort_np(arr, kind="mergesort")
136
+ elif algo == ALG_COUNTING:
137
+ sort_counting(arr)
138
+ elif algo == ALG_RADIX:
139
+ sort_radix_lsd(arr)
140
+ else:
141
+ raise ValueError(f"unknown algo {algo}")
142
+ best = min(best, start() - t0)
143
+ return best
144
+
145
+
146
+ def measure_best_algorithm(arr: Sequence[Any], repeats: int = 1):
147
+ algos = available_algorithms_for(arr)
148
+ times: Dict[str, float] = {}
149
+ for algo in algos:
150
+ try:
151
+ t = time_algorithm(arr, algo, repeats=repeats)
152
+ times[algo] = t
153
+ except Exception:
154
+ # skip invalid
155
+ continue
156
+ if not times:
157
+ # fallback
158
+ return ALG_TIMSORT, {ALG_TIMSORT: float("inf")}
159
+ best_algo = min(times.items(), key=lambda kv: kv[1])[0]
160
+ return best_algo, times
mlsort/api.py ADDED
@@ -0,0 +1,159 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from typing import Any, Dict, Iterable, List, Sequence, Tuple
6
+
7
+ import numpy as np
8
+
9
+ from .config import get_artifacts_dir, get_env_bool, get_seed
10
+ from .decision import decide
11
+ from .installer import load_thresholds, train_thresholds, save_thresholds, Thresholds
12
+ from .optimize import gen_cases, optimize_cutoffs
13
+ from .algorithms import (
14
+ ALG_TIMSORT, ALG_NP_QUICK, ALG_NP_MERGE, ALG_COUNTING, ALG_RADIX,
15
+ sort_timsort, sort_np, sort_counting, sort_radix_lsd, available_algorithms_for
16
+ )
17
+
18
+
19
+ log = logging.getLogger("mlsort")
20
+
21
+
22
+ def _ensure_thresholds(path: str) -> Thresholds:
23
+ # Lazy init: controlled by env flags
24
+ if os.path.exists(path):
25
+ return load_thresholds(path)
26
+ if not get_env_bool("MLSORT_ENABLE_INSTALL_BENCH", False):
27
+ # Safe default if benchmarks disabled
28
+ th = Thresholds(cutoff_n=1024, activation_n=98304, tree={"leaf": True, "label": ALG_NP_QUICK}, feature_names=[
29
+ "n","dtype_code","est_sortedness","est_dup_ratio","est_range","est_entropy","est_run_len"
30
+ ])
31
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
32
+ save_thresholds(path, th)
33
+ return th
34
+ # Run small-budget train+optimize
35
+ seed = get_seed()
36
+ th = train_thresholds(num_samples=600, max_n=120_000, seed=seed, max_depth=3)
37
+ save_thresholds(path, th)
38
+ arrays = gen_cases(num_samples=250, max_n=120_000, seed=seed + 7)
39
+ res = optimize_cutoffs(th, arrays)
40
+ th.cutoff_n = int(res["best"]["cutoff_n"]) # type: ignore[attr-defined]
41
+ th.activation_n = int(res["best"]["activation_n"]) # type: ignore[attr-defined]
42
+ save_thresholds(path, th)
43
+ return th
44
+
45
+
46
+ def select_algorithm(arr: Sequence[Any], thresholds_path: str | None = None, *, key: Any = None, reverse: bool = False) -> str:
47
+ # Input validation
48
+ try:
49
+ n = len(arr) # type: ignore[arg-type]
50
+ except Exception:
51
+ raise TypeError("arr must be a sequence with __len__ and indexable by int")
52
+ if n == 0:
53
+ return ALG_TIMSORT
54
+ # If a key function is provided, prefer builtin Timsort for correctness and stability
55
+ if key is not None:
56
+ return ALG_TIMSORT
57
+ # If data are strings or mixed/object types, default to Python's Timsort
58
+ try:
59
+ if isinstance(arr, np.ndarray):
60
+ if arr.dtype.kind in {"O", "U", "S"}:
61
+ return ALG_TIMSORT
62
+ else:
63
+ # Sample a subset to determine type categories
64
+ sample_count = min(n, 256)
65
+ idxs = range(sample_count)
66
+ cats = set()
67
+ for i in idxs:
68
+ v = arr[i]
69
+ if isinstance(v, str) or isinstance(v, bytes):
70
+ cats.add("string")
71
+ elif isinstance(v, (int, float, np.integer, np.floating)):
72
+ cats.add("number")
73
+ elif v is None:
74
+ cats.add("other")
75
+ else:
76
+ # Unknown/object type
77
+ cats.add("other")
78
+ if len(cats) > 1:
79
+ break
80
+ if "string" in cats:
81
+ return ALG_TIMSORT
82
+ if len(cats) > 1 or (cats and next(iter(cats)) == "other"):
83
+ return ALG_TIMSORT
84
+ except Exception:
85
+ # On any detection error, prefer safe fallback
86
+ return ALG_TIMSORT
87
+ # Ensure thresholds
88
+ thr_path = thresholds_path or os.path.join(get_artifacts_dir(), "thresholds.json")
89
+ os.makedirs(os.path.dirname(thr_path) or ".", exist_ok=True)
90
+ th = _ensure_thresholds(thr_path)
91
+ algo = decide(arr, th)
92
+ if get_env_bool("MLSORT_DEBUG", False):
93
+ log.debug("mlsort.select algo=%s n=%d path=%s", algo, n, thr_path)
94
+ return algo
95
+
96
+
97
+ def sort(
98
+ arr: Sequence[Any],
99
+ thresholds_path: str | None = None,
100
+ *,
101
+ key: Any = None,
102
+ reverse: bool = False,
103
+ ) -> List[Any]:
104
+ # Always safe fallback path
105
+ try:
106
+ algo = select_algorithm(arr, thresholds_path, key=key, reverse=reverse)
107
+ except Exception as e: # strict safety: fallback
108
+ if get_env_bool("MLSORT_DEBUG", False):
109
+ log.debug("mlsort.select failed: %s; falling back to timsort", e)
110
+ algo = ALG_TIMSORT
111
+
112
+ # Execute with correct key/reverse handling
113
+ if algo == ALG_TIMSORT:
114
+ a = list(arr)
115
+ a.sort(key=key, reverse=reverse)
116
+ return a
117
+
118
+ # For non-Timsort backends, key is unsupported (would have forced Timsort above)
119
+ if algo == ALG_NP_QUICK:
120
+ res = sort_np(arr, kind="quicksort").tolist()
121
+ return res[::-1] if reverse else res
122
+ if algo == ALG_NP_MERGE:
123
+ res = sort_np(arr, kind="mergesort").tolist()
124
+ return res[::-1] if reverse else res
125
+ if algo == ALG_COUNTING:
126
+ try:
127
+ res = sort_counting(arr)
128
+ return res[::-1] if reverse else res
129
+ except Exception:
130
+ res = sort_np(arr, kind="quicksort").tolist()
131
+ return res[::-1] if reverse else res
132
+ if algo == ALG_RADIX:
133
+ try:
134
+ res = sort_radix_lsd(arr)
135
+ return res[::-1] if reverse else res
136
+ except Exception:
137
+ res = sort_np(arr, kind="quicksort").tolist()
138
+ return res[::-1] if reverse else res
139
+
140
+ # Last resort: builtin
141
+ a = list(arr)
142
+ a.sort(key=key, reverse=reverse)
143
+ return a
144
+
145
+
146
+ def profile_decisions(samples: int = 100, max_n: int = 200_000, thresholds_path: str | None = None) -> Dict[str, Any]:
147
+ import time
148
+ from .algorithms import time_algorithm
149
+ thr_path = thresholds_path or os.path.join(get_artifacts_dir(), "thresholds.json")
150
+ th = _ensure_thresholds(thr_path)
151
+ arrays = gen_cases(samples, max_n, seed=get_seed()+99)
152
+ rows = []
153
+ for arr in arrays:
154
+ t0 = time.perf_counter()
155
+ algo = decide(arr, th)
156
+ t1 = time.perf_counter()
157
+ t_sort = time_algorithm(arr, algo, repeats=1)
158
+ rows.append({"n": len(arr), "algo": algo, "decision_ms": (t1-t0)*1000.0, "sort_s": t_sort})
159
+ return {"count": len(rows), "rows": rows[:50]}
mlsort/baseline.py ADDED
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict
4
+
5
+ from .algorithms import ALG_COUNTING, ALG_NP_MERGE, ALG_NP_QUICK, ALG_RADIX, ALG_TIMSORT
6
+
7
+
8
+ def heuristic_baseline(props: Dict[str, float]) -> str:
9
+ n = props["n"]
10
+ dtype = int(props["dtype_code"]) # 0 float, 1 int
11
+ sortedness = props["est_sortedness"]
12
+ dup_ratio = props["est_dup_ratio"]
13
+ rng = props["est_range"]
14
+ entropy = props["est_entropy"]
15
+ run_len = props["est_run_len"]
16
+
17
+ # If almost sorted or long runs, Timsort shines
18
+ if sortedness >= 0.9 or run_len >= 32:
19
+ return ALG_TIMSORT
20
+
21
+ if dtype == 1:
22
+ # Counting sort when range relatively small and many duplicates
23
+ if rng > 0 and rng <= max(1024.0, 8.0 * n) and dup_ratio >= 0.3 and entropy <= 0.7:
24
+ return ALG_COUNTING
25
+ # Radix for wide range ints with moderate entropy
26
+ if n >= 512 and entropy <= 0.9:
27
+ return ALG_RADIX
28
+
29
+ # For general cases prefer NumPy quicksort for speed, merge for stability/some patterns
30
+ if n >= 2000:
31
+ return ALG_NP_QUICK
32
+ else:
33
+ return ALG_NP_MERGE
mlsort/benchmark.py ADDED
@@ -0,0 +1,118 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import statistics
6
+ import time
7
+ from typing import Dict, List
8
+
9
+ import numpy as np
10
+
11
+ from .algorithms import (
12
+ ALG_NP_MERGE,
13
+ ALG_NP_QUICK,
14
+ ALG_TIMSORT,
15
+ time_algorithm,
16
+ )
17
+ from .decision import decide
18
+ from .installer import load_thresholds, train_thresholds, save_thresholds
19
+ from .data import (
20
+ gen_sorted, gen_reverse, gen_nearly_sorted, gen_uniform, gen_small_range, gen_zipf, gen_normal,
21
+ )
22
+
23
+
24
+ Naive = [ALG_TIMSORT, ALG_NP_QUICK, ALG_NP_MERGE]
25
+
26
+
27
+ def ensure_thresholds(path: str, samples: int, max_n: int, seed: int) -> str:
28
+ if os.path.exists(path):
29
+ return path
30
+ th = train_thresholds(num_samples=samples, max_n=max_n, seed=seed, max_depth=3)
31
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
32
+ save_thresholds(path, th)
33
+ return path
34
+
35
+
36
+ def _random_case(rng: np.random.Generator, max_n: int):
37
+ n = int(rng.integers(128, max_n + 1))
38
+ gens = [
39
+ lambda n: gen_sorted(n, "int"),
40
+ lambda n: gen_reverse(n, "int"),
41
+ lambda n: gen_nearly_sorted(n, dtype="int"),
42
+ lambda n: gen_uniform(n, "int", 0, 10_000),
43
+ lambda n: gen_uniform(n, "float"),
44
+ lambda n: gen_small_range(n, 128),
45
+ lambda n: gen_zipf(n, a=2.0, dtype="int", max_val=50_000),
46
+ lambda n: gen_normal(n, dtype="float"),
47
+ ]
48
+ g = gens[int(rng.integers(0, len(gens)))]
49
+ arr = g(n)
50
+ return arr
51
+
52
+
53
+ def run_benchmark(num_samples: int, max_n: int, seed: int, thresholds_path: str) -> Dict:
54
+ ensure_thresholds(thresholds_path, samples=max(600, num_samples // 2), max_n=max_n, seed=seed)
55
+ th = load_thresholds(thresholds_path)
56
+
57
+ rng = np.random.default_rng(seed + 101)
58
+
59
+ times_decision: List[float] = []
60
+ times_naive: Dict[str, List[float]] = {k: [] for k in Naive}
61
+ per_case: List[Dict] = []
62
+
63
+ for _ in range(num_samples):
64
+ arr = _random_case(rng, max_n)
65
+
66
+ t0 = time.perf_counter()
67
+ algo = decide(arr, th)
68
+ t1 = time.perf_counter()
69
+ t_sort = time_algorithm(arr, algo, repeats=1)
70
+ t_decision_total = (t1 - t0) + t_sort
71
+ times_decision.append(t_decision_total)
72
+
73
+ for k in Naive:
74
+ t = time_algorithm(arr, k, repeats=1)
75
+ times_naive[k].append(t)
76
+
77
+ per_case.append({
78
+ "n": int(len(arr)),
79
+ "decision_algo": algo,
80
+ "decision_total_time": t_decision_total,
81
+ **{f"time_{k}": times_naive[k][-1] for k in Naive},
82
+ })
83
+
84
+ def stats(vals: List[float]):
85
+ vals_sorted = sorted(vals)
86
+ p50 = vals_sorted[len(vals_sorted)//2]
87
+ p90 = vals_sorted[int(len(vals_sorted)*0.9)-1]
88
+ return {
89
+ "mean": float(statistics.fmean(vals)),
90
+ "median": float(p50),
91
+ "p90": float(p90),
92
+ }
93
+
94
+ agg = {"decision": stats(times_decision)}
95
+ for k in Naive:
96
+ agg[k] = stats(times_naive[k])
97
+
98
+ mean_naive = {k: agg[k]["mean"] for k in Naive}
99
+ best_naive_key = min(mean_naive.items(), key=lambda kv: kv[1])[0]
100
+ best_naive_mean = mean_naive[best_naive_key]
101
+ speedup_vs_best = best_naive_mean / agg["decision"]["mean"] if agg["decision"]["mean"] > 0 else 1.0
102
+
103
+ win_rates = {}
104
+ for k in Naive:
105
+ wins = sum(1 for i in range(num_samples) if times_decision[i] <= times_naive[k][i])
106
+ win_rates[k] = wins / num_samples
107
+
108
+ return {
109
+ "samples": num_samples,
110
+ "max_n": max_n,
111
+ "seed": seed,
112
+ "thresholds": thresholds_path,
113
+ "aggregate": agg,
114
+ "best_naive": {"key": best_naive_key, "mean_time": best_naive_mean},
115
+ "speedup_vs_best_naive": speedup_vs_best,
116
+ "win_rates": win_rates,
117
+ "cases_head": per_case[:50],
118
+ }
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+
7
+ from mlsort.benchmark import run_benchmark
8
+
9
+
10
+ def main():
11
+ parser = argparse.ArgumentParser(description="Benchmark decision policy vs naive baselines")
12
+ parser.add_argument("--samples", type=int, default=600)
13
+ parser.add_argument("--max-n", type=int, default=20000)
14
+ parser.add_argument("--seed", type=int, default=123)
15
+ parser.add_argument("--thresholds", type=str, default=os.path.join(os.path.expanduser("~"), ".cache", "mlsort", "thresholds.json"))
16
+ parser.add_argument("--out-json", type=str, default="bench_compare.json")
17
+ parser.add_argument("--out-md", type=str, default="report.md")
18
+ args = parser.parse_args()
19
+
20
+ results = run_benchmark(args.samples, args.max_n, args.seed, args.thresholds)
21
+
22
+ with open(args.out_json, "w") as f:
23
+ json.dump(results, f, indent=2)
24
+
25
+ md = []
26
+ md.append("# mlsort: Decision Policy vs. Naive Single-Choice\n\n")
27
+ md.append(f"- Samples: {results['samples']} \n- Max n: {results['max_n']} \n- Seed: {results['seed']} \n- Thresholds: {results['thresholds']}\n\n")
28
+ md.append("## Mean/Median/P90 (seconds)\n")
29
+ md.append(f"- Decision mean: {results['aggregate']['decision']['mean']:.6f}\n")
30
+ for k, v in results['aggregate'].items():
31
+ if k == 'decision':
32
+ continue
33
+ md.append(f"- {k} mean: {v['mean']:.6f}\n")
34
+ md.append(f"\nBest naive: {results['best_naive']['key']} at {results['best_naive']['mean_time']:.6f} s\n")
35
+ md.append(f"Speedup vs best naive: {results['speedup_vs_best_naive']:.3f}x\n")
36
+
37
+ with open(args.out_md, "w") as f:
38
+ f.write("".join(md))
39
+
40
+ print(f"Wrote {args.out_json} and {args.out_md}") # noqa: T201
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()
@@ -0,0 +1,25 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+
6
+ from mlsort.installer import train_thresholds, save_thresholds
7
+
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser(description="Benchmark and derive thresholds for mlsort")
11
+ parser.add_argument("--samples", type=int, default=1200)
12
+ parser.add_argument("--max-n", type=int, default=20000)
13
+ parser.add_argument("--seed", type=int, default=42)
14
+ parser.add_argument("--max-depth", type=int, default=3)
15
+ parser.add_argument("--out", type=str, default=os.path.join(os.path.expanduser("~"), ".cache", "mlsort", "thresholds.json"))
16
+ args = parser.parse_args()
17
+
18
+ os.makedirs(os.path.dirname(args.out), exist_ok=True)
19
+ th = train_thresholds(num_samples=args.samples, max_n=args.max_n, seed=args.seed, max_depth=args.max_depth)
20
+ save_thresholds(args.out, th)
21
+ print(f"Saved thresholds to {args.out}") # noqa: T201
22
+
23
+
24
+ if __name__ == "__main__":
25
+ main()
mlsort/cli_init.py ADDED
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+
6
+ from mlsort.config import get_artifacts_dir, get_seed
7
+ from mlsort.installer import train_thresholds, save_thresholds, load_thresholds
8
+ from mlsort.optimize import gen_cases, optimize_cutoffs
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(description="Initialize mlsort artifacts; optional params have defaults")
13
+ parser.add_argument("--samples", type=int, default=1200, help="training samples")
14
+ parser.add_argument("--max-n", type=int, default=200000, help="max array size in benchmarking")
15
+ parser.add_argument("--seed", type=int, default=None, help="random seed; default from MLSORT_SEED or 42")
16
+ parser.add_argument("--artifacts", type=str, default=None, help="artifacts dir; default MLSORT_ARTIFACTS_DIR or OS cache")
17
+ args = parser.parse_args()
18
+
19
+ seed = args.seed if args.seed is not None else get_seed()
20
+ artifacts = args.artifacts or get_artifacts_dir()
21
+ thr_path = os.path.join(artifacts, "thresholds.json")
22
+ os.makedirs(artifacts, exist_ok=True)
23
+
24
+ th = train_thresholds(num_samples=args.samples, max_n=args.max_n, seed=seed, max_depth=3)
25
+ save_thresholds(thr_path, th)
26
+
27
+ arrays = gen_cases(num_samples=min(300, args.samples), max_n=args.max_n, seed=seed + 17)
28
+ res = optimize_cutoffs(load_thresholds(thr_path), arrays)
29
+ th_best = load_thresholds(thr_path)
30
+ th_best.cutoff_n = int(res["best"]["cutoff_n"]) # type: ignore[attr-defined]
31
+ th_best.activation_n = int(res["best"]["activation_n"]) # type: ignore[attr-defined]
32
+ save_thresholds(thr_path, th_best)
33
+
34
+ print({"artifacts": artifacts, "thresholds": thr_path, "best": res["best"]}) # noqa: T201
35
+
36
+
37
+ if __name__ == "__main__":
38
+ main()
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+
6
+ from mlsort.config import get_artifacts_dir, get_seed
7
+ from mlsort.installer import load_thresholds, save_thresholds
8
+ from mlsort.optimize import gen_cases, optimize_cutoffs
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(description="Optimize cutoff and activation thresholds")
13
+ parser.add_argument("--samples", type=int, default=350)
14
+ parser.add_argument("--max-n", type=int, default=100000)
15
+ parser.add_argument("--seed", type=int, default=None)
16
+ parser.add_argument("--thresholds", type=str, default=None)
17
+ args = parser.parse_args()
18
+
19
+ seed = args.seed if args.seed is not None else get_seed()
20
+ thr_path = args.thresholds or os.path.join(get_artifacts_dir(), "thresholds.json")
21
+ th = load_thresholds(thr_path)
22
+
23
+ arrays = gen_cases(args.samples, args.max_n, seed=seed + 10)
24
+ res = optimize_cutoffs(th, arrays)
25
+
26
+ th.cutoff_n = int(res["best"]["cutoff_n"]) # type: ignore[attr-defined]
27
+ th.activation_n = int(res["best"]["activation_n"]) # type: ignore[attr-defined]
28
+ save_thresholds(thr_path, th)
29
+
30
+ print({"best": res["best"], "thresholds": thr_path}) # noqa: T201
31
+
32
+
33
+ if __name__ == "__main__":
34
+ main()
mlsort/config.py ADDED
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ from typing import Optional
6
+
7
+
8
+ def get_cache_root() -> str:
9
+ # macOS: ~/Library/Caches
10
+ # Linux: XDG_CACHE_HOME or ~/.cache
11
+ # Windows: LOCALAPPDATA
12
+ if sys.platform == "darwin":
13
+ return os.path.expanduser("~/Library/Caches")
14
+ if os.name == "nt":
15
+ return os.environ.get("LOCALAPPDATA", os.path.expanduser("~\\AppData\\Local"))
16
+ return os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
17
+
18
+
19
+ def get_artifacts_dir() -> str:
20
+ return os.environ.get("MLSORT_ARTIFACTS_DIR", os.path.join(get_cache_root(), "mlsort"))
21
+
22
+
23
+ def get_env_bool(name: str, default: bool = False) -> bool:
24
+ v = os.environ.get(name)
25
+ if v is None:
26
+ return default
27
+ return v.strip().lower() in {"1", "true", "yes", "on"}
28
+
29
+
30
+ def get_seed(default: int = 42) -> int:
31
+ try:
32
+ return int(os.environ.get("MLSORT_SEED", str(default)))
33
+ except Exception:
34
+ return default
35
+