XspecT 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of XspecT might be problematic. Click here for more details.
- {XspecT-0.1.2.dist-info → XspecT-0.2.0.dist-info}/METADATA +23 -29
- XspecT-0.2.0.dist-info/RECORD +30 -0
- {XspecT-0.1.2.dist-info → XspecT-0.2.0.dist-info}/WHEEL +1 -1
- xspect/definitions.py +42 -0
- xspect/download_filters.py +11 -26
- xspect/fastapi.py +101 -0
- xspect/file_io.py +34 -103
- xspect/main.py +70 -66
- xspect/model_management.py +88 -0
- xspect/models/__init__.py +0 -0
- xspect/models/probabilistic_filter_model.py +277 -0
- xspect/models/probabilistic_filter_svm_model.py +169 -0
- xspect/models/probabilistic_single_filter_model.py +109 -0
- xspect/models/result.py +148 -0
- xspect/pipeline.py +201 -0
- xspect/run.py +38 -0
- xspect/train.py +304 -0
- xspect/train_filter/create_svm.py +6 -183
- xspect/train_filter/extract_and_concatenate.py +117 -121
- xspect/train_filter/html_scrap.py +16 -28
- xspect/train_filter/ncbi_api/download_assemblies.py +7 -8
- xspect/train_filter/ncbi_api/ncbi_assembly_metadata.py +9 -17
- xspect/train_filter/ncbi_api/ncbi_children_tree.py +3 -2
- xspect/train_filter/ncbi_api/ncbi_taxon_metadata.py +7 -5
- XspecT-0.1.2.dist-info/RECORD +0 -48
- xspect/BF_v2.py +0 -648
- xspect/Bootstrap.py +0 -29
- xspect/Classifier.py +0 -142
- xspect/OXA_Table.py +0 -53
- xspect/WebApp.py +0 -737
- xspect/XspecT_mini.py +0 -1377
- xspect/XspecT_trainer.py +0 -611
- xspect/map_kmers.py +0 -155
- xspect/search_filter.py +0 -504
- xspect/static/How-To.png +0 -0
- xspect/static/Logo.png +0 -0
- xspect/static/Logo2.png +0 -0
- xspect/static/Workflow_AspecT.png +0 -0
- xspect/static/Workflow_ClAssT.png +0 -0
- xspect/static/js.js +0 -615
- xspect/static/main.css +0 -280
- xspect/templates/400.html +0 -64
- xspect/templates/401.html +0 -62
- xspect/templates/404.html +0 -62
- xspect/templates/500.html +0 -62
- xspect/templates/about.html +0 -544
- xspect/templates/home.html +0 -51
- xspect/templates/layoutabout.html +0 -87
- xspect/templates/layouthome.html +0 -63
- xspect/templates/layoutspecies.html +0 -468
- xspect/templates/species.html +0 -33
- xspect/train_filter/get_paths.py +0 -35
- xspect/train_filter/interface_XspecT.py +0 -204
- xspect/train_filter/k_mer_count.py +0 -162
- {XspecT-0.1.2.dist-info → XspecT-0.2.0.dist-info}/LICENSE +0 -0
- {XspecT-0.1.2.dist-info → XspecT-0.2.0.dist-info}/entry_points.txt +0 -0
- {XspecT-0.1.2.dist-info → XspecT-0.2.0.dist-info}/top_level.txt +0 -0
xspect/main.py
CHANGED
|
@@ -1,11 +1,17 @@
|
|
|
1
1
|
"""Project CLI"""
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import datetime
|
|
4
5
|
import click
|
|
5
|
-
|
|
6
|
+
import uvicorn
|
|
7
|
+
from xspect import fastapi
|
|
6
8
|
from xspect.download_filters import download_test_filters
|
|
7
|
-
from xspect.
|
|
8
|
-
from xspect.
|
|
9
|
+
from xspect.train import train_ncbi
|
|
10
|
+
from xspect.models.result import (
|
|
11
|
+
StepType,
|
|
12
|
+
)
|
|
13
|
+
from xspect.definitions import get_xspect_runs_path, fasta_endings, fastq_endings
|
|
14
|
+
from xspect.pipeline import ModelExecution, Pipeline, PipelineStep
|
|
9
15
|
|
|
10
16
|
|
|
11
17
|
@click.group()
|
|
@@ -18,57 +24,60 @@ def cli():
|
|
|
18
24
|
def download_filters():
|
|
19
25
|
"""Download filters."""
|
|
20
26
|
click.echo("Downloading filters, this may take a while...")
|
|
21
|
-
download_test_filters(
|
|
22
|
-
"https://applbio.biologie.uni-frankfurt.de/download/xspect/filters.zip"
|
|
23
|
-
)
|
|
27
|
+
download_test_filters("https://xspect2.s3.eu-central-1.amazonaws.com/models.zip")
|
|
24
28
|
|
|
25
29
|
|
|
26
|
-
# todo: add read amount option -> why 342480?
|
|
27
30
|
@cli.command()
|
|
28
31
|
@click.argument("genus")
|
|
29
|
-
@click.argument("path", type=click.Path(exists=True, dir_okay=True, file_okay=
|
|
30
|
-
@click.option(
|
|
31
|
-
"-s", "--species/--no-species", help="Species classification.", default=True
|
|
32
|
-
)
|
|
33
|
-
@click.option("-i", "--ic/--no-ic", help="IC strain typing.", default=False)
|
|
34
|
-
@click.option("-o", "--oxa/--no-oxa", help="OXA gene family detection.", default=False)
|
|
32
|
+
@click.argument("path", type=click.Path(exists=True, dir_okay=True, file_okay=True))
|
|
35
33
|
@click.option(
|
|
36
34
|
"-m",
|
|
37
|
-
"--
|
|
35
|
+
"--meta/--no-meta",
|
|
38
36
|
help="Metagenome classification.",
|
|
39
37
|
default=False,
|
|
40
38
|
)
|
|
41
39
|
@click.option(
|
|
42
|
-
"-
|
|
43
|
-
"--
|
|
44
|
-
help="
|
|
45
|
-
|
|
46
|
-
default=False,
|
|
40
|
+
"-s",
|
|
41
|
+
"--step",
|
|
42
|
+
help="Sparse sampling step size (e. g. only every 500th kmer for step=500).",
|
|
43
|
+
default=1,
|
|
47
44
|
)
|
|
48
|
-
|
|
49
|
-
"
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
path
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
45
|
+
def classify(genus, path, meta, step):
|
|
46
|
+
"""Classify sample(s) from file or directory PATH."""
|
|
47
|
+
click.echo("Classifying...")
|
|
48
|
+
click.echo(f"Step: {step}")
|
|
49
|
+
|
|
50
|
+
file_paths = []
|
|
51
|
+
if Path(path).is_dir():
|
|
52
|
+
file_paths = [
|
|
53
|
+
f
|
|
54
|
+
for f in Path(path).iterdir()
|
|
55
|
+
if f.is_file() and f.suffix[1:] in fasta_endings + fastq_endings
|
|
56
|
+
]
|
|
57
|
+
else:
|
|
58
|
+
file_paths = [Path(path)]
|
|
59
|
+
|
|
60
|
+
# define pipeline
|
|
61
|
+
pipeline = Pipeline(genus + " classification", "Test Author", "test@example.com")
|
|
62
|
+
species_execution = ModelExecution(genus + "-species", sparse_sampling_step=step)
|
|
63
|
+
if meta:
|
|
64
|
+
species_filtering_step = PipelineStep(
|
|
65
|
+
StepType.FILTERING, genus, 0.7, species_execution
|
|
66
|
+
)
|
|
67
|
+
genus_execution = ModelExecution(genus + "-genus", sparse_sampling_step=step)
|
|
68
|
+
genus_execution.add_pipeline_step(species_filtering_step)
|
|
69
|
+
pipeline.add_pipeline_step(genus_execution)
|
|
70
|
+
else:
|
|
71
|
+
pipeline.add_pipeline_step(species_execution)
|
|
72
|
+
|
|
73
|
+
for idx, file_path in enumerate(file_paths):
|
|
74
|
+
run = pipeline.run(file_path)
|
|
75
|
+
time_str = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
76
|
+
save_path = get_xspect_runs_path() / f"run_{time_str}.json"
|
|
77
|
+
run.save(save_path)
|
|
78
|
+
print(
|
|
79
|
+
f"[{idx+1}/{len(file_paths)}] Run finished. Results saved to '{save_path}'."
|
|
80
|
+
)
|
|
72
81
|
|
|
73
82
|
|
|
74
83
|
@cli.command()
|
|
@@ -86,33 +95,28 @@ def classify(genus, path, species, ic, oxa, metagenome, complete, save):
|
|
|
86
95
|
type=click.Path(exists=True, dir_okay=True, file_okay=False),
|
|
87
96
|
)
|
|
88
97
|
@click.option(
|
|
89
|
-
"-
|
|
90
|
-
"--
|
|
91
|
-
help="
|
|
92
|
-
|
|
93
|
-
default=False,
|
|
98
|
+
"-s",
|
|
99
|
+
"--svm-step",
|
|
100
|
+
help="SVM Sparse sampling step size (e. g. only every 500th kmer for step=500).",
|
|
101
|
+
default=1,
|
|
94
102
|
)
|
|
95
|
-
|
|
96
|
-
"--check",
|
|
97
|
-
help="Check if metagenome file was correctly created.",
|
|
98
|
-
is_flag=True,
|
|
99
|
-
default=False,
|
|
100
|
-
)
|
|
101
|
-
def train(genus, bf_assembly_path, svm_assembly_path, complete, check):
|
|
103
|
+
def train(genus, bf_assembly_path, svm_assembly_path, svm_step):
|
|
102
104
|
"""Train model."""
|
|
103
|
-
|
|
104
|
-
if bf_assembly_path
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
105
|
+
|
|
106
|
+
if bf_assembly_path or svm_assembly_path:
|
|
107
|
+
raise NotImplementedError(
|
|
108
|
+
"Training with specific assembly paths is not yet implemented."
|
|
109
|
+
)
|
|
110
|
+
try:
|
|
111
|
+
train_ncbi(genus, svm_step=svm_step)
|
|
112
|
+
except ValueError as e:
|
|
113
|
+
raise click.ClickException(str(e)) from e
|
|
109
114
|
|
|
110
115
|
|
|
111
116
|
@cli.command()
|
|
112
|
-
def
|
|
113
|
-
"""Open the XspecT
|
|
114
|
-
|
|
115
|
-
app.run(host="0.0.0.0", port=8000, debug=True, threaded=True)
|
|
117
|
+
def api():
|
|
118
|
+
"""Open the XspecT FastAPI."""
|
|
119
|
+
uvicorn.run(fastapi.app, host="0.0.0.0", port=8000)
|
|
116
120
|
|
|
117
121
|
|
|
118
122
|
if __name__ == "__main__":
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""This module contains functions to manage models."""
|
|
2
|
+
|
|
3
|
+
from json import loads, dumps
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from xspect.models.probabilistic_filter_model import ProbabilisticFilterModel
|
|
6
|
+
from xspect.models.probabilistic_single_filter_model import (
|
|
7
|
+
ProbabilisticSingleFilterModel,
|
|
8
|
+
)
|
|
9
|
+
from xspect.models.probabilistic_filter_svm_model import ProbabilisticFilterSVMModel
|
|
10
|
+
from xspect.definitions import get_xspect_model_path
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_genus_model(genus):
|
|
14
|
+
"""Get a metagenomic model for the specified genus."""
|
|
15
|
+
genus_model_path = get_xspect_model_path() / (genus.lower() + "-genus.json")
|
|
16
|
+
genus_filter_model = ProbabilisticSingleFilterModel.load(genus_model_path)
|
|
17
|
+
return genus_filter_model
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_species_model(genus):
|
|
21
|
+
"""Get a species classification model for the specified genus."""
|
|
22
|
+
species_model_path = get_xspect_model_path() / (genus.lower() + "-species.json")
|
|
23
|
+
species_filter_model = ProbabilisticFilterSVMModel.load(species_model_path)
|
|
24
|
+
return species_filter_model
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_model_by_slug(model_slug: str):
|
|
28
|
+
"""Get a model by its slug."""
|
|
29
|
+
model_path = get_xspect_model_path() / (model_slug + ".json")
|
|
30
|
+
model_metadata = get_model_metadata(model_path)
|
|
31
|
+
if model_metadata["model_class"] == "ProbabilisticSingleFilterModel":
|
|
32
|
+
return ProbabilisticSingleFilterModel.load(model_path)
|
|
33
|
+
elif model_metadata["model_class"] == "ProbabilisticFilterSVMModel":
|
|
34
|
+
return ProbabilisticFilterSVMModel.load(model_path)
|
|
35
|
+
elif model_metadata["model_class"] == "ProbabilisticFilterModel":
|
|
36
|
+
return ProbabilisticFilterModel.load(model_path)
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError(f"Model class {model_metadata['model_class']} not recognized.")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_model_metadata(model: str | Path):
|
|
42
|
+
"""Get the metadata of a model."""
|
|
43
|
+
if isinstance(model, str):
|
|
44
|
+
model_path = get_xspect_model_path() / (model + ".json")
|
|
45
|
+
elif isinstance(model, Path):
|
|
46
|
+
model_path = model
|
|
47
|
+
else:
|
|
48
|
+
raise ValueError("Model must be a string (slug) or a Path object.")
|
|
49
|
+
|
|
50
|
+
if not model_path.exists() or not model_path.is_file():
|
|
51
|
+
raise ValueError(f"Model at {model_path} does not exist.")
|
|
52
|
+
|
|
53
|
+
with open(model_path, "r", encoding="utf-8") as file:
|
|
54
|
+
model_json = loads(file.read())
|
|
55
|
+
return model_json
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def update_model_metadata(model_slug: str, author: str, author_email: str):
|
|
59
|
+
"""Update the metadata of a model."""
|
|
60
|
+
model_metadata = get_model_metadata(model_slug)
|
|
61
|
+
model_metadata["author"] = author
|
|
62
|
+
model_metadata["author_email"] = author_email
|
|
63
|
+
|
|
64
|
+
model_path = get_xspect_model_path() / (model_slug + ".json")
|
|
65
|
+
with open(model_path, "w", encoding="utf-8") as file:
|
|
66
|
+
file.write(dumps(model_metadata, indent=4))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def update_model_display_name(model_slug: str, filter_id: str, display_name: str):
|
|
70
|
+
"""Update the display name of a filter in a model."""
|
|
71
|
+
model_metadata = get_model_metadata(model_slug)
|
|
72
|
+
model_metadata["display_names"][filter_id] = display_name
|
|
73
|
+
|
|
74
|
+
model_path = get_xspect_model_path() / (model_slug + ".json")
|
|
75
|
+
with open(model_path, "w", encoding="utf-8") as file:
|
|
76
|
+
file.write(dumps(model_metadata, indent=4))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_models():
|
|
80
|
+
"""Get a list of all available models in a dictionary by type."""
|
|
81
|
+
model_dict = {}
|
|
82
|
+
for model_file in get_xspect_model_path().glob("*.json"):
|
|
83
|
+
model_metadata = get_model_metadata(model_file)
|
|
84
|
+
model_type = model_metadata["model_type"]
|
|
85
|
+
model_dict.setdefault(model_type, []).append(
|
|
86
|
+
model_metadata["model_display_name"]
|
|
87
|
+
)
|
|
88
|
+
return model_dict
|
|
File without changes
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""Probabilistic filter model for sequence data"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from math import ceil
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from Bio.Seq import Seq
|
|
7
|
+
from Bio.SeqRecord import SeqRecord
|
|
8
|
+
from Bio import SeqIO
|
|
9
|
+
from slugify import slugify
|
|
10
|
+
import cobs_index as cobs
|
|
11
|
+
from xspect.file_io import get_record_iterator
|
|
12
|
+
from xspect.models.result import ModelResult
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ProbabilisticFilterModel:
|
|
16
|
+
"""Probabilistic filter model for sequence data"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
k: int,
|
|
21
|
+
model_display_name: str,
|
|
22
|
+
author: str,
|
|
23
|
+
author_email: str,
|
|
24
|
+
model_type: str,
|
|
25
|
+
base_path: Path,
|
|
26
|
+
fpr: float = 0.01,
|
|
27
|
+
num_hashes: int = 7,
|
|
28
|
+
) -> None:
|
|
29
|
+
if k < 1:
|
|
30
|
+
raise ValueError("Invalid k value, must be greater than 0")
|
|
31
|
+
if not model_display_name:
|
|
32
|
+
raise ValueError("Invalid filter display name, must be a non-empty string")
|
|
33
|
+
if not model_type:
|
|
34
|
+
raise ValueError("Invalid filter type, must be a non-empty string")
|
|
35
|
+
if not isinstance(base_path, Path):
|
|
36
|
+
raise ValueError("Invalid base path, must be a pathlib.Path object")
|
|
37
|
+
|
|
38
|
+
self.k = k
|
|
39
|
+
self.model_display_name = model_display_name
|
|
40
|
+
self.author = author
|
|
41
|
+
self.author_email = author_email
|
|
42
|
+
self.model_type = model_type
|
|
43
|
+
self.base_path = base_path
|
|
44
|
+
self.display_names = {}
|
|
45
|
+
self.fpr = fpr
|
|
46
|
+
self.num_hashes = num_hashes
|
|
47
|
+
self.index = None
|
|
48
|
+
|
|
49
|
+
def get_cobs_index_path(self) -> Path:
|
|
50
|
+
"""Returns the path to the cobs index"""
|
|
51
|
+
return str(self.base_path / self.slug() / "index.cobs_classic")
|
|
52
|
+
|
|
53
|
+
def to_dict(self) -> dict:
|
|
54
|
+
"""Returns a dictionary representation of the model"""
|
|
55
|
+
return {
|
|
56
|
+
"k": self.k,
|
|
57
|
+
"model_display_name": self.model_display_name,
|
|
58
|
+
"author": self.author,
|
|
59
|
+
"author_email": self.author_email,
|
|
60
|
+
"model_type": self.model_type,
|
|
61
|
+
"model_class": self.__class__.__name__,
|
|
62
|
+
"display_names": self.display_names,
|
|
63
|
+
"fpr": self.fpr,
|
|
64
|
+
"num_hashes": self.num_hashes,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
def __dict__(self) -> dict:
|
|
68
|
+
"""Returns a dictionary representation of the model"""
|
|
69
|
+
return self.to_dict()
|
|
70
|
+
|
|
71
|
+
def slug(self) -> str:
|
|
72
|
+
"""Returns a slug representation of the model"""
|
|
73
|
+
return slugify(self.model_display_name + "-" + str(self.model_type))
|
|
74
|
+
|
|
75
|
+
def fit(self, dir_path: Path, display_names: dict = None) -> None:
|
|
76
|
+
"""Adds filters to the model"""
|
|
77
|
+
|
|
78
|
+
if display_names is None:
|
|
79
|
+
display_names = {}
|
|
80
|
+
|
|
81
|
+
if not isinstance(dir_path, Path):
|
|
82
|
+
raise ValueError("Invalid directory path, must be a pathlib.Path object")
|
|
83
|
+
|
|
84
|
+
if not dir_path.exists():
|
|
85
|
+
raise ValueError("Directory path does not exist")
|
|
86
|
+
|
|
87
|
+
if not dir_path.is_dir():
|
|
88
|
+
raise ValueError("Directory path must be a directory")
|
|
89
|
+
|
|
90
|
+
doclist = cobs.DocumentList()
|
|
91
|
+
for file in dir_path.iterdir():
|
|
92
|
+
if file.is_file() and file.suffix in [
|
|
93
|
+
".fasta",
|
|
94
|
+
".fna",
|
|
95
|
+
".fa",
|
|
96
|
+
".fastq",
|
|
97
|
+
".fq",
|
|
98
|
+
]:
|
|
99
|
+
# cobs only uses the file name to the first "." as the document name
|
|
100
|
+
if file.name in display_names:
|
|
101
|
+
self.display_names[file.name.split(".")[0]] = display_names[
|
|
102
|
+
file.name
|
|
103
|
+
]
|
|
104
|
+
else:
|
|
105
|
+
self.display_names[file.name.split(".")[0]] = file.stem
|
|
106
|
+
doclist.add(str(file))
|
|
107
|
+
|
|
108
|
+
if len(doclist) == 0:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"No valid files found in directory. Must be fasta or fastq"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
index_params = cobs.ClassicIndexParameters()
|
|
114
|
+
index_params.term_size = self.k
|
|
115
|
+
index_params.num_hashes = self.num_hashes
|
|
116
|
+
index_params.false_positive_rate = self.fpr
|
|
117
|
+
index_params.clobber = True
|
|
118
|
+
|
|
119
|
+
cobs.classic_construct_list(doclist, self.get_cobs_index_path(), index_params)
|
|
120
|
+
|
|
121
|
+
self.index = cobs.Search(self.get_cobs_index_path(), True)
|
|
122
|
+
|
|
123
|
+
def calculate_hits(
|
|
124
|
+
self, sequence: Seq, filter_ids: list[str] = None, step: int = 1
|
|
125
|
+
) -> dict:
|
|
126
|
+
"""Calculates the hits for a sequence"""
|
|
127
|
+
|
|
128
|
+
if not isinstance(sequence, (Seq)):
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"Invalid sequence, must be a Bio.Seq or a Bio.SeqRecord object"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if not len(sequence) > self.k:
|
|
134
|
+
raise ValueError("Invalid sequence, must be longer than k")
|
|
135
|
+
|
|
136
|
+
r = self.index.search(str(sequence), step=step)
|
|
137
|
+
result_dict = self._convert_cobs_result_to_dict(r)
|
|
138
|
+
if filter_ids:
|
|
139
|
+
return {doc: result_dict[doc] for doc in filter_ids}
|
|
140
|
+
return result_dict
|
|
141
|
+
|
|
142
|
+
def predict(
|
|
143
|
+
self,
|
|
144
|
+
sequence_input: (
|
|
145
|
+
SeqRecord
|
|
146
|
+
| list[SeqRecord]
|
|
147
|
+
| SeqIO.FastaIO.FastaIterator
|
|
148
|
+
| SeqIO.QualityIO.FastqPhredIterator
|
|
149
|
+
| Path
|
|
150
|
+
),
|
|
151
|
+
filter_ids: list[str] = None,
|
|
152
|
+
step: int = 1,
|
|
153
|
+
) -> ModelResult:
|
|
154
|
+
"""Returns scores for the sequence(s) based on the filters in the model"""
|
|
155
|
+
if isinstance(sequence_input, (SeqRecord)):
|
|
156
|
+
return ProbabilisticFilterModel.predict(
|
|
157
|
+
self, [sequence_input], filter_ids, step=step
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if self._is_sequence_list(sequence_input) | self._is_sequence_iterator(
|
|
161
|
+
sequence_input
|
|
162
|
+
):
|
|
163
|
+
hits = {}
|
|
164
|
+
num_kmers = {}
|
|
165
|
+
for individual_sequence in sequence_input:
|
|
166
|
+
individual_hits = self.calculate_hits(
|
|
167
|
+
individual_sequence.seq, filter_ids, step=step
|
|
168
|
+
)
|
|
169
|
+
num_kmers[individual_sequence.id] = self._count_kmers(
|
|
170
|
+
individual_sequence, step=step
|
|
171
|
+
)
|
|
172
|
+
hits[individual_sequence.id] = individual_hits
|
|
173
|
+
return ModelResult(self.slug(), hits, num_kmers, sparse_sampling_step=step)
|
|
174
|
+
|
|
175
|
+
if isinstance(sequence_input, Path):
|
|
176
|
+
return ProbabilisticFilterModel.predict(
|
|
177
|
+
self, get_record_iterator(sequence_input), step=step
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
raise ValueError(
|
|
181
|
+
"Invalid sequence input, must be a Seq object, a list of Seq objects, a"
|
|
182
|
+
" SeqIO FastaIterator, a SeqIO FastqPhredIterator, or a Path object to a"
|
|
183
|
+
" fasta/fastq file"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def save(self) -> None:
|
|
187
|
+
"""Saves the model to disk"""
|
|
188
|
+
json_path = self.base_path / f"{self.slug()}.json"
|
|
189
|
+
filter_path = self.base_path / self.slug()
|
|
190
|
+
filter_path.mkdir(exist_ok=True, parents=True)
|
|
191
|
+
|
|
192
|
+
json_object = json.dumps(self.to_dict(), indent=4)
|
|
193
|
+
|
|
194
|
+
with open(json_path, "w", encoding="utf-8") as file:
|
|
195
|
+
file.write(json_object)
|
|
196
|
+
|
|
197
|
+
@staticmethod
|
|
198
|
+
def load(path: Path) -> "ProbabilisticFilterModel":
|
|
199
|
+
"""Loads the model from a file"""
|
|
200
|
+
with open(path, "r", encoding="utf-8") as file:
|
|
201
|
+
json_object = file.read()
|
|
202
|
+
model_json = json.loads(json_object)
|
|
203
|
+
model = ProbabilisticFilterModel(
|
|
204
|
+
model_json["k"],
|
|
205
|
+
model_json["model_display_name"],
|
|
206
|
+
model_json["author"],
|
|
207
|
+
model_json["author_email"],
|
|
208
|
+
model_json["model_type"],
|
|
209
|
+
path.parent,
|
|
210
|
+
model_json["fpr"],
|
|
211
|
+
model_json["num_hashes"],
|
|
212
|
+
)
|
|
213
|
+
model.display_names = model_json["display_names"]
|
|
214
|
+
|
|
215
|
+
p = model.get_cobs_index_path()
|
|
216
|
+
if not Path(p).exists():
|
|
217
|
+
raise FileNotFoundError(f"Index file not found at {p}")
|
|
218
|
+
model.index = cobs.Search(p, True)
|
|
219
|
+
|
|
220
|
+
return model
|
|
221
|
+
|
|
222
|
+
def _convert_cobs_result_to_dict(self, cobs_result: cobs.SearchResult) -> dict:
|
|
223
|
+
return {
|
|
224
|
+
individual_result.doc_name: individual_result.score
|
|
225
|
+
for individual_result in cobs_result
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
def _count_kmers(
|
|
229
|
+
self,
|
|
230
|
+
sequence_input: (
|
|
231
|
+
Seq
|
|
232
|
+
| SeqRecord
|
|
233
|
+
| list[Seq]
|
|
234
|
+
| SeqIO.FastaIO.FastaIterator
|
|
235
|
+
| SeqIO.QualityIO.FastqPhredIterator
|
|
236
|
+
),
|
|
237
|
+
step: int = 1,
|
|
238
|
+
) -> int:
|
|
239
|
+
"""Counts the number of kmers in the sequence(s)"""
|
|
240
|
+
if isinstance(sequence_input, Seq):
|
|
241
|
+
return self._count_kmers([sequence_input], step=step)
|
|
242
|
+
|
|
243
|
+
if isinstance(sequence_input, SeqRecord):
|
|
244
|
+
return self._count_kmers(sequence_input.seq, step=step)
|
|
245
|
+
|
|
246
|
+
is_sequence_list = isinstance(sequence_input, list) and all(
|
|
247
|
+
isinstance(seq, Seq) for seq in sequence_input
|
|
248
|
+
)
|
|
249
|
+
is_iterator = isinstance(
|
|
250
|
+
sequence_input,
|
|
251
|
+
(SeqIO.FastaIO.FastaIterator, SeqIO.QualityIO.FastqPhredIterator),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if is_sequence_list | is_iterator:
|
|
255
|
+
kmer_sum = 0
|
|
256
|
+
for individual_sequence in sequence_input:
|
|
257
|
+
# we need to look specifically at .seq for SeqIO iterators
|
|
258
|
+
seq = individual_sequence.seq if is_iterator else individual_sequence
|
|
259
|
+
num_kmers = ceil((len(seq) - self.k + 1) / step)
|
|
260
|
+
kmer_sum += num_kmers
|
|
261
|
+
return kmer_sum
|
|
262
|
+
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"Invalid sequence input, must be a Seq object, a list of Seq objects, a"
|
|
265
|
+
" SeqIO FastaIterator, or a SeqIO FastqPhredIterator"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def _is_sequence_list(self, sequence_input):
|
|
269
|
+
return isinstance(sequence_input, list) and all(
|
|
270
|
+
isinstance(seq, (SeqRecord)) for seq in sequence_input
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
def _is_sequence_iterator(self, sequence_input):
|
|
274
|
+
return isinstance(
|
|
275
|
+
sequence_input,
|
|
276
|
+
(SeqIO.FastaIO.FastaIterator, SeqIO.QualityIO.FastqPhredIterator),
|
|
277
|
+
)
|