XspecT 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of XspecT might be problematic. Click here for more details.
- xspect/classify.py +51 -38
- xspect/definitions.py +50 -10
- xspect/download_models.py +10 -2
- xspect/file_io.py +115 -48
- xspect/filter_sequences.py +36 -66
- xspect/main.py +41 -10
- xspect/mlst_feature/mlst_helper.py +3 -0
- xspect/mlst_feature/pub_mlst_handler.py +43 -1
- xspect/model_management.py +84 -14
- xspect/models/probabilistic_filter_mlst_model.py +75 -37
- xspect/models/probabilistic_filter_model.py +194 -12
- xspect/models/probabilistic_filter_svm_model.py +99 -6
- xspect/models/probabilistic_single_filter_model.py +66 -5
- xspect/models/result.py +77 -10
- xspect/ncbi.py +45 -10
- xspect/train.py +2 -1
- xspect/web.py +68 -12
- xspect/xspect-web/dist/assets/index-Ceo58xui.css +1 -0
- xspect/xspect-web/dist/assets/{index-CMG4V7fZ.js → index-Dt_UlbgE.js} +82 -77
- xspect/xspect-web/dist/index.html +2 -2
- xspect/xspect-web/src/App.tsx +4 -2
- xspect/xspect-web/src/api.tsx +23 -1
- xspect/xspect-web/src/components/filter-form.tsx +16 -3
- xspect/xspect-web/src/components/filtering-result.tsx +65 -0
- xspect/xspect-web/src/components/result.tsx +2 -2
- xspect/xspect-web/src/types.tsx +5 -0
- {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/METADATA +1 -1
- {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/RECORD +32 -31
- {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/WHEEL +1 -1
- xspect/xspect-web/dist/assets/index-jIKg1HIy.css +0 -1
- {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/entry_points.txt +0 -0
- {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/top_level.txt +0 -0
xspect/main.py
CHANGED
|
@@ -18,10 +18,8 @@ from xspect.models.probabilistic_filter_mlst_model import (
|
|
|
18
18
|
ProbabilisticFilterMlstSchemeModel,
|
|
19
19
|
)
|
|
20
20
|
from xspect.model_management import (
|
|
21
|
-
get_genus_model,
|
|
22
21
|
get_model_metadata,
|
|
23
22
|
get_models,
|
|
24
|
-
get_species_model,
|
|
25
23
|
)
|
|
26
24
|
|
|
27
25
|
|
|
@@ -173,8 +171,9 @@ def train_mlst(choose_schemes):
|
|
|
173
171
|
scheme_path = pick_scheme(handler.get_scheme_paths())
|
|
174
172
|
species_name = str(scheme_path).split("/")[-2]
|
|
175
173
|
scheme_name = str(scheme_path).split("/")[-1]
|
|
174
|
+
scheme_url = handler.scheme_mapping[str(scheme_path)]
|
|
176
175
|
model = ProbabilisticFilterMlstSchemeModel(
|
|
177
|
-
31, f"{species_name}:{scheme_name}", get_xspect_model_path()
|
|
176
|
+
31, f"{species_name}:{scheme_name}", get_xspect_model_path(), scheme_url
|
|
178
177
|
)
|
|
179
178
|
click.echo("Creating mlst model")
|
|
180
179
|
model.fit(scheme_path)
|
|
@@ -220,10 +219,18 @@ def classify_seqs():
|
|
|
220
219
|
type=click.Path(dir_okay=False, file_okay=True),
|
|
221
220
|
default=Path(".") / f"result_{uuid4()}.json",
|
|
222
221
|
)
|
|
223
|
-
|
|
222
|
+
@click.option(
|
|
223
|
+
"--sparse-sampling-step",
|
|
224
|
+
type=int,
|
|
225
|
+
help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
|
|
226
|
+
default=1,
|
|
227
|
+
)
|
|
228
|
+
def classify_genus(model_genus, input_path, output_path, sparse_sampling_step):
|
|
224
229
|
"""Classify samples using a genus model."""
|
|
225
230
|
click.echo("Classifying...")
|
|
226
|
-
classify.classify_genus(
|
|
231
|
+
classify.classify_genus(
|
|
232
|
+
model_genus, Path(input_path), Path(output_path), sparse_sampling_step
|
|
233
|
+
)
|
|
227
234
|
|
|
228
235
|
|
|
229
236
|
@classify_seqs.command(
|
|
@@ -275,20 +282,24 @@ def classify_species(model_genus, input_path, output_path, sparse_sampling_step)
|
|
|
275
282
|
"-i",
|
|
276
283
|
"--input-path",
|
|
277
284
|
help="Path to FASTA-file for mlst identification.",
|
|
278
|
-
type=click.Path(exists=True, dir_okay=
|
|
285
|
+
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
279
286
|
prompt=True,
|
|
287
|
+
default=Path("."),
|
|
280
288
|
)
|
|
281
289
|
@click.option(
|
|
282
290
|
"-o",
|
|
283
291
|
"--output-path",
|
|
284
292
|
help="Path to the output file.",
|
|
285
293
|
type=click.Path(dir_okay=False, file_okay=True),
|
|
294
|
+
default=Path(".") / f"MLST_result_{uuid4()}.json",
|
|
295
|
+
)
|
|
296
|
+
@click.option(
|
|
297
|
+
"-l", "--limit", is_flag=True, help="Limit the output to 5 results for each locus."
|
|
286
298
|
)
|
|
287
|
-
def classify_mlst(input_path, output_path):
|
|
299
|
+
def classify_mlst(input_path, output_path, limit):
|
|
288
300
|
"""MLST classify a sample."""
|
|
289
301
|
click.echo("Classifying...")
|
|
290
|
-
classify.classify_mlst(Path(input_path), Path(output_path))
|
|
291
|
-
click.echo(f"Result saved as {output_path}.")
|
|
302
|
+
classify.classify_mlst(Path(input_path), Path(output_path), limit)
|
|
292
303
|
|
|
293
304
|
|
|
294
305
|
# # # # # # # # # # # # # # #
|
|
@@ -343,8 +354,19 @@ def filter_seqs():
|
|
|
343
354
|
default=0.7,
|
|
344
355
|
prompt=True,
|
|
345
356
|
)
|
|
357
|
+
@click.option(
|
|
358
|
+
"--sparse-sampling-step",
|
|
359
|
+
type=int,
|
|
360
|
+
help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
|
|
361
|
+
default=1,
|
|
362
|
+
)
|
|
346
363
|
def filter_genus(
|
|
347
|
-
model_genus,
|
|
364
|
+
model_genus,
|
|
365
|
+
input_path,
|
|
366
|
+
output_path,
|
|
367
|
+
classification_output_path,
|
|
368
|
+
threshold,
|
|
369
|
+
sparse_sampling_step,
|
|
348
370
|
):
|
|
349
371
|
"""Filter samples using a genus model."""
|
|
350
372
|
click.echo("Filtering...")
|
|
@@ -355,6 +377,7 @@ def filter_genus(
|
|
|
355
377
|
Path(output_path),
|
|
356
378
|
threshold,
|
|
357
379
|
Path(classification_output_path) if classification_output_path else None,
|
|
380
|
+
sparse_sampling_step=sparse_sampling_step,
|
|
358
381
|
)
|
|
359
382
|
|
|
360
383
|
|
|
@@ -405,6 +428,12 @@ def filter_genus(
|
|
|
405
428
|
default=0.7,
|
|
406
429
|
prompt=True,
|
|
407
430
|
)
|
|
431
|
+
@click.option(
|
|
432
|
+
"--sparse-sampling-step",
|
|
433
|
+
type=int,
|
|
434
|
+
help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
|
|
435
|
+
default=1,
|
|
436
|
+
)
|
|
408
437
|
def filter_species(
|
|
409
438
|
model_genus,
|
|
410
439
|
model_species,
|
|
@@ -412,6 +441,7 @@ def filter_species(
|
|
|
412
441
|
output_path,
|
|
413
442
|
threshold,
|
|
414
443
|
classification_output_path,
|
|
444
|
+
sparse_sampling_step,
|
|
415
445
|
):
|
|
416
446
|
"""Filter a sample using the species model."""
|
|
417
447
|
|
|
@@ -451,6 +481,7 @@ def filter_species(
|
|
|
451
481
|
Path(output_path),
|
|
452
482
|
threshold,
|
|
453
483
|
Path(classification_output_path) if classification_output_path else None,
|
|
484
|
+
sparse_sampling_step=sparse_sampling_step,
|
|
454
485
|
)
|
|
455
486
|
|
|
456
487
|
|
|
@@ -194,11 +194,13 @@ class MlstResult:
|
|
|
194
194
|
scheme_model: str,
|
|
195
195
|
steps: int,
|
|
196
196
|
hits: dict[str, list[dict]],
|
|
197
|
+
input_source: str = None,
|
|
197
198
|
):
|
|
198
199
|
"""Initialise an MlstResult object."""
|
|
199
200
|
self.scheme_model = scheme_model
|
|
200
201
|
self.steps = steps
|
|
201
202
|
self.hits = hits
|
|
203
|
+
self.input_source = input_source
|
|
202
204
|
|
|
203
205
|
def get_results(self) -> dict:
|
|
204
206
|
"""
|
|
@@ -221,6 +223,7 @@ class MlstResult:
|
|
|
221
223
|
"Scheme": self.scheme_model,
|
|
222
224
|
"Steps": self.steps,
|
|
223
225
|
"Results": self.get_results(),
|
|
226
|
+
"Input_source": self.input_source,
|
|
224
227
|
}
|
|
225
228
|
return result
|
|
226
229
|
|
|
@@ -17,7 +17,7 @@ from xspect.definitions import get_xspect_mlst_path, get_xspect_upload_path
|
|
|
17
17
|
class PubMLSTHandler:
|
|
18
18
|
"""Class for communicating with PubMLST and downloading alleles (FASTA-Format) from all loci."""
|
|
19
19
|
|
|
20
|
-
base_url = "
|
|
20
|
+
base_url = "https://rest.pubmlst.org/db"
|
|
21
21
|
|
|
22
22
|
def __init__(self):
|
|
23
23
|
"""Initialise a PubMLSTHandler object."""
|
|
@@ -27,6 +27,7 @@ class PubMLSTHandler:
|
|
|
27
27
|
self.base_url + "/pubmlst_abaumannii_seqdef/schemes/2",
|
|
28
28
|
]
|
|
29
29
|
self.scheme_paths = []
|
|
30
|
+
self.scheme_mapping = {}
|
|
30
31
|
|
|
31
32
|
def get_scheme_paths(self) -> dict:
|
|
32
33
|
"""
|
|
@@ -103,6 +104,7 @@ class PubMLSTHandler:
|
|
|
103
104
|
|
|
104
105
|
species_name = scheme.split("_")[1] # name = pubmlst_abaumannii_seqdef
|
|
105
106
|
scheme_path = get_xspect_mlst_path() / species_name / scheme_name
|
|
107
|
+
self.scheme_mapping[str(scheme_path)] = scheme
|
|
106
108
|
self.scheme_paths.append(scheme_path)
|
|
107
109
|
|
|
108
110
|
for locus_url in locus_list:
|
|
@@ -143,3 +145,43 @@ class PubMLSTHandler:
|
|
|
143
145
|
# Example: 'Pas_fusA': [{'href': some URL, 'allele_id': '2'}]
|
|
144
146
|
print(locus + ":" + meta_data[0]["allele_id"], end="; ")
|
|
145
147
|
print("\nStrain Type:", response["fields"])
|
|
148
|
+
|
|
149
|
+
def get_strain_type_name(self, highest_results: dict, post_url: str) -> str:
|
|
150
|
+
"""
|
|
151
|
+
Send an API-POST request to PubMLST with the highest result of each locus as payload.
|
|
152
|
+
|
|
153
|
+
This function formats the highest_result dict into an accepted input for the request.
|
|
154
|
+
It gets a response from the site which is the strain type name.
|
|
155
|
+
The name is based on the allele id with the highest score for each locus.
|
|
156
|
+
Example of post_url for the oxford scheme of A.baumannii:
|
|
157
|
+
https://rest.pubmlst.org/db/pubmlst_abaumannii_seqdef/schemes/1/designations
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
highest_results (dict): The allele ids with the highest kmer matches.
|
|
161
|
+
post_url (str): The specific url for the scheme of a species
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
str: The response (ST name or No ST found) of the POST request.
|
|
165
|
+
"""
|
|
166
|
+
payload = {
|
|
167
|
+
"designations": {
|
|
168
|
+
locus: [{"allele": str(allele)}]
|
|
169
|
+
for locus, allele in highest_results.items()
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
response = requests.post(post_url + "/designations", json=payload)
|
|
174
|
+
|
|
175
|
+
if response.status_code == 200:
|
|
176
|
+
data = response.json()
|
|
177
|
+
if "fields" in data:
|
|
178
|
+
post_response = data["fields"]
|
|
179
|
+
return post_response
|
|
180
|
+
else:
|
|
181
|
+
post_response = "No matching Strain Type found in the database. "
|
|
182
|
+
post_response += "Possibly a novel Strain Type."
|
|
183
|
+
return post_response
|
|
184
|
+
else:
|
|
185
|
+
post_response = "Error:" + str(response.status_code)
|
|
186
|
+
post_response += response.text
|
|
187
|
+
return post_response
|
xspect/model_management.py
CHANGED
|
@@ -9,22 +9,55 @@ from xspect.models.probabilistic_filter_svm_model import ProbabilisticFilterSVMM
|
|
|
9
9
|
from xspect.definitions import get_xspect_model_path
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def get_genus_model(genus):
|
|
13
|
-
"""
|
|
12
|
+
def get_genus_model(genus) -> ProbabilisticSingleFilterModel:
|
|
13
|
+
"""
|
|
14
|
+
Get a genus model for the specified genus.
|
|
15
|
+
|
|
16
|
+
This function retrieves a pre-trained genus classification model based on the provided genus name.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
genus (str): The genus name for which the model is to be retrieved.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
ProbabilisticSingleFilterModel: An instance of the genus classification model.
|
|
23
|
+
"""
|
|
14
24
|
genus_model_path = get_xspect_model_path() / (genus.lower() + "-genus.json")
|
|
15
25
|
genus_filter_model = ProbabilisticSingleFilterModel.load(genus_model_path)
|
|
16
26
|
return genus_filter_model
|
|
17
27
|
|
|
18
28
|
|
|
19
|
-
def get_species_model(genus):
|
|
20
|
-
"""
|
|
29
|
+
def get_species_model(genus) -> ProbabilisticFilterSVMModel:
|
|
30
|
+
"""
|
|
31
|
+
Get a species classification model for the specified genus.
|
|
32
|
+
|
|
33
|
+
This function retrieves a pre-trained species classification model based on the provided genus name.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
genus (str): The genus name for which the species model is to be retrieved.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
ProbabilisticFilterSVMModel: An instance of the species classification model.
|
|
40
|
+
"""
|
|
21
41
|
species_model_path = get_xspect_model_path() / (genus.lower() + "-species.json")
|
|
22
42
|
species_filter_model = ProbabilisticFilterSVMModel.load(species_model_path)
|
|
23
43
|
return species_filter_model
|
|
24
44
|
|
|
25
45
|
|
|
26
|
-
def get_model_metadata(model: str | Path):
|
|
27
|
-
"""
|
|
46
|
+
def get_model_metadata(model: str | Path) -> dict:
|
|
47
|
+
"""
|
|
48
|
+
Get metadata of a specified model.
|
|
49
|
+
|
|
50
|
+
This function retrieves the metadata of a model from its JSON file.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model (str | Path): The slug of the model (as a string) or the path to the model JSON file.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
dict: A dictionary containing the model metadata.
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
ValueError: If the model does not exist or is not a valid file.
|
|
60
|
+
"""
|
|
28
61
|
if isinstance(model, str):
|
|
29
62
|
model_path = get_xspect_model_path() / (model.lower() + ".json")
|
|
30
63
|
elif isinstance(model, Path):
|
|
@@ -40,8 +73,17 @@ def get_model_metadata(model: str | Path):
|
|
|
40
73
|
return model_json
|
|
41
74
|
|
|
42
75
|
|
|
43
|
-
def update_model_metadata(model_slug: str, author: str, author_email: str):
|
|
44
|
-
"""
|
|
76
|
+
def update_model_metadata(model_slug: str, author: str, author_email: str) -> None:
|
|
77
|
+
"""
|
|
78
|
+
Update the metadata of a model.
|
|
79
|
+
|
|
80
|
+
This function updates the author and author email in the model's metadata JSON file.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model_slug (str): The slug of the model to update.
|
|
84
|
+
author (str): The name of the author to set in the metadata.
|
|
85
|
+
author_email (str): The email of the author to set in the metadata.
|
|
86
|
+
"""
|
|
45
87
|
model_metadata = get_model_metadata(model_slug)
|
|
46
88
|
model_metadata["author"] = author
|
|
47
89
|
model_metadata["author_email"] = author_email
|
|
@@ -51,8 +93,19 @@ def update_model_metadata(model_slug: str, author: str, author_email: str):
|
|
|
51
93
|
file.write(dumps(model_metadata, indent=4))
|
|
52
94
|
|
|
53
95
|
|
|
54
|
-
def update_model_display_name(
|
|
55
|
-
|
|
96
|
+
def update_model_display_name(
|
|
97
|
+
model_slug: str, filter_id: str, display_name: str
|
|
98
|
+
) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Update the display name of a filter in a model.
|
|
101
|
+
|
|
102
|
+
This function updates the display name of a specific filter in the model's metadata JSON file.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
model_slug (str): The slug of the model to update.
|
|
106
|
+
filter_id (str): The ID of the filter whose display name is to be updated.
|
|
107
|
+
display_name (str): The new display name for the filter.
|
|
108
|
+
"""
|
|
56
109
|
model_metadata = get_model_metadata(model_slug)
|
|
57
110
|
model_metadata["display_names"][filter_id] = display_name
|
|
58
111
|
|
|
@@ -61,8 +114,15 @@ def update_model_display_name(model_slug: str, filter_id: str, display_name: str
|
|
|
61
114
|
file.write(dumps(model_metadata, indent=4))
|
|
62
115
|
|
|
63
116
|
|
|
64
|
-
def get_models():
|
|
65
|
-
"""
|
|
117
|
+
def get_models() -> dict[str, list[dict]]:
|
|
118
|
+
"""
|
|
119
|
+
Get a list of all available models in a dictionary by type.
|
|
120
|
+
|
|
121
|
+
This function scans the model directory for JSON files and organizes them by their model type.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
dict[str, list[dict]]: A dictionary where keys are model types and values are lists of model display names.
|
|
125
|
+
"""
|
|
66
126
|
model_dict = {}
|
|
67
127
|
for model_file in get_xspect_model_path().glob("*.json"):
|
|
68
128
|
model_metadata = get_model_metadata(model_file)
|
|
@@ -73,7 +133,17 @@ def get_models():
|
|
|
73
133
|
return model_dict
|
|
74
134
|
|
|
75
135
|
|
|
76
|
-
def get_model_display_names(model_slug: str):
|
|
77
|
-
"""
|
|
136
|
+
def get_model_display_names(model_slug: str) -> list[str]:
|
|
137
|
+
"""
|
|
138
|
+
Get the display names included in a model.
|
|
139
|
+
|
|
140
|
+
This function retrieves the display names of individual filters from the model's metadata.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
model_slug (str): The slug of the model for which to retrieve display names.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
list[str]: A list of display names for the individual filters in the model.
|
|
147
|
+
"""
|
|
78
148
|
model_metadata = get_model_metadata(model_slug)
|
|
79
149
|
return list(model_metadata["display_names"].values())
|
|
@@ -12,6 +12,7 @@ from cobs_index import DocumentList
|
|
|
12
12
|
from collections import defaultdict
|
|
13
13
|
from xspect.file_io import get_record_iterator
|
|
14
14
|
from xspect.mlst_feature.mlst_helper import MlstResult
|
|
15
|
+
from xspect.mlst_feature.pub_mlst_handler import PubMLSTHandler
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class ProbabilisticFilterMlstSchemeModel:
|
|
@@ -19,20 +20,22 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
19
20
|
|
|
20
21
|
def __init__(
|
|
21
22
|
self,
|
|
22
|
-
|
|
23
|
-
|
|
23
|
+
k_value: int,
|
|
24
|
+
model_name: str,
|
|
24
25
|
base_path: Path,
|
|
26
|
+
scheme_url: str,
|
|
25
27
|
fpr: float = 0.001,
|
|
26
28
|
) -> None:
|
|
27
29
|
"""Initialise a ProbabilisticFilterMlstSchemeModel object."""
|
|
28
|
-
if
|
|
30
|
+
if k_value < 1:
|
|
29
31
|
raise ValueError("Invalid k value, must be greater than 0")
|
|
30
32
|
if not isinstance(base_path, Path):
|
|
31
33
|
raise ValueError("Invalid base path, must be a pathlib.Path object")
|
|
32
34
|
|
|
33
|
-
self.
|
|
34
|
-
self.
|
|
35
|
+
self.k_value = k_value
|
|
36
|
+
self.model_name = model_name
|
|
35
37
|
self.base_path = base_path / "MLST"
|
|
38
|
+
self.scheme_url = scheme_url
|
|
36
39
|
self.fpr = fpr
|
|
37
40
|
self.model_type = "Strain"
|
|
38
41
|
self.loci = {}
|
|
@@ -49,9 +52,10 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
49
52
|
dict: The dictionary containing all metadata of an object.
|
|
50
53
|
"""
|
|
51
54
|
return {
|
|
52
|
-
"
|
|
53
|
-
"
|
|
55
|
+
"k_value": self.k_value,
|
|
56
|
+
"model_name": self.model_name,
|
|
54
57
|
"model_type": self.model_type,
|
|
58
|
+
"scheme_url": str(self.scheme_url),
|
|
55
59
|
"fpr": self.fpr,
|
|
56
60
|
"scheme_path": str(self.scheme_path),
|
|
57
61
|
"cobs_path": str(self.cobs_path),
|
|
@@ -115,7 +119,7 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
115
119
|
# COBS only accepts strings as paths
|
|
116
120
|
doclist = DocumentList(str(locus_path))
|
|
117
121
|
index_params = cobs_index.CompactIndexParameters()
|
|
118
|
-
index_params.term_size = self.
|
|
122
|
+
index_params.term_size = self.k_value # k-mer size
|
|
119
123
|
index_params.clobber = True # overwrite output and temporary files
|
|
120
124
|
index_params.false_positive_rate = self.fpr
|
|
121
125
|
|
|
@@ -130,9 +134,7 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
130
134
|
|
|
131
135
|
def save(self) -> None:
|
|
132
136
|
"""Saves the model to disk"""
|
|
133
|
-
scheme = str(self.scheme_path).split("/")[
|
|
134
|
-
-1
|
|
135
|
-
] # [-1] -> contains the scheme name
|
|
137
|
+
scheme = str(self.scheme_path).split("/")[-1] # [-1] contains the scheme name
|
|
136
138
|
json_path = self.base_path / scheme / f"{scheme}.json"
|
|
137
139
|
json_object = json.dumps(self.to_dict(), indent=4)
|
|
138
140
|
|
|
@@ -156,9 +158,10 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
156
158
|
json_object = file.read()
|
|
157
159
|
model_json = json.loads(json_object)
|
|
158
160
|
model = ProbabilisticFilterMlstSchemeModel(
|
|
159
|
-
model_json["
|
|
160
|
-
model_json["
|
|
161
|
+
model_json["k_value"],
|
|
162
|
+
model_json["model_name"],
|
|
161
163
|
json_path.parent,
|
|
164
|
+
model_json["scheme_url"],
|
|
162
165
|
model_json["fpr"],
|
|
163
166
|
)
|
|
164
167
|
model.scheme_path = model_json["scheme_path"]
|
|
@@ -175,7 +178,12 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
175
178
|
return model
|
|
176
179
|
|
|
177
180
|
def calculate_hits(
|
|
178
|
-
self,
|
|
181
|
+
self,
|
|
182
|
+
cobs_path: Path,
|
|
183
|
+
sequence: Seq,
|
|
184
|
+
step: int = 1,
|
|
185
|
+
limit: bool = False,
|
|
186
|
+
limit_number: int = 5,
|
|
179
187
|
) -> list[dict]:
|
|
180
188
|
"""
|
|
181
189
|
Calculates the hits for a sequence.
|
|
@@ -189,6 +197,8 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
189
197
|
cobs_path (Path): The path of the COBS-structure directory.
|
|
190
198
|
sequence (Seq): The input sequence for classification.
|
|
191
199
|
step (int, optional): The amount of kmers that are passed; defaults to one.
|
|
200
|
+
limit (bool): Applying a filter that limits the best result.
|
|
201
|
+
limit_number (int): The amount of results when the filter is set to true.
|
|
192
202
|
|
|
193
203
|
Returns:
|
|
194
204
|
list[dict]: The results of the prediction.
|
|
@@ -201,7 +211,7 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
201
211
|
if not isinstance(sequence, Seq):
|
|
202
212
|
raise ValueError("Invalid sequence, must be a Bio.Seq object")
|
|
203
213
|
|
|
204
|
-
if not len(sequence) > self.
|
|
214
|
+
if not len(sequence) > self.k_value:
|
|
205
215
|
raise ValueError("Invalid sequence, must be longer than k")
|
|
206
216
|
|
|
207
217
|
if not self.indices:
|
|
@@ -239,6 +249,10 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
239
249
|
sorted_counts = dict(
|
|
240
250
|
sorted(all_counts.items(), key=lambda item: -item[1])
|
|
241
251
|
)
|
|
252
|
+
|
|
253
|
+
if limit:
|
|
254
|
+
sorted_counts = dict(list(sorted_counts.items())[:limit_number])
|
|
255
|
+
|
|
242
256
|
if not sorted_counts:
|
|
243
257
|
result_dict = "A Strain type could not be detected because of no kmer matches!"
|
|
244
258
|
highest_results[scheme_path_list[counter]] = {"N/A": 0}
|
|
@@ -250,25 +264,37 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
250
264
|
first_key: highest_result
|
|
251
265
|
}
|
|
252
266
|
counter += 1
|
|
253
|
-
else:
|
|
267
|
+
else: # No split procedure is needed, when the sequence is short
|
|
254
268
|
for index in self.indices:
|
|
255
|
-
res = index.search(
|
|
269
|
+
res = index.search( # COBS can't handle Seq-Objects
|
|
256
270
|
str(sequence), step=step
|
|
257
|
-
) # COBS can't handle Seq-Objects
|
|
258
|
-
result_dict[scheme_path_list[counter]] = self.get_cobs_result(
|
|
259
|
-
res, False
|
|
260
271
|
)
|
|
261
|
-
|
|
262
|
-
|
|
272
|
+
result = self.get_cobs_result(res, False)
|
|
273
|
+
result = (
|
|
274
|
+
dict(sorted(result.items(), key=lambda x: -x[1])[:limit_number])
|
|
275
|
+
if limit
|
|
276
|
+
else result
|
|
263
277
|
)
|
|
278
|
+
result_dict[scheme_path_list[counter]] = result
|
|
279
|
+
first_key, highest_result = next(iter(result.items()))
|
|
264
280
|
highest_results[scheme_path_list[counter]] = {first_key: highest_result}
|
|
265
281
|
counter += 1
|
|
282
|
+
|
|
266
283
|
# check if the strain type has sufficient amount of kmer hits
|
|
267
284
|
is_valid = self.has_sufficient_score(highest_results, self.avg_locus_bp_size)
|
|
268
285
|
if not is_valid:
|
|
269
286
|
highest_results["Attention:"] = (
|
|
270
287
|
"This strain type is not reliable due to low kmer hit rates!"
|
|
271
288
|
)
|
|
289
|
+
else:
|
|
290
|
+
handler = PubMLSTHandler()
|
|
291
|
+
# allele_id is of type dict
|
|
292
|
+
flattened = {
|
|
293
|
+
locus: int(list(allele_id.keys())[0].split("_")[-1])
|
|
294
|
+
for locus, allele_id in highest_results.items()
|
|
295
|
+
}
|
|
296
|
+
strain_type_name = handler.get_strain_type_name(flattened, self.scheme_url)
|
|
297
|
+
highest_results["ST_Name"] = strain_type_name
|
|
272
298
|
return [{"Strain type": highest_results}, {"All results": result_dict}]
|
|
273
299
|
|
|
274
300
|
def predict(
|
|
@@ -282,6 +308,7 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
282
308
|
| Path
|
|
283
309
|
),
|
|
284
310
|
step: int = 1,
|
|
311
|
+
limit: bool = False,
|
|
285
312
|
) -> MlstResult:
|
|
286
313
|
"""
|
|
287
314
|
Get scores for the sequence(s) based on the filters in the model.
|
|
@@ -290,6 +317,7 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
290
317
|
cobs_path (Path): The path of the COBS-structure directory.
|
|
291
318
|
sequence_input (Seq): The input sequence for classification
|
|
292
319
|
step (int, optional): The amount of kmers that are passed; defaults to one
|
|
320
|
+
limit (bool, optional): Applying a filter that limits the best result.
|
|
293
321
|
|
|
294
322
|
Returns:
|
|
295
323
|
MlstResult: The results of the prediction.
|
|
@@ -301,13 +329,19 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
301
329
|
if sequence_input.id == "<unknown id>":
|
|
302
330
|
sequence_input.id = "test"
|
|
303
331
|
hits = {
|
|
304
|
-
sequence_input.id: self.calculate_hits(
|
|
332
|
+
sequence_input.id: self.calculate_hits(
|
|
333
|
+
cobs_path, sequence_input.seq, step, limit
|
|
334
|
+
)
|
|
305
335
|
}
|
|
306
|
-
return MlstResult(self.
|
|
336
|
+
return MlstResult(self.model_name, step, hits, None)
|
|
307
337
|
|
|
308
338
|
if isinstance(sequence_input, Path):
|
|
309
339
|
return ProbabilisticFilterMlstSchemeModel.predict(
|
|
310
|
-
self,
|
|
340
|
+
self,
|
|
341
|
+
cobs_path,
|
|
342
|
+
get_record_iterator(sequence_input),
|
|
343
|
+
step=step,
|
|
344
|
+
limit=limit,
|
|
311
345
|
)
|
|
312
346
|
|
|
313
347
|
if isinstance(
|
|
@@ -317,33 +351,35 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
317
351
|
hits = {}
|
|
318
352
|
# individual_seq is a SeqRecord-Object
|
|
319
353
|
for individual_seq in sequence_input:
|
|
320
|
-
individual_hits = self.calculate_hits(
|
|
354
|
+
individual_hits = self.calculate_hits(
|
|
355
|
+
cobs_path, individual_seq.seq, step, limit
|
|
356
|
+
)
|
|
321
357
|
hits[individual_seq.id] = individual_hits
|
|
322
|
-
return MlstResult(self.
|
|
323
|
-
|
|
358
|
+
return MlstResult(self.model_name, step, hits, None)
|
|
324
359
|
raise ValueError(
|
|
325
360
|
"Invalid sequence input, must be a Seq object, a list of Seq objects, a"
|
|
326
361
|
" SeqIO FastaIterator, or a SeqIO FastqPhredIterator"
|
|
327
362
|
)
|
|
328
363
|
|
|
329
364
|
def get_cobs_result(
|
|
330
|
-
self,
|
|
365
|
+
self,
|
|
366
|
+
cobs_result: cobs_index.SearchResult,
|
|
367
|
+
kmer_threshold: bool,
|
|
331
368
|
) -> dict:
|
|
332
369
|
"""
|
|
333
370
|
Get every entry in a COBS search result.
|
|
334
371
|
|
|
335
372
|
Args:
|
|
336
373
|
cobs_result (SearchResult): The result of the prediction.
|
|
337
|
-
kmer_threshold (bool): Applying a kmer threshold to mitigate false positives
|
|
374
|
+
kmer_threshold (bool): Applying a kmer threshold to mitigate false positives.
|
|
338
375
|
|
|
339
376
|
Returns:
|
|
340
377
|
dict: A dictionary storing the allele id of locus as key and the score as value.
|
|
341
378
|
"""
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
}
|
|
379
|
+
hits = [
|
|
380
|
+
result for result in cobs_result if not kmer_threshold or result.score > 50
|
|
381
|
+
]
|
|
382
|
+
return {result.doc_name: result.score for result in hits}
|
|
347
383
|
|
|
348
384
|
def sequence_splitter(self, input_sequence: str, allele_len: int) -> list[str]:
|
|
349
385
|
"""
|
|
@@ -379,13 +415,15 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
379
415
|
|
|
380
416
|
while start + substring_length <= sequence_len:
|
|
381
417
|
substring_list.append(input_sequence[start : start + substring_length])
|
|
382
|
-
start +=
|
|
418
|
+
start += (
|
|
419
|
+
substring_length - self.k_value + 1
|
|
420
|
+
) # To not lose kmers when dividing
|
|
383
421
|
|
|
384
422
|
# The remaining string is either appended to the list or added to the last entry.
|
|
385
423
|
if start < len(input_sequence):
|
|
386
424
|
remaining_substring = input_sequence[start:]
|
|
387
425
|
# A substring needs to be at least of size k for COBS.
|
|
388
|
-
if len(remaining_substring) < self.
|
|
426
|
+
if len(remaining_substring) < self.k_value:
|
|
389
427
|
substring_list[-1] += remaining_substring
|
|
390
428
|
else:
|
|
391
429
|
substring_list.append(remaining_substring)
|