XspecT 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of XspecT might be problematic. Click here for more details.
- xspect/classify.py +61 -13
- xspect/definitions.py +61 -13
- xspect/download_models.py +10 -2
- xspect/file_io.py +115 -48
- xspect/filter_sequences.py +81 -29
- xspect/main.py +90 -39
- xspect/mlst_feature/mlst_helper.py +3 -0
- xspect/mlst_feature/pub_mlst_handler.py +43 -1
- xspect/model_management.py +84 -14
- xspect/models/probabilistic_filter_mlst_model.py +75 -37
- xspect/models/probabilistic_filter_model.py +201 -19
- xspect/models/probabilistic_filter_svm_model.py +106 -13
- xspect/models/probabilistic_single_filter_model.py +73 -9
- xspect/models/result.py +77 -10
- xspect/ncbi.py +48 -12
- xspect/train.py +19 -11
- xspect/web.py +68 -12
- xspect/xspect-web/dist/assets/index-Ceo58xui.css +1 -0
- xspect/xspect-web/dist/assets/{index-CMG4V7fZ.js → index-Dt_UlbgE.js} +82 -77
- xspect/xspect-web/dist/index.html +2 -2
- xspect/xspect-web/src/App.tsx +4 -2
- xspect/xspect-web/src/api.tsx +23 -1
- xspect/xspect-web/src/components/filter-form.tsx +16 -3
- xspect/xspect-web/src/components/filtering-result.tsx +65 -0
- xspect/xspect-web/src/components/result.tsx +2 -2
- xspect/xspect-web/src/types.tsx +5 -0
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/METADATA +11 -5
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/RECORD +32 -31
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/WHEEL +1 -1
- xspect/xspect-web/dist/assets/index-jIKg1HIy.css +0 -1
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/entry_points.txt +0 -0
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -15,23 +15,51 @@ from xspect.models.result import ModelResult
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
18
|
-
"""
|
|
18
|
+
"""
|
|
19
|
+
Probabilistic filter SVM model for sequence data
|
|
20
|
+
|
|
21
|
+
In addition to the standard probabilistic filter model, this model uses an SVM to predict
|
|
22
|
+
labels based on their scores and training data. It requires the `scikit-learn` library
|
|
23
|
+
to be installed.
|
|
24
|
+
"""
|
|
19
25
|
|
|
20
26
|
def __init__(
|
|
21
27
|
self,
|
|
22
28
|
k: int,
|
|
23
29
|
model_display_name: str,
|
|
24
|
-
author: str,
|
|
25
|
-
author_email: str,
|
|
30
|
+
author: str | None,
|
|
31
|
+
author_email: str | None,
|
|
26
32
|
model_type: str,
|
|
27
33
|
base_path: Path,
|
|
28
34
|
kernel: str,
|
|
29
35
|
c: float,
|
|
30
36
|
fpr: float = 0.01,
|
|
31
37
|
num_hashes: int = 7,
|
|
32
|
-
training_accessions: dict[str, list[str]] = None,
|
|
33
|
-
svm_accessions: dict[str, list[str]] = None,
|
|
38
|
+
training_accessions: dict[str, list[str]] | None = None,
|
|
39
|
+
svm_accessions: dict[str, list[str]] | None = None,
|
|
34
40
|
) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Initialize the SVM model with the given parameters.
|
|
43
|
+
|
|
44
|
+
In addition to the standard parameters, this model uses an SVM.
|
|
45
|
+
Therefore, it requires the `kernel` and `C` parameters to be set.
|
|
46
|
+
Furthermore, the `svm_accessions` parameter is used to store which accessions
|
|
47
|
+
are used for training the SVM.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
k (int): The k-mer size for the probabilistic filter.
|
|
51
|
+
model_display_name (str): The display name of the model.
|
|
52
|
+
author (str | None): The author of the model.
|
|
53
|
+
author_email (str | None): The author's email address.
|
|
54
|
+
model_type (str): The type of the model.
|
|
55
|
+
base_path (Path): The base path where the model will be stored.
|
|
56
|
+
kernel (str): The kernel type for the SVM (e.g., 'linear', 'rbf').
|
|
57
|
+
c (float): Regularization parameter for the SVM.
|
|
58
|
+
fpr (float, optional): False positive rate for the probabilistic filter. Defaults to 0.01.
|
|
59
|
+
num_hashes (int, optional): Number of hashes for the probabilistic filter. Defaults to 7.
|
|
60
|
+
training_accessions (dict[str, list[str]] | None, optional): Accessions used for training the probabilistic filter. Defaults to None.
|
|
61
|
+
svm_accessions (dict[str, list[str]] | None, optional): Accessions used for training the SVM. Defaults to None.
|
|
62
|
+
"""
|
|
35
63
|
super().__init__(
|
|
36
64
|
k=k,
|
|
37
65
|
model_display_name=model_display_name,
|
|
@@ -48,6 +76,12 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
48
76
|
self.svm_accessions = svm_accessions
|
|
49
77
|
|
|
50
78
|
def to_dict(self) -> dict:
|
|
79
|
+
"""
|
|
80
|
+
Convert the model to a dictionary representation
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
dict: A dictionary containing the model's parameters and state.
|
|
84
|
+
"""
|
|
51
85
|
return super().to_dict() | {
|
|
52
86
|
"kernel": self.kernel,
|
|
53
87
|
"C": self.c,
|
|
@@ -55,7 +89,13 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
55
89
|
}
|
|
56
90
|
|
|
57
91
|
def set_svm_params(self, kernel: str, c: float) -> None:
|
|
58
|
-
"""
|
|
92
|
+
"""
|
|
93
|
+
Set the parameters for the SVM
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
kernel (str): The kernel type for the SVM (e.g., 'linear', 'rbf').
|
|
97
|
+
c (float): Regularization parameter for the SVM.
|
|
98
|
+
"""
|
|
59
99
|
self.kernel = kernel
|
|
60
100
|
self.c = c
|
|
61
101
|
self.save()
|
|
@@ -64,12 +104,27 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
64
104
|
self,
|
|
65
105
|
dir_path: Path,
|
|
66
106
|
svm_path: Path,
|
|
67
|
-
display_names: dict = None,
|
|
107
|
+
display_names: dict[str, str] | None = None,
|
|
68
108
|
svm_step: int = 1,
|
|
69
|
-
training_accessions: list[str] = None,
|
|
70
|
-
svm_accessions: list[str] = None,
|
|
109
|
+
training_accessions: dict[str, list[str]] | None = None,
|
|
110
|
+
svm_accessions: dict[str, list[str]] | None = None,
|
|
71
111
|
) -> None:
|
|
72
|
-
"""
|
|
112
|
+
"""
|
|
113
|
+
Fit the SVM to the sequences and labels.
|
|
114
|
+
|
|
115
|
+
This method first trains the probabilistic filter model and then
|
|
116
|
+
calculates scores for the SVM training. It expects the sequences to be in
|
|
117
|
+
the specified directory and the SVM training sequences to be in the
|
|
118
|
+
specified SVM path. The scores are saved in a CSV file for later use.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
dir_path (Path): The directory containing the training sequences.
|
|
122
|
+
svm_path (Path): The directory containing the SVM training sequences.
|
|
123
|
+
display_names (dict[str, str] | None): A mapping of accession IDs to display names.
|
|
124
|
+
svm_step (int): Step size for sparse sampling in SVM training.
|
|
125
|
+
training_accessions (dict[str, list[str]] | None): Accessions used for training the probabilistic filter.
|
|
126
|
+
svm_accessions (dict[str, list[str]] | None): Accessions used for training the SVM.
|
|
127
|
+
"""
|
|
73
128
|
|
|
74
129
|
# Since the SVM works with score data, we need to train
|
|
75
130
|
# the underlying data structure for score generation first
|
|
@@ -124,7 +179,21 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
124
179
|
filter_ids: list[str] = None,
|
|
125
180
|
step: int = 1,
|
|
126
181
|
) -> ModelResult:
|
|
127
|
-
"""
|
|
182
|
+
"""
|
|
183
|
+
Predict the labels of the sequences.
|
|
184
|
+
|
|
185
|
+
This method uses the SVM to predict labels based on the scores generated
|
|
186
|
+
from the sequences. It expects the sequences to be in a format compatible
|
|
187
|
+
with the probabilistic filter model, and it will return a `ModelResult`.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
sequence_input (SeqRecord | list[SeqRecord] | SeqIO.FastaIO.FastaIterator | SeqIO.QualityIO.FastqPhredIterator | Path): The input sequences to predict.
|
|
191
|
+
filter_ids (list[str], optional): A list of IDs to filter the predictions. Defaults to None.
|
|
192
|
+
step (int, optional): Step size for sparse sampling. Defaults to 1.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
ModelResult: The result of the prediction containing hits, number of kmers, and the predicted label.
|
|
196
|
+
"""
|
|
128
197
|
# get scores and format them for the SVM
|
|
129
198
|
res = super().predict(sequence_input, filter_ids, step=step)
|
|
130
199
|
svm_scores = dict(sorted(res.get_scores()["total"].items()))
|
|
@@ -140,7 +209,19 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
140
209
|
)
|
|
141
210
|
|
|
142
211
|
def _get_svm(self, id_keys) -> SVC:
|
|
143
|
-
"""
|
|
212
|
+
"""
|
|
213
|
+
Get the SVM for the given id keys.
|
|
214
|
+
|
|
215
|
+
This method loads the SVM model from the scores CSV file and trains it
|
|
216
|
+
using the scores from the CSV. If `id_keys` is provided, it filters the
|
|
217
|
+
training data to only include those keys.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
id_keys (list[str] | None): A list of IDs to filter the training data. If None, all data is used.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
SVC: The trained SVM model.
|
|
224
|
+
"""
|
|
144
225
|
svm = SVC(kernel=self.kernel, C=self.c)
|
|
145
226
|
# parse csv
|
|
146
227
|
with open(
|
|
@@ -160,7 +241,19 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
160
241
|
|
|
161
242
|
@staticmethod
|
|
162
243
|
def load(path: Path) -> "ProbabilisticFilterSVMModel":
|
|
163
|
-
"""
|
|
244
|
+
"""
|
|
245
|
+
Load the model from disk
|
|
246
|
+
|
|
247
|
+
Loads the model from the specified path. The path should point to a JSON file
|
|
248
|
+
containing the model's parameters and state. It also checks for the existence of
|
|
249
|
+
the COBS index file.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
path (Path): The path to the model JSON file.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
ProbabilisticFilterSVMModel: The loaded model instance.
|
|
256
|
+
"""
|
|
164
257
|
with open(path, "r", encoding="utf-8") as file:
|
|
165
258
|
json_object = file.read()
|
|
166
259
|
model_json = json.loads(json_object)
|
|
@@ -14,19 +14,39 @@ from xspect.file_io import get_record_iterator
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
17
|
-
"""
|
|
17
|
+
"""
|
|
18
|
+
Probabilistic filter model for sequence data, with a single filter
|
|
19
|
+
|
|
20
|
+
This model uses a Bloom filter to store k-mers from the training sequences. It is designed to
|
|
21
|
+
be used with a single filter, which is suitable e. g. for genus-level classification.
|
|
22
|
+
"""
|
|
18
23
|
|
|
19
24
|
def __init__(
|
|
20
25
|
self,
|
|
21
26
|
k: int,
|
|
22
27
|
model_display_name: str,
|
|
23
|
-
author: str,
|
|
24
|
-
author_email: str,
|
|
28
|
+
author: str | None,
|
|
29
|
+
author_email: str | None,
|
|
25
30
|
model_type: str,
|
|
26
31
|
base_path: Path,
|
|
27
32
|
fpr: float = 0.01,
|
|
28
|
-
training_accessions: list[str] = None,
|
|
33
|
+
training_accessions: list[str] | None = None,
|
|
29
34
|
) -> None:
|
|
35
|
+
"""Initialize probabilistic single filter model.
|
|
36
|
+
|
|
37
|
+
This model uses a Bloom filter to store k-mers from the training sequences. It is designed to
|
|
38
|
+
be used with a single filter, which is suitable e.g. for genus-level classification.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
k (int): Length of the k-mers to use for filtering
|
|
42
|
+
model_display_name (str): Display name of the model
|
|
43
|
+
author (str | None): Author of the model
|
|
44
|
+
author_email (str | None): Email of the author
|
|
45
|
+
model_type (str): Type of the model, e.g. "probabilistic_single_filter"
|
|
46
|
+
base_path (Path): Base path where the model will be saved
|
|
47
|
+
fpr (float): False positive rate for the Bloom filter, default is 0.01
|
|
48
|
+
training_accessions (list[str] | None): List of accessions used for training, default is None
|
|
49
|
+
"""
|
|
30
50
|
super().__init__(
|
|
31
51
|
k=k,
|
|
32
52
|
model_display_name=model_display_name,
|
|
@@ -41,9 +61,22 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
41
61
|
self.bf = None
|
|
42
62
|
|
|
43
63
|
def fit(
|
|
44
|
-
self,
|
|
64
|
+
self,
|
|
65
|
+
file_path: Path,
|
|
66
|
+
display_name: str,
|
|
67
|
+
training_accessions: list[str] | None = None,
|
|
45
68
|
) -> None:
|
|
46
|
-
"""
|
|
69
|
+
"""
|
|
70
|
+
Fit the bloom filter to the sequences.
|
|
71
|
+
|
|
72
|
+
Trains the model by reading sequences from the provided file path,
|
|
73
|
+
generating k-mers, and adding them to the Bloom filter.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
file_path (Path): Path to the file containing sequences in FASTA format
|
|
77
|
+
display_name (str): Display name for the model
|
|
78
|
+
training_accessions (list[str] | None): List of accessions used for training, default is None
|
|
79
|
+
"""
|
|
47
80
|
self.training_accessions = training_accessions
|
|
48
81
|
|
|
49
82
|
# estimate number of kmers
|
|
@@ -65,7 +98,18 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
65
98
|
def calculate_hits(
|
|
66
99
|
self, sequence: Seq | SeqRecord, filter_ids=None, step: int = 1
|
|
67
100
|
) -> dict:
|
|
68
|
-
"""
|
|
101
|
+
"""
|
|
102
|
+
Calculate the hits for the sequence
|
|
103
|
+
|
|
104
|
+
Calculates the number of k-mers in the sequence that are present in the Bloom filter.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
sequence (Seq | SeqRecord): Sequence to calculate hits for, can be a Bio.Seq or Bio.SeqRecord object
|
|
108
|
+
filter_ids (list[str] | None): List of filter IDs to use, default is None
|
|
109
|
+
step (int): Step size for generating k-mers, default is 1
|
|
110
|
+
Returns:
|
|
111
|
+
dict: Dictionary with the display name as key and the number of hits as value
|
|
112
|
+
"""
|
|
69
113
|
if isinstance(sequence, SeqRecord):
|
|
70
114
|
sequence = sequence.seq
|
|
71
115
|
|
|
@@ -82,7 +126,17 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
82
126
|
|
|
83
127
|
@staticmethod
|
|
84
128
|
def load(path: Path) -> "ProbabilisticSingleFilterModel":
|
|
85
|
-
"""
|
|
129
|
+
"""
|
|
130
|
+
Load the model from disk
|
|
131
|
+
|
|
132
|
+
This method reads the model's JSON file and the associated Bloom filter file,
|
|
133
|
+
reconstructing the model instance.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
path (Path): Path to the model directory containing the JSON file
|
|
137
|
+
Returns:
|
|
138
|
+
ProbabilisticSingleFilterModel: An instance of the model loaded from disk
|
|
139
|
+
"""
|
|
86
140
|
with open(path, "r", encoding="utf-8") as file:
|
|
87
141
|
json_object = file.read()
|
|
88
142
|
model_json = json.loads(json_object)
|
|
@@ -105,7 +159,17 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
105
159
|
return model
|
|
106
160
|
|
|
107
161
|
def _generate_kmers(self, sequence: Seq, step: int = 1):
|
|
108
|
-
"""
|
|
162
|
+
"""
|
|
163
|
+
Generate kmers from the sequence
|
|
164
|
+
|
|
165
|
+
Generates k-mers from the sequence, considering both the forward and reverse complement strands.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
sequence (Seq): Sequence to generate k-mers from
|
|
169
|
+
step (int): Step size for generating k-mers, default is 1
|
|
170
|
+
Yields:
|
|
171
|
+
str: The minimizer k-mer (the lexicographically smallest k-mer between the forward and reverse complement)
|
|
172
|
+
"""
|
|
109
173
|
num_kmers = ceil((len(sequence) - self.k + 1) / step)
|
|
110
174
|
for i in range(num_kmers):
|
|
111
175
|
start_pos = i * step
|
xspect/models/result.py
CHANGED
|
@@ -14,9 +14,22 @@ class ModelResult:
|
|
|
14
14
|
hits: dict[str, dict[str, int]],
|
|
15
15
|
num_kmers: dict[str, int],
|
|
16
16
|
sparse_sampling_step: int = 1,
|
|
17
|
-
prediction: str = None,
|
|
18
|
-
input_source: str = None,
|
|
17
|
+
prediction: str | None = None,
|
|
18
|
+
input_source: str | None = None,
|
|
19
19
|
):
|
|
20
|
+
"""
|
|
21
|
+
Initialize the ModelResult object.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
model_slug (str): The slug of the model.
|
|
25
|
+
hits (dict[str, dict[str, int]]): A dictionary where keys are subsequence names
|
|
26
|
+
and values are dictionaries with labels as keys and hit counts as values.
|
|
27
|
+
num_kmers (dict[str, int]): A dictionary where keys are subsequence names
|
|
28
|
+
and values are the total number of k-mers for that subsequence.
|
|
29
|
+
sparse_sampling_step (int): The step size for sparse sampling, default is 1.
|
|
30
|
+
prediction (str | None): The prediction made by the model, default is None.
|
|
31
|
+
input_source (str | None): The source of the input data, default is None.
|
|
32
|
+
"""
|
|
20
33
|
if "total" in hits:
|
|
21
34
|
raise ValueError(
|
|
22
35
|
"'total' is a reserved key and cannot be used as a subsequence"
|
|
@@ -29,7 +42,16 @@ class ModelResult:
|
|
|
29
42
|
self.input_source = input_source
|
|
30
43
|
|
|
31
44
|
def get_scores(self) -> dict:
|
|
32
|
-
"""
|
|
45
|
+
"""
|
|
46
|
+
Return the scores of the model.
|
|
47
|
+
|
|
48
|
+
The scores are calculated as the number of hits divided by the total number of k-mers
|
|
49
|
+
for each subsequence and label. The scores are rounded to two decimal places.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
dict: A dictionary where keys are subsequence names and values are dictionaries
|
|
53
|
+
with labels as keys and scores as values. Also includes a 'total' key for overall scores.
|
|
54
|
+
"""
|
|
33
55
|
scores = {
|
|
34
56
|
subsequence: {
|
|
35
57
|
label: round(hits / self.num_kmers[subsequence], 2)
|
|
@@ -50,20 +72,37 @@ class ModelResult:
|
|
|
50
72
|
return scores
|
|
51
73
|
|
|
52
74
|
def get_total_hits(self) -> dict[str, int]:
|
|
53
|
-
"""
|
|
75
|
+
"""
|
|
76
|
+
Return the total hits of the model.
|
|
77
|
+
|
|
78
|
+
The total hits are calculated by summing the hits for each label across all subsequences.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
dict: A dictionary where keys are labels and values are the total number of hits for that label.
|
|
82
|
+
"""
|
|
54
83
|
total_hits = {label: 0 for label in list(self.hits.values())[0]}
|
|
55
|
-
for _,
|
|
56
|
-
for label, hits in
|
|
84
|
+
for _, subsequence_hits in self.hits.items():
|
|
85
|
+
for label, hits in subsequence_hits.items():
|
|
57
86
|
total_hits[label] += hits
|
|
58
87
|
return total_hits
|
|
59
88
|
|
|
60
89
|
def get_filter_mask(self, label: str, filter_threshold: float) -> dict[str, bool]:
|
|
61
|
-
"""
|
|
90
|
+
"""
|
|
91
|
+
Return a mask for filtered subsequences.
|
|
62
92
|
|
|
63
93
|
The mask is a dictionary with subsequence names as keys and boolean values
|
|
64
94
|
indicating whether the subsequence is above the filter threshold for the given label.
|
|
65
95
|
A value of -1 for filter_threshold indicates that the subsequence with the maximum score
|
|
66
96
|
for the given label should be returned.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
label (str): The label for which to filter the subsequences.
|
|
100
|
+
filter_threshold (float): The threshold for filtering subsequences. Must be between 0 and 1,
|
|
101
|
+
or -1 to return the subsequence with the maximum score for the label.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
dict[str, bool]: A dictionary where keys are subsequence names and values are booleans
|
|
105
|
+
indicating whether the subsequence meets the filter criteria for the given label.
|
|
67
106
|
"""
|
|
68
107
|
if filter_threshold < 0 and not filter_threshold == -1 or filter_threshold > 1:
|
|
69
108
|
raise ValueError("The filter threshold must be between 0 and 1.")
|
|
@@ -84,7 +123,19 @@ class ModelResult:
|
|
|
84
123
|
def get_filtered_subsequence_labels(
|
|
85
124
|
self, label: str, filter_threshold: float = 0.7
|
|
86
125
|
) -> list[str]:
|
|
87
|
-
"""
|
|
126
|
+
"""
|
|
127
|
+
Return the labels of filtered subsequences.
|
|
128
|
+
|
|
129
|
+
This method filters subsequences based on the scores for a given label and a filter threshold.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
label (str): The label for which to filter the subsequences.
|
|
133
|
+
filter_threshold (float): The threshold for filtering subsequences. Must be between 0 and 1,
|
|
134
|
+
or -1 to return the subsequence with the maximum score for the label.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
list[str]: A list of subsequence names that meet the filter criteria for the given label.
|
|
138
|
+
"""
|
|
88
139
|
return [
|
|
89
140
|
subsequence
|
|
90
141
|
for subsequence, mask in self.get_filter_mask(
|
|
@@ -94,7 +145,15 @@ class ModelResult:
|
|
|
94
145
|
]
|
|
95
146
|
|
|
96
147
|
def to_dict(self) -> dict:
|
|
97
|
-
"""
|
|
148
|
+
"""
|
|
149
|
+
Return the result as a dictionary.
|
|
150
|
+
|
|
151
|
+
This method converts the ModelResult object into a dictionary format suitable for serialization.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
dict: A dictionary representation of the ModelResult object, including model slug,
|
|
155
|
+
sparse sampling step, hits, scores, number of k-mers, input source, and prediction if available.
|
|
156
|
+
"""
|
|
98
157
|
res = {
|
|
99
158
|
"model_slug": self.model_slug,
|
|
100
159
|
"sparse_sampling_step": self.sparse_sampling_step,
|
|
@@ -110,6 +169,14 @@ class ModelResult:
|
|
|
110
169
|
return res
|
|
111
170
|
|
|
112
171
|
def save(self, path: Path) -> None:
|
|
113
|
-
"""
|
|
172
|
+
"""
|
|
173
|
+
Save the result as a JSON file.
|
|
174
|
+
|
|
175
|
+
This method serializes the ModelResult object to a JSON file at the specified path.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
path (Path): The path where the JSON file will be saved.
|
|
179
|
+
"""
|
|
180
|
+
path.parent.mkdir(exist_ok=True, parents=True)
|
|
114
181
|
with open(path, "w", encoding="utf-8") as f:
|
|
115
182
|
f.write(dumps(self.to_dict(), indent=4))
|
xspect/ncbi.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from enum import Enum
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
import time
|
|
6
|
+
from loguru import logger
|
|
6
7
|
import requests
|
|
7
8
|
|
|
8
9
|
# pylint: disable=line-too-long
|
|
@@ -26,26 +27,35 @@ class AssemblySource(Enum):
|
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class NCBIHandler:
|
|
29
|
-
"""
|
|
30
|
+
"""
|
|
31
|
+
This class uses the NCBI Datasets API to get data about taxa and their assemblies.
|
|
30
32
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
+
It provides methods to get taxon IDs, species, names, accessions, and download assemblies.
|
|
34
|
+
It also enforces rate limiting to comply with NCBI's API usage policies.
|
|
33
35
|
"""
|
|
34
36
|
|
|
35
37
|
def __init__(
|
|
36
38
|
self,
|
|
37
|
-
api_key: str = None,
|
|
39
|
+
api_key: str | None = None,
|
|
38
40
|
):
|
|
39
|
-
"""
|
|
41
|
+
"""
|
|
42
|
+
Initialise the NCBI handler.
|
|
43
|
+
|
|
44
|
+
This method sets up the base URL for the NCBI Datasets API and initializes the rate limiting parameters.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
api_key (str | None): The NCBI API key. If None, the handler will use the public API without an API key.
|
|
48
|
+
"""
|
|
40
49
|
self.api_key = api_key
|
|
41
50
|
self.base_url = "https://api.ncbi.nlm.nih.gov/datasets/v2"
|
|
42
|
-
self.last_request_time = 0
|
|
51
|
+
self.last_request_time = 0.0
|
|
43
52
|
self.min_interval = (
|
|
44
53
|
1 / 10 if api_key else 1 / 5
|
|
45
54
|
) # NCBI allows 10 requests per second with if an API key, otherwise 5 requests per second
|
|
46
55
|
|
|
47
|
-
def _enforce_rate_limit(self):
|
|
48
|
-
"""
|
|
56
|
+
def _enforce_rate_limit(self) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Enforce rate limiting for the NCBI Datasets API.
|
|
49
59
|
|
|
50
60
|
This method ensures that the requests to the API are limited to 5 requests per second
|
|
51
61
|
without an API key and 10 requests per second with an API key.
|
|
@@ -58,7 +68,11 @@ class NCBIHandler:
|
|
|
58
68
|
self.last_request_time = now
|
|
59
69
|
|
|
60
70
|
def _make_request(self, endpoint: str, timeout: int = 15) -> dict:
|
|
61
|
-
"""
|
|
71
|
+
"""
|
|
72
|
+
Make a request to the NCBI Datasets API.
|
|
73
|
+
|
|
74
|
+
This method constructs the full URL for the API endpoint, adds the necessary headers (including the API key if provided),
|
|
75
|
+
and makes a GET request to the API. It also enforces rate limiting before making the request.
|
|
62
76
|
|
|
63
77
|
Args:
|
|
64
78
|
endpoint (str): The endpoint to make the request to.
|
|
@@ -229,7 +243,7 @@ class NCBIHandler:
|
|
|
229
243
|
== "OK"
|
|
230
244
|
]
|
|
231
245
|
except (IndexError, KeyError, TypeError):
|
|
232
|
-
|
|
246
|
+
logger.debug(
|
|
233
247
|
f"Could not get {assembly_level.value} accessions for taxon with ID: {taxon_id}. Skipping."
|
|
234
248
|
)
|
|
235
249
|
return []
|
|
@@ -238,7 +252,20 @@ class NCBIHandler:
|
|
|
238
252
|
def get_highest_quality_accessions(
|
|
239
253
|
self, taxon_id: int, assembly_source: AssemblySource, count: int
|
|
240
254
|
) -> list[str]:
|
|
241
|
-
"""
|
|
255
|
+
"""
|
|
256
|
+
Get the highest quality accessions for a given taxon id (based on the assembly level).
|
|
257
|
+
|
|
258
|
+
This function iterates through the assembly levels in order of quality and retrieves accessions
|
|
259
|
+
until the specified count is reached. It ensures that the accessions are unique and sorted by quality.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
taxon_id (int): The taxon id to get the accessions for.
|
|
263
|
+
assembly_source (AssemblySource): The assembly source to get the accessions for.
|
|
264
|
+
count (int): The number of accessions to get.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
list[str]: A list containing the highest quality accessions.
|
|
268
|
+
"""
|
|
242
269
|
accessions = []
|
|
243
270
|
for assembly_level in list(AssemblyLevel):
|
|
244
271
|
accessions += self.get_accessions(
|
|
@@ -252,7 +279,16 @@ class NCBIHandler:
|
|
|
252
279
|
return list(set(accessions))[:count] # Remove duplicates and limit to count
|
|
253
280
|
|
|
254
281
|
def download_assemblies(self, accessions: list[str], output_dir: Path) -> None:
|
|
255
|
-
"""
|
|
282
|
+
"""
|
|
283
|
+
Download assemblies for a list of accessions.
|
|
284
|
+
|
|
285
|
+
This function makes a request to the NCBI Datasets API to download the assemblies for the given accessions.
|
|
286
|
+
It saves the downloaded assemblies as a zip file in the specified output directory.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
accessions (list[str]): A list of accessions to download.
|
|
290
|
+
output_dir (Path): The directory where the downloaded assemblies will be saved.
|
|
291
|
+
"""
|
|
256
292
|
endpoint = f"/genome/accession/{','.join(accessions)}/download?include_annotation_type=GENOME_FASTA"
|
|
257
293
|
|
|
258
294
|
self._enforce_rate_limit()
|
xspect/train.py
CHANGED
|
@@ -25,12 +25,12 @@ def train_from_directory(
|
|
|
25
25
|
display_name: str,
|
|
26
26
|
dir_path: Path,
|
|
27
27
|
meta: bool = False,
|
|
28
|
-
training_accessions: dict[str, list[str]] = None,
|
|
29
|
-
svm_accessions: list[str] = None,
|
|
28
|
+
training_accessions: dict[str, list[str]] | None = None,
|
|
29
|
+
svm_accessions: dict[str, list[str]] | None = None,
|
|
30
30
|
svm_step: int = 1,
|
|
31
|
-
translation_dict: dict[str, str] = None,
|
|
32
|
-
author: str = None,
|
|
33
|
-
author_email: str = None,
|
|
31
|
+
translation_dict: dict[str, str] | None = None,
|
|
32
|
+
author: str | None = None,
|
|
33
|
+
author_email: str | None = None,
|
|
34
34
|
):
|
|
35
35
|
"""
|
|
36
36
|
Train a model from a directory containing training data.
|
|
@@ -113,10 +113,11 @@ def train_from_directory(
|
|
|
113
113
|
species_dir = tmp_dir / "species"
|
|
114
114
|
species_dir.mkdir(parents=True, exist_ok=True)
|
|
115
115
|
|
|
116
|
-
|
|
116
|
+
logger.info("Concatenating genomes for species training...")
|
|
117
117
|
concatenate_species_fasta_files(cobs_folders, species_dir)
|
|
118
118
|
|
|
119
119
|
if svm_path.exists():
|
|
120
|
+
logger.info("Training species SVM model...")
|
|
120
121
|
species_model = ProbabilisticFilterSVMModel(
|
|
121
122
|
k=21,
|
|
122
123
|
model_display_name=display_name,
|
|
@@ -136,6 +137,7 @@ def train_from_directory(
|
|
|
136
137
|
svm_accessions=svm_accessions,
|
|
137
138
|
)
|
|
138
139
|
else:
|
|
140
|
+
logger.info("Training species model...")
|
|
139
141
|
species_model = ProbabilisticFilterModel(
|
|
140
142
|
k=21,
|
|
141
143
|
model_display_name=display_name,
|
|
@@ -153,9 +155,11 @@ def train_from_directory(
|
|
|
153
155
|
species_model.save()
|
|
154
156
|
|
|
155
157
|
if meta:
|
|
158
|
+
logger.info("Concatenating genomes for metagenome training...")
|
|
156
159
|
meta_fasta = tmp_dir / f"{display_name}.fasta"
|
|
157
160
|
concatenate_metagenome(species_dir, meta_fasta)
|
|
158
161
|
|
|
162
|
+
logger.info("Training metagenome model...")
|
|
159
163
|
genus_model = ProbabilisticSingleFilterModel(
|
|
160
164
|
k=21,
|
|
161
165
|
model_display_name=display_name,
|
|
@@ -179,10 +183,12 @@ def train_from_directory(
|
|
|
179
183
|
def train_from_ncbi(
|
|
180
184
|
genus: str,
|
|
181
185
|
svm_step: int = 1,
|
|
182
|
-
author: str = None,
|
|
183
|
-
author_email: str = None,
|
|
186
|
+
author: str | None = None,
|
|
187
|
+
author_email: str | None = None,
|
|
188
|
+
ncbi_api_key: str | None = None,
|
|
184
189
|
):
|
|
185
|
-
"""
|
|
190
|
+
"""
|
|
191
|
+
Train a model using NCBI assembly data for a given genus.
|
|
186
192
|
|
|
187
193
|
This function trains a probabilistic filter model using the assembly data from NCBI.
|
|
188
194
|
The training data is downloaded and processed, and the model is saved to the
|
|
@@ -193,6 +199,7 @@ def train_from_ncbi(
|
|
|
193
199
|
svm_step (int, optional): Step size for SVM training. Defaults to 1.
|
|
194
200
|
author (str, optional): Author of the model. Defaults to None.
|
|
195
201
|
author_email (str, optional): Author's email. Defaults to None.
|
|
202
|
+
ncbi_api_key (str, optional): NCBI API key for accessing NCBI resources. Defaults to None.
|
|
196
203
|
|
|
197
204
|
Raises:
|
|
198
205
|
TypeError: If `genus` is not a string.
|
|
@@ -205,7 +212,8 @@ def train_from_ncbi(
|
|
|
205
212
|
if not isinstance(genus, str):
|
|
206
213
|
raise TypeError("genus must be a string")
|
|
207
214
|
|
|
208
|
-
|
|
215
|
+
logger.info("Getting NCBI metadata...")
|
|
216
|
+
ncbi_handler = NCBIHandler(api_key=ncbi_api_key)
|
|
209
217
|
genus_tax_id = ncbi_handler.get_genus_taxon_id(genus)
|
|
210
218
|
species_ids = ncbi_handler.get_species(genus_tax_id)
|
|
211
219
|
species_names = ncbi_handler.get_taxon_names(species_ids)
|
|
@@ -243,7 +251,7 @@ def train_from_ncbi(
|
|
|
243
251
|
cobs_dir.mkdir(parents=True, exist_ok=True)
|
|
244
252
|
svm_dir.mkdir(parents=True, exist_ok=True)
|
|
245
253
|
|
|
246
|
-
|
|
254
|
+
logger.info("Downloading genomes from NCBI...")
|
|
247
255
|
all_accessions = sum(accessions.values(), [])
|
|
248
256
|
batch_size = 100
|
|
249
257
|
accession_paths = {}
|