XspecT 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

Files changed (33) hide show
  1. xspect/classify.py +51 -38
  2. xspect/definitions.py +50 -10
  3. xspect/download_models.py +10 -2
  4. xspect/file_io.py +115 -48
  5. xspect/filter_sequences.py +36 -66
  6. xspect/main.py +41 -10
  7. xspect/mlst_feature/mlst_helper.py +3 -0
  8. xspect/mlst_feature/pub_mlst_handler.py +43 -1
  9. xspect/model_management.py +84 -14
  10. xspect/models/probabilistic_filter_mlst_model.py +75 -37
  11. xspect/models/probabilistic_filter_model.py +194 -12
  12. xspect/models/probabilistic_filter_svm_model.py +99 -6
  13. xspect/models/probabilistic_single_filter_model.py +66 -5
  14. xspect/models/result.py +77 -10
  15. xspect/ncbi.py +45 -10
  16. xspect/train.py +2 -1
  17. xspect/web.py +68 -12
  18. xspect/xspect-web/dist/assets/index-Ceo58xui.css +1 -0
  19. xspect/xspect-web/dist/assets/{index-CMG4V7fZ.js → index-Dt_UlbgE.js} +82 -77
  20. xspect/xspect-web/dist/index.html +2 -2
  21. xspect/xspect-web/src/App.tsx +4 -2
  22. xspect/xspect-web/src/api.tsx +23 -1
  23. xspect/xspect-web/src/components/filter-form.tsx +16 -3
  24. xspect/xspect-web/src/components/filtering-result.tsx +65 -0
  25. xspect/xspect-web/src/components/result.tsx +2 -2
  26. xspect/xspect-web/src/types.tsx +5 -0
  27. {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/METADATA +1 -1
  28. {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/RECORD +32 -31
  29. {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/WHEEL +1 -1
  30. xspect/xspect-web/dist/assets/index-jIKg1HIy.css +0 -1
  31. {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/entry_points.txt +0 -0
  32. {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/licenses/LICENSE +0 -0
  33. {xspect-0.5.1.dist-info → xspect-0.5.2.dist-info}/top_level.txt +0 -0
xspect/main.py CHANGED
@@ -18,10 +18,8 @@ from xspect.models.probabilistic_filter_mlst_model import (
18
18
  ProbabilisticFilterMlstSchemeModel,
19
19
  )
20
20
  from xspect.model_management import (
21
- get_genus_model,
22
21
  get_model_metadata,
23
22
  get_models,
24
- get_species_model,
25
23
  )
26
24
 
27
25
 
@@ -173,8 +171,9 @@ def train_mlst(choose_schemes):
173
171
  scheme_path = pick_scheme(handler.get_scheme_paths())
174
172
  species_name = str(scheme_path).split("/")[-2]
175
173
  scheme_name = str(scheme_path).split("/")[-1]
174
+ scheme_url = handler.scheme_mapping[str(scheme_path)]
176
175
  model = ProbabilisticFilterMlstSchemeModel(
177
- 31, f"{species_name}:{scheme_name}", get_xspect_model_path()
176
+ 31, f"{species_name}:{scheme_name}", get_xspect_model_path(), scheme_url
178
177
  )
179
178
  click.echo("Creating mlst model")
180
179
  model.fit(scheme_path)
@@ -220,10 +219,18 @@ def classify_seqs():
220
219
  type=click.Path(dir_okay=False, file_okay=True),
221
220
  default=Path(".") / f"result_{uuid4()}.json",
222
221
  )
223
- def classify_genus(model_genus, input_path, output_path):
222
+ @click.option(
223
+ "--sparse-sampling-step",
224
+ type=int,
225
+ help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
226
+ default=1,
227
+ )
228
+ def classify_genus(model_genus, input_path, output_path, sparse_sampling_step):
224
229
  """Classify samples using a genus model."""
225
230
  click.echo("Classifying...")
226
- classify.classify_genus(model_genus, Path(input_path), Path(output_path))
231
+ classify.classify_genus(
232
+ model_genus, Path(input_path), Path(output_path), sparse_sampling_step
233
+ )
227
234
 
228
235
 
229
236
  @classify_seqs.command(
@@ -275,20 +282,24 @@ def classify_species(model_genus, input_path, output_path, sparse_sampling_step)
275
282
  "-i",
276
283
  "--input-path",
277
284
  help="Path to FASTA-file for mlst identification.",
278
- type=click.Path(exists=True, dir_okay=False, file_okay=True),
285
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
279
286
  prompt=True,
287
+ default=Path("."),
280
288
  )
281
289
  @click.option(
282
290
  "-o",
283
291
  "--output-path",
284
292
  help="Path to the output file.",
285
293
  type=click.Path(dir_okay=False, file_okay=True),
294
+ default=Path(".") / f"MLST_result_{uuid4()}.json",
295
+ )
296
+ @click.option(
297
+ "-l", "--limit", is_flag=True, help="Limit the output to 5 results for each locus."
286
298
  )
287
- def classify_mlst(input_path, output_path):
299
+ def classify_mlst(input_path, output_path, limit):
288
300
  """MLST classify a sample."""
289
301
  click.echo("Classifying...")
290
- classify.classify_mlst(Path(input_path), Path(output_path))
291
- click.echo(f"Result saved as {output_path}.")
302
+ classify.classify_mlst(Path(input_path), Path(output_path), limit)
292
303
 
293
304
 
294
305
  # # # # # # # # # # # # # # #
@@ -343,8 +354,19 @@ def filter_seqs():
343
354
  default=0.7,
344
355
  prompt=True,
345
356
  )
357
+ @click.option(
358
+ "--sparse-sampling-step",
359
+ type=int,
360
+ help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
361
+ default=1,
362
+ )
346
363
  def filter_genus(
347
- model_genus, input_path, output_path, classification_output_path, threshold
364
+ model_genus,
365
+ input_path,
366
+ output_path,
367
+ classification_output_path,
368
+ threshold,
369
+ sparse_sampling_step,
348
370
  ):
349
371
  """Filter samples using a genus model."""
350
372
  click.echo("Filtering...")
@@ -355,6 +377,7 @@ def filter_genus(
355
377
  Path(output_path),
356
378
  threshold,
357
379
  Path(classification_output_path) if classification_output_path else None,
380
+ sparse_sampling_step=sparse_sampling_step,
358
381
  )
359
382
 
360
383
 
@@ -405,6 +428,12 @@ def filter_genus(
405
428
  default=0.7,
406
429
  prompt=True,
407
430
  )
431
+ @click.option(
432
+ "--sparse-sampling-step",
433
+ type=int,
434
+ help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
435
+ default=1,
436
+ )
408
437
  def filter_species(
409
438
  model_genus,
410
439
  model_species,
@@ -412,6 +441,7 @@ def filter_species(
412
441
  output_path,
413
442
  threshold,
414
443
  classification_output_path,
444
+ sparse_sampling_step,
415
445
  ):
416
446
  """Filter a sample using the species model."""
417
447
 
@@ -451,6 +481,7 @@ def filter_species(
451
481
  Path(output_path),
452
482
  threshold,
453
483
  Path(classification_output_path) if classification_output_path else None,
484
+ sparse_sampling_step=sparse_sampling_step,
454
485
  )
455
486
 
456
487
 
@@ -194,11 +194,13 @@ class MlstResult:
194
194
  scheme_model: str,
195
195
  steps: int,
196
196
  hits: dict[str, list[dict]],
197
+ input_source: str = None,
197
198
  ):
198
199
  """Initialise an MlstResult object."""
199
200
  self.scheme_model = scheme_model
200
201
  self.steps = steps
201
202
  self.hits = hits
203
+ self.input_source = input_source
202
204
 
203
205
  def get_results(self) -> dict:
204
206
  """
@@ -221,6 +223,7 @@ class MlstResult:
221
223
  "Scheme": self.scheme_model,
222
224
  "Steps": self.steps,
223
225
  "Results": self.get_results(),
226
+ "Input_source": self.input_source,
224
227
  }
225
228
  return result
226
229
 
@@ -17,7 +17,7 @@ from xspect.definitions import get_xspect_mlst_path, get_xspect_upload_path
17
17
  class PubMLSTHandler:
18
18
  """Class for communicating with PubMLST and downloading alleles (FASTA-Format) from all loci."""
19
19
 
20
- base_url = "http://rest.pubmlst.org/db"
20
+ base_url = "https://rest.pubmlst.org/db"
21
21
 
22
22
  def __init__(self):
23
23
  """Initialise a PubMLSTHandler object."""
@@ -27,6 +27,7 @@ class PubMLSTHandler:
27
27
  self.base_url + "/pubmlst_abaumannii_seqdef/schemes/2",
28
28
  ]
29
29
  self.scheme_paths = []
30
+ self.scheme_mapping = {}
30
31
 
31
32
  def get_scheme_paths(self) -> dict:
32
33
  """
@@ -103,6 +104,7 @@ class PubMLSTHandler:
103
104
 
104
105
  species_name = scheme.split("_")[1] # name = pubmlst_abaumannii_seqdef
105
106
  scheme_path = get_xspect_mlst_path() / species_name / scheme_name
107
+ self.scheme_mapping[str(scheme_path)] = scheme
106
108
  self.scheme_paths.append(scheme_path)
107
109
 
108
110
  for locus_url in locus_list:
@@ -143,3 +145,43 @@ class PubMLSTHandler:
143
145
  # Example: 'Pas_fusA': [{'href': some URL, 'allele_id': '2'}]
144
146
  print(locus + ":" + meta_data[0]["allele_id"], end="; ")
145
147
  print("\nStrain Type:", response["fields"])
148
+
149
+ def get_strain_type_name(self, highest_results: dict, post_url: str) -> str:
150
+ """
151
+ Send an API-POST request to PubMLST with the highest result of each locus as payload.
152
+
153
+ This function formats the highest_result dict into an accepted input for the request.
154
+ It gets a response from the site which is the strain type name.
155
+ The name is based on the allele id with the highest score for each locus.
156
+ Example of post_url for the oxford scheme of A.baumannii:
157
+ https://rest.pubmlst.org/db/pubmlst_abaumannii_seqdef/schemes/1/designations
158
+
159
+ Args:
160
+ highest_results (dict): The allele ids with the highest kmer matches.
161
+ post_url (str): The specific url for the scheme of a species
162
+
163
+ Returns:
164
+ str: The response (ST name or No ST found) of the POST request.
165
+ """
166
+ payload = {
167
+ "designations": {
168
+ locus: [{"allele": str(allele)}]
169
+ for locus, allele in highest_results.items()
170
+ }
171
+ }
172
+
173
+ response = requests.post(post_url + "/designations", json=payload)
174
+
175
+ if response.status_code == 200:
176
+ data = response.json()
177
+ if "fields" in data:
178
+ post_response = data["fields"]
179
+ return post_response
180
+ else:
181
+ post_response = "No matching Strain Type found in the database. "
182
+ post_response += "Possibly a novel Strain Type."
183
+ return post_response
184
+ else:
185
+ post_response = "Error:" + str(response.status_code)
186
+ post_response += response.text
187
+ return post_response
@@ -9,22 +9,55 @@ from xspect.models.probabilistic_filter_svm_model import ProbabilisticFilterSVMM
9
9
  from xspect.definitions import get_xspect_model_path
10
10
 
11
11
 
12
- def get_genus_model(genus):
13
- """Get a metagenomic model for the specified genus."""
12
+ def get_genus_model(genus) -> ProbabilisticSingleFilterModel:
13
+ """
14
+ Get a genus model for the specified genus.
15
+
16
+ This function retrieves a pre-trained genus classification model based on the provided genus name.
17
+
18
+ Args:
19
+ genus (str): The genus name for which the model is to be retrieved.
20
+
21
+ Returns:
22
+ ProbabilisticSingleFilterModel: An instance of the genus classification model.
23
+ """
14
24
  genus_model_path = get_xspect_model_path() / (genus.lower() + "-genus.json")
15
25
  genus_filter_model = ProbabilisticSingleFilterModel.load(genus_model_path)
16
26
  return genus_filter_model
17
27
 
18
28
 
19
- def get_species_model(genus):
20
- """Get a species classification model for the specified genus."""
29
+ def get_species_model(genus) -> ProbabilisticFilterSVMModel:
30
+ """
31
+ Get a species classification model for the specified genus.
32
+
33
+ This function retrieves a pre-trained species classification model based on the provided genus name.
34
+
35
+ Args:
36
+ genus (str): The genus name for which the species model is to be retrieved.
37
+
38
+ Returns:
39
+ ProbabilisticFilterSVMModel: An instance of the species classification model.
40
+ """
21
41
  species_model_path = get_xspect_model_path() / (genus.lower() + "-species.json")
22
42
  species_filter_model = ProbabilisticFilterSVMModel.load(species_model_path)
23
43
  return species_filter_model
24
44
 
25
45
 
26
- def get_model_metadata(model: str | Path):
27
- """Get the metadata of a model."""
46
+ def get_model_metadata(model: str | Path) -> dict:
47
+ """
48
+ Get metadata of a specified model.
49
+
50
+ This function retrieves the metadata of a model from its JSON file.
51
+
52
+ Args:
53
+ model (str | Path): The slug of the model (as a string) or the path to the model JSON file.
54
+
55
+ Returns:
56
+ dict: A dictionary containing the model metadata.
57
+
58
+ Raises:
59
+ ValueError: If the model does not exist or is not a valid file.
60
+ """
28
61
  if isinstance(model, str):
29
62
  model_path = get_xspect_model_path() / (model.lower() + ".json")
30
63
  elif isinstance(model, Path):
@@ -40,8 +73,17 @@ def get_model_metadata(model: str | Path):
40
73
  return model_json
41
74
 
42
75
 
43
- def update_model_metadata(model_slug: str, author: str, author_email: str):
44
- """Update the metadata of a model."""
76
+ def update_model_metadata(model_slug: str, author: str, author_email: str) -> None:
77
+ """
78
+ Update the metadata of a model.
79
+
80
+ This function updates the author and author email in the model's metadata JSON file.
81
+
82
+ Args:
83
+ model_slug (str): The slug of the model to update.
84
+ author (str): The name of the author to set in the metadata.
85
+ author_email (str): The email of the author to set in the metadata.
86
+ """
45
87
  model_metadata = get_model_metadata(model_slug)
46
88
  model_metadata["author"] = author
47
89
  model_metadata["author_email"] = author_email
@@ -51,8 +93,19 @@ def update_model_metadata(model_slug: str, author: str, author_email: str):
51
93
  file.write(dumps(model_metadata, indent=4))
52
94
 
53
95
 
54
- def update_model_display_name(model_slug: str, filter_id: str, display_name: str):
55
- """Update the display name of a filter in a model."""
96
+ def update_model_display_name(
97
+ model_slug: str, filter_id: str, display_name: str
98
+ ) -> None:
99
+ """
100
+ Update the display name of a filter in a model.
101
+
102
+ This function updates the display name of a specific filter in the model's metadata JSON file.
103
+
104
+ Args:
105
+ model_slug (str): The slug of the model to update.
106
+ filter_id (str): The ID of the filter whose display name is to be updated.
107
+ display_name (str): The new display name for the filter.
108
+ """
56
109
  model_metadata = get_model_metadata(model_slug)
57
110
  model_metadata["display_names"][filter_id] = display_name
58
111
 
@@ -61,8 +114,15 @@ def update_model_display_name(model_slug: str, filter_id: str, display_name: str
61
114
  file.write(dumps(model_metadata, indent=4))
62
115
 
63
116
 
64
- def get_models():
65
- """Get a list of all available models in a dictionary by type."""
117
+ def get_models() -> dict[str, list[dict]]:
118
+ """
119
+ Get a list of all available models in a dictionary by type.
120
+
121
+ This function scans the model directory for JSON files and organizes them by their model type.
122
+
123
+ Returns:
124
+ dict[str, list[dict]]: A dictionary where keys are model types and values are lists of model display names.
125
+ """
66
126
  model_dict = {}
67
127
  for model_file in get_xspect_model_path().glob("*.json"):
68
128
  model_metadata = get_model_metadata(model_file)
@@ -73,7 +133,17 @@ def get_models():
73
133
  return model_dict
74
134
 
75
135
 
76
- def get_model_display_names(model_slug: str):
77
- """Get the display names included in a model."""
136
+ def get_model_display_names(model_slug: str) -> list[str]:
137
+ """
138
+ Get the display names included in a model.
139
+
140
+ This function retrieves the display names of individual filters from the model's metadata.
141
+
142
+ Args:
143
+ model_slug (str): The slug of the model for which to retrieve display names.
144
+
145
+ Returns:
146
+ list[str]: A list of display names for the individual filters in the model.
147
+ """
78
148
  model_metadata = get_model_metadata(model_slug)
79
149
  return list(model_metadata["display_names"].values())
@@ -12,6 +12,7 @@ from cobs_index import DocumentList
12
12
  from collections import defaultdict
13
13
  from xspect.file_io import get_record_iterator
14
14
  from xspect.mlst_feature.mlst_helper import MlstResult
15
+ from xspect.mlst_feature.pub_mlst_handler import PubMLSTHandler
15
16
 
16
17
 
17
18
  class ProbabilisticFilterMlstSchemeModel:
@@ -19,20 +20,22 @@ class ProbabilisticFilterMlstSchemeModel:
19
20
 
20
21
  def __init__(
21
22
  self,
22
- k: int,
23
- model_display_name: str,
23
+ k_value: int,
24
+ model_name: str,
24
25
  base_path: Path,
26
+ scheme_url: str,
25
27
  fpr: float = 0.001,
26
28
  ) -> None:
27
29
  """Initialise a ProbabilisticFilterMlstSchemeModel object."""
28
- if k < 1:
30
+ if k_value < 1:
29
31
  raise ValueError("Invalid k value, must be greater than 0")
30
32
  if not isinstance(base_path, Path):
31
33
  raise ValueError("Invalid base path, must be a pathlib.Path object")
32
34
 
33
- self.k = k
34
- self.model_display_name = model_display_name
35
+ self.k_value = k_value
36
+ self.model_name = model_name
35
37
  self.base_path = base_path / "MLST"
38
+ self.scheme_url = scheme_url
36
39
  self.fpr = fpr
37
40
  self.model_type = "Strain"
38
41
  self.loci = {}
@@ -49,9 +52,10 @@ class ProbabilisticFilterMlstSchemeModel:
49
52
  dict: The dictionary containing all metadata of an object.
50
53
  """
51
54
  return {
52
- "k": self.k,
53
- "model_display_name": self.model_display_name,
55
+ "k_value": self.k_value,
56
+ "model_name": self.model_name,
54
57
  "model_type": self.model_type,
58
+ "scheme_url": str(self.scheme_url),
55
59
  "fpr": self.fpr,
56
60
  "scheme_path": str(self.scheme_path),
57
61
  "cobs_path": str(self.cobs_path),
@@ -115,7 +119,7 @@ class ProbabilisticFilterMlstSchemeModel:
115
119
  # COBS only accepts strings as paths
116
120
  doclist = DocumentList(str(locus_path))
117
121
  index_params = cobs_index.CompactIndexParameters()
118
- index_params.term_size = self.k # k-mer size
122
+ index_params.term_size = self.k_value # k-mer size
119
123
  index_params.clobber = True # overwrite output and temporary files
120
124
  index_params.false_positive_rate = self.fpr
121
125
 
@@ -130,9 +134,7 @@ class ProbabilisticFilterMlstSchemeModel:
130
134
 
131
135
  def save(self) -> None:
132
136
  """Saves the model to disk"""
133
- scheme = str(self.scheme_path).split("/")[
134
- -1
135
- ] # [-1] -> contains the scheme name
137
+ scheme = str(self.scheme_path).split("/")[-1] # [-1] contains the scheme name
136
138
  json_path = self.base_path / scheme / f"{scheme}.json"
137
139
  json_object = json.dumps(self.to_dict(), indent=4)
138
140
 
@@ -156,9 +158,10 @@ class ProbabilisticFilterMlstSchemeModel:
156
158
  json_object = file.read()
157
159
  model_json = json.loads(json_object)
158
160
  model = ProbabilisticFilterMlstSchemeModel(
159
- model_json["k"],
160
- model_json["model_display_name"],
161
+ model_json["k_value"],
162
+ model_json["model_name"],
161
163
  json_path.parent,
164
+ model_json["scheme_url"],
162
165
  model_json["fpr"],
163
166
  )
164
167
  model.scheme_path = model_json["scheme_path"]
@@ -175,7 +178,12 @@ class ProbabilisticFilterMlstSchemeModel:
175
178
  return model
176
179
 
177
180
  def calculate_hits(
178
- self, cobs_path: Path, sequence: Seq, step: int = 1
181
+ self,
182
+ cobs_path: Path,
183
+ sequence: Seq,
184
+ step: int = 1,
185
+ limit: bool = False,
186
+ limit_number: int = 5,
179
187
  ) -> list[dict]:
180
188
  """
181
189
  Calculates the hits for a sequence.
@@ -189,6 +197,8 @@ class ProbabilisticFilterMlstSchemeModel:
189
197
  cobs_path (Path): The path of the COBS-structure directory.
190
198
  sequence (Seq): The input sequence for classification.
191
199
  step (int, optional): The amount of kmers that are passed; defaults to one.
200
+ limit (bool): Applying a filter that limits the best result.
201
+ limit_number (int): The amount of results when the filter is set to true.
192
202
 
193
203
  Returns:
194
204
  list[dict]: The results of the prediction.
@@ -201,7 +211,7 @@ class ProbabilisticFilterMlstSchemeModel:
201
211
  if not isinstance(sequence, Seq):
202
212
  raise ValueError("Invalid sequence, must be a Bio.Seq object")
203
213
 
204
- if not len(sequence) > self.k:
214
+ if not len(sequence) > self.k_value:
205
215
  raise ValueError("Invalid sequence, must be longer than k")
206
216
 
207
217
  if not self.indices:
@@ -239,6 +249,10 @@ class ProbabilisticFilterMlstSchemeModel:
239
249
  sorted_counts = dict(
240
250
  sorted(all_counts.items(), key=lambda item: -item[1])
241
251
  )
252
+
253
+ if limit:
254
+ sorted_counts = dict(list(sorted_counts.items())[:limit_number])
255
+
242
256
  if not sorted_counts:
243
257
  result_dict = "A Strain type could not be detected because of no kmer matches!"
244
258
  highest_results[scheme_path_list[counter]] = {"N/A": 0}
@@ -250,25 +264,37 @@ class ProbabilisticFilterMlstSchemeModel:
250
264
  first_key: highest_result
251
265
  }
252
266
  counter += 1
253
- else:
267
+ else: # No split procedure is needed, when the sequence is short
254
268
  for index in self.indices:
255
- res = index.search(
269
+ res = index.search( # COBS can't handle Seq-Objects
256
270
  str(sequence), step=step
257
- ) # COBS can't handle Seq-Objects
258
- result_dict[scheme_path_list[counter]] = self.get_cobs_result(
259
- res, False
260
271
  )
261
- first_key, highest_result = next(
262
- iter(result_dict[scheme_path_list[counter]].items())
272
+ result = self.get_cobs_result(res, False)
273
+ result = (
274
+ dict(sorted(result.items(), key=lambda x: -x[1])[:limit_number])
275
+ if limit
276
+ else result
263
277
  )
278
+ result_dict[scheme_path_list[counter]] = result
279
+ first_key, highest_result = next(iter(result.items()))
264
280
  highest_results[scheme_path_list[counter]] = {first_key: highest_result}
265
281
  counter += 1
282
+
266
283
  # check if the strain type has sufficient amount of kmer hits
267
284
  is_valid = self.has_sufficient_score(highest_results, self.avg_locus_bp_size)
268
285
  if not is_valid:
269
286
  highest_results["Attention:"] = (
270
287
  "This strain type is not reliable due to low kmer hit rates!"
271
288
  )
289
+ else:
290
+ handler = PubMLSTHandler()
291
+ # allele_id is of type dict
292
+ flattened = {
293
+ locus: int(list(allele_id.keys())[0].split("_")[-1])
294
+ for locus, allele_id in highest_results.items()
295
+ }
296
+ strain_type_name = handler.get_strain_type_name(flattened, self.scheme_url)
297
+ highest_results["ST_Name"] = strain_type_name
272
298
  return [{"Strain type": highest_results}, {"All results": result_dict}]
273
299
 
274
300
  def predict(
@@ -282,6 +308,7 @@ class ProbabilisticFilterMlstSchemeModel:
282
308
  | Path
283
309
  ),
284
310
  step: int = 1,
311
+ limit: bool = False,
285
312
  ) -> MlstResult:
286
313
  """
287
314
  Get scores for the sequence(s) based on the filters in the model.
@@ -290,6 +317,7 @@ class ProbabilisticFilterMlstSchemeModel:
290
317
  cobs_path (Path): The path of the COBS-structure directory.
291
318
  sequence_input (Seq): The input sequence for classification
292
319
  step (int, optional): The amount of kmers that are passed; defaults to one
320
+ limit (bool, optional): Applying a filter that limits the best result.
293
321
 
294
322
  Returns:
295
323
  MlstResult: The results of the prediction.
@@ -301,13 +329,19 @@ class ProbabilisticFilterMlstSchemeModel:
301
329
  if sequence_input.id == "<unknown id>":
302
330
  sequence_input.id = "test"
303
331
  hits = {
304
- sequence_input.id: self.calculate_hits(cobs_path, sequence_input.seq)
332
+ sequence_input.id: self.calculate_hits(
333
+ cobs_path, sequence_input.seq, step, limit
334
+ )
305
335
  }
306
- return MlstResult(self.model_display_name, step, hits)
336
+ return MlstResult(self.model_name, step, hits, None)
307
337
 
308
338
  if isinstance(sequence_input, Path):
309
339
  return ProbabilisticFilterMlstSchemeModel.predict(
310
- self, cobs_path, get_record_iterator(sequence_input), step=step
340
+ self,
341
+ cobs_path,
342
+ get_record_iterator(sequence_input),
343
+ step=step,
344
+ limit=limit,
311
345
  )
312
346
 
313
347
  if isinstance(
@@ -317,33 +351,35 @@ class ProbabilisticFilterMlstSchemeModel:
317
351
  hits = {}
318
352
  # individual_seq is a SeqRecord-Object
319
353
  for individual_seq in sequence_input:
320
- individual_hits = self.calculate_hits(cobs_path, individual_seq.seq)
354
+ individual_hits = self.calculate_hits(
355
+ cobs_path, individual_seq.seq, step, limit
356
+ )
321
357
  hits[individual_seq.id] = individual_hits
322
- return MlstResult(self.model_display_name, step, hits)
323
-
358
+ return MlstResult(self.model_name, step, hits, None)
324
359
  raise ValueError(
325
360
  "Invalid sequence input, must be a Seq object, a list of Seq objects, a"
326
361
  " SeqIO FastaIterator, or a SeqIO FastqPhredIterator"
327
362
  )
328
363
 
329
364
  def get_cobs_result(
330
- self, cobs_result: cobs_index.SearchResult, kmer_threshold: bool
365
+ self,
366
+ cobs_result: cobs_index.SearchResult,
367
+ kmer_threshold: bool,
331
368
  ) -> dict:
332
369
  """
333
370
  Get every entry in a COBS search result.
334
371
 
335
372
  Args:
336
373
  cobs_result (SearchResult): The result of the prediction.
337
- kmer_threshold (bool): Applying a kmer threshold to mitigate false positives
374
+ kmer_threshold (bool): Applying a kmer threshold to mitigate false positives.
338
375
 
339
376
  Returns:
340
377
  dict: A dictionary storing the allele id of locus as key and the score as value.
341
378
  """
342
- return {
343
- individual_result.doc_name: individual_result.score
344
- for individual_result in cobs_result
345
- if not kmer_threshold or individual_result.score > 50
346
- }
379
+ hits = [
380
+ result for result in cobs_result if not kmer_threshold or result.score > 50
381
+ ]
382
+ return {result.doc_name: result.score for result in hits}
347
383
 
348
384
  def sequence_splitter(self, input_sequence: str, allele_len: int) -> list[str]:
349
385
  """
@@ -379,13 +415,15 @@ class ProbabilisticFilterMlstSchemeModel:
379
415
 
380
416
  while start + substring_length <= sequence_len:
381
417
  substring_list.append(input_sequence[start : start + substring_length])
382
- start += substring_length - self.k + 1 # To not lose kmers when dividing
418
+ start += (
419
+ substring_length - self.k_value + 1
420
+ ) # To not lose kmers when dividing
383
421
 
384
422
  # The remaining string is either appended to the list or added to the last entry.
385
423
  if start < len(input_sequence):
386
424
  remaining_substring = input_sequence[start:]
387
425
  # A substring needs to be at least of size k for COBS.
388
- if len(remaining_substring) < self.k:
426
+ if len(remaining_substring) < self.k_value:
389
427
  substring_list[-1] += remaining_substring
390
428
  else:
391
429
  substring_list.append(remaining_substring)