XspecT 0.2.7__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of XspecT might be problematic. Click here for more details.
- xspect/definitions.py +0 -7
- xspect/download_models.py +25 -24
- xspect/fastapi.py +23 -26
- xspect/file_io.py +86 -2
- xspect/main.py +360 -98
- xspect/mlst_feature/mlst_helper.py +4 -6
- xspect/model_management.py +7 -15
- xspect/models/probabilistic_filter_model.py +16 -5
- xspect/models/probabilistic_filter_svm_model.py +33 -18
- xspect/models/probabilistic_single_filter_model.py +8 -1
- xspect/models/result.py +32 -66
- xspect/ncbi.py +265 -0
- xspect/train.py +258 -242
- {xspect-0.2.7.dist-info → xspect-0.4.1.dist-info}/METADATA +15 -21
- xspect-0.4.1.dist-info/RECORD +24 -0
- {xspect-0.2.7.dist-info → xspect-0.4.1.dist-info}/WHEEL +1 -1
- xspect/pipeline.py +0 -201
- xspect/run.py +0 -38
- xspect/train_filter/__init__.py +0 -0
- xspect/train_filter/create_svm.py +0 -45
- xspect/train_filter/extract_and_concatenate.py +0 -124
- xspect/train_filter/ncbi_api/__init__.py +0 -0
- xspect/train_filter/ncbi_api/download_assemblies.py +0 -31
- xspect/train_filter/ncbi_api/ncbi_assembly_metadata.py +0 -110
- xspect/train_filter/ncbi_api/ncbi_children_tree.py +0 -53
- xspect/train_filter/ncbi_api/ncbi_taxon_metadata.py +0 -55
- xspect-0.2.7.dist-info/RECORD +0 -33
- {xspect-0.2.7.dist-info → xspect-0.4.1.dist-info}/entry_points.txt +0 -0
- {xspect-0.2.7.dist-info → xspect-0.4.1.dist-info/licenses}/LICENSE +0 -0
- {xspect-0.2.7.dist-info → xspect-0.4.1.dist-info}/top_level.txt +0 -0
xspect/main.py
CHANGED
|
@@ -1,28 +1,27 @@
|
|
|
1
1
|
"""Project CLI"""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
import
|
|
5
|
-
import uuid
|
|
4
|
+
from uuid import uuid4
|
|
6
5
|
import click
|
|
7
6
|
import uvicorn
|
|
8
7
|
from xspect import fastapi
|
|
9
8
|
from xspect.download_models import download_test_models
|
|
10
|
-
from xspect.
|
|
11
|
-
from xspect.
|
|
12
|
-
StepType,
|
|
13
|
-
)
|
|
9
|
+
from xspect.file_io import filter_sequences
|
|
10
|
+
from xspect.train import train_from_directory, train_from_ncbi
|
|
14
11
|
from xspect.definitions import (
|
|
15
|
-
get_xspect_runs_path,
|
|
16
|
-
fasta_endings,
|
|
17
|
-
fastq_endings,
|
|
18
12
|
get_xspect_model_path,
|
|
19
13
|
)
|
|
20
|
-
from xspect.pipeline import ModelExecution, Pipeline, PipelineStep
|
|
21
14
|
from xspect.mlst_feature.mlst_helper import pick_scheme, pick_scheme_from_models_dir
|
|
22
15
|
from xspect.mlst_feature.pub_mlst_handler import PubMLSTHandler
|
|
23
16
|
from xspect.models.probabilistic_filter_mlst_model import (
|
|
24
17
|
ProbabilisticFilterMlstSchemeModel,
|
|
25
18
|
)
|
|
19
|
+
from xspect.model_management import (
|
|
20
|
+
get_genus_model,
|
|
21
|
+
get_model_metadata,
|
|
22
|
+
get_models,
|
|
23
|
+
get_species_model,
|
|
24
|
+
)
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
@click.group()
|
|
@@ -32,103 +31,131 @@ def cli():
|
|
|
32
31
|
|
|
33
32
|
|
|
34
33
|
@cli.command()
|
|
35
|
-
def
|
|
34
|
+
def web():
|
|
35
|
+
"""Open the XspecT web application."""
|
|
36
|
+
uvicorn.run(fastapi.app, host="0.0.0.0", port=8000)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# # # # # # # # # # # # # # #
|
|
40
|
+
# Model management commands #
|
|
41
|
+
# # # # # # # # # # # # # # #
|
|
42
|
+
@cli.group()
|
|
43
|
+
def models():
|
|
44
|
+
"""Model management commands."""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@models.command(
|
|
48
|
+
help="Download models from the internet.",
|
|
49
|
+
)
|
|
50
|
+
def download():
|
|
36
51
|
"""Download models."""
|
|
37
52
|
click.echo("Downloading models, this may take a while...")
|
|
38
|
-
download_test_models("
|
|
53
|
+
download_test_models("http://assets.adrianromberg.com/xspect-models.zip")
|
|
39
54
|
|
|
40
55
|
|
|
41
|
-
@
|
|
42
|
-
|
|
43
|
-
|
|
56
|
+
@models.command(
|
|
57
|
+
name="list",
|
|
58
|
+
help="List all models in the model directory.",
|
|
59
|
+
)
|
|
60
|
+
def list_models():
|
|
61
|
+
"""List models."""
|
|
62
|
+
available_models = get_models()
|
|
63
|
+
if not available_models:
|
|
64
|
+
click.echo("No models found.")
|
|
65
|
+
return
|
|
66
|
+
# todo: make this machine readable
|
|
67
|
+
click.echo("Models found:")
|
|
68
|
+
click.echo("--------------")
|
|
69
|
+
for model_type, names in available_models.items():
|
|
70
|
+
if not names:
|
|
71
|
+
continue
|
|
72
|
+
click.echo(f" {model_type}:")
|
|
73
|
+
for name in names:
|
|
74
|
+
click.echo(f" - {name}")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@models.group()
|
|
78
|
+
def train():
|
|
79
|
+
"""Train models."""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@train.command(
|
|
83
|
+
name="ncbi",
|
|
84
|
+
help="Train a species and a genus model based on NCBI data.",
|
|
85
|
+
)
|
|
86
|
+
@click.option("-g", "--genus", "model_genus", prompt=True)
|
|
87
|
+
@click.option("--svm_steps", type=int, default=1)
|
|
44
88
|
@click.option(
|
|
45
|
-
"
|
|
46
|
-
"
|
|
47
|
-
|
|
48
|
-
default=False,
|
|
89
|
+
"--author",
|
|
90
|
+
help="Author of the model.",
|
|
91
|
+
default=None,
|
|
49
92
|
)
|
|
50
93
|
@click.option(
|
|
51
|
-
"-
|
|
52
|
-
"
|
|
53
|
-
|
|
54
|
-
default=1,
|
|
94
|
+
"--author-email",
|
|
95
|
+
help="Email of the author.",
|
|
96
|
+
default=None,
|
|
55
97
|
)
|
|
56
|
-
def
|
|
57
|
-
"""
|
|
58
|
-
click.echo("
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
for f in Path(path).iterdir()
|
|
66
|
-
if f.is_file() and f.suffix[1:] in fasta_endings + fastq_endings
|
|
67
|
-
]
|
|
68
|
-
else:
|
|
69
|
-
file_paths = [Path(path)]
|
|
70
|
-
|
|
71
|
-
# define pipeline
|
|
72
|
-
pipeline = Pipeline(genus + " classification", "Test Author", "test@example.com")
|
|
73
|
-
species_execution = ModelExecution(
|
|
74
|
-
genus.lower() + "-species", sparse_sampling_step=step
|
|
75
|
-
)
|
|
76
|
-
if meta:
|
|
77
|
-
species_filtering_step = PipelineStep(
|
|
78
|
-
StepType.FILTERING, genus, 0.7, species_execution
|
|
79
|
-
)
|
|
80
|
-
genus_execution = ModelExecution(
|
|
81
|
-
genus.lower() + "-genus", sparse_sampling_step=step
|
|
82
|
-
)
|
|
83
|
-
genus_execution.add_pipeline_step(species_filtering_step)
|
|
84
|
-
pipeline.add_pipeline_step(genus_execution)
|
|
85
|
-
else:
|
|
86
|
-
pipeline.add_pipeline_step(species_execution)
|
|
87
|
-
|
|
88
|
-
for idx, file_path in enumerate(file_paths):
|
|
89
|
-
run = pipeline.run(file_path)
|
|
90
|
-
time_str = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
91
|
-
save_path = get_xspect_runs_path() / f"run_{time_str}_{uuid.uuid4()}.json"
|
|
92
|
-
run.save(save_path)
|
|
93
|
-
print(
|
|
94
|
-
f"[{idx+1}/{len(file_paths)}] Run finished. Results saved to '{save_path}'."
|
|
95
|
-
)
|
|
98
|
+
def train_ncbi(model_genus, svm_steps, author, author_email):
|
|
99
|
+
"""Train a species and a genus model based on NCBI data."""
|
|
100
|
+
click.echo(f"Training {model_genus} species and genus metagenome model.")
|
|
101
|
+
try:
|
|
102
|
+
train_from_ncbi(model_genus, svm_steps, author, author_email)
|
|
103
|
+
except ValueError as e:
|
|
104
|
+
click.echo(f"Error: {e}")
|
|
105
|
+
return
|
|
106
|
+
click.echo(f"Training of {model_genus} model finished.")
|
|
96
107
|
|
|
97
108
|
|
|
98
|
-
@
|
|
99
|
-
|
|
109
|
+
@train.command(
|
|
110
|
+
name="directory",
|
|
111
|
+
help="Train a species (and possibly a genus) model based on local data.",
|
|
112
|
+
)
|
|
113
|
+
@click.option("-g", "--genus", "model_genus", prompt=True)
|
|
100
114
|
@click.option(
|
|
101
|
-
"-
|
|
102
|
-
"--
|
|
103
|
-
|
|
104
|
-
|
|
115
|
+
"-i",
|
|
116
|
+
"--input-path",
|
|
117
|
+
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
118
|
+
prompt=True,
|
|
105
119
|
)
|
|
106
120
|
@click.option(
|
|
107
|
-
"
|
|
108
|
-
|
|
109
|
-
help="
|
|
110
|
-
|
|
121
|
+
"--meta",
|
|
122
|
+
is_flag=True,
|
|
123
|
+
help="Train a metagenome model for the genus.",
|
|
124
|
+
default=True,
|
|
111
125
|
)
|
|
112
126
|
@click.option(
|
|
113
|
-
"-
|
|
114
|
-
|
|
127
|
+
"--svm-steps",
|
|
128
|
+
type=int,
|
|
115
129
|
help="SVM Sparse sampling step size (e. g. only every 500th kmer for step=500).",
|
|
116
130
|
default=1,
|
|
117
131
|
)
|
|
118
|
-
|
|
119
|
-
""
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
132
|
+
@click.option(
|
|
133
|
+
"--author",
|
|
134
|
+
help="Author of the model.",
|
|
135
|
+
default=None,
|
|
136
|
+
)
|
|
137
|
+
@click.option(
|
|
138
|
+
"--author-email",
|
|
139
|
+
help="Email of the author.",
|
|
140
|
+
default=None,
|
|
141
|
+
)
|
|
142
|
+
def train_directory(model_genus, input_path, svm_steps, meta, author, author_email):
|
|
143
|
+
"""Train a model based on data from a directory for a given genus."""
|
|
144
|
+
click.echo(f"Training {model_genus} model with {svm_steps} SVM steps.")
|
|
145
|
+
train_from_directory(
|
|
146
|
+
model_genus,
|
|
147
|
+
Path(input_path),
|
|
148
|
+
svm_step=svm_steps,
|
|
149
|
+
meta=meta,
|
|
150
|
+
author=author,
|
|
151
|
+
author_email=author_email,
|
|
152
|
+
)
|
|
129
153
|
|
|
130
154
|
|
|
131
|
-
@
|
|
155
|
+
@train.command(
|
|
156
|
+
name="mlst",
|
|
157
|
+
help="Train a MLST model based on PubMLST data.",
|
|
158
|
+
)
|
|
132
159
|
@click.option(
|
|
133
160
|
"-c",
|
|
134
161
|
"--choose_schemes",
|
|
@@ -154,27 +181,262 @@ def train_mlst(choose_schemes):
|
|
|
154
181
|
click.echo(f"Saved at {model.cobs_path}")
|
|
155
182
|
|
|
156
183
|
|
|
157
|
-
|
|
184
|
+
# # # # # # # # # # # # # # #
|
|
185
|
+
# Classification commands #
|
|
186
|
+
# # # # # # # # # # # # # # #
|
|
187
|
+
@cli.group(
|
|
188
|
+
name="classify",
|
|
189
|
+
help="Classify sequences using XspecT models.",
|
|
190
|
+
)
|
|
191
|
+
def classify_seqs():
|
|
192
|
+
"""Classification commands."""
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@classify_seqs.command(
|
|
196
|
+
name="genus",
|
|
197
|
+
help="Classify samples using a genus model.",
|
|
198
|
+
)
|
|
199
|
+
@click.option(
|
|
200
|
+
"-g",
|
|
201
|
+
"--genus",
|
|
202
|
+
"model_genus",
|
|
203
|
+
help="Genus of the model to classify.",
|
|
204
|
+
type=click.Choice(get_models().get("Genus"), None),
|
|
205
|
+
prompt=True,
|
|
206
|
+
)
|
|
207
|
+
@click.option(
|
|
208
|
+
"-i",
|
|
209
|
+
"--input-path",
|
|
210
|
+
help="Path to FASTA or FASTQ file for classification.",
|
|
211
|
+
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
212
|
+
prompt=True,
|
|
213
|
+
)
|
|
214
|
+
@click.option(
|
|
215
|
+
"-o",
|
|
216
|
+
"--output-path",
|
|
217
|
+
help="Path to the output file.",
|
|
218
|
+
type=click.Path(dir_okay=True, file_okay=True),
|
|
219
|
+
default=Path(".") / f"result_{uuid4()}.json",
|
|
220
|
+
)
|
|
221
|
+
def classify_genus(model_genus, input_path, output_path):
|
|
222
|
+
"""Classify samples using a genus model."""
|
|
223
|
+
click.echo("Classifying...")
|
|
224
|
+
genus_model = get_genus_model(model_genus)
|
|
225
|
+
result = genus_model.predict(Path(input_path))
|
|
226
|
+
result.save(output_path)
|
|
227
|
+
click.echo(f"Result saved as {output_path}.")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@classify_seqs.command(
|
|
231
|
+
name="species",
|
|
232
|
+
help="Classify samples using a species model.",
|
|
233
|
+
)
|
|
234
|
+
@click.option(
|
|
235
|
+
"-g",
|
|
236
|
+
"--genus",
|
|
237
|
+
"model_genus",
|
|
238
|
+
help="Genus of the model to classify.",
|
|
239
|
+
type=click.Choice(get_models().get("Species"), None),
|
|
240
|
+
prompt=True,
|
|
241
|
+
)
|
|
242
|
+
@click.option(
|
|
243
|
+
"-i",
|
|
244
|
+
"--input-path",
|
|
245
|
+
help="Path to FASTA or FASTQ file for classification.",
|
|
246
|
+
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
247
|
+
prompt=True,
|
|
248
|
+
)
|
|
249
|
+
@click.option(
|
|
250
|
+
"-o",
|
|
251
|
+
"--output-path",
|
|
252
|
+
help="Path to the output file.",
|
|
253
|
+
type=click.Path(dir_okay=True, file_okay=True),
|
|
254
|
+
default=Path(".") / f"result_{uuid4()}.json",
|
|
255
|
+
)
|
|
256
|
+
@click.option(
|
|
257
|
+
"--sparse-sampling-step",
|
|
258
|
+
type=int,
|
|
259
|
+
help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
|
|
260
|
+
default=1,
|
|
261
|
+
)
|
|
262
|
+
def classify_species(model_genus, input_path, output_path, sparse_sampling_step):
|
|
263
|
+
"""Classify samples using a species model."""
|
|
264
|
+
click.echo("Classifying...")
|
|
265
|
+
species_model = get_species_model(model_genus)
|
|
266
|
+
result = species_model.predict(Path(input_path), step=sparse_sampling_step)
|
|
267
|
+
result.save(output_path)
|
|
268
|
+
click.echo(f"Result saved as {output_path}.")
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@classify_seqs.command(
|
|
272
|
+
name="mlst",
|
|
273
|
+
help="Classify samples using a MLST model.",
|
|
274
|
+
)
|
|
158
275
|
@click.option(
|
|
159
|
-
"-
|
|
160
|
-
"--path",
|
|
276
|
+
"-i",
|
|
277
|
+
"--input-path",
|
|
161
278
|
help="Path to FASTA-file for mlst identification.",
|
|
162
279
|
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
280
|
+
prompt=True,
|
|
281
|
+
)
|
|
282
|
+
@click.option(
|
|
283
|
+
"-o",
|
|
284
|
+
"--output-path",
|
|
285
|
+
help="Path to the output file.",
|
|
286
|
+
type=click.Path(dir_okay=True, file_okay=True),
|
|
287
|
+
default=Path(".") / f"result_{uuid4()}.json",
|
|
163
288
|
)
|
|
164
|
-
def classify_mlst(
|
|
289
|
+
def classify_mlst(input_path, output_path):
|
|
165
290
|
"""MLST classify a sample."""
|
|
166
291
|
click.echo("Classifying...")
|
|
167
|
-
|
|
292
|
+
input_path = Path(input_path)
|
|
168
293
|
scheme_path = pick_scheme_from_models_dir()
|
|
169
294
|
model = ProbabilisticFilterMlstSchemeModel.load(scheme_path)
|
|
170
|
-
model.predict(scheme_path,
|
|
171
|
-
|
|
295
|
+
result = model.predict(scheme_path, input_path)
|
|
296
|
+
result.save(output_path)
|
|
297
|
+
click.echo(f"Result saved as {output_path}.")
|
|
172
298
|
|
|
173
299
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
300
|
+
# # # # # # # # # # # # # # #
|
|
301
|
+
# Filtering commands #
|
|
302
|
+
# # # # # # # # # # # # # # #
|
|
303
|
+
@cli.group(
|
|
304
|
+
name="filter",
|
|
305
|
+
help="Filter sequences using XspecT models.",
|
|
306
|
+
)
|
|
307
|
+
def filter_seqs():
|
|
308
|
+
"""Filter commands."""
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
@filter_seqs.command(
|
|
312
|
+
name="genus",
|
|
313
|
+
help="Filter sequences using a genus model.",
|
|
314
|
+
)
|
|
315
|
+
@click.option(
|
|
316
|
+
"-g",
|
|
317
|
+
"--genus",
|
|
318
|
+
"model_genus",
|
|
319
|
+
help="Genus of the model to use for filtering.",
|
|
320
|
+
type=click.Choice(get_models().get("Species"), None),
|
|
321
|
+
prompt=True,
|
|
322
|
+
)
|
|
323
|
+
@click.option(
|
|
324
|
+
"-i",
|
|
325
|
+
"--input-path",
|
|
326
|
+
help="Path to FASTA or FASTQ file for classification.",
|
|
327
|
+
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
328
|
+
prompt=True,
|
|
329
|
+
)
|
|
330
|
+
@click.option(
|
|
331
|
+
"-o",
|
|
332
|
+
"--output-path",
|
|
333
|
+
help="Path to the output file.",
|
|
334
|
+
type=click.Path(dir_okay=True, file_okay=True),
|
|
335
|
+
prompt=True,
|
|
336
|
+
)
|
|
337
|
+
@click.option(
|
|
338
|
+
"--threshold",
|
|
339
|
+
type=float,
|
|
340
|
+
help="Threshold for filtering (default: 0.7).",
|
|
341
|
+
default=0.7,
|
|
342
|
+
prompt=True,
|
|
343
|
+
)
|
|
344
|
+
def filter_genus(model_genus, input_path, output_path, threshold):
|
|
345
|
+
"""Filter samples using a genus model."""
|
|
346
|
+
click.echo("Filtering...")
|
|
347
|
+
genus_model = get_genus_model(model_genus)
|
|
348
|
+
result = genus_model.predict(Path(input_path))
|
|
349
|
+
included_ids = result.get_filtered_subsequence_labels(model_genus, threshold)
|
|
350
|
+
if not included_ids:
|
|
351
|
+
click.echo("No sequences found for the given genus.")
|
|
352
|
+
return
|
|
353
|
+
|
|
354
|
+
filter_sequences(
|
|
355
|
+
Path(input_path),
|
|
356
|
+
Path(output_path),
|
|
357
|
+
included_ids=included_ids,
|
|
358
|
+
)
|
|
359
|
+
click.echo(f"Filtered sequences saved at {output_path}.")
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@filter_seqs.command(
|
|
363
|
+
name="species",
|
|
364
|
+
help="Filter sequences using a species model.",
|
|
365
|
+
)
|
|
366
|
+
@click.option(
|
|
367
|
+
"-g",
|
|
368
|
+
"--genus",
|
|
369
|
+
"model_genus",
|
|
370
|
+
help="Genus of the model to use for filtering.",
|
|
371
|
+
type=click.Choice(get_models().get("Species"), None),
|
|
372
|
+
prompt=True,
|
|
373
|
+
)
|
|
374
|
+
@click.option(
|
|
375
|
+
"-s",
|
|
376
|
+
"--species",
|
|
377
|
+
"model_species",
|
|
378
|
+
help="Species of the model to filter for.",
|
|
379
|
+
)
|
|
380
|
+
@click.option(
|
|
381
|
+
"-i",
|
|
382
|
+
"--input-path",
|
|
383
|
+
help="Path to FASTA or FASTQ file for classification.",
|
|
384
|
+
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
385
|
+
prompt=True,
|
|
386
|
+
)
|
|
387
|
+
@click.option(
|
|
388
|
+
"-o",
|
|
389
|
+
"--output-path",
|
|
390
|
+
help="Path to the output file.",
|
|
391
|
+
type=click.Path(dir_okay=True, file_okay=True),
|
|
392
|
+
prompt=True,
|
|
393
|
+
)
|
|
394
|
+
@click.option(
|
|
395
|
+
"--threshold",
|
|
396
|
+
type=float,
|
|
397
|
+
help="Threshold for filtering (default: 0.7). Use -1 to filter for the highest scoring species.",
|
|
398
|
+
default=0.7,
|
|
399
|
+
prompt=True,
|
|
400
|
+
)
|
|
401
|
+
def filter_species(model_genus, model_species, input_path, output_path, threshold):
|
|
402
|
+
"""Filter a sample using the species model."""
|
|
403
|
+
|
|
404
|
+
available_species = get_model_metadata(f"{model_genus}-species")["display_names"]
|
|
405
|
+
available_species = {
|
|
406
|
+
id: name.replace(f"{model_genus} ", "")
|
|
407
|
+
for id, name in available_species.items()
|
|
408
|
+
}
|
|
409
|
+
if not model_species:
|
|
410
|
+
sorted_available_species = sorted(available_species.values())
|
|
411
|
+
model_species = click.prompt(
|
|
412
|
+
f"Please enter the species name: {model_genus}",
|
|
413
|
+
type=click.Choice(sorted_available_species, case_sensitive=False),
|
|
414
|
+
)
|
|
415
|
+
if model_species not in available_species.values():
|
|
416
|
+
raise click.BadParameter(
|
|
417
|
+
f"Species '{model_species}' not found in the {model_genus} species model."
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# get the species ID from the name
|
|
421
|
+
model_species = [
|
|
422
|
+
id
|
|
423
|
+
for id, name in available_species.items()
|
|
424
|
+
if name.lower() == model_species.lower()
|
|
425
|
+
][0]
|
|
426
|
+
|
|
427
|
+
click.echo("Filtering...")
|
|
428
|
+
species_model = get_species_model(model_genus)
|
|
429
|
+
result = species_model.predict(Path(input_path))
|
|
430
|
+
included_ids = result.get_filtered_subsequence_labels(model_species, threshold)
|
|
431
|
+
if not included_ids:
|
|
432
|
+
click.echo("No sequences found for the given species.")
|
|
433
|
+
return
|
|
434
|
+
filter_sequences(
|
|
435
|
+
Path(input_path),
|
|
436
|
+
Path(output_path),
|
|
437
|
+
included_ids=included_ids,
|
|
438
|
+
)
|
|
439
|
+
click.echo(f"Filtered sequences saved at {output_path}.")
|
|
178
440
|
|
|
179
441
|
|
|
180
442
|
if __name__ == "__main__":
|
|
@@ -144,12 +144,10 @@ class MlstResult:
|
|
|
144
144
|
}
|
|
145
145
|
return result
|
|
146
146
|
|
|
147
|
-
def save(self,
|
|
148
|
-
"""Saves the result
|
|
149
|
-
|
|
150
|
-
json_path = get_xspect_runs_path() / "MLST" / f"{file_name}-{display}.json"
|
|
151
|
-
json_path.parent.mkdir(exist_ok=True, parents=True)
|
|
147
|
+
def save(self, output_path: Path) -> None:
|
|
148
|
+
"""Saves the result as a JSON file."""
|
|
149
|
+
output_path.parent.mkdir(exist_ok=True, parents=True)
|
|
152
150
|
json_object = json.dumps(self.to_dict(), indent=4)
|
|
153
151
|
|
|
154
|
-
with open(
|
|
152
|
+
with open(output_path, "w", encoding="utf-8") as file:
|
|
155
153
|
file.write(json_object)
|
xspect/model_management.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from json import loads, dumps
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from xspect.models.probabilistic_filter_model import ProbabilisticFilterModel
|
|
6
5
|
from xspect.models.probabilistic_single_filter_model import (
|
|
7
6
|
ProbabilisticSingleFilterModel,
|
|
8
7
|
)
|
|
@@ -24,23 +23,10 @@ def get_species_model(genus):
|
|
|
24
23
|
return species_filter_model
|
|
25
24
|
|
|
26
25
|
|
|
27
|
-
def get_model_by_slug(model_slug: str):
|
|
28
|
-
"""Get a model by its slug."""
|
|
29
|
-
model_path = get_xspect_model_path() / (model_slug + ".json")
|
|
30
|
-
model_metadata = get_model_metadata(model_path)
|
|
31
|
-
if model_metadata["model_class"] == "ProbabilisticSingleFilterModel":
|
|
32
|
-
return ProbabilisticSingleFilterModel.load(model_path)
|
|
33
|
-
if model_metadata["model_class"] == "ProbabilisticFilterSVMModel":
|
|
34
|
-
return ProbabilisticFilterSVMModel.load(model_path)
|
|
35
|
-
if model_metadata["model_class"] == "ProbabilisticFilterModel":
|
|
36
|
-
return ProbabilisticFilterModel.load(model_path)
|
|
37
|
-
raise ValueError(f"Model class {model_metadata['model_class']} not recognized.")
|
|
38
|
-
|
|
39
|
-
|
|
40
26
|
def get_model_metadata(model: str | Path):
|
|
41
27
|
"""Get the metadata of a model."""
|
|
42
28
|
if isinstance(model, str):
|
|
43
|
-
model_path = get_xspect_model_path() / (model + ".json")
|
|
29
|
+
model_path = get_xspect_model_path() / (model.lower() + ".json")
|
|
44
30
|
elif isinstance(model, Path):
|
|
45
31
|
model_path = model
|
|
46
32
|
else:
|
|
@@ -85,3 +71,9 @@ def get_models():
|
|
|
85
71
|
model_metadata["model_display_name"]
|
|
86
72
|
)
|
|
87
73
|
return model_dict
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_model_display_names(model_slug: str):
|
|
77
|
+
"""Get the display names included in a model."""
|
|
78
|
+
model_metadata = get_model_metadata(model_slug)
|
|
79
|
+
return list(model_metadata["display_names"].values())
|
|
@@ -26,6 +26,7 @@ class ProbabilisticFilterModel:
|
|
|
26
26
|
base_path: Path,
|
|
27
27
|
fpr: float = 0.01,
|
|
28
28
|
num_hashes: int = 7,
|
|
29
|
+
training_accessions: dict[str, list[str]] = None,
|
|
29
30
|
) -> None:
|
|
30
31
|
if k < 1:
|
|
31
32
|
raise ValueError("Invalid k value, must be greater than 0")
|
|
@@ -46,6 +47,7 @@ class ProbabilisticFilterModel:
|
|
|
46
47
|
self.fpr = fpr
|
|
47
48
|
self.num_hashes = num_hashes
|
|
48
49
|
self.index = None
|
|
50
|
+
self.training_accessions = training_accessions
|
|
49
51
|
|
|
50
52
|
def get_cobs_index_path(self) -> Path:
|
|
51
53
|
"""Returns the path to the cobs index"""
|
|
@@ -63,13 +65,19 @@ class ProbabilisticFilterModel:
|
|
|
63
65
|
"display_names": self.display_names,
|
|
64
66
|
"fpr": self.fpr,
|
|
65
67
|
"num_hashes": self.num_hashes,
|
|
68
|
+
"training_accessions": self.training_accessions,
|
|
66
69
|
}
|
|
67
70
|
|
|
68
71
|
def slug(self) -> str:
|
|
69
72
|
"""Returns a slug representation of the model"""
|
|
70
73
|
return slugify(self.model_display_name + "-" + str(self.model_type))
|
|
71
74
|
|
|
72
|
-
def fit(
|
|
75
|
+
def fit(
|
|
76
|
+
self,
|
|
77
|
+
dir_path: Path,
|
|
78
|
+
display_names: dict = None,
|
|
79
|
+
training_accessions: dict[str, list[str]] = None,
|
|
80
|
+
) -> None:
|
|
73
81
|
"""Adds filters to the model"""
|
|
74
82
|
|
|
75
83
|
if display_names is None:
|
|
@@ -84,16 +92,18 @@ class ProbabilisticFilterModel:
|
|
|
84
92
|
if not dir_path.is_dir():
|
|
85
93
|
raise ValueError("Directory path must be a directory")
|
|
86
94
|
|
|
95
|
+
self.training_accessions = training_accessions
|
|
96
|
+
|
|
87
97
|
doclist = cobs.DocumentList()
|
|
88
98
|
for file in dir_path.iterdir():
|
|
89
99
|
if file.is_file() and file.suffix[1:] in fasta_endings + fastq_endings:
|
|
90
100
|
# cobs only uses the file name to the first "." as the document name
|
|
91
|
-
if file.
|
|
92
|
-
self.display_names[file.
|
|
93
|
-
file.
|
|
101
|
+
if file.stem in display_names:
|
|
102
|
+
self.display_names[file.stem.split(".")[0]] = display_names[
|
|
103
|
+
file.stem
|
|
94
104
|
]
|
|
95
105
|
else:
|
|
96
|
-
self.display_names[file.
|
|
106
|
+
self.display_names[file.stem.split(".")[0]] = file.stem
|
|
97
107
|
doclist.add(str(file))
|
|
98
108
|
|
|
99
109
|
if len(doclist) == 0:
|
|
@@ -200,6 +210,7 @@ class ProbabilisticFilterModel:
|
|
|
200
210
|
path.parent,
|
|
201
211
|
model_json["fpr"],
|
|
202
212
|
model_json["num_hashes"],
|
|
213
|
+
model_json["training_accessions"],
|
|
203
214
|
)
|
|
204
215
|
model.display_names = model_json["display_names"]
|
|
205
216
|
|