blazefl 2.0.0.dev2__tar.gz → 2.0.0.dev4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/PKG-INFO +1 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/main.py +1 -2
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/pyproject.toml +1 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/uv.lock +4 -4
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/config/config.yaml +2 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/main.py +15 -14
- blazefl-2.0.0.dev4/examples/step-by-step-dsfl/algorithm/__init__.py +3 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/algorithm/dsfl.py +53 -49
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/main.py +5 -5
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/pyproject.toml +1 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/contrib/__init__.py +6 -6
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/contrib/fedavg.py +101 -59
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/__init__.py +8 -8
- blazefl-2.0.0.dev4/src/blazefl/core/__init__.pyi +6 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/client_trainer.py +80 -35
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/client_trainer.pyi +10 -9
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/server_handler.py +1 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/server_handler.pyi +1 -1
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/__init__.py +2 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/__init__.pyi +2 -1
- blazefl-2.0.0.dev4/src/blazefl/utils/ipc.py +33 -0
- blazefl-2.0.0.dev4/src/blazefl/utils/ipc.pyi +3 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_contrib/test_fedavg.py +16 -13
- blazefl-2.0.0.dev4/tests/test_core/test_client_trainer.py +126 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/uv.lock +1 -1
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/__init__.py +0 -5
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/client_trainer.py +0 -45
- blazefl-2.0.0.dev2/examples/step-by-step-dsfl/algorithm/__init__.py +0 -3
- blazefl-2.0.0.dev2/src/blazefl/core/__init__.pyi +0 -6
- blazefl-2.0.0.dev2/tests/test_core/test_client_trainer.py +0 -84
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/FUNDING.yml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/dependabot.yml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/workflows/ci.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/workflows/publish.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/workflows/sphinx.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.gitignore +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.pre-commit-config.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.python-version +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/CODE_OF_CONDUCT.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/LICENSE +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/Makefile +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/README.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/Makefile +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/architecture.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/benchmark_cnn.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/benchmark_resnet18.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/logo.svg +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/logo_square.svg +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/ogp.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/make.bat +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/_static/favicon.ico +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/_static/logo.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/_static/logo_square.png +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/_templates/autosummary/class.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/_templates/autosummary/module.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/benchmark.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/conf.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/contribute.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/example.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/index.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/install.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/overview.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/reference.rst +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/.gitignore +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/.python-version +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/Makefile +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/README.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/config/config.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/models/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/models/selector.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/.gitignore +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/.python-version +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/Makefile +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/README.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/models/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/models/selector.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/pyproject.toml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/.gitignore +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/.python-version +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/Makefile +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/README.md +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/config/config.yaml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/cnn.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/selector.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/pyproject.toml +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/model_selector.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/model_selector.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/partitioned_dataset.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/py.typed +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/dataset.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/seed.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/seed.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/serialize.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/serialize.pyi +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/conftest.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_contrib/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_core/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_core/test_model_selector.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_core/test_partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_utils/__init__.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_utils/test_dataset.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_utils/test_seed.py +0 -0
- {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_utils/test_serialize.py +0 -0
|
@@ -12,11 +12,10 @@ from blazefl.contrib import (
|
|
|
12
12
|
FedAvgServerHandler,
|
|
13
13
|
)
|
|
14
14
|
from blazefl.contrib.fedavg import FedAvgDownlinkPackage, FedAvgUplinkPackage
|
|
15
|
-
from blazefl.core import ModelSelector, PartitionedDataset
|
|
15
|
+
from blazefl.core import ModelSelector, MultiThreadClientTrainer, PartitionedDataset
|
|
16
16
|
from blazefl.utils import seed_everything
|
|
17
17
|
from omegaconf import DictConfig, OmegaConf
|
|
18
18
|
|
|
19
|
-
from core.client_trainer import MultiThreadClientTrainer
|
|
20
19
|
from dataset import PartitionedCIFAR10
|
|
21
20
|
from models import FedAvgModelSelector
|
|
22
21
|
|
|
@@ -10,16 +10,16 @@ sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d
|
|
|
10
10
|
|
|
11
11
|
[[package]]
|
|
12
12
|
name = "blazefl"
|
|
13
|
-
version = "2.0.0.
|
|
13
|
+
version = "2.0.0.dev2"
|
|
14
14
|
source = { registry = "https://pypi.org/simple" }
|
|
15
15
|
dependencies = [
|
|
16
16
|
{ name = "numpy" },
|
|
17
17
|
{ name = "torch" },
|
|
18
18
|
{ name = "tqdm" },
|
|
19
19
|
]
|
|
20
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
|
20
|
+
sdist = { url = "https://files.pythonhosted.org/packages/9b/92/3f668d7259a81d13d5d2e6089d8022629a09058ddbcae33f5873d3871656/blazefl-2.0.0.dev2.tar.gz", hash = "sha256:28a3e09d6f6cec8d8e8bb4da90adaa0c795980c74f9d4ce65024ba762726b595", size = 613563, upload_time = "2025-06-10T09:24:08.981Z" }
|
|
21
21
|
wheels = [
|
|
22
|
-
{ url = "https://files.pythonhosted.org/packages/
|
|
22
|
+
{ url = "https://files.pythonhosted.org/packages/e3/7c/841c8b9d22caaefc1dcc167076292fd1479309b3e11eb6fc1b24141792a8/blazefl-2.0.0.dev2-py3-none-any.whl", hash = "sha256:88639018166dafa79fde59aa5575bfb7277bf9e0e4497f22a59e9c7ec34efa4b", size = 23999, upload_time = "2025-06-10T09:24:07.553Z" },
|
|
23
23
|
]
|
|
24
24
|
|
|
25
25
|
[[package]]
|
|
@@ -50,7 +50,7 @@ dev = [
|
|
|
50
50
|
|
|
51
51
|
[package.metadata]
|
|
52
52
|
requires-dist = [
|
|
53
|
-
{ name = "blazefl", specifier = ">=2.0.0.
|
|
53
|
+
{ name = "blazefl", specifier = ">=2.0.0.dev2" },
|
|
54
54
|
{ name = "hydra-core", specifier = ">=1.3.2" },
|
|
55
55
|
{ name = "torch", specifier = ">=2.7.1" },
|
|
56
56
|
{ name = "torchvision", specifier = ">=0.22.1" },
|
|
@@ -6,9 +6,9 @@ import hydra
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.multiprocessing as mp
|
|
8
8
|
from blazefl.contrib import (
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
FedAvgBaseClientTrainer,
|
|
10
|
+
FedAvgBaseServerHandler,
|
|
11
|
+
FedAvgProcessPoolClientTrainer,
|
|
12
12
|
)
|
|
13
13
|
from blazefl.utils import seed_everything
|
|
14
14
|
from hydra.core import hydra_config
|
|
@@ -22,8 +22,8 @@ from models import FedAvgModelSelector
|
|
|
22
22
|
class FedAvgPipeline:
|
|
23
23
|
def __init__(
|
|
24
24
|
self,
|
|
25
|
-
handler:
|
|
26
|
-
trainer:
|
|
25
|
+
handler: FedAvgBaseServerHandler,
|
|
26
|
+
trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer,
|
|
27
27
|
writer: SummaryWriter,
|
|
28
28
|
) -> None:
|
|
29
29
|
self.handler = handler
|
|
@@ -87,7 +87,7 @@ def main(cfg: DictConfig):
|
|
|
87
87
|
)
|
|
88
88
|
model_selector = FedAvgModelSelector(num_classes=10)
|
|
89
89
|
|
|
90
|
-
handler =
|
|
90
|
+
handler = FedAvgBaseServerHandler(
|
|
91
91
|
model_selector=model_selector,
|
|
92
92
|
model_name=cfg.model_name,
|
|
93
93
|
dataset=dataset,
|
|
@@ -97,32 +97,33 @@ def main(cfg: DictConfig):
|
|
|
97
97
|
sample_ratio=cfg.sample_ratio,
|
|
98
98
|
batch_size=cfg.batch_size,
|
|
99
99
|
)
|
|
100
|
-
trainer:
|
|
101
|
-
if cfg.
|
|
102
|
-
trainer =
|
|
100
|
+
trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer | None = None
|
|
101
|
+
if cfg.parallel:
|
|
102
|
+
trainer = FedAvgProcessPoolClientTrainer(
|
|
103
103
|
model_selector=model_selector,
|
|
104
104
|
model_name=cfg.model_name,
|
|
105
105
|
dataset=dataset,
|
|
106
|
+
share_dir=share_dir,
|
|
107
|
+
state_dir=state_dir,
|
|
108
|
+
seed=cfg.seed,
|
|
106
109
|
device=device,
|
|
107
110
|
num_clients=cfg.num_clients,
|
|
108
111
|
epochs=cfg.epochs,
|
|
109
112
|
lr=cfg.lr,
|
|
110
113
|
batch_size=cfg.batch_size,
|
|
114
|
+
num_parallels=cfg.num_parallels,
|
|
115
|
+
ipc_mode=cfg.ipc_mode,
|
|
111
116
|
)
|
|
112
117
|
else:
|
|
113
|
-
trainer =
|
|
118
|
+
trainer = FedAvgBaseClientTrainer(
|
|
114
119
|
model_selector=model_selector,
|
|
115
120
|
model_name=cfg.model_name,
|
|
116
121
|
dataset=dataset,
|
|
117
|
-
share_dir=share_dir,
|
|
118
|
-
state_dir=state_dir,
|
|
119
|
-
seed=cfg.seed,
|
|
120
122
|
device=device,
|
|
121
123
|
num_clients=cfg.num_clients,
|
|
122
124
|
epochs=cfg.epochs,
|
|
123
125
|
lr=cfg.lr,
|
|
124
126
|
batch_size=cfg.batch_size,
|
|
125
|
-
num_parallels=cfg.num_parallels,
|
|
126
127
|
)
|
|
127
128
|
pipeline = FedAvgPipeline(handler=handler, trainer=trainer, writer=writer)
|
|
128
129
|
try:
|
|
@@ -7,8 +7,8 @@ from pathlib import Path
|
|
|
7
7
|
import torch
|
|
8
8
|
import torch.nn.functional as F
|
|
9
9
|
from blazefl.core import (
|
|
10
|
-
|
|
11
|
-
|
|
10
|
+
BaseServerHandler,
|
|
11
|
+
ProcessPoolClientTrainer,
|
|
12
12
|
)
|
|
13
13
|
from blazefl.utils import (
|
|
14
14
|
FilteredDataset,
|
|
@@ -35,7 +35,7 @@ class DSFLDownlinkPackage:
|
|
|
35
35
|
next_indices: torch.Tensor
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
class
|
|
38
|
+
class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
|
|
39
39
|
def __init__(
|
|
40
40
|
self,
|
|
41
41
|
model_selector: DSFLModelSelector,
|
|
@@ -116,7 +116,7 @@ class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
|
|
|
116
116
|
era_soft_labels = F.softmax(mean_soft_labels / self.era_temperature, dim=0)
|
|
117
117
|
global_soft_labels.append(era_soft_labels)
|
|
118
118
|
|
|
119
|
-
|
|
119
|
+
DSFLBaseServerHandler.distill(
|
|
120
120
|
self.model,
|
|
121
121
|
self.kd_optimizer,
|
|
122
122
|
self.dataset,
|
|
@@ -205,7 +205,7 @@ class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
|
|
|
205
205
|
return avg_loss, avg_acc
|
|
206
206
|
|
|
207
207
|
def get_summary(self) -> dict[str, float]:
|
|
208
|
-
server_loss, server_acc =
|
|
208
|
+
server_loss, server_acc = DSFLBaseServerHandler.evaulate(
|
|
209
209
|
self.model,
|
|
210
210
|
self.dataset.get_dataloader(
|
|
211
211
|
type_="test",
|
|
@@ -233,7 +233,7 @@ class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
|
|
|
233
233
|
|
|
234
234
|
|
|
235
235
|
@dataclass
|
|
236
|
-
class
|
|
236
|
+
class DSFLClientConfig:
|
|
237
237
|
model_selector: DSFLModelSelector
|
|
238
238
|
model_name: str
|
|
239
239
|
dataset: DSFLPartitionedDataset
|
|
@@ -245,7 +245,6 @@ class DSFLDiskSharedData:
|
|
|
245
245
|
kd_lr: float
|
|
246
246
|
cid: int
|
|
247
247
|
seed: int
|
|
248
|
-
payload: DSFLDownlinkPackage
|
|
249
248
|
state_path: Path
|
|
250
249
|
|
|
251
250
|
|
|
@@ -257,8 +256,8 @@ class DSFLClientState:
|
|
|
257
256
|
kd_optimizer: dict[str, torch.Tensor] | None
|
|
258
257
|
|
|
259
258
|
|
|
260
|
-
class
|
|
261
|
-
|
|
259
|
+
class DSFLProcessPoolClientTrainer(
|
|
260
|
+
ProcessPoolClientTrainer[DSFLUplinkPackage, DSFLDownlinkPackage, DSFLClientConfig]
|
|
262
261
|
):
|
|
263
262
|
def __init__(
|
|
264
263
|
self,
|
|
@@ -300,82 +299,90 @@ class DSFLParallelClientTrainer(
|
|
|
300
299
|
self.device = device
|
|
301
300
|
self.num_clients = num_clients
|
|
302
301
|
self.seed = seed
|
|
302
|
+
self.ipc_mode = "storage"
|
|
303
303
|
|
|
304
304
|
if self.device == "cuda":
|
|
305
305
|
self.device_count = torch.cuda.device_count()
|
|
306
306
|
|
|
307
307
|
@staticmethod
|
|
308
|
-
def
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
308
|
+
def worker(
|
|
309
|
+
config: DSFLClientConfig | Path,
|
|
310
|
+
payload: DSFLDownlinkPackage | Path,
|
|
311
|
+
device: str,
|
|
312
|
+
) -> Path:
|
|
313
|
+
assert isinstance(config, Path) and isinstance(payload, Path)
|
|
314
|
+
config_path, payload_path = config, payload
|
|
315
|
+
c = torch.load(config_path, weights_only=False)
|
|
316
|
+
p = torch.load(payload_path, weights_only=False)
|
|
317
|
+
assert isinstance(c, DSFLClientConfig) and isinstance(p, DSFLDownlinkPackage)
|
|
318
|
+
|
|
319
|
+
model = c.model_selector.select_model(c.model_name)
|
|
320
|
+
optimizer = torch.optim.SGD(model.parameters(), lr=c.lr)
|
|
314
321
|
kd_optimizer: torch.optim.SGD | None = None
|
|
315
322
|
|
|
316
323
|
state: DSFLClientState | None = None
|
|
317
|
-
if
|
|
318
|
-
state = torch.load(
|
|
324
|
+
if c.state_path.exists():
|
|
325
|
+
state = torch.load(c.state_path, weights_only=False)
|
|
319
326
|
assert isinstance(state, DSFLClientState)
|
|
320
327
|
RandomState.set_random_state(state.random)
|
|
321
328
|
model.load_state_dict(state.model)
|
|
322
329
|
optimizer.load_state_dict(state.optimizer)
|
|
323
330
|
if state.kd_optimizer is not None:
|
|
324
|
-
kd_optimizer = torch.optim.SGD(model.parameters(), lr=
|
|
331
|
+
kd_optimizer = torch.optim.SGD(model.parameters(), lr=c.kd_lr)
|
|
325
332
|
kd_optimizer.load_state_dict(state.kd_optimizer)
|
|
326
333
|
else:
|
|
327
|
-
seed_everything(
|
|
334
|
+
seed_everything(c.seed, device=device)
|
|
328
335
|
|
|
329
336
|
# Distill
|
|
330
|
-
open_dataset =
|
|
331
|
-
if
|
|
332
|
-
global_soft_labels = list(torch.unbind(
|
|
333
|
-
global_indices =
|
|
337
|
+
open_dataset = c.dataset.get_dataset(type_="open", cid=None)
|
|
338
|
+
if p.indices is not None and p.soft_labels is not None:
|
|
339
|
+
global_soft_labels = list(torch.unbind(p.soft_labels, dim=0))
|
|
340
|
+
global_indices = p.indices.tolist()
|
|
334
341
|
if kd_optimizer is None:
|
|
335
|
-
kd_optimizer = torch.optim.SGD(model.parameters(), lr=
|
|
336
|
-
|
|
342
|
+
kd_optimizer = torch.optim.SGD(model.parameters(), lr=c.kd_lr)
|
|
343
|
+
DSFLBaseServerHandler.distill(
|
|
337
344
|
model=model,
|
|
338
345
|
optimizer=kd_optimizer,
|
|
339
|
-
dataset=
|
|
346
|
+
dataset=c.dataset,
|
|
340
347
|
global_soft_labels=global_soft_labels,
|
|
341
348
|
global_indices=global_indices,
|
|
342
|
-
kd_epochs=
|
|
343
|
-
kd_batch_size=
|
|
349
|
+
kd_epochs=c.kd_epochs,
|
|
350
|
+
kd_batch_size=c.kd_batch_size,
|
|
344
351
|
device=device,
|
|
345
352
|
)
|
|
346
353
|
|
|
347
354
|
# Train
|
|
348
|
-
train_loader =
|
|
355
|
+
train_loader = c.dataset.get_dataloader(
|
|
349
356
|
type_="train",
|
|
350
|
-
cid=
|
|
351
|
-
batch_size=
|
|
357
|
+
cid=c.cid,
|
|
358
|
+
batch_size=c.batch_size,
|
|
352
359
|
)
|
|
353
|
-
|
|
360
|
+
DSFLProcessPoolClientTrainer.train(
|
|
354
361
|
model=model,
|
|
355
362
|
optimizer=optimizer,
|
|
356
363
|
train_loader=train_loader,
|
|
357
364
|
device=device,
|
|
358
|
-
epochs=
|
|
365
|
+
epochs=c.epochs,
|
|
359
366
|
)
|
|
360
367
|
|
|
361
368
|
# Predict
|
|
362
369
|
open_loader = DataLoader(
|
|
363
|
-
Subset(open_dataset,
|
|
364
|
-
batch_size=
|
|
370
|
+
Subset(open_dataset, p.next_indices.tolist()),
|
|
371
|
+
batch_size=c.batch_size,
|
|
365
372
|
)
|
|
366
|
-
soft_labels =
|
|
373
|
+
soft_labels = DSFLProcessPoolClientTrainer.predict(
|
|
367
374
|
model=model,
|
|
368
375
|
open_loader=open_loader,
|
|
369
376
|
device=device,
|
|
370
377
|
)
|
|
371
378
|
|
|
372
379
|
# Evaluate
|
|
373
|
-
test_loader =
|
|
380
|
+
test_loader = c.dataset.get_dataloader(
|
|
374
381
|
type_="test",
|
|
375
|
-
cid=
|
|
376
|
-
batch_size=
|
|
382
|
+
cid=c.cid,
|
|
383
|
+
batch_size=c.batch_size,
|
|
377
384
|
)
|
|
378
|
-
loss, acc =
|
|
385
|
+
loss, acc = DSFLBaseServerHandler.evaulate(
|
|
379
386
|
model=model,
|
|
380
387
|
test_loader=test_loader,
|
|
381
388
|
device=device,
|
|
@@ -383,19 +390,19 @@ class DSFLParallelClientTrainer(
|
|
|
383
390
|
|
|
384
391
|
package = DSFLUplinkPackage(
|
|
385
392
|
soft_labels=soft_labels,
|
|
386
|
-
indices=
|
|
393
|
+
indices=p.next_indices,
|
|
387
394
|
metadata={"loss": loss, "acc": acc},
|
|
388
395
|
)
|
|
389
396
|
|
|
390
|
-
torch.save(package,
|
|
397
|
+
torch.save(package, config_path)
|
|
391
398
|
state = DSFLClientState(
|
|
392
399
|
random=RandomState.get_random_state(device=device),
|
|
393
400
|
model=model.state_dict(),
|
|
394
401
|
optimizer=optimizer.state_dict(),
|
|
395
402
|
kd_optimizer=kd_optimizer.state_dict() if kd_optimizer else None,
|
|
396
403
|
)
|
|
397
|
-
torch.save(state,
|
|
398
|
-
return
|
|
404
|
+
torch.save(state, c.state_path)
|
|
405
|
+
return config_path
|
|
399
406
|
|
|
400
407
|
@staticmethod
|
|
401
408
|
def train(
|
|
@@ -442,10 +449,8 @@ class DSFLParallelClientTrainer(
|
|
|
442
449
|
soft_labels = torch.cat(soft_labels_list, dim=0)
|
|
443
450
|
return soft_labels.cpu()
|
|
444
451
|
|
|
445
|
-
def
|
|
446
|
-
|
|
447
|
-
) -> DSFLDiskSharedData:
|
|
448
|
-
data = DSFLDiskSharedData(
|
|
452
|
+
def get_client_config(self, cid: int) -> DSFLClientConfig:
|
|
453
|
+
data = DSFLClientConfig(
|
|
449
454
|
model_selector=self.model_selector,
|
|
450
455
|
model_name=self.model_name,
|
|
451
456
|
dataset=self.dataset,
|
|
@@ -457,7 +462,6 @@ class DSFLParallelClientTrainer(
|
|
|
457
462
|
kd_lr=self.kd_lr,
|
|
458
463
|
cid=cid,
|
|
459
464
|
seed=self.seed,
|
|
460
|
-
payload=payload,
|
|
461
465
|
state_path=self.state_dir.joinpath(f"{cid}.pt"),
|
|
462
466
|
)
|
|
463
467
|
return data
|
|
@@ -10,7 +10,7 @@ from hydra.core import hydra_config
|
|
|
10
10
|
from omegaconf import DictConfig, OmegaConf
|
|
11
11
|
from torch.utils.tensorboard.writer import SummaryWriter
|
|
12
12
|
|
|
13
|
-
from algorithm import
|
|
13
|
+
from algorithm import DSFLBaseServerHandler, DSFLProcessPoolClientTrainer
|
|
14
14
|
from dataset import DSFLPartitionedDataset
|
|
15
15
|
from models import DSFLModelSelector
|
|
16
16
|
|
|
@@ -18,8 +18,8 @@ from models import DSFLModelSelector
|
|
|
18
18
|
class DSFLPipeline:
|
|
19
19
|
def __init__(
|
|
20
20
|
self,
|
|
21
|
-
handler:
|
|
22
|
-
trainer:
|
|
21
|
+
handler: DSFLBaseServerHandler,
|
|
22
|
+
trainer: DSFLProcessPoolClientTrainer,
|
|
23
23
|
writer: SummaryWriter,
|
|
24
24
|
) -> None:
|
|
25
25
|
self.handler = handler
|
|
@@ -82,7 +82,7 @@ def main(
|
|
|
82
82
|
|
|
83
83
|
match cfg.algorithm.name:
|
|
84
84
|
case "dsfl":
|
|
85
|
-
handler =
|
|
85
|
+
handler = DSFLBaseServerHandler(
|
|
86
86
|
model_selector=model_selector,
|
|
87
87
|
model_name=cfg.model_name,
|
|
88
88
|
dataset=dataset,
|
|
@@ -96,7 +96,7 @@ def main(
|
|
|
96
96
|
device=device,
|
|
97
97
|
sample_ratio=cfg.sample_ratio,
|
|
98
98
|
)
|
|
99
|
-
trainer =
|
|
99
|
+
trainer = DSFLProcessPoolClientTrainer(
|
|
100
100
|
model_selector=model_selector,
|
|
101
101
|
model_name=cfg.model_name,
|
|
102
102
|
dataset=dataset,
|
|
@@ -6,13 +6,13 @@ extending the core functionalities of BlazeFL.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from blazefl.contrib.fedavg import (
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
FedAvgBaseClientTrainer,
|
|
10
|
+
FedAvgBaseServerHandler,
|
|
11
|
+
FedAvgProcessPoolClientTrainer,
|
|
12
12
|
)
|
|
13
13
|
|
|
14
14
|
__all__ = [
|
|
15
|
-
"
|
|
16
|
-
"
|
|
17
|
-
"
|
|
15
|
+
"FedAvgBaseServerHandler",
|
|
16
|
+
"FedAvgProcessPoolClientTrainer",
|
|
17
|
+
"FedAvgBaseClientTrainer",
|
|
18
18
|
]
|