blazefl 2.0.0.dev3__tar.gz → 2.0.0.dev4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/PKG-INFO +1 -1
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/config/config.yaml +2 -1
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/main.py +8 -7
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/algorithm/dsfl.py +43 -39
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/pyproject.toml +1 -1
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/contrib/fedavg.py +87 -47
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/client_trainer.py +73 -22
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/client_trainer.pyi +8 -6
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/__init__.py +2 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/__init__.pyi +2 -1
- blazefl-2.0.0.dev4/src/blazefl/utils/ipc.py +33 -0
- blazefl-2.0.0.dev4/src/blazefl/utils/ipc.pyi +3 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_contrib/test_fedavg.py +4 -1
- blazefl-2.0.0.dev4/tests/test_core/test_client_trainer.py +126 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/uv.lock +1 -1
- blazefl-2.0.0.dev3/tests/test_core/test_client_trainer.py +0 -84
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/FUNDING.yml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/dependabot.yml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/workflows/ci.yaml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/workflows/publish.yaml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/workflows/sphinx.yaml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.gitignore +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.pre-commit-config.yaml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.python-version +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/CODE_OF_CONDUCT.md +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/LICENSE +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/Makefile +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/README.md +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/Makefile +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/architecture.png +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/benchmark_cnn.png +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/benchmark_resnet18.png +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/logo.svg +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/logo_square.svg +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/ogp.png +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/make.bat +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/_static/favicon.ico +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/_static/logo.png +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/_static/logo_square.png +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/_templates/autosummary/class.rst +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/_templates/autosummary/module.rst +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/benchmark.rst +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/conf.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/contribute.rst +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/example.rst +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/index.rst +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/install.rst +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/overview.rst +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/reference.rst +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/.gitignore +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/.python-version +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/Makefile +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/README.md +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/config/config.yaml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/main.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/models/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/models/selector.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/pyproject.toml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/uv.lock +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/.gitignore +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/.python-version +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/Makefile +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/README.md +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/models/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/models/selector.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/pyproject.toml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/.gitignore +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/.python-version +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/Makefile +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/README.md +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/config/config.yaml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/main.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/cnn.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/selector.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/pyproject.toml +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/contrib/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/__init__.pyi +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/model_selector.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/model_selector.pyi +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/partitioned_dataset.pyi +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/server_handler.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/server_handler.pyi +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/py.typed +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/dataset.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/dataset.pyi +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/seed.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/seed.pyi +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/serialize.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/serialize.pyi +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/conftest.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_contrib/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_core/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_core/test_model_selector.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_core/test_partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_utils/__init__.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_utils/test_dataset.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_utils/test_seed.py +0 -0
- {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_utils/test_serialize.py +0 -0
|
@@ -98,31 +98,32 @@ def main(cfg: DictConfig):
|
|
|
98
98
|
batch_size=cfg.batch_size,
|
|
99
99
|
)
|
|
100
100
|
trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer | None = None
|
|
101
|
-
if cfg.
|
|
102
|
-
trainer =
|
|
101
|
+
if cfg.parallel:
|
|
102
|
+
trainer = FedAvgProcessPoolClientTrainer(
|
|
103
103
|
model_selector=model_selector,
|
|
104
104
|
model_name=cfg.model_name,
|
|
105
105
|
dataset=dataset,
|
|
106
|
+
share_dir=share_dir,
|
|
107
|
+
state_dir=state_dir,
|
|
108
|
+
seed=cfg.seed,
|
|
106
109
|
device=device,
|
|
107
110
|
num_clients=cfg.num_clients,
|
|
108
111
|
epochs=cfg.epochs,
|
|
109
112
|
lr=cfg.lr,
|
|
110
113
|
batch_size=cfg.batch_size,
|
|
114
|
+
num_parallels=cfg.num_parallels,
|
|
115
|
+
ipc_mode=cfg.ipc_mode,
|
|
111
116
|
)
|
|
112
117
|
else:
|
|
113
|
-
trainer =
|
|
118
|
+
trainer = FedAvgBaseClientTrainer(
|
|
114
119
|
model_selector=model_selector,
|
|
115
120
|
model_name=cfg.model_name,
|
|
116
121
|
dataset=dataset,
|
|
117
|
-
share_dir=share_dir,
|
|
118
|
-
state_dir=state_dir,
|
|
119
|
-
seed=cfg.seed,
|
|
120
122
|
device=device,
|
|
121
123
|
num_clients=cfg.num_clients,
|
|
122
124
|
epochs=cfg.epochs,
|
|
123
125
|
lr=cfg.lr,
|
|
124
126
|
batch_size=cfg.batch_size,
|
|
125
|
-
num_parallels=cfg.num_parallels,
|
|
126
127
|
)
|
|
127
128
|
pipeline = FedAvgPipeline(handler=handler, trainer=trainer, writer=writer)
|
|
128
129
|
try:
|
|
@@ -233,7 +233,7 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
|
|
|
233
233
|
|
|
234
234
|
|
|
235
235
|
@dataclass
|
|
236
|
-
class
|
|
236
|
+
class DSFLClientConfig:
|
|
237
237
|
model_selector: DSFLModelSelector
|
|
238
238
|
model_name: str
|
|
239
239
|
dataset: DSFLPartitionedDataset
|
|
@@ -245,7 +245,6 @@ class DSFLDiskSharedData:
|
|
|
245
245
|
kd_lr: float
|
|
246
246
|
cid: int
|
|
247
247
|
seed: int
|
|
248
|
-
payload: DSFLDownlinkPackage
|
|
249
248
|
state_path: Path
|
|
250
249
|
|
|
251
250
|
|
|
@@ -258,7 +257,7 @@ class DSFLClientState:
|
|
|
258
257
|
|
|
259
258
|
|
|
260
259
|
class DSFLProcessPoolClientTrainer(
|
|
261
|
-
ProcessPoolClientTrainer[DSFLUplinkPackage, DSFLDownlinkPackage,
|
|
260
|
+
ProcessPoolClientTrainer[DSFLUplinkPackage, DSFLDownlinkPackage, DSFLClientConfig]
|
|
262
261
|
):
|
|
263
262
|
def __init__(
|
|
264
263
|
self,
|
|
@@ -300,68 +299,76 @@ class DSFLProcessPoolClientTrainer(
|
|
|
300
299
|
self.device = device
|
|
301
300
|
self.num_clients = num_clients
|
|
302
301
|
self.seed = seed
|
|
302
|
+
self.ipc_mode = "storage"
|
|
303
303
|
|
|
304
304
|
if self.device == "cuda":
|
|
305
305
|
self.device_count = torch.cuda.device_count()
|
|
306
306
|
|
|
307
307
|
@staticmethod
|
|
308
|
-
def
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
308
|
+
def worker(
|
|
309
|
+
config: DSFLClientConfig | Path,
|
|
310
|
+
payload: DSFLDownlinkPackage | Path,
|
|
311
|
+
device: str,
|
|
312
|
+
) -> Path:
|
|
313
|
+
assert isinstance(config, Path) and isinstance(payload, Path)
|
|
314
|
+
config_path, payload_path = config, payload
|
|
315
|
+
c = torch.load(config_path, weights_only=False)
|
|
316
|
+
p = torch.load(payload_path, weights_only=False)
|
|
317
|
+
assert isinstance(c, DSFLClientConfig) and isinstance(p, DSFLDownlinkPackage)
|
|
318
|
+
|
|
319
|
+
model = c.model_selector.select_model(c.model_name)
|
|
320
|
+
optimizer = torch.optim.SGD(model.parameters(), lr=c.lr)
|
|
314
321
|
kd_optimizer: torch.optim.SGD | None = None
|
|
315
322
|
|
|
316
323
|
state: DSFLClientState | None = None
|
|
317
|
-
if
|
|
318
|
-
state = torch.load(
|
|
324
|
+
if c.state_path.exists():
|
|
325
|
+
state = torch.load(c.state_path, weights_only=False)
|
|
319
326
|
assert isinstance(state, DSFLClientState)
|
|
320
327
|
RandomState.set_random_state(state.random)
|
|
321
328
|
model.load_state_dict(state.model)
|
|
322
329
|
optimizer.load_state_dict(state.optimizer)
|
|
323
330
|
if state.kd_optimizer is not None:
|
|
324
|
-
kd_optimizer = torch.optim.SGD(model.parameters(), lr=
|
|
331
|
+
kd_optimizer = torch.optim.SGD(model.parameters(), lr=c.kd_lr)
|
|
325
332
|
kd_optimizer.load_state_dict(state.kd_optimizer)
|
|
326
333
|
else:
|
|
327
|
-
seed_everything(
|
|
334
|
+
seed_everything(c.seed, device=device)
|
|
328
335
|
|
|
329
336
|
# Distill
|
|
330
|
-
open_dataset =
|
|
331
|
-
if
|
|
332
|
-
global_soft_labels = list(torch.unbind(
|
|
333
|
-
global_indices =
|
|
337
|
+
open_dataset = c.dataset.get_dataset(type_="open", cid=None)
|
|
338
|
+
if p.indices is not None and p.soft_labels is not None:
|
|
339
|
+
global_soft_labels = list(torch.unbind(p.soft_labels, dim=0))
|
|
340
|
+
global_indices = p.indices.tolist()
|
|
334
341
|
if kd_optimizer is None:
|
|
335
|
-
kd_optimizer = torch.optim.SGD(model.parameters(), lr=
|
|
342
|
+
kd_optimizer = torch.optim.SGD(model.parameters(), lr=c.kd_lr)
|
|
336
343
|
DSFLBaseServerHandler.distill(
|
|
337
344
|
model=model,
|
|
338
345
|
optimizer=kd_optimizer,
|
|
339
|
-
dataset=
|
|
346
|
+
dataset=c.dataset,
|
|
340
347
|
global_soft_labels=global_soft_labels,
|
|
341
348
|
global_indices=global_indices,
|
|
342
|
-
kd_epochs=
|
|
343
|
-
kd_batch_size=
|
|
349
|
+
kd_epochs=c.kd_epochs,
|
|
350
|
+
kd_batch_size=c.kd_batch_size,
|
|
344
351
|
device=device,
|
|
345
352
|
)
|
|
346
353
|
|
|
347
354
|
# Train
|
|
348
|
-
train_loader =
|
|
355
|
+
train_loader = c.dataset.get_dataloader(
|
|
349
356
|
type_="train",
|
|
350
|
-
cid=
|
|
351
|
-
batch_size=
|
|
357
|
+
cid=c.cid,
|
|
358
|
+
batch_size=c.batch_size,
|
|
352
359
|
)
|
|
353
360
|
DSFLProcessPoolClientTrainer.train(
|
|
354
361
|
model=model,
|
|
355
362
|
optimizer=optimizer,
|
|
356
363
|
train_loader=train_loader,
|
|
357
364
|
device=device,
|
|
358
|
-
epochs=
|
|
365
|
+
epochs=c.epochs,
|
|
359
366
|
)
|
|
360
367
|
|
|
361
368
|
# Predict
|
|
362
369
|
open_loader = DataLoader(
|
|
363
|
-
Subset(open_dataset,
|
|
364
|
-
batch_size=
|
|
370
|
+
Subset(open_dataset, p.next_indices.tolist()),
|
|
371
|
+
batch_size=c.batch_size,
|
|
365
372
|
)
|
|
366
373
|
soft_labels = DSFLProcessPoolClientTrainer.predict(
|
|
367
374
|
model=model,
|
|
@@ -370,10 +377,10 @@ class DSFLProcessPoolClientTrainer(
|
|
|
370
377
|
)
|
|
371
378
|
|
|
372
379
|
# Evaluate
|
|
373
|
-
test_loader =
|
|
380
|
+
test_loader = c.dataset.get_dataloader(
|
|
374
381
|
type_="test",
|
|
375
|
-
cid=
|
|
376
|
-
batch_size=
|
|
382
|
+
cid=c.cid,
|
|
383
|
+
batch_size=c.batch_size,
|
|
377
384
|
)
|
|
378
385
|
loss, acc = DSFLBaseServerHandler.evaulate(
|
|
379
386
|
model=model,
|
|
@@ -383,19 +390,19 @@ class DSFLProcessPoolClientTrainer(
|
|
|
383
390
|
|
|
384
391
|
package = DSFLUplinkPackage(
|
|
385
392
|
soft_labels=soft_labels,
|
|
386
|
-
indices=
|
|
393
|
+
indices=p.next_indices,
|
|
387
394
|
metadata={"loss": loss, "acc": acc},
|
|
388
395
|
)
|
|
389
396
|
|
|
390
|
-
torch.save(package,
|
|
397
|
+
torch.save(package, config_path)
|
|
391
398
|
state = DSFLClientState(
|
|
392
399
|
random=RandomState.get_random_state(device=device),
|
|
393
400
|
model=model.state_dict(),
|
|
394
401
|
optimizer=optimizer.state_dict(),
|
|
395
402
|
kd_optimizer=kd_optimizer.state_dict() if kd_optimizer else None,
|
|
396
403
|
)
|
|
397
|
-
torch.save(state,
|
|
398
|
-
return
|
|
404
|
+
torch.save(state, c.state_path)
|
|
405
|
+
return config_path
|
|
399
406
|
|
|
400
407
|
@staticmethod
|
|
401
408
|
def train(
|
|
@@ -442,10 +449,8 @@ class DSFLProcessPoolClientTrainer(
|
|
|
442
449
|
soft_labels = torch.cat(soft_labels_list, dim=0)
|
|
443
450
|
return soft_labels.cpu()
|
|
444
451
|
|
|
445
|
-
def
|
|
446
|
-
|
|
447
|
-
) -> DSFLDiskSharedData:
|
|
448
|
-
data = DSFLDiskSharedData(
|
|
452
|
+
def get_client_config(self, cid: int) -> DSFLClientConfig:
|
|
453
|
+
data = DSFLClientConfig(
|
|
449
454
|
model_selector=self.model_selector,
|
|
450
455
|
model_name=self.model_name,
|
|
451
456
|
dataset=self.dataset,
|
|
@@ -457,7 +462,6 @@ class DSFLProcessPoolClientTrainer(
|
|
|
457
462
|
kd_lr=self.kd_lr,
|
|
458
463
|
cid=cid,
|
|
459
464
|
seed=self.seed,
|
|
460
|
-
payload=payload,
|
|
461
465
|
state_path=self.state_dir.joinpath(f"{cid}.pt"),
|
|
462
466
|
)
|
|
463
467
|
return data
|
|
@@ -2,6 +2,7 @@ import random
|
|
|
2
2
|
from copy import deepcopy
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from pathlib import Path
|
|
5
|
+
from typing import Literal
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
8
|
from torch.utils.data import DataLoader
|
|
@@ -431,7 +432,7 @@ class FedAvgBaseClientTrainer(
|
|
|
431
432
|
|
|
432
433
|
|
|
433
434
|
@dataclass
|
|
434
|
-
class
|
|
435
|
+
class FedAvgClientConfig:
|
|
435
436
|
"""
|
|
436
437
|
Data structure representing shared data for parallel client training
|
|
437
438
|
in the Federated Averaging (FedAvg) algorithm.
|
|
@@ -448,7 +449,6 @@ class FedAvgDiskSharedData:
|
|
|
448
449
|
lr (float): Learning rate for the optimizer.
|
|
449
450
|
cid (int): Client ID.
|
|
450
451
|
seed (int): Seed for reproducibility.
|
|
451
|
-
payload (FedAvgDownlinkPackage): Downlink package with global model parameters.
|
|
452
452
|
state_path (Path): Path to save the client's random state.
|
|
453
453
|
"""
|
|
454
454
|
|
|
@@ -460,13 +460,12 @@ class FedAvgDiskSharedData:
|
|
|
460
460
|
lr: float
|
|
461
461
|
cid: int
|
|
462
462
|
seed: int
|
|
463
|
-
payload: FedAvgDownlinkPackage
|
|
464
463
|
state_path: Path
|
|
465
464
|
|
|
466
465
|
|
|
467
466
|
class FedAvgProcessPoolClientTrainer(
|
|
468
467
|
ProcessPoolClientTrainer[
|
|
469
|
-
FedAvgUplinkPackage, FedAvgDownlinkPackage,
|
|
468
|
+
FedAvgUplinkPackage, FedAvgDownlinkPackage, FedAvgClientConfig
|
|
470
469
|
]
|
|
471
470
|
):
|
|
472
471
|
"""
|
|
@@ -488,6 +487,8 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
488
487
|
lr (float): Learning rate for the optimizer.
|
|
489
488
|
seed (int): Seed for reproducibility.
|
|
490
489
|
num_parallels (int): Number of parallel processes for training.
|
|
490
|
+
ipc_mode (Literal["storage", "shared_memory"]):
|
|
491
|
+
Inter-process communication mode.
|
|
491
492
|
device_count (int | None): Number of CUDA devices available (if using GPU).
|
|
492
493
|
"""
|
|
493
494
|
|
|
@@ -505,6 +506,7 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
505
506
|
lr: float,
|
|
506
507
|
seed: int,
|
|
507
508
|
num_parallels: int,
|
|
509
|
+
ipc_mode: Literal["storage", "shared_memory"],
|
|
508
510
|
) -> None:
|
|
509
511
|
"""
|
|
510
512
|
Initialize the FedAvgParalleClientTrainer.
|
|
@@ -542,50 +544,93 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
542
544
|
self.device = device
|
|
543
545
|
self.num_clients = num_clients
|
|
544
546
|
self.seed = seed
|
|
547
|
+
self.ipc_mode = ipc_mode
|
|
545
548
|
|
|
546
549
|
@staticmethod
|
|
547
|
-
def
|
|
550
|
+
def worker(
|
|
551
|
+
config: FedAvgClientConfig | Path,
|
|
552
|
+
payload: FedAvgDownlinkPackage | Path,
|
|
553
|
+
device: str,
|
|
554
|
+
) -> FedAvgUplinkPackage | Path:
|
|
548
555
|
"""
|
|
549
556
|
Process a single client's local training and evaluation.
|
|
550
557
|
|
|
551
|
-
This method is executed by a
|
|
552
|
-
|
|
558
|
+
This method is executed by a worker process and handles loading client
|
|
559
|
+
configuration and payload, performing the client-specific training,
|
|
560
|
+
and returning the result.
|
|
553
561
|
|
|
554
562
|
Args:
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
563
|
+
config (FedAvgClientConfig | Path):
|
|
564
|
+
The client's configuration data, or a path to a file containing
|
|
565
|
+
the configuration if `ipc_mode` is "storage".
|
|
566
|
+
payload (FedAvgDownlinkPackage | Path):
|
|
567
|
+
The downlink payload from the server, or a path to a file
|
|
568
|
+
containing the payload if `ipc_mode` is "storage".
|
|
569
|
+
device (str): Device to use for processing (e.g., "cpu", "cuda:0").
|
|
558
570
|
|
|
559
571
|
Returns:
|
|
560
|
-
Path:
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
572
|
+
FedAvgUplinkPackage | Path:
|
|
573
|
+
The uplink package containing the client's results, or a path to
|
|
574
|
+
a file containing the package if `ipc_mode` is "storage".
|
|
575
|
+
"""
|
|
576
|
+
|
|
577
|
+
def _storage_worker(
|
|
578
|
+
config_path: Path,
|
|
579
|
+
payload_path: Path,
|
|
580
|
+
device: str,
|
|
581
|
+
) -> Path:
|
|
582
|
+
config = torch.load(config_path, weights_only=False)
|
|
583
|
+
assert isinstance(config, FedAvgClientConfig)
|
|
584
|
+
payload = torch.load(payload_path, weights_only=False)
|
|
585
|
+
assert isinstance(payload, FedAvgDownlinkPackage)
|
|
586
|
+
package = _shared_memory_worker(
|
|
587
|
+
config=config,
|
|
588
|
+
payload=payload,
|
|
589
|
+
device=device,
|
|
590
|
+
)
|
|
591
|
+
torch.save(package, config_path)
|
|
592
|
+
return config_path
|
|
593
|
+
|
|
594
|
+
def _shared_memory_worker(
|
|
595
|
+
config: FedAvgClientConfig,
|
|
596
|
+
payload: FedAvgDownlinkPackage,
|
|
597
|
+
device: str,
|
|
598
|
+
) -> FedAvgUplinkPackage:
|
|
599
|
+
if config.state_path.exists():
|
|
600
|
+
state = torch.load(config.state_path, weights_only=False)
|
|
601
|
+
assert isinstance(state, RandomState)
|
|
602
|
+
RandomState.set_random_state(state)
|
|
603
|
+
else:
|
|
604
|
+
seed_everything(config.seed, device=device)
|
|
605
|
+
|
|
606
|
+
model = config.model_selector.select_model(config.model_name)
|
|
607
|
+
train_loader = config.dataset.get_dataloader(
|
|
608
|
+
type_="train",
|
|
609
|
+
cid=config.cid,
|
|
610
|
+
batch_size=config.batch_size,
|
|
611
|
+
)
|
|
612
|
+
package = FedAvgProcessPoolClientTrainer.train(
|
|
613
|
+
model=model,
|
|
614
|
+
model_parameters=payload.model_parameters,
|
|
615
|
+
train_loader=train_loader,
|
|
616
|
+
device=device,
|
|
617
|
+
epochs=config.epochs,
|
|
618
|
+
lr=config.lr,
|
|
619
|
+
)
|
|
620
|
+
torch.save(RandomState.get_random_state(device=device), config.state_path)
|
|
621
|
+
return package
|
|
622
|
+
|
|
623
|
+
if isinstance(config, Path) and isinstance(payload, Path):
|
|
624
|
+
return _storage_worker(config, payload, device)
|
|
625
|
+
elif isinstance(config, FedAvgClientConfig) and isinstance(
|
|
626
|
+
payload, FedAvgDownlinkPackage
|
|
627
|
+
):
|
|
628
|
+
return _shared_memory_worker(config, payload, device)
|
|
569
629
|
else:
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
type_="train",
|
|
575
|
-
cid=data.cid,
|
|
576
|
-
batch_size=data.batch_size,
|
|
577
|
-
)
|
|
578
|
-
package = FedAvgProcessPoolClientTrainer.train(
|
|
579
|
-
model=model,
|
|
580
|
-
model_parameters=data.payload.model_parameters,
|
|
581
|
-
train_loader=train_loader,
|
|
582
|
-
device=device,
|
|
583
|
-
epochs=data.epochs,
|
|
584
|
-
lr=data.lr,
|
|
585
|
-
)
|
|
586
|
-
torch.save(package, path)
|
|
587
|
-
torch.save(RandomState.get_random_state(device=device), data.state_path)
|
|
588
|
-
return path
|
|
630
|
+
raise TypeError(
|
|
631
|
+
"Invalid types for config and payload."
|
|
632
|
+
" Expected FedAvgClientConfig and FedAvgDownlinkPackage or Path."
|
|
633
|
+
)
|
|
589
634
|
|
|
590
635
|
@staticmethod
|
|
591
636
|
def train(
|
|
@@ -636,21 +681,17 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
636
681
|
|
|
637
682
|
return FedAvgUplinkPackage(model_parameters, data_size)
|
|
638
683
|
|
|
639
|
-
def
|
|
640
|
-
self, cid: int, payload: FedAvgDownlinkPackage
|
|
641
|
-
) -> FedAvgDiskSharedData:
|
|
684
|
+
def get_client_config(self, cid: int) -> FedAvgClientConfig:
|
|
642
685
|
"""
|
|
643
|
-
Generate the
|
|
686
|
+
Generate the client configuration for a specific client.
|
|
644
687
|
|
|
645
688
|
Args:
|
|
646
689
|
cid (int): Client ID.
|
|
647
|
-
payload (FedAvgDownlinkPackage): Downlink package with global model
|
|
648
|
-
parameters.
|
|
649
690
|
|
|
650
691
|
Returns:
|
|
651
|
-
|
|
692
|
+
FedAvgClientConfig: Client configuration data structure.
|
|
652
693
|
"""
|
|
653
|
-
data =
|
|
694
|
+
data = FedAvgClientConfig(
|
|
654
695
|
model_selector=self.model_selector,
|
|
655
696
|
model_name=self.model_name,
|
|
656
697
|
dataset=self.dataset,
|
|
@@ -659,7 +700,6 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
659
700
|
lr=self.lr,
|
|
660
701
|
cid=cid,
|
|
661
702
|
seed=self.seed,
|
|
662
|
-
payload=payload,
|
|
663
703
|
state_path=self.state_dir.joinpath(f"{cid}.pt"),
|
|
664
704
|
)
|
|
665
705
|
return data
|
|
@@ -3,11 +3,13 @@ import signal
|
|
|
3
3
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
4
4
|
from multiprocessing.pool import ApplyResult
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Protocol, TypeVar
|
|
6
|
+
from typing import Literal, Protocol, TypeVar
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
from tqdm import tqdm
|
|
10
10
|
|
|
11
|
+
from blazefl.utils import move_tensor_to_shared_memory
|
|
12
|
+
|
|
11
13
|
UplinkPackage = TypeVar("UplinkPackage")
|
|
12
14
|
DownlinkPackage = TypeVar("DownlinkPackage", contravariant=True)
|
|
13
15
|
|
|
@@ -47,12 +49,12 @@ class BaseClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
|
|
|
47
49
|
...
|
|
48
50
|
|
|
49
51
|
|
|
50
|
-
|
|
52
|
+
ClientConfig = TypeVar("ClientConfig")
|
|
51
53
|
|
|
52
54
|
|
|
53
55
|
class ProcessPoolClientTrainer(
|
|
54
56
|
BaseClientTrainer[UplinkPackage, DownlinkPackage],
|
|
55
|
-
Protocol[UplinkPackage, DownlinkPackage,
|
|
57
|
+
Protocol[UplinkPackage, DownlinkPackage, ClientConfig],
|
|
56
58
|
):
|
|
57
59
|
"""
|
|
58
60
|
Abstract base class for parallel client training in federated learning.
|
|
@@ -63,7 +65,12 @@ class ProcessPoolClientTrainer(
|
|
|
63
65
|
Attributes:
|
|
64
66
|
num_parallels (int): Number of parallel processes to use for client training.
|
|
65
67
|
share_dir (Path): Directory path for sharing data between processes.
|
|
68
|
+
device (str): The primary device to use for computation (e.g., "cpu", "cuda").
|
|
69
|
+
device_count (int): The number of available CUDA devices, if `device` is "cuda".
|
|
66
70
|
cache (list[UplinkPackage]): Cache to store uplink packages from clients.
|
|
71
|
+
ipc_mode (Literal["storage", "shared_memory"]): Inter-process communication
|
|
72
|
+
mode. "storage" uses disk for data exchange, "shared_memory" uses
|
|
73
|
+
shared memory for tensor data. Defaults to "storage".
|
|
67
74
|
|
|
68
75
|
Raises:
|
|
69
76
|
NotImplementedError: If the abstract methods are not implemented in a subclass.
|
|
@@ -74,17 +81,17 @@ class ProcessPoolClientTrainer(
|
|
|
74
81
|
device: str
|
|
75
82
|
device_count: int
|
|
76
83
|
cache: list[UplinkPackage]
|
|
84
|
+
ipc_mode: Literal["storage", "shared_memory"] = "storage"
|
|
77
85
|
|
|
78
|
-
def
|
|
86
|
+
def get_client_config(self, cid: int) -> ClientConfig:
|
|
79
87
|
"""
|
|
80
|
-
Retrieve
|
|
88
|
+
Retrieve the configuration for a given client ID.
|
|
81
89
|
|
|
82
90
|
Args:
|
|
83
91
|
cid (int): Client ID.
|
|
84
|
-
payload (DownlinkPackage): The data package received from the server.
|
|
85
92
|
|
|
86
93
|
Returns:
|
|
87
|
-
|
|
94
|
+
ClientConfig: The configuration for the specified client.
|
|
88
95
|
"""
|
|
89
96
|
...
|
|
90
97
|
|
|
@@ -103,16 +110,29 @@ class ProcessPoolClientTrainer(
|
|
|
103
110
|
return self.device
|
|
104
111
|
|
|
105
112
|
@staticmethod
|
|
106
|
-
def
|
|
113
|
+
def worker(
|
|
114
|
+
config: ClientConfig | Path, payload: DownlinkPackage | Path, device: str
|
|
115
|
+
) -> UplinkPackage | Path:
|
|
107
116
|
"""
|
|
108
|
-
Process a single client
|
|
117
|
+
Process a single client's training task.
|
|
118
|
+
|
|
119
|
+
This method is executed by each worker process in the pool.
|
|
120
|
+
It handles loading client configuration and payload, performing
|
|
121
|
+
the client-specific operations, and returning the result.
|
|
109
122
|
|
|
110
123
|
Args:
|
|
111
|
-
|
|
112
|
-
|
|
124
|
+
config (ClientConfig | Path):
|
|
125
|
+
The client's configuration data, or a path to a file containing
|
|
126
|
+
the configuration if `ipc_mode` is "storage".
|
|
127
|
+
payload (DownlinkPackage | Path):
|
|
128
|
+
The downlink payload from the server, or a path to a file
|
|
129
|
+
containing the payload if `ipc_mode` is "storage".
|
|
130
|
+
device (str): Device to use for processing (e.g., "cpu", "cuda:0").
|
|
113
131
|
|
|
114
132
|
Returns:
|
|
115
|
-
Path:
|
|
133
|
+
UplinkPackage | Path:
|
|
134
|
+
The uplink package containing the client's results, or a path
|
|
135
|
+
to a file containing the package if `ipc_mode` is "storage".
|
|
116
136
|
"""
|
|
117
137
|
...
|
|
118
138
|
|
|
@@ -130,6 +150,13 @@ class ProcessPoolClientTrainer(
|
|
|
130
150
|
Returns:
|
|
131
151
|
None
|
|
132
152
|
"""
|
|
153
|
+
payload_path = Path()
|
|
154
|
+
if self.ipc_mode == "storage":
|
|
155
|
+
payload_path = self.share_dir.joinpath("payload.pkl")
|
|
156
|
+
torch.save(payload, payload_path)
|
|
157
|
+
else: # shared_memory
|
|
158
|
+
move_tensor_to_shared_memory(payload)
|
|
159
|
+
|
|
133
160
|
with mp.Pool(
|
|
134
161
|
processes=self.num_parallels,
|
|
135
162
|
initializer=signal.signal,
|
|
@@ -137,16 +164,28 @@ class ProcessPoolClientTrainer(
|
|
|
137
164
|
) as pool:
|
|
138
165
|
jobs: list[ApplyResult] = []
|
|
139
166
|
for cid in cid_list:
|
|
140
|
-
|
|
141
|
-
data = self.get_shared_data(cid, payload)
|
|
167
|
+
config = self.get_client_config(cid)
|
|
142
168
|
device = self.get_client_device(cid)
|
|
143
|
-
|
|
144
|
-
|
|
169
|
+
if self.ipc_mode == "storage":
|
|
170
|
+
config_path = self.share_dir.joinpath(f"{cid}.pkl")
|
|
171
|
+
torch.save(config, config_path)
|
|
172
|
+
jobs.append(
|
|
173
|
+
pool.apply_async(
|
|
174
|
+
self.worker, (config_path, payload_path, device)
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
else: # shared_memory
|
|
178
|
+
jobs.append(
|
|
179
|
+
pool.apply_async(self.worker, (config, payload, device))
|
|
180
|
+
)
|
|
145
181
|
|
|
146
182
|
for job in tqdm(jobs, desc="Client", leave=False):
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
183
|
+
result = job.get()
|
|
184
|
+
if self.ipc_mode == "storage":
|
|
185
|
+
assert isinstance(result, Path)
|
|
186
|
+
package = torch.load(result, weights_only=False)
|
|
187
|
+
else: # shared_memory
|
|
188
|
+
package = result
|
|
150
189
|
self.cache.append(package)
|
|
151
190
|
|
|
152
191
|
|
|
@@ -159,12 +198,24 @@ class ThreadPoolClientTrainer(
|
|
|
159
198
|
device_count: int
|
|
160
199
|
cache: list[UplinkPackage]
|
|
161
200
|
|
|
162
|
-
def
|
|
201
|
+
def worker(
|
|
163
202
|
self,
|
|
164
203
|
cid: int,
|
|
165
204
|
device: str,
|
|
166
205
|
payload: DownlinkPackage,
|
|
167
|
-
) -> UplinkPackage:
|
|
206
|
+
) -> UplinkPackage:
|
|
207
|
+
"""
|
|
208
|
+
Process a single client's training task in a thread.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
cid (int): The client ID.
|
|
212
|
+
device (str): The device to use for processing this client.
|
|
213
|
+
payload (DownlinkPackage): The data package received from the server.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
UplinkPackage: The uplink package containing the client's results.
|
|
217
|
+
"""
|
|
218
|
+
...
|
|
168
219
|
|
|
169
220
|
def get_client_device(self, cid: int) -> str:
|
|
170
221
|
if self.device == "cuda":
|
|
@@ -177,7 +228,7 @@ class ThreadPoolClientTrainer(
|
|
|
177
228
|
for cid in cid_list:
|
|
178
229
|
device = self.get_client_device(cid)
|
|
179
230
|
future = executor.submit(
|
|
180
|
-
self.
|
|
231
|
+
self.worker,
|
|
181
232
|
cid,
|
|
182
233
|
device,
|
|
183
234
|
payload,
|