XspecT 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

Files changed (58) hide show
  1. {XspecT-0.1.3.dist-info → XspecT-0.2.0.dist-info}/METADATA +23 -29
  2. XspecT-0.2.0.dist-info/RECORD +30 -0
  3. {XspecT-0.1.3.dist-info → XspecT-0.2.0.dist-info}/WHEEL +1 -1
  4. xspect/definitions.py +42 -0
  5. xspect/download_filters.py +11 -26
  6. xspect/fastapi.py +101 -0
  7. xspect/file_io.py +34 -103
  8. xspect/main.py +70 -66
  9. xspect/model_management.py +88 -0
  10. xspect/models/__init__.py +0 -0
  11. xspect/models/probabilistic_filter_model.py +277 -0
  12. xspect/models/probabilistic_filter_svm_model.py +169 -0
  13. xspect/models/probabilistic_single_filter_model.py +109 -0
  14. xspect/models/result.py +148 -0
  15. xspect/pipeline.py +201 -0
  16. xspect/run.py +38 -0
  17. xspect/train.py +304 -0
  18. xspect/train_filter/create_svm.py +6 -183
  19. xspect/train_filter/extract_and_concatenate.py +117 -121
  20. xspect/train_filter/html_scrap.py +16 -28
  21. xspect/train_filter/ncbi_api/download_assemblies.py +7 -8
  22. xspect/train_filter/ncbi_api/ncbi_assembly_metadata.py +9 -17
  23. xspect/train_filter/ncbi_api/ncbi_children_tree.py +3 -2
  24. xspect/train_filter/ncbi_api/ncbi_taxon_metadata.py +7 -5
  25. XspecT-0.1.3.dist-info/RECORD +0 -49
  26. xspect/BF_v2.py +0 -637
  27. xspect/Bootstrap.py +0 -29
  28. xspect/Classifier.py +0 -142
  29. xspect/OXA_Table.py +0 -53
  30. xspect/WebApp.py +0 -724
  31. xspect/XspecT_mini.py +0 -1363
  32. xspect/XspecT_trainer.py +0 -611
  33. xspect/map_kmers.py +0 -155
  34. xspect/search_filter.py +0 -504
  35. xspect/static/How-To.png +0 -0
  36. xspect/static/Logo.png +0 -0
  37. xspect/static/Logo2.png +0 -0
  38. xspect/static/Workflow_AspecT.png +0 -0
  39. xspect/static/Workflow_ClAssT.png +0 -0
  40. xspect/static/js.js +0 -615
  41. xspect/static/main.css +0 -280
  42. xspect/templates/400.html +0 -64
  43. xspect/templates/401.html +0 -62
  44. xspect/templates/404.html +0 -62
  45. xspect/templates/500.html +0 -62
  46. xspect/templates/about.html +0 -544
  47. xspect/templates/home.html +0 -51
  48. xspect/templates/layoutabout.html +0 -87
  49. xspect/templates/layouthome.html +0 -63
  50. xspect/templates/layoutspecies.html +0 -468
  51. xspect/templates/species.html +0 -33
  52. xspect/train_filter/README_XspecT_Erweiterung.md +0 -119
  53. xspect/train_filter/get_paths.py +0 -35
  54. xspect/train_filter/interface_XspecT.py +0 -204
  55. xspect/train_filter/k_mer_count.py +0 -162
  56. {XspecT-0.1.3.dist-info → XspecT-0.2.0.dist-info}/LICENSE +0 -0
  57. {XspecT-0.1.3.dist-info → XspecT-0.2.0.dist-info}/entry_points.txt +0 -0
  58. {XspecT-0.1.3.dist-info → XspecT-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
1
+ """Probabilistic filter SVM model for sequence data"""
2
+
3
+ # pylint: disable=no-name-in-module, too-many-instance-attributes, arguments-renamed
4
+
5
+ import csv
6
+ import json
7
+ from linecache import getline
8
+ from pathlib import Path
9
+ from sklearn.svm import SVC
10
+ from Bio.SeqRecord import SeqRecord
11
+ from Bio import SeqIO
12
+ import cobs_index as cobs
13
+ from xspect.models.probabilistic_filter_model import ProbabilisticFilterModel
14
+ from xspect.definitions import fasta_endings, fastq_endings
15
+ from xspect.models.result import ModelResult
16
+
17
+
18
+ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
19
+ """Probabilistic filter SVM model for sequence data"""
20
+
21
+ def __init__(
22
+ self,
23
+ k: int,
24
+ model_display_name: str,
25
+ author: str,
26
+ author_email: str,
27
+ model_type: str,
28
+ base_path: Path,
29
+ kernel: str,
30
+ c: float,
31
+ fpr: float = 0.01,
32
+ num_hashes: int = 7,
33
+ ) -> None:
34
+ super().__init__(
35
+ k=k,
36
+ model_display_name=model_display_name,
37
+ author=author,
38
+ author_email=author_email,
39
+ model_type=model_type,
40
+ base_path=base_path,
41
+ fpr=fpr,
42
+ num_hashes=num_hashes,
43
+ )
44
+ self.kernel = kernel
45
+ self.c = c
46
+
47
+ def to_dict(self) -> dict:
48
+ return super().to_dict() | {
49
+ "kernel": self.kernel,
50
+ "C": self.c,
51
+ }
52
+
53
+ def set_svm_params(self, kernel: str, c: float) -> None:
54
+ """Set the parameters for the SVM"""
55
+ self.kernel = kernel
56
+ self.c = c
57
+ self.save()
58
+
59
+ def fit(
60
+ self,
61
+ dir_path: Path,
62
+ svm_path: Path,
63
+ display_names: dict = None,
64
+ svm_step: int = 1,
65
+ ) -> None:
66
+ """Fit the SVM to the sequences and labels"""
67
+
68
+ super().fit(dir_path, display_names=display_names)
69
+
70
+ score_list = []
71
+ for file in svm_path.iterdir():
72
+ if not file.is_file():
73
+ continue
74
+ if file.suffix[1:] not in fasta_endings + fastq_endings:
75
+ continue
76
+ print(f"Calculating {file.name} scores for SVM training...")
77
+ res = super().predict(file, step=svm_step)
78
+ scores = res.get_scores()["total"]
79
+ accession = "".join(file.name.split("_")[:2])
80
+ file_header = getline(str(file), 1)
81
+ label_id = file_header.replace("\n", "").replace(">", "")
82
+
83
+ # format scores for csv
84
+ scores = dict(sorted(scores.items()))
85
+ scores = ",".join([str(score) for score in scores.values()])
86
+ scores = f"{accession},{scores},{label_id}"
87
+ score_list.append(scores)
88
+
89
+ # csv header
90
+ keys = list(self.display_names.keys())
91
+ keys.sort()
92
+ score_list.insert(0, f"file,{','.join(keys)},label_id")
93
+
94
+ with open(
95
+ self.base_path / self.slug() / "scores.csv", "w", encoding="utf-8"
96
+ ) as file:
97
+ file.write("\n".join(score_list))
98
+
99
+ def predict(
100
+ self,
101
+ sequence_input: (
102
+ SeqRecord
103
+ | list[SeqRecord]
104
+ | SeqIO.FastaIO.FastaIterator
105
+ | SeqIO.QualityIO.FastqPhredIterator
106
+ | Path
107
+ ),
108
+ filter_ids: list[str] = None,
109
+ step: int = 1,
110
+ ) -> dict:
111
+ """Predict the labels of the sequences"""
112
+ # get scores and format them for the SVM
113
+ res = super().predict(sequence_input, filter_ids, step=step)
114
+ svm_scores = dict(sorted(res.get_scores()["total"].items()))
115
+ svm_scores = [list(svm_scores.values())]
116
+
117
+ svm = self._get_svm(filter_ids)
118
+ return ModelResult(
119
+ self.slug(),
120
+ res.hits,
121
+ res.num_kmers,
122
+ prediction=str(svm.predict(svm_scores)[0]),
123
+ )
124
+
125
+ def _get_svm(self, id_keys) -> SVC:
126
+ """Get the SVM for the given id keys"""
127
+ svm = SVC(kernel=self.kernel, C=self.c)
128
+ # parse csv
129
+ with open(
130
+ self.base_path / self.slug() / "scores.csv", "r", encoding="utf-8"
131
+ ) as file:
132
+ file.readline()
133
+ x_train = []
134
+ y_train = []
135
+ for row in csv.reader(file):
136
+ if id_keys is None or row[-1] in id_keys:
137
+ x_train.append(row[1:-1])
138
+ y_train.append(row[-1])
139
+
140
+ # train svm
141
+ svm.fit(x_train, y_train)
142
+ return svm
143
+
144
+ @staticmethod
145
+ def load(path: Path) -> "ProbabilisticFilterSVMModel":
146
+ """Load the model from disk"""
147
+ with open(path, "r", encoding="utf-8") as file:
148
+ json_object = file.read()
149
+ model_json = json.loads(json_object)
150
+ model = ProbabilisticFilterSVMModel(
151
+ model_json["k"],
152
+ model_json["model_display_name"],
153
+ model_json["author"],
154
+ model_json["author_email"],
155
+ model_json["model_type"],
156
+ path.parent,
157
+ model_json["kernel"],
158
+ model_json["C"],
159
+ fpr=model_json["fpr"],
160
+ num_hashes=model_json["num_hashes"],
161
+ )
162
+ model.display_names = model_json["display_names"]
163
+
164
+ p = model.get_cobs_index_path()
165
+ if not Path(p).exists():
166
+ raise FileNotFoundError(f"Index file not found at {p}")
167
+ model.index = cobs.Search(p, True)
168
+
169
+ return model
@@ -0,0 +1,109 @@
1
+ """Probabilistic filter SVM model for sequence data"""
2
+
3
+ # pylint: disable=no-name-in-module, too-many-instance-attributes
4
+
5
+ import json
6
+ from math import ceil
7
+ from pathlib import Path
8
+ from Bio.Seq import Seq
9
+ from Bio.SeqRecord import SeqRecord
10
+ from rbloom import Bloom
11
+ from xxhash import xxh3_64_intdigest
12
+ from xspect.models.probabilistic_filter_model import ProbabilisticFilterModel
13
+ from xspect.file_io import get_record_iterator
14
+
15
+
16
+ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
17
+ """Probabilistic filter SVM model for sequence data"""
18
+
19
+ def __init__(
20
+ self,
21
+ k: int,
22
+ model_display_name: str,
23
+ author: str,
24
+ author_email: str,
25
+ model_type: str,
26
+ base_path: Path,
27
+ fpr: float = 0.01,
28
+ num_hashes: int = 7,
29
+ ) -> None:
30
+ super().__init__(
31
+ k=k,
32
+ model_display_name=model_display_name,
33
+ author=author,
34
+ author_email=author_email,
35
+ model_type=model_type,
36
+ base_path=base_path,
37
+ fpr=fpr,
38
+ num_hashes=num_hashes,
39
+ )
40
+ self.bf = None
41
+
42
+ def fit(self, file_path: Path, display_name: str) -> None:
43
+ """Fit the SVM to the sequences and labels"""
44
+ # estimate number of kmers
45
+ total_length = 0
46
+ for record in get_record_iterator(file_path):
47
+ total_length += len(record.seq)
48
+ num_kmers = total_length - self.k + 1
49
+
50
+ self.bf = Bloom(num_kmers, self.fpr, hash_func=xxh3_64_intdigest)
51
+ for record in get_record_iterator(file_path):
52
+ for kmer in self._generate_kmers(record.seq):
53
+ self.bf.add(kmer)
54
+ self.display_names[file_path.stem] = display_name
55
+
56
+ bloom_path = self.base_path / self.slug() / "filter.bloom"
57
+ bloom_path.parent.mkdir(parents=True, exist_ok=True)
58
+ self.bf.save(str(bloom_path))
59
+
60
+ def calculate_hits(
61
+ self, sequence: Seq | SeqRecord, filter_ids=None, step: int = 1
62
+ ) -> dict:
63
+ """Calculate the hits for the sequence"""
64
+ if isinstance(sequence, SeqRecord):
65
+ sequence = sequence.seq
66
+
67
+ if not isinstance(sequence, Seq):
68
+ raise ValueError("Invalid sequence, must be a Bio.Seq object")
69
+
70
+ if not len(sequence) > self.k:
71
+ raise ValueError("Invalid sequence, must be longer than k")
72
+
73
+ num_hits = sum(
74
+ 1 for kmer in self._generate_kmers(sequence, step=step) if kmer in self.bf
75
+ )
76
+ return {next(iter(self.display_names)): num_hits}
77
+
78
+ @staticmethod
79
+ def load(path: Path) -> "ProbabilisticSingleFilterModel":
80
+ """Load the model from disk"""
81
+ with open(path, "r", encoding="utf-8") as file:
82
+ json_object = file.read()
83
+ model_json = json.loads(json_object)
84
+ model = ProbabilisticSingleFilterModel(
85
+ model_json["k"],
86
+ model_json["model_display_name"],
87
+ model_json["author"],
88
+ model_json["author_email"],
89
+ model_json["model_type"],
90
+ path.parent,
91
+ fpr=model_json["fpr"],
92
+ num_hashes=model_json["num_hashes"],
93
+ )
94
+ model.display_names = model_json["display_names"]
95
+ bloom_path = model.base_path / model.slug() / "filter.bloom"
96
+ model.bf = Bloom.load(
97
+ str(bloom_path),
98
+ hash_func=xxh3_64_intdigest,
99
+ )
100
+ return model
101
+
102
+ def _generate_kmers(self, sequence: Seq, step: int = 1):
103
+ """Generate kmers from the sequence"""
104
+ num_kmers = ceil((len(sequence) - self.k + 1) / step)
105
+ for i in range(num_kmers):
106
+ start_pos = i * step
107
+ kmer = sequence[start_pos : start_pos + self.k]
108
+ minimizer = min(kmer, str(kmer.reverse_complement()))
109
+ yield str(minimizer)
@@ -0,0 +1,148 @@
1
+ """ Module for storing the results of XspecT models. """
2
+
3
+ from enum import Enum
4
+
5
+
6
+ def get_last_processing_step(result: "ModelResult") -> "ModelResult":
7
+ """Get the last subprocessing step of the result. First path only."""
8
+ last_step = result
9
+ while last_step.subprocessing_steps:
10
+ last_step = last_step.subprocessing_steps[-1].result
11
+ return last_step
12
+
13
+
14
+ class StepType(Enum):
15
+ """Enum for defining the type of a subprocessing step."""
16
+
17
+ PREDICTION = 1
18
+ FILTERING = 2
19
+
20
+ def __str__(self) -> str:
21
+ return self.name.lower()
22
+
23
+
24
+ class SubprocessingStep:
25
+ """Class for storing a subprocessing step of an XspecT model."""
26
+
27
+ def __init__(
28
+ self,
29
+ subprocessing_type: StepType,
30
+ label: str,
31
+ treshold: float,
32
+ result: "ModelResult",
33
+ ):
34
+ self.subprocessing_type = subprocessing_type
35
+ self.label = label
36
+ self.treshold = treshold
37
+ self.result = result
38
+
39
+ def to_dict(self) -> dict:
40
+ """Return the subprocessing step as a dictionary."""
41
+ return {
42
+ "subprocessing_type": str(self.subprocessing_type),
43
+ "label": self.label,
44
+ "treshold": self.treshold,
45
+ "result": self.result.to_dict() if self.result else {},
46
+ }
47
+
48
+
49
+ class ModelResult:
50
+ """Class for storing an XspecT model result."""
51
+
52
+ def __init__(
53
+ self,
54
+ # we store hits depending on the subsequence as well as on the label
55
+ model_slug: str,
56
+ hits: dict[str, dict[str, int]],
57
+ num_kmers: dict[str, int],
58
+ sparse_sampling_step: int = 1,
59
+ prediction: str = None,
60
+ ):
61
+ if "total" in hits:
62
+ raise ValueError(
63
+ "'total' is a reserved key and cannot be used as a subsequence"
64
+ )
65
+ self.model_slug = model_slug
66
+ self.hits = hits
67
+ self.num_kmers = num_kmers
68
+ self.sparse_sampling_step = sparse_sampling_step
69
+ self.prediction = prediction
70
+ self.subprocessing_steps = []
71
+
72
+ def add_subprocessing_step(self, subprocessing_step: SubprocessingStep) -> None:
73
+ """Add a subprocessing step to the result."""
74
+ if subprocessing_step.label in self.subprocessing_steps:
75
+ raise ValueError(
76
+ f"Subprocessing step {subprocessing_step.label} already exists in the result"
77
+ )
78
+ self.subprocessing_steps.append(subprocessing_step)
79
+
80
+ def get_scores(self) -> dict:
81
+ """Return the scores of the model."""
82
+ scores = {
83
+ subsequence: {
84
+ label: round(hits / self.num_kmers[subsequence], 2)
85
+ for label, hits in subseuqence_hits.items()
86
+ }
87
+ for subsequence, subseuqence_hits in self.hits.items()
88
+ }
89
+
90
+ # calculate total scores
91
+ total_num_kmers = sum(self.num_kmers.values())
92
+ total_hits = self.get_total_hits()
93
+
94
+ scores["total"] = {
95
+ label: round(hits / total_num_kmers, 2)
96
+ for label, hits in total_hits.items()
97
+ }
98
+
99
+ return scores
100
+
101
+ def get_total_hits(self) -> dict[str, int]:
102
+ """Return the total hits of the model."""
103
+ total_hits = {label: 0 for label in list(self.hits.values())[0]}
104
+ for _, subseuqence_hits in self.hits.items():
105
+ for label, hits in subseuqence_hits.items():
106
+ total_hits[label] += hits
107
+ return total_hits
108
+
109
+ def get_filter_mask(self, label: str, filter_threshold: float) -> dict[str, bool]:
110
+ """Return a mask for filtered subsequences."""
111
+ if filter_threshold < 0 or filter_threshold > 1:
112
+ raise ValueError("The filter threshold must be between 0 and 1.")
113
+
114
+ scores = self.get_scores()
115
+ scores.pop("total")
116
+ return {
117
+ subsequence: score[label] >= filter_threshold
118
+ for subsequence, score in scores.items()
119
+ }
120
+
121
+ def get_filtered_subsequences(self, label: str, filter_threshold: 0.7) -> list[str]:
122
+ """Return the filtered subsequences."""
123
+ return [
124
+ subsequence
125
+ for subsequence, mask in self.get_filter_mask(
126
+ label, filter_threshold
127
+ ).items()
128
+ if mask
129
+ ]
130
+
131
+ def to_dict(self) -> dict:
132
+ """Return the result as a dictionary."""
133
+ res = {
134
+ "model_slug": self.model_slug,
135
+ "sparse_sampling_step": self.sparse_sampling_step,
136
+ "hits": self.hits,
137
+ "scores": self.get_scores(),
138
+ "num_kmers": self.num_kmers,
139
+ "subprocessing_steps": [
140
+ subprocessing_step.to_dict()
141
+ for subprocessing_step in self.subprocessing_steps
142
+ ],
143
+ }
144
+
145
+ if self.prediction is not None:
146
+ res["prediction"] = self.prediction
147
+
148
+ return res
xspect/pipeline.py ADDED
@@ -0,0 +1,201 @@
1
+ """ Module for defining the Pipeline class. """
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from Bio.SeqRecord import SeqRecord
6
+ from Bio import SeqIO
7
+ from xspect.file_io import get_records_by_id
8
+ from xspect.models.result import StepType, SubprocessingStep
9
+ from xspect.run import Run
10
+ from xspect.models.result import ModelResult
11
+ from xspect.model_management import get_model_by_slug
12
+
13
+
14
+ class ModelExecution:
15
+ """Class for storing a processing step of an XspecT pipeline."""
16
+
17
+ def __init__(
18
+ self,
19
+ model_slug: str,
20
+ sparse_sampling_step: int = 1,
21
+ ):
22
+ self.model_slug = model_slug
23
+ self.sparse_sampling_step = sparse_sampling_step
24
+ self.pipeline_steps = []
25
+
26
+ def add_pipeline_step(
27
+ self,
28
+ pipeline_step: "PipelineStep",
29
+ ):
30
+ """Add a subprocessing step to the pipeline step."""
31
+ self.pipeline_steps.append(pipeline_step)
32
+
33
+ def to_dict(self) -> dict:
34
+ """Return the processing step as a dictionary."""
35
+ return {
36
+ "model_slug": self.model_slug,
37
+ "sparse_sampling_step": self.sparse_sampling_step,
38
+ "pipeline_steps": [
39
+ pipeline_step.to_dict() for pipeline_step in self.pipeline_steps
40
+ ],
41
+ }
42
+
43
+ def run(
44
+ self,
45
+ sequence_input: (
46
+ SeqRecord
47
+ | list[SeqRecord]
48
+ | SeqIO.FastaIO.FastaIterator
49
+ | SeqIO.QualityIO.FastqPhredIterator
50
+ | Path
51
+ ),
52
+ ) -> ModelResult:
53
+ """Run the model on a given input."""
54
+ model = get_model_by_slug(self.model_slug)
55
+ model_result = model.predict(sequence_input, step=self.sparse_sampling_step)
56
+
57
+ for pipeline_step in self.pipeline_steps:
58
+ if pipeline_step.subprocessing_type == StepType.PREDICTION:
59
+ score = model_result.get_scores()["total"][pipeline_step.label]
60
+ if score >= pipeline_step.treshold:
61
+ prediction_model_result = pipeline_step.model_execution.run(
62
+ sequence_input
63
+ )
64
+ subprocessing_step = SubprocessingStep(
65
+ pipeline_step.subprocessing_type,
66
+ pipeline_step.label,
67
+ pipeline_step.treshold,
68
+ prediction_model_result,
69
+ )
70
+ model_result.add_subprocessing_step(subprocessing_step)
71
+ elif pipeline_step.subprocessing_type == StepType.FILTERING:
72
+ filtered_sequence_ids = model_result.get_filtered_subsequences(
73
+ pipeline_step.label, pipeline_step.treshold
74
+ )
75
+ sequence_input = get_records_by_id(
76
+ sequence_input, filtered_sequence_ids
77
+ )
78
+
79
+ filtering_model_result = None
80
+ if sequence_input:
81
+ filtering_model_result = pipeline_step.model_execution.run(
82
+ sequence_input
83
+ )
84
+
85
+ subprocessing_step = SubprocessingStep(
86
+ pipeline_step.subprocessing_type,
87
+ pipeline_step.label,
88
+ pipeline_step.treshold,
89
+ filtering_model_result,
90
+ )
91
+ model_result.add_subprocessing_step(subprocessing_step)
92
+ else:
93
+ raise ValueError(
94
+ f"Invalid subprocessing type {pipeline_step.subprocessing_type}"
95
+ )
96
+
97
+ return model_result
98
+
99
+
100
+ class PipelineStep:
101
+ """Class for storing a subprocessing step of an XspecT model."""
102
+
103
+ def __init__(
104
+ self,
105
+ subprocessing_type: StepType,
106
+ label: str,
107
+ treshold: float,
108
+ model_execution: ModelExecution,
109
+ ):
110
+ self.subprocessing_type = subprocessing_type
111
+ self.label = label
112
+ self.treshold = treshold
113
+ self.model_execution = model_execution
114
+
115
+ def to_dict(self) -> dict:
116
+ """Return the subprocessing step as a dictionary."""
117
+ return {
118
+ "subprocessing_type": str(self.subprocessing_type),
119
+ "label": self.label,
120
+ "treshold": self.treshold,
121
+ "model_execution": self.model_execution.to_dict(),
122
+ }
123
+
124
+
125
+ class Pipeline:
126
+ """Class for storing an XspecT pipeline consisting of multiple model processing steps."""
127
+
128
+ def __init__(self, display_name: str, author: str, author_email: str):
129
+ self.display_name = display_name
130
+ self.author = author
131
+ self.author_email = author_email
132
+ self.model_executions = []
133
+
134
+ def add_pipeline_step(
135
+ self,
136
+ pipeline_step: ModelExecution,
137
+ ):
138
+ """Add a processing step to the pipeline."""
139
+ self.model_executions.append(pipeline_step)
140
+
141
+ def to_dict(self) -> dict:
142
+ """Return the pipeline as a dictionary."""
143
+ return {
144
+ "display_name": self.display_name,
145
+ "author": self.author,
146
+ "author_email": self.author_email,
147
+ "model_executions": [
148
+ model_execution.to_dict() for model_execution in self.model_executions
149
+ ],
150
+ }
151
+
152
+ def to_json(self) -> str:
153
+ """Return the pipeline as a JSON string."""
154
+ return json.dumps(self.to_dict())
155
+
156
+ def save(self, path: Path) -> None:
157
+ """Save the pipeline as a JSON file."""
158
+ with open(path, "w", encoding="utf-8") as f:
159
+ f.write(self.to_json())
160
+
161
+ @staticmethod
162
+ def from_file(path: Path) -> "Pipeline":
163
+ """Load the pipeline from a JSON file."""
164
+ with open(path, "r", encoding="utf-8") as f:
165
+ pipeline_json = json.load(f)
166
+ pipeline = Pipeline(
167
+ pipeline_json["display_name"],
168
+ pipeline_json["author"],
169
+ pipeline_json["author_email"],
170
+ )
171
+ for model_execution in pipeline_json["model_executions"]:
172
+ model_execution = ModelExecution(
173
+ model_execution["model_slug"],
174
+ model_execution["sparse_sampling_step"],
175
+ )
176
+ for pipeline_step in model_execution["pipeline_steps"]:
177
+ model_execution.add_pipeline_step(
178
+ PipelineStep(
179
+ StepType(pipeline_step["subprocessing_type"]),
180
+ pipeline_step["label"],
181
+ pipeline_step["treshold"],
182
+ ModelExecution(
183
+ pipeline_step["model_execution"]["model_slug"],
184
+ pipeline_step["model_execution"][
185
+ "sparse_sampling_step"
186
+ ],
187
+ ),
188
+ )
189
+ )
190
+ pipeline.add_pipeline_step(model_execution)
191
+ return pipeline
192
+
193
+ def run(self, input_file: Path) -> Run:
194
+ """Run the pipeline on a given input."""
195
+ run = Run(self.display_name, input_file)
196
+
197
+ for model_execution in self.model_executions:
198
+ result = model_execution.run(input_file)
199
+ run.add_result(result)
200
+
201
+ return run
xspect/run.py ADDED
@@ -0,0 +1,38 @@
1
+ """ Module with XspecT global run class, which summarizes individual model results. """
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from xspect.models.result import ModelResult
6
+
7
+
8
+ class Run:
9
+ """Class for storing the results of an XspecT run."""
10
+
11
+ def __init__(self, display_name: str, input_file: str):
12
+ self.display_name = display_name
13
+ self.input_file = input_file
14
+ self.results = []
15
+
16
+ def add_result(self, result: ModelResult):
17
+ """Add a result to the run."""
18
+ self.results.append(result)
19
+
20
+ def to_dict(self) -> dict:
21
+ """Return the run as a dictionary."""
22
+ return {
23
+ "display_name": self.display_name,
24
+ "input_file": str(self.input_file),
25
+ "results": (
26
+ [result.to_dict() for result in self.results] if self.results else []
27
+ ),
28
+ }
29
+
30
+ def to_json(self) -> str:
31
+ """Return the run as a JSON string."""
32
+ json_dict = self.to_dict()
33
+ return json.dumps(json_dict, indent=4)
34
+
35
+ def save(self, path: Path) -> None:
36
+ """Save the run as a JSON file."""
37
+ with open(path, "w", encoding="utf-8") as f:
38
+ f.write(self.to_json())