XspecT 0.2.7__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of XspecT might be problematic. Click here for more details.
- xspect/definitions.py +0 -7
- xspect/download_models.py +25 -24
- xspect/fastapi.py +23 -26
- xspect/file_io.py +86 -2
- xspect/main.py +333 -98
- xspect/mlst_feature/mlst_helper.py +4 -6
- xspect/model_management.py +6 -0
- xspect/models/probabilistic_filter_model.py +16 -5
- xspect/models/probabilistic_filter_svm_model.py +33 -18
- xspect/models/probabilistic_single_filter_model.py +8 -1
- xspect/models/result.py +14 -60
- xspect/ncbi.py +265 -0
- xspect/train.py +258 -242
- {xspect-0.2.7.dist-info → xspect-0.4.0.dist-info}/METADATA +14 -21
- xspect-0.4.0.dist-info/RECORD +24 -0
- {xspect-0.2.7.dist-info → xspect-0.4.0.dist-info}/WHEEL +1 -1
- xspect/pipeline.py +0 -201
- xspect/run.py +0 -38
- xspect/train_filter/__init__.py +0 -0
- xspect/train_filter/create_svm.py +0 -45
- xspect/train_filter/extract_and_concatenate.py +0 -124
- xspect/train_filter/ncbi_api/__init__.py +0 -0
- xspect/train_filter/ncbi_api/download_assemblies.py +0 -31
- xspect/train_filter/ncbi_api/ncbi_assembly_metadata.py +0 -110
- xspect/train_filter/ncbi_api/ncbi_children_tree.py +0 -53
- xspect/train_filter/ncbi_api/ncbi_taxon_metadata.py +0 -55
- xspect-0.2.7.dist-info/RECORD +0 -33
- {xspect-0.2.7.dist-info → xspect-0.4.0.dist-info}/entry_points.txt +0 -0
- {xspect-0.2.7.dist-info → xspect-0.4.0.dist-info/licenses}/LICENSE +0 -0
- {xspect-0.2.7.dist-info → xspect-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -25,6 +25,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
25
25
|
model_type: str,
|
|
26
26
|
base_path: Path,
|
|
27
27
|
fpr: float = 0.01,
|
|
28
|
+
training_accessions: list[str] = None,
|
|
28
29
|
) -> None:
|
|
29
30
|
super().__init__(
|
|
30
31
|
k=k,
|
|
@@ -35,11 +36,16 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
35
36
|
base_path=base_path,
|
|
36
37
|
fpr=fpr,
|
|
37
38
|
num_hashes=1,
|
|
39
|
+
training_accessions=training_accessions,
|
|
38
40
|
)
|
|
39
41
|
self.bf = None
|
|
40
42
|
|
|
41
|
-
def fit(
|
|
43
|
+
def fit(
|
|
44
|
+
self, file_path: Path, display_name: str, training_accessions: list[str] = None
|
|
45
|
+
) -> None:
|
|
42
46
|
"""Fit the cobs classic index to the sequences and labels"""
|
|
47
|
+
self.training_accessions = training_accessions
|
|
48
|
+
|
|
43
49
|
# estimate number of kmers
|
|
44
50
|
total_length = 0
|
|
45
51
|
for record in get_record_iterator(file_path):
|
|
@@ -88,6 +94,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
88
94
|
model_json["model_type"],
|
|
89
95
|
path.parent,
|
|
90
96
|
fpr=model_json["fpr"],
|
|
97
|
+
training_accessions=model_json["training_accessions"],
|
|
91
98
|
)
|
|
92
99
|
model.display_names = model_json["display_names"]
|
|
93
100
|
bloom_path = model.base_path / model.slug() / "filter.bloom"
|
xspect/models/result.py
CHANGED
|
@@ -1,50 +1,7 @@
|
|
|
1
1
|
"""Module for storing the results of XspecT models."""
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def get_last_processing_step(result: "ModelResult") -> "ModelResult":
|
|
7
|
-
"""Get the last subprocessing step of the result. First path only."""
|
|
8
|
-
|
|
9
|
-
# traverse result tree to get last step
|
|
10
|
-
while result.subprocessing_steps:
|
|
11
|
-
result = result.subprocessing_steps[-1].result
|
|
12
|
-
return result
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class StepType(Enum):
|
|
16
|
-
"""Enum for defining the type of a subprocessing step."""
|
|
17
|
-
|
|
18
|
-
PREDICTION = 1
|
|
19
|
-
FILTERING = 2
|
|
20
|
-
|
|
21
|
-
def __str__(self) -> str:
|
|
22
|
-
return self.name.lower()
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class SubprocessingStep:
|
|
26
|
-
"""Class for storing a subprocessing step of an XspecT model."""
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
self,
|
|
30
|
-
subprocessing_type: StepType,
|
|
31
|
-
label: str,
|
|
32
|
-
treshold: float,
|
|
33
|
-
result: "ModelResult",
|
|
34
|
-
):
|
|
35
|
-
self.subprocessing_type = subprocessing_type
|
|
36
|
-
self.label = label
|
|
37
|
-
self.treshold = treshold
|
|
38
|
-
self.result = result
|
|
39
|
-
|
|
40
|
-
def to_dict(self) -> dict:
|
|
41
|
-
"""Return the subprocessing step as a dictionary."""
|
|
42
|
-
return {
|
|
43
|
-
"subprocessing_type": str(self.subprocessing_type),
|
|
44
|
-
"label": self.label,
|
|
45
|
-
"treshold": self.treshold,
|
|
46
|
-
"result": self.result.to_dict() if self.result else {},
|
|
47
|
-
}
|
|
3
|
+
from json import dumps
|
|
4
|
+
from pathlib import Path
|
|
48
5
|
|
|
49
6
|
|
|
50
7
|
class ModelResult:
|
|
@@ -58,6 +15,7 @@ class ModelResult:
|
|
|
58
15
|
num_kmers: dict[str, int],
|
|
59
16
|
sparse_sampling_step: int = 1,
|
|
60
17
|
prediction: str = None,
|
|
18
|
+
input_source: str = None,
|
|
61
19
|
):
|
|
62
20
|
if "total" in hits:
|
|
63
21
|
raise ValueError(
|
|
@@ -68,15 +26,7 @@ class ModelResult:
|
|
|
68
26
|
self.num_kmers = num_kmers
|
|
69
27
|
self.sparse_sampling_step = sparse_sampling_step
|
|
70
28
|
self.prediction = prediction
|
|
71
|
-
self.
|
|
72
|
-
|
|
73
|
-
def add_subprocessing_step(self, subprocessing_step: SubprocessingStep) -> None:
|
|
74
|
-
"""Add a subprocessing step to the result."""
|
|
75
|
-
if subprocessing_step.label in self.subprocessing_steps:
|
|
76
|
-
raise ValueError(
|
|
77
|
-
f"Subprocessing step {subprocessing_step.label} already exists in the result"
|
|
78
|
-
)
|
|
79
|
-
self.subprocessing_steps.append(subprocessing_step)
|
|
29
|
+
self.input_source = input_source
|
|
80
30
|
|
|
81
31
|
def get_scores(self) -> dict:
|
|
82
32
|
"""Return the scores of the model."""
|
|
@@ -119,8 +69,10 @@ class ModelResult:
|
|
|
119
69
|
for subsequence, score in scores.items()
|
|
120
70
|
}
|
|
121
71
|
|
|
122
|
-
def
|
|
123
|
-
|
|
72
|
+
def get_filtered_subsequence_labels(
|
|
73
|
+
self, label: str, filter_threshold: float = 0.7
|
|
74
|
+
) -> list[str]:
|
|
75
|
+
"""Return the labels of filtered subsequences."""
|
|
124
76
|
return [
|
|
125
77
|
subsequence
|
|
126
78
|
for subsequence, mask in self.get_filter_mask(
|
|
@@ -137,13 +89,15 @@ class ModelResult:
|
|
|
137
89
|
"hits": self.hits,
|
|
138
90
|
"scores": self.get_scores(),
|
|
139
91
|
"num_kmers": self.num_kmers,
|
|
140
|
-
"
|
|
141
|
-
subprocessing_step.to_dict()
|
|
142
|
-
for subprocessing_step in self.subprocessing_steps
|
|
143
|
-
],
|
|
92
|
+
"input_source": self.input_source,
|
|
144
93
|
}
|
|
145
94
|
|
|
146
95
|
if self.prediction is not None:
|
|
147
96
|
res["prediction"] = self.prediction
|
|
148
97
|
|
|
149
98
|
return res
|
|
99
|
+
|
|
100
|
+
def save(self, path: Path) -> None:
|
|
101
|
+
"""Save the result as a JSON file."""
|
|
102
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
103
|
+
f.write(dumps(self.to_dict(), indent=4))
|
xspect/ncbi.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
"""NCBI handler for the NCBI Datasets API."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import requests
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
# pylint: disable=line-too-long
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AssemblyLevel(Enum):
|
|
12
|
+
"""Enum for the assembly level."""
|
|
13
|
+
|
|
14
|
+
REFERENCE = "reference"
|
|
15
|
+
COMPLETE_GENOME = "complete_genome"
|
|
16
|
+
CHROMOSOME = "chromosome"
|
|
17
|
+
SCAFFOLD = "scaffold"
|
|
18
|
+
CONTIG = "contig"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AssemblySource(Enum):
|
|
22
|
+
"""Enum for the assembly source."""
|
|
23
|
+
|
|
24
|
+
REFSEQ = "refseq"
|
|
25
|
+
GENBANK = "genbank"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class NCBIHandler:
|
|
29
|
+
"""This class uses the NCBI Datasets API to get the taxonomy tree of a given Taxon.
|
|
30
|
+
|
|
31
|
+
The taxonomy tree consists of only the next children to the parent taxon.
|
|
32
|
+
The children are only of the next lower rank of the parent taxon.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
api_key: str = None,
|
|
38
|
+
):
|
|
39
|
+
"""Initialise the NCBI handler."""
|
|
40
|
+
self.api_key = api_key
|
|
41
|
+
self.base_url = "https://api.ncbi.nlm.nih.gov/datasets/v2"
|
|
42
|
+
self.last_request_time = 0
|
|
43
|
+
self.min_interval = (
|
|
44
|
+
1 / 10 if api_key else 1 / 5
|
|
45
|
+
) # NCBI allows 10 requests per second with if an API key, otherwise 5 requests per second
|
|
46
|
+
|
|
47
|
+
def _enforce_rate_limit(self):
|
|
48
|
+
"""Enforce rate limiting for the NCBI Datasets API.
|
|
49
|
+
|
|
50
|
+
This method ensures that the requests to the API are limited to 5 requests per second
|
|
51
|
+
without an API key and 10 requests per second with an API key.
|
|
52
|
+
It uses a simple time-based approach to enforce the rate limit.
|
|
53
|
+
"""
|
|
54
|
+
now = time.time()
|
|
55
|
+
elapsed_time = now - self.last_request_time
|
|
56
|
+
if elapsed_time < self.min_interval:
|
|
57
|
+
time.sleep(self.min_interval - elapsed_time)
|
|
58
|
+
self.last_request_time = now # Update last request time
|
|
59
|
+
|
|
60
|
+
def _make_request(self, endpoint: str, timeout: int = 5) -> dict:
|
|
61
|
+
"""Make a request to the NCBI Datasets API.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
endpoint (str): The endpoint to make the request to.
|
|
65
|
+
timeout (int, optional): The timeout for the request in seconds. Defaults to 5.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
dict: The response from the API.
|
|
69
|
+
"""
|
|
70
|
+
self._enforce_rate_limit()
|
|
71
|
+
|
|
72
|
+
endpoint = endpoint if endpoint.startswith("/") else "/" + endpoint
|
|
73
|
+
headers = {}
|
|
74
|
+
if self.api_key:
|
|
75
|
+
headers["api-key"] = self.api_key
|
|
76
|
+
response = requests.get(
|
|
77
|
+
self.base_url + endpoint, headers=headers, timeout=timeout
|
|
78
|
+
)
|
|
79
|
+
if response.status_code != 200:
|
|
80
|
+
response.raise_for_status()
|
|
81
|
+
|
|
82
|
+
return response.json()
|
|
83
|
+
|
|
84
|
+
def get_genus_taxon_id(self, genus: str) -> int:
|
|
85
|
+
"""
|
|
86
|
+
Get the taxon id for a given genus name.
|
|
87
|
+
|
|
88
|
+
This function checks if the genus name is valid by making a request to the NCBI Datasets API.
|
|
89
|
+
If the genus name is valid, it returns the taxon id.
|
|
90
|
+
If the genus name is not valid, it raises an exception.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
genus (str): The genus name to validate.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
int: The taxon id for the given genus name.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If the genus name is not valid.
|
|
100
|
+
"""
|
|
101
|
+
endpoint = f"/taxonomy/taxon/{genus}"
|
|
102
|
+
response = self._make_request(endpoint)
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
taxonomy = response["taxonomy_nodes"][0]["taxonomy"]
|
|
106
|
+
|
|
107
|
+
taxon_id = taxonomy["tax_id"]
|
|
108
|
+
rank = taxonomy["rank"]
|
|
109
|
+
lineage = taxonomy["lineage"]
|
|
110
|
+
|
|
111
|
+
if rank != "GENUS":
|
|
112
|
+
raise ValueError(f"Genus name {genus} is not a genus.")
|
|
113
|
+
if lineage[2] != 2:
|
|
114
|
+
raise ValueError(f"Genus name {genus} does not belong to bacteria.")
|
|
115
|
+
|
|
116
|
+
return taxon_id
|
|
117
|
+
except (IndexError, KeyError, TypeError) as e:
|
|
118
|
+
raise ValueError(f"Invalid genus name: {genus}") from e
|
|
119
|
+
|
|
120
|
+
def get_species(self, genus_id: int) -> list[int]:
|
|
121
|
+
"""
|
|
122
|
+
Get the species for a given genus id.
|
|
123
|
+
|
|
124
|
+
This function makes a request to the NCBI Datasets API to get the species for a given genus id.
|
|
125
|
+
It returns a list of species taxonomy ids.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
genus_id (int): The genus id to get the species for.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
list[int]: A list containing the species taxnomy ids.
|
|
132
|
+
"""
|
|
133
|
+
endpoint = f"/taxonomy/taxon/{genus_id}/filtered_subtree"
|
|
134
|
+
response = self._make_request(endpoint)
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
species_ids = response["edges"][str(genus_id)]["visible_children"]
|
|
138
|
+
except (IndexError, KeyError, TypeError) as e:
|
|
139
|
+
raise ValueError(f"Invalid genus id: {genus_id}") from e
|
|
140
|
+
return species_ids
|
|
141
|
+
|
|
142
|
+
def get_taxon_names(self, taxon_ids: list[int]) -> dict[int, str]:
|
|
143
|
+
"""
|
|
144
|
+
Get the names for a given list of taxon ids.
|
|
145
|
+
|
|
146
|
+
This function makes a request to the NCBI Datasets API to get the names for a given list of taxon ids.
|
|
147
|
+
It returns a dictionary with the taxon ids as keys and the names as values.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
taxon_ids (list[int]): The list of taxon ids to get the names for.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
dict[int, str]: A dictionary containing the taxon ids and their corresponding names.
|
|
154
|
+
"""
|
|
155
|
+
if len(taxon_ids) > 1000:
|
|
156
|
+
raise ValueError("Maximum number of taxon ids is 1000.")
|
|
157
|
+
if len(taxon_ids) < 1:
|
|
158
|
+
raise ValueError("At least one taxon id is required.")
|
|
159
|
+
|
|
160
|
+
endpoint = f"/taxonomy/taxon/{','.join(map(str, taxon_ids))}?page_size=1000"
|
|
161
|
+
response = self._make_request(endpoint)
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
taxon_names = {
|
|
165
|
+
int(taxonomy_node["taxonomy"]["tax_id"]): taxonomy_node["taxonomy"][
|
|
166
|
+
"organism_name"
|
|
167
|
+
]
|
|
168
|
+
for taxonomy_node in response["taxonomy_nodes"]
|
|
169
|
+
}
|
|
170
|
+
if len(taxon_names) != len(taxon_ids):
|
|
171
|
+
raise ValueError("Not all taxon ids were found.")
|
|
172
|
+
except (IndexError, KeyError, TypeError) as e:
|
|
173
|
+
raise ValueError(f"Invalid taxon ids: {taxon_ids}") from e
|
|
174
|
+
|
|
175
|
+
return taxon_names
|
|
176
|
+
|
|
177
|
+
def get_accessions(
|
|
178
|
+
self,
|
|
179
|
+
taxon_id: int,
|
|
180
|
+
assembly_level: AssemblyLevel,
|
|
181
|
+
assembly_source: AssemblySource,
|
|
182
|
+
count: int,
|
|
183
|
+
min_n50: int = 10000,
|
|
184
|
+
exclude_atypical: bool = True,
|
|
185
|
+
exclude_paired_reports: bool = True,
|
|
186
|
+
current_version_only: bool = True,
|
|
187
|
+
) -> list[str]:
|
|
188
|
+
"""
|
|
189
|
+
Get the accessions for a given taxon id.
|
|
190
|
+
|
|
191
|
+
This function makes a request to the NCBI Datasets API to get the accessions for a given taxon id.
|
|
192
|
+
It filters the accessions based on the assembly level, assembly source, and other parameters.
|
|
193
|
+
It returns a list with the respective accessions.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
taxon_id int: The taxon id to get the accessions for.
|
|
197
|
+
assembly_level (AssemblyLevel): The assembly level to get the accessions for.
|
|
198
|
+
assembly_source (AssemblySource): The assembly source to get the accessions for.
|
|
199
|
+
count (int): The number of accessions to get.
|
|
200
|
+
min_n50 (int, optional): The minimum contig n50 to filter the accessions. Defaults to 10000.
|
|
201
|
+
exclude_atypical (bool, optional): Whether to exclude atypical accessions. Defaults to True.
|
|
202
|
+
exclude_paired_reports (bool, optional): Whether to exclude paired reports. Defaults to True.
|
|
203
|
+
current_version_only (bool, optional): Whether to get only the current version of the accessions. Defaults to True.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
list[str]: A list containing the accessions.
|
|
207
|
+
"""
|
|
208
|
+
endpoint = (
|
|
209
|
+
f"/genome/taxon/{taxon_id}/dataset_report?"
|
|
210
|
+
f"filters.assembly_source={assembly_source.value}&"
|
|
211
|
+
f"filters.exclude_atypical={exclude_atypical}&"
|
|
212
|
+
f"filters.exclude_paired_reports={exclude_paired_reports}&"
|
|
213
|
+
f"filters.current_version_only={current_version_only}&"
|
|
214
|
+
f"page_size={count * 2}&" # to avoid having less than count if n50 or ANI is not met
|
|
215
|
+
)
|
|
216
|
+
endpoint += (
|
|
217
|
+
"&filters.reference_only=true"
|
|
218
|
+
if assembly_level == AssemblyLevel.REFERENCE
|
|
219
|
+
else f"&filters.assembly_level={assembly_level.value}"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
response = self._make_request(endpoint)
|
|
223
|
+
try:
|
|
224
|
+
accessions = [
|
|
225
|
+
report["accession"]
|
|
226
|
+
for report in response["reports"]
|
|
227
|
+
if report["assembly_stats"]["contig_n50"] >= min_n50
|
|
228
|
+
and report["average_nucleotide_identity"]["taxonomy_check_status"]
|
|
229
|
+
== "OK"
|
|
230
|
+
]
|
|
231
|
+
except (IndexError, KeyError, TypeError):
|
|
232
|
+
print(f"Could not get accessions for taxon with ID: {taxon_id}. Skipping.")
|
|
233
|
+
return []
|
|
234
|
+
return accessions[:count] # Limit to count
|
|
235
|
+
|
|
236
|
+
def get_highest_quality_accessions(
|
|
237
|
+
self, taxon_id: int, assembly_source: AssemblySource, count: int
|
|
238
|
+
) -> list[str]:
|
|
239
|
+
"""Get the highest quality accessions for a given taxon id (based on the assembly level)."""
|
|
240
|
+
accessions = []
|
|
241
|
+
for assembly_level in list(AssemblyLevel):
|
|
242
|
+
accessions += self.get_accessions(
|
|
243
|
+
taxon_id,
|
|
244
|
+
assembly_level,
|
|
245
|
+
assembly_source,
|
|
246
|
+
count,
|
|
247
|
+
)
|
|
248
|
+
if len(set(accessions)) >= count:
|
|
249
|
+
break
|
|
250
|
+
return list(set(accessions))[:count] # Remove duplicates and limit to count
|
|
251
|
+
|
|
252
|
+
def download_assemblies(self, accessions: list[str], output_dir: Path) -> None:
|
|
253
|
+
"""Download assemblies for a list of accessions."""
|
|
254
|
+
endpoint = f"/genome/accession/{','.join(accessions)}/download?include_annotation_type=GENOME_FASTA"
|
|
255
|
+
|
|
256
|
+
self._enforce_rate_limit()
|
|
257
|
+
|
|
258
|
+
response = requests.get(self.base_url + endpoint, stream=True, timeout=5)
|
|
259
|
+
if response.status_code != 200:
|
|
260
|
+
response.raise_for_status()
|
|
261
|
+
|
|
262
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
263
|
+
with open(output_dir / "ncbi_dataset.zip", "wb") as f:
|
|
264
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
265
|
+
f.write(chunk)
|