sai-pg 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sai/__init__.py +18 -0
- sai/__main__.py +73 -0
- sai/parsers/__init__.py +18 -0
- sai/parsers/argument_validation.py +169 -0
- sai/parsers/outlier_parser.py +76 -0
- sai/parsers/plot_parser.py +152 -0
- sai/parsers/score_parser.py +241 -0
- sai/sai.py +315 -0
- sai/stats/__init__.py +18 -0
- sai/stats/features.py +302 -0
- sai/utils/__init__.py +22 -0
- sai/utils/generators/__init__.py +23 -0
- sai/utils/generators/chunk_generator.py +148 -0
- sai/utils/generators/data_generator.py +49 -0
- sai/utils/generators/window_generator.py +250 -0
- sai/utils/genomic_dataclasses.py +46 -0
- sai/utils/multiprocessing/__init__.py +22 -0
- sai/utils/multiprocessing/mp_manager.py +251 -0
- sai/utils/multiprocessing/mp_pool.py +73 -0
- sai/utils/preprocessors/__init__.py +23 -0
- sai/utils/preprocessors/chunk_preprocessor.py +152 -0
- sai/utils/preprocessors/data_preprocessor.py +94 -0
- sai/utils/preprocessors/feature_preprocessor.py +211 -0
- sai/utils/utils.py +689 -0
- sai_pg-1.0.0.dist-info/METADATA +44 -0
- sai_pg-1.0.0.dist-info/RECORD +30 -0
- sai_pg-1.0.0.dist-info/WHEEL +5 -0
- sai_pg-1.0.0.dist-info/entry_points.txt +2 -0
- sai_pg-1.0.0.dist-info/licenses/LICENSE +674 -0
- sai_pg-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,250 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
from itertools import combinations, product
|
22
|
+
from typing import Iterator, Any
|
23
|
+
from sai.utils import read_data, split_genome
|
24
|
+
from sai.utils.generators import DataGenerator
|
25
|
+
|
26
|
+
|
27
|
+
class WindowGenerator(DataGenerator):
|
28
|
+
"""
|
29
|
+
Generates genomic data for each specified window from VCF and other related files,
|
30
|
+
allowing the user to select the number of source populations.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
vcf_file: str,
|
36
|
+
chr_name: str,
|
37
|
+
ref_ind_file: str,
|
38
|
+
tgt_ind_file: str,
|
39
|
+
src_ind_file: str,
|
40
|
+
win_len: int,
|
41
|
+
win_step: int,
|
42
|
+
start: int = None,
|
43
|
+
end: int = None,
|
44
|
+
anc_allele_file: str = None,
|
45
|
+
num_src: int = 1,
|
46
|
+
):
|
47
|
+
"""
|
48
|
+
Initializes a new instance of WindowGenerator.
|
49
|
+
|
50
|
+
Parameters
|
51
|
+
----------
|
52
|
+
vcf_file : str
|
53
|
+
The path to the VCF file containing variant data.
|
54
|
+
chr_name: str
|
55
|
+
The chromosome name to read from the VCF file.
|
56
|
+
ref_ind_file : str
|
57
|
+
The path to the file containing identifiers for reference populations.
|
58
|
+
tgt_ind_file : str
|
59
|
+
The path to the file containing identifiers for target populations.
|
60
|
+
src_ind_file : str
|
61
|
+
The path to the file containing identifiers for source populations.
|
62
|
+
win_len : int
|
63
|
+
The length of each window in base pairs.
|
64
|
+
win_step : int
|
65
|
+
The step size between windows in base pairs.
|
66
|
+
start: int, optional
|
67
|
+
The starting position (1-based, inclusive) on the chromosome. Default: None.
|
68
|
+
end: int, optional
|
69
|
+
The ending position (1-based, inclusive) on the chromosome. Default: None.
|
70
|
+
anc_allele_file: str, optional
|
71
|
+
Path to the file containing ancestral allele information. Default: None.
|
72
|
+
num_src : int, optional
|
73
|
+
The number of source populations to include in each combination. Default: 1.
|
74
|
+
|
75
|
+
Raises
|
76
|
+
------
|
77
|
+
ValueError
|
78
|
+
If `win_len` is less than or equal to 0, if `win_step` is negative.
|
79
|
+
"""
|
80
|
+
if win_len <= 0:
|
81
|
+
raise ValueError("`win_len` must be greater than 0.")
|
82
|
+
if win_step < 0:
|
83
|
+
raise ValueError("`win_step` must be non-negative.")
|
84
|
+
if num_src < 1:
|
85
|
+
raise ValueError("`num_src` must be at least 1.")
|
86
|
+
|
87
|
+
self.win_len = win_len
|
88
|
+
self.win_step = win_step
|
89
|
+
self.num_src = num_src
|
90
|
+
self.chr_name = chr_name
|
91
|
+
|
92
|
+
# Load data
|
93
|
+
(
|
94
|
+
self.ref_data,
|
95
|
+
self.ref_samples,
|
96
|
+
self.tgt_data,
|
97
|
+
self.tgt_samples,
|
98
|
+
self.src_data,
|
99
|
+
self.src_samples,
|
100
|
+
self.ploidy,
|
101
|
+
) = read_data(
|
102
|
+
vcf_file=vcf_file,
|
103
|
+
chr_name=self.chr_name,
|
104
|
+
start=start,
|
105
|
+
end=end,
|
106
|
+
ref_ind_file=ref_ind_file,
|
107
|
+
tgt_ind_file=tgt_ind_file,
|
108
|
+
src_ind_file=src_ind_file,
|
109
|
+
anc_allele_file=anc_allele_file,
|
110
|
+
is_phased=False,
|
111
|
+
filter_ref=False,
|
112
|
+
filter_tgt=False,
|
113
|
+
filter_src=False,
|
114
|
+
)
|
115
|
+
|
116
|
+
self.src_combinations = list(
|
117
|
+
combinations(self.src_samples.keys(), self.num_src)
|
118
|
+
)
|
119
|
+
self.tgt_windows = {
|
120
|
+
tgt_pop: split_genome(
|
121
|
+
pos=(
|
122
|
+
self.tgt_data[tgt_pop].POS
|
123
|
+
if (start is None) and (end is None)
|
124
|
+
else [start, end - win_len + win_step]
|
125
|
+
),
|
126
|
+
window_size=self.win_len,
|
127
|
+
step_size=self.win_step,
|
128
|
+
)
|
129
|
+
for tgt_pop in self.tgt_samples
|
130
|
+
}
|
131
|
+
self.total_windows = sum(
|
132
|
+
len(windows) * len(self.ref_samples) * len(self.src_combinations)
|
133
|
+
for windows in self.tgt_windows.values()
|
134
|
+
)
|
135
|
+
|
136
|
+
def _window_generator(self) -> Iterator[dict[str, Any]]:
|
137
|
+
"""
|
138
|
+
Generator function that yields genomic data for each window for each
|
139
|
+
population combination, including specified source population combinations.
|
140
|
+
|
141
|
+
Yields
|
142
|
+
------
|
143
|
+
dict
|
144
|
+
A dictionary containing population names, start and end positions,
|
145
|
+
ploidy and phase information, reference, target, and source genotypes,
|
146
|
+
and positions for each window.
|
147
|
+
"""
|
148
|
+
for ref_pop, tgt_pop, src_comb in product(
|
149
|
+
self.ref_samples, self.tgt_samples, self.src_combinations
|
150
|
+
):
|
151
|
+
tgt_pos = self.tgt_data[tgt_pop].POS
|
152
|
+
for start, end in self.tgt_windows[tgt_pop]:
|
153
|
+
ref_gts = self.ref_data[ref_pop].GT[
|
154
|
+
(self.ref_data[ref_pop].POS >= start)
|
155
|
+
& (self.ref_data[ref_pop].POS < end)
|
156
|
+
]
|
157
|
+
tgt_gts = self.tgt_data[tgt_pop].GT[
|
158
|
+
(self.tgt_data[tgt_pop].POS >= start)
|
159
|
+
& (self.tgt_data[tgt_pop].POS < end)
|
160
|
+
]
|
161
|
+
src_gts_list = [
|
162
|
+
self.src_data[src_pop].GT[
|
163
|
+
(self.src_data[src_pop].POS >= start)
|
164
|
+
& (self.src_data[src_pop].POS < end)
|
165
|
+
]
|
166
|
+
for src_pop in src_comb
|
167
|
+
]
|
168
|
+
|
169
|
+
sub_pos = tgt_pos[(tgt_pos >= start) & (tgt_pos < end)]
|
170
|
+
|
171
|
+
yield {
|
172
|
+
"chr_name": self.chr_name,
|
173
|
+
"ref_pop": ref_pop,
|
174
|
+
"tgt_pop": tgt_pop,
|
175
|
+
"src_pop_list": src_comb, # List of source populations in this combination
|
176
|
+
"start": start,
|
177
|
+
"end": end,
|
178
|
+
"pos": sub_pos,
|
179
|
+
"ref_gts": ref_gts,
|
180
|
+
"tgt_gts": tgt_gts,
|
181
|
+
"src_gts_list": src_gts_list, # List of genotypes for each source population in src_comb
|
182
|
+
"ploidy": self.ploidy,
|
183
|
+
}
|
184
|
+
|
185
|
+
def _none_window_generator(self) -> Iterator[dict[str, Any]]:
|
186
|
+
"""
|
187
|
+
Generates empty window data when reference, target, or source data is missing.
|
188
|
+
|
189
|
+
Yields
|
190
|
+
------
|
191
|
+
dict[str, Any]
|
192
|
+
A dictionary containing the following keys:
|
193
|
+
- "chr_name" (str): The chromosome name.
|
194
|
+
- "ref_pop" (str): Reference population name.
|
195
|
+
- "tgt_pop" (str): Target population name.
|
196
|
+
- "src_pop_list" (list[str]): List of source populations in this combination.
|
197
|
+
- "start" (int): Start position of the window.
|
198
|
+
- "end" (int): End position of the window.
|
199
|
+
- "pos" (list[int]): Empty list, since no positions are available.
|
200
|
+
- "ref_gts" (None): Placeholder for missing reference genotypes.
|
201
|
+
- "tgt_gts" (None): Placeholder for missing target genotypes.
|
202
|
+
- "src_gts_list" (None): Placeholder for missing source genotypes.
|
203
|
+
- "ploidy" (None): Placeholder for missing ploidy information.
|
204
|
+
"""
|
205
|
+
for ref_pop, tgt_pop, src_comb in product(
|
206
|
+
self.ref_samples, self.tgt_samples, self.src_combinations
|
207
|
+
):
|
208
|
+
for start, end in self.tgt_windows[tgt_pop]:
|
209
|
+
yield {
|
210
|
+
"chr_name": self.chr_name,
|
211
|
+
"ref_pop": ref_pop,
|
212
|
+
"tgt_pop": tgt_pop,
|
213
|
+
"src_pop_list": src_comb,
|
214
|
+
"start": start,
|
215
|
+
"end": end,
|
216
|
+
"pos": [],
|
217
|
+
"ref_gts": None,
|
218
|
+
"tgt_gts": None,
|
219
|
+
"src_gts_list": None,
|
220
|
+
"ploidy": None,
|
221
|
+
}
|
222
|
+
|
223
|
+
def get(self) -> Iterator[dict[str, Any]]:
|
224
|
+
"""
|
225
|
+
Returns the generator for window data.
|
226
|
+
|
227
|
+
Returns
|
228
|
+
-------
|
229
|
+
generator
|
230
|
+
A generator yielding genomic data for each window.
|
231
|
+
"""
|
232
|
+
if (
|
233
|
+
(self.ref_data is None)
|
234
|
+
or (self.tgt_data is None)
|
235
|
+
or (self.src_data is None)
|
236
|
+
):
|
237
|
+
return self._none_window_generator()
|
238
|
+
else:
|
239
|
+
return self._window_generator()
|
240
|
+
|
241
|
+
def __len__(self) -> int:
|
242
|
+
"""
|
243
|
+
Returns the precomputed total number of windows across all population combinations.
|
244
|
+
|
245
|
+
Returns
|
246
|
+
-------
|
247
|
+
int
|
248
|
+
Total number of windows.
|
249
|
+
"""
|
250
|
+
return self.total_windows
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import allel
|
22
|
+
from dataclasses import dataclass
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class ChromosomeData:
|
27
|
+
"""
|
28
|
+
A data structure for storing chromosome-specific genotype information.
|
29
|
+
|
30
|
+
Attributes
|
31
|
+
----------
|
32
|
+
REF : list[str]
|
33
|
+
A list of reference alleles for each variant position.
|
34
|
+
ALT : list[str]
|
35
|
+
A list of alternate alleles for each variant position.
|
36
|
+
POS : list[int]
|
37
|
+
A list of genomic positions corresponding to each variant.
|
38
|
+
GT : list[allel.GenotypeVector]
|
39
|
+
A list of genotype vectors, where each vector represents the genotype
|
40
|
+
information for a specific variant position.
|
41
|
+
"""
|
42
|
+
|
43
|
+
REF: list[str]
|
44
|
+
ALT: list[str]
|
45
|
+
POS: list[int]
|
46
|
+
GT: list[allel.GenotypeVector]
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright 2024 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
from .mp_manager import mp_manager
|
22
|
+
from .mp_pool import mp_pool
|
@@ -0,0 +1,251 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import multiprocessing
|
22
|
+
import queue
|
23
|
+
import time
|
24
|
+
from multiprocessing import current_process
|
25
|
+
from multiprocessing import Manager
|
26
|
+
from multiprocessing import Process
|
27
|
+
from threading import Thread
|
28
|
+
from sai.utils.generators import DataGenerator
|
29
|
+
from sai.utils.preprocessors import DataPreprocessor
|
30
|
+
|
31
|
+
|
32
|
+
def monitor(shared_dict: dict, workers: list[multiprocessing.Process]) -> None:
|
33
|
+
"""
|
34
|
+
Monitors worker processes to ensure they complete successfully, initiating shutdown if any worker fails.
|
35
|
+
|
36
|
+
Continuously checks the status of each worker process through a shared dictionary. If all workers
|
37
|
+
have completed successfully, the monitoring loop exits. If any worker process terminates without
|
38
|
+
marking its completion as 'Completed', a shutdown procedure is initiated for all workers.
|
39
|
+
|
40
|
+
Parameters
|
41
|
+
----------
|
42
|
+
shared_dict : dict
|
43
|
+
A shared dictionary managed by a multiprocessing Manager. Worker processes use this dictionary
|
44
|
+
to update their status. Keys are the names of worker processes, and values are status strings
|
45
|
+
('Completed', 'Failed', etc.).
|
46
|
+
workers : list[multiprocessing.Process]
|
47
|
+
A list of multiprocessing.Process objects, each representing a worker process to be monitored.
|
48
|
+
|
49
|
+
Notes
|
50
|
+
-----
|
51
|
+
- The function assumes that worker processes update their status in the shared dictionary
|
52
|
+
upon completion or failure.
|
53
|
+
- In case of a worker failure (process is no longer alive but hasn't marked 'Completed'),
|
54
|
+
`terminate_all_workers` is called to gracefully shutdown all workers.
|
55
|
+
- The function uses a 1-second interval for periodic checks to balance responsiveness with efficiency.
|
56
|
+
"""
|
57
|
+
while True:
|
58
|
+
# alive_workers = [worker.name for worker in workers if worker.is_alive()]
|
59
|
+
# completed_workers = [worker.name for worker in workers if shared_dict.get(worker.name) == 'Completed']
|
60
|
+
# print("Monitoring", "Alive:", alive_workers, "Completed:", completed_workers)
|
61
|
+
|
62
|
+
if all(shared_dict.get(worker.name) == "Completed" for worker in workers):
|
63
|
+
# print("All workers completed their tasks successfully.")
|
64
|
+
return
|
65
|
+
|
66
|
+
for worker in workers:
|
67
|
+
if not worker.is_alive() and shared_dict.get(worker.name) != "Completed":
|
68
|
+
print(
|
69
|
+
f"{worker.name} did not complete successfully. Initiating shutdown."
|
70
|
+
)
|
71
|
+
terminate_all_workers(workers)
|
72
|
+
print("All workers are terminated.")
|
73
|
+
return
|
74
|
+
|
75
|
+
time.sleep(1) # Check periodically
|
76
|
+
|
77
|
+
|
78
|
+
def terminate_all_workers(workers: list[multiprocessing.Process]) -> None:
|
79
|
+
"""
|
80
|
+
Terminates all worker processes and waits for them to complete.
|
81
|
+
|
82
|
+
Sends a terminate signal to each worker process in the provided list and waits for each
|
83
|
+
to join, ensuring all processes are properly terminated before proceeding.
|
84
|
+
|
85
|
+
Parameters
|
86
|
+
----------
|
87
|
+
workers : list[multiprocessing.Process]
|
88
|
+
A list of multiprocessing.Process objects, each representing a worker process to be terminated.
|
89
|
+
|
90
|
+
Notes
|
91
|
+
-----
|
92
|
+
- This function is typically called to ensure a clean shutdown in case of errors or when
|
93
|
+
all work has been completed.
|
94
|
+
- It first sends a `terminate` signal to each worker and then waits for each process to join,
|
95
|
+
guaranteeing that no worker process is left hanging.
|
96
|
+
"""
|
97
|
+
for w in workers:
|
98
|
+
w.terminate()
|
99
|
+
for w in workers:
|
100
|
+
w.join() # Wait for the process to terminate
|
101
|
+
|
102
|
+
|
103
|
+
def mp_manager(
|
104
|
+
data_processor: DataPreprocessor,
|
105
|
+
data_generator: DataGenerator,
|
106
|
+
nprocess: int,
|
107
|
+
**kwargs,
|
108
|
+
) -> None:
|
109
|
+
"""
|
110
|
+
Manages the distribution of tasks across multiple processes for parallel execution, ensuring
|
111
|
+
reproducibility through controlled seed values for each task.
|
112
|
+
|
113
|
+
This function initializes a pool of worker processes and distributes tasks among them.
|
114
|
+
Each task involves executing a specified job, potentially with different seeds for
|
115
|
+
each repetition to ensure variability yet reproducibility in stochastic processes.
|
116
|
+
|
117
|
+
Parameters
|
118
|
+
----------
|
119
|
+
data_processor : DataPreprocessor
|
120
|
+
An instance of `DataPreprocessor` that prepares data for each task before execution.
|
121
|
+
data_generator : DataGenerator
|
122
|
+
An instance of a `DataGenerator` subclass that yields dictionaries with parameters for each task.
|
123
|
+
The `run` method in the corresponding job instance must be compatible with the parameters returned
|
124
|
+
by the `get` method in the data_generator. This ensures that each task executed by the job function
|
125
|
+
receives the correct parameters, facilitating reproducibility and consistency across tasks.
|
126
|
+
nprocess : int
|
127
|
+
The number of worker processes to use for executing the job in parallel. This determines
|
128
|
+
the pool size of the multiprocessing environment.
|
129
|
+
**kwargs : dict
|
130
|
+
Additional keyword arguments to be passed directly to the job function. These are
|
131
|
+
forwarded as-is to each job invocation.
|
132
|
+
|
133
|
+
Raises
|
134
|
+
------
|
135
|
+
Exception
|
136
|
+
Captures and logs any exceptions encountered during the initialization or execution
|
137
|
+
phase, including issues with starting worker processes or collecting results.
|
138
|
+
|
139
|
+
Notes
|
140
|
+
-----
|
141
|
+
- The function utilizes a multiprocessing manager to create shared queues and dictionaries
|
142
|
+
for task distribution and worker status monitoring.
|
143
|
+
- To ensure smooth termination and cleanup, a monitoring thread is used to join all worker
|
144
|
+
processes, and `cleanup_on_sigterm` is called to handle sudden terminations gracefully.
|
145
|
+
"""
|
146
|
+
try:
|
147
|
+
from pytest_cov.embed import cleanup_on_sigterm
|
148
|
+
except ImportError:
|
149
|
+
pass
|
150
|
+
else:
|
151
|
+
cleanup_on_sigterm()
|
152
|
+
|
153
|
+
with Manager() as manager:
|
154
|
+
in_queue, out_queue = manager.Queue(), manager.Queue()
|
155
|
+
shared_dict = manager.dict()
|
156
|
+
workers = [
|
157
|
+
Process(
|
158
|
+
target=mp_worker, args=(in_queue, out_queue, shared_dict), kwargs=kwargs
|
159
|
+
)
|
160
|
+
for i in range(nprocess)
|
161
|
+
]
|
162
|
+
|
163
|
+
for params in data_generator.get():
|
164
|
+
in_queue.put((data_processor, params))
|
165
|
+
|
166
|
+
try:
|
167
|
+
for w in workers:
|
168
|
+
w.start()
|
169
|
+
|
170
|
+
monitor_thread = Thread(target=monitor, args=(shared_dict, workers))
|
171
|
+
monitor_thread.start()
|
172
|
+
|
173
|
+
results = []
|
174
|
+
|
175
|
+
for i in range(len(data_generator)):
|
176
|
+
items = out_queue.get()
|
177
|
+
if items is None:
|
178
|
+
continue
|
179
|
+
if isinstance(items, tuple) and "error" in items:
|
180
|
+
break
|
181
|
+
|
182
|
+
results.extend(items)
|
183
|
+
|
184
|
+
if results:
|
185
|
+
data_processor.process_items(results)
|
186
|
+
|
187
|
+
for w in workers:
|
188
|
+
w.join()
|
189
|
+
finally:
|
190
|
+
for w in workers:
|
191
|
+
w.terminate()
|
192
|
+
monitor_thread.join()
|
193
|
+
|
194
|
+
|
195
|
+
def mp_worker(
|
196
|
+
in_queue: queue.Queue, out_queue: queue.Queue, shared_dict: dict, **kwargs
|
197
|
+
) -> None:
|
198
|
+
"""
|
199
|
+
A multiprocessing worker function that processes tasks from an input queue, executes a job,
|
200
|
+
and reports the status to an output queue and a shared dictionary.
|
201
|
+
|
202
|
+
This worker continuously fetches tasks from `in_queue`, each task comprising a repetition
|
203
|
+
number, a seed value, and a job object. It executes the `run` method of the job object.
|
204
|
+
Upon successful completion of a task, it places the result in `out_queue`. If the input queue
|
205
|
+
is empty or an exception occurs during task processing, the worker updates its status in
|
206
|
+
`shared_dict` and terminates gracefully.
|
207
|
+
|
208
|
+
Parameters
|
209
|
+
----------
|
210
|
+
in_queue : multiprocessing.managers.SyncManager.Queue
|
211
|
+
The input queue from which the worker fetches tasks.
|
212
|
+
out_queue : multiprocessing.managers.SyncManager.Queue
|
213
|
+
The output queue where the worker posts the results of processed tasks.
|
214
|
+
shared_dict : dict
|
215
|
+
A shared dictionary where the worker updates its status. Uses the worker's process name
|
216
|
+
as the key and the status ('Started', 'Completed', 'Failed') as the value.
|
217
|
+
**kwargs : dict
|
218
|
+
Additional keyword arguments that may be passed to the job's `run` method.
|
219
|
+
|
220
|
+
Raises
|
221
|
+
------
|
222
|
+
Exception
|
223
|
+
Captures and logs any exceptions encountered during task processing. The worker updates
|
224
|
+
its status as 'Failed' in `shared_dict` and posts an error message to `out_queue` before
|
225
|
+
termination.
|
226
|
+
|
227
|
+
Notes
|
228
|
+
-----
|
229
|
+
- The worker uses a timeout of 5 seconds for fetching tasks from `in_queue` to prevent
|
230
|
+
indefinite blocking if the queue is empty.
|
231
|
+
- Upon encountering an empty queue, the worker marks itself as 'Completed' in `shared_dict`
|
232
|
+
and exits.
|
233
|
+
- If an exception occurs, it marks itself as 'Failed' and posts the error to `out_queue`
|
234
|
+
before breaking the loop and terminating.
|
235
|
+
"""
|
236
|
+
process_name = current_process().name
|
237
|
+
shared_dict[process_name] = "Started"
|
238
|
+
|
239
|
+
while True:
|
240
|
+
try:
|
241
|
+
try:
|
242
|
+
data_processor, params = in_queue.get(timeout=5)
|
243
|
+
items = data_processor.run(**params, **kwargs)
|
244
|
+
except queue.Empty:
|
245
|
+
shared_dict[process_name] = "Completed"
|
246
|
+
return # Exit the loop and end the worker process
|
247
|
+
out_queue.put(items)
|
248
|
+
except Exception as e:
|
249
|
+
shared_dict[process_name] = "Failed"
|
250
|
+
out_queue.put(("error", str(e)))
|
251
|
+
raise Exception(f"Worker {process_name} encountered an exception: {e}")
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
from multiprocessing import Pool
|
22
|
+
from typing import Any
|
23
|
+
from sai.utils.generators import DataGenerator
|
24
|
+
from sai.utils.preprocessors import DataPreprocessor
|
25
|
+
|
26
|
+
|
27
|
+
def mp_worker(params: tuple[DataPreprocessor, dict]) -> Any:
|
28
|
+
"""
|
29
|
+
Executes the `run` method of the `DataPreprocessor` with provided parameters.
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
params : tuple of (DataPreprocessor, dict)
|
34
|
+
A tuple containing an instance of `DataPreprocessor` and a dictionary of parameters.
|
35
|
+
|
36
|
+
Returns
|
37
|
+
-------
|
38
|
+
Any
|
39
|
+
The result of `data_processor.run(**param_dict)`.
|
40
|
+
"""
|
41
|
+
data_processor, param_dict = params
|
42
|
+
return data_processor.run(**param_dict)
|
43
|
+
|
44
|
+
|
45
|
+
def mp_pool(
|
46
|
+
data_processor: DataPreprocessor,
|
47
|
+
data_generator: DataGenerator,
|
48
|
+
nprocess: int,
|
49
|
+
) -> None:
|
50
|
+
"""
|
51
|
+
Distributes data processing tasks across multiple processes.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
data_processor : DataPreprocessor
|
56
|
+
An instance of `DataPreprocessor` responsible for processing data.
|
57
|
+
data_generator : DataGenerator
|
58
|
+
A generator that yields parameter dictionaries for processing.
|
59
|
+
nprocess : int
|
60
|
+
The number of worker processes to use.
|
61
|
+
|
62
|
+
Returns
|
63
|
+
-------
|
64
|
+
None
|
65
|
+
The processed results are handled by `data_processor.process_items()`.
|
66
|
+
"""
|
67
|
+
tasks: list[tuple[DataPreprocessor, dict]] = [
|
68
|
+
(data_processor, params) for params in data_generator.get()
|
69
|
+
]
|
70
|
+
with Pool(processes=nprocess) as pool:
|
71
|
+
results = pool.map(mp_worker, tasks)
|
72
|
+
|
73
|
+
data_processor.process_items(results)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
from .data_preprocessor import DataPreprocessor
|
22
|
+
from .chunk_preprocessor import ChunkPreprocessor
|
23
|
+
from .feature_preprocessor import FeaturePreprocessor
|