baclast 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,16 @@
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv/
11
+
12
+ # Model files
13
+ *.pkl
14
+
15
+ # Notebook cache
16
+ cache/
baclast-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,52 @@
1
+ Metadata-Version: 2.4
2
+ Name: baclast
3
+ Version: 0.1.0
4
+ Summary: Fast ESKAPE bacterial genome classifier using k-mer profiles
5
+ Requires-Python: >=3.12
6
+ Requires-Dist: biopython>=1.81
7
+ Requires-Dist: joblib>=1.3
8
+ Requires-Dist: numpy>=1.24
9
+ Requires-Dist: scikit-learn>=1.3
10
+ Description-Content-Type: text/markdown
11
+
12
+ # BaClasT -- Bacterial Classification Tool
13
+
14
+ Fast classification of assembled bacterial genomes into ESKAPE pathogen species using k-mer frequency profiling.
15
+
16
+ ## Install
17
+
18
+ ```bash
19
+ uv add baclast
20
+ ```
21
+
22
+ ## CLI
23
+
24
+ ```bash
25
+ baclast --predict genome.fna
26
+ baclast --predict genomes/ -o results.csv
27
+ ```
28
+
29
+ ## Python
30
+
31
+ ```python
32
+ import src.classifier as baclast
33
+
34
+ baclast.predict(file="genome.fna")
35
+ baclast.to_csv(baclast.predict(file="genome.fna"), "results.csv")
36
+ ```
37
+
38
+ ## What it classifies
39
+
40
+ ESKAPE pathogens (*E. faecium*, *S. aureus*, *K. pneumoniae*, *A. baumannii*, *P. aeruginosa*, *E. cloacae*) plus an "Other" class for non-ESKAPE bacteria. Includes centroid-based out-of-distribution detection.
41
+
42
+ ## How it works
43
+
44
+ Computes 4-mer frequency profiles (256 features) from genome assemblies and classifies with a Random Forest. A bundled pre-trained model is included -- no training data or setup required.
45
+
46
+ ## Requirements
47
+
48
+ Python >= 3.12, biopython, scikit-learn, joblib, numpy.
49
+
50
+ ## License
51
+
52
+ MIT
@@ -0,0 +1,41 @@
1
+ # BaClasT -- Bacterial Classification Tool
2
+
3
+ Fast classification of assembled bacterial genomes into ESKAPE pathogen species using k-mer frequency profiling.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ uv add baclast
9
+ ```
10
+
11
+ ## CLI
12
+
13
+ ```bash
14
+ baclast --predict genome.fna
15
+ baclast --predict genomes/ -o results.csv
16
+ ```
17
+
18
+ ## Python
19
+
20
+ ```python
21
+ import src.classifier as baclast
22
+
23
+ baclast.predict(file="genome.fna")
24
+ baclast.to_csv(baclast.predict(file="genome.fna"), "results.csv")
25
+ ```
26
+
27
+ ## What it classifies
28
+
29
+ ESKAPE pathogens (*E. faecium*, *S. aureus*, *K. pneumoniae*, *A. baumannii*, *P. aeruginosa*, *E. cloacae*) plus an "Other" class for non-ESKAPE bacteria. Includes centroid-based out-of-distribution detection.
30
+
31
+ ## How it works
32
+
33
+ Computes 4-mer frequency profiles (256 features) from genome assemblies and classifies with a Random Forest. A bundled pre-trained model is included -- no training data or setup required.
34
+
35
+ ## Requirements
36
+
37
+ Python >= 3.12, biopython, scikit-learn, joblib, numpy.
38
+
39
+ ## License
40
+
41
+ MIT
@@ -0,0 +1,92 @@
1
+ __version__ = "0.1.0"
2
+
3
+ import csv as _csv
4
+ from pathlib import Path as _Path
5
+
6
+ import numpy as _np
7
+
8
+ _BUNDLED_MODEL = _Path(__file__).resolve().parent / "model.pkl"
9
+ _model_cache = None
10
+
11
+ _FIELDS = [
12
+ "filepath", "filename", "organism_prediction", "confidence",
13
+ "confidence_warning", "nearest_centroid", "distance", "threshold",
14
+ "within_distribution", "baclast_version",
15
+ ]
16
+
17
+
18
+ def _load_model():
19
+ global _model_cache
20
+ if _model_cache is None:
21
+ from baclast.model import load_model
22
+ _model_cache = load_model(_BUNDLED_MODEL)
23
+ return _model_cache
24
+
25
+
26
+ def predict(file: str) -> dict:
27
+ """Classify a bacterial genome FASTA file.
28
+
29
+ Args:
30
+ file: Path to a FASTA file (.fasta, .fa, .fna).
31
+
32
+ Returns:
33
+ Dict with keys: filepath, filename, organism_prediction, confidence,
34
+ confidence_warning, nearest_centroid, distance, threshold,
35
+ within_distribution, baclast_version.
36
+ """
37
+ from baclast.features import genome_to_vector
38
+ from baclast.model import novelty_score
39
+
40
+ payload = _load_model()
41
+ clf = payload["classifier"]
42
+ label_names = payload["label_names"]
43
+ k = payload["k"]
44
+ kmer_vocab = payload["kmer_vocab"]
45
+ centroids = payload.get("centroids")
46
+ threshold = payload.get("distance_threshold")
47
+
48
+ fpath = _Path(file)
49
+ vec = genome_to_vector(fpath, k, kmer_vocab)
50
+ X_q = _np.array([vec])
51
+ pred = clf.predict(X_q)[0]
52
+ proba = clf.predict_proba(X_q)[0]
53
+ species = label_names[pred]
54
+ confidence = round(float(proba[pred]) * 100, 2)
55
+
56
+ result = {
57
+ "filepath": str(fpath),
58
+ "filename": fpath.name,
59
+ "organism_prediction": species,
60
+ "confidence": confidence,
61
+ "confidence_warning": "LOW" if confidence < 70.0 else "",
62
+ "baclast_version": __version__,
63
+ }
64
+
65
+ if centroids and threshold:
66
+ nearest, dist = novelty_score(vec, centroids)
67
+ result["nearest_centroid"] = nearest
68
+ result["distance"] = round(float(dist), 6)
69
+ result["threshold"] = round(float(threshold), 6)
70
+ result["within_distribution"] = "Yes" if dist <= threshold else "No"
71
+ else:
72
+ result["nearest_centroid"] = ""
73
+ result["distance"] = ""
74
+ result["threshold"] = ""
75
+ result["within_distribution"] = ""
76
+
77
+ return result
78
+
79
+
80
+ def to_csv(result: dict, path: str) -> None:
81
+ """Write a prediction result dict to a CSV file.
82
+
83
+ Raises:
84
+ FileExistsError: If the file already exists.
85
+ """
86
+ p = _Path(path)
87
+ if p.exists():
88
+ raise FileExistsError(f"Output file already exists: {p}")
89
+ with open(p, "w", newline="") as f:
90
+ writer = _csv.DictWriter(f, fieldnames=_FIELDS)
91
+ writer.writeheader()
92
+ writer.writerow(result)
@@ -0,0 +1,184 @@
1
+ """BaClasT CLI — bacterial genome classification tool."""
2
+
3
+ import argparse
4
+ import csv
5
+ import io
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+
11
+ from baclast import __version__
12
+ from baclast.features import genome_to_vector
13
+ from baclast.model import load_model, novelty_score
14
+
15
+ _BUNDLED_MODEL = Path(__file__).resolve().parent / "model.pkl"
16
+ _FASTA_EXTENSIONS = {".fasta", ".fa", ".fna"}
17
+ _MIN_CONFIDENCE = 70.0 # below this, flag as low confidence
18
+ _CSV_FIELDS = [
19
+ "filepath",
20
+ "filename",
21
+ "organism_prediction",
22
+ "confidence",
23
+ "confidence_warning",
24
+ "nearest_centroid",
25
+ "distance",
26
+ "threshold",
27
+ "within_distribution",
28
+ "baclast_version",
29
+ ]
30
+
31
+
32
+ def _find_model(user_path: str | None) -> Path:
33
+ """Resolve model path: user-provided, or bundled default."""
34
+ if user_path:
35
+ p = Path(user_path)
36
+ if not p.exists():
37
+ sys.exit(f"Error: Model file not found: {p}")
38
+ return p
39
+ if _BUNDLED_MODEL.exists():
40
+ return _BUNDLED_MODEL
41
+ sys.exit(f"Error: No model found. Provide --model or install a model to {_BUNDLED_MODEL}")
42
+
43
+
44
+ def _collect_fastas(target: Path) -> list[Path]:
45
+ """Return a list of FASTA files from a file path or directory."""
46
+ if target.is_file():
47
+ if target.suffix not in _FASTA_EXTENSIONS:
48
+ sys.exit(f"Error: {target} does not look like a FASTA file. "
49
+ f"Expected extensions: {', '.join(sorted(_FASTA_EXTENSIONS))}")
50
+ return [target]
51
+ if target.is_dir():
52
+ fastas = sorted(f for f in target.iterdir() if f.suffix in _FASTA_EXTENSIONS)
53
+ if not fastas:
54
+ sys.exit(f"Error: No FASTA files found in {target}")
55
+ return fastas
56
+ sys.exit(f"Error: Path not found: {target}")
57
+
58
+
59
+ def _classify_one(fpath: Path, clf, label_names, k, kmer_vocab, centroids, threshold) -> dict:
60
+ """Classify a single FASTA and return a result row."""
61
+ vec = genome_to_vector(fpath, k, kmer_vocab)
62
+ X_q = np.array([vec])
63
+ pred = clf.predict(X_q)[0]
64
+ proba = clf.predict_proba(X_q)[0]
65
+ species = label_names[pred]
66
+ confidence = round(proba[pred] * 100, 2)
67
+
68
+ if confidence < _MIN_CONFIDENCE:
69
+ warning = "LOW"
70
+ else:
71
+ warning = ""
72
+
73
+ row = {
74
+ "filepath": str(fpath),
75
+ "filename": fpath.name,
76
+ "organism_prediction": species,
77
+ "confidence": confidence,
78
+ "confidence_warning": warning,
79
+ "baclast_version": __version__,
80
+ }
81
+
82
+ if centroids and threshold:
83
+ nearest, dist = novelty_score(vec, centroids)
84
+ row["nearest_centroid"] = nearest
85
+ row["distance"] = round(dist, 6)
86
+ row["threshold"] = round(threshold, 6)
87
+ row["within_distribution"] = "Yes" if dist <= threshold else "No"
88
+ else:
89
+ row["nearest_centroid"] = ""
90
+ row["distance"] = ""
91
+ row["threshold"] = ""
92
+ row["within_distribution"] = ""
93
+
94
+ return row
95
+
96
+
97
+ def main():
98
+ parser = argparse.ArgumentParser(
99
+ prog="baclast",
100
+ description="BaClasT — fast bacterial genome classification using k-mer profiles",
101
+ )
102
+ parser.add_argument(
103
+ "--predict", required=True, metavar="PATH",
104
+ help="Path to a FASTA file or directory of FASTAs",
105
+ )
106
+ parser.add_argument(
107
+ "-o", "--output", default=None, metavar="FILE",
108
+ help="Write results to a CSV file instead of stdout",
109
+ )
110
+ parser.add_argument(
111
+ "--model", default=None, metavar="FILE",
112
+ help="Path to model .pkl (uses bundled model if omitted)",
113
+ )
114
+
115
+ args = parser.parse_args()
116
+
117
+ # Check output file doesn't already exist
118
+ if args.output:
119
+ out_path = Path(args.output)
120
+ if out_path.exists():
121
+ sys.exit(f"Error: Output file already exists: {out_path}")
122
+
123
+ # Load model
124
+ payload = load_model(_find_model(args.model))
125
+ clf = payload["classifier"]
126
+ label_names = payload["label_names"]
127
+ k = payload["k"]
128
+ kmer_vocab = payload["kmer_vocab"]
129
+ centroids = payload.get("centroids")
130
+ threshold = payload.get("distance_threshold")
131
+
132
+ # Collect input FASTAs
133
+ target = Path(args.predict)
134
+ fastas = _collect_fastas(target)
135
+
136
+ # Classify
137
+ rows = []
138
+ for i, fpath in enumerate(fastas, 1):
139
+ if len(fastas) > 1:
140
+ print(f" [{i}/{len(fastas)}] {fpath.name} ... ", end="", flush=True, file=sys.stderr)
141
+ try:
142
+ row = _classify_one(fpath, clf, label_names, k, kmer_vocab, centroids, threshold)
143
+ rows.append(row)
144
+ status = f"{row['organism_prediction']} ({row['confidence']}%)"
145
+ if row["confidence_warning"]:
146
+ status += f" [{row['confidence_warning']} CONFIDENCE]"
147
+ if len(fastas) > 1:
148
+ print(status, file=sys.stderr)
149
+ except (ValueError, Exception) as exc:
150
+ rows.append({
151
+ "filepath": str(fpath),
152
+ "filename": fpath.name,
153
+ "organism_prediction": "SKIPPED",
154
+ "confidence": "",
155
+ "confidence_warning": str(exc),
156
+ "nearest_centroid": "",
157
+ "distance": "",
158
+ "threshold": "",
159
+ "within_distribution": "",
160
+ "baclast_version": __version__,
161
+ })
162
+ if len(fastas) > 1:
163
+ print(f"SKIPPED: {exc}", file=sys.stderr)
164
+ else:
165
+ sys.exit(f"Error: {exc}")
166
+
167
+ # Output
168
+ if args.output:
169
+ out_path = Path(args.output)
170
+ with open(out_path, "w", newline="") as f:
171
+ writer = csv.DictWriter(f, fieldnames=_CSV_FIELDS)
172
+ writer.writeheader()
173
+ writer.writerows(rows)
174
+ print(f"Results written to {out_path} ({len(rows)} genomes)", file=sys.stderr)
175
+ else:
176
+ buf = io.StringIO()
177
+ writer = csv.DictWriter(buf, fieldnames=_CSV_FIELDS)
178
+ writer.writeheader()
179
+ writer.writerows(rows)
180
+ print(buf.getvalue(), end="")
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
@@ -0,0 +1,172 @@
1
+ """BaClasT — main CLI entry point for training and prediction."""
2
+
3
+ import argparse
4
+ import sys
5
+
6
+ from baclast.features import all_kmers, genome_to_vector, load_dataset
7
+ from baclast.model import evaluate, load_model, save_model, train_classifier
8
+ from baclast.utils import print_banner, setup_logging
9
+
10
+
11
+ def cmd_train(args):
12
+ """Execute the 'train' sub-command."""
13
+ logger = setup_logging(args.verbose)
14
+ print_banner()
15
+
16
+ k = args.k
17
+ kmer_vocab = all_kmers(k)
18
+
19
+ print(f"Loading genomes from {args.data_dir} (k={k})...")
20
+ try:
21
+ X, y, label_names = load_dataset(args.data_dir, k, kmer_vocab)
22
+ except Exception as exc:
23
+ sys.exit(f"Error: {exc}")
24
+
25
+ if len(label_names) < 2:
26
+ sys.exit("Error: Need at least 2 species to train a classifier.")
27
+
28
+ n_genomes = len(y)
29
+ print(f"Loaded {n_genomes} genomes across {len(label_names)} species.")
30
+
31
+ # 80/20 stratified train/test split
32
+ from sklearn.model_selection import train_test_split
33
+
34
+ X_train, X_test, y_train, y_test = train_test_split(
35
+ X, y, test_size=0.2, stratify=y, random_state=42
36
+ )
37
+
38
+ print(f"Training Random Forest ({args.n_estimators} trees)...")
39
+ clf = train_classifier(X_train, y_train, n_estimators=args.n_estimators)
40
+
41
+ print("Evaluating on held-out test set:")
42
+ evaluate(clf, X_test, y_test, label_names)
43
+
44
+ # Optional cross-validation
45
+ if args.cv is not None:
46
+ from sklearn.model_selection import StratifiedKFold, cross_val_score
47
+
48
+ print(f"\nRunning {args.cv}-fold stratified cross-validation...")
49
+ cv = StratifiedKFold(n_splits=args.cv, shuffle=True, random_state=42)
50
+ scores = cross_val_score(clf, X, y, cv=cv, scoring="accuracy")
51
+ print(f"CV accuracy: {scores.mean():.4f} +/- {scores.std():.4f}")
52
+
53
+ # Save model
54
+ save_model(clf, label_names, k, kmer_vocab, args.output)
55
+
56
+ # Patch in n_genomes (save_model doesn't know the total count)
57
+ import joblib
58
+
59
+ payload = joblib.load(args.output)
60
+ payload["n_genomes"] = n_genomes
61
+ joblib.dump(payload, args.output)
62
+
63
+ print(f"\nModel saved to {args.output}")
64
+
65
+
66
+ def cmd_predict(args):
67
+ """Execute the 'predict' sub-command."""
68
+ logger = setup_logging(args.verbose)
69
+ print_banner()
70
+
71
+ # Load model
72
+ try:
73
+ payload = load_model(args.model)
74
+ except FileNotFoundError:
75
+ sys.exit(f"Error: Model file not found: {args.model}")
76
+ except ValueError as exc:
77
+ sys.exit(f"Error: {exc}")
78
+
79
+ clf = payload["classifier"]
80
+ label_names = payload["label_names"]
81
+ k = payload["k"]
82
+ kmer_vocab = payload["kmer_vocab"]
83
+
84
+ # Extract features from input FASTA
85
+ try:
86
+ vec = genome_to_vector(args.fasta, k, kmer_vocab)
87
+ except FileNotFoundError:
88
+ sys.exit(f"Error: FASTA file not found: {args.fasta}")
89
+ except ValueError as exc:
90
+ sys.exit(f"Error: {exc}")
91
+
92
+ import numpy as np
93
+
94
+ X_query = np.array([vec])
95
+ pred = clf.predict(X_query)[0]
96
+ proba = clf.predict_proba(X_query)[0]
97
+
98
+ species = label_names[pred]
99
+ confidence = proba[pred] * 100
100
+
101
+ print(f"Predicted species: {species}")
102
+ print(f"Confidence: {confidence:.1f}%")
103
+
104
+ if args.verbose:
105
+ print("\nAll species probabilities:")
106
+ # Sort by probability descending
107
+ ranked = sorted(
108
+ zip(label_names, proba), key=lambda x: x[1], reverse=True
109
+ )
110
+ max_name_len = max(len(name) for name in label_names)
111
+ for name, prob in ranked:
112
+ bar_len = int(prob * 40)
113
+ bar = "#" * bar_len
114
+ print(f" {name:<{max_name_len}} {prob * 100:5.1f}% |{bar}")
115
+
116
+
117
+ def main():
118
+ """Main entry point — parse arguments and dispatch to sub-commands."""
119
+ parser = argparse.ArgumentParser(
120
+ prog="baclasp",
121
+ description="BaClasT — Bacterial Classification Tool",
122
+ )
123
+ subparsers = parser.add_subparsers(dest="command", required=True)
124
+
125
+ # train sub-command
126
+ train_parser = subparsers.add_parser("train", help="Train a classifier")
127
+ train_parser.add_argument(
128
+ "--data_dir", required=True, help="Directory with species sub-folders"
129
+ )
130
+ train_parser.add_argument(
131
+ "--output", default="model.pkl", help="Output model path (default: model.pkl)"
132
+ )
133
+ train_parser.add_argument(
134
+ "--k", type=int, default=4, help="K-mer length (default: 4)"
135
+ )
136
+ train_parser.add_argument(
137
+ "--n_estimators",
138
+ type=int,
139
+ default=200,
140
+ help="Number of trees (default: 200)",
141
+ )
142
+ train_parser.add_argument(
143
+ "--cv", type=int, default=None, help="Number of CV folds (optional)"
144
+ )
145
+ train_parser.add_argument(
146
+ "--verbose", action="store_true", help="Enable verbose output"
147
+ )
148
+
149
+ # predict sub-command
150
+ predict_parser = subparsers.add_parser(
151
+ "predict", help="Predict species for a FASTA file"
152
+ )
153
+ predict_parser.add_argument(
154
+ "--model", required=True, help="Path to trained model .pkl"
155
+ )
156
+ predict_parser.add_argument(
157
+ "--fasta", required=True, help="Path to input FASTA file"
158
+ )
159
+ predict_parser.add_argument(
160
+ "--verbose", action="store_true", help="Show all species probabilities"
161
+ )
162
+
163
+ args = parser.parse_args()
164
+
165
+ if args.command == "train":
166
+ cmd_train(args)
167
+ elif args.command == "predict":
168
+ cmd_predict(args)
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()