blazefl 2.0.0.dev4__tar.gz → 2.0.0.dev6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blazefl-2.0.0.dev6/.python-version +1 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/PKG-INFO +1 -1
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/config/config.yaml +1 -1
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/main.py +54 -32
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/main.py +1 -3
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/pyproject.toml +1 -1
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/contrib/__init__.py +6 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/contrib/fedavg.py +132 -2
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/client_trainer.py +34 -8
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/client_trainer.pyi +5 -2
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_contrib/test_fedavg.py +161 -33
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_core/test_client_trainer.py +14 -3
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/uv.lock +1 -1
- blazefl-2.0.0.dev4/examples/step-by-step-dsfl/.python-version +0 -1
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/FUNDING.yml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/dependabot.yml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/workflows/ci.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/workflows/publish.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/workflows/sphinx.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.gitignore +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.pre-commit-config.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/CODE_OF_CONDUCT.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/LICENSE +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/Makefile +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/README.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/Makefile +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/architecture.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/benchmark_cnn.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/benchmark_resnet18.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/logo.svg +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/logo_square.svg +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/ogp.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/make.bat +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/_static/favicon.ico +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/_static/logo.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/_static/logo_square.png +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/_templates/autosummary/class.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/_templates/autosummary/module.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/benchmark.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/conf.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/contribute.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/example.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/index.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/install.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/overview.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/reference.rst +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/.gitignore +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/.python-version +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/Makefile +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/README.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/config/config.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/main.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/models/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/models/selector.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/pyproject.toml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/uv.lock +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/.gitignore +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6/examples/quickstart-fedavg}/.python-version +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/Makefile +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/README.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/models/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/models/selector.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/pyproject.toml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/.gitignore +0 -0
- {blazefl-2.0.0.dev4/examples/quickstart-fedavg → blazefl-2.0.0.dev6/examples/step-by-step-dsfl}/.python-version +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/Makefile +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/README.md +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/algorithm/dsfl.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/config/config.yaml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/models/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/models/cnn.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/models/selector.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/pyproject.toml +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/__init__.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/model_selector.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/model_selector.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/partitioned_dataset.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/server_handler.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/server_handler.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/py.typed +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/__init__.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/dataset.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/ipc.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/ipc.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/seed.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/seed.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/serialize.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/serialize.pyi +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/conftest.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_contrib/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_core/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_core/test_model_selector.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_core/test_partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_utils/__init__.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_utils/test_dataset.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_utils/test_seed.py +0 -0
- {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_utils/test_serialize.py +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.13
|
|
@@ -9,6 +9,7 @@ from blazefl.contrib import (
|
|
|
9
9
|
FedAvgBaseClientTrainer,
|
|
10
10
|
FedAvgBaseServerHandler,
|
|
11
11
|
FedAvgProcessPoolClientTrainer,
|
|
12
|
+
FedAvgThreadPoolClientTrainer,
|
|
12
13
|
)
|
|
13
14
|
from blazefl.utils import seed_everything
|
|
14
15
|
from hydra.core import hydra_config
|
|
@@ -23,7 +24,9 @@ class FedAvgPipeline:
|
|
|
23
24
|
def __init__(
|
|
24
25
|
self,
|
|
25
26
|
handler: FedAvgBaseServerHandler,
|
|
26
|
-
trainer: FedAvgBaseClientTrainer
|
|
27
|
+
trainer: FedAvgBaseClientTrainer
|
|
28
|
+
| FedAvgProcessPoolClientTrainer
|
|
29
|
+
| FedAvgThreadPoolClientTrainer,
|
|
27
30
|
writer: SummaryWriter,
|
|
28
31
|
) -> None:
|
|
29
32
|
self.handler = handler
|
|
@@ -97,41 +100,60 @@ def main(cfg: DictConfig):
|
|
|
97
100
|
sample_ratio=cfg.sample_ratio,
|
|
98
101
|
batch_size=cfg.batch_size,
|
|
99
102
|
)
|
|
100
|
-
trainer:
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
103
|
+
trainer: (
|
|
104
|
+
FedAvgBaseClientTrainer
|
|
105
|
+
| FedAvgProcessPoolClientTrainer
|
|
106
|
+
| FedAvgThreadPoolClientTrainer
|
|
107
|
+
| None
|
|
108
|
+
) = None
|
|
109
|
+
match cfg.execution_mode:
|
|
110
|
+
case "multi-process":
|
|
111
|
+
trainer = FedAvgProcessPoolClientTrainer(
|
|
112
|
+
model_selector=model_selector,
|
|
113
|
+
model_name=cfg.model_name,
|
|
114
|
+
dataset=dataset,
|
|
115
|
+
share_dir=share_dir,
|
|
116
|
+
state_dir=state_dir,
|
|
117
|
+
seed=cfg.seed,
|
|
118
|
+
device=device,
|
|
119
|
+
num_clients=cfg.num_clients,
|
|
120
|
+
epochs=cfg.epochs,
|
|
121
|
+
lr=cfg.lr,
|
|
122
|
+
batch_size=cfg.batch_size,
|
|
123
|
+
num_parallels=cfg.num_parallels,
|
|
124
|
+
ipc_mode=cfg.ipc_mode,
|
|
125
|
+
)
|
|
126
|
+
case "single-thread":
|
|
127
|
+
trainer = FedAvgBaseClientTrainer(
|
|
128
|
+
model_selector=model_selector,
|
|
129
|
+
model_name=cfg.model_name,
|
|
130
|
+
dataset=dataset,
|
|
131
|
+
device=device,
|
|
132
|
+
num_clients=cfg.num_clients,
|
|
133
|
+
epochs=cfg.epochs,
|
|
134
|
+
lr=cfg.lr,
|
|
135
|
+
batch_size=cfg.batch_size,
|
|
136
|
+
)
|
|
137
|
+
case "multi-thread":
|
|
138
|
+
trainer = FedAvgThreadPoolClientTrainer(
|
|
139
|
+
model_selector=model_selector,
|
|
140
|
+
model_name=cfg.model_name,
|
|
141
|
+
dataset=dataset,
|
|
142
|
+
seed=cfg.seed,
|
|
143
|
+
device=device,
|
|
144
|
+
num_clients=cfg.num_clients,
|
|
145
|
+
epochs=cfg.epochs,
|
|
146
|
+
lr=cfg.lr,
|
|
147
|
+
batch_size=cfg.batch_size,
|
|
148
|
+
num_parallels=cfg.num_parallels,
|
|
149
|
+
)
|
|
150
|
+
case _:
|
|
151
|
+
raise ValueError(f"Invalid execution mode: {cfg.execution_mode}")
|
|
128
152
|
pipeline = FedAvgPipeline(handler=handler, trainer=trainer, writer=writer)
|
|
129
153
|
try:
|
|
130
154
|
pipeline.main()
|
|
131
155
|
except KeyboardInterrupt:
|
|
132
|
-
logging.info("KeyboardInterrupt
|
|
133
|
-
except Exception as e:
|
|
134
|
-
logging.exception(f"An error occurred: {e}")
|
|
156
|
+
logging.info("KeyboardInterrupt")
|
|
135
157
|
|
|
136
158
|
|
|
137
159
|
if __name__ == "__main__":
|
|
@@ -120,9 +120,7 @@ def main(
|
|
|
120
120
|
try:
|
|
121
121
|
pipeline.main()
|
|
122
122
|
except KeyboardInterrupt:
|
|
123
|
-
logging.info("KeyboardInterrupt
|
|
124
|
-
except Exception as e:
|
|
125
|
-
logging.exception(f"An error occurred: {e}")
|
|
123
|
+
logging.info("KeyboardInterrupt")
|
|
126
124
|
|
|
127
125
|
|
|
128
126
|
if __name__ == "__main__":
|
|
@@ -8,11 +8,17 @@ extending the core functionalities of BlazeFL.
|
|
|
8
8
|
from blazefl.contrib.fedavg import (
|
|
9
9
|
FedAvgBaseClientTrainer,
|
|
10
10
|
FedAvgBaseServerHandler,
|
|
11
|
+
FedAvgDownlinkPackage,
|
|
11
12
|
FedAvgProcessPoolClientTrainer,
|
|
13
|
+
FedAvgThreadPoolClientTrainer,
|
|
14
|
+
FedAvgUplinkPackage,
|
|
12
15
|
)
|
|
13
16
|
|
|
14
17
|
__all__ = [
|
|
15
18
|
"FedAvgBaseServerHandler",
|
|
16
19
|
"FedAvgProcessPoolClientTrainer",
|
|
17
20
|
"FedAvgBaseClientTrainer",
|
|
21
|
+
"FedAvgThreadPoolClientTrainer",
|
|
22
|
+
"FedAvgUplinkPackage",
|
|
23
|
+
"FedAvgDownlinkPackage",
|
|
18
24
|
]
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import random
|
|
2
|
+
import threading
|
|
2
3
|
from copy import deepcopy
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import Literal
|
|
6
7
|
|
|
7
8
|
import torch
|
|
9
|
+
import torch.multiprocessing as mp
|
|
8
10
|
from torch.utils.data import DataLoader
|
|
9
11
|
from tqdm import tqdm
|
|
10
12
|
|
|
@@ -14,6 +16,7 @@ from blazefl.core import (
|
|
|
14
16
|
ModelSelector,
|
|
15
17
|
PartitionedDataset,
|
|
16
18
|
ProcessPoolClientTrainer,
|
|
19
|
+
ThreadPoolClientTrainer,
|
|
17
20
|
)
|
|
18
21
|
from blazefl.utils import (
|
|
19
22
|
RandomState,
|
|
@@ -545,12 +548,15 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
545
548
|
self.num_clients = num_clients
|
|
546
549
|
self.seed = seed
|
|
547
550
|
self.ipc_mode = ipc_mode
|
|
551
|
+
self.manager = mp.Manager()
|
|
552
|
+
self.stop_event = self.manager.Event()
|
|
548
553
|
|
|
549
554
|
@staticmethod
|
|
550
555
|
def worker(
|
|
551
556
|
config: FedAvgClientConfig | Path,
|
|
552
557
|
payload: FedAvgDownlinkPackage | Path,
|
|
553
558
|
device: str,
|
|
559
|
+
stop_event: threading.Event,
|
|
554
560
|
) -> FedAvgUplinkPackage | Path:
|
|
555
561
|
"""
|
|
556
562
|
Process a single client's local training and evaluation.
|
|
@@ -578,6 +584,7 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
578
584
|
config_path: Path,
|
|
579
585
|
payload_path: Path,
|
|
580
586
|
device: str,
|
|
587
|
+
stop_event: threading.Event,
|
|
581
588
|
) -> Path:
|
|
582
589
|
config = torch.load(config_path, weights_only=False)
|
|
583
590
|
assert isinstance(config, FedAvgClientConfig)
|
|
@@ -587,6 +594,7 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
587
594
|
config=config,
|
|
588
595
|
payload=payload,
|
|
589
596
|
device=device,
|
|
597
|
+
stop_event=stop_event,
|
|
590
598
|
)
|
|
591
599
|
torch.save(package, config_path)
|
|
592
600
|
return config_path
|
|
@@ -595,6 +603,7 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
595
603
|
config: FedAvgClientConfig,
|
|
596
604
|
payload: FedAvgDownlinkPackage,
|
|
597
605
|
device: str,
|
|
606
|
+
stop_event: threading.Event,
|
|
598
607
|
) -> FedAvgUplinkPackage:
|
|
599
608
|
if config.state_path.exists():
|
|
600
609
|
state = torch.load(config.state_path, weights_only=False)
|
|
@@ -616,16 +625,17 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
616
625
|
device=device,
|
|
617
626
|
epochs=config.epochs,
|
|
618
627
|
lr=config.lr,
|
|
628
|
+
stop_event=stop_event,
|
|
619
629
|
)
|
|
620
630
|
torch.save(RandomState.get_random_state(device=device), config.state_path)
|
|
621
631
|
return package
|
|
622
632
|
|
|
623
633
|
if isinstance(config, Path) and isinstance(payload, Path):
|
|
624
|
-
return _storage_worker(config, payload, device)
|
|
634
|
+
return _storage_worker(config, payload, device, stop_event)
|
|
625
635
|
elif isinstance(config, FedAvgClientConfig) and isinstance(
|
|
626
636
|
payload, FedAvgDownlinkPackage
|
|
627
637
|
):
|
|
628
|
-
return _shared_memory_worker(config, payload, device)
|
|
638
|
+
return _shared_memory_worker(config, payload, device, stop_event)
|
|
629
639
|
else:
|
|
630
640
|
raise TypeError(
|
|
631
641
|
"Invalid types for config and payload."
|
|
@@ -640,6 +650,7 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
640
650
|
device: str,
|
|
641
651
|
epochs: int,
|
|
642
652
|
lr: float,
|
|
653
|
+
stop_event: threading.Event,
|
|
643
654
|
) -> FedAvgUplinkPackage:
|
|
644
655
|
"""
|
|
645
656
|
Train the model with the given training data loader.
|
|
@@ -664,6 +675,8 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
664
675
|
|
|
665
676
|
data_size = 0
|
|
666
677
|
for _ in range(epochs):
|
|
678
|
+
if stop_event.is_set():
|
|
679
|
+
break
|
|
667
680
|
for data, target in train_loader:
|
|
668
681
|
data = data.to(device)
|
|
669
682
|
target = target.to(device)
|
|
@@ -714,3 +727,120 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
714
727
|
package = deepcopy(self.cache)
|
|
715
728
|
self.cache = []
|
|
716
729
|
return package
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
class FedAvgThreadPoolClientTrainer(
|
|
733
|
+
ThreadPoolClientTrainer[
|
|
734
|
+
FedAvgUplinkPackage,
|
|
735
|
+
FedAvgDownlinkPackage,
|
|
736
|
+
]
|
|
737
|
+
):
|
|
738
|
+
def __init__(
|
|
739
|
+
self,
|
|
740
|
+
model_selector: ModelSelector,
|
|
741
|
+
model_name: str,
|
|
742
|
+
dataset: PartitionedDataset,
|
|
743
|
+
device: str,
|
|
744
|
+
num_clients: int,
|
|
745
|
+
epochs: int,
|
|
746
|
+
batch_size: int,
|
|
747
|
+
lr: float,
|
|
748
|
+
seed: int,
|
|
749
|
+
num_parallels: int,
|
|
750
|
+
) -> None:
|
|
751
|
+
self.num_parallels = num_parallels
|
|
752
|
+
self.device = device
|
|
753
|
+
if self.device == "cuda":
|
|
754
|
+
self.device_count = torch.cuda.device_count()
|
|
755
|
+
self.cache: list[FedAvgUplinkPackage] = []
|
|
756
|
+
|
|
757
|
+
self.model_selector = model_selector
|
|
758
|
+
self.model_name = model_name
|
|
759
|
+
self.dataset = dataset
|
|
760
|
+
self.epochs = epochs
|
|
761
|
+
self.batch_size = batch_size
|
|
762
|
+
self.lr = lr
|
|
763
|
+
self.num_clients = num_clients
|
|
764
|
+
self.seed = seed
|
|
765
|
+
self.stop_event = threading.Event()
|
|
766
|
+
|
|
767
|
+
def worker(
|
|
768
|
+
self,
|
|
769
|
+
cid: int,
|
|
770
|
+
device: str,
|
|
771
|
+
payload: FedAvgDownlinkPackage,
|
|
772
|
+
stop_event: threading.Event,
|
|
773
|
+
) -> FedAvgUplinkPackage:
|
|
774
|
+
model = self.model_selector.select_model(self.model_name)
|
|
775
|
+
train_loader = self.dataset.get_dataloader(
|
|
776
|
+
type_="train",
|
|
777
|
+
cid=cid,
|
|
778
|
+
batch_size=self.batch_size,
|
|
779
|
+
)
|
|
780
|
+
package = self.train(
|
|
781
|
+
model=model,
|
|
782
|
+
model_parameters=payload.model_parameters,
|
|
783
|
+
train_loader=train_loader,
|
|
784
|
+
device=device,
|
|
785
|
+
epochs=self.epochs,
|
|
786
|
+
lr=self.lr,
|
|
787
|
+
stop_event=stop_event,
|
|
788
|
+
)
|
|
789
|
+
return package
|
|
790
|
+
|
|
791
|
+
def train(
|
|
792
|
+
self,
|
|
793
|
+
model: torch.nn.Module,
|
|
794
|
+
model_parameters: torch.Tensor,
|
|
795
|
+
train_loader: DataLoader,
|
|
796
|
+
device: str,
|
|
797
|
+
epochs: int,
|
|
798
|
+
lr: float,
|
|
799
|
+
stop_event: threading.Event,
|
|
800
|
+
) -> FedAvgUplinkPackage:
|
|
801
|
+
"""
|
|
802
|
+
Train the model with the given training data loader.
|
|
803
|
+
|
|
804
|
+
Args:
|
|
805
|
+
model (torch.nn.Module): The model to train.
|
|
806
|
+
model_parameters (torch.Tensor): Initial global model parameters.
|
|
807
|
+
train_loader (DataLoader): DataLoader for the training data.
|
|
808
|
+
device (str): Device to run the training on.
|
|
809
|
+
epochs (int): Number of local training epochs.
|
|
810
|
+
lr (float): Learning rate for the optimizer.
|
|
811
|
+
|
|
812
|
+
Returns:
|
|
813
|
+
FedAvgUplinkPackage: Uplink package containing updated model parameters
|
|
814
|
+
and data size.
|
|
815
|
+
"""
|
|
816
|
+
model.to(device)
|
|
817
|
+
deserialize_model(model, model_parameters)
|
|
818
|
+
model.train()
|
|
819
|
+
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
|
820
|
+
criterion = torch.nn.CrossEntropyLoss()
|
|
821
|
+
|
|
822
|
+
data_size = 0
|
|
823
|
+
for _ in range(epochs):
|
|
824
|
+
if stop_event.is_set():
|
|
825
|
+
break
|
|
826
|
+
for data, target in train_loader:
|
|
827
|
+
data = data.to(device)
|
|
828
|
+
target = target.to(device)
|
|
829
|
+
|
|
830
|
+
output = model(data)
|
|
831
|
+
loss = criterion(output, target)
|
|
832
|
+
|
|
833
|
+
data_size += len(target)
|
|
834
|
+
|
|
835
|
+
optimizer.zero_grad()
|
|
836
|
+
loss.backward()
|
|
837
|
+
optimizer.step()
|
|
838
|
+
|
|
839
|
+
model_parameters = serialize_model(model)
|
|
840
|
+
|
|
841
|
+
return FedAvgUplinkPackage(model_parameters, data_size)
|
|
842
|
+
|
|
843
|
+
def uplink_package(self) -> list[FedAvgUplinkPackage]:
|
|
844
|
+
package = deepcopy(self.cache)
|
|
845
|
+
self.cache = []
|
|
846
|
+
return package
|
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
import multiprocessing as mp
|
|
2
1
|
import signal
|
|
2
|
+
import threading
|
|
3
3
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
4
4
|
from multiprocessing.pool import ApplyResult
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Literal, Protocol, TypeVar
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
+
import torch.multiprocessing as mp
|
|
9
10
|
from tqdm import tqdm
|
|
10
11
|
|
|
11
12
|
from blazefl.utils import move_tensor_to_shared_memory
|
|
@@ -82,6 +83,7 @@ class ProcessPoolClientTrainer(
|
|
|
82
83
|
device_count: int
|
|
83
84
|
cache: list[UplinkPackage]
|
|
84
85
|
ipc_mode: Literal["storage", "shared_memory"] = "storage"
|
|
86
|
+
stop_event: threading.Event
|
|
85
87
|
|
|
86
88
|
def get_client_config(self, cid: int) -> ClientConfig:
|
|
87
89
|
"""
|
|
@@ -111,7 +113,10 @@ class ProcessPoolClientTrainer(
|
|
|
111
113
|
|
|
112
114
|
@staticmethod
|
|
113
115
|
def worker(
|
|
114
|
-
config: ClientConfig | Path,
|
|
116
|
+
config: ClientConfig | Path,
|
|
117
|
+
payload: DownlinkPackage | Path,
|
|
118
|
+
device: str,
|
|
119
|
+
stop_event: threading.Event,
|
|
115
120
|
) -> UplinkPackage | Path:
|
|
116
121
|
"""
|
|
117
122
|
Process a single client's training task.
|
|
@@ -157,11 +162,13 @@ class ProcessPoolClientTrainer(
|
|
|
157
162
|
else: # shared_memory
|
|
158
163
|
move_tensor_to_shared_memory(payload)
|
|
159
164
|
|
|
160
|
-
|
|
165
|
+
self.stop_event.clear()
|
|
166
|
+
pool = mp.Pool(
|
|
161
167
|
processes=self.num_parallels,
|
|
162
168
|
initializer=signal.signal,
|
|
163
169
|
initargs=(signal.SIGINT, signal.SIG_IGN),
|
|
164
|
-
)
|
|
170
|
+
)
|
|
171
|
+
try:
|
|
165
172
|
jobs: list[ApplyResult] = []
|
|
166
173
|
for cid in cid_list:
|
|
167
174
|
config = self.get_client_config(cid)
|
|
@@ -171,12 +178,15 @@ class ProcessPoolClientTrainer(
|
|
|
171
178
|
torch.save(config, config_path)
|
|
172
179
|
jobs.append(
|
|
173
180
|
pool.apply_async(
|
|
174
|
-
self.worker,
|
|
181
|
+
self.worker,
|
|
182
|
+
(config_path, payload_path, device, self.stop_event),
|
|
175
183
|
)
|
|
176
184
|
)
|
|
177
185
|
else: # shared_memory
|
|
178
186
|
jobs.append(
|
|
179
|
-
pool.apply_async(
|
|
187
|
+
pool.apply_async(
|
|
188
|
+
self.worker, (config, payload, device, self.stop_event)
|
|
189
|
+
)
|
|
180
190
|
)
|
|
181
191
|
|
|
182
192
|
for job in tqdm(jobs, desc="Client", leave=False):
|
|
@@ -187,6 +197,10 @@ class ProcessPoolClientTrainer(
|
|
|
187
197
|
else: # shared_memory
|
|
188
198
|
package = result
|
|
189
199
|
self.cache.append(package)
|
|
200
|
+
finally:
|
|
201
|
+
self.stop_event.set()
|
|
202
|
+
pool.close()
|
|
203
|
+
pool.join()
|
|
190
204
|
|
|
191
205
|
|
|
192
206
|
class ThreadPoolClientTrainer(
|
|
@@ -197,12 +211,14 @@ class ThreadPoolClientTrainer(
|
|
|
197
211
|
device: str
|
|
198
212
|
device_count: int
|
|
199
213
|
cache: list[UplinkPackage]
|
|
214
|
+
stop_event: threading.Event
|
|
200
215
|
|
|
201
216
|
def worker(
|
|
202
217
|
self,
|
|
203
218
|
cid: int,
|
|
204
219
|
device: str,
|
|
205
220
|
payload: DownlinkPackage,
|
|
221
|
+
stop_event: threading.Event,
|
|
206
222
|
) -> UplinkPackage:
|
|
207
223
|
"""
|
|
208
224
|
Process a single client's training task in a thread.
|
|
@@ -211,6 +227,7 @@ class ThreadPoolClientTrainer(
|
|
|
211
227
|
cid (int): The client ID.
|
|
212
228
|
device (str): The device to use for processing this client.
|
|
213
229
|
payload (DownlinkPackage): The data package received from the server.
|
|
230
|
+
stop_event (threading.Event): Event to signal stopping the worker.
|
|
214
231
|
|
|
215
232
|
Returns:
|
|
216
233
|
UplinkPackage: The uplink package containing the client's results.
|
|
@@ -223,7 +240,9 @@ class ThreadPoolClientTrainer(
|
|
|
223
240
|
return self.device
|
|
224
241
|
|
|
225
242
|
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
|
|
226
|
-
|
|
243
|
+
self.stop_event.clear()
|
|
244
|
+
executor = ThreadPoolExecutor(max_workers=self.num_parallels)
|
|
245
|
+
try:
|
|
227
246
|
futures = []
|
|
228
247
|
for cid in cid_list:
|
|
229
248
|
device = self.get_client_device(cid)
|
|
@@ -232,11 +251,18 @@ class ThreadPoolClientTrainer(
|
|
|
232
251
|
cid,
|
|
233
252
|
device,
|
|
234
253
|
payload,
|
|
254
|
+
self.stop_event,
|
|
235
255
|
)
|
|
236
256
|
futures.append(future)
|
|
237
257
|
|
|
238
258
|
for future in tqdm(
|
|
239
|
-
as_completed(futures),
|
|
259
|
+
as_completed(futures),
|
|
260
|
+
total=len(futures),
|
|
261
|
+
desc="Client",
|
|
262
|
+
leave=False,
|
|
240
263
|
):
|
|
241
264
|
result = future.result()
|
|
242
265
|
self.cache.append(result)
|
|
266
|
+
finally:
|
|
267
|
+
self.stop_event.set()
|
|
268
|
+
executor.shutdown(wait=True, cancel_futures=True)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import threading
|
|
1
2
|
from blazefl.utils import move_tensor_to_shared_memory as move_tensor_to_shared_memory
|
|
2
3
|
from multiprocessing.pool import ApplyResult as ApplyResult
|
|
3
4
|
from pathlib import Path
|
|
@@ -18,10 +19,11 @@ class ProcessPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage]
|
|
|
18
19
|
device_count: int
|
|
19
20
|
cache: list[UplinkPackage]
|
|
20
21
|
ipc_mode: Literal['storage', 'shared_memory']
|
|
22
|
+
stop_event: threading.Event
|
|
21
23
|
def get_client_config(self, cid: int) -> ClientConfig: ...
|
|
22
24
|
def get_client_device(self, cid: int) -> str: ...
|
|
23
25
|
@staticmethod
|
|
24
|
-
def worker(config: ClientConfig | Path, payload: DownlinkPackage | Path, device: str) -> UplinkPackage | Path: ...
|
|
26
|
+
def worker(config: ClientConfig | Path, payload: DownlinkPackage | Path, device: str, stop_event: threading.Event) -> UplinkPackage | Path: ...
|
|
25
27
|
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...
|
|
26
28
|
|
|
27
29
|
class ThreadPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage], Protocol[UplinkPackage, DownlinkPackage]):
|
|
@@ -29,6 +31,7 @@ class ThreadPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage],
|
|
|
29
31
|
device: str
|
|
30
32
|
device_count: int
|
|
31
33
|
cache: list[UplinkPackage]
|
|
32
|
-
|
|
34
|
+
stop_event: threading.Event
|
|
35
|
+
def worker(self, cid: int, device: str, payload: DownlinkPackage, stop_event: threading.Event) -> UplinkPackage: ...
|
|
33
36
|
def get_client_device(self, cid: int) -> str: ...
|
|
34
37
|
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...
|