XspecT 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

Files changed (80) hide show
  1. xspect/classify.py +32 -0
  2. xspect/file_io.py +3 -9
  3. xspect/filter_sequences.py +56 -0
  4. xspect/main.py +52 -30
  5. xspect/mlst_feature/mlst_helper.py +102 -13
  6. xspect/mlst_feature/pub_mlst_handler.py +32 -6
  7. xspect/model_management.py +1 -15
  8. xspect/models/probabilistic_filter_mlst_model.py +160 -32
  9. xspect/models/probabilistic_filter_model.py +1 -0
  10. xspect/models/result.py +18 -6
  11. xspect/ncbi.py +8 -6
  12. xspect/train.py +13 -5
  13. xspect/web.py +173 -0
  14. xspect/xspect-web/.gitignore +24 -0
  15. xspect/xspect-web/README.md +54 -0
  16. xspect/xspect-web/components.json +21 -0
  17. xspect/xspect-web/dist/assets/index-CMG4V7fZ.js +290 -0
  18. xspect/xspect-web/dist/assets/index-jIKg1HIy.css +1 -0
  19. xspect/xspect-web/dist/index.html +14 -0
  20. xspect/xspect-web/dist/vite.svg +1 -0
  21. xspect/xspect-web/eslint.config.js +28 -0
  22. xspect/xspect-web/index.html +13 -0
  23. xspect/xspect-web/package-lock.json +6865 -0
  24. xspect/xspect-web/package.json +58 -0
  25. xspect/xspect-web/pnpm-lock.yaml +4317 -0
  26. xspect/xspect-web/public/vite.svg +1 -0
  27. xspect/xspect-web/src/App.tsx +29 -0
  28. xspect/xspect-web/src/api.tsx +62 -0
  29. xspect/xspect-web/src/assets/react.svg +1 -0
  30. xspect/xspect-web/src/components/classification-form.tsx +284 -0
  31. xspect/xspect-web/src/components/classify.tsx +18 -0
  32. xspect/xspect-web/src/components/data-table.tsx +78 -0
  33. xspect/xspect-web/src/components/dropdown-checkboxes.tsx +63 -0
  34. xspect/xspect-web/src/components/dropdown-slider.tsx +42 -0
  35. xspect/xspect-web/src/components/filter-form.tsx +423 -0
  36. xspect/xspect-web/src/components/filter.tsx +15 -0
  37. xspect/xspect-web/src/components/header.tsx +46 -0
  38. xspect/xspect-web/src/components/landing.tsx +7 -0
  39. xspect/xspect-web/src/components/models-details.tsx +138 -0
  40. xspect/xspect-web/src/components/models.tsx +53 -0
  41. xspect/xspect-web/src/components/result-chart.tsx +44 -0
  42. xspect/xspect-web/src/components/result.tsx +155 -0
  43. xspect/xspect-web/src/components/spinner.tsx +30 -0
  44. xspect/xspect-web/src/components/ui/accordion.tsx +64 -0
  45. xspect/xspect-web/src/components/ui/button.tsx +59 -0
  46. xspect/xspect-web/src/components/ui/card.tsx +92 -0
  47. xspect/xspect-web/src/components/ui/chart.tsx +351 -0
  48. xspect/xspect-web/src/components/ui/command.tsx +175 -0
  49. xspect/xspect-web/src/components/ui/dialog.tsx +135 -0
  50. xspect/xspect-web/src/components/ui/dropdown-menu.tsx +255 -0
  51. xspect/xspect-web/src/components/ui/file-upload.tsx +1459 -0
  52. xspect/xspect-web/src/components/ui/form.tsx +165 -0
  53. xspect/xspect-web/src/components/ui/input.tsx +21 -0
  54. xspect/xspect-web/src/components/ui/label.tsx +24 -0
  55. xspect/xspect-web/src/components/ui/navigation-menu.tsx +168 -0
  56. xspect/xspect-web/src/components/ui/popover.tsx +46 -0
  57. xspect/xspect-web/src/components/ui/select.tsx +183 -0
  58. xspect/xspect-web/src/components/ui/separator.tsx +26 -0
  59. xspect/xspect-web/src/components/ui/slider.tsx +61 -0
  60. xspect/xspect-web/src/components/ui/switch.tsx +29 -0
  61. xspect/xspect-web/src/components/ui/table.tsx +113 -0
  62. xspect/xspect-web/src/components/ui/tabs.tsx +64 -0
  63. xspect/xspect-web/src/index.css +120 -0
  64. xspect/xspect-web/src/lib/utils.ts +6 -0
  65. xspect/xspect-web/src/main.tsx +10 -0
  66. xspect/xspect-web/src/types.tsx +34 -0
  67. xspect/xspect-web/src/utils.tsx +6 -0
  68. xspect/xspect-web/src/vite-env.d.ts +1 -0
  69. xspect/xspect-web/tsconfig.app.json +32 -0
  70. xspect/xspect-web/tsconfig.json +13 -0
  71. xspect/xspect-web/tsconfig.node.json +24 -0
  72. xspect/xspect-web/vite.config.ts +24 -0
  73. {xspect-0.4.0.dist-info → xspect-0.5.0.dist-info}/METADATA +7 -8
  74. xspect-0.5.0.dist-info/RECORD +85 -0
  75. {xspect-0.4.0.dist-info → xspect-0.5.0.dist-info}/WHEEL +1 -1
  76. xspect/fastapi.py +0 -102
  77. xspect-0.4.0.dist-info/RECORD +0 -24
  78. {xspect-0.4.0.dist-info → xspect-0.5.0.dist-info}/entry_points.txt +0 -0
  79. {xspect-0.4.0.dist-info → xspect-0.5.0.dist-info}/licenses/LICENSE +0 -0
  80. {xspect-0.4.0.dist-info → xspect-0.5.0.dist-info}/top_level.txt +0 -0
xspect/classify.py ADDED
@@ -0,0 +1,32 @@
1
+ from pathlib import Path
2
+ from xspect.mlst_feature.mlst_helper import pick_scheme_from_models_dir
3
+ import xspect.model_management as mm
4
+ from xspect.models.probabilistic_filter_mlst_model import (
5
+ ProbabilisticFilterMlstSchemeModel,
6
+ )
7
+
8
+
9
+ def classify_genus(
10
+ model_genus: str, input_path: Path, output_path: Path, step: int = 1
11
+ ):
12
+ """Classify the input file using the genus model."""
13
+ model = mm.get_genus_model(model_genus)
14
+ result = model.predict(input_path, step=step)
15
+ result.input_source = input_path.name
16
+ result.save(output_path)
17
+
18
+
19
+ def classify_species(model_genus, input_path, output_path, step=1):
20
+ """Classify the input file using the species model."""
21
+ model = mm.get_species_model(model_genus)
22
+ result = model.predict(input_path, step=step)
23
+ result.input_source = input_path.name
24
+ result.save(output_path)
25
+
26
+
27
+ def classify_mlst(input_path, output_path):
28
+ """Classify the input file using the MLST model."""
29
+ scheme_path = pick_scheme_from_models_dir()
30
+ model = ProbabilisticFilterMlstSchemeModel.load(scheme_path)
31
+ result = model.predict(scheme_path, input_path)
32
+ result.save(output_path)
xspect/file_io.py CHANGED
@@ -20,17 +20,11 @@ def delete_zip_files(dir_path):
20
20
 
21
21
 
22
22
  def extract_zip(zip_path: Path, unzipped_path: Path):
23
- """Extracts all files from a directory with zip files."""
24
- # Make new directory.
23
+ """Extracts all files from a zip file."""
25
24
  unzipped_path.mkdir(parents=True, exist_ok=True)
26
25
 
27
- file_names = os.listdir(zip_path)
28
- for file in file_names:
29
- file_path = zip_path / file
30
- if zipfile.is_zipfile(file_path):
31
- with zipfile.ZipFile(file_path) as item:
32
- directory = unzipped_path / file.replace(".zip", "")
33
- item.extractall(directory)
26
+ with zipfile.ZipFile(zip_path) as item:
27
+ item.extractall(unzipped_path)
34
28
 
35
29
 
36
30
  def concatenate_meta(path: Path, genus: str):
@@ -0,0 +1,56 @@
1
+ from pathlib import Path
2
+ from xspect.model_management import get_genus_model, get_species_model
3
+ from xspect.file_io import filter_sequences
4
+
5
+
6
+ def filter_species(
7
+ model_genus: str,
8
+ model_species: str,
9
+ input_path: Path,
10
+ output_path: Path,
11
+ threshold: float,
12
+ ):
13
+ """Filter sequences by species.
14
+ This function filters sequences from the input file based on the species model.
15
+ It uses the genus model to identify the genus of the sequences and then applies
16
+ the species model to filter the sequences.
17
+
18
+ Args:
19
+ model_genus (str): The genus model slug.
20
+ model_species (str): The species model slug.
21
+ input_path (Path): The path to the input file containing sequences.
22
+ output_path (Path): The path to the output file where filtered sequences will be saved.
23
+ threshold (float): The threshold for filtering sequences. Only sequences with a score
24
+ above this threshold will be included in the output file.
25
+ """
26
+ species_model = get_species_model(model_genus)
27
+ result = species_model.predict(input_path)
28
+ included_ids = result.get_filtered_subsequence_labels(model_species, threshold)
29
+ if not included_ids:
30
+ print("No sequences found for the given species.")
31
+ return
32
+ filter_sequences(
33
+ input_path,
34
+ output_path,
35
+ included_ids,
36
+ )
37
+
38
+
39
+ def filter_genus(
40
+ model_genus: str,
41
+ input_path: Path,
42
+ output_path: Path,
43
+ threshold: float,
44
+ ):
45
+ genus_model = get_genus_model(model_genus)
46
+ result = genus_model.predict(Path(input_path))
47
+ included_ids = result.get_filtered_subsequence_labels(model_genus, threshold)
48
+ if not included_ids:
49
+ print("No sequences found for the given genus.")
50
+ return
51
+
52
+ filter_sequences(
53
+ input_path,
54
+ output_path,
55
+ included_ids,
56
+ )
xspect/main.py CHANGED
@@ -4,7 +4,8 @@ from pathlib import Path
4
4
  from uuid import uuid4
5
5
  import click
6
6
  import uvicorn
7
- from xspect import fastapi
7
+ from xspect import classify
8
+ from xspect.web import app
8
9
  from xspect.download_models import download_test_models
9
10
  from xspect.file_io import filter_sequences
10
11
  from xspect.train import train_from_directory, train_from_ncbi
@@ -18,6 +19,7 @@ from xspect.models.probabilistic_filter_mlst_model import (
18
19
  )
19
20
  from xspect.model_management import (
20
21
  get_genus_model,
22
+ get_model_metadata,
21
23
  get_models,
22
24
  get_species_model,
23
25
  )
@@ -32,7 +34,7 @@ def cli():
32
34
  @cli.command()
33
35
  def web():
34
36
  """Open the XspecT web application."""
35
- uvicorn.run(fastapi.app, host="0.0.0.0", port=8000)
37
+ uvicorn.run(app, host="0.0.0.0", port=8000)
36
38
 
37
39
 
38
40
  # # # # # # # # # # # # # # #
@@ -41,7 +43,6 @@ def web():
41
43
  @cli.group()
42
44
  def models():
43
45
  """Model management commands."""
44
- pass
45
46
 
46
47
 
47
48
  @models.command(
@@ -50,7 +51,7 @@ def models():
50
51
  def download():
51
52
  """Download models."""
52
53
  click.echo("Downloading models, this may take a while...")
53
- download_test_models("http://assets.adrianromberg.com/xspect-models.zip")
54
+ download_test_models("http://assets.adrianromberg.com/ake/xspect-models.zip")
54
55
 
55
56
 
56
57
  @models.command(
@@ -77,7 +78,6 @@ def list_models():
77
78
  @models.group()
78
79
  def train():
79
80
  """Train models."""
80
- pass
81
81
 
82
82
 
83
83
  @train.command(
@@ -191,16 +191,18 @@ def train_mlst(choose_schemes):
191
191
  )
192
192
  def classify_seqs():
193
193
  """Classification commands."""
194
- pass
195
194
 
196
195
 
197
- @classify_seqs.command()
196
+ @classify_seqs.command(
197
+ name="genus",
198
+ help="Classify samples using a genus model.",
199
+ )
198
200
  @click.option(
199
201
  "-g",
200
202
  "--genus",
201
203
  "model_genus",
202
204
  help="Genus of the model to classify.",
203
- type=click.Choice(get_models().get("Genus"), None),
205
+ type=click.Choice(get_models().get("Genus", [])),
204
206
  prompt=True,
205
207
  )
206
208
  @click.option(
@@ -217,22 +219,23 @@ def classify_seqs():
217
219
  type=click.Path(dir_okay=True, file_okay=True),
218
220
  default=Path(".") / f"result_{uuid4()}.json",
219
221
  )
220
- def genus(model_genus, input_path, output_path):
222
+ def classify_genus(model_genus, input_path, output_path):
221
223
  """Classify samples using a genus model."""
222
224
  click.echo("Classifying...")
223
- genus_model = get_genus_model(model_genus)
224
- result = genus_model.predict(Path(input_path))
225
- result.save(output_path)
225
+ classify.classify_genus(model_genus, Path(input_path), Path(output_path))
226
226
  click.echo(f"Result saved as {output_path}.")
227
227
 
228
228
 
229
- @classify_seqs.command()
229
+ @classify_seqs.command(
230
+ name="species",
231
+ help="Classify samples using a species model.",
232
+ )
230
233
  @click.option(
231
234
  "-g",
232
235
  "--genus",
233
236
  "model_genus",
234
237
  help="Genus of the model to classify.",
235
- type=click.Choice(get_models().get("Species"), None),
238
+ type=click.Choice(get_models().get("Species", [])),
236
239
  prompt=True,
237
240
  )
238
241
  @click.option(
@@ -252,15 +255,15 @@ def genus(model_genus, input_path, output_path):
252
255
  @click.option(
253
256
  "--sparse-sampling-step",
254
257
  type=int,
255
- help="Sparse sampling step size (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
258
+ help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
256
259
  default=1,
257
260
  )
258
- def species(model_genus, input_path, output_path, sparse_sampling_step):
261
+ def classify_species(model_genus, input_path, output_path, sparse_sampling_step):
259
262
  """Classify samples using a species model."""
260
263
  click.echo("Classifying...")
261
- species_model = get_species_model(model_genus)
262
- result = species_model.predict(Path(input_path), step=sparse_sampling_step)
263
- result.save(output_path)
264
+ classify.classify_species(
265
+ model_genus, Path(input_path), Path(output_path), sparse_sampling_step
266
+ )
264
267
  click.echo(f"Result saved as {output_path}.")
265
268
 
266
269
 
@@ -285,11 +288,7 @@ def species(model_genus, input_path, output_path, sparse_sampling_step):
285
288
  def classify_mlst(input_path, output_path):
286
289
  """MLST classify a sample."""
287
290
  click.echo("Classifying...")
288
- input_path = Path(input_path)
289
- scheme_path = pick_scheme_from_models_dir()
290
- model = ProbabilisticFilterMlstSchemeModel.load(scheme_path)
291
- result = model.predict(scheme_path, input_path)
292
- result.save(output_path)
291
+ classify.classify_mlst(Path(input_path), Path(output_path))
293
292
  click.echo(f"Result saved as {output_path}.")
294
293
 
295
294
 
@@ -302,7 +301,6 @@ def classify_mlst(input_path, output_path):
302
301
  )
303
302
  def filter_seqs():
304
303
  """Filter commands."""
305
- pass
306
304
 
307
305
 
308
306
  @filter_seqs.command(
@@ -314,7 +312,7 @@ def filter_seqs():
314
312
  "--genus",
315
313
  "model_genus",
316
314
  help="Genus of the model to use for filtering.",
317
- type=click.Choice(get_models().get("Species"), None),
315
+ type=click.Choice(get_models().get("Species", [])),
318
316
  prompt=True,
319
317
  )
320
318
  @click.option(
@@ -336,6 +334,7 @@ def filter_seqs():
336
334
  type=float,
337
335
  help="Threshold for filtering (default: 0.7).",
338
336
  default=0.7,
337
+ prompt=True,
339
338
  )
340
339
  def filter_genus(model_genus, input_path, output_path, threshold):
341
340
  """Filter samples using a genus model."""
@@ -364,16 +363,14 @@ def filter_genus(model_genus, input_path, output_path, threshold):
364
363
  "--genus",
365
364
  "model_genus",
366
365
  help="Genus of the model to use for filtering.",
367
- type=click.Choice(get_models().get("Species"), None),
366
+ type=click.Choice(get_models().get("Species", [])),
368
367
  prompt=True,
369
368
  )
370
369
  @click.option(
371
- # todo: this should be a choice of the species in the model w/ display names
372
370
  "-s",
373
371
  "--species",
374
372
  "model_species",
375
373
  help="Species of the model to filter for.",
376
- prompt=True,
377
374
  )
378
375
  @click.option(
379
376
  "-i",
@@ -392,11 +389,36 @@ def filter_genus(model_genus, input_path, output_path, threshold):
392
389
  @click.option(
393
390
  "--threshold",
394
391
  type=float,
395
- help="Threshold for filtering (default: 0.7).",
392
+ help="Threshold for filtering (default: 0.7). Use -1 to filter for the highest scoring species.",
396
393
  default=0.7,
394
+ prompt=True,
397
395
  )
398
396
  def filter_species(model_genus, model_species, input_path, output_path, threshold):
399
397
  """Filter a sample using the species model."""
398
+
399
+ available_species = get_model_metadata(f"{model_genus}-species")["display_names"]
400
+ available_species = {
401
+ id: name.replace(f"{model_genus} ", "")
402
+ for id, name in available_species.items()
403
+ }
404
+ if not model_species:
405
+ sorted_available_species = sorted(available_species.values())
406
+ model_species = click.prompt(
407
+ f"Please enter the species name: {model_genus}",
408
+ type=click.Choice(sorted_available_species, case_sensitive=False),
409
+ )
410
+ if model_species not in available_species.values():
411
+ raise click.BadParameter(
412
+ f"Species '{model_species}' not found in the {model_genus} species model."
413
+ )
414
+
415
+ # get the species ID from the name
416
+ model_species = [
417
+ id
418
+ for id, name in available_species.items()
419
+ if name.lower() == model_species.lower()
420
+ ][0]
421
+
400
422
  click.echo("Filtering...")
401
423
  species_model = get_species_model(model_genus)
402
424
  result = species_model.predict(Path(input_path))
@@ -7,11 +7,22 @@ import json
7
7
  from io import StringIO
8
8
  from pathlib import Path
9
9
  from Bio import SeqIO
10
- from xspect.definitions import get_xspect_model_path, get_xspect_runs_path
10
+ from xspect.definitions import get_xspect_model_path
11
11
 
12
12
 
13
- def create_fasta_files(locus_path: Path, fasta_batch: str):
14
- """Create Fasta-Files for every allele of a locus."""
13
+ def create_fasta_files(locus_path: Path, fasta_batch: str) -> None:
14
+ """
15
+ Create Fasta-Files for every allele of a locus.
16
+
17
+ This function creates a fasta file for each record in the batch-string of a locus.
18
+ The batch originates from an API-GET-request to PubMLST.
19
+ The files are named after the record ID.
20
+ If a fasta file already exists, it will be skipped.
21
+
22
+ Args:
23
+ locus_path (Path): The directory where the fasta-files will be saved.
24
+ fasta_batch (str): A string containing every record of a locus from PubMLST.
25
+ """
15
26
  # fasta_batch = full string of a fasta file containing every allele sequence of a locus
16
27
  for record in SeqIO.parse(StringIO(fasta_batch), "fasta"):
17
28
  number = record.id.split("_")[-1] # example id = Oxf_cpn60_263
@@ -23,7 +34,21 @@ def create_fasta_files(locus_path: Path, fasta_batch: str):
23
34
 
24
35
 
25
36
  def pick_species_number_from_db(available_species: dict) -> str:
26
- """Returns the chosen species from all available ones in the database."""
37
+ """
38
+ Get the chosen species from all available ones in the database.
39
+
40
+ This function lists all available species of PubMLST.
41
+ The user is then asked to pick a species by its associated number.
42
+
43
+ Args:
44
+ available_species (dict): A dictionary storing all available species.
45
+
46
+ Returns:
47
+ str: The name of the chosen species.
48
+
49
+ Raises:
50
+ ValueError: If the user input is not valid.
51
+ """
27
52
  # The "database" string can look like this: pubmlst_abaumannii_seqdef
28
53
  for counter, database in available_species.items():
29
54
  print(str(counter) + ":" + database.split("_")[1])
@@ -45,7 +70,21 @@ def pick_species_number_from_db(available_species: dict) -> str:
45
70
 
46
71
 
47
72
  def pick_scheme_number_from_db(available_schemes: dict) -> str:
48
- """Returns the chosen schemes from all available ones of a species."""
73
+ """
74
+ Get the chosen scheme from all available ones of a species.
75
+
76
+ This function lists all available schemes of a species.
77
+ The user is then asked to pick a scheme by its associated number.
78
+
79
+ Args:
80
+ available_schemes (dict): A dictionary storing all available schemes.
81
+
82
+ Returns:
83
+ str: The name of the chosen scheme.
84
+
85
+ Raises:
86
+ ValueError: If the user input is not valid.
87
+ """
49
88
  # List all available schemes of a species database
50
89
  for counter, scheme in available_schemes.items():
51
90
  print(str(counter) + ":" + scheme[0])
@@ -67,12 +106,28 @@ def pick_scheme_number_from_db(available_schemes: dict) -> str:
67
106
 
68
107
 
69
108
  def scheme_list_to_dict(scheme_list: list[str]):
70
- """Converts the scheme list attribute into a dictionary with a number as the key."""
109
+ """
110
+ Converts the scheme list into a dictionary.
111
+
112
+ Args:
113
+ scheme_list (list[str]): A list storing all chosen schemes.
114
+
115
+ Returns:
116
+ dict: The converted dictionary.
117
+ """
71
118
  return dict(zip(range(1, len(scheme_list) + 1), scheme_list))
72
119
 
73
120
 
74
121
  def pick_scheme_from_models_dir() -> Path:
75
- """Returns the chosen scheme from models that have been fitted prior."""
122
+ """
123
+ Get the chosen scheme from models that have been fitted prior.
124
+
125
+ This function creates a dictionary containing all trained models.
126
+ The dictionary is used as an argument for the "pick_scheme" function.
127
+
128
+ Returns:
129
+ Path: The path to the chosen model (trained).
130
+ """
76
131
  schemes = {}
77
132
  counter = 1
78
133
  for entry in sorted((get_xspect_model_path() / "MLST").iterdir()):
@@ -82,7 +137,21 @@ def pick_scheme_from_models_dir() -> Path:
82
137
 
83
138
 
84
139
  def pick_scheme(available_schemes: dict) -> Path:
85
- """Returns the chosen scheme from the scheme list."""
140
+ """
141
+ Get the chosen scheme from the scheme dictionary.
142
+
143
+ This function lists all available schemes of a species that have been downloaded.
144
+ The user is then asked to pick a scheme by its associated number.
145
+
146
+ Args:
147
+ available_schemes (dict): A dictionary storing all available schemes.
148
+
149
+ Returns:
150
+ Path: The path to the chosen model (trained).
151
+
152
+ Raises:
153
+ ValueError: If the user input is not valid or if no scheme was downloaded prior.
154
+ """
86
155
  if not available_schemes:
87
156
  raise ValueError("No scheme has been chosen for download yet!")
88
157
 
@@ -118,7 +187,7 @@ def pick_scheme(available_schemes: dict) -> Path:
118
187
 
119
188
 
120
189
  class MlstResult:
121
- """Class for storing mlst results."""
190
+ """Class for storing MLST results."""
122
191
 
123
192
  def __init__(
124
193
  self,
@@ -126,17 +195,28 @@ class MlstResult:
126
195
  steps: int,
127
196
  hits: dict[str, list[dict]],
128
197
  ):
198
+ """Initialise an MlstResult object."""
129
199
  self.scheme_model = scheme_model
130
200
  self.steps = steps
131
201
  self.hits = hits
132
202
 
133
203
  def get_results(self) -> dict:
134
- """Stores the result of a prediction in a dictionary."""
204
+ """
205
+ Stores the result of a prediction in a dictionary.
206
+
207
+ Returns:
208
+ dict: The result dictionary with s sequence ID as key and the Strain type as value.
209
+ """
135
210
  results = {seq_id: result for seq_id, result in self.hits.items()}
136
211
  return results
137
212
 
138
213
  def to_dict(self) -> dict:
139
- """Converts all attributes into one dictionary."""
214
+ """
215
+ Converts all attributes into one dictionary.
216
+
217
+ Returns:
218
+ dict: The dictionary containing all metadata of a run.
219
+ """
140
220
  result = {
141
221
  "Scheme": self.scheme_model,
142
222
  "Steps": self.steps,
@@ -144,8 +224,17 @@ class MlstResult:
144
224
  }
145
225
  return result
146
226
 
147
- def save(self, output_path: Path) -> None:
148
- """Saves the result as a JSON file."""
227
+ def save(self, output_path: Path | str) -> None:
228
+ """
229
+ Saves the result as a JSON file.
230
+
231
+ Args:
232
+ output_path (Path,str): The path where the results are saved.
233
+ """
234
+
235
+ if isinstance(output_path, str):
236
+ output_path = Path(output_path)
237
+
149
238
  output_path.parent.mkdir(exist_ok=True, parents=True)
150
239
  json_object = json.dumps(self.to_dict(), indent=4)
151
240
 
@@ -20,6 +20,7 @@ class PubMLSTHandler:
20
20
  base_url = "http://rest.pubmlst.org/db"
21
21
 
22
22
  def __init__(self):
23
+ """Initialise a PubMLSTHandler object."""
23
24
  # Default values: Oxford (1) and Pasteur (2) schemes of A.baumannii species
24
25
  self.scheme_list = [
25
26
  self.base_url + "/pubmlst_abaumannii_seqdef/schemes/1",
@@ -28,11 +29,21 @@ class PubMLSTHandler:
28
29
  self.scheme_paths = []
29
30
 
30
31
  def get_scheme_paths(self) -> dict:
31
- """Returns the scheme paths in a dictionary"""
32
+ """
33
+ Get the scheme paths in a dictionary.
34
+
35
+ Returns:
36
+ dict: A dictionary containing the scheme paths.
37
+ """
32
38
  return scheme_list_to_dict(self.scheme_paths)
33
39
 
34
40
  def choose_schemes(self) -> None:
35
- """Changes the scheme list attribute to feature other schemes from some species"""
41
+ """
42
+ Changes the scheme list attribute to feature other schemes from another species.
43
+
44
+ This function lets the user pick schemes to download all alleles that belong to it.
45
+ The scheme has to be available in the database.
46
+ """
36
47
  available_species = {}
37
48
  available_schemes = {}
38
49
  chosen_schemes = []
@@ -70,8 +81,17 @@ class PubMLSTHandler:
70
81
  break
71
82
  self.scheme_list = chosen_schemes
72
83
 
73
- def download_alleles(self, choice: False):
74
- """Downloads every allele FASTA-file from all loci of the scheme list attribute"""
84
+ def download_alleles(self, choice: False) -> None:
85
+ """
86
+ Downloads every allele FASTA-file from all loci of the scheme list attribute.
87
+
88
+ This function sends API-GET requests to PubMLST.
89
+ It downloads all alleles based on the scheme_list attribute.
90
+ The default schemes are the Oxford and Pasteur schemes of A.baumannii
91
+
92
+ Args:
93
+ choice (bool): The decision to download different schemes, defaults to False.
94
+ """
75
95
  if choice: # pick an own scheme if not Oxford or Pasteur
76
96
  self.choose_schemes() # changes the scheme_list attribute
77
97
 
@@ -98,8 +118,14 @@ class PubMLSTHandler:
98
118
  alleles = requests.get(f"{locus_url}/alleles_fasta").text
99
119
  create_fasta_files(locus_path, alleles)
100
120
 
101
- def assign_strain_type_by_db(self):
102
- """Sends an API-POST-Request to the database for MLST without bloom filters"""
121
+ def assign_strain_type_by_db(self) -> None:
122
+ """
123
+ Sends an API-POST-Request to the database for MLST without bloom filters.
124
+
125
+ This function sends API-POST requests to PubMLST.
126
+ It is a different way to determine strain types based on a BLAST-Search.
127
+ This function is only used for testing and comparing results.
128
+ """
103
129
  scheme_url = (
104
130
  str(pick_scheme(scheme_list_to_dict(self.scheme_list))) + "/sequence"
105
131
  )
@@ -2,7 +2,6 @@
2
2
 
3
3
  from json import loads, dumps
4
4
  from pathlib import Path
5
- from xspect.models.probabilistic_filter_model import ProbabilisticFilterModel
6
5
  from xspect.models.probabilistic_single_filter_model import (
7
6
  ProbabilisticSingleFilterModel,
8
7
  )
@@ -24,23 +23,10 @@ def get_species_model(genus):
24
23
  return species_filter_model
25
24
 
26
25
 
27
- def get_model_by_slug(model_slug: str):
28
- """Get a model by its slug."""
29
- model_path = get_xspect_model_path() / (model_slug + ".json")
30
- model_metadata = get_model_metadata(model_path)
31
- if model_metadata["model_class"] == "ProbabilisticSingleFilterModel":
32
- return ProbabilisticSingleFilterModel.load(model_path)
33
- if model_metadata["model_class"] == "ProbabilisticFilterSVMModel":
34
- return ProbabilisticFilterSVMModel.load(model_path)
35
- if model_metadata["model_class"] == "ProbabilisticFilterModel":
36
- return ProbabilisticFilterModel.load(model_path)
37
- raise ValueError(f"Model class {model_metadata['model_class']} not recognized.")
38
-
39
-
40
26
  def get_model_metadata(model: str | Path):
41
27
  """Get the metadata of a model."""
42
28
  if isinstance(model, str):
43
- model_path = get_xspect_model_path() / (model + ".json")
29
+ model_path = get_xspect_model_path() / (model.lower() + ".json")
44
30
  elif isinstance(model, Path):
45
31
  model_path = model
46
32
  else: