XspecT 0.2.7__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

xspect/main.py CHANGED
@@ -1,28 +1,26 @@
1
1
  """Project CLI"""
2
2
 
3
3
  from pathlib import Path
4
- import datetime
5
- import uuid
4
+ from uuid import uuid4
6
5
  import click
7
6
  import uvicorn
8
7
  from xspect import fastapi
9
8
  from xspect.download_models import download_test_models
10
- from xspect.train import train_ncbi
11
- from xspect.models.result import (
12
- StepType,
13
- )
9
+ from xspect.file_io import filter_sequences
10
+ from xspect.train import train_from_directory, train_from_ncbi
14
11
  from xspect.definitions import (
15
- get_xspect_runs_path,
16
- fasta_endings,
17
- fastq_endings,
18
12
  get_xspect_model_path,
19
13
  )
20
- from xspect.pipeline import ModelExecution, Pipeline, PipelineStep
21
14
  from xspect.mlst_feature.mlst_helper import pick_scheme, pick_scheme_from_models_dir
22
15
  from xspect.mlst_feature.pub_mlst_handler import PubMLSTHandler
23
16
  from xspect.models.probabilistic_filter_mlst_model import (
24
17
  ProbabilisticFilterMlstSchemeModel,
25
18
  )
19
+ from xspect.model_management import (
20
+ get_genus_model,
21
+ get_models,
22
+ get_species_model,
23
+ )
26
24
 
27
25
 
28
26
  @click.group()
@@ -32,103 +30,133 @@ def cli():
32
30
 
33
31
 
34
32
  @cli.command()
35
- def download_models():
33
+ def web():
34
+ """Open the XspecT web application."""
35
+ uvicorn.run(fastapi.app, host="0.0.0.0", port=8000)
36
+
37
+
38
+ # # # # # # # # # # # # # # #
39
+ # Model management commands #
40
+ # # # # # # # # # # # # # # #
41
+ @cli.group()
42
+ def models():
43
+ """Model management commands."""
44
+ pass
45
+
46
+
47
+ @models.command(
48
+ help="Download models from the internet.",
49
+ )
50
+ def download():
36
51
  """Download models."""
37
52
  click.echo("Downloading models, this may take a while...")
38
- download_test_models("https://xspect2.s3.eu-central-1.amazonaws.com/models.zip")
53
+ download_test_models("http://assets.adrianromberg.com/xspect-models.zip")
39
54
 
40
55
 
41
- @cli.command()
42
- @click.argument("genus")
43
- @click.argument("path", type=click.Path(exists=True, dir_okay=True, file_okay=True))
56
+ @models.command(
57
+ name="list",
58
+ help="List all models in the model directory.",
59
+ )
60
+ def list_models():
61
+ """List models."""
62
+ available_models = get_models()
63
+ if not available_models:
64
+ click.echo("No models found.")
65
+ return
66
+ # todo: make this machine readable
67
+ click.echo("Models found:")
68
+ click.echo("--------------")
69
+ for model_type, names in available_models.items():
70
+ if not names:
71
+ continue
72
+ click.echo(f" {model_type}:")
73
+ for name in names:
74
+ click.echo(f" - {name}")
75
+
76
+
77
+ @models.group()
78
+ def train():
79
+ """Train models."""
80
+ pass
81
+
82
+
83
+ @train.command(
84
+ name="ncbi",
85
+ help="Train a species and a genus model based on NCBI data.",
86
+ )
87
+ @click.option("-g", "--genus", "model_genus", prompt=True)
88
+ @click.option("--svm_steps", type=int, default=1)
44
89
  @click.option(
45
- "-m",
46
- "--meta/--no-meta",
47
- help="Metagenome classification.",
48
- default=False,
90
+ "--author",
91
+ help="Author of the model.",
92
+ default=None,
49
93
  )
50
94
  @click.option(
51
- "-s",
52
- "--step",
53
- help="Sparse sampling step size (e. g. only every 500th kmer for step=500).",
54
- default=1,
95
+ "--author-email",
96
+ help="Email of the author.",
97
+ default=None,
55
98
  )
56
- def classify_species(genus, path, meta, step):
57
- """Classify sample(s) from file or directory PATH."""
58
- click.echo("Classifying...")
59
- click.echo(f"Step: {step}")
60
-
61
- file_paths = []
62
- if Path(path).is_dir():
63
- file_paths = [
64
- f
65
- for f in Path(path).iterdir()
66
- if f.is_file() and f.suffix[1:] in fasta_endings + fastq_endings
67
- ]
68
- else:
69
- file_paths = [Path(path)]
70
-
71
- # define pipeline
72
- pipeline = Pipeline(genus + " classification", "Test Author", "test@example.com")
73
- species_execution = ModelExecution(
74
- genus.lower() + "-species", sparse_sampling_step=step
75
- )
76
- if meta:
77
- species_filtering_step = PipelineStep(
78
- StepType.FILTERING, genus, 0.7, species_execution
79
- )
80
- genus_execution = ModelExecution(
81
- genus.lower() + "-genus", sparse_sampling_step=step
82
- )
83
- genus_execution.add_pipeline_step(species_filtering_step)
84
- pipeline.add_pipeline_step(genus_execution)
85
- else:
86
- pipeline.add_pipeline_step(species_execution)
87
-
88
- for idx, file_path in enumerate(file_paths):
89
- run = pipeline.run(file_path)
90
- time_str = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
91
- save_path = get_xspect_runs_path() / f"run_{time_str}_{uuid.uuid4()}.json"
92
- run.save(save_path)
93
- print(
94
- f"[{idx+1}/{len(file_paths)}] Run finished. Results saved to '{save_path}'."
95
- )
99
+ def train_ncbi(model_genus, svm_steps, author, author_email):
100
+ """Train a species and a genus model based on NCBI data."""
101
+ click.echo(f"Training {model_genus} species and genus metagenome model.")
102
+ try:
103
+ train_from_ncbi(model_genus, svm_steps, author, author_email)
104
+ except ValueError as e:
105
+ click.echo(f"Error: {e}")
106
+ return
107
+ click.echo(f"Training of {model_genus} model finished.")
96
108
 
97
109
 
98
- @cli.command()
99
- @click.argument("genus")
110
+ @train.command(
111
+ name="directory",
112
+ help="Train a species (and possibly a genus) model based on local data.",
113
+ )
114
+ @click.option("-g", "--genus", "model_genus", prompt=True)
100
115
  @click.option(
101
- "-bf-path",
102
- "--bf-assembly-path",
103
- help="Path to assembly directory for Bloom filter training.",
104
- type=click.Path(exists=True, dir_okay=True, file_okay=False),
116
+ "-i",
117
+ "--input-path",
118
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
119
+ prompt=True,
105
120
  )
106
121
  @click.option(
107
- "-svm-path",
108
- "--svm-assembly-path",
109
- help="Path to assembly directory for SVM training.",
110
- type=click.Path(exists=True, dir_okay=True, file_okay=False),
122
+ "--meta",
123
+ is_flag=True,
124
+ help="Train a metagenome model for the genus.",
125
+ default=True,
111
126
  )
112
127
  @click.option(
113
- "-s",
114
- "--svm-step",
128
+ "--svm-steps",
129
+ type=int,
115
130
  help="SVM Sparse sampling step size (e. g. only every 500th kmer for step=500).",
116
131
  default=1,
117
132
  )
118
- def train_species(genus, bf_assembly_path, svm_assembly_path, svm_step):
119
- """Train model."""
120
-
121
- if bf_assembly_path or svm_assembly_path:
122
- raise NotImplementedError(
123
- "Training with specific assembly paths is not yet implemented."
124
- )
125
- try:
126
- train_ncbi(genus, svm_step=svm_step)
127
- except ValueError as e:
128
- raise click.ClickException(str(e)) from e
133
+ @click.option(
134
+ "--author",
135
+ help="Author of the model.",
136
+ default=None,
137
+ )
138
+ @click.option(
139
+ "--author-email",
140
+ help="Email of the author.",
141
+ default=None,
142
+ )
143
+ def train_directory(model_genus, input_path, svm_steps, meta, author, author_email):
144
+ """Train a model based on data from a directory for a given genus."""
145
+ click.echo(f"Training {model_genus} model with {svm_steps} SVM steps.")
146
+ train_from_directory(
147
+ model_genus,
148
+ Path(input_path),
149
+ svm_step=svm_steps,
150
+ meta=meta,
151
+ author=author,
152
+ author_email=author_email,
153
+ )
129
154
 
130
155
 
131
- @cli.command()
156
+ @train.command(
157
+ name="mlst",
158
+ help="Train a MLST model based on PubMLST data.",
159
+ )
132
160
  @click.option(
133
161
  "-c",
134
162
  "--choose_schemes",
@@ -154,27 +182,234 @@ def train_mlst(choose_schemes):
154
182
  click.echo(f"Saved at {model.cobs_path}")
155
183
 
156
184
 
157
- @cli.command()
185
+ # # # # # # # # # # # # # # #
186
+ # Classification commands #
187
+ # # # # # # # # # # # # # # #
188
+ @cli.group(
189
+ name="classify",
190
+ help="Classify sequences using XspecT models.",
191
+ )
192
+ def classify_seqs():
193
+ """Classification commands."""
194
+ pass
195
+
196
+
197
+ @classify_seqs.command()
198
+ @click.option(
199
+ "-g",
200
+ "--genus",
201
+ "model_genus",
202
+ help="Genus of the model to classify.",
203
+ type=click.Choice(get_models().get("Genus"), None),
204
+ prompt=True,
205
+ )
206
+ @click.option(
207
+ "-i",
208
+ "--input-path",
209
+ help="Path to FASTA or FASTQ file for classification.",
210
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
211
+ prompt=True,
212
+ )
213
+ @click.option(
214
+ "-o",
215
+ "--output-path",
216
+ help="Path to the output file.",
217
+ type=click.Path(dir_okay=True, file_okay=True),
218
+ default=Path(".") / f"result_{uuid4()}.json",
219
+ )
220
+ def genus(model_genus, input_path, output_path):
221
+ """Classify samples using a genus model."""
222
+ click.echo("Classifying...")
223
+ genus_model = get_genus_model(model_genus)
224
+ result = genus_model.predict(Path(input_path))
225
+ result.save(output_path)
226
+ click.echo(f"Result saved as {output_path}.")
227
+
228
+
229
+ @classify_seqs.command()
158
230
  @click.option(
159
- "-p",
160
- "--path",
231
+ "-g",
232
+ "--genus",
233
+ "model_genus",
234
+ help="Genus of the model to classify.",
235
+ type=click.Choice(get_models().get("Species"), None),
236
+ prompt=True,
237
+ )
238
+ @click.option(
239
+ "-i",
240
+ "--input-path",
241
+ help="Path to FASTA or FASTQ file for classification.",
242
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
243
+ prompt=True,
244
+ )
245
+ @click.option(
246
+ "-o",
247
+ "--output-path",
248
+ help="Path to the output file.",
249
+ type=click.Path(dir_okay=True, file_okay=True),
250
+ default=Path(".") / f"result_{uuid4()}.json",
251
+ )
252
+ @click.option(
253
+ "--sparse-sampling-step",
254
+ type=int,
255
+ help="Sparse sampling step size (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
256
+ default=1,
257
+ )
258
+ def species(model_genus, input_path, output_path, sparse_sampling_step):
259
+ """Classify samples using a species model."""
260
+ click.echo("Classifying...")
261
+ species_model = get_species_model(model_genus)
262
+ result = species_model.predict(Path(input_path), step=sparse_sampling_step)
263
+ result.save(output_path)
264
+ click.echo(f"Result saved as {output_path}.")
265
+
266
+
267
+ @classify_seqs.command(
268
+ name="mlst",
269
+ help="Classify samples using a MLST model.",
270
+ )
271
+ @click.option(
272
+ "-i",
273
+ "--input-path",
161
274
  help="Path to FASTA-file for mlst identification.",
162
275
  type=click.Path(exists=True, dir_okay=True, file_okay=True),
276
+ prompt=True,
277
+ )
278
+ @click.option(
279
+ "-o",
280
+ "--output-path",
281
+ help="Path to the output file.",
282
+ type=click.Path(dir_okay=True, file_okay=True),
283
+ default=Path(".") / f"result_{uuid4()}.json",
163
284
  )
164
- def classify_mlst(path):
285
+ def classify_mlst(input_path, output_path):
165
286
  """MLST classify a sample."""
166
287
  click.echo("Classifying...")
167
- path = Path(path)
288
+ input_path = Path(input_path)
168
289
  scheme_path = pick_scheme_from_models_dir()
169
290
  model = ProbabilisticFilterMlstSchemeModel.load(scheme_path)
170
- model.predict(scheme_path, path).save(model.model_display_name, path)
171
- click.echo(f"Run saved at {get_xspect_runs_path()}.")
291
+ result = model.predict(scheme_path, input_path)
292
+ result.save(output_path)
293
+ click.echo(f"Result saved as {output_path}.")
172
294
 
173
295
 
174
- @cli.command()
175
- def api():
176
- """Open the XspecT FastAPI."""
177
- uvicorn.run(fastapi.app, host="0.0.0.0", port=8000)
296
+ # # # # # # # # # # # # # # #
297
+ # Filtering commands #
298
+ # # # # # # # # # # # # # # #
299
+ @cli.group(
300
+ name="filter",
301
+ help="Filter sequences using XspecT models.",
302
+ )
303
+ def filter_seqs():
304
+ """Filter commands."""
305
+ pass
306
+
307
+
308
+ @filter_seqs.command(
309
+ name="genus",
310
+ help="Filter sequences using a genus model.",
311
+ )
312
+ @click.option(
313
+ "-g",
314
+ "--genus",
315
+ "model_genus",
316
+ help="Genus of the model to use for filtering.",
317
+ type=click.Choice(get_models().get("Species"), None),
318
+ prompt=True,
319
+ )
320
+ @click.option(
321
+ "-i",
322
+ "--input-path",
323
+ help="Path to FASTA or FASTQ file for classification.",
324
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
325
+ prompt=True,
326
+ )
327
+ @click.option(
328
+ "-o",
329
+ "--output-path",
330
+ help="Path to the output file.",
331
+ type=click.Path(dir_okay=True, file_okay=True),
332
+ prompt=True,
333
+ )
334
+ @click.option(
335
+ "--threshold",
336
+ type=float,
337
+ help="Threshold for filtering (default: 0.7).",
338
+ default=0.7,
339
+ )
340
+ def filter_genus(model_genus, input_path, output_path, threshold):
341
+ """Filter samples using a genus model."""
342
+ click.echo("Filtering...")
343
+ genus_model = get_genus_model(model_genus)
344
+ result = genus_model.predict(Path(input_path))
345
+ included_ids = result.get_filtered_subsequence_labels(model_genus, threshold)
346
+ if not included_ids:
347
+ click.echo("No sequences found for the given genus.")
348
+ return
349
+
350
+ filter_sequences(
351
+ Path(input_path),
352
+ Path(output_path),
353
+ included_ids=included_ids,
354
+ )
355
+ click.echo(f"Filtered sequences saved at {output_path}.")
356
+
357
+
358
+ @filter_seqs.command(
359
+ name="species",
360
+ help="Filter sequences using a species model.",
361
+ )
362
+ @click.option(
363
+ "-g",
364
+ "--genus",
365
+ "model_genus",
366
+ help="Genus of the model to use for filtering.",
367
+ type=click.Choice(get_models().get("Species"), None),
368
+ prompt=True,
369
+ )
370
+ @click.option(
371
+ # todo: this should be a choice of the species in the model w/ display names
372
+ "-s",
373
+ "--species",
374
+ "model_species",
375
+ help="Species of the model to filter for.",
376
+ prompt=True,
377
+ )
378
+ @click.option(
379
+ "-i",
380
+ "--input-path",
381
+ help="Path to FASTA or FASTQ file for classification.",
382
+ type=click.Path(exists=True, dir_okay=True, file_okay=True),
383
+ prompt=True,
384
+ )
385
+ @click.option(
386
+ "-o",
387
+ "--output-path",
388
+ help="Path to the output file.",
389
+ type=click.Path(dir_okay=True, file_okay=True),
390
+ prompt=True,
391
+ )
392
+ @click.option(
393
+ "--threshold",
394
+ type=float,
395
+ help="Threshold for filtering (default: 0.7).",
396
+ default=0.7,
397
+ )
398
+ def filter_species(model_genus, model_species, input_path, output_path, threshold):
399
+ """Filter a sample using the species model."""
400
+ click.echo("Filtering...")
401
+ species_model = get_species_model(model_genus)
402
+ result = species_model.predict(Path(input_path))
403
+ included_ids = result.get_filtered_subsequence_labels(model_species, threshold)
404
+ if not included_ids:
405
+ click.echo("No sequences found for the given species.")
406
+ return
407
+ filter_sequences(
408
+ Path(input_path),
409
+ Path(output_path),
410
+ included_ids=included_ids,
411
+ )
412
+ click.echo(f"Filtered sequences saved at {output_path}.")
178
413
 
179
414
 
180
415
  if __name__ == "__main__":
@@ -144,12 +144,10 @@ class MlstResult:
144
144
  }
145
145
  return result
146
146
 
147
- def save(self, display: str, file_path: Path) -> None:
148
- """Saves the result inside the "runs" directory"""
149
- file_name = str(file_path).split("/")[-1]
150
- json_path = get_xspect_runs_path() / "MLST" / f"{file_name}-{display}.json"
151
- json_path.parent.mkdir(exist_ok=True, parents=True)
147
+ def save(self, output_path: Path) -> None:
148
+ """Saves the result as a JSON file."""
149
+ output_path.parent.mkdir(exist_ok=True, parents=True)
152
150
  json_object = json.dumps(self.to_dict(), indent=4)
153
151
 
154
- with open(json_path, "w", encoding="utf-8") as file:
152
+ with open(output_path, "w", encoding="utf-8") as file:
155
153
  file.write(json_object)
@@ -85,3 +85,9 @@ def get_models():
85
85
  model_metadata["model_display_name"]
86
86
  )
87
87
  return model_dict
88
+
89
+
90
+ def get_model_display_names(model_slug: str):
91
+ """Get the display names included in a model."""
92
+ model_metadata = get_model_metadata(model_slug)
93
+ return list(model_metadata["display_names"].values())
@@ -26,6 +26,7 @@ class ProbabilisticFilterModel:
26
26
  base_path: Path,
27
27
  fpr: float = 0.01,
28
28
  num_hashes: int = 7,
29
+ training_accessions: dict[str, list[str]] = None,
29
30
  ) -> None:
30
31
  if k < 1:
31
32
  raise ValueError("Invalid k value, must be greater than 0")
@@ -46,6 +47,7 @@ class ProbabilisticFilterModel:
46
47
  self.fpr = fpr
47
48
  self.num_hashes = num_hashes
48
49
  self.index = None
50
+ self.training_accessions = training_accessions
49
51
 
50
52
  def get_cobs_index_path(self) -> Path:
51
53
  """Returns the path to the cobs index"""
@@ -63,13 +65,19 @@ class ProbabilisticFilterModel:
63
65
  "display_names": self.display_names,
64
66
  "fpr": self.fpr,
65
67
  "num_hashes": self.num_hashes,
68
+ "training_accessions": self.training_accessions,
66
69
  }
67
70
 
68
71
  def slug(self) -> str:
69
72
  """Returns a slug representation of the model"""
70
73
  return slugify(self.model_display_name + "-" + str(self.model_type))
71
74
 
72
- def fit(self, dir_path: Path, display_names: dict = None) -> None:
75
+ def fit(
76
+ self,
77
+ dir_path: Path,
78
+ display_names: dict = None,
79
+ training_accessions: dict[str, list[str]] = None,
80
+ ) -> None:
73
81
  """Adds filters to the model"""
74
82
 
75
83
  if display_names is None:
@@ -84,16 +92,18 @@ class ProbabilisticFilterModel:
84
92
  if not dir_path.is_dir():
85
93
  raise ValueError("Directory path must be a directory")
86
94
 
95
+ self.training_accessions = training_accessions
96
+
87
97
  doclist = cobs.DocumentList()
88
98
  for file in dir_path.iterdir():
89
99
  if file.is_file() and file.suffix[1:] in fasta_endings + fastq_endings:
90
100
  # cobs only uses the file name to the first "." as the document name
91
- if file.name in display_names:
92
- self.display_names[file.name.split(".")[0]] = display_names[
93
- file.name
101
+ if file.stem in display_names:
102
+ self.display_names[file.stem.split(".")[0]] = display_names[
103
+ file.stem
94
104
  ]
95
105
  else:
96
- self.display_names[file.name.split(".")[0]] = file.stem
106
+ self.display_names[file.stem.split(".")[0]] = file.stem
97
107
  doclist.add(str(file))
98
108
 
99
109
  if len(doclist) == 0:
@@ -200,6 +210,7 @@ class ProbabilisticFilterModel:
200
210
  path.parent,
201
211
  model_json["fpr"],
202
212
  model_json["num_hashes"],
213
+ model_json["training_accessions"],
203
214
  )
204
215
  model.display_names = model_json["display_names"]
205
216
 
@@ -4,7 +4,6 @@
4
4
 
5
5
  import csv
6
6
  import json
7
- from linecache import getline
8
7
  from pathlib import Path
9
8
  from sklearn.svm import SVC
10
9
  from Bio.SeqRecord import SeqRecord
@@ -30,6 +29,8 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
30
29
  c: float,
31
30
  fpr: float = 0.01,
32
31
  num_hashes: int = 7,
32
+ training_accessions: dict[str, list[str]] = None,
33
+ svm_accessions: dict[str, list[str]] = None,
33
34
  ) -> None:
34
35
  super().__init__(
35
36
  k=k,
@@ -40,14 +41,17 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
40
41
  base_path=base_path,
41
42
  fpr=fpr,
42
43
  num_hashes=num_hashes,
44
+ training_accessions=training_accessions,
43
45
  )
44
46
  self.kernel = kernel
45
47
  self.c = c
48
+ self.svm_accessions = svm_accessions
46
49
 
47
50
  def to_dict(self) -> dict:
48
51
  return super().to_dict() | {
49
52
  "kernel": self.kernel,
50
53
  "C": self.c,
54
+ "svm_accessions": self.svm_accessions,
51
55
  }
52
56
 
53
57
  def set_svm_params(self, kernel: str, c: float) -> None:
@@ -62,32 +66,41 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
62
66
  svm_path: Path,
63
67
  display_names: dict = None,
64
68
  svm_step: int = 1,
69
+ training_accessions: list[str] = None,
70
+ svm_accessions: list[str] = None,
65
71
  ) -> None:
66
72
  """Fit the SVM to the sequences and labels"""
67
73
 
68
74
  # Since the SVM works with score data, we need to train
69
75
  # the underlying data structure for score generation first
70
- super().fit(dir_path, display_names=display_names)
76
+ super().fit(
77
+ dir_path,
78
+ display_names=display_names,
79
+ training_accessions=training_accessions,
80
+ )
81
+
82
+ self.svm_accessions = svm_accessions
71
83
 
72
84
  # calculate scores for SVM training
73
85
  score_list = []
74
- for file in svm_path.iterdir():
75
- if not file.is_file():
76
- continue
77
- if file.suffix[1:] not in fasta_endings + fastq_endings:
86
+
87
+ for species_folder in svm_path.iterdir():
88
+ if not species_folder.is_dir():
78
89
  continue
79
- print(f"Calculating {file.name} scores for SVM training...")
80
- res = super().predict(file, step=svm_step)
81
- scores = res.get_scores()["total"]
82
- accession = "".join(file.name.split("_")[:2])
83
- file_header = getline(str(file), 1)
84
- label_id = file_header.replace("\n", "").replace(">", "")
85
-
86
- # format scores for csv
87
- scores = dict(sorted(scores.items()))
88
- scores = ",".join([str(score) for score in scores.values()])
89
- scores = f"{accession},{scores},{label_id}"
90
- score_list.append(scores)
90
+ for file in species_folder.iterdir():
91
+ if file.suffix[1:] not in fasta_endings + fastq_endings:
92
+ continue
93
+ print(f"Calculating {file.name} scores for SVM training...")
94
+ res = super().predict(file, step=svm_step)
95
+ scores = res.get_scores()["total"]
96
+ accession = file.stem
97
+ label_id = species_folder.name
98
+
99
+ # format scores for csv
100
+ scores = dict(sorted(scores.items()))
101
+ scores = ",".join([str(score) for score in scores.values()])
102
+ scores = f"{accession},{scores},{label_id}"
103
+ score_list.append(scores)
91
104
 
92
105
  # csv header
93
106
  keys = list(self.display_names.keys())
@@ -162,6 +175,8 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
162
175
  model_json["C"],
163
176
  fpr=model_json["fpr"],
164
177
  num_hashes=model_json["num_hashes"],
178
+ training_accessions=model_json["training_accessions"],
179
+ svm_accessions=model_json["svm_accessions"],
165
180
  )
166
181
  model.display_names = model_json["display_names"]
167
182