mlsort 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlsort/api.py CHANGED
@@ -43,6 +43,10 @@ def _ensure_thresholds(path: str) -> Thresholds:
43
43
  return th
44
44
 
45
45
 
46
+ def _use_fast_model() -> bool:
47
+ return os.environ.get("MLSORT_USE_FAST_MODEL", "0").lower() in {"1", "true", "yes", "on"}
48
+
49
+
46
50
  def select_algorithm(arr: Sequence[Any], thresholds_path: str | None = None, *, key: Any = None, reverse: bool = False) -> str:
47
51
  # Input validation
48
52
  try:
@@ -88,7 +92,17 @@ def select_algorithm(arr: Sequence[Any], thresholds_path: str | None = None, *,
88
92
  thr_path = thresholds_path or os.path.join(get_artifacts_dir(), "thresholds.json")
89
93
  os.makedirs(os.path.dirname(thr_path) or ".", exist_ok=True)
90
94
  th = _ensure_thresholds(thr_path)
91
- algo = decide(arr, th)
95
+ # Large arrays: optionally use fast model
96
+ if _use_fast_model():
97
+ try:
98
+ from .features import estimate_properties
99
+ from .fast_model import predict_fast
100
+ props = estimate_properties(arr)
101
+ algo = predict_fast(props)
102
+ except Exception:
103
+ algo = decide(arr, th)
104
+ else:
105
+ algo = decide(arr, th)
92
106
  if get_env_bool("MLSORT_DEBUG", False):
93
107
  log.debug("mlsort.select algo=%s n=%d path=%s", algo, n, thr_path)
94
108
  return algo
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+
7
+ import joblib
8
+ from sklearn.ensemble import RandomForestClassifier
9
+
10
+ from .model import LABELS
11
+
12
+
13
+ def export_forest(model: RandomForestClassifier) -> dict:
14
+ trees = []
15
+ for est in model.estimators_:
16
+ t = est.tree_
17
+ nodes = []
18
+ for i in range(t.node_count):
19
+ if t.children_left[i] == -1 and t.children_right[i] == -1:
20
+ value = t.value[i][0].tolist()
21
+ nodes.append({"value": value})
22
+ else:
23
+ nodes.append({
24
+ "feature": int(t.feature[i]),
25
+ "threshold": float(t.threshold[i]),
26
+ "left": int(t.children_left[i]),
27
+ "right": int(t.children_right[i]),
28
+ })
29
+ trees.append({"nodes": nodes})
30
+ return {"label_names": LABELS, "trees": trees}
31
+
32
+
33
+ def main():
34
+ p = argparse.ArgumentParser(description="Export sklearn RandomForest to fast JSON format")
35
+ p.add_argument("--model", required=True, help="Path to model.joblib")
36
+ p.add_argument("--out", required=True, help="Path to write forest.json")
37
+ args = p.parse_args()
38
+
39
+ model: RandomForestClassifier = joblib.load(args.model)
40
+ spec = export_forest(model)
41
+ Path(args.out).parent.mkdir(parents=True, exist_ok=True)
42
+ Path(args.out).write_text(json.dumps(spec))
43
+ print(f"Wrote {args.out}") # noqa: T201
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
mlsort/fast_model.py ADDED
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from typing import Dict, List, Optional
6
+
7
+ from .features import to_feature_vector
8
+ from .model import ID_TO_LABEL
9
+ from .config import get_artifacts_dir
10
+
11
+ _FAST_MODEL: Optional[Dict] = None
12
+ _FAST_MODEL_PATH: Optional[str] = None
13
+
14
+
15
+ def _get_default_fast_model_path() -> str:
16
+ return os.path.join(get_artifacts_dir(), "forest.json")
17
+
18
+
19
+ def load_fast_model(path: Optional[str] = None) -> Dict:
20
+ global _FAST_MODEL, _FAST_MODEL_PATH
21
+ use_path = path or _get_default_fast_model_path()
22
+ if _FAST_MODEL is None or _FAST_MODEL_PATH != use_path:
23
+ with open(use_path, "r") as f:
24
+ _FAST_MODEL = json.load(f)
25
+ _FAST_MODEL_PATH = use_path
26
+ return _FAST_MODEL # type: ignore[return-value]
27
+
28
+
29
+ def _tree_predict(tree: Dict, x: List[float]) -> int:
30
+ nodes = tree["nodes"]
31
+ i = 0
32
+ while True:
33
+ node = nodes[i]
34
+ if "value" in node:
35
+ vec: List[float] = node["value"]
36
+ return int(max(range(len(vec)), key=lambda k: vec[k]))
37
+ feat = node["feature"]
38
+ thr = node["threshold"]
39
+ left = node["left"]
40
+ right = node["right"]
41
+ i = left if x[feat] <= thr else right
42
+
43
+
44
+ def predict_fast(props: Dict[str, float], *, model_path: Optional[str] = None) -> str:
45
+ fm = load_fast_model(model_path)
46
+ x = to_feature_vector(props)
47
+ votes: List[int] = []
48
+ for tree in fm["trees"]:
49
+ votes.append(_tree_predict(tree, x))
50
+ if not votes:
51
+ return "timsort"
52
+ counts: Dict[int, int] = {}
53
+ for v in votes:
54
+ counts[v] = counts.get(v, 0) + 1
55
+ best_id = max(counts.items(), key=lambda kv: kv[1])[0]
56
+ return ID_TO_LABEL[int(best_id)]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlsort
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: ML-guided sorting backend selector with install-time benchmarking
5
5
  Author: Siddharth Chaudhary
6
6
  License: MIT License
@@ -43,9 +43,10 @@ Requires-Python: >=3.9
43
43
  Description-Content-Type: text/markdown
44
44
  License-File: LICENSE
45
45
  Requires-Dist: numpy>=1.24
46
- Requires-Dist: scikit-learn>=1.3
47
- Requires-Dist: scipy>=1.10
48
- Requires-Dist: joblib>=1.3
46
+ Provides-Extra: train
47
+ Requires-Dist: scikit-learn>=1.3; extra == "train"
48
+ Requires-Dist: scipy>=1.10; extra == "train"
49
+ Requires-Dist: joblib>=1.3; extra == "train"
49
50
  Dynamic: license-file
50
51
 
51
52
  # mlsort
@@ -1,22 +1,24 @@
1
1
  mlsort/__init__.py,sha256=49ZFRUBmCcD_YpHDLtAvb6CjCOAUoDqczL0c5pTWhPs,1121
2
2
  mlsort/algorithms.py,sha256=MgOOe9SHy9D_af7siDS4jWtuLK6alhIv8sPusqCx9qI,4475
3
- mlsort/api.py,sha256=T1T_ND-ybfld0FTHyjAihhNgPTcjcWgXnv50bnVnoKw,6026
3
+ mlsort/api.py,sha256=gyxjAYYu254QJyKDU0d0ApkWdf_O-DAOjj9zI-NZuiE,6500
4
4
  mlsort/baseline.py,sha256=2nZrEY7P5QAQ8RPOxqNz47rR_WZMyd3iONOSR38u_-Y,1104
5
5
  mlsort/benchmark.py,sha256=Ez_-HOnbvzfZD0323Nv_8vGj3xhENflztD-7IOEAalo,3713
6
6
  mlsort/cli_bench_compare.py,sha256=HH1C8H8IpWWdORT81_1gsO_je7emLW9NEIpqPedDbgw,1771
7
7
  mlsort/cli_bench_install.py,sha256=g28V9TZ_b5rJIEbV9LAkXZWYZaUMYQK10npGGxTZ5jI,936
8
+ mlsort/cli_export_forest.py,sha256=t2qpwfU85bMv1XNphjtA1u9YETM2j7WHcIeWdMctqWE,1481
8
9
  mlsort/cli_init.py,sha256=kDcnne1lLTRg5IEZVBOYiV8kkzb_5JHmTybZEvvPpKw,1721
9
10
  mlsort/cli_optimize_cutoffs.py,sha256=6De71xb93z6JScfMIduKkDe9DZ7xaG2Pg085JQe6HB0,1216
10
11
  mlsort/config.py,sha256=3Qzumm41uCvseH9LbRaDe06ffFOsJ3k3f20pM_bdjg8,966
11
12
  mlsort/data.py,sha256=HHtffrqOE15jLRxN6sc__mXBKsptT-USLxsZhq09SIc,3246
12
13
  mlsort/decision.py,sha256=YB3epa2L7Wa1faFmbyTSNjdGj0NigO2mro7K_k5IklA,1527
14
+ mlsort/fast_model.py,sha256=DQ1Pg3i1BeYcfvIX-6L17DO8b-FcbVRU2q7KyKcaxS8,1699
13
15
  mlsort/features.py,sha256=MJOwPnC4z8VDwiA9PXBgjwSr5r9djup7ZieGzSmwMbg,5303
14
16
  mlsort/installer.py,sha256=M7Dj2lMEecNblWY4YoMBME3fjqNf8L2BolNbvLcYUxE,4904
15
17
  mlsort/model.py,sha256=OY6b_04unIIbjrjmQ4LKrG52TmB7FvUBbflcBNg2-d0,2378
16
18
  mlsort/optimize.py,sha256=7Yi6tmiJcnj_6NtDfViuxzY1nVR3zoq-sbOj0v7yEis,2945
17
- mlsort-0.1.0.dist-info/licenses/LICENSE,sha256=yzOA5llIyAHw7tVsir3l5NgRm1_pkvXy2r4bUFcZY0g,1076
18
- mlsort-0.1.0.dist-info/METADATA,sha256=m4fy4EuvfqrB-enbeFxXrEQZb0nhbErdEaqAooD05DU,5794
19
- mlsort-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- mlsort-0.1.0.dist-info/entry_points.txt,sha256=HKRZnDWd50NuGw9uFEhAUZE6OOEKl74MT3J3yHL-1Is,218
21
- mlsort-0.1.0.dist-info/top_level.txt,sha256=0tl8OhYGP3bgyXuS76DsDreFASPKloMccz5pGfteKp0,7
22
- mlsort-0.1.0.dist-info/RECORD,,
19
+ mlsort-0.1.1.dist-info/licenses/LICENSE,sha256=yzOA5llIyAHw7tVsir3l5NgRm1_pkvXy2r4bUFcZY0g,1076
20
+ mlsort-0.1.1.dist-info/METADATA,sha256=_yZ_ztOpM8KorrJdFJt43p2vP234Un8Ck0QQqFRB5Pc,5870
21
+ mlsort-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ mlsort-0.1.1.dist-info/entry_points.txt,sha256=2oMHL1Z_f3U-hEwWb369v4XTwgKCkztdc1UJUPMhaDM,271
23
+ mlsort-0.1.1.dist-info/top_level.txt,sha256=0tl8OhYGP3bgyXuS76DsDreFASPKloMccz5pGfteKp0,7
24
+ mlsort-0.1.1.dist-info/RECORD,,
@@ -1,5 +1,6 @@
1
1
  [console_scripts]
2
2
  mlsort-bench-compare = mlsort.cli_bench_compare:main
3
3
  mlsort-bench-install = mlsort.cli_bench_install:main
4
+ mlsort-export-forest = mlsort.cli_export_forest:main
4
5
  mlsort-init = mlsort.cli_init:main
5
6
  mlsort-optimize-cutoffs = mlsort.cli_optimize_cutoffs:main
File without changes