XspecT 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of XspecT might be problematic. Click here for more details.
- xspect/classify.py +38 -8
- xspect/definitions.py +30 -10
- xspect/file_io.py +2 -1
- xspect/filter_sequences.py +20 -4
- xspect/main.py +126 -28
- xspect/misclassification_detection/__init__.py +0 -0
- xspect/misclassification_detection/mapping.py +168 -0
- xspect/misclassification_detection/point_pattern_analysis.py +102 -0
- xspect/misclassification_detection/simulate_reads.py +55 -0
- xspect/mlst_feature/mlst_helper.py +15 -19
- xspect/mlst_feature/pub_mlst_handler.py +16 -19
- xspect/model_management.py +14 -17
- xspect/models/probabilistic_filter_mlst_model.py +11 -10
- xspect/models/probabilistic_filter_model.py +142 -8
- xspect/models/probabilistic_filter_svm_model.py +29 -14
- xspect/models/probabilistic_single_filter_model.py +9 -7
- xspect/models/result.py +22 -15
- xspect/ncbi.py +82 -7
- xspect/train.py +21 -4
- xspect/web.py +13 -4
- {xspect-0.5.3.dist-info → xspect-0.6.0.dist-info}/METADATA +4 -1
- {xspect-0.5.3.dist-info → xspect-0.6.0.dist-info}/RECORD +26 -22
- {xspect-0.5.3.dist-info → xspect-0.6.0.dist-info}/WHEEL +0 -0
- {xspect-0.5.3.dist-info → xspect-0.6.0.dist-info}/entry_points.txt +0 -0
- {xspect-0.5.3.dist-info → xspect-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {xspect-0.5.3.dist-info → xspect-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Probabilistic filter model for sequence data"""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import shutil
|
|
4
5
|
from math import ceil
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import Any
|
|
@@ -9,9 +10,19 @@ from Bio.SeqRecord import SeqRecord
|
|
|
9
10
|
from Bio import SeqIO
|
|
10
11
|
from slugify import slugify
|
|
11
12
|
import cobs_index as cobs
|
|
12
|
-
from xspect.definitions import
|
|
13
|
+
from xspect.definitions import (
|
|
14
|
+
fasta_endings,
|
|
15
|
+
fastq_endings,
|
|
16
|
+
get_xspect_misclassification_path,
|
|
17
|
+
)
|
|
13
18
|
from xspect.file_io import get_record_iterator
|
|
19
|
+
from xspect.misclassification_detection.mapping import MappingHandler
|
|
14
20
|
from xspect.models.result import ModelResult
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
from xspect.ncbi import NCBIHandler
|
|
23
|
+
from xspect.misclassification_detection.point_pattern_analysis import (
|
|
24
|
+
PointPatternAnalysis,
|
|
25
|
+
)
|
|
15
26
|
|
|
16
27
|
|
|
17
28
|
class ProbabilisticFilterModel:
|
|
@@ -135,8 +146,8 @@ class ProbabilisticFilterModel:
|
|
|
135
146
|
display_names (dict | None): A dictionary mapping file names to display names.
|
|
136
147
|
If None, uses file names as display names.
|
|
137
148
|
training_accessions (dict[str, list[str]] | None): A dictionary mapping filter IDs to
|
|
138
|
-
lists of accession numbers used for training the model. If None, no training
|
|
139
|
-
are set.
|
|
149
|
+
lists of accession numbers used for training the model. If None, no training
|
|
150
|
+
accessions are set.
|
|
140
151
|
Raises:
|
|
141
152
|
ValueError: If the directory path is invalid, does not exist, or is not a directory.
|
|
142
153
|
"""
|
|
@@ -230,6 +241,8 @@ class ProbabilisticFilterModel:
|
|
|
230
241
|
),
|
|
231
242
|
filter_ids: list[str] = None,
|
|
232
243
|
step: int = 1,
|
|
244
|
+
display_name: bool = False,
|
|
245
|
+
validation: bool = False,
|
|
233
246
|
) -> ModelResult:
|
|
234
247
|
"""
|
|
235
248
|
Returns a model result object for the sequence(s) based on the filters in the model
|
|
@@ -246,6 +259,8 @@ class ProbabilisticFilterModel:
|
|
|
246
259
|
filter_ids (list[str]): A list of filter IDs to filter the results. If None,
|
|
247
260
|
all results are returned.
|
|
248
261
|
step (int): The step size for the k-mer search. Default is 1.
|
|
262
|
+
display_name (bool): Includes a display name for each tax_ID.
|
|
263
|
+
validation (bool): Sorts out misclassified reads.
|
|
249
264
|
|
|
250
265
|
Returns:
|
|
251
266
|
ModelResult: An object containing the hits for each sequence, the number of kmers,
|
|
@@ -253,11 +268,12 @@ class ProbabilisticFilterModel:
|
|
|
253
268
|
|
|
254
269
|
Raises:
|
|
255
270
|
ValueError: If the input sequence is not valid, or if it is not a Seq object,
|
|
256
|
-
a list of Seq objects, a SeqIO iterator, or a Path object to a fasta/fastq
|
|
271
|
+
a list of Seq objects, a SeqIO iterator, or a Path object to a fasta/fastq
|
|
272
|
+
file.
|
|
257
273
|
"""
|
|
258
274
|
if isinstance(sequence_input, (SeqRecord)):
|
|
259
275
|
return ProbabilisticFilterModel.predict(
|
|
260
|
-
self, [sequence_input], filter_ids, step
|
|
276
|
+
self, [sequence_input], filter_ids, step, display_name, validation
|
|
261
277
|
)
|
|
262
278
|
|
|
263
279
|
if self._is_sequence_list(sequence_input) | self._is_sequence_iterator(
|
|
@@ -265,19 +281,42 @@ class ProbabilisticFilterModel:
|
|
|
265
281
|
):
|
|
266
282
|
hits = {}
|
|
267
283
|
num_kmers = {}
|
|
284
|
+
if validation and self._is_sequence_iterator(sequence_input):
|
|
285
|
+
sequence_input = list(sequence_input)
|
|
286
|
+
|
|
268
287
|
for individual_sequence in sequence_input:
|
|
269
288
|
individual_hits = self.calculate_hits(
|
|
270
|
-
individual_sequence.seq, filter_ids, step
|
|
289
|
+
individual_sequence.seq, filter_ids, step
|
|
271
290
|
)
|
|
272
291
|
num_kmers[individual_sequence.id] = self._count_kmers(
|
|
273
|
-
individual_sequence, step
|
|
292
|
+
individual_sequence, step
|
|
274
293
|
)
|
|
294
|
+
|
|
295
|
+
if display_name:
|
|
296
|
+
individual_hits.update(
|
|
297
|
+
{
|
|
298
|
+
f"{key} -{self.display_names.get(key, 'Unknown').replace(
|
|
299
|
+
self.model_display_name, '', 1)}": individual_hits.pop(
|
|
300
|
+
key
|
|
301
|
+
)
|
|
302
|
+
for key in list(individual_hits.keys())
|
|
303
|
+
}
|
|
304
|
+
)
|
|
305
|
+
|
|
275
306
|
hits[individual_sequence.id] = individual_hits
|
|
307
|
+
|
|
308
|
+
if validation:
|
|
309
|
+
hits = self.detecting_misclassification(hits, sequence_input)
|
|
310
|
+
|
|
276
311
|
return ModelResult(self.slug(), hits, num_kmers, sparse_sampling_step=step)
|
|
277
312
|
|
|
278
313
|
if isinstance(sequence_input, Path):
|
|
279
314
|
return ProbabilisticFilterModel.predict(
|
|
280
|
-
self,
|
|
315
|
+
self,
|
|
316
|
+
get_record_iterator(sequence_input),
|
|
317
|
+
step=step,
|
|
318
|
+
display_name=display_name,
|
|
319
|
+
validation=validation,
|
|
281
320
|
)
|
|
282
321
|
|
|
283
322
|
raise ValueError(
|
|
@@ -460,3 +499,98 @@ class ProbabilisticFilterModel:
|
|
|
460
499
|
sequence_input,
|
|
461
500
|
(SeqIO.FastaIO.FastaIterator, SeqIO.QualityIO.FastqPhredIterator),
|
|
462
501
|
)
|
|
502
|
+
|
|
503
|
+
def detecting_misclassification(
|
|
504
|
+
self,
|
|
505
|
+
hits: dict[str, dict[str, int]],
|
|
506
|
+
seq_records: list[SeqRecord],
|
|
507
|
+
min_reads: int = 10,
|
|
508
|
+
) -> dict[str, dict[str, int]]:
|
|
509
|
+
"""
|
|
510
|
+
Notes:
|
|
511
|
+
Developed by Oemer Cetin as part of a Bsc thesis at Goethe University Frankfurt am Main (2025).
|
|
512
|
+
(An Integration of Alignment-Free and Alignment-Based Approaches for Bacterial Taxon Assignment)
|
|
513
|
+
|
|
514
|
+
Detects misclassification for short sequences.
|
|
515
|
+
|
|
516
|
+
This function is an alignment-based procedure that groups species by highest XspecT scores.
|
|
517
|
+
Each species group is mapped against the respective reference genome.
|
|
518
|
+
Start coordinates are extracted and scanned for local clustering.
|
|
519
|
+
When local clustering is detected, all sequences belonging to the species are sorted out.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
hits (dict): The species annotations from the prediction step.
|
|
523
|
+
seq_records (list): The provided sequences.
|
|
524
|
+
min_reads (int): Minimum amount of reads, that species groups should have.
|
|
525
|
+
|
|
526
|
+
Returns:
|
|
527
|
+
dict: hits where misclassified sequences have been sorted out.
|
|
528
|
+
"""
|
|
529
|
+
rec_by_id = {record.id: record for record in seq_records}
|
|
530
|
+
grouped: dict[int, list[SeqRecord]] = defaultdict(list)
|
|
531
|
+
misclassified = {}
|
|
532
|
+
|
|
533
|
+
# group by species annotation
|
|
534
|
+
for record, score_dict in hits.items():
|
|
535
|
+
if record == "misclassified":
|
|
536
|
+
continue
|
|
537
|
+
sorted_hits = sorted(
|
|
538
|
+
score_dict.items(), key=lambda entry: entry[1], reverse=True
|
|
539
|
+
)
|
|
540
|
+
if sorted_hits[0][1] > sorted_hits[1][1]: # unique highest score
|
|
541
|
+
highest_tax_id = int(sorted_hits[0][0]) # tax_id
|
|
542
|
+
if record in rec_by_id:
|
|
543
|
+
# groups all reads with the highest score by tax_id
|
|
544
|
+
grouped[highest_tax_id].append(rec_by_id[record])
|
|
545
|
+
filtered_grouped = {
|
|
546
|
+
tax_id: seq for tax_id, seq in grouped.items() if len(seq) > min_reads
|
|
547
|
+
}
|
|
548
|
+
largest_group = max(
|
|
549
|
+
filtered_grouped,
|
|
550
|
+
key=lambda tax_id: len(filtered_grouped[tax_id]),
|
|
551
|
+
default=None,
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# mapping procedure
|
|
555
|
+
handler = NCBIHandler()
|
|
556
|
+
out_dir = get_xspect_misclassification_path()
|
|
557
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
558
|
+
for tax_id, reads in filtered_grouped.items():
|
|
559
|
+
if tax_id == largest_group:
|
|
560
|
+
continue
|
|
561
|
+
|
|
562
|
+
tax_dir = out_dir / str(tax_id)
|
|
563
|
+
tax_dir.mkdir(parents=True, exist_ok=True)
|
|
564
|
+
fasta_path = tax_dir / f"{tax_id}.fasta"
|
|
565
|
+
SeqIO.write(reads, fasta_path, "fasta")
|
|
566
|
+
reference_path = tax_dir / f"{tax_id}.fna"
|
|
567
|
+
# download reference once
|
|
568
|
+
if not (reference_path.exists() and reference_path.stat().st_size > 0):
|
|
569
|
+
handler.download_reference_genome(tax_id, tax_dir)
|
|
570
|
+
if not reference_path.exists():
|
|
571
|
+
shutil.rmtree(tax_dir)
|
|
572
|
+
continue
|
|
573
|
+
|
|
574
|
+
mapping_handler = MappingHandler(str(reference_path), str(fasta_path))
|
|
575
|
+
mapping_handler.map_reads_onto_reference()
|
|
576
|
+
mapping_handler.extract_starting_coordinates()
|
|
577
|
+
genome_length = mapping_handler.get_total_genome_length()
|
|
578
|
+
start_coordinates = mapping_handler.get_start_coordinates()
|
|
579
|
+
|
|
580
|
+
if len(start_coordinates) < min_reads:
|
|
581
|
+
continue
|
|
582
|
+
|
|
583
|
+
# cluster analysis
|
|
584
|
+
analysis = PointPatternAnalysis(start_coordinates, genome_length)
|
|
585
|
+
clustered = analysis.ripleys_k_edge_corrected()
|
|
586
|
+
|
|
587
|
+
if clustered[0]: # True or False
|
|
588
|
+
bucket = misclassified.setdefault(tax_id, {})
|
|
589
|
+
for read in reads:
|
|
590
|
+
data = hits.pop(read.id, None) # remove false reads from main hits
|
|
591
|
+
if data is not None:
|
|
592
|
+
bucket[read.id] = data
|
|
593
|
+
|
|
594
|
+
if misclassified:
|
|
595
|
+
hits["misclassified"] = misclassified
|
|
596
|
+
return hits
|
|
@@ -55,10 +55,14 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
55
55
|
base_path (Path): The base path where the model will be stored.
|
|
56
56
|
kernel (str): The kernel type for the SVM (e.g., 'linear', 'rbf').
|
|
57
57
|
c (float): Regularization parameter for the SVM.
|
|
58
|
-
fpr (float, optional): False positive rate for the probabilistic filter.
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
58
|
+
fpr (float, optional): False positive rate for the probabilistic filter.
|
|
59
|
+
Defaults to 0.01.
|
|
60
|
+
num_hashes (int, optional): Number of hashes for the probabilistic filter.
|
|
61
|
+
Defaults to 7.
|
|
62
|
+
training_accessions (dict[str, list[str]] | None, optional): Accessions used for
|
|
63
|
+
training the probabilistic filter. Defaults to None.
|
|
64
|
+
svm_accessions (dict[str, list[str]] | None, optional): Accessions used for
|
|
65
|
+
training the SVM. Defaults to None.
|
|
62
66
|
"""
|
|
63
67
|
super().__init__(
|
|
64
68
|
k=k,
|
|
@@ -112,17 +116,18 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
112
116
|
"""
|
|
113
117
|
Fit the SVM to the sequences and labels.
|
|
114
118
|
|
|
115
|
-
This method first trains the probabilistic filter model and then
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
+
This method first trains the probabilistic filter model and then calculates scores for
|
|
120
|
+
the SVM training. It expects the sequences to be in the specified directory and the SVM
|
|
121
|
+
training sequences to be in the specified SVM path. The scores are saved in a CSV file
|
|
122
|
+
for later use.
|
|
119
123
|
|
|
120
124
|
Args:
|
|
121
125
|
dir_path (Path): The directory containing the training sequences.
|
|
122
126
|
svm_path (Path): The directory containing the SVM training sequences.
|
|
123
127
|
display_names (dict[str, str] | None): A mapping of accession IDs to display names.
|
|
124
128
|
svm_step (int): Step size for sparse sampling in SVM training.
|
|
125
|
-
training_accessions (dict[str, list[str]] | None): Accessions used for training the
|
|
129
|
+
training_accessions (dict[str, list[str]] | None): Accessions used for training the
|
|
130
|
+
probabilistic filter.
|
|
126
131
|
svm_accessions (dict[str, list[str]] | None): Accessions used for training the SVM.
|
|
127
132
|
"""
|
|
128
133
|
|
|
@@ -178,6 +183,8 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
178
183
|
),
|
|
179
184
|
filter_ids: list[str] = None,
|
|
180
185
|
step: int = 1,
|
|
186
|
+
display_name: bool = False,
|
|
187
|
+
validation: bool = False,
|
|
181
188
|
) -> ModelResult:
|
|
182
189
|
"""
|
|
183
190
|
Predict the labels of the sequences.
|
|
@@ -187,19 +194,26 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
187
194
|
with the probabilistic filter model, and it will return a `ModelResult`.
|
|
188
195
|
|
|
189
196
|
Args:
|
|
190
|
-
sequence_input (SeqRecord | list[SeqRecord] | SeqIO.FastaIO.FastaIterator |
|
|
191
|
-
|
|
197
|
+
sequence_input (SeqRecord | list[SeqRecord] | SeqIO.FastaIO.FastaIterator |
|
|
198
|
+
SeqIO.QualityIO.FastqPhredIterator | Path): The input sequences to predict.
|
|
199
|
+
filter_ids (list[str], optional): A list of IDs to filter the predictions.
|
|
192
200
|
step (int, optional): Step size for sparse sampling. Defaults to 1.
|
|
201
|
+
display_name (bool): Includes a display name for each tax_ID.
|
|
202
|
+
validation (bool): Sorts out misclassified reads .
|
|
193
203
|
|
|
194
204
|
Returns:
|
|
195
|
-
ModelResult: The result of the prediction containing hits, number of kmers, and the
|
|
205
|
+
ModelResult: The result of the prediction containing hits, number of kmers, and the
|
|
206
|
+
predicted label.
|
|
196
207
|
"""
|
|
197
208
|
# get scores and format them for the SVM
|
|
198
|
-
res = super().predict(
|
|
209
|
+
res = super().predict(
|
|
210
|
+
sequence_input, filter_ids, step, display_name, validation
|
|
211
|
+
)
|
|
199
212
|
svm_scores = dict(sorted(res.get_scores()["total"].items()))
|
|
200
213
|
svm_scores = [list(svm_scores.values())]
|
|
201
214
|
|
|
202
215
|
svm = self._get_svm(filter_ids)
|
|
216
|
+
res.hits["misclassified"] = res.misclassified
|
|
203
217
|
return ModelResult(
|
|
204
218
|
self.slug(),
|
|
205
219
|
res.hits,
|
|
@@ -217,7 +231,8 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
217
231
|
training data to only include those keys.
|
|
218
232
|
|
|
219
233
|
Args:
|
|
220
|
-
id_keys (list[str] | None): A list of IDs to filter the training data.
|
|
234
|
+
id_keys (list[str] | None): A list of IDs to filter the training data.
|
|
235
|
+
If None, all data is used.
|
|
221
236
|
|
|
222
237
|
Returns:
|
|
223
238
|
SVC: The trained SVM model.
|
|
@@ -34,8 +34,8 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
34
34
|
) -> None:
|
|
35
35
|
"""Initialize probabilistic single filter model.
|
|
36
36
|
|
|
37
|
-
This model uses a Bloom filter to store k-mers from the training sequences. It is designed
|
|
38
|
-
be used with a single filter, which is suitable e.g. for genus-level classification.
|
|
37
|
+
This model uses a Bloom filter to store k-mers from the training sequences. It is designed
|
|
38
|
+
to be used with a single filter, which is suitable e.g. for genus-level classification.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
41
|
k (int): Length of the k-mers to use for filtering
|
|
@@ -45,7 +45,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
45
45
|
model_type (str): Type of the model, e.g. "probabilistic_single_filter"
|
|
46
46
|
base_path (Path): Base path where the model will be saved
|
|
47
47
|
fpr (float): False positive rate for the Bloom filter, default is 0.01
|
|
48
|
-
training_accessions (list[str] | None): List of accessions used for training
|
|
48
|
+
training_accessions (list[str] | None): List of accessions used for training
|
|
49
49
|
"""
|
|
50
50
|
super().__init__(
|
|
51
51
|
k=k,
|
|
@@ -75,7 +75,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
75
75
|
Args:
|
|
76
76
|
file_path (Path): Path to the file containing sequences in FASTA format
|
|
77
77
|
display_name (str): Display name for the model
|
|
78
|
-
training_accessions (list[str] | None): List of accessions used for training
|
|
78
|
+
training_accessions (list[str] | None): List of accessions used for training
|
|
79
79
|
"""
|
|
80
80
|
self.training_accessions = training_accessions
|
|
81
81
|
|
|
@@ -104,7 +104,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
104
104
|
Calculates the number of k-mers in the sequence that are present in the Bloom filter.
|
|
105
105
|
|
|
106
106
|
Args:
|
|
107
|
-
sequence (Seq | SeqRecord): Sequence to calculate hits for
|
|
107
|
+
sequence (Seq | SeqRecord): Sequence to calculate hits for
|
|
108
108
|
filter_ids (list[str] | None): List of filter IDs to use, default is None
|
|
109
109
|
step (int): Step size for generating k-mers, default is 1
|
|
110
110
|
Returns:
|
|
@@ -162,13 +162,15 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
162
162
|
"""
|
|
163
163
|
Generate kmers from the sequence
|
|
164
164
|
|
|
165
|
-
Generates k-mers from the sequence, considering both the forward and reverse complement
|
|
165
|
+
Generates k-mers from the sequence, considering both the forward and reverse complement
|
|
166
|
+
strands.
|
|
166
167
|
|
|
167
168
|
Args:
|
|
168
169
|
sequence (Seq): Sequence to generate k-mers from
|
|
169
170
|
step (int): Step size for generating k-mers, default is 1
|
|
170
171
|
Yields:
|
|
171
|
-
str: The minimizer k-mer (the lexicographically smallest k-mer between the forward and
|
|
172
|
+
str: The minimizer k-mer (the lexicographically smallest k-mer between the forward and
|
|
173
|
+
reverse complement)
|
|
172
174
|
"""
|
|
173
175
|
num_kmers = ceil((len(sequence) - self.k + 1) / step)
|
|
174
176
|
for i in range(num_kmers):
|
xspect/models/result.py
CHANGED
|
@@ -40,6 +40,7 @@ class ModelResult:
|
|
|
40
40
|
self.sparse_sampling_step = sparse_sampling_step
|
|
41
41
|
self.prediction = prediction
|
|
42
42
|
self.input_source = input_source
|
|
43
|
+
self.misclassified = self.hits.pop("misclassified", None)
|
|
43
44
|
|
|
44
45
|
def get_scores(self) -> dict:
|
|
45
46
|
"""
|
|
@@ -50,7 +51,8 @@ class ModelResult:
|
|
|
50
51
|
|
|
51
52
|
Returns:
|
|
52
53
|
dict: A dictionary where keys are subsequence names and values are dictionaries
|
|
53
|
-
with labels as keys and scores as values. Also includes a 'total' key for
|
|
54
|
+
with labels as keys and scores as values. Also includes a 'total' key for
|
|
55
|
+
overall scores.
|
|
54
56
|
"""
|
|
55
57
|
scores = {
|
|
56
58
|
subsequence: {
|
|
@@ -78,7 +80,8 @@ class ModelResult:
|
|
|
78
80
|
The total hits are calculated by summing the hits for each label across all subsequences.
|
|
79
81
|
|
|
80
82
|
Returns:
|
|
81
|
-
dict: A dictionary where keys are labels and values are the total number of hits for
|
|
83
|
+
dict: A dictionary where keys are labels and values are the total number of hits for
|
|
84
|
+
that label.
|
|
82
85
|
"""
|
|
83
86
|
total_hits = {label: 0 for label in list(self.hits.values())[0]}
|
|
84
87
|
for _, subsequence_hits in self.hits.items():
|
|
@@ -97,8 +100,8 @@ class ModelResult:
|
|
|
97
100
|
|
|
98
101
|
Args:
|
|
99
102
|
label (str): The label for which to filter the subsequences.
|
|
100
|
-
filter_threshold (float): The threshold for filtering subsequences. Must be between 0
|
|
101
|
-
or -1 to return the subsequence with the maximum score for the label.
|
|
103
|
+
filter_threshold (float): The threshold for filtering subsequences. Must be between 0
|
|
104
|
+
and 1, or -1 to return the subsequence with the maximum score for the label.
|
|
102
105
|
|
|
103
106
|
Returns:
|
|
104
107
|
dict[str, bool]: A dictionary where keys are subsequence names and values are booleans
|
|
@@ -114,11 +117,10 @@ class ModelResult:
|
|
|
114
117
|
subsequence: score[label] >= filter_threshold
|
|
115
118
|
for subsequence, score in scores.items()
|
|
116
119
|
}
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
}
|
|
120
|
+
return {
|
|
121
|
+
subsequence: score[label] == max(score.values())
|
|
122
|
+
for subsequence, score in scores.items()
|
|
123
|
+
}
|
|
122
124
|
|
|
123
125
|
def get_filtered_subsequence_labels(
|
|
124
126
|
self, label: str, filter_threshold: float = 0.7
|
|
@@ -126,15 +128,17 @@ class ModelResult:
|
|
|
126
128
|
"""
|
|
127
129
|
Return the labels of filtered subsequences.
|
|
128
130
|
|
|
129
|
-
This method filters subsequences based on the scores for a given label and a filter
|
|
131
|
+
This method filters subsequences based on the scores for a given label and a filter
|
|
132
|
+
threshold.
|
|
130
133
|
|
|
131
134
|
Args:
|
|
132
135
|
label (str): The label for which to filter the subsequences.
|
|
133
|
-
filter_threshold (float): The threshold for filtering subsequences. Must be between 0
|
|
134
|
-
or -1 to return the subsequence with the maximum score for the label.
|
|
136
|
+
filter_threshold (float): The threshold for filtering subsequences. Must be between 0
|
|
137
|
+
and 1, or -1 to return the subsequence with the maximum score for the label.
|
|
135
138
|
|
|
136
139
|
Returns:
|
|
137
|
-
list[str]: A list of subsequence names that meet the filter criteria for the given
|
|
140
|
+
list[str]: A list of subsequence names that meet the filter criteria for the given
|
|
141
|
+
label.
|
|
138
142
|
"""
|
|
139
143
|
return [
|
|
140
144
|
subsequence
|
|
@@ -148,11 +152,13 @@ class ModelResult:
|
|
|
148
152
|
"""
|
|
149
153
|
Return the result as a dictionary.
|
|
150
154
|
|
|
151
|
-
This method converts the ModelResult object into a dictionary format suitable for
|
|
155
|
+
This method converts the ModelResult object into a dictionary format suitable for
|
|
156
|
+
serialization.
|
|
152
157
|
|
|
153
158
|
Returns:
|
|
154
159
|
dict: A dictionary representation of the ModelResult object, including model slug,
|
|
155
|
-
sparse sampling step, hits, scores, number of k-mers, input source, and prediction if
|
|
160
|
+
sparse sampling step, hits, scores, number of k-mers, input source, and prediction if
|
|
161
|
+
available.
|
|
156
162
|
"""
|
|
157
163
|
res = {
|
|
158
164
|
"model_slug": self.model_slug,
|
|
@@ -160,6 +166,7 @@ class ModelResult:
|
|
|
160
166
|
"hits": self.hits,
|
|
161
167
|
"scores": self.get_scores(),
|
|
162
168
|
"num_kmers": self.num_kmers,
|
|
169
|
+
"misclassified": self.misclassified,
|
|
163
170
|
"input_source": self.input_source,
|
|
164
171
|
}
|
|
165
172
|
|
xspect/ncbi.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
"""NCBI handler for the NCBI Datasets API."""
|
|
2
2
|
|
|
3
|
+
import shutil
|
|
3
4
|
from enum import Enum
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
import time
|
|
6
7
|
from loguru import logger
|
|
7
8
|
import requests
|
|
9
|
+
import zipfile
|
|
8
10
|
|
|
9
11
|
# pylint: disable=line-too-long
|
|
10
12
|
|
|
@@ -194,8 +196,9 @@ class NCBIHandler:
|
|
|
194
196
|
assembly_level: AssemblyLevel,
|
|
195
197
|
assembly_source: AssemblySource,
|
|
196
198
|
count: int,
|
|
197
|
-
min_n50: int
|
|
198
|
-
exclude_atypical: bool
|
|
199
|
+
min_n50: int,
|
|
200
|
+
exclude_atypical: bool,
|
|
201
|
+
allow_inconclusive: bool,
|
|
199
202
|
exclude_paired_reports: bool = True,
|
|
200
203
|
current_version_only: bool = True,
|
|
201
204
|
) -> list[str]:
|
|
@@ -211,11 +214,13 @@ class NCBIHandler:
|
|
|
211
214
|
assembly_level (AssemblyLevel): The assembly level to get the accessions for.
|
|
212
215
|
assembly_source (AssemblySource): The assembly source to get the accessions for.
|
|
213
216
|
count (int): The number of accessions to get.
|
|
214
|
-
min_n50 (int
|
|
215
|
-
exclude_atypical (bool
|
|
217
|
+
min_n50 (int): The minimum contig n50 to filter the accessions.
|
|
218
|
+
exclude_atypical (bool): Whether to exclude atypical accessions.
|
|
219
|
+
allow_inconclusive (bool): Whether to allow accessions with an inconclusive taxonomy check status.
|
|
216
220
|
exclude_paired_reports (bool, optional): Whether to exclude paired reports. Defaults to True.
|
|
217
221
|
current_version_only (bool, optional): Whether to get only the current version of the accessions. Defaults to True.
|
|
218
222
|
|
|
223
|
+
|
|
219
224
|
Returns:
|
|
220
225
|
list[str]: A list containing the accessions.
|
|
221
226
|
"""
|
|
@@ -240,8 +245,11 @@ class NCBIHandler:
|
|
|
240
245
|
report["accession"]
|
|
241
246
|
for report in response["reports"]
|
|
242
247
|
if report["assembly_stats"]["contig_n50"] >= min_n50
|
|
243
|
-
and
|
|
244
|
-
|
|
248
|
+
and (
|
|
249
|
+
allow_inconclusive
|
|
250
|
+
or report["average_nucleotide_identity"]["taxonomy_check_status"]
|
|
251
|
+
== "OK"
|
|
252
|
+
)
|
|
245
253
|
]
|
|
246
254
|
except (IndexError, KeyError, TypeError):
|
|
247
255
|
logger.debug(
|
|
@@ -251,7 +259,13 @@ class NCBIHandler:
|
|
|
251
259
|
return accessions[:count] # Limit to count
|
|
252
260
|
|
|
253
261
|
def get_highest_quality_accessions(
|
|
254
|
-
self,
|
|
262
|
+
self,
|
|
263
|
+
taxon_id: int,
|
|
264
|
+
assembly_source: AssemblySource,
|
|
265
|
+
count: int,
|
|
266
|
+
min_n50: int,
|
|
267
|
+
exclude_atypical: bool,
|
|
268
|
+
allow_inconclusive: bool,
|
|
255
269
|
) -> list[str]:
|
|
256
270
|
"""
|
|
257
271
|
Get the highest quality accessions for a given taxon id (based on the assembly level).
|
|
@@ -263,6 +277,9 @@ class NCBIHandler:
|
|
|
263
277
|
taxon_id (int): The taxon id to get the accessions for.
|
|
264
278
|
assembly_source (AssemblySource): The assembly source to get the accessions for.
|
|
265
279
|
count (int): The number of accessions to get.
|
|
280
|
+
min_n50 (int): The minimum contig n50 to filter the accessions.
|
|
281
|
+
exclude_atypical (bool): Whether to exclude atypical accessions.
|
|
282
|
+
allow_inconclusive (bool): Whether to allow accessions with an inconclusive taxonomy check status.
|
|
266
283
|
|
|
267
284
|
Returns:
|
|
268
285
|
list[str]: A list containing the highest quality accessions.
|
|
@@ -274,6 +291,9 @@ class NCBIHandler:
|
|
|
274
291
|
assembly_level,
|
|
275
292
|
assembly_source,
|
|
276
293
|
count,
|
|
294
|
+
min_n50=min_n50,
|
|
295
|
+
exclude_atypical=exclude_atypical,
|
|
296
|
+
allow_inconclusive=allow_inconclusive,
|
|
277
297
|
)
|
|
278
298
|
if len(set(accessions)) >= count:
|
|
279
299
|
break
|
|
@@ -302,3 +322,58 @@ class NCBIHandler:
|
|
|
302
322
|
with open(output_dir / "ncbi_dataset.zip", "wb") as f:
|
|
303
323
|
for chunk in response.iter_content(chunk_size=8192):
|
|
304
324
|
f.write(chunk)
|
|
325
|
+
|
|
326
|
+
def download_reference_genome(self, taxon_id: int, output_dir: Path) -> Path | None:
|
|
327
|
+
"""
|
|
328
|
+
Notes:
|
|
329
|
+
Developed by Oemer Cetin as part of a Bsc thesis at Goethe University Frankfurt am Main (2025).
|
|
330
|
+
(An Integration of Alignment-Free and Alignment-Based Approaches for Bacterial Taxon Assignment)
|
|
331
|
+
|
|
332
|
+
Downloads the reference genome from the RefSeq-DB for a given taxon ID.
|
|
333
|
+
|
|
334
|
+
This function queries the NCBI Datasets API for the reference genome and downloads it.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
taxon_id (int): The taxonomy ID of the species.
|
|
338
|
+
output_dir (Path): Directory where the genome will be saved.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
Path: Path to the downloaded ZIP file.
|
|
342
|
+
"""
|
|
343
|
+
accessions = self.get_accessions(
|
|
344
|
+
taxon_id=taxon_id,
|
|
345
|
+
assembly_level=AssemblyLevel.REFERENCE,
|
|
346
|
+
assembly_source=AssemblySource.REFSEQ,
|
|
347
|
+
count=1, # only one reference exists
|
|
348
|
+
min_n50=0,
|
|
349
|
+
exclude_atypical=True,
|
|
350
|
+
allow_inconclusive=False,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
if not accessions:
|
|
354
|
+
return None
|
|
355
|
+
|
|
356
|
+
logger.info(
|
|
357
|
+
f"Downloading reference genome for taxon {taxon_id}: {accessions[0]}"
|
|
358
|
+
)
|
|
359
|
+
self.download_assemblies(accessions, output_dir)
|
|
360
|
+
|
|
361
|
+
zip_path = output_dir / "ncbi_dataset.zip"
|
|
362
|
+
|
|
363
|
+
fna_file = ""
|
|
364
|
+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
|
365
|
+
for file in zip_ref.namelist():
|
|
366
|
+
if file.endswith(".fna"):
|
|
367
|
+
extracted_path = zip_ref.extract(file, path=output_dir)
|
|
368
|
+
fna_file = output_dir / f"{taxon_id}.fna"
|
|
369
|
+
Path(extracted_path).rename(
|
|
370
|
+
fna_file
|
|
371
|
+
) # consistent file name (tax_id)
|
|
372
|
+
logger.info(f"Extracted reference genome to {fna_file}")
|
|
373
|
+
break
|
|
374
|
+
|
|
375
|
+
# clean up
|
|
376
|
+
zip_path.unlink()
|
|
377
|
+
shutil.rmtree(output_dir / "ncbi_dataset")
|
|
378
|
+
|
|
379
|
+
return fna_file
|
xspect/train.py
CHANGED
|
@@ -186,6 +186,11 @@ def train_from_ncbi(
|
|
|
186
186
|
author: str | None = None,
|
|
187
187
|
author_email: str | None = None,
|
|
188
188
|
ncbi_api_key: str | None = None,
|
|
189
|
+
min_n50: int = 10000,
|
|
190
|
+
exclude_atypical: bool = True,
|
|
191
|
+
allow_inconclusive: bool = False,
|
|
192
|
+
allow_candidatus: bool = False,
|
|
193
|
+
allow_sp: bool = False,
|
|
189
194
|
):
|
|
190
195
|
"""
|
|
191
196
|
Train a model using NCBI assembly data for a given genus.
|
|
@@ -200,6 +205,11 @@ def train_from_ncbi(
|
|
|
200
205
|
author (str, optional): Author of the model. Defaults to None.
|
|
201
206
|
author_email (str, optional): Author's email. Defaults to None.
|
|
202
207
|
ncbi_api_key (str, optional): NCBI API key for accessing NCBI resources. Defaults to None.
|
|
208
|
+
min_n50 (int, optional): Minimum N50 value for assemblies. Defaults to 10000.
|
|
209
|
+
exclude_atypical (bool, optional): Exclude atypical assemblies. Defaults to True.
|
|
210
|
+
allow_inconclusive (bool, optional): Allow use of accessions with inconclusive taxonomy check status. Defaults to False.
|
|
211
|
+
allow_candidatus (bool, optional): Allow use of Candidatus species for training. Defaults to False.
|
|
212
|
+
allow_sp (bool, optional): Allow use of species with "sp." in their names. Defaults to False.
|
|
203
213
|
|
|
204
214
|
Raises:
|
|
205
215
|
TypeError: If `genus` is not a string.
|
|
@@ -221,8 +231,8 @@ def train_from_ncbi(
|
|
|
221
231
|
filtered_species_ids = [
|
|
222
232
|
tax_id
|
|
223
233
|
for tax_id in species_ids
|
|
224
|
-
if "candidatus" not in species_names[tax_id].lower()
|
|
225
|
-
and " sp." not in species_names[tax_id].lower()
|
|
234
|
+
if (allow_candidatus or "candidatus" not in species_names[tax_id].lower())
|
|
235
|
+
and (allow_sp or " sp." not in species_names[tax_id].lower())
|
|
226
236
|
]
|
|
227
237
|
filtered_species_names = {
|
|
228
238
|
str(tax_id): species_names[tax_id] for tax_id in filtered_species_ids
|
|
@@ -231,7 +241,12 @@ def train_from_ncbi(
|
|
|
231
241
|
accessions = {}
|
|
232
242
|
for tax_id in filtered_species_ids:
|
|
233
243
|
taxon_accessions = ncbi_handler.get_highest_quality_accessions(
|
|
234
|
-
tax_id,
|
|
244
|
+
tax_id,
|
|
245
|
+
AssemblySource.REFSEQ,
|
|
246
|
+
8,
|
|
247
|
+
min_n50,
|
|
248
|
+
exclude_atypical,
|
|
249
|
+
allow_inconclusive,
|
|
235
250
|
)
|
|
236
251
|
if not taxon_accessions:
|
|
237
252
|
logger.warning(f"No assemblies found for tax_id {tax_id}. Skipping.")
|
|
@@ -241,7 +256,9 @@ def train_from_ncbi(
|
|
|
241
256
|
|
|
242
257
|
if not accessions:
|
|
243
258
|
raise ValueError(
|
|
244
|
-
"No species with accessions found.
|
|
259
|
+
"No species with accessions found. "
|
|
260
|
+
"Please check if the genus name is correct or if there are any data quality issues "
|
|
261
|
+
"(e. g. inconclusive taxonomy check status, atypical assemblies, low N50 values)."
|
|
245
262
|
)
|
|
246
263
|
|
|
247
264
|
with TemporaryDirectory() as tmp_dir:
|