baclast 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baclast-0.1.0/.gitignore +16 -0
- baclast-0.1.0/PKG-INFO +52 -0
- baclast-0.1.0/README.md +41 -0
- baclast-0.1.0/baclast/__init__.py +92 -0
- baclast-0.1.0/baclast/cli.py +184 -0
- baclast-0.1.0/baclast/eskape_classifier.py +172 -0
- baclast-0.1.0/baclast/features.py +155 -0
- baclast-0.1.0/baclast/model.pkl +0 -0
- baclast-0.1.0/baclast/model.py +211 -0
- baclast-0.1.0/baclast/utils.py +38 -0
- baclast-0.1.0/baclast/viz.py +127 -0
- baclast-0.1.0/pyproject.toml +39 -0
- baclast-0.1.0/tests/__init__.py +0 -0
- baclast-0.1.0/tests/conftest.py +43 -0
- baclast-0.1.0/tests/test_features.py +133 -0
- baclast-0.1.0/tests/test_model.py +156 -0
baclast-0.1.0/.gitignore
ADDED
baclast-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: baclast
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Fast ESKAPE bacterial genome classifier using k-mer profiles
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Requires-Dist: biopython>=1.81
|
|
7
|
+
Requires-Dist: joblib>=1.3
|
|
8
|
+
Requires-Dist: numpy>=1.24
|
|
9
|
+
Requires-Dist: scikit-learn>=1.3
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# BaClasT -- Bacterial Classification Tool
|
|
13
|
+
|
|
14
|
+
Fast classification of assembled bacterial genomes into ESKAPE pathogen species using k-mer frequency profiling.
|
|
15
|
+
|
|
16
|
+
## Install
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
uv add baclast
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## CLI
|
|
23
|
+
|
|
24
|
+
```bash
|
|
25
|
+
baclast --predict genome.fna
|
|
26
|
+
baclast --predict genomes/ -o results.csv
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
## Python
|
|
30
|
+
|
|
31
|
+
```python
|
|
32
|
+
import src.classifier as baclast
|
|
33
|
+
|
|
34
|
+
baclast.predict(file="genome.fna")
|
|
35
|
+
baclast.to_csv(baclast.predict(file="genome.fna"), "results.csv")
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
## What it classifies
|
|
39
|
+
|
|
40
|
+
ESKAPE pathogens (*E. faecium*, *S. aureus*, *K. pneumoniae*, *A. baumannii*, *P. aeruginosa*, *E. cloacae*) plus an "Other" class for non-ESKAPE bacteria. Includes centroid-based out-of-distribution detection.
|
|
41
|
+
|
|
42
|
+
## How it works
|
|
43
|
+
|
|
44
|
+
Computes 4-mer frequency profiles (256 features) from genome assemblies and classifies with a Random Forest. A bundled pre-trained model is included -- no training data or setup required.
|
|
45
|
+
|
|
46
|
+
## Requirements
|
|
47
|
+
|
|
48
|
+
Python >= 3.12, biopython, scikit-learn, joblib, numpy.
|
|
49
|
+
|
|
50
|
+
## License
|
|
51
|
+
|
|
52
|
+
MIT
|
baclast-0.1.0/README.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# BaClasT -- Bacterial Classification Tool
|
|
2
|
+
|
|
3
|
+
Fast classification of assembled bacterial genomes into ESKAPE pathogen species using k-mer frequency profiling.
|
|
4
|
+
|
|
5
|
+
## Install
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
uv add baclast
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## CLI
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
baclast --predict genome.fna
|
|
15
|
+
baclast --predict genomes/ -o results.csv
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
## Python
|
|
19
|
+
|
|
20
|
+
```python
|
|
21
|
+
import src.classifier as baclast
|
|
22
|
+
|
|
23
|
+
baclast.predict(file="genome.fna")
|
|
24
|
+
baclast.to_csv(baclast.predict(file="genome.fna"), "results.csv")
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## What it classifies
|
|
28
|
+
|
|
29
|
+
ESKAPE pathogens (*E. faecium*, *S. aureus*, *K. pneumoniae*, *A. baumannii*, *P. aeruginosa*, *E. cloacae*) plus an "Other" class for non-ESKAPE bacteria. Includes centroid-based out-of-distribution detection.
|
|
30
|
+
|
|
31
|
+
## How it works
|
|
32
|
+
|
|
33
|
+
Computes 4-mer frequency profiles (256 features) from genome assemblies and classifies with a Random Forest. A bundled pre-trained model is included -- no training data or setup required.
|
|
34
|
+
|
|
35
|
+
## Requirements
|
|
36
|
+
|
|
37
|
+
Python >= 3.12, biopython, scikit-learn, joblib, numpy.
|
|
38
|
+
|
|
39
|
+
## License
|
|
40
|
+
|
|
41
|
+
MIT
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
__version__ = "0.1.0"
|
|
2
|
+
|
|
3
|
+
import csv as _csv
|
|
4
|
+
from pathlib import Path as _Path
|
|
5
|
+
|
|
6
|
+
import numpy as _np
|
|
7
|
+
|
|
8
|
+
_BUNDLED_MODEL = _Path(__file__).resolve().parent / "model.pkl"
|
|
9
|
+
_model_cache = None
|
|
10
|
+
|
|
11
|
+
_FIELDS = [
|
|
12
|
+
"filepath", "filename", "organism_prediction", "confidence",
|
|
13
|
+
"confidence_warning", "nearest_centroid", "distance", "threshold",
|
|
14
|
+
"within_distribution", "baclast_version",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _load_model():
|
|
19
|
+
global _model_cache
|
|
20
|
+
if _model_cache is None:
|
|
21
|
+
from baclast.model import load_model
|
|
22
|
+
_model_cache = load_model(_BUNDLED_MODEL)
|
|
23
|
+
return _model_cache
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def predict(file: str) -> dict:
|
|
27
|
+
"""Classify a bacterial genome FASTA file.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
file: Path to a FASTA file (.fasta, .fa, .fna).
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Dict with keys: filepath, filename, organism_prediction, confidence,
|
|
34
|
+
confidence_warning, nearest_centroid, distance, threshold,
|
|
35
|
+
within_distribution, baclast_version.
|
|
36
|
+
"""
|
|
37
|
+
from baclast.features import genome_to_vector
|
|
38
|
+
from baclast.model import novelty_score
|
|
39
|
+
|
|
40
|
+
payload = _load_model()
|
|
41
|
+
clf = payload["classifier"]
|
|
42
|
+
label_names = payload["label_names"]
|
|
43
|
+
k = payload["k"]
|
|
44
|
+
kmer_vocab = payload["kmer_vocab"]
|
|
45
|
+
centroids = payload.get("centroids")
|
|
46
|
+
threshold = payload.get("distance_threshold")
|
|
47
|
+
|
|
48
|
+
fpath = _Path(file)
|
|
49
|
+
vec = genome_to_vector(fpath, k, kmer_vocab)
|
|
50
|
+
X_q = _np.array([vec])
|
|
51
|
+
pred = clf.predict(X_q)[0]
|
|
52
|
+
proba = clf.predict_proba(X_q)[0]
|
|
53
|
+
species = label_names[pred]
|
|
54
|
+
confidence = round(float(proba[pred]) * 100, 2)
|
|
55
|
+
|
|
56
|
+
result = {
|
|
57
|
+
"filepath": str(fpath),
|
|
58
|
+
"filename": fpath.name,
|
|
59
|
+
"organism_prediction": species,
|
|
60
|
+
"confidence": confidence,
|
|
61
|
+
"confidence_warning": "LOW" if confidence < 70.0 else "",
|
|
62
|
+
"baclast_version": __version__,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
if centroids and threshold:
|
|
66
|
+
nearest, dist = novelty_score(vec, centroids)
|
|
67
|
+
result["nearest_centroid"] = nearest
|
|
68
|
+
result["distance"] = round(float(dist), 6)
|
|
69
|
+
result["threshold"] = round(float(threshold), 6)
|
|
70
|
+
result["within_distribution"] = "Yes" if dist <= threshold else "No"
|
|
71
|
+
else:
|
|
72
|
+
result["nearest_centroid"] = ""
|
|
73
|
+
result["distance"] = ""
|
|
74
|
+
result["threshold"] = ""
|
|
75
|
+
result["within_distribution"] = ""
|
|
76
|
+
|
|
77
|
+
return result
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def to_csv(result: dict, path: str) -> None:
|
|
81
|
+
"""Write a prediction result dict to a CSV file.
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
FileExistsError: If the file already exists.
|
|
85
|
+
"""
|
|
86
|
+
p = _Path(path)
|
|
87
|
+
if p.exists():
|
|
88
|
+
raise FileExistsError(f"Output file already exists: {p}")
|
|
89
|
+
with open(p, "w", newline="") as f:
|
|
90
|
+
writer = _csv.DictWriter(f, fieldnames=_FIELDS)
|
|
91
|
+
writer.writeheader()
|
|
92
|
+
writer.writerow(result)
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""BaClasT CLI — bacterial genome classification tool."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import csv
|
|
5
|
+
import io
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from baclast import __version__
|
|
12
|
+
from baclast.features import genome_to_vector
|
|
13
|
+
from baclast.model import load_model, novelty_score
|
|
14
|
+
|
|
15
|
+
_BUNDLED_MODEL = Path(__file__).resolve().parent / "model.pkl"
|
|
16
|
+
_FASTA_EXTENSIONS = {".fasta", ".fa", ".fna"}
|
|
17
|
+
_MIN_CONFIDENCE = 70.0 # below this, flag as low confidence
|
|
18
|
+
_CSV_FIELDS = [
|
|
19
|
+
"filepath",
|
|
20
|
+
"filename",
|
|
21
|
+
"organism_prediction",
|
|
22
|
+
"confidence",
|
|
23
|
+
"confidence_warning",
|
|
24
|
+
"nearest_centroid",
|
|
25
|
+
"distance",
|
|
26
|
+
"threshold",
|
|
27
|
+
"within_distribution",
|
|
28
|
+
"baclast_version",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _find_model(user_path: str | None) -> Path:
|
|
33
|
+
"""Resolve model path: user-provided, or bundled default."""
|
|
34
|
+
if user_path:
|
|
35
|
+
p = Path(user_path)
|
|
36
|
+
if not p.exists():
|
|
37
|
+
sys.exit(f"Error: Model file not found: {p}")
|
|
38
|
+
return p
|
|
39
|
+
if _BUNDLED_MODEL.exists():
|
|
40
|
+
return _BUNDLED_MODEL
|
|
41
|
+
sys.exit(f"Error: No model found. Provide --model or install a model to {_BUNDLED_MODEL}")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _collect_fastas(target: Path) -> list[Path]:
|
|
45
|
+
"""Return a list of FASTA files from a file path or directory."""
|
|
46
|
+
if target.is_file():
|
|
47
|
+
if target.suffix not in _FASTA_EXTENSIONS:
|
|
48
|
+
sys.exit(f"Error: {target} does not look like a FASTA file. "
|
|
49
|
+
f"Expected extensions: {', '.join(sorted(_FASTA_EXTENSIONS))}")
|
|
50
|
+
return [target]
|
|
51
|
+
if target.is_dir():
|
|
52
|
+
fastas = sorted(f for f in target.iterdir() if f.suffix in _FASTA_EXTENSIONS)
|
|
53
|
+
if not fastas:
|
|
54
|
+
sys.exit(f"Error: No FASTA files found in {target}")
|
|
55
|
+
return fastas
|
|
56
|
+
sys.exit(f"Error: Path not found: {target}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _classify_one(fpath: Path, clf, label_names, k, kmer_vocab, centroids, threshold) -> dict:
|
|
60
|
+
"""Classify a single FASTA and return a result row."""
|
|
61
|
+
vec = genome_to_vector(fpath, k, kmer_vocab)
|
|
62
|
+
X_q = np.array([vec])
|
|
63
|
+
pred = clf.predict(X_q)[0]
|
|
64
|
+
proba = clf.predict_proba(X_q)[0]
|
|
65
|
+
species = label_names[pred]
|
|
66
|
+
confidence = round(proba[pred] * 100, 2)
|
|
67
|
+
|
|
68
|
+
if confidence < _MIN_CONFIDENCE:
|
|
69
|
+
warning = "LOW"
|
|
70
|
+
else:
|
|
71
|
+
warning = ""
|
|
72
|
+
|
|
73
|
+
row = {
|
|
74
|
+
"filepath": str(fpath),
|
|
75
|
+
"filename": fpath.name,
|
|
76
|
+
"organism_prediction": species,
|
|
77
|
+
"confidence": confidence,
|
|
78
|
+
"confidence_warning": warning,
|
|
79
|
+
"baclast_version": __version__,
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
if centroids and threshold:
|
|
83
|
+
nearest, dist = novelty_score(vec, centroids)
|
|
84
|
+
row["nearest_centroid"] = nearest
|
|
85
|
+
row["distance"] = round(dist, 6)
|
|
86
|
+
row["threshold"] = round(threshold, 6)
|
|
87
|
+
row["within_distribution"] = "Yes" if dist <= threshold else "No"
|
|
88
|
+
else:
|
|
89
|
+
row["nearest_centroid"] = ""
|
|
90
|
+
row["distance"] = ""
|
|
91
|
+
row["threshold"] = ""
|
|
92
|
+
row["within_distribution"] = ""
|
|
93
|
+
|
|
94
|
+
return row
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def main():
|
|
98
|
+
parser = argparse.ArgumentParser(
|
|
99
|
+
prog="baclast",
|
|
100
|
+
description="BaClasT — fast bacterial genome classification using k-mer profiles",
|
|
101
|
+
)
|
|
102
|
+
parser.add_argument(
|
|
103
|
+
"--predict", required=True, metavar="PATH",
|
|
104
|
+
help="Path to a FASTA file or directory of FASTAs",
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"-o", "--output", default=None, metavar="FILE",
|
|
108
|
+
help="Write results to a CSV file instead of stdout",
|
|
109
|
+
)
|
|
110
|
+
parser.add_argument(
|
|
111
|
+
"--model", default=None, metavar="FILE",
|
|
112
|
+
help="Path to model .pkl (uses bundled model if omitted)",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
args = parser.parse_args()
|
|
116
|
+
|
|
117
|
+
# Check output file doesn't already exist
|
|
118
|
+
if args.output:
|
|
119
|
+
out_path = Path(args.output)
|
|
120
|
+
if out_path.exists():
|
|
121
|
+
sys.exit(f"Error: Output file already exists: {out_path}")
|
|
122
|
+
|
|
123
|
+
# Load model
|
|
124
|
+
payload = load_model(_find_model(args.model))
|
|
125
|
+
clf = payload["classifier"]
|
|
126
|
+
label_names = payload["label_names"]
|
|
127
|
+
k = payload["k"]
|
|
128
|
+
kmer_vocab = payload["kmer_vocab"]
|
|
129
|
+
centroids = payload.get("centroids")
|
|
130
|
+
threshold = payload.get("distance_threshold")
|
|
131
|
+
|
|
132
|
+
# Collect input FASTAs
|
|
133
|
+
target = Path(args.predict)
|
|
134
|
+
fastas = _collect_fastas(target)
|
|
135
|
+
|
|
136
|
+
# Classify
|
|
137
|
+
rows = []
|
|
138
|
+
for i, fpath in enumerate(fastas, 1):
|
|
139
|
+
if len(fastas) > 1:
|
|
140
|
+
print(f" [{i}/{len(fastas)}] {fpath.name} ... ", end="", flush=True, file=sys.stderr)
|
|
141
|
+
try:
|
|
142
|
+
row = _classify_one(fpath, clf, label_names, k, kmer_vocab, centroids, threshold)
|
|
143
|
+
rows.append(row)
|
|
144
|
+
status = f"{row['organism_prediction']} ({row['confidence']}%)"
|
|
145
|
+
if row["confidence_warning"]:
|
|
146
|
+
status += f" [{row['confidence_warning']} CONFIDENCE]"
|
|
147
|
+
if len(fastas) > 1:
|
|
148
|
+
print(status, file=sys.stderr)
|
|
149
|
+
except (ValueError, Exception) as exc:
|
|
150
|
+
rows.append({
|
|
151
|
+
"filepath": str(fpath),
|
|
152
|
+
"filename": fpath.name,
|
|
153
|
+
"organism_prediction": "SKIPPED",
|
|
154
|
+
"confidence": "",
|
|
155
|
+
"confidence_warning": str(exc),
|
|
156
|
+
"nearest_centroid": "",
|
|
157
|
+
"distance": "",
|
|
158
|
+
"threshold": "",
|
|
159
|
+
"within_distribution": "",
|
|
160
|
+
"baclast_version": __version__,
|
|
161
|
+
})
|
|
162
|
+
if len(fastas) > 1:
|
|
163
|
+
print(f"SKIPPED: {exc}", file=sys.stderr)
|
|
164
|
+
else:
|
|
165
|
+
sys.exit(f"Error: {exc}")
|
|
166
|
+
|
|
167
|
+
# Output
|
|
168
|
+
if args.output:
|
|
169
|
+
out_path = Path(args.output)
|
|
170
|
+
with open(out_path, "w", newline="") as f:
|
|
171
|
+
writer = csv.DictWriter(f, fieldnames=_CSV_FIELDS)
|
|
172
|
+
writer.writeheader()
|
|
173
|
+
writer.writerows(rows)
|
|
174
|
+
print(f"Results written to {out_path} ({len(rows)} genomes)", file=sys.stderr)
|
|
175
|
+
else:
|
|
176
|
+
buf = io.StringIO()
|
|
177
|
+
writer = csv.DictWriter(buf, fieldnames=_CSV_FIELDS)
|
|
178
|
+
writer.writeheader()
|
|
179
|
+
writer.writerows(rows)
|
|
180
|
+
print(buf.getvalue(), end="")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
if __name__ == "__main__":
|
|
184
|
+
main()
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""BaClasT — main CLI entry point for training and prediction."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
from baclast.features import all_kmers, genome_to_vector, load_dataset
|
|
7
|
+
from baclast.model import evaluate, load_model, save_model, train_classifier
|
|
8
|
+
from baclast.utils import print_banner, setup_logging
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def cmd_train(args):
|
|
12
|
+
"""Execute the 'train' sub-command."""
|
|
13
|
+
logger = setup_logging(args.verbose)
|
|
14
|
+
print_banner()
|
|
15
|
+
|
|
16
|
+
k = args.k
|
|
17
|
+
kmer_vocab = all_kmers(k)
|
|
18
|
+
|
|
19
|
+
print(f"Loading genomes from {args.data_dir} (k={k})...")
|
|
20
|
+
try:
|
|
21
|
+
X, y, label_names = load_dataset(args.data_dir, k, kmer_vocab)
|
|
22
|
+
except Exception as exc:
|
|
23
|
+
sys.exit(f"Error: {exc}")
|
|
24
|
+
|
|
25
|
+
if len(label_names) < 2:
|
|
26
|
+
sys.exit("Error: Need at least 2 species to train a classifier.")
|
|
27
|
+
|
|
28
|
+
n_genomes = len(y)
|
|
29
|
+
print(f"Loaded {n_genomes} genomes across {len(label_names)} species.")
|
|
30
|
+
|
|
31
|
+
# 80/20 stratified train/test split
|
|
32
|
+
from sklearn.model_selection import train_test_split
|
|
33
|
+
|
|
34
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
|
35
|
+
X, y, test_size=0.2, stratify=y, random_state=42
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
print(f"Training Random Forest ({args.n_estimators} trees)...")
|
|
39
|
+
clf = train_classifier(X_train, y_train, n_estimators=args.n_estimators)
|
|
40
|
+
|
|
41
|
+
print("Evaluating on held-out test set:")
|
|
42
|
+
evaluate(clf, X_test, y_test, label_names)
|
|
43
|
+
|
|
44
|
+
# Optional cross-validation
|
|
45
|
+
if args.cv is not None:
|
|
46
|
+
from sklearn.model_selection import StratifiedKFold, cross_val_score
|
|
47
|
+
|
|
48
|
+
print(f"\nRunning {args.cv}-fold stratified cross-validation...")
|
|
49
|
+
cv = StratifiedKFold(n_splits=args.cv, shuffle=True, random_state=42)
|
|
50
|
+
scores = cross_val_score(clf, X, y, cv=cv, scoring="accuracy")
|
|
51
|
+
print(f"CV accuracy: {scores.mean():.4f} +/- {scores.std():.4f}")
|
|
52
|
+
|
|
53
|
+
# Save model
|
|
54
|
+
save_model(clf, label_names, k, kmer_vocab, args.output)
|
|
55
|
+
|
|
56
|
+
# Patch in n_genomes (save_model doesn't know the total count)
|
|
57
|
+
import joblib
|
|
58
|
+
|
|
59
|
+
payload = joblib.load(args.output)
|
|
60
|
+
payload["n_genomes"] = n_genomes
|
|
61
|
+
joblib.dump(payload, args.output)
|
|
62
|
+
|
|
63
|
+
print(f"\nModel saved to {args.output}")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def cmd_predict(args):
|
|
67
|
+
"""Execute the 'predict' sub-command."""
|
|
68
|
+
logger = setup_logging(args.verbose)
|
|
69
|
+
print_banner()
|
|
70
|
+
|
|
71
|
+
# Load model
|
|
72
|
+
try:
|
|
73
|
+
payload = load_model(args.model)
|
|
74
|
+
except FileNotFoundError:
|
|
75
|
+
sys.exit(f"Error: Model file not found: {args.model}")
|
|
76
|
+
except ValueError as exc:
|
|
77
|
+
sys.exit(f"Error: {exc}")
|
|
78
|
+
|
|
79
|
+
clf = payload["classifier"]
|
|
80
|
+
label_names = payload["label_names"]
|
|
81
|
+
k = payload["k"]
|
|
82
|
+
kmer_vocab = payload["kmer_vocab"]
|
|
83
|
+
|
|
84
|
+
# Extract features from input FASTA
|
|
85
|
+
try:
|
|
86
|
+
vec = genome_to_vector(args.fasta, k, kmer_vocab)
|
|
87
|
+
except FileNotFoundError:
|
|
88
|
+
sys.exit(f"Error: FASTA file not found: {args.fasta}")
|
|
89
|
+
except ValueError as exc:
|
|
90
|
+
sys.exit(f"Error: {exc}")
|
|
91
|
+
|
|
92
|
+
import numpy as np
|
|
93
|
+
|
|
94
|
+
X_query = np.array([vec])
|
|
95
|
+
pred = clf.predict(X_query)[0]
|
|
96
|
+
proba = clf.predict_proba(X_query)[0]
|
|
97
|
+
|
|
98
|
+
species = label_names[pred]
|
|
99
|
+
confidence = proba[pred] * 100
|
|
100
|
+
|
|
101
|
+
print(f"Predicted species: {species}")
|
|
102
|
+
print(f"Confidence: {confidence:.1f}%")
|
|
103
|
+
|
|
104
|
+
if args.verbose:
|
|
105
|
+
print("\nAll species probabilities:")
|
|
106
|
+
# Sort by probability descending
|
|
107
|
+
ranked = sorted(
|
|
108
|
+
zip(label_names, proba), key=lambda x: x[1], reverse=True
|
|
109
|
+
)
|
|
110
|
+
max_name_len = max(len(name) for name in label_names)
|
|
111
|
+
for name, prob in ranked:
|
|
112
|
+
bar_len = int(prob * 40)
|
|
113
|
+
bar = "#" * bar_len
|
|
114
|
+
print(f" {name:<{max_name_len}} {prob * 100:5.1f}% |{bar}")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def main():
|
|
118
|
+
"""Main entry point — parse arguments and dispatch to sub-commands."""
|
|
119
|
+
parser = argparse.ArgumentParser(
|
|
120
|
+
prog="baclasp",
|
|
121
|
+
description="BaClasT — Bacterial Classification Tool",
|
|
122
|
+
)
|
|
123
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
124
|
+
|
|
125
|
+
# train sub-command
|
|
126
|
+
train_parser = subparsers.add_parser("train", help="Train a classifier")
|
|
127
|
+
train_parser.add_argument(
|
|
128
|
+
"--data_dir", required=True, help="Directory with species sub-folders"
|
|
129
|
+
)
|
|
130
|
+
train_parser.add_argument(
|
|
131
|
+
"--output", default="model.pkl", help="Output model path (default: model.pkl)"
|
|
132
|
+
)
|
|
133
|
+
train_parser.add_argument(
|
|
134
|
+
"--k", type=int, default=4, help="K-mer length (default: 4)"
|
|
135
|
+
)
|
|
136
|
+
train_parser.add_argument(
|
|
137
|
+
"--n_estimators",
|
|
138
|
+
type=int,
|
|
139
|
+
default=200,
|
|
140
|
+
help="Number of trees (default: 200)",
|
|
141
|
+
)
|
|
142
|
+
train_parser.add_argument(
|
|
143
|
+
"--cv", type=int, default=None, help="Number of CV folds (optional)"
|
|
144
|
+
)
|
|
145
|
+
train_parser.add_argument(
|
|
146
|
+
"--verbose", action="store_true", help="Enable verbose output"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# predict sub-command
|
|
150
|
+
predict_parser = subparsers.add_parser(
|
|
151
|
+
"predict", help="Predict species for a FASTA file"
|
|
152
|
+
)
|
|
153
|
+
predict_parser.add_argument(
|
|
154
|
+
"--model", required=True, help="Path to trained model .pkl"
|
|
155
|
+
)
|
|
156
|
+
predict_parser.add_argument(
|
|
157
|
+
"--fasta", required=True, help="Path to input FASTA file"
|
|
158
|
+
)
|
|
159
|
+
predict_parser.add_argument(
|
|
160
|
+
"--verbose", action="store_true", help="Show all species probabilities"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
args = parser.parse_args()
|
|
164
|
+
|
|
165
|
+
if args.command == "train":
|
|
166
|
+
cmd_train(args)
|
|
167
|
+
elif args.command == "predict":
|
|
168
|
+
cmd_predict(args)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
if __name__ == "__main__":
|
|
172
|
+
main()
|