blazefl 2.0.0.dev4__tar.gz → 2.0.0.dev5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blazefl-2.0.0.dev5/.python-version +1 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/PKG-INFO +1 -1
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/pyproject.toml +1 -1
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/contrib/__init__.py +6 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/contrib/fedavg.py +69 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/client_trainer.py +31 -17
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/client_trainer.pyi +2 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_contrib/test_fedavg.py +124 -4
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/uv.lock +1 -1
- blazefl-2.0.0.dev4/examples/step-by-step-dsfl/.python-version +0 -1
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/FUNDING.yml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/dependabot.yml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/workflows/ci.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/workflows/publish.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/workflows/sphinx.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.gitignore +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.pre-commit-config.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/CODE_OF_CONDUCT.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/LICENSE +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/Makefile +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/README.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/Makefile +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/architecture.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/benchmark_cnn.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/benchmark_resnet18.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/logo.svg +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/logo_square.svg +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/ogp.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/make.bat +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/_static/favicon.ico +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/_static/logo.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/_static/logo_square.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/_templates/autosummary/class.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/_templates/autosummary/module.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/benchmark.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/conf.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/contribute.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/example.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/index.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/install.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/overview.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/reference.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/.gitignore +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/.python-version +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/Makefile +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/README.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/config/config.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/main.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/models/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/models/selector.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/pyproject.toml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/uv.lock +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/.gitignore +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5/examples/quickstart-fedavg}/.python-version +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/Makefile +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/README.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/config/config.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/main.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/models/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/models/selector.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/pyproject.toml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/.gitignore +0 -0
- {blazefl-2.0.0.dev4/examples/quickstart-fedavg → blazefl-2.0.0.dev5/examples/step-by-step-dsfl}/.python-version +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/Makefile +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/README.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/algorithm/dsfl.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/config/config.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/main.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/models/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/models/cnn.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/models/selector.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/pyproject.toml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/__init__.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/model_selector.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/model_selector.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/partitioned_dataset.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/server_handler.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/server_handler.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/py.typed +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/__init__.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/dataset.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/ipc.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/ipc.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/seed.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/seed.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/serialize.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/serialize.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/conftest.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_contrib/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_core/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_core/test_client_trainer.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_core/test_model_selector.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_core/test_partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_utils/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_utils/test_dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_utils/test_seed.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_utils/test_serialize.py +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.13
|
|
@@ -8,11 +8,17 @@ extending the core functionalities of BlazeFL.
|
|
|
8
8
|
from blazefl.contrib.fedavg import (
|
|
9
9
|
FedAvgBaseClientTrainer,
|
|
10
10
|
FedAvgBaseServerHandler,
|
|
11
|
+
FedAvgDownlinkPackage,
|
|
11
12
|
FedAvgProcessPoolClientTrainer,
|
|
13
|
+
FedAvgThreadPoolClientTrainer,
|
|
14
|
+
FedAvgUplinkPackage,
|
|
12
15
|
)
|
|
13
16
|
|
|
14
17
|
__all__ = [
|
|
15
18
|
"FedAvgBaseServerHandler",
|
|
16
19
|
"FedAvgProcessPoolClientTrainer",
|
|
17
20
|
"FedAvgBaseClientTrainer",
|
|
21
|
+
"FedAvgThreadPoolClientTrainer",
|
|
22
|
+
"FedAvgUplinkPackage",
|
|
23
|
+
"FedAvgDownlinkPackage",
|
|
18
24
|
]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import random
|
|
2
|
+
import threading
|
|
2
3
|
from copy import deepcopy
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from pathlib import Path
|
|
@@ -14,6 +15,7 @@ from blazefl.core import (
|
|
|
14
15
|
ModelSelector,
|
|
15
16
|
PartitionedDataset,
|
|
16
17
|
ProcessPoolClientTrainer,
|
|
18
|
+
ThreadPoolClientTrainer,
|
|
17
19
|
)
|
|
18
20
|
from blazefl.utils import (
|
|
19
21
|
RandomState,
|
|
@@ -640,6 +642,7 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
640
642
|
device: str,
|
|
641
643
|
epochs: int,
|
|
642
644
|
lr: float,
|
|
645
|
+
stop_event: threading.Event | None = None,
|
|
643
646
|
) -> FedAvgUplinkPackage:
|
|
644
647
|
"""
|
|
645
648
|
Train the model with the given training data loader.
|
|
@@ -664,6 +667,8 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
664
667
|
|
|
665
668
|
data_size = 0
|
|
666
669
|
for _ in range(epochs):
|
|
670
|
+
if stop_event is not None and stop_event.is_set():
|
|
671
|
+
break
|
|
667
672
|
for data, target in train_loader:
|
|
668
673
|
data = data.to(device)
|
|
669
674
|
target = target.to(device)
|
|
@@ -714,3 +719,67 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
714
719
|
package = deepcopy(self.cache)
|
|
715
720
|
self.cache = []
|
|
716
721
|
return package
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
class FedAvgThreadPoolClientTrainer(
|
|
725
|
+
ThreadPoolClientTrainer[
|
|
726
|
+
FedAvgUplinkPackage,
|
|
727
|
+
FedAvgDownlinkPackage,
|
|
728
|
+
]
|
|
729
|
+
):
|
|
730
|
+
def __init__(
|
|
731
|
+
self,
|
|
732
|
+
model_selector: ModelSelector,
|
|
733
|
+
model_name: str,
|
|
734
|
+
dataset: PartitionedDataset,
|
|
735
|
+
device: str,
|
|
736
|
+
num_clients: int,
|
|
737
|
+
epochs: int,
|
|
738
|
+
batch_size: int,
|
|
739
|
+
lr: float,
|
|
740
|
+
seed: int,
|
|
741
|
+
num_parallels: int,
|
|
742
|
+
) -> None:
|
|
743
|
+
self.num_parallels = num_parallels
|
|
744
|
+
self.device = device
|
|
745
|
+
if self.device == "cuda":
|
|
746
|
+
self.device_count = torch.cuda.device_count()
|
|
747
|
+
self.cache: list[FedAvgUplinkPackage] = []
|
|
748
|
+
|
|
749
|
+
self.model_selector = model_selector
|
|
750
|
+
self.model_name = model_name
|
|
751
|
+
self.dataset = dataset
|
|
752
|
+
self.epochs = epochs
|
|
753
|
+
self.batch_size = batch_size
|
|
754
|
+
self.lr = lr
|
|
755
|
+
self.num_clients = num_clients
|
|
756
|
+
self.seed = seed
|
|
757
|
+
self.stop_event = None
|
|
758
|
+
|
|
759
|
+
def worker(
|
|
760
|
+
self,
|
|
761
|
+
cid: int,
|
|
762
|
+
device: str,
|
|
763
|
+
payload: FedAvgDownlinkPackage,
|
|
764
|
+
) -> FedAvgUplinkPackage:
|
|
765
|
+
model = self.model_selector.select_model(self.model_name)
|
|
766
|
+
train_loader = self.dataset.get_dataloader(
|
|
767
|
+
type_="train",
|
|
768
|
+
cid=cid,
|
|
769
|
+
batch_size=self.batch_size,
|
|
770
|
+
)
|
|
771
|
+
package = FedAvgProcessPoolClientTrainer.train(
|
|
772
|
+
model=model,
|
|
773
|
+
model_parameters=payload.model_parameters,
|
|
774
|
+
train_loader=train_loader,
|
|
775
|
+
device=device,
|
|
776
|
+
epochs=self.epochs,
|
|
777
|
+
lr=self.lr,
|
|
778
|
+
stop_event=self.stop_event,
|
|
779
|
+
)
|
|
780
|
+
return package
|
|
781
|
+
|
|
782
|
+
def uplink_package(self) -> list[FedAvgUplinkPackage]:
|
|
783
|
+
package = deepcopy(self.cache)
|
|
784
|
+
self.cache = []
|
|
785
|
+
return package
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import multiprocessing as mp
|
|
2
3
|
import signal
|
|
4
|
+
import threading
|
|
3
5
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
4
6
|
from multiprocessing.pool import ApplyResult
|
|
5
7
|
from pathlib import Path
|
|
@@ -197,6 +199,7 @@ class ThreadPoolClientTrainer(
|
|
|
197
199
|
device: str
|
|
198
200
|
device_count: int
|
|
199
201
|
cache: list[UplinkPackage]
|
|
202
|
+
stop_event: threading.Event | None
|
|
200
203
|
|
|
201
204
|
def worker(
|
|
202
205
|
self,
|
|
@@ -223,20 +226,31 @@ class ThreadPoolClientTrainer(
|
|
|
223
226
|
return self.device
|
|
224
227
|
|
|
225
228
|
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
device
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
229
|
+
if self.stop_event is None:
|
|
230
|
+
self.stop_event = threading.Event()
|
|
231
|
+
self.stop_event.clear()
|
|
232
|
+
try:
|
|
233
|
+
with ThreadPoolExecutor(max_workers=self.num_parallels) as executor:
|
|
234
|
+
futures = []
|
|
235
|
+
for cid in cid_list:
|
|
236
|
+
device = self.get_client_device(cid)
|
|
237
|
+
future = executor.submit(
|
|
238
|
+
self.worker,
|
|
239
|
+
cid,
|
|
240
|
+
device,
|
|
241
|
+
payload,
|
|
242
|
+
)
|
|
243
|
+
futures.append(future)
|
|
244
|
+
|
|
245
|
+
for future in tqdm(
|
|
246
|
+
as_completed(futures),
|
|
247
|
+
total=len(futures),
|
|
248
|
+
desc="Client",
|
|
249
|
+
leave=False,
|
|
250
|
+
):
|
|
251
|
+
result = future.result()
|
|
252
|
+
self.cache.append(result)
|
|
253
|
+
except KeyboardInterrupt:
|
|
254
|
+
logging.warning("Training interrupted by user.")
|
|
255
|
+
self.stop_event.set()
|
|
256
|
+
return
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import threading
|
|
1
2
|
from blazefl.utils import move_tensor_to_shared_memory as move_tensor_to_shared_memory
|
|
2
3
|
from multiprocessing.pool import ApplyResult as ApplyResult
|
|
3
4
|
from pathlib import Path
|
|
@@ -29,6 +30,7 @@ class ThreadPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage],
|
|
|
29
30
|
device: str
|
|
30
31
|
device_count: int
|
|
31
32
|
cache: list[UplinkPackage]
|
|
33
|
+
stop_event: threading.Event | None
|
|
32
34
|
def worker(self, cid: int, device: str, payload: DownlinkPackage) -> UplinkPackage: ...
|
|
33
35
|
def get_client_device(self, cid: int) -> str: ...
|
|
34
36
|
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...
|
|
@@ -13,6 +13,7 @@ from src.blazefl.contrib.fedavg import (
|
|
|
13
13
|
FedAvgBaseClientTrainer,
|
|
14
14
|
FedAvgBaseServerHandler,
|
|
15
15
|
FedAvgProcessPoolClientTrainer,
|
|
16
|
+
FedAvgThreadPoolClientTrainer,
|
|
16
17
|
)
|
|
17
18
|
from src.blazefl.core import ModelSelector, PartitionedDataset
|
|
18
19
|
|
|
@@ -86,7 +87,9 @@ def tmp_state_dir(tmp_path):
|
|
|
86
87
|
return state_dir
|
|
87
88
|
|
|
88
89
|
|
|
89
|
-
def
|
|
90
|
+
def test_base_server_and_base_trainer_integration(
|
|
91
|
+
model_selector, partitioned_dataset, device
|
|
92
|
+
):
|
|
90
93
|
model_name = "dummy"
|
|
91
94
|
global_round = 1
|
|
92
95
|
num_clients = 3
|
|
@@ -134,7 +137,7 @@ def test_server_and_base_integration(model_selector, partitioned_dataset, device
|
|
|
134
137
|
|
|
135
138
|
|
|
136
139
|
@pytest.mark.parametrize("ipc_mode", ["storage", "shared_memory"])
|
|
137
|
-
def
|
|
140
|
+
def test_base_handler_and_process_pool_trainer_integration(
|
|
138
141
|
model_selector, partitioned_dataset, device, tmp_share_dir, tmp_state_dir, ipc_mode
|
|
139
142
|
):
|
|
140
143
|
model_name = "dummy"
|
|
@@ -195,7 +198,7 @@ def run_local_process(trainer, downlink, cids):
|
|
|
195
198
|
trainer.local_process(downlink, cids)
|
|
196
199
|
|
|
197
200
|
|
|
198
|
-
def
|
|
201
|
+
def test_base_handler_and_process_pool_trainer_integration_keyboard_interrupt(
|
|
199
202
|
model_selector, partitioned_dataset, device, tmp_share_dir, tmp_state_dir
|
|
200
203
|
):
|
|
201
204
|
model_name = "dummy"
|
|
@@ -252,7 +255,9 @@ def test_server_and_process_pool_integration_keyboard_interrupt(
|
|
|
252
255
|
if len(spawned_pids) == num_parallels:
|
|
253
256
|
break
|
|
254
257
|
if time.time() - start_time > timeout:
|
|
255
|
-
|
|
258
|
+
pytest.fail(
|
|
259
|
+
f"Process did not spawn {len(spawned_pids)} processes within {timeout}s"
|
|
260
|
+
)
|
|
256
261
|
assert proc.is_alive()
|
|
257
262
|
|
|
258
263
|
os.kill(proc.pid, signal.SIGINT)
|
|
@@ -265,3 +270,118 @@ def test_server_and_process_pool_integration_keyboard_interrupt(
|
|
|
265
270
|
if psutil.pid_exists(pid):
|
|
266
271
|
orphan_pids.append(pid)
|
|
267
272
|
assert len(orphan_pids) == 0
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def test_base_handler_and_thread_pool_trainer_integration(
|
|
276
|
+
model_selector, partitioned_dataset, device
|
|
277
|
+
):
|
|
278
|
+
model_name = "dummy"
|
|
279
|
+
global_round = 2
|
|
280
|
+
num_clients = 3
|
|
281
|
+
sample_ratio = 1.0
|
|
282
|
+
epochs = 1
|
|
283
|
+
batch_size = 2
|
|
284
|
+
lr = 0.01
|
|
285
|
+
seed = 42
|
|
286
|
+
num_parallels = 2
|
|
287
|
+
|
|
288
|
+
server = FedAvgBaseServerHandler(
|
|
289
|
+
model_selector=model_selector,
|
|
290
|
+
model_name=model_name,
|
|
291
|
+
dataset=partitioned_dataset,
|
|
292
|
+
global_round=global_round,
|
|
293
|
+
num_clients=num_clients,
|
|
294
|
+
sample_ratio=sample_ratio,
|
|
295
|
+
device=device,
|
|
296
|
+
batch_size=batch_size,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
trainer = FedAvgThreadPoolClientTrainer(
|
|
300
|
+
model_selector=model_selector,
|
|
301
|
+
model_name=model_name,
|
|
302
|
+
dataset=partitioned_dataset,
|
|
303
|
+
device=device,
|
|
304
|
+
num_clients=num_clients,
|
|
305
|
+
epochs=epochs,
|
|
306
|
+
batch_size=batch_size,
|
|
307
|
+
lr=lr,
|
|
308
|
+
seed=seed,
|
|
309
|
+
num_parallels=num_parallels,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
for round_ in range(1, global_round + 1):
|
|
313
|
+
cids = server.sample_clients()
|
|
314
|
+
downlink = server.downlink_package()
|
|
315
|
+
trainer.local_process(downlink, cids)
|
|
316
|
+
uplinks = trainer.uplink_package()
|
|
317
|
+
assert len(uplinks) == num_clients
|
|
318
|
+
|
|
319
|
+
done = False
|
|
320
|
+
for pkg in uplinks:
|
|
321
|
+
done = server.load(pkg)
|
|
322
|
+
assert done is True
|
|
323
|
+
assert server.round == round_
|
|
324
|
+
|
|
325
|
+
assert server.if_stop() is True
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def test_base_handler_and_thread_pool_trainer_integration_keyboard_interrupt(
|
|
329
|
+
model_selector,
|
|
330
|
+
partitioned_dataset,
|
|
331
|
+
device,
|
|
332
|
+
):
|
|
333
|
+
model_name = "dummy"
|
|
334
|
+
global_round = 1
|
|
335
|
+
num_clients = 10
|
|
336
|
+
sample_ratio = 1.0
|
|
337
|
+
epochs = 10**5
|
|
338
|
+
batch_size = 2
|
|
339
|
+
lr = 0.01
|
|
340
|
+
seed = 42
|
|
341
|
+
num_parallels = 10
|
|
342
|
+
|
|
343
|
+
server = FedAvgBaseServerHandler(
|
|
344
|
+
model_selector=model_selector,
|
|
345
|
+
model_name=model_name,
|
|
346
|
+
dataset=partitioned_dataset,
|
|
347
|
+
global_round=global_round,
|
|
348
|
+
num_clients=num_clients,
|
|
349
|
+
sample_ratio=sample_ratio,
|
|
350
|
+
device=device,
|
|
351
|
+
batch_size=batch_size,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
trainer = FedAvgThreadPoolClientTrainer(
|
|
355
|
+
model_selector=model_selector,
|
|
356
|
+
model_name=model_name,
|
|
357
|
+
dataset=partitioned_dataset,
|
|
358
|
+
device=device,
|
|
359
|
+
num_clients=num_clients,
|
|
360
|
+
epochs=epochs,
|
|
361
|
+
batch_size=batch_size,
|
|
362
|
+
lr=lr,
|
|
363
|
+
seed=seed,
|
|
364
|
+
num_parallels=num_parallels,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
cids = server.sample_clients()
|
|
368
|
+
downlink = server.downlink_package()
|
|
369
|
+
|
|
370
|
+
proc = Process(target=run_local_process, args=(trainer, downlink, cids))
|
|
371
|
+
proc.start()
|
|
372
|
+
assert proc.pid is not None
|
|
373
|
+
|
|
374
|
+
p = psutil.Process(proc.pid)
|
|
375
|
+
timeout = 5
|
|
376
|
+
while p.num_threads() < num_parallels + 1:
|
|
377
|
+
if not p.is_running() or time.time() - p.create_time() > timeout:
|
|
378
|
+
pytest.fail(
|
|
379
|
+
f"Process did not spawn {num_parallels} threads within {timeout}s"
|
|
380
|
+
)
|
|
381
|
+
time.sleep(0.1)
|
|
382
|
+
assert proc.is_alive()
|
|
383
|
+
|
|
384
|
+
os.kill(proc.pid, signal.SIGINT)
|
|
385
|
+
|
|
386
|
+
proc.join(timeout=5)
|
|
387
|
+
assert not proc.is_alive()
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
3.12
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/.python-version
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/config/config.yaml
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/dataset/__init__.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/dataset/dataset.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/dataset/functional.py
RENAMED
|
File without changes
|
|
File without changes
|
{blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/models/__init__.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/models/selector.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|