XspecT 0.2.7__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

xspect/main.py CHANGED
@@ -1,28 +1,27 @@
1
1
  """Project CLI"""
2
2
 
3
3
  from pathlib import Path
4
- import datetime
5
- import uuid
4
+ from uuid import uuid4
6
5
  import click
7
6
  import uvicorn
8
7
  from xspect import fastapi
9
8
  from xspect.download_models import download_test_models
10
- from xspect.train import train_ncbi
11
- from xspect.models.result import (
12
- StepType,
13
- )
9
+ from xspect.file_io import filter_sequences
10
+ from xspect.train import train_from_directory, train_from_ncbi
14
11
  from xspect.definitions import (
15
- get_xspect_runs_path,
16
- fasta_endings,
17
- fastq_endings,
18
12
  get_xspect_model_path,
19
13
  )
20
- from xspect.pipeline import ModelExecution, Pipeline, PipelineStep
21
14
  from xspect.mlst_feature.mlst_helper import pick_scheme, pick_scheme_from_models_dir
22
15
  from xspect.mlst_feature.pub_mlst_handler import PubMLSTHandler
23
16
  from xspect.models.probabilistic_filter_mlst_model import (
24
17
  ProbabilisticFilterMlstSchemeModel,
25
18
  )
19
+ from xspect.model_management import (
20
+ get_genus_model,
21
+ get_model_metadata,
22
+ get_models,
23
+ get_species_model,
24
+ )
26
25
 
27
26
 
28
27
  @click.group()
@@ -32,103 +31,131 @@ def cli():
32
31
 
33
32
 
34
33
  @cli.command()
35
- def download_models():
34
+ def web():
35
+ """Open the XspecT web application."""
36
+ uvicorn.run(fastapi.app, host="0.0.0.0", port=8000)
37
+
38
+
39
+ # # # # # # # # # # # # # # #
40
+ # Model management commands #
41
+ # # # # # # # # # # # # # # #
42
+ @cli.group()
43
+ def models():
44
+ """Model management commands."""
45
+
46
+
47
+ @models.command(
48
+ help="Download models from the internet.",
49
+ )
50
+ def download():
36
51
  """Download models."""
37
52
  click.echo("Downloading models, this may take a while...")
38
- download_test_models("https://xspect2.s3.eu-central-1.amazonaws.com/models.zip")
53
+ download_test_models("http://assets.adrianromberg.com/xspect-models.zip")
39
54
 
40
55
 
41
- @cli.command()
42
- @click.argument("genus")
43
- @click.argument("path", type=click.Path(exists=True, dir_okay=True, file_okay=True))
56
+ @models.command(
57
+ name="list",
58
+ help="List all models in the model directory.",
59
+ )
60
+ def list_models():
61
+ """List models."""
62
+ available_models = get_models()
63
+ if not available_models:
64
+ click.echo("No models found.")
65
+ return
66
+ # todo: make this machine readable
67
+ click.echo("Models found:")
68
+ click.echo("--------------")
69
+ for model_type, names in available_models.items():
70
+ if not names:
71
+ continue
72
+ click.echo(f" {model_type}:")
73
+ for name in names:
74
+ click.echo(f" - {name}")
75
+
76
+
77
+ @models.group()
78
+ def train():
79
+ """Train models."""
80
+
81
+
82
+ @train.command(
83
+ name="ncbi",
84
+ help="Train a species and a genus model based on NCBI data.",
85
+ )
86
+ @click.option("-g", "--genus", "model_genus", prompt=True)
87
+ @click.option("--svm_steps", type=int, default=1)
44
88
  @click.option(
45
- "-m",
46
- "--meta/--no-meta",
47
- help="Metagenome classification.",
48
- default=False,
89
+ "--author",
90
+ help="Author of the model.",
91
+ default=None,
49
92
  )
50
93
  @click.option(
51
- "-s",
52
- "--step",
53
- help="Sparse sampling step size (e. g. only every 500th kmer for step=500).",
54
- default=1,
94
+ "--author-email",
95
+ help="Email of the author.",
96
+ default=None,
55
97
  )
56
- def classify_species(genus, path, meta, step):
57
- """Classify sample(s) from file or directory PATH."""
58
- click.echo("Classifying...")
59
- click.echo(f"Step: {step}")
60
-
61
- file_paths = []
62
- if Path(path).is_dir():
63
- file_paths = [
64
- f
65
- for f in Path(path).iterdir()
66
- if f.is_file() and f.suffix[1:] in fasta_endings + fastq_endings
67
- ]
68
- else:
69
- file_paths = [Path(path)]
70
-
71
- # define pipeline
72
- pipeline = Pipeline(genus + " classification", "Test Author", "test@example.com")
73
- species_execution = ModelExecution(
74
- genus.lower() + "-species", sparse_sampling_step=step
75
- )
76
- if meta:
77
- species_filtering_step = PipelineStep(
78
- StepType.FILTERING, genus, 0.7, species_execution
79
- )
80
- genus_execution = ModelExecution(
81
- genus.lower() + "-genus", sparse_sampling_step=step
82
- )
83
- genus_execution.add_pipeline_step(species_filtering_step)
84
- pipeline.add_pipeline_step(genus_execution)
85
- else:
86
- pipeline.add_pipeline_step(species_execution)
87
-
88
- for idx, file_path in enumerate(file_paths):
89
- run = pipeline.run(file_path)
90
- time_str = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
91
- save_path = get_xspect_runs_path() / f"run_{time_str}_{uuid.uuid4()}.json"
92
- run.save(save_path)
93
- print(
94
- f"[{idx+1}/{len(file_paths)}] Run finished. Results saved to '{save_path}'."
95
- )
98
+ def train_ncbi(model_genus, svm_steps, author, author_email):
99
+ """Train a species and a genus model based on NCBI data."""
100
+ click.echo(f"Training {model_genus} species and genus metagenome model.")
101
+ try:
102
+ train_from_ncbi(model_genus, svm_steps, author, author_email)
103
+ except ValueError as e:
104
+ click.echo(f"Error: {e}")
105
+ return
106
+ click.echo(f"Training of {model_genus} model finished.")
96
107
 
97
108
 
98
- @cli.command()
99
- @click.argument("genus")
109
+ @train.command(
110
+ name="directory",
111
+ help="Train a species (and possibly a genus) model based on local data.",
112
+ )
113
+ @click.option("-g", "--genus", "model_genus", prompt=True)
100
114
  @click.option(
101
- "-bf-path",
102
- "--bf-assembly-path",
103
- help="Path to assembly directory for Bloom filter training.",
104
- type=click.Path(exists=True, dir_okay=True, file_okay=False),
115
+ "-i",
116
+ "--input-path",
117
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
118
+ prompt=True,
105
119
  )
106
120
  @click.option(
107
- "-svm-path",
108
- "--svm-assembly-path",
109
- help="Path to assembly directory for SVM training.",
110
- type=click.Path(exists=True, dir_okay=True, file_okay=False),
121
+ "--meta",
122
+ is_flag=True,
123
+ help="Train a metagenome model for the genus.",
124
+ default=True,
111
125
  )
112
126
  @click.option(
113
- "-s",
114
- "--svm-step",
127
+ "--svm-steps",
128
+ type=int,
115
129
  help="SVM Sparse sampling step size (e. g. only every 500th kmer for step=500).",
116
130
  default=1,
117
131
  )
118
- def train_species(genus, bf_assembly_path, svm_assembly_path, svm_step):
119
- """Train model."""
120
-
121
- if bf_assembly_path or svm_assembly_path:
122
- raise NotImplementedError(
123
- "Training with specific assembly paths is not yet implemented."
124
- )
125
- try:
126
- train_ncbi(genus, svm_step=svm_step)
127
- except ValueError as e:
128
- raise click.ClickException(str(e)) from e
132
+ @click.option(
133
+ "--author",
134
+ help="Author of the model.",
135
+ default=None,
136
+ )
137
+ @click.option(
138
+ "--author-email",
139
+ help="Email of the author.",
140
+ default=None,
141
+ )
142
+ def train_directory(model_genus, input_path, svm_steps, meta, author, author_email):
143
+ """Train a model based on data from a directory for a given genus."""
144
+ click.echo(f"Training {model_genus} model with {svm_steps} SVM steps.")
145
+ train_from_directory(
146
+ model_genus,
147
+ Path(input_path),
148
+ svm_step=svm_steps,
149
+ meta=meta,
150
+ author=author,
151
+ author_email=author_email,
152
+ )
129
153
 
130
154
 
131
- @cli.command()
155
+ @train.command(
156
+ name="mlst",
157
+ help="Train a MLST model based on PubMLST data.",
158
+ )
132
159
  @click.option(
133
160
  "-c",
134
161
  "--choose_schemes",
@@ -154,27 +181,262 @@ def train_mlst(choose_schemes):
154
181
  click.echo(f"Saved at {model.cobs_path}")
155
182
 
156
183
 
157
- @cli.command()
184
+ # # # # # # # # # # # # # # #
185
+ # Classification commands #
186
+ # # # # # # # # # # # # # # #
187
+ @cli.group(
188
+ name="classify",
189
+ help="Classify sequences using XspecT models.",
190
+ )
191
+ def classify_seqs():
192
+ """Classification commands."""
193
+
194
+
195
+ @classify_seqs.command(
196
+ name="genus",
197
+ help="Classify samples using a genus model.",
198
+ )
199
+ @click.option(
200
+ "-g",
201
+ "--genus",
202
+ "model_genus",
203
+ help="Genus of the model to classify.",
204
+ type=click.Choice(get_models().get("Genus"), None),
205
+ prompt=True,
206
+ )
207
+ @click.option(
208
+ "-i",
209
+ "--input-path",
210
+ help="Path to FASTA or FASTQ file for classification.",
211
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
212
+ prompt=True,
213
+ )
214
+ @click.option(
215
+ "-o",
216
+ "--output-path",
217
+ help="Path to the output file.",
218
+ type=click.Path(dir_okay=True, file_okay=True),
219
+ default=Path(".") / f"result_{uuid4()}.json",
220
+ )
221
+ def classify_genus(model_genus, input_path, output_path):
222
+ """Classify samples using a genus model."""
223
+ click.echo("Classifying...")
224
+ genus_model = get_genus_model(model_genus)
225
+ result = genus_model.predict(Path(input_path))
226
+ result.save(output_path)
227
+ click.echo(f"Result saved as {output_path}.")
228
+
229
+
230
+ @classify_seqs.command(
231
+ name="species",
232
+ help="Classify samples using a species model.",
233
+ )
234
+ @click.option(
235
+ "-g",
236
+ "--genus",
237
+ "model_genus",
238
+ help="Genus of the model to classify.",
239
+ type=click.Choice(get_models().get("Species"), None),
240
+ prompt=True,
241
+ )
242
+ @click.option(
243
+ "-i",
244
+ "--input-path",
245
+ help="Path to FASTA or FASTQ file for classification.",
246
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
247
+ prompt=True,
248
+ )
249
+ @click.option(
250
+ "-o",
251
+ "--output-path",
252
+ help="Path to the output file.",
253
+ type=click.Path(dir_okay=True, file_okay=True),
254
+ default=Path(".") / f"result_{uuid4()}.json",
255
+ )
256
+ @click.option(
257
+ "--sparse-sampling-step",
258
+ type=int,
259
+ help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
260
+ default=1,
261
+ )
262
+ def classify_species(model_genus, input_path, output_path, sparse_sampling_step):
263
+ """Classify samples using a species model."""
264
+ click.echo("Classifying...")
265
+ species_model = get_species_model(model_genus)
266
+ result = species_model.predict(Path(input_path), step=sparse_sampling_step)
267
+ result.save(output_path)
268
+ click.echo(f"Result saved as {output_path}.")
269
+
270
+
271
+ @classify_seqs.command(
272
+ name="mlst",
273
+ help="Classify samples using a MLST model.",
274
+ )
158
275
  @click.option(
159
- "-p",
160
- "--path",
276
+ "-i",
277
+ "--input-path",
161
278
  help="Path to FASTA-file for mlst identification.",
162
279
  type=click.Path(exists=True, dir_okay=True, file_okay=True),
280
+ prompt=True,
281
+ )
282
+ @click.option(
283
+ "-o",
284
+ "--output-path",
285
+ help="Path to the output file.",
286
+ type=click.Path(dir_okay=True, file_okay=True),
287
+ default=Path(".") / f"result_{uuid4()}.json",
163
288
  )
164
- def classify_mlst(path):
289
+ def classify_mlst(input_path, output_path):
165
290
  """MLST classify a sample."""
166
291
  click.echo("Classifying...")
167
- path = Path(path)
292
+ input_path = Path(input_path)
168
293
  scheme_path = pick_scheme_from_models_dir()
169
294
  model = ProbabilisticFilterMlstSchemeModel.load(scheme_path)
170
- model.predict(scheme_path, path).save(model.model_display_name, path)
171
- click.echo(f"Run saved at {get_xspect_runs_path()}.")
295
+ result = model.predict(scheme_path, input_path)
296
+ result.save(output_path)
297
+ click.echo(f"Result saved as {output_path}.")
172
298
 
173
299
 
174
- @cli.command()
175
- def api():
176
- """Open the XspecT FastAPI."""
177
- uvicorn.run(fastapi.app, host="0.0.0.0", port=8000)
300
+ # # # # # # # # # # # # # # #
301
+ # Filtering commands #
302
+ # # # # # # # # # # # # # # #
303
+ @cli.group(
304
+ name="filter",
305
+ help="Filter sequences using XspecT models.",
306
+ )
307
+ def filter_seqs():
308
+ """Filter commands."""
309
+
310
+
311
+ @filter_seqs.command(
312
+ name="genus",
313
+ help="Filter sequences using a genus model.",
314
+ )
315
+ @click.option(
316
+ "-g",
317
+ "--genus",
318
+ "model_genus",
319
+ help="Genus of the model to use for filtering.",
320
+ type=click.Choice(get_models().get("Species"), None),
321
+ prompt=True,
322
+ )
323
+ @click.option(
324
+ "-i",
325
+ "--input-path",
326
+ help="Path to FASTA or FASTQ file for classification.",
327
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
328
+ prompt=True,
329
+ )
330
+ @click.option(
331
+ "-o",
332
+ "--output-path",
333
+ help="Path to the output file.",
334
+ type=click.Path(dir_okay=True, file_okay=True),
335
+ prompt=True,
336
+ )
337
+ @click.option(
338
+ "--threshold",
339
+ type=float,
340
+ help="Threshold for filtering (default: 0.7).",
341
+ default=0.7,
342
+ prompt=True,
343
+ )
344
+ def filter_genus(model_genus, input_path, output_path, threshold):
345
+ """Filter samples using a genus model."""
346
+ click.echo("Filtering...")
347
+ genus_model = get_genus_model(model_genus)
348
+ result = genus_model.predict(Path(input_path))
349
+ included_ids = result.get_filtered_subsequence_labels(model_genus, threshold)
350
+ if not included_ids:
351
+ click.echo("No sequences found for the given genus.")
352
+ return
353
+
354
+ filter_sequences(
355
+ Path(input_path),
356
+ Path(output_path),
357
+ included_ids=included_ids,
358
+ )
359
+ click.echo(f"Filtered sequences saved at {output_path}.")
360
+
361
+
362
+ @filter_seqs.command(
363
+ name="species",
364
+ help="Filter sequences using a species model.",
365
+ )
366
+ @click.option(
367
+ "-g",
368
+ "--genus",
369
+ "model_genus",
370
+ help="Genus of the model to use for filtering.",
371
+ type=click.Choice(get_models().get("Species"), None),
372
+ prompt=True,
373
+ )
374
+ @click.option(
375
+ "-s",
376
+ "--species",
377
+ "model_species",
378
+ help="Species of the model to filter for.",
379
+ )
380
+ @click.option(
381
+ "-i",
382
+ "--input-path",
383
+ help="Path to FASTA or FASTQ file for classification.",
384
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
385
+ prompt=True,
386
+ )
387
+ @click.option(
388
+ "-o",
389
+ "--output-path",
390
+ help="Path to the output file.",
391
+ type=click.Path(dir_okay=True, file_okay=True),
392
+ prompt=True,
393
+ )
394
+ @click.option(
395
+ "--threshold",
396
+ type=float,
397
+ help="Threshold for filtering (default: 0.7). Use -1 to filter for the highest scoring species.",
398
+ default=0.7,
399
+ prompt=True,
400
+ )
401
+ def filter_species(model_genus, model_species, input_path, output_path, threshold):
402
+ """Filter a sample using the species model."""
403
+
404
+ available_species = get_model_metadata(f"{model_genus}-species")["display_names"]
405
+ available_species = {
406
+ id: name.replace(f"{model_genus} ", "")
407
+ for id, name in available_species.items()
408
+ }
409
+ if not model_species:
410
+ sorted_available_species = sorted(available_species.values())
411
+ model_species = click.prompt(
412
+ f"Please enter the species name: {model_genus}",
413
+ type=click.Choice(sorted_available_species, case_sensitive=False),
414
+ )
415
+ if model_species not in available_species.values():
416
+ raise click.BadParameter(
417
+ f"Species '{model_species}' not found in the {model_genus} species model."
418
+ )
419
+
420
+ # get the species ID from the name
421
+ model_species = [
422
+ id
423
+ for id, name in available_species.items()
424
+ if name.lower() == model_species.lower()
425
+ ][0]
426
+
427
+ click.echo("Filtering...")
428
+ species_model = get_species_model(model_genus)
429
+ result = species_model.predict(Path(input_path))
430
+ included_ids = result.get_filtered_subsequence_labels(model_species, threshold)
431
+ if not included_ids:
432
+ click.echo("No sequences found for the given species.")
433
+ return
434
+ filter_sequences(
435
+ Path(input_path),
436
+ Path(output_path),
437
+ included_ids=included_ids,
438
+ )
439
+ click.echo(f"Filtered sequences saved at {output_path}.")
178
440
 
179
441
 
180
442
  if __name__ == "__main__":
@@ -144,12 +144,10 @@ class MlstResult:
144
144
  }
145
145
  return result
146
146
 
147
- def save(self, display: str, file_path: Path) -> None:
148
- """Saves the result inside the "runs" directory"""
149
- file_name = str(file_path).split("/")[-1]
150
- json_path = get_xspect_runs_path() / "MLST" / f"{file_name}-{display}.json"
151
- json_path.parent.mkdir(exist_ok=True, parents=True)
147
+ def save(self, output_path: Path) -> None:
148
+ """Saves the result as a JSON file."""
149
+ output_path.parent.mkdir(exist_ok=True, parents=True)
152
150
  json_object = json.dumps(self.to_dict(), indent=4)
153
151
 
154
- with open(json_path, "w", encoding="utf-8") as file:
152
+ with open(output_path, "w", encoding="utf-8") as file:
155
153
  file.write(json_object)
@@ -2,7 +2,6 @@
2
2
 
3
3
  from json import loads, dumps
4
4
  from pathlib import Path
5
- from xspect.models.probabilistic_filter_model import ProbabilisticFilterModel
6
5
  from xspect.models.probabilistic_single_filter_model import (
7
6
  ProbabilisticSingleFilterModel,
8
7
  )
@@ -24,23 +23,10 @@ def get_species_model(genus):
24
23
  return species_filter_model
25
24
 
26
25
 
27
- def get_model_by_slug(model_slug: str):
28
- """Get a model by its slug."""
29
- model_path = get_xspect_model_path() / (model_slug + ".json")
30
- model_metadata = get_model_metadata(model_path)
31
- if model_metadata["model_class"] == "ProbabilisticSingleFilterModel":
32
- return ProbabilisticSingleFilterModel.load(model_path)
33
- if model_metadata["model_class"] == "ProbabilisticFilterSVMModel":
34
- return ProbabilisticFilterSVMModel.load(model_path)
35
- if model_metadata["model_class"] == "ProbabilisticFilterModel":
36
- return ProbabilisticFilterModel.load(model_path)
37
- raise ValueError(f"Model class {model_metadata['model_class']} not recognized.")
38
-
39
-
40
26
  def get_model_metadata(model: str | Path):
41
27
  """Get the metadata of a model."""
42
28
  if isinstance(model, str):
43
- model_path = get_xspect_model_path() / (model + ".json")
29
+ model_path = get_xspect_model_path() / (model.lower() + ".json")
44
30
  elif isinstance(model, Path):
45
31
  model_path = model
46
32
  else:
@@ -85,3 +71,9 @@ def get_models():
85
71
  model_metadata["model_display_name"]
86
72
  )
87
73
  return model_dict
74
+
75
+
76
+ def get_model_display_names(model_slug: str):
77
+ """Get the display names included in a model."""
78
+ model_metadata = get_model_metadata(model_slug)
79
+ return list(model_metadata["display_names"].values())
@@ -26,6 +26,7 @@ class ProbabilisticFilterModel:
26
26
  base_path: Path,
27
27
  fpr: float = 0.01,
28
28
  num_hashes: int = 7,
29
+ training_accessions: dict[str, list[str]] = None,
29
30
  ) -> None:
30
31
  if k < 1:
31
32
  raise ValueError("Invalid k value, must be greater than 0")
@@ -46,6 +47,7 @@ class ProbabilisticFilterModel:
46
47
  self.fpr = fpr
47
48
  self.num_hashes = num_hashes
48
49
  self.index = None
50
+ self.training_accessions = training_accessions
49
51
 
50
52
  def get_cobs_index_path(self) -> Path:
51
53
  """Returns the path to the cobs index"""
@@ -63,13 +65,19 @@ class ProbabilisticFilterModel:
63
65
  "display_names": self.display_names,
64
66
  "fpr": self.fpr,
65
67
  "num_hashes": self.num_hashes,
68
+ "training_accessions": self.training_accessions,
66
69
  }
67
70
 
68
71
  def slug(self) -> str:
69
72
  """Returns a slug representation of the model"""
70
73
  return slugify(self.model_display_name + "-" + str(self.model_type))
71
74
 
72
- def fit(self, dir_path: Path, display_names: dict = None) -> None:
75
+ def fit(
76
+ self,
77
+ dir_path: Path,
78
+ display_names: dict = None,
79
+ training_accessions: dict[str, list[str]] = None,
80
+ ) -> None:
73
81
  """Adds filters to the model"""
74
82
 
75
83
  if display_names is None:
@@ -84,16 +92,18 @@ class ProbabilisticFilterModel:
84
92
  if not dir_path.is_dir():
85
93
  raise ValueError("Directory path must be a directory")
86
94
 
95
+ self.training_accessions = training_accessions
96
+
87
97
  doclist = cobs.DocumentList()
88
98
  for file in dir_path.iterdir():
89
99
  if file.is_file() and file.suffix[1:] in fasta_endings + fastq_endings:
90
100
  # cobs only uses the file name to the first "." as the document name
91
- if file.name in display_names:
92
- self.display_names[file.name.split(".")[0]] = display_names[
93
- file.name
101
+ if file.stem in display_names:
102
+ self.display_names[file.stem.split(".")[0]] = display_names[
103
+ file.stem
94
104
  ]
95
105
  else:
96
- self.display_names[file.name.split(".")[0]] = file.stem
106
+ self.display_names[file.stem.split(".")[0]] = file.stem
97
107
  doclist.add(str(file))
98
108
 
99
109
  if len(doclist) == 0:
@@ -200,6 +210,7 @@ class ProbabilisticFilterModel:
200
210
  path.parent,
201
211
  model_json["fpr"],
202
212
  model_json["num_hashes"],
213
+ model_json["training_accessions"],
203
214
  )
204
215
  model.display_names = model_json["display_names"]
205
216