mlsort 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlsort/api.py +15 -1
- mlsort/cli_export_forest.py +47 -0
- mlsort/fast_model.py +56 -0
- {mlsort-0.1.0.dist-info → mlsort-0.1.1.dist-info}/METADATA +5 -4
- {mlsort-0.1.0.dist-info → mlsort-0.1.1.dist-info}/RECORD +9 -7
- {mlsort-0.1.0.dist-info → mlsort-0.1.1.dist-info}/entry_points.txt +1 -0
- {mlsort-0.1.0.dist-info → mlsort-0.1.1.dist-info}/WHEEL +0 -0
- {mlsort-0.1.0.dist-info → mlsort-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {mlsort-0.1.0.dist-info → mlsort-0.1.1.dist-info}/top_level.txt +0 -0
mlsort/api.py
CHANGED
@@ -43,6 +43,10 @@ def _ensure_thresholds(path: str) -> Thresholds:
|
|
43
43
|
return th
|
44
44
|
|
45
45
|
|
46
|
+
def _use_fast_model() -> bool:
|
47
|
+
return os.environ.get("MLSORT_USE_FAST_MODEL", "0").lower() in {"1", "true", "yes", "on"}
|
48
|
+
|
49
|
+
|
46
50
|
def select_algorithm(arr: Sequence[Any], thresholds_path: str | None = None, *, key: Any = None, reverse: bool = False) -> str:
|
47
51
|
# Input validation
|
48
52
|
try:
|
@@ -88,7 +92,17 @@ def select_algorithm(arr: Sequence[Any], thresholds_path: str | None = None, *,
|
|
88
92
|
thr_path = thresholds_path or os.path.join(get_artifacts_dir(), "thresholds.json")
|
89
93
|
os.makedirs(os.path.dirname(thr_path) or ".", exist_ok=True)
|
90
94
|
th = _ensure_thresholds(thr_path)
|
91
|
-
|
95
|
+
# Large arrays: optionally use fast model
|
96
|
+
if _use_fast_model():
|
97
|
+
try:
|
98
|
+
from .features import estimate_properties
|
99
|
+
from .fast_model import predict_fast
|
100
|
+
props = estimate_properties(arr)
|
101
|
+
algo = predict_fast(props)
|
102
|
+
except Exception:
|
103
|
+
algo = decide(arr, th)
|
104
|
+
else:
|
105
|
+
algo = decide(arr, th)
|
92
106
|
if get_env_bool("MLSORT_DEBUG", False):
|
93
107
|
log.debug("mlsort.select algo=%s n=%d path=%s", algo, n, thr_path)
|
94
108
|
return algo
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import json
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
import joblib
|
8
|
+
from sklearn.ensemble import RandomForestClassifier
|
9
|
+
|
10
|
+
from .model import LABELS
|
11
|
+
|
12
|
+
|
13
|
+
def export_forest(model: RandomForestClassifier) -> dict:
|
14
|
+
trees = []
|
15
|
+
for est in model.estimators_:
|
16
|
+
t = est.tree_
|
17
|
+
nodes = []
|
18
|
+
for i in range(t.node_count):
|
19
|
+
if t.children_left[i] == -1 and t.children_right[i] == -1:
|
20
|
+
value = t.value[i][0].tolist()
|
21
|
+
nodes.append({"value": value})
|
22
|
+
else:
|
23
|
+
nodes.append({
|
24
|
+
"feature": int(t.feature[i]),
|
25
|
+
"threshold": float(t.threshold[i]),
|
26
|
+
"left": int(t.children_left[i]),
|
27
|
+
"right": int(t.children_right[i]),
|
28
|
+
})
|
29
|
+
trees.append({"nodes": nodes})
|
30
|
+
return {"label_names": LABELS, "trees": trees}
|
31
|
+
|
32
|
+
|
33
|
+
def main():
|
34
|
+
p = argparse.ArgumentParser(description="Export sklearn RandomForest to fast JSON format")
|
35
|
+
p.add_argument("--model", required=True, help="Path to model.joblib")
|
36
|
+
p.add_argument("--out", required=True, help="Path to write forest.json")
|
37
|
+
args = p.parse_args()
|
38
|
+
|
39
|
+
model: RandomForestClassifier = joblib.load(args.model)
|
40
|
+
spec = export_forest(model)
|
41
|
+
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
42
|
+
Path(args.out).write_text(json.dumps(spec))
|
43
|
+
print(f"Wrote {args.out}") # noqa: T201
|
44
|
+
|
45
|
+
|
46
|
+
if __name__ == "__main__":
|
47
|
+
main()
|
mlsort/fast_model.py
ADDED
@@ -0,0 +1,56 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
from typing import Dict, List, Optional
|
6
|
+
|
7
|
+
from .features import to_feature_vector
|
8
|
+
from .model import ID_TO_LABEL
|
9
|
+
from .config import get_artifacts_dir
|
10
|
+
|
11
|
+
_FAST_MODEL: Optional[Dict] = None
|
12
|
+
_FAST_MODEL_PATH: Optional[str] = None
|
13
|
+
|
14
|
+
|
15
|
+
def _get_default_fast_model_path() -> str:
|
16
|
+
return os.path.join(get_artifacts_dir(), "forest.json")
|
17
|
+
|
18
|
+
|
19
|
+
def load_fast_model(path: Optional[str] = None) -> Dict:
|
20
|
+
global _FAST_MODEL, _FAST_MODEL_PATH
|
21
|
+
use_path = path or _get_default_fast_model_path()
|
22
|
+
if _FAST_MODEL is None or _FAST_MODEL_PATH != use_path:
|
23
|
+
with open(use_path, "r") as f:
|
24
|
+
_FAST_MODEL = json.load(f)
|
25
|
+
_FAST_MODEL_PATH = use_path
|
26
|
+
return _FAST_MODEL # type: ignore[return-value]
|
27
|
+
|
28
|
+
|
29
|
+
def _tree_predict(tree: Dict, x: List[float]) -> int:
|
30
|
+
nodes = tree["nodes"]
|
31
|
+
i = 0
|
32
|
+
while True:
|
33
|
+
node = nodes[i]
|
34
|
+
if "value" in node:
|
35
|
+
vec: List[float] = node["value"]
|
36
|
+
return int(max(range(len(vec)), key=lambda k: vec[k]))
|
37
|
+
feat = node["feature"]
|
38
|
+
thr = node["threshold"]
|
39
|
+
left = node["left"]
|
40
|
+
right = node["right"]
|
41
|
+
i = left if x[feat] <= thr else right
|
42
|
+
|
43
|
+
|
44
|
+
def predict_fast(props: Dict[str, float], *, model_path: Optional[str] = None) -> str:
|
45
|
+
fm = load_fast_model(model_path)
|
46
|
+
x = to_feature_vector(props)
|
47
|
+
votes: List[int] = []
|
48
|
+
for tree in fm["trees"]:
|
49
|
+
votes.append(_tree_predict(tree, x))
|
50
|
+
if not votes:
|
51
|
+
return "timsort"
|
52
|
+
counts: Dict[int, int] = {}
|
53
|
+
for v in votes:
|
54
|
+
counts[v] = counts.get(v, 0) + 1
|
55
|
+
best_id = max(counts.items(), key=lambda kv: kv[1])[0]
|
56
|
+
return ID_TO_LABEL[int(best_id)]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mlsort
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.1
|
4
4
|
Summary: ML-guided sorting backend selector with install-time benchmarking
|
5
5
|
Author: Siddharth Chaudhary
|
6
6
|
License: MIT License
|
@@ -43,9 +43,10 @@ Requires-Python: >=3.9
|
|
43
43
|
Description-Content-Type: text/markdown
|
44
44
|
License-File: LICENSE
|
45
45
|
Requires-Dist: numpy>=1.24
|
46
|
-
|
47
|
-
Requires-Dist:
|
48
|
-
Requires-Dist:
|
46
|
+
Provides-Extra: train
|
47
|
+
Requires-Dist: scikit-learn>=1.3; extra == "train"
|
48
|
+
Requires-Dist: scipy>=1.10; extra == "train"
|
49
|
+
Requires-Dist: joblib>=1.3; extra == "train"
|
49
50
|
Dynamic: license-file
|
50
51
|
|
51
52
|
# mlsort
|
@@ -1,22 +1,24 @@
|
|
1
1
|
mlsort/__init__.py,sha256=49ZFRUBmCcD_YpHDLtAvb6CjCOAUoDqczL0c5pTWhPs,1121
|
2
2
|
mlsort/algorithms.py,sha256=MgOOe9SHy9D_af7siDS4jWtuLK6alhIv8sPusqCx9qI,4475
|
3
|
-
mlsort/api.py,sha256=
|
3
|
+
mlsort/api.py,sha256=gyxjAYYu254QJyKDU0d0ApkWdf_O-DAOjj9zI-NZuiE,6500
|
4
4
|
mlsort/baseline.py,sha256=2nZrEY7P5QAQ8RPOxqNz47rR_WZMyd3iONOSR38u_-Y,1104
|
5
5
|
mlsort/benchmark.py,sha256=Ez_-HOnbvzfZD0323Nv_8vGj3xhENflztD-7IOEAalo,3713
|
6
6
|
mlsort/cli_bench_compare.py,sha256=HH1C8H8IpWWdORT81_1gsO_je7emLW9NEIpqPedDbgw,1771
|
7
7
|
mlsort/cli_bench_install.py,sha256=g28V9TZ_b5rJIEbV9LAkXZWYZaUMYQK10npGGxTZ5jI,936
|
8
|
+
mlsort/cli_export_forest.py,sha256=t2qpwfU85bMv1XNphjtA1u9YETM2j7WHcIeWdMctqWE,1481
|
8
9
|
mlsort/cli_init.py,sha256=kDcnne1lLTRg5IEZVBOYiV8kkzb_5JHmTybZEvvPpKw,1721
|
9
10
|
mlsort/cli_optimize_cutoffs.py,sha256=6De71xb93z6JScfMIduKkDe9DZ7xaG2Pg085JQe6HB0,1216
|
10
11
|
mlsort/config.py,sha256=3Qzumm41uCvseH9LbRaDe06ffFOsJ3k3f20pM_bdjg8,966
|
11
12
|
mlsort/data.py,sha256=HHtffrqOE15jLRxN6sc__mXBKsptT-USLxsZhq09SIc,3246
|
12
13
|
mlsort/decision.py,sha256=YB3epa2L7Wa1faFmbyTSNjdGj0NigO2mro7K_k5IklA,1527
|
14
|
+
mlsort/fast_model.py,sha256=DQ1Pg3i1BeYcfvIX-6L17DO8b-FcbVRU2q7KyKcaxS8,1699
|
13
15
|
mlsort/features.py,sha256=MJOwPnC4z8VDwiA9PXBgjwSr5r9djup7ZieGzSmwMbg,5303
|
14
16
|
mlsort/installer.py,sha256=M7Dj2lMEecNblWY4YoMBME3fjqNf8L2BolNbvLcYUxE,4904
|
15
17
|
mlsort/model.py,sha256=OY6b_04unIIbjrjmQ4LKrG52TmB7FvUBbflcBNg2-d0,2378
|
16
18
|
mlsort/optimize.py,sha256=7Yi6tmiJcnj_6NtDfViuxzY1nVR3zoq-sbOj0v7yEis,2945
|
17
|
-
mlsort-0.1.
|
18
|
-
mlsort-0.1.
|
19
|
-
mlsort-0.1.
|
20
|
-
mlsort-0.1.
|
21
|
-
mlsort-0.1.
|
22
|
-
mlsort-0.1.
|
19
|
+
mlsort-0.1.1.dist-info/licenses/LICENSE,sha256=yzOA5llIyAHw7tVsir3l5NgRm1_pkvXy2r4bUFcZY0g,1076
|
20
|
+
mlsort-0.1.1.dist-info/METADATA,sha256=_yZ_ztOpM8KorrJdFJt43p2vP234Un8Ck0QQqFRB5Pc,5870
|
21
|
+
mlsort-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
22
|
+
mlsort-0.1.1.dist-info/entry_points.txt,sha256=2oMHL1Z_f3U-hEwWb369v4XTwgKCkztdc1UJUPMhaDM,271
|
23
|
+
mlsort-0.1.1.dist-info/top_level.txt,sha256=0tl8OhYGP3bgyXuS76DsDreFASPKloMccz5pGfteKp0,7
|
24
|
+
mlsort-0.1.1.dist-info/RECORD,,
|
@@ -1,5 +1,6 @@
|
|
1
1
|
[console_scripts]
|
2
2
|
mlsort-bench-compare = mlsort.cli_bench_compare:main
|
3
3
|
mlsort-bench-install = mlsort.cli_bench_install:main
|
4
|
+
mlsort-export-forest = mlsort.cli_export_forest:main
|
4
5
|
mlsort-init = mlsort.cli_init:main
|
5
6
|
mlsort-optimize-cutoffs = mlsort.cli_optimize_cutoffs:main
|
File without changes
|
File without changes
|
File without changes
|