blazefl 2.0.0.dev2__tar.gz → 2.0.0.dev3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/PKG-INFO +1 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/main.py +1 -2
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/pyproject.toml +1 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/uv.lock +4 -4
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/main.py +9 -9
- blazefl-2.0.0.dev3/examples/step-by-step-dsfl/algorithm/__init__.py +3 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/algorithm/dsfl.py +11 -11
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/main.py +5 -5
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/pyproject.toml +1 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/contrib/__init__.py +6 -6
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/contrib/fedavg.py +15 -13
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/__init__.py +8 -8
- blazefl-2.0.0.dev3/src/blazefl/core/__init__.pyi +6 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/client_trainer.py +7 -13
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/client_trainer.pyi +3 -4
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/server_handler.py +1 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/server_handler.pyi +1 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_contrib/test_fedavg.py +12 -12
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_core/test_client_trainer.py +5 -5
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/uv.lock +1 -1
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/__init__.py +0 -5
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/client_trainer.py +0 -45
- blazefl-2.0.0.dev2/examples/step-by-step-dsfl/algorithm/__init__.py +0 -3
- blazefl-2.0.0.dev2/src/blazefl/core/__init__.pyi +0 -6
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/FUNDING.yml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/dependabot.yml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/workflows/ci.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/workflows/publish.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/workflows/sphinx.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.gitignore +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.pre-commit-config.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.python-version +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/CODE_OF_CONDUCT.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/LICENSE +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/Makefile +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/README.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/Makefile +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/architecture.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/benchmark_cnn.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/benchmark_resnet18.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/logo.svg +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/logo_square.svg +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/ogp.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/make.bat +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/_static/favicon.ico +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/_static/logo.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/_static/logo_square.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/_templates/autosummary/class.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/_templates/autosummary/module.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/benchmark.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/conf.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/contribute.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/example.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/index.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/install.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/overview.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/reference.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/.gitignore +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/.python-version +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/Makefile +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/README.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/config/config.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/models/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/models/selector.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/.gitignore +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/.python-version +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/Makefile +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/README.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/config/config.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/models/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/models/selector.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/pyproject.toml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/.gitignore +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/.python-version +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/Makefile +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/README.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/config/config.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/models/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/models/cnn.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/models/selector.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/pyproject.toml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/model_selector.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/model_selector.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/partitioned_dataset.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/py.typed +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/__init__.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/dataset.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/seed.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/seed.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/serialize.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/serialize.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/conftest.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_contrib/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_core/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_core/test_model_selector.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_core/test_partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_utils/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_utils/test_dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_utils/test_seed.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_utils/test_serialize.py +0 -0
|
@@ -12,11 +12,10 @@ from blazefl.contrib import (
|
|
|
12
12
|
FedAvgServerHandler,
|
|
13
13
|
)
|
|
14
14
|
from blazefl.contrib.fedavg import FedAvgDownlinkPackage, FedAvgUplinkPackage
|
|
15
|
-
from blazefl.core import ModelSelector, PartitionedDataset
|
|
15
|
+
from blazefl.core import ModelSelector, MultiThreadClientTrainer, PartitionedDataset
|
|
16
16
|
from blazefl.utils import seed_everything
|
|
17
17
|
from omegaconf import DictConfig, OmegaConf
|
|
18
18
|
|
|
19
|
-
from core.client_trainer import MultiThreadClientTrainer
|
|
20
19
|
from dataset import PartitionedCIFAR10
|
|
21
20
|
from models import FedAvgModelSelector
|
|
22
21
|
|
|
@@ -10,16 +10,16 @@ sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d
|
|
|
10
10
|
|
|
11
11
|
[[package]]
|
|
12
12
|
name = "blazefl"
|
|
13
|
-
version = "2.0.0.
|
|
13
|
+
version = "2.0.0.dev2"
|
|
14
14
|
source = { registry = "https://pypi.org/simple" }
|
|
15
15
|
dependencies = [
|
|
16
16
|
{ name = "numpy" },
|
|
17
17
|
{ name = "torch" },
|
|
18
18
|
{ name = "tqdm" },
|
|
19
19
|
]
|
|
20
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
|
20
|
+
sdist = { url = "https://files.pythonhosted.org/packages/9b/92/3f668d7259a81d13d5d2e6089d8022629a09058ddbcae33f5873d3871656/blazefl-2.0.0.dev2.tar.gz", hash = "sha256:28a3e09d6f6cec8d8e8bb4da90adaa0c795980c74f9d4ce65024ba762726b595", size = 613563, upload_time = "2025-06-10T09:24:08.981Z" }
|
|
21
21
|
wheels = [
|
|
22
|
-
{ url = "https://files.pythonhosted.org/packages/
|
|
22
|
+
{ url = "https://files.pythonhosted.org/packages/e3/7c/841c8b9d22caaefc1dcc167076292fd1479309b3e11eb6fc1b24141792a8/blazefl-2.0.0.dev2-py3-none-any.whl", hash = "sha256:88639018166dafa79fde59aa5575bfb7277bf9e0e4497f22a59e9c7ec34efa4b", size = 23999, upload_time = "2025-06-10T09:24:07.553Z" },
|
|
23
23
|
]
|
|
24
24
|
|
|
25
25
|
[[package]]
|
|
@@ -50,7 +50,7 @@ dev = [
|
|
|
50
50
|
|
|
51
51
|
[package.metadata]
|
|
52
52
|
requires-dist = [
|
|
53
|
-
{ name = "blazefl", specifier = ">=2.0.0.
|
|
53
|
+
{ name = "blazefl", specifier = ">=2.0.0.dev2" },
|
|
54
54
|
{ name = "hydra-core", specifier = ">=1.3.2" },
|
|
55
55
|
{ name = "torch", specifier = ">=2.7.1" },
|
|
56
56
|
{ name = "torchvision", specifier = ">=0.22.1" },
|
|
@@ -6,9 +6,9 @@ import hydra
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.multiprocessing as mp
|
|
8
8
|
from blazefl.contrib import (
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
FedAvgBaseClientTrainer,
|
|
10
|
+
FedAvgBaseServerHandler,
|
|
11
|
+
FedAvgProcessPoolClientTrainer,
|
|
12
12
|
)
|
|
13
13
|
from blazefl.utils import seed_everything
|
|
14
14
|
from hydra.core import hydra_config
|
|
@@ -22,8 +22,8 @@ from models import FedAvgModelSelector
|
|
|
22
22
|
class FedAvgPipeline:
|
|
23
23
|
def __init__(
|
|
24
24
|
self,
|
|
25
|
-
handler:
|
|
26
|
-
trainer:
|
|
25
|
+
handler: FedAvgBaseServerHandler,
|
|
26
|
+
trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer,
|
|
27
27
|
writer: SummaryWriter,
|
|
28
28
|
) -> None:
|
|
29
29
|
self.handler = handler
|
|
@@ -87,7 +87,7 @@ def main(cfg: DictConfig):
|
|
|
87
87
|
)
|
|
88
88
|
model_selector = FedAvgModelSelector(num_classes=10)
|
|
89
89
|
|
|
90
|
-
handler =
|
|
90
|
+
handler = FedAvgBaseServerHandler(
|
|
91
91
|
model_selector=model_selector,
|
|
92
92
|
model_name=cfg.model_name,
|
|
93
93
|
dataset=dataset,
|
|
@@ -97,9 +97,9 @@ def main(cfg: DictConfig):
|
|
|
97
97
|
sample_ratio=cfg.sample_ratio,
|
|
98
98
|
batch_size=cfg.batch_size,
|
|
99
99
|
)
|
|
100
|
-
trainer:
|
|
100
|
+
trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer | None = None
|
|
101
101
|
if cfg.serial:
|
|
102
|
-
trainer =
|
|
102
|
+
trainer = FedAvgBaseClientTrainer(
|
|
103
103
|
model_selector=model_selector,
|
|
104
104
|
model_name=cfg.model_name,
|
|
105
105
|
dataset=dataset,
|
|
@@ -110,7 +110,7 @@ def main(cfg: DictConfig):
|
|
|
110
110
|
batch_size=cfg.batch_size,
|
|
111
111
|
)
|
|
112
112
|
else:
|
|
113
|
-
trainer =
|
|
113
|
+
trainer = FedAvgProcessPoolClientTrainer(
|
|
114
114
|
model_selector=model_selector,
|
|
115
115
|
model_name=cfg.model_name,
|
|
116
116
|
dataset=dataset,
|
|
@@ -7,8 +7,8 @@ from pathlib import Path
|
|
|
7
7
|
import torch
|
|
8
8
|
import torch.nn.functional as F
|
|
9
9
|
from blazefl.core import (
|
|
10
|
-
|
|
11
|
-
|
|
10
|
+
BaseServerHandler,
|
|
11
|
+
ProcessPoolClientTrainer,
|
|
12
12
|
)
|
|
13
13
|
from blazefl.utils import (
|
|
14
14
|
FilteredDataset,
|
|
@@ -35,7 +35,7 @@ class DSFLDownlinkPackage:
|
|
|
35
35
|
next_indices: torch.Tensor
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
class
|
|
38
|
+
class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
|
|
39
39
|
def __init__(
|
|
40
40
|
self,
|
|
41
41
|
model_selector: DSFLModelSelector,
|
|
@@ -116,7 +116,7 @@ class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
|
|
|
116
116
|
era_soft_labels = F.softmax(mean_soft_labels / self.era_temperature, dim=0)
|
|
117
117
|
global_soft_labels.append(era_soft_labels)
|
|
118
118
|
|
|
119
|
-
|
|
119
|
+
DSFLBaseServerHandler.distill(
|
|
120
120
|
self.model,
|
|
121
121
|
self.kd_optimizer,
|
|
122
122
|
self.dataset,
|
|
@@ -205,7 +205,7 @@ class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
|
|
|
205
205
|
return avg_loss, avg_acc
|
|
206
206
|
|
|
207
207
|
def get_summary(self) -> dict[str, float]:
|
|
208
|
-
server_loss, server_acc =
|
|
208
|
+
server_loss, server_acc = DSFLBaseServerHandler.evaulate(
|
|
209
209
|
self.model,
|
|
210
210
|
self.dataset.get_dataloader(
|
|
211
211
|
type_="test",
|
|
@@ -257,8 +257,8 @@ class DSFLClientState:
|
|
|
257
257
|
kd_optimizer: dict[str, torch.Tensor] | None
|
|
258
258
|
|
|
259
259
|
|
|
260
|
-
class
|
|
261
|
-
|
|
260
|
+
class DSFLProcessPoolClientTrainer(
|
|
261
|
+
ProcessPoolClientTrainer[DSFLUplinkPackage, DSFLDownlinkPackage, DSFLDiskSharedData]
|
|
262
262
|
):
|
|
263
263
|
def __init__(
|
|
264
264
|
self,
|
|
@@ -333,7 +333,7 @@ class DSFLParallelClientTrainer(
|
|
|
333
333
|
global_indices = data.payload.indices.tolist()
|
|
334
334
|
if kd_optimizer is None:
|
|
335
335
|
kd_optimizer = torch.optim.SGD(model.parameters(), lr=data.kd_lr)
|
|
336
|
-
|
|
336
|
+
DSFLBaseServerHandler.distill(
|
|
337
337
|
model=model,
|
|
338
338
|
optimizer=kd_optimizer,
|
|
339
339
|
dataset=data.dataset,
|
|
@@ -350,7 +350,7 @@ class DSFLParallelClientTrainer(
|
|
|
350
350
|
cid=data.cid,
|
|
351
351
|
batch_size=data.batch_size,
|
|
352
352
|
)
|
|
353
|
-
|
|
353
|
+
DSFLProcessPoolClientTrainer.train(
|
|
354
354
|
model=model,
|
|
355
355
|
optimizer=optimizer,
|
|
356
356
|
train_loader=train_loader,
|
|
@@ -363,7 +363,7 @@ class DSFLParallelClientTrainer(
|
|
|
363
363
|
Subset(open_dataset, data.payload.next_indices.tolist()),
|
|
364
364
|
batch_size=data.batch_size,
|
|
365
365
|
)
|
|
366
|
-
soft_labels =
|
|
366
|
+
soft_labels = DSFLProcessPoolClientTrainer.predict(
|
|
367
367
|
model=model,
|
|
368
368
|
open_loader=open_loader,
|
|
369
369
|
device=device,
|
|
@@ -375,7 +375,7 @@ class DSFLParallelClientTrainer(
|
|
|
375
375
|
cid=data.cid,
|
|
376
376
|
batch_size=data.batch_size,
|
|
377
377
|
)
|
|
378
|
-
loss, acc =
|
|
378
|
+
loss, acc = DSFLBaseServerHandler.evaulate(
|
|
379
379
|
model=model,
|
|
380
380
|
test_loader=test_loader,
|
|
381
381
|
device=device,
|
|
@@ -10,7 +10,7 @@ from hydra.core import hydra_config
|
|
|
10
10
|
from omegaconf import DictConfig, OmegaConf
|
|
11
11
|
from torch.utils.tensorboard.writer import SummaryWriter
|
|
12
12
|
|
|
13
|
-
from algorithm import
|
|
13
|
+
from algorithm import DSFLBaseServerHandler, DSFLProcessPoolClientTrainer
|
|
14
14
|
from dataset import DSFLPartitionedDataset
|
|
15
15
|
from models import DSFLModelSelector
|
|
16
16
|
|
|
@@ -18,8 +18,8 @@ from models import DSFLModelSelector
|
|
|
18
18
|
class DSFLPipeline:
|
|
19
19
|
def __init__(
|
|
20
20
|
self,
|
|
21
|
-
handler:
|
|
22
|
-
trainer:
|
|
21
|
+
handler: DSFLBaseServerHandler,
|
|
22
|
+
trainer: DSFLProcessPoolClientTrainer,
|
|
23
23
|
writer: SummaryWriter,
|
|
24
24
|
) -> None:
|
|
25
25
|
self.handler = handler
|
|
@@ -82,7 +82,7 @@ def main(
|
|
|
82
82
|
|
|
83
83
|
match cfg.algorithm.name:
|
|
84
84
|
case "dsfl":
|
|
85
|
-
handler =
|
|
85
|
+
handler = DSFLBaseServerHandler(
|
|
86
86
|
model_selector=model_selector,
|
|
87
87
|
model_name=cfg.model_name,
|
|
88
88
|
dataset=dataset,
|
|
@@ -96,7 +96,7 @@ def main(
|
|
|
96
96
|
device=device,
|
|
97
97
|
sample_ratio=cfg.sample_ratio,
|
|
98
98
|
)
|
|
99
|
-
trainer =
|
|
99
|
+
trainer = DSFLProcessPoolClientTrainer(
|
|
100
100
|
model_selector=model_selector,
|
|
101
101
|
model_name=cfg.model_name,
|
|
102
102
|
dataset=dataset,
|
|
@@ -6,13 +6,13 @@ extending the core functionalities of BlazeFL.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from blazefl.contrib.fedavg import (
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
FedAvgBaseClientTrainer,
|
|
10
|
+
FedAvgBaseServerHandler,
|
|
11
|
+
FedAvgProcessPoolClientTrainer,
|
|
12
12
|
)
|
|
13
13
|
|
|
14
14
|
__all__ = [
|
|
15
|
-
"
|
|
16
|
-
"
|
|
17
|
-
"
|
|
15
|
+
"FedAvgBaseServerHandler",
|
|
16
|
+
"FedAvgProcessPoolClientTrainer",
|
|
17
|
+
"FedAvgBaseClientTrainer",
|
|
18
18
|
]
|
|
@@ -8,11 +8,11 @@ from torch.utils.data import DataLoader
|
|
|
8
8
|
from tqdm import tqdm
|
|
9
9
|
|
|
10
10
|
from blazefl.core import (
|
|
11
|
+
BaseClientTrainer,
|
|
12
|
+
BaseServerHandler,
|
|
11
13
|
ModelSelector,
|
|
12
|
-
ParallelClientTrainer,
|
|
13
14
|
PartitionedDataset,
|
|
14
|
-
|
|
15
|
-
ServerHandler,
|
|
15
|
+
ProcessPoolClientTrainer,
|
|
16
16
|
)
|
|
17
17
|
from blazefl.utils import (
|
|
18
18
|
RandomState,
|
|
@@ -53,7 +53,9 @@ class FedAvgDownlinkPackage:
|
|
|
53
53
|
model_parameters: torch.Tensor
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
class
|
|
56
|
+
class FedAvgBaseServerHandler(
|
|
57
|
+
BaseServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPackage]
|
|
58
|
+
):
|
|
57
59
|
"""
|
|
58
60
|
Server-side handler for the Federated Averaging (FedAvg) algorithm.
|
|
59
61
|
|
|
@@ -85,7 +87,7 @@ class FedAvgServerHandler(ServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPacka
|
|
|
85
87
|
batch_size: int,
|
|
86
88
|
) -> None:
|
|
87
89
|
"""
|
|
88
|
-
Initialize the
|
|
90
|
+
Initialize the FedAvgBaseServerHandler.
|
|
89
91
|
|
|
90
92
|
Args:
|
|
91
93
|
model_selector (ModelSelector): Selector for initializing the model.
|
|
@@ -232,7 +234,7 @@ class FedAvgServerHandler(ServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPacka
|
|
|
232
234
|
return avg_loss, avg_acc
|
|
233
235
|
|
|
234
236
|
def get_summary(self) -> dict[str, float]:
|
|
235
|
-
server_loss, server_acc =
|
|
237
|
+
server_loss, server_acc = FedAvgBaseServerHandler.evaluate(
|
|
236
238
|
self.model,
|
|
237
239
|
self.dataset.get_dataloader(
|
|
238
240
|
type_="test",
|
|
@@ -258,11 +260,11 @@ class FedAvgServerHandler(ServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPacka
|
|
|
258
260
|
return FedAvgDownlinkPackage(model_parameters)
|
|
259
261
|
|
|
260
262
|
|
|
261
|
-
class
|
|
262
|
-
|
|
263
|
+
class FedAvgBaseClientTrainer(
|
|
264
|
+
BaseClientTrainer[FedAvgUplinkPackage, FedAvgDownlinkPackage]
|
|
263
265
|
):
|
|
264
266
|
"""
|
|
265
|
-
|
|
267
|
+
Base client trainer for the Federated Averaging (FedAvg) algorithm.
|
|
266
268
|
|
|
267
269
|
This trainer processes clients sequentially, training and evaluating a local model
|
|
268
270
|
for each client based on the server-provided model parameters.
|
|
@@ -291,7 +293,7 @@ class FedAvgSerialClientTrainer(
|
|
|
291
293
|
lr: float,
|
|
292
294
|
) -> None:
|
|
293
295
|
"""
|
|
294
|
-
Initialize the
|
|
296
|
+
Initialize the FedAvgBaseClientTrainer.
|
|
295
297
|
|
|
296
298
|
Args:
|
|
297
299
|
model_selector (ModelSelector): Selector for initializing the local model.
|
|
@@ -462,8 +464,8 @@ class FedAvgDiskSharedData:
|
|
|
462
464
|
state_path: Path
|
|
463
465
|
|
|
464
466
|
|
|
465
|
-
class
|
|
466
|
-
|
|
467
|
+
class FedAvgProcessPoolClientTrainer(
|
|
468
|
+
ProcessPoolClientTrainer[
|
|
467
469
|
FedAvgUplinkPackage, FedAvgDownlinkPackage, FedAvgDiskSharedData
|
|
468
470
|
]
|
|
469
471
|
):
|
|
@@ -573,7 +575,7 @@ class FedAvgParallelClientTrainer(
|
|
|
573
575
|
cid=data.cid,
|
|
574
576
|
batch_size=data.batch_size,
|
|
575
577
|
)
|
|
576
|
-
package =
|
|
578
|
+
package = FedAvgProcessPoolClientTrainer.train(
|
|
577
579
|
model=model,
|
|
578
580
|
model_parameters=data.payload.model_parameters,
|
|
579
581
|
train_loader=train_loader,
|
|
@@ -6,19 +6,19 @@ including client trainers, model selectors, partitioned datasets, and server han
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from blazefl.core.client_trainer import (
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
BaseClientTrainer,
|
|
10
|
+
ProcessPoolClientTrainer,
|
|
11
|
+
ThreadPoolClientTrainer,
|
|
12
12
|
)
|
|
13
13
|
from blazefl.core.model_selector import ModelSelector
|
|
14
14
|
from blazefl.core.partitioned_dataset import PartitionedDataset
|
|
15
|
-
from blazefl.core.server_handler import
|
|
15
|
+
from blazefl.core.server_handler import BaseServerHandler
|
|
16
16
|
|
|
17
17
|
__all__ = [
|
|
18
|
-
"
|
|
19
|
-
"
|
|
20
|
-
"
|
|
18
|
+
"BaseClientTrainer",
|
|
19
|
+
"ProcessPoolClientTrainer",
|
|
20
|
+
"ThreadPoolClientTrainer",
|
|
21
21
|
"ModelSelector",
|
|
22
22
|
"PartitionedDataset",
|
|
23
|
-
"
|
|
23
|
+
"BaseServerHandler",
|
|
24
24
|
]
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from blazefl.core.client_trainer import BaseClientTrainer as BaseClientTrainer, ProcessPoolClientTrainer as ProcessPoolClientTrainer, ThreadPoolClientTrainer as ThreadPoolClientTrainer
|
|
2
|
+
from blazefl.core.model_selector import ModelSelector as ModelSelector
|
|
3
|
+
from blazefl.core.partitioned_dataset import PartitionedDataset as PartitionedDataset
|
|
4
|
+
from blazefl.core.server_handler import BaseServerHandler as BaseServerHandler
|
|
5
|
+
|
|
6
|
+
__all__ = ['BaseClientTrainer', 'ProcessPoolClientTrainer', 'ThreadPoolClientTrainer', 'ModelSelector', 'PartitionedDataset', 'BaseServerHandler']
|
|
@@ -12,7 +12,7 @@ UplinkPackage = TypeVar("UplinkPackage")
|
|
|
12
12
|
DownlinkPackage = TypeVar("DownlinkPackage", contravariant=True)
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class
|
|
15
|
+
class BaseClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
|
|
16
16
|
"""
|
|
17
17
|
Abstract base class for serial client training in federated learning.
|
|
18
18
|
|
|
@@ -50,7 +50,8 @@ class SerialClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
|
|
|
50
50
|
DiskSharedData = TypeVar("DiskSharedData", covariant=True)
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
class
|
|
53
|
+
class ProcessPoolClientTrainer(
|
|
54
|
+
BaseClientTrainer[UplinkPackage, DownlinkPackage],
|
|
54
55
|
Protocol[UplinkPackage, DownlinkPackage, DiskSharedData],
|
|
55
56
|
):
|
|
56
57
|
"""
|
|
@@ -74,16 +75,6 @@ class ParallelClientTrainer(
|
|
|
74
75
|
device_count: int
|
|
75
76
|
cache: list[UplinkPackage]
|
|
76
77
|
|
|
77
|
-
def uplink_package(self) -> list[UplinkPackage]:
|
|
78
|
-
"""
|
|
79
|
-
Prepare the data package to be sent from the client to the server.
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
list[UplinkPackage]: A list of data packages prepared for uplink
|
|
83
|
-
transmission.
|
|
84
|
-
"""
|
|
85
|
-
...
|
|
86
|
-
|
|
87
78
|
def get_shared_data(self, cid: int, payload: DownlinkPackage) -> DiskSharedData:
|
|
88
79
|
"""
|
|
89
80
|
Retrieve shared data for a given client ID and payload.
|
|
@@ -159,7 +150,10 @@ class ParallelClientTrainer(
|
|
|
159
150
|
self.cache.append(package)
|
|
160
151
|
|
|
161
152
|
|
|
162
|
-
class
|
|
153
|
+
class ThreadPoolClientTrainer(
|
|
154
|
+
BaseClientTrainer[UplinkPackage, DownlinkPackage],
|
|
155
|
+
Protocol[UplinkPackage, DownlinkPackage],
|
|
156
|
+
):
|
|
163
157
|
num_parallels: int
|
|
164
158
|
device: str
|
|
165
159
|
device_count: int
|
|
@@ -5,25 +5,24 @@ from typing import Protocol, TypeVar
|
|
|
5
5
|
UplinkPackage = TypeVar('UplinkPackage')
|
|
6
6
|
DownlinkPackage = TypeVar('DownlinkPackage', contravariant=True)
|
|
7
7
|
|
|
8
|
-
class
|
|
8
|
+
class BaseClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
|
|
9
9
|
def uplink_package(self) -> list[UplinkPackage]: ...
|
|
10
10
|
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...
|
|
11
11
|
DiskSharedData = TypeVar('DiskSharedData', covariant=True)
|
|
12
12
|
|
|
13
|
-
class
|
|
13
|
+
class ProcessPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage], Protocol[UplinkPackage, DownlinkPackage, DiskSharedData]):
|
|
14
14
|
num_parallels: int
|
|
15
15
|
share_dir: Path
|
|
16
16
|
device: str
|
|
17
17
|
device_count: int
|
|
18
18
|
cache: list[UplinkPackage]
|
|
19
|
-
def uplink_package(self) -> list[UplinkPackage]: ...
|
|
20
19
|
def get_shared_data(self, cid: int, payload: DownlinkPackage) -> DiskSharedData: ...
|
|
21
20
|
def get_client_device(self, cid: int) -> str: ...
|
|
22
21
|
@staticmethod
|
|
23
22
|
def process_client(path: Path, device: str) -> Path: ...
|
|
24
23
|
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...
|
|
25
24
|
|
|
26
|
-
class
|
|
25
|
+
class ThreadPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage], Protocol[UplinkPackage, DownlinkPackage]):
|
|
27
26
|
num_parallels: int
|
|
28
27
|
device: str
|
|
29
28
|
device_count: int
|
|
@@ -4,7 +4,7 @@ UplinkPackage = TypeVar("UplinkPackage")
|
|
|
4
4
|
DownlinkPackage = TypeVar("DownlinkPackage", covariant=True)
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
class
|
|
7
|
+
class BaseServerHandler(Protocol[UplinkPackage, DownlinkPackage]):
|
|
8
8
|
"""
|
|
9
9
|
Abstract base class for server-side operations in federated learning.
|
|
10
10
|
|
|
@@ -3,7 +3,7 @@ from typing import Protocol, TypeVar
|
|
|
3
3
|
UplinkPackage = TypeVar('UplinkPackage')
|
|
4
4
|
DownlinkPackage = TypeVar('DownlinkPackage', covariant=True)
|
|
5
5
|
|
|
6
|
-
class
|
|
6
|
+
class BaseServerHandler(Protocol[UplinkPackage, DownlinkPackage]):
|
|
7
7
|
def downlink_package(self) -> DownlinkPackage: ...
|
|
8
8
|
def sample_clients(self) -> list[int]: ...
|
|
9
9
|
def if_stop(self) -> bool: ...
|
|
@@ -10,9 +10,9 @@ import torch
|
|
|
10
10
|
from torch.utils.data import DataLoader, Dataset
|
|
11
11
|
|
|
12
12
|
from src.blazefl.contrib.fedavg import (
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
13
|
+
FedAvgBaseClientTrainer,
|
|
14
|
+
FedAvgBaseServerHandler,
|
|
15
|
+
FedAvgProcessPoolClientTrainer,
|
|
16
16
|
)
|
|
17
17
|
from src.blazefl.core import ModelSelector, PartitionedDataset
|
|
18
18
|
|
|
@@ -86,7 +86,7 @@ def tmp_state_dir(tmp_path):
|
|
|
86
86
|
return state_dir
|
|
87
87
|
|
|
88
88
|
|
|
89
|
-
def
|
|
89
|
+
def test_server_and_base_integration(model_selector, partitioned_dataset, device):
|
|
90
90
|
model_name = "dummy"
|
|
91
91
|
global_round = 1
|
|
92
92
|
num_clients = 3
|
|
@@ -95,7 +95,7 @@ def test_server_and_serial_integration(model_selector, partitioned_dataset, devi
|
|
|
95
95
|
batch_size = 2
|
|
96
96
|
lr = 0.01
|
|
97
97
|
|
|
98
|
-
server =
|
|
98
|
+
server = FedAvgBaseServerHandler(
|
|
99
99
|
model_selector=model_selector,
|
|
100
100
|
model_name=model_name,
|
|
101
101
|
dataset=partitioned_dataset,
|
|
@@ -106,7 +106,7 @@ def test_server_and_serial_integration(model_selector, partitioned_dataset, devi
|
|
|
106
106
|
batch_size=batch_size,
|
|
107
107
|
)
|
|
108
108
|
|
|
109
|
-
trainer =
|
|
109
|
+
trainer = FedAvgBaseClientTrainer(
|
|
110
110
|
model_selector=model_selector,
|
|
111
111
|
model_name=model_name,
|
|
112
112
|
dataset=partitioned_dataset,
|
|
@@ -133,7 +133,7 @@ def test_server_and_serial_integration(model_selector, partitioned_dataset, devi
|
|
|
133
133
|
assert server.if_stop() is True
|
|
134
134
|
|
|
135
135
|
|
|
136
|
-
def
|
|
136
|
+
def test_server_and_process_pool_integration(
|
|
137
137
|
model_selector, partitioned_dataset, device, tmp_share_dir, tmp_state_dir
|
|
138
138
|
):
|
|
139
139
|
model_name = "dummy"
|
|
@@ -146,7 +146,7 @@ def test_server_and_parallel_integration(
|
|
|
146
146
|
seed = 42
|
|
147
147
|
num_parallels = 2
|
|
148
148
|
|
|
149
|
-
server =
|
|
149
|
+
server = FedAvgBaseServerHandler(
|
|
150
150
|
model_selector=model_selector,
|
|
151
151
|
model_name=model_name,
|
|
152
152
|
dataset=partitioned_dataset,
|
|
@@ -157,7 +157,7 @@ def test_server_and_parallel_integration(
|
|
|
157
157
|
batch_size=batch_size,
|
|
158
158
|
)
|
|
159
159
|
|
|
160
|
-
trainer =
|
|
160
|
+
trainer = FedAvgProcessPoolClientTrainer(
|
|
161
161
|
model_selector=model_selector,
|
|
162
162
|
model_name=model_name,
|
|
163
163
|
share_dir=tmp_share_dir,
|
|
@@ -193,7 +193,7 @@ def run_local_process(trainer, downlink, cids):
|
|
|
193
193
|
trainer.local_process(downlink, cids)
|
|
194
194
|
|
|
195
195
|
|
|
196
|
-
def
|
|
196
|
+
def test_server_and_process_pool_integration_keyboard_interrupt(
|
|
197
197
|
model_selector, partitioned_dataset, device, tmp_share_dir, tmp_state_dir
|
|
198
198
|
):
|
|
199
199
|
model_name = "dummy"
|
|
@@ -206,7 +206,7 @@ def test_server_and_parallel_integration_keyboard_interrupt(
|
|
|
206
206
|
seed = 42
|
|
207
207
|
num_parallels = 10
|
|
208
208
|
|
|
209
|
-
server =
|
|
209
|
+
server = FedAvgBaseServerHandler(
|
|
210
210
|
model_selector=model_selector,
|
|
211
211
|
model_name=model_name,
|
|
212
212
|
dataset=partitioned_dataset,
|
|
@@ -217,7 +217,7 @@ def test_server_and_parallel_integration_keyboard_interrupt(
|
|
|
217
217
|
batch_size=batch_size,
|
|
218
218
|
)
|
|
219
219
|
|
|
220
|
-
trainer =
|
|
220
|
+
trainer = FedAvgProcessPoolClientTrainer(
|
|
221
221
|
model_selector=model_selector,
|
|
222
222
|
model_name=model_name,
|
|
223
223
|
share_dir=tmp_share_dir,
|
|
@@ -4,7 +4,7 @@ from pathlib import Path
|
|
|
4
4
|
import pytest
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from src.blazefl.core import
|
|
7
|
+
from src.blazefl.core import ProcessPoolClientTrainer
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
@dataclass
|
|
@@ -24,8 +24,8 @@ class DiskSharedData:
|
|
|
24
24
|
payload: DownlinkPackage
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
class
|
|
28
|
-
|
|
27
|
+
class DummyProcessPoolClientTrainer(
|
|
28
|
+
ProcessPoolClientTrainer[UplinkPackage, DownlinkPackage, DiskSharedData]
|
|
29
29
|
):
|
|
30
30
|
def __init__(self, num_parallels: int, share_dir: Path, device: str):
|
|
31
31
|
self.num_parallels = num_parallels
|
|
@@ -58,10 +58,10 @@ class DummyParallelClientTrainer(
|
|
|
58
58
|
|
|
59
59
|
@pytest.mark.parametrize("num_parallels", [1, 2, 4])
|
|
60
60
|
@pytest.mark.parametrize("cid_list", [[], [42], [0, 1, 2]])
|
|
61
|
-
def
|
|
61
|
+
def test_process_pool_client_trainer(
|
|
62
62
|
tmp_path: Path, num_parallels: int, cid_list: list[int]
|
|
63
63
|
) -> None:
|
|
64
|
-
trainer =
|
|
64
|
+
trainer = DummyProcessPoolClientTrainer(
|
|
65
65
|
num_parallels=num_parallels, share_dir=tmp_path, device="cpu"
|
|
66
66
|
)
|
|
67
67
|
|
|
@@ -1,45 +0,0 @@
|
|
|
1
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
2
|
-
from typing import Protocol, TypeVar
|
|
3
|
-
|
|
4
|
-
from tqdm import tqdm
|
|
5
|
-
|
|
6
|
-
UplinkPackage = TypeVar("UplinkPackage")
|
|
7
|
-
DownlinkPackage = TypeVar("DownlinkPackage", contravariant=True)
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class MultiThreadClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
|
|
11
|
-
num_parallels: int
|
|
12
|
-
device: str
|
|
13
|
-
device_count: int
|
|
14
|
-
cache: list[UplinkPackage]
|
|
15
|
-
|
|
16
|
-
def process_client(
|
|
17
|
-
self,
|
|
18
|
-
cid: int,
|
|
19
|
-
device: str,
|
|
20
|
-
payload: DownlinkPackage,
|
|
21
|
-
) -> UplinkPackage: ...
|
|
22
|
-
|
|
23
|
-
def get_client_device(self, cid: int) -> str:
|
|
24
|
-
if self.device == "cuda":
|
|
25
|
-
return f"cuda:{cid % self.device_count}"
|
|
26
|
-
return self.device
|
|
27
|
-
|
|
28
|
-
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
|
|
29
|
-
with ThreadPoolExecutor(max_workers=self.num_parallels) as executor:
|
|
30
|
-
futures = []
|
|
31
|
-
for cid in cid_list:
|
|
32
|
-
device = self.get_client_device(cid)
|
|
33
|
-
future = executor.submit(
|
|
34
|
-
self.process_client,
|
|
35
|
-
cid,
|
|
36
|
-
device,
|
|
37
|
-
payload,
|
|
38
|
-
)
|
|
39
|
-
futures.append(future)
|
|
40
|
-
|
|
41
|
-
for future in tqdm(
|
|
42
|
-
as_completed(futures), total=len(futures), desc="Client", leave=False
|
|
43
|
-
):
|
|
44
|
-
result = future.result()
|
|
45
|
-
self.cache.append(result)
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
from blazefl.core.client_trainer import MultiThreadClientTrainer as MultiThreadClientTrainer, ParallelClientTrainer as ParallelClientTrainer, SerialClientTrainer as SerialClientTrainer
|
|
2
|
-
from blazefl.core.model_selector import ModelSelector as ModelSelector
|
|
3
|
-
from blazefl.core.partitioned_dataset import PartitionedDataset as PartitionedDataset
|
|
4
|
-
from blazefl.core.server_handler import ServerHandler as ServerHandler
|
|
5
|
-
|
|
6
|
-
__all__ = ['SerialClientTrainer', 'ParallelClientTrainer', 'MultiThreadClientTrainer', 'ModelSelector', 'PartitionedDataset', 'ServerHandler']
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/.python-version
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/config/config.yaml
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/dataset/__init__.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/dataset/dataset.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/dataset/functional.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/models/__init__.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/models/selector.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|