blazefl 2.0.0.dev1__tar.gz → 2.0.0.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/Makefile +1 -1
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/PKG-INFO +1 -1
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/.python-version +1 -0
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/Makefile +6 -0
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/config/config.yaml +17 -0
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/__init__.py +5 -0
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/client_trainer.py +45 -0
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/main.py +220 -0
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/pyproject.toml +37 -0
- blazefl-2.0.0.dev2/examples/experimental-freethreaded/uv.lock +569 -0
- blazefl-2.0.0.dev2/examples/quickstart-fedavg/dataset/__init__.py +3 -0
- blazefl-2.0.0.dev2/examples/quickstart-fedavg/dataset/dataset.py +133 -0
- blazefl-2.0.0.dev2/examples/quickstart-fedavg/dataset/functional.py +136 -0
- blazefl-2.0.0.dev2/examples/quickstart-fedavg/models/__init__.py +3 -0
- blazefl-2.0.0.dev2/examples/quickstart-fedavg/models/selector.py +45 -0
- blazefl-2.0.0.dev2/examples/step-by-step-dsfl/.gitignore +3 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/pyproject.toml +1 -1
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/__init__.py +6 -1
- blazefl-2.0.0.dev2/src/blazefl/core/__init__.pyi +6 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/client_trainer.py +39 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/client_trainer.pyi +9 -0
- blazefl-2.0.0.dev2/tests/test_utils/__init__.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/uv.lock +1 -1
- blazefl-2.0.0.dev1/src/blazefl/core/__init__.pyi +0 -6
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/FUNDING.yml +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/dependabot.yml +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/workflows/ci.yaml +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/workflows/publish.yaml +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/workflows/sphinx.yaml +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.gitignore +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.pre-commit-config.yaml +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.python-version +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/CODE_OF_CONDUCT.md +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/LICENSE +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/README.md +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/Makefile +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/architecture.png +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/benchmark_cnn.png +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/benchmark_resnet18.png +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/logo.svg +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/logo_square.svg +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/ogp.png +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/make.bat +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/_static/favicon.ico +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/_static/logo.png +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/_static/logo_square.png +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/_templates/autosummary/class.rst +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/_templates/autosummary/module.rst +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/benchmark.rst +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/conf.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/contribute.rst +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/example.rst +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/index.rst +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/install.rst +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/overview.rst +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/reference.rst +0 -0
- {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/.gitignore +0 -0
- /blazefl-2.0.0.dev1/src/blazefl/__init__.py → /blazefl-2.0.0.dev2/examples/experimental-freethreaded/README.md +0 -0
- {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/models/__init__.py +0 -0
- {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/models/selector.py +0 -0
- {blazefl-2.0.0.dev1/examples/step-by-step-dsfl → blazefl-2.0.0.dev2/examples/quickstart-fedavg}/.gitignore +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/.python-version +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/Makefile +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/README.md +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/config/config.yaml +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/main.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/pyproject.toml +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/.python-version +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/Makefile +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/README.md +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/algorithm/dsfl.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/config/config.yaml +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/main.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/models/__init__.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/models/cnn.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/models/selector.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/pyproject.toml +0 -0
- {blazefl-2.0.0.dev1/tests → blazefl-2.0.0.dev2/src/blazefl}/__init__.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/contrib/__init__.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/contrib/fedavg.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/model_selector.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/model_selector.pyi +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/partitioned_dataset.pyi +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/server_handler.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/server_handler.pyi +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/py.typed +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/__init__.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/__init__.pyi +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/dataset.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/dataset.pyi +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/seed.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/seed.pyi +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/serialize.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/serialize.pyi +0 -0
- {blazefl-2.0.0.dev1/tests/test_contrib → blazefl-2.0.0.dev2/tests}/__init__.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/conftest.py +0 -0
- {blazefl-2.0.0.dev1/tests/test_core → blazefl-2.0.0.dev2/tests/test_contrib}/__init__.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_contrib/test_fedavg.py +0 -0
- {blazefl-2.0.0.dev1/tests/test_utils → blazefl-2.0.0.dev2/tests/test_core}/__init__.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_core/test_client_trainer.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_core/test_model_selector.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_core/test_partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_utils/test_dataset.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_utils/test_seed.py +0 -0
- {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_utils/test_serialize.py +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.13t
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
model_name: cnn
|
|
2
|
+
num_clients: 100
|
|
3
|
+
global_round: 5
|
|
4
|
+
sample_ratio: 1.0
|
|
5
|
+
partition: shards
|
|
6
|
+
num_shards: 200
|
|
7
|
+
dir_alpha: 1.0
|
|
8
|
+
seed: 42
|
|
9
|
+
epochs: 5
|
|
10
|
+
lr: 0.1
|
|
11
|
+
batch_size: 50
|
|
12
|
+
num_parallels: 10
|
|
13
|
+
dataset_root_dir: /tmp/experimental-freethreaded/dataset
|
|
14
|
+
dataset_split_dir: /tmp/experimental-freethreaded/split
|
|
15
|
+
share_dir: /tmp/experimental-freethreaded/share
|
|
16
|
+
state_dir: /tmp/experimental-freethreaded/state
|
|
17
|
+
execution_mode: multi-thread
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
2
|
+
from typing import Protocol, TypeVar
|
|
3
|
+
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
UplinkPackage = TypeVar("UplinkPackage")
|
|
7
|
+
DownlinkPackage = TypeVar("DownlinkPackage", contravariant=True)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MultiThreadClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
|
|
11
|
+
num_parallels: int
|
|
12
|
+
device: str
|
|
13
|
+
device_count: int
|
|
14
|
+
cache: list[UplinkPackage]
|
|
15
|
+
|
|
16
|
+
def process_client(
|
|
17
|
+
self,
|
|
18
|
+
cid: int,
|
|
19
|
+
device: str,
|
|
20
|
+
payload: DownlinkPackage,
|
|
21
|
+
) -> UplinkPackage: ...
|
|
22
|
+
|
|
23
|
+
def get_client_device(self, cid: int) -> str:
|
|
24
|
+
if self.device == "cuda":
|
|
25
|
+
return f"cuda:{cid % self.device_count}"
|
|
26
|
+
return self.device
|
|
27
|
+
|
|
28
|
+
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
|
|
29
|
+
with ThreadPoolExecutor(max_workers=self.num_parallels) as executor:
|
|
30
|
+
futures = []
|
|
31
|
+
for cid in cid_list:
|
|
32
|
+
device = self.get_client_device(cid)
|
|
33
|
+
future = executor.submit(
|
|
34
|
+
self.process_client,
|
|
35
|
+
cid,
|
|
36
|
+
device,
|
|
37
|
+
payload,
|
|
38
|
+
)
|
|
39
|
+
futures.append(future)
|
|
40
|
+
|
|
41
|
+
for future in tqdm(
|
|
42
|
+
as_completed(futures), total=len(futures), desc="Client", leave=False
|
|
43
|
+
):
|
|
44
|
+
result = future.result()
|
|
45
|
+
self.cache.append(result)
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import hydra
|
|
7
|
+
import torch
|
|
8
|
+
import torch.multiprocessing as mp
|
|
9
|
+
from blazefl.contrib import (
|
|
10
|
+
FedAvgParallelClientTrainer,
|
|
11
|
+
FedAvgSerialClientTrainer,
|
|
12
|
+
FedAvgServerHandler,
|
|
13
|
+
)
|
|
14
|
+
from blazefl.contrib.fedavg import FedAvgDownlinkPackage, FedAvgUplinkPackage
|
|
15
|
+
from blazefl.core import ModelSelector, PartitionedDataset
|
|
16
|
+
from blazefl.utils import seed_everything
|
|
17
|
+
from omegaconf import DictConfig, OmegaConf
|
|
18
|
+
|
|
19
|
+
from core.client_trainer import MultiThreadClientTrainer
|
|
20
|
+
from dataset import PartitionedCIFAR10
|
|
21
|
+
from models import FedAvgModelSelector
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FedAvgMultiThreadClientTrainer(
|
|
25
|
+
MultiThreadClientTrainer[
|
|
26
|
+
FedAvgUplinkPackage,
|
|
27
|
+
FedAvgDownlinkPackage,
|
|
28
|
+
]
|
|
29
|
+
):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model_selector: ModelSelector,
|
|
33
|
+
model_name: str,
|
|
34
|
+
dataset: PartitionedDataset,
|
|
35
|
+
device: str,
|
|
36
|
+
num_clients: int,
|
|
37
|
+
epochs: int,
|
|
38
|
+
batch_size: int,
|
|
39
|
+
lr: float,
|
|
40
|
+
seed: int,
|
|
41
|
+
num_parallels: int,
|
|
42
|
+
) -> None:
|
|
43
|
+
self.num_parallels = num_parallels
|
|
44
|
+
self.device = device
|
|
45
|
+
if self.device == "cuda":
|
|
46
|
+
self.device_count = torch.cuda.device_count()
|
|
47
|
+
self.cache: list[FedAvgUplinkPackage] = []
|
|
48
|
+
|
|
49
|
+
self.model_selector = model_selector
|
|
50
|
+
self.model_name = model_name
|
|
51
|
+
self.dataset = dataset
|
|
52
|
+
self.epochs = epochs
|
|
53
|
+
self.batch_size = batch_size
|
|
54
|
+
self.lr = lr
|
|
55
|
+
self.num_clients = num_clients
|
|
56
|
+
self.seed = seed
|
|
57
|
+
|
|
58
|
+
def process_client(
|
|
59
|
+
self,
|
|
60
|
+
cid: int,
|
|
61
|
+
device: str,
|
|
62
|
+
payload: FedAvgDownlinkPackage,
|
|
63
|
+
) -> FedAvgUplinkPackage:
|
|
64
|
+
model = self.model_selector.select_model(self.model_name)
|
|
65
|
+
train_loader = self.dataset.get_dataloader(
|
|
66
|
+
type_="train",
|
|
67
|
+
cid=cid,
|
|
68
|
+
batch_size=self.batch_size,
|
|
69
|
+
)
|
|
70
|
+
package = FedAvgParallelClientTrainer.train(
|
|
71
|
+
model=model,
|
|
72
|
+
model_parameters=payload.model_parameters,
|
|
73
|
+
train_loader=train_loader,
|
|
74
|
+
device=device,
|
|
75
|
+
epochs=self.epochs,
|
|
76
|
+
lr=self.lr,
|
|
77
|
+
)
|
|
78
|
+
return package
|
|
79
|
+
|
|
80
|
+
def uplink_package(self) -> list[FedAvgUplinkPackage]:
|
|
81
|
+
package = deepcopy(self.cache)
|
|
82
|
+
self.cache = []
|
|
83
|
+
return package
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class FedAvgPipeline:
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
handler: FedAvgServerHandler,
|
|
90
|
+
trainer: FedAvgSerialClientTrainer
|
|
91
|
+
| FedAvgParallelClientTrainer
|
|
92
|
+
| FedAvgMultiThreadClientTrainer,
|
|
93
|
+
) -> None:
|
|
94
|
+
self.handler = handler
|
|
95
|
+
self.trainer = trainer
|
|
96
|
+
|
|
97
|
+
def main(self):
|
|
98
|
+
while not self.handler.if_stop():
|
|
99
|
+
round_ = self.handler.round
|
|
100
|
+
# server side
|
|
101
|
+
sampled_clients = self.handler.sample_clients()
|
|
102
|
+
broadcast = self.handler.downlink_package()
|
|
103
|
+
|
|
104
|
+
# client side
|
|
105
|
+
self.trainer.local_process(broadcast, sampled_clients)
|
|
106
|
+
uploads = self.trainer.uplink_package()
|
|
107
|
+
|
|
108
|
+
# server side
|
|
109
|
+
for pack in uploads:
|
|
110
|
+
self.handler.load(pack)
|
|
111
|
+
|
|
112
|
+
summary = self.handler.get_summary()
|
|
113
|
+
formatted_summary = ", ".join(f"{k}: {v:.3f}" for k, v in summary.items())
|
|
114
|
+
logging.info(f"round: {round_}, {formatted_summary}")
|
|
115
|
+
|
|
116
|
+
logging.info("done!")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@hydra.main(version_base=None, config_path="config", config_name="config")
|
|
120
|
+
def main(cfg: DictConfig):
|
|
121
|
+
print(OmegaConf.to_yaml(cfg))
|
|
122
|
+
|
|
123
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
124
|
+
dataset_root_dir = Path(cfg.dataset_root_dir)
|
|
125
|
+
dataset_split_dir = dataset_root_dir.joinpath(timestamp)
|
|
126
|
+
share_dir = Path(cfg.share_dir).joinpath(timestamp)
|
|
127
|
+
state_dir = Path(cfg.state_dir).joinpath(timestamp)
|
|
128
|
+
|
|
129
|
+
device = "cpu"
|
|
130
|
+
if torch.cuda.is_available():
|
|
131
|
+
device = "cuda"
|
|
132
|
+
elif torch.backends.mps.is_available():
|
|
133
|
+
device = "mps"
|
|
134
|
+
logging.info(f"device: {device}")
|
|
135
|
+
|
|
136
|
+
seed_everything(cfg.seed, device=device)
|
|
137
|
+
|
|
138
|
+
dataset = PartitionedCIFAR10(
|
|
139
|
+
root=dataset_root_dir,
|
|
140
|
+
path=dataset_split_dir,
|
|
141
|
+
num_clients=cfg.num_clients,
|
|
142
|
+
seed=cfg.seed,
|
|
143
|
+
partition=cfg.partition,
|
|
144
|
+
num_shards=cfg.num_shards,
|
|
145
|
+
dir_alpha=cfg.dir_alpha,
|
|
146
|
+
)
|
|
147
|
+
model_selector = FedAvgModelSelector(num_classes=10)
|
|
148
|
+
|
|
149
|
+
handler = FedAvgServerHandler(
|
|
150
|
+
model_selector=model_selector,
|
|
151
|
+
model_name=cfg.model_name,
|
|
152
|
+
dataset=dataset,
|
|
153
|
+
global_round=cfg.global_round,
|
|
154
|
+
num_clients=cfg.num_clients,
|
|
155
|
+
device=device,
|
|
156
|
+
sample_ratio=cfg.sample_ratio,
|
|
157
|
+
batch_size=cfg.batch_size,
|
|
158
|
+
)
|
|
159
|
+
trainer: (
|
|
160
|
+
FedAvgSerialClientTrainer
|
|
161
|
+
| FedAvgParallelClientTrainer
|
|
162
|
+
| FedAvgMultiThreadClientTrainer
|
|
163
|
+
| None
|
|
164
|
+
) = None
|
|
165
|
+
match cfg.execution_mode:
|
|
166
|
+
case "serial":
|
|
167
|
+
trainer = FedAvgSerialClientTrainer(
|
|
168
|
+
model_selector=model_selector,
|
|
169
|
+
model_name=cfg.model_name,
|
|
170
|
+
dataset=dataset,
|
|
171
|
+
device=device,
|
|
172
|
+
num_clients=cfg.num_clients,
|
|
173
|
+
epochs=cfg.epochs,
|
|
174
|
+
lr=cfg.lr,
|
|
175
|
+
batch_size=cfg.batch_size,
|
|
176
|
+
)
|
|
177
|
+
case "multi-process":
|
|
178
|
+
trainer = FedAvgParallelClientTrainer(
|
|
179
|
+
model_selector=model_selector,
|
|
180
|
+
model_name=cfg.model_name,
|
|
181
|
+
dataset=dataset,
|
|
182
|
+
share_dir=share_dir,
|
|
183
|
+
state_dir=state_dir,
|
|
184
|
+
seed=cfg.seed,
|
|
185
|
+
device=device,
|
|
186
|
+
num_clients=cfg.num_clients,
|
|
187
|
+
epochs=cfg.epochs,
|
|
188
|
+
lr=cfg.lr,
|
|
189
|
+
batch_size=cfg.batch_size,
|
|
190
|
+
num_parallels=cfg.num_parallels,
|
|
191
|
+
)
|
|
192
|
+
case "multi-thread":
|
|
193
|
+
trainer = FedAvgMultiThreadClientTrainer(
|
|
194
|
+
model_selector=model_selector,
|
|
195
|
+
model_name=cfg.model_name,
|
|
196
|
+
dataset=dataset,
|
|
197
|
+
device=device,
|
|
198
|
+
num_clients=cfg.num_clients,
|
|
199
|
+
epochs=cfg.epochs,
|
|
200
|
+
lr=cfg.lr,
|
|
201
|
+
batch_size=cfg.batch_size,
|
|
202
|
+
num_parallels=cfg.num_parallels,
|
|
203
|
+
seed=cfg.seed,
|
|
204
|
+
)
|
|
205
|
+
case _:
|
|
206
|
+
raise ValueError(f"Invalid execution mode: {cfg.execution_mode}")
|
|
207
|
+
pipeline = FedAvgPipeline(handler=handler, trainer=trainer)
|
|
208
|
+
try:
|
|
209
|
+
pipeline.main()
|
|
210
|
+
except KeyboardInterrupt:
|
|
211
|
+
logging.info("KeyboardInterrupt: Stopping the pipeline.")
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logging.exception(f"An error occurred: {e}")
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
if __name__ == "__main__":
|
|
217
|
+
# NOTE: To use CUDA with multiprocessing, you must use the 'spawn' start method
|
|
218
|
+
mp.set_start_method("spawn")
|
|
219
|
+
|
|
220
|
+
main()
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "experimental-freethreaded"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Add your description here"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.13"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"blazefl>=2.0.0dev1",
|
|
9
|
+
"hydra-core>=1.3.2",
|
|
10
|
+
"torch>=2.7.1",
|
|
11
|
+
"torchvision>=0.22.1",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[tool.basedpyright]
|
|
15
|
+
typeCheckingMode = "standard"
|
|
16
|
+
|
|
17
|
+
[[tool.mypy.overrides]]
|
|
18
|
+
module = ["torchvision.*"]
|
|
19
|
+
ignore_missing_imports = true
|
|
20
|
+
|
|
21
|
+
[tool.ruff.lint]
|
|
22
|
+
select = [
|
|
23
|
+
"E", # pycodestyle
|
|
24
|
+
"F", # Pyflakes
|
|
25
|
+
"UP", # pyupgrade
|
|
26
|
+
"B", # flake8-bugbear
|
|
27
|
+
"SIM", # flake8-simplify
|
|
28
|
+
"I", # isort
|
|
29
|
+
]
|
|
30
|
+
ignore = []
|
|
31
|
+
fixable = ["ALL"]
|
|
32
|
+
|
|
33
|
+
[dependency-groups]
|
|
34
|
+
dev = [
|
|
35
|
+
"mypy>=1.16.0",
|
|
36
|
+
"types-tqdm>=4.67.0.20250516",
|
|
37
|
+
]
|