XspecT 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of XspecT might be problematic. Click here for more details.
- xspect/classify.py +51 -38
- xspect/definitions.py +50 -10
- xspect/download_models.py +10 -2
- xspect/file_io.py +115 -48
- xspect/filter_sequences.py +36 -66
- xspect/main.py +44 -11
- xspect/mlst_feature/mlst_helper.py +3 -0
- xspect/mlst_feature/pub_mlst_handler.py +43 -1
- xspect/model_management.py +84 -14
- xspect/models/probabilistic_filter_mlst_model.py +75 -37
- xspect/models/probabilistic_filter_model.py +194 -12
- xspect/models/probabilistic_filter_svm_model.py +99 -6
- xspect/models/probabilistic_single_filter_model.py +66 -5
- xspect/models/result.py +77 -10
- xspect/ncbi.py +48 -12
- xspect/train.py +2 -1
- xspect/web.py +71 -13
- xspect/xspect-web/dist/assets/index-Ceo58xui.css +1 -0
- xspect/xspect-web/dist/assets/{index-CMG4V7fZ.js → index-Dt_UlbgE.js} +82 -77
- xspect/xspect-web/dist/index.html +2 -2
- xspect/xspect-web/src/App.tsx +4 -2
- xspect/xspect-web/src/api.tsx +23 -1
- xspect/xspect-web/src/components/filter-form.tsx +16 -3
- xspect/xspect-web/src/components/filtering-result.tsx +65 -0
- xspect/xspect-web/src/components/result.tsx +2 -2
- xspect/xspect-web/src/types.tsx +5 -0
- {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/METADATA +1 -1
- {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/RECORD +32 -31
- {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/WHEEL +1 -1
- xspect/xspect-web/dist/assets/index-jIKg1HIy.css +0 -1
- {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/entry_points.txt +0 -0
- {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/top_level.txt +0 -0
xspect/main.py
CHANGED
|
@@ -18,10 +18,8 @@ from xspect.models.probabilistic_filter_mlst_model import (
|
|
|
18
18
|
ProbabilisticFilterMlstSchemeModel,
|
|
19
19
|
)
|
|
20
20
|
from xspect.model_management import (
|
|
21
|
-
get_genus_model,
|
|
22
21
|
get_model_metadata,
|
|
23
22
|
get_models,
|
|
24
|
-
get_species_model,
|
|
25
23
|
)
|
|
26
24
|
|
|
27
25
|
|
|
@@ -51,7 +49,9 @@ def models():
|
|
|
51
49
|
def download():
|
|
52
50
|
"""Download models."""
|
|
53
51
|
click.echo("Downloading models, this may take a while...")
|
|
54
|
-
download_test_models(
|
|
52
|
+
download_test_models(
|
|
53
|
+
"https://assets.adrianromberg.com/science/xspect-models-07-08-2025.zip"
|
|
54
|
+
)
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
@models.command(
|
|
@@ -173,8 +173,9 @@ def train_mlst(choose_schemes):
|
|
|
173
173
|
scheme_path = pick_scheme(handler.get_scheme_paths())
|
|
174
174
|
species_name = str(scheme_path).split("/")[-2]
|
|
175
175
|
scheme_name = str(scheme_path).split("/")[-1]
|
|
176
|
+
scheme_url = handler.scheme_mapping[str(scheme_path)]
|
|
176
177
|
model = ProbabilisticFilterMlstSchemeModel(
|
|
177
|
-
31, f"{species_name}:{scheme_name}", get_xspect_model_path()
|
|
178
|
+
31, f"{species_name}:{scheme_name}", get_xspect_model_path(), scheme_url
|
|
178
179
|
)
|
|
179
180
|
click.echo("Creating mlst model")
|
|
180
181
|
model.fit(scheme_path)
|
|
@@ -220,10 +221,18 @@ def classify_seqs():
|
|
|
220
221
|
type=click.Path(dir_okay=False, file_okay=True),
|
|
221
222
|
default=Path(".") / f"result_{uuid4()}.json",
|
|
222
223
|
)
|
|
223
|
-
|
|
224
|
+
@click.option(
|
|
225
|
+
"--sparse-sampling-step",
|
|
226
|
+
type=int,
|
|
227
|
+
help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
|
|
228
|
+
default=1,
|
|
229
|
+
)
|
|
230
|
+
def classify_genus(model_genus, input_path, output_path, sparse_sampling_step):
|
|
224
231
|
"""Classify samples using a genus model."""
|
|
225
232
|
click.echo("Classifying...")
|
|
226
|
-
classify.classify_genus(
|
|
233
|
+
classify.classify_genus(
|
|
234
|
+
model_genus, Path(input_path), Path(output_path), sparse_sampling_step
|
|
235
|
+
)
|
|
227
236
|
|
|
228
237
|
|
|
229
238
|
@classify_seqs.command(
|
|
@@ -275,20 +284,24 @@ def classify_species(model_genus, input_path, output_path, sparse_sampling_step)
|
|
|
275
284
|
"-i",
|
|
276
285
|
"--input-path",
|
|
277
286
|
help="Path to FASTA-file for mlst identification.",
|
|
278
|
-
type=click.Path(exists=True, dir_okay=
|
|
287
|
+
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
279
288
|
prompt=True,
|
|
289
|
+
default=Path("."),
|
|
280
290
|
)
|
|
281
291
|
@click.option(
|
|
282
292
|
"-o",
|
|
283
293
|
"--output-path",
|
|
284
294
|
help="Path to the output file.",
|
|
285
295
|
type=click.Path(dir_okay=False, file_okay=True),
|
|
296
|
+
default=Path(".") / f"MLST_result_{uuid4()}.json",
|
|
286
297
|
)
|
|
287
|
-
|
|
298
|
+
@click.option(
|
|
299
|
+
"-l", "--limit", is_flag=True, help="Limit the output to 5 results for each locus."
|
|
300
|
+
)
|
|
301
|
+
def classify_mlst(input_path, output_path, limit):
|
|
288
302
|
"""MLST classify a sample."""
|
|
289
303
|
click.echo("Classifying...")
|
|
290
|
-
classify.classify_mlst(Path(input_path), Path(output_path))
|
|
291
|
-
click.echo(f"Result saved as {output_path}.")
|
|
304
|
+
classify.classify_mlst(Path(input_path), Path(output_path), limit)
|
|
292
305
|
|
|
293
306
|
|
|
294
307
|
# # # # # # # # # # # # # # #
|
|
@@ -343,8 +356,19 @@ def filter_seqs():
|
|
|
343
356
|
default=0.7,
|
|
344
357
|
prompt=True,
|
|
345
358
|
)
|
|
359
|
+
@click.option(
|
|
360
|
+
"--sparse-sampling-step",
|
|
361
|
+
type=int,
|
|
362
|
+
help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
|
|
363
|
+
default=1,
|
|
364
|
+
)
|
|
346
365
|
def filter_genus(
|
|
347
|
-
model_genus,
|
|
366
|
+
model_genus,
|
|
367
|
+
input_path,
|
|
368
|
+
output_path,
|
|
369
|
+
classification_output_path,
|
|
370
|
+
threshold,
|
|
371
|
+
sparse_sampling_step,
|
|
348
372
|
):
|
|
349
373
|
"""Filter samples using a genus model."""
|
|
350
374
|
click.echo("Filtering...")
|
|
@@ -355,6 +379,7 @@ def filter_genus(
|
|
|
355
379
|
Path(output_path),
|
|
356
380
|
threshold,
|
|
357
381
|
Path(classification_output_path) if classification_output_path else None,
|
|
382
|
+
sparse_sampling_step=sparse_sampling_step,
|
|
358
383
|
)
|
|
359
384
|
|
|
360
385
|
|
|
@@ -405,6 +430,12 @@ def filter_genus(
|
|
|
405
430
|
default=0.7,
|
|
406
431
|
prompt=True,
|
|
407
432
|
)
|
|
433
|
+
@click.option(
|
|
434
|
+
"--sparse-sampling-step",
|
|
435
|
+
type=int,
|
|
436
|
+
help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
|
|
437
|
+
default=1,
|
|
438
|
+
)
|
|
408
439
|
def filter_species(
|
|
409
440
|
model_genus,
|
|
410
441
|
model_species,
|
|
@@ -412,6 +443,7 @@ def filter_species(
|
|
|
412
443
|
output_path,
|
|
413
444
|
threshold,
|
|
414
445
|
classification_output_path,
|
|
446
|
+
sparse_sampling_step,
|
|
415
447
|
):
|
|
416
448
|
"""Filter a sample using the species model."""
|
|
417
449
|
|
|
@@ -451,6 +483,7 @@ def filter_species(
|
|
|
451
483
|
Path(output_path),
|
|
452
484
|
threshold,
|
|
453
485
|
Path(classification_output_path) if classification_output_path else None,
|
|
486
|
+
sparse_sampling_step=sparse_sampling_step,
|
|
454
487
|
)
|
|
455
488
|
|
|
456
489
|
|
|
@@ -194,11 +194,13 @@ class MlstResult:
|
|
|
194
194
|
scheme_model: str,
|
|
195
195
|
steps: int,
|
|
196
196
|
hits: dict[str, list[dict]],
|
|
197
|
+
input_source: str = None,
|
|
197
198
|
):
|
|
198
199
|
"""Initialise an MlstResult object."""
|
|
199
200
|
self.scheme_model = scheme_model
|
|
200
201
|
self.steps = steps
|
|
201
202
|
self.hits = hits
|
|
203
|
+
self.input_source = input_source
|
|
202
204
|
|
|
203
205
|
def get_results(self) -> dict:
|
|
204
206
|
"""
|
|
@@ -221,6 +223,7 @@ class MlstResult:
|
|
|
221
223
|
"Scheme": self.scheme_model,
|
|
222
224
|
"Steps": self.steps,
|
|
223
225
|
"Results": self.get_results(),
|
|
226
|
+
"Input_source": self.input_source,
|
|
224
227
|
}
|
|
225
228
|
return result
|
|
226
229
|
|
|
@@ -17,7 +17,7 @@ from xspect.definitions import get_xspect_mlst_path, get_xspect_upload_path
|
|
|
17
17
|
class PubMLSTHandler:
|
|
18
18
|
"""Class for communicating with PubMLST and downloading alleles (FASTA-Format) from all loci."""
|
|
19
19
|
|
|
20
|
-
base_url = "
|
|
20
|
+
base_url = "https://rest.pubmlst.org/db"
|
|
21
21
|
|
|
22
22
|
def __init__(self):
|
|
23
23
|
"""Initialise a PubMLSTHandler object."""
|
|
@@ -27,6 +27,7 @@ class PubMLSTHandler:
|
|
|
27
27
|
self.base_url + "/pubmlst_abaumannii_seqdef/schemes/2",
|
|
28
28
|
]
|
|
29
29
|
self.scheme_paths = []
|
|
30
|
+
self.scheme_mapping = {}
|
|
30
31
|
|
|
31
32
|
def get_scheme_paths(self) -> dict:
|
|
32
33
|
"""
|
|
@@ -103,6 +104,7 @@ class PubMLSTHandler:
|
|
|
103
104
|
|
|
104
105
|
species_name = scheme.split("_")[1] # name = pubmlst_abaumannii_seqdef
|
|
105
106
|
scheme_path = get_xspect_mlst_path() / species_name / scheme_name
|
|
107
|
+
self.scheme_mapping[str(scheme_path)] = scheme
|
|
106
108
|
self.scheme_paths.append(scheme_path)
|
|
107
109
|
|
|
108
110
|
for locus_url in locus_list:
|
|
@@ -143,3 +145,43 @@ class PubMLSTHandler:
|
|
|
143
145
|
# Example: 'Pas_fusA': [{'href': some URL, 'allele_id': '2'}]
|
|
144
146
|
print(locus + ":" + meta_data[0]["allele_id"], end="; ")
|
|
145
147
|
print("\nStrain Type:", response["fields"])
|
|
148
|
+
|
|
149
|
+
def get_strain_type_name(self, highest_results: dict, post_url: str) -> str:
|
|
150
|
+
"""
|
|
151
|
+
Send an API-POST request to PubMLST with the highest result of each locus as payload.
|
|
152
|
+
|
|
153
|
+
This function formats the highest_result dict into an accepted input for the request.
|
|
154
|
+
It gets a response from the site which is the strain type name.
|
|
155
|
+
The name is based on the allele id with the highest score for each locus.
|
|
156
|
+
Example of post_url for the oxford scheme of A.baumannii:
|
|
157
|
+
https://rest.pubmlst.org/db/pubmlst_abaumannii_seqdef/schemes/1/designations
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
highest_results (dict): The allele ids with the highest kmer matches.
|
|
161
|
+
post_url (str): The specific url for the scheme of a species
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
str: The response (ST name or No ST found) of the POST request.
|
|
165
|
+
"""
|
|
166
|
+
payload = {
|
|
167
|
+
"designations": {
|
|
168
|
+
locus: [{"allele": str(allele)}]
|
|
169
|
+
for locus, allele in highest_results.items()
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
response = requests.post(post_url + "/designations", json=payload)
|
|
174
|
+
|
|
175
|
+
if response.status_code == 200:
|
|
176
|
+
data = response.json()
|
|
177
|
+
if "fields" in data:
|
|
178
|
+
post_response = data["fields"]
|
|
179
|
+
return post_response
|
|
180
|
+
else:
|
|
181
|
+
post_response = "No matching Strain Type found in the database. "
|
|
182
|
+
post_response += "Possibly a novel Strain Type."
|
|
183
|
+
return post_response
|
|
184
|
+
else:
|
|
185
|
+
post_response = "Error:" + str(response.status_code)
|
|
186
|
+
post_response += response.text
|
|
187
|
+
return post_response
|
xspect/model_management.py
CHANGED
|
@@ -9,22 +9,55 @@ from xspect.models.probabilistic_filter_svm_model import ProbabilisticFilterSVMM
|
|
|
9
9
|
from xspect.definitions import get_xspect_model_path
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def get_genus_model(genus):
|
|
13
|
-
"""
|
|
12
|
+
def get_genus_model(genus) -> ProbabilisticSingleFilterModel:
|
|
13
|
+
"""
|
|
14
|
+
Get a genus model for the specified genus.
|
|
15
|
+
|
|
16
|
+
This function retrieves a pre-trained genus classification model based on the provided genus name.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
genus (str): The genus name for which the model is to be retrieved.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
ProbabilisticSingleFilterModel: An instance of the genus classification model.
|
|
23
|
+
"""
|
|
14
24
|
genus_model_path = get_xspect_model_path() / (genus.lower() + "-genus.json")
|
|
15
25
|
genus_filter_model = ProbabilisticSingleFilterModel.load(genus_model_path)
|
|
16
26
|
return genus_filter_model
|
|
17
27
|
|
|
18
28
|
|
|
19
|
-
def get_species_model(genus):
|
|
20
|
-
"""
|
|
29
|
+
def get_species_model(genus) -> ProbabilisticFilterSVMModel:
|
|
30
|
+
"""
|
|
31
|
+
Get a species classification model for the specified genus.
|
|
32
|
+
|
|
33
|
+
This function retrieves a pre-trained species classification model based on the provided genus name.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
genus (str): The genus name for which the species model is to be retrieved.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
ProbabilisticFilterSVMModel: An instance of the species classification model.
|
|
40
|
+
"""
|
|
21
41
|
species_model_path = get_xspect_model_path() / (genus.lower() + "-species.json")
|
|
22
42
|
species_filter_model = ProbabilisticFilterSVMModel.load(species_model_path)
|
|
23
43
|
return species_filter_model
|
|
24
44
|
|
|
25
45
|
|
|
26
|
-
def get_model_metadata(model: str | Path):
|
|
27
|
-
"""
|
|
46
|
+
def get_model_metadata(model: str | Path) -> dict:
|
|
47
|
+
"""
|
|
48
|
+
Get metadata of a specified model.
|
|
49
|
+
|
|
50
|
+
This function retrieves the metadata of a model from its JSON file.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model (str | Path): The slug of the model (as a string) or the path to the model JSON file.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
dict: A dictionary containing the model metadata.
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
ValueError: If the model does not exist or is not a valid file.
|
|
60
|
+
"""
|
|
28
61
|
if isinstance(model, str):
|
|
29
62
|
model_path = get_xspect_model_path() / (model.lower() + ".json")
|
|
30
63
|
elif isinstance(model, Path):
|
|
@@ -40,8 +73,17 @@ def get_model_metadata(model: str | Path):
|
|
|
40
73
|
return model_json
|
|
41
74
|
|
|
42
75
|
|
|
43
|
-
def update_model_metadata(model_slug: str, author: str, author_email: str):
|
|
44
|
-
"""
|
|
76
|
+
def update_model_metadata(model_slug: str, author: str, author_email: str) -> None:
|
|
77
|
+
"""
|
|
78
|
+
Update the metadata of a model.
|
|
79
|
+
|
|
80
|
+
This function updates the author and author email in the model's metadata JSON file.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model_slug (str): The slug of the model to update.
|
|
84
|
+
author (str): The name of the author to set in the metadata.
|
|
85
|
+
author_email (str): The email of the author to set in the metadata.
|
|
86
|
+
"""
|
|
45
87
|
model_metadata = get_model_metadata(model_slug)
|
|
46
88
|
model_metadata["author"] = author
|
|
47
89
|
model_metadata["author_email"] = author_email
|
|
@@ -51,8 +93,19 @@ def update_model_metadata(model_slug: str, author: str, author_email: str):
|
|
|
51
93
|
file.write(dumps(model_metadata, indent=4))
|
|
52
94
|
|
|
53
95
|
|
|
54
|
-
def update_model_display_name(
|
|
55
|
-
|
|
96
|
+
def update_model_display_name(
|
|
97
|
+
model_slug: str, filter_id: str, display_name: str
|
|
98
|
+
) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Update the display name of a filter in a model.
|
|
101
|
+
|
|
102
|
+
This function updates the display name of a specific filter in the model's metadata JSON file.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
model_slug (str): The slug of the model to update.
|
|
106
|
+
filter_id (str): The ID of the filter whose display name is to be updated.
|
|
107
|
+
display_name (str): The new display name for the filter.
|
|
108
|
+
"""
|
|
56
109
|
model_metadata = get_model_metadata(model_slug)
|
|
57
110
|
model_metadata["display_names"][filter_id] = display_name
|
|
58
111
|
|
|
@@ -61,8 +114,15 @@ def update_model_display_name(model_slug: str, filter_id: str, display_name: str
|
|
|
61
114
|
file.write(dumps(model_metadata, indent=4))
|
|
62
115
|
|
|
63
116
|
|
|
64
|
-
def get_models():
|
|
65
|
-
"""
|
|
117
|
+
def get_models() -> dict[str, list[dict]]:
|
|
118
|
+
"""
|
|
119
|
+
Get a list of all available models in a dictionary by type.
|
|
120
|
+
|
|
121
|
+
This function scans the model directory for JSON files and organizes them by their model type.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
dict[str, list[dict]]: A dictionary where keys are model types and values are lists of model display names.
|
|
125
|
+
"""
|
|
66
126
|
model_dict = {}
|
|
67
127
|
for model_file in get_xspect_model_path().glob("*.json"):
|
|
68
128
|
model_metadata = get_model_metadata(model_file)
|
|
@@ -73,7 +133,17 @@ def get_models():
|
|
|
73
133
|
return model_dict
|
|
74
134
|
|
|
75
135
|
|
|
76
|
-
def get_model_display_names(model_slug: str):
|
|
77
|
-
"""
|
|
136
|
+
def get_model_display_names(model_slug: str) -> list[str]:
|
|
137
|
+
"""
|
|
138
|
+
Get the display names included in a model.
|
|
139
|
+
|
|
140
|
+
This function retrieves the display names of individual filters from the model's metadata.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
model_slug (str): The slug of the model for which to retrieve display names.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
list[str]: A list of display names for the individual filters in the model.
|
|
147
|
+
"""
|
|
78
148
|
model_metadata = get_model_metadata(model_slug)
|
|
79
149
|
return list(model_metadata["display_names"].values())
|
|
@@ -12,6 +12,7 @@ from cobs_index import DocumentList
|
|
|
12
12
|
from collections import defaultdict
|
|
13
13
|
from xspect.file_io import get_record_iterator
|
|
14
14
|
from xspect.mlst_feature.mlst_helper import MlstResult
|
|
15
|
+
from xspect.mlst_feature.pub_mlst_handler import PubMLSTHandler
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class ProbabilisticFilterMlstSchemeModel:
|
|
@@ -19,20 +20,22 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
19
20
|
|
|
20
21
|
def __init__(
|
|
21
22
|
self,
|
|
22
|
-
|
|
23
|
-
|
|
23
|
+
k_value: int,
|
|
24
|
+
model_name: str,
|
|
24
25
|
base_path: Path,
|
|
26
|
+
scheme_url: str,
|
|
25
27
|
fpr: float = 0.001,
|
|
26
28
|
) -> None:
|
|
27
29
|
"""Initialise a ProbabilisticFilterMlstSchemeModel object."""
|
|
28
|
-
if
|
|
30
|
+
if k_value < 1:
|
|
29
31
|
raise ValueError("Invalid k value, must be greater than 0")
|
|
30
32
|
if not isinstance(base_path, Path):
|
|
31
33
|
raise ValueError("Invalid base path, must be a pathlib.Path object")
|
|
32
34
|
|
|
33
|
-
self.
|
|
34
|
-
self.
|
|
35
|
+
self.k_value = k_value
|
|
36
|
+
self.model_name = model_name
|
|
35
37
|
self.base_path = base_path / "MLST"
|
|
38
|
+
self.scheme_url = scheme_url
|
|
36
39
|
self.fpr = fpr
|
|
37
40
|
self.model_type = "Strain"
|
|
38
41
|
self.loci = {}
|
|
@@ -49,9 +52,10 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
49
52
|
dict: The dictionary containing all metadata of an object.
|
|
50
53
|
"""
|
|
51
54
|
return {
|
|
52
|
-
"
|
|
53
|
-
"
|
|
55
|
+
"k_value": self.k_value,
|
|
56
|
+
"model_name": self.model_name,
|
|
54
57
|
"model_type": self.model_type,
|
|
58
|
+
"scheme_url": str(self.scheme_url),
|
|
55
59
|
"fpr": self.fpr,
|
|
56
60
|
"scheme_path": str(self.scheme_path),
|
|
57
61
|
"cobs_path": str(self.cobs_path),
|
|
@@ -115,7 +119,7 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
115
119
|
# COBS only accepts strings as paths
|
|
116
120
|
doclist = DocumentList(str(locus_path))
|
|
117
121
|
index_params = cobs_index.CompactIndexParameters()
|
|
118
|
-
index_params.term_size = self.
|
|
122
|
+
index_params.term_size = self.k_value # k-mer size
|
|
119
123
|
index_params.clobber = True # overwrite output and temporary files
|
|
120
124
|
index_params.false_positive_rate = self.fpr
|
|
121
125
|
|
|
@@ -130,9 +134,7 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
130
134
|
|
|
131
135
|
def save(self) -> None:
|
|
132
136
|
"""Saves the model to disk"""
|
|
133
|
-
scheme = str(self.scheme_path).split("/")[
|
|
134
|
-
-1
|
|
135
|
-
] # [-1] -> contains the scheme name
|
|
137
|
+
scheme = str(self.scheme_path).split("/")[-1] # [-1] contains the scheme name
|
|
136
138
|
json_path = self.base_path / scheme / f"{scheme}.json"
|
|
137
139
|
json_object = json.dumps(self.to_dict(), indent=4)
|
|
138
140
|
|
|
@@ -156,9 +158,10 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
156
158
|
json_object = file.read()
|
|
157
159
|
model_json = json.loads(json_object)
|
|
158
160
|
model = ProbabilisticFilterMlstSchemeModel(
|
|
159
|
-
model_json["
|
|
160
|
-
model_json["
|
|
161
|
+
model_json["k_value"],
|
|
162
|
+
model_json["model_name"],
|
|
161
163
|
json_path.parent,
|
|
164
|
+
model_json["scheme_url"],
|
|
162
165
|
model_json["fpr"],
|
|
163
166
|
)
|
|
164
167
|
model.scheme_path = model_json["scheme_path"]
|
|
@@ -175,7 +178,12 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
175
178
|
return model
|
|
176
179
|
|
|
177
180
|
def calculate_hits(
|
|
178
|
-
self,
|
|
181
|
+
self,
|
|
182
|
+
cobs_path: Path,
|
|
183
|
+
sequence: Seq,
|
|
184
|
+
step: int = 1,
|
|
185
|
+
limit: bool = False,
|
|
186
|
+
limit_number: int = 5,
|
|
179
187
|
) -> list[dict]:
|
|
180
188
|
"""
|
|
181
189
|
Calculates the hits for a sequence.
|
|
@@ -189,6 +197,8 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
189
197
|
cobs_path (Path): The path of the COBS-structure directory.
|
|
190
198
|
sequence (Seq): The input sequence for classification.
|
|
191
199
|
step (int, optional): The amount of kmers that are passed; defaults to one.
|
|
200
|
+
limit (bool): Applying a filter that limits the best result.
|
|
201
|
+
limit_number (int): The amount of results when the filter is set to true.
|
|
192
202
|
|
|
193
203
|
Returns:
|
|
194
204
|
list[dict]: The results of the prediction.
|
|
@@ -201,7 +211,7 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
201
211
|
if not isinstance(sequence, Seq):
|
|
202
212
|
raise ValueError("Invalid sequence, must be a Bio.Seq object")
|
|
203
213
|
|
|
204
|
-
if not len(sequence) > self.
|
|
214
|
+
if not len(sequence) > self.k_value:
|
|
205
215
|
raise ValueError("Invalid sequence, must be longer than k")
|
|
206
216
|
|
|
207
217
|
if not self.indices:
|
|
@@ -239,6 +249,10 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
239
249
|
sorted_counts = dict(
|
|
240
250
|
sorted(all_counts.items(), key=lambda item: -item[1])
|
|
241
251
|
)
|
|
252
|
+
|
|
253
|
+
if limit:
|
|
254
|
+
sorted_counts = dict(list(sorted_counts.items())[:limit_number])
|
|
255
|
+
|
|
242
256
|
if not sorted_counts:
|
|
243
257
|
result_dict = "A Strain type could not be detected because of no kmer matches!"
|
|
244
258
|
highest_results[scheme_path_list[counter]] = {"N/A": 0}
|
|
@@ -250,25 +264,37 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
250
264
|
first_key: highest_result
|
|
251
265
|
}
|
|
252
266
|
counter += 1
|
|
253
|
-
else:
|
|
267
|
+
else: # No split procedure is needed, when the sequence is short
|
|
254
268
|
for index in self.indices:
|
|
255
|
-
res = index.search(
|
|
269
|
+
res = index.search( # COBS can't handle Seq-Objects
|
|
256
270
|
str(sequence), step=step
|
|
257
|
-
) # COBS can't handle Seq-Objects
|
|
258
|
-
result_dict[scheme_path_list[counter]] = self.get_cobs_result(
|
|
259
|
-
res, False
|
|
260
271
|
)
|
|
261
|
-
|
|
262
|
-
|
|
272
|
+
result = self.get_cobs_result(res, False)
|
|
273
|
+
result = (
|
|
274
|
+
dict(sorted(result.items(), key=lambda x: -x[1])[:limit_number])
|
|
275
|
+
if limit
|
|
276
|
+
else result
|
|
263
277
|
)
|
|
278
|
+
result_dict[scheme_path_list[counter]] = result
|
|
279
|
+
first_key, highest_result = next(iter(result.items()))
|
|
264
280
|
highest_results[scheme_path_list[counter]] = {first_key: highest_result}
|
|
265
281
|
counter += 1
|
|
282
|
+
|
|
266
283
|
# check if the strain type has sufficient amount of kmer hits
|
|
267
284
|
is_valid = self.has_sufficient_score(highest_results, self.avg_locus_bp_size)
|
|
268
285
|
if not is_valid:
|
|
269
286
|
highest_results["Attention:"] = (
|
|
270
287
|
"This strain type is not reliable due to low kmer hit rates!"
|
|
271
288
|
)
|
|
289
|
+
else:
|
|
290
|
+
handler = PubMLSTHandler()
|
|
291
|
+
# allele_id is of type dict
|
|
292
|
+
flattened = {
|
|
293
|
+
locus: int(list(allele_id.keys())[0].split("_")[-1])
|
|
294
|
+
for locus, allele_id in highest_results.items()
|
|
295
|
+
}
|
|
296
|
+
strain_type_name = handler.get_strain_type_name(flattened, self.scheme_url)
|
|
297
|
+
highest_results["ST_Name"] = strain_type_name
|
|
272
298
|
return [{"Strain type": highest_results}, {"All results": result_dict}]
|
|
273
299
|
|
|
274
300
|
def predict(
|
|
@@ -282,6 +308,7 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
282
308
|
| Path
|
|
283
309
|
),
|
|
284
310
|
step: int = 1,
|
|
311
|
+
limit: bool = False,
|
|
285
312
|
) -> MlstResult:
|
|
286
313
|
"""
|
|
287
314
|
Get scores for the sequence(s) based on the filters in the model.
|
|
@@ -290,6 +317,7 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
290
317
|
cobs_path (Path): The path of the COBS-structure directory.
|
|
291
318
|
sequence_input (Seq): The input sequence for classification
|
|
292
319
|
step (int, optional): The amount of kmers that are passed; defaults to one
|
|
320
|
+
limit (bool, optional): Applying a filter that limits the best result.
|
|
293
321
|
|
|
294
322
|
Returns:
|
|
295
323
|
MlstResult: The results of the prediction.
|
|
@@ -301,13 +329,19 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
301
329
|
if sequence_input.id == "<unknown id>":
|
|
302
330
|
sequence_input.id = "test"
|
|
303
331
|
hits = {
|
|
304
|
-
sequence_input.id: self.calculate_hits(
|
|
332
|
+
sequence_input.id: self.calculate_hits(
|
|
333
|
+
cobs_path, sequence_input.seq, step, limit
|
|
334
|
+
)
|
|
305
335
|
}
|
|
306
|
-
return MlstResult(self.
|
|
336
|
+
return MlstResult(self.model_name, step, hits, None)
|
|
307
337
|
|
|
308
338
|
if isinstance(sequence_input, Path):
|
|
309
339
|
return ProbabilisticFilterMlstSchemeModel.predict(
|
|
310
|
-
self,
|
|
340
|
+
self,
|
|
341
|
+
cobs_path,
|
|
342
|
+
get_record_iterator(sequence_input),
|
|
343
|
+
step=step,
|
|
344
|
+
limit=limit,
|
|
311
345
|
)
|
|
312
346
|
|
|
313
347
|
if isinstance(
|
|
@@ -317,33 +351,35 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
317
351
|
hits = {}
|
|
318
352
|
# individual_seq is a SeqRecord-Object
|
|
319
353
|
for individual_seq in sequence_input:
|
|
320
|
-
individual_hits = self.calculate_hits(
|
|
354
|
+
individual_hits = self.calculate_hits(
|
|
355
|
+
cobs_path, individual_seq.seq, step, limit
|
|
356
|
+
)
|
|
321
357
|
hits[individual_seq.id] = individual_hits
|
|
322
|
-
return MlstResult(self.
|
|
323
|
-
|
|
358
|
+
return MlstResult(self.model_name, step, hits, None)
|
|
324
359
|
raise ValueError(
|
|
325
360
|
"Invalid sequence input, must be a Seq object, a list of Seq objects, a"
|
|
326
361
|
" SeqIO FastaIterator, or a SeqIO FastqPhredIterator"
|
|
327
362
|
)
|
|
328
363
|
|
|
329
364
|
def get_cobs_result(
|
|
330
|
-
self,
|
|
365
|
+
self,
|
|
366
|
+
cobs_result: cobs_index.SearchResult,
|
|
367
|
+
kmer_threshold: bool,
|
|
331
368
|
) -> dict:
|
|
332
369
|
"""
|
|
333
370
|
Get every entry in a COBS search result.
|
|
334
371
|
|
|
335
372
|
Args:
|
|
336
373
|
cobs_result (SearchResult): The result of the prediction.
|
|
337
|
-
kmer_threshold (bool): Applying a kmer threshold to mitigate false positives
|
|
374
|
+
kmer_threshold (bool): Applying a kmer threshold to mitigate false positives.
|
|
338
375
|
|
|
339
376
|
Returns:
|
|
340
377
|
dict: A dictionary storing the allele id of locus as key and the score as value.
|
|
341
378
|
"""
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
}
|
|
379
|
+
hits = [
|
|
380
|
+
result for result in cobs_result if not kmer_threshold or result.score > 50
|
|
381
|
+
]
|
|
382
|
+
return {result.doc_name: result.score for result in hits}
|
|
347
383
|
|
|
348
384
|
def sequence_splitter(self, input_sequence: str, allele_len: int) -> list[str]:
|
|
349
385
|
"""
|
|
@@ -379,13 +415,15 @@ class ProbabilisticFilterMlstSchemeModel:
|
|
|
379
415
|
|
|
380
416
|
while start + substring_length <= sequence_len:
|
|
381
417
|
substring_list.append(input_sequence[start : start + substring_length])
|
|
382
|
-
start +=
|
|
418
|
+
start += (
|
|
419
|
+
substring_length - self.k_value + 1
|
|
420
|
+
) # To not lose kmers when dividing
|
|
383
421
|
|
|
384
422
|
# The remaining string is either appended to the list or added to the last entry.
|
|
385
423
|
if start < len(input_sequence):
|
|
386
424
|
remaining_substring = input_sequence[start:]
|
|
387
425
|
# A substring needs to be at least of size k for COBS.
|
|
388
|
-
if len(remaining_substring) < self.
|
|
426
|
+
if len(remaining_substring) < self.k_value:
|
|
389
427
|
substring_list[-1] += remaining_substring
|
|
390
428
|
else:
|
|
391
429
|
substring_list.append(remaining_substring)
|