blazefl 2.0.0.dev6__tar.gz → 2.0.0.dev7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/PKG-INFO +1 -1
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/config/config.yaml +1 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/main.py +20 -83
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/pyproject.toml +1 -1
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/uv.lock +4 -4
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/dataset/dataset.py +6 -7
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/pyproject.toml +1 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/algorithm/dsfl.py +23 -7
- blazefl-2.0.0.dev7/examples/step-by-step-dsfl/dataset/__init__.py +3 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/dataset/dataset.py +14 -9
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/pyproject.toml +1 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/pyproject.toml +1 -1
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/contrib/__init__.py +4 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/contrib/fedavg.py +19 -15
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/partitioned_dataset.py +7 -4
- blazefl-2.0.0.dev7/src/blazefl/core/partitioned_dataset.pyi +9 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/uv.lock +11 -3
- blazefl-2.0.0.dev6/examples/step-by-step-dsfl/dataset/__init__.py +0 -3
- blazefl-2.0.0.dev6/src/blazefl/core/partitioned_dataset.pyi +0 -6
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/FUNDING.yml +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/dependabot.yml +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/workflows/ci.yaml +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/workflows/publish.yaml +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/workflows/sphinx.yaml +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.gitignore +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.pre-commit-config.yaml +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.python-version +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/CODE_OF_CONDUCT.md +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/LICENSE +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/Makefile +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/README.md +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/Makefile +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/architecture.png +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/benchmark_cnn.png +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/benchmark_resnet18.png +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/logo.svg +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/logo_square.svg +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/ogp.png +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/make.bat +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/_static/favicon.ico +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/_static/logo.png +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/_static/logo_square.png +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/_templates/autosummary/class.rst +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/_templates/autosummary/module.rst +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/benchmark.rst +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/conf.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/contribute.rst +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/example.rst +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/index.rst +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/install.rst +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/overview.rst +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/reference.rst +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/.gitignore +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/.python-version +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/Makefile +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/README.md +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/models/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/models/selector.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/.gitignore +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/.python-version +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/Makefile +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/README.md +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/config/config.yaml +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/main.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/models/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/models/selector.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/.gitignore +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/.python-version +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/Makefile +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/README.md +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/config/config.yaml +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/main.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/models/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/models/cnn.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/models/selector.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/__init__.pyi +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/client_trainer.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/client_trainer.pyi +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/model_selector.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/model_selector.pyi +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/server_handler.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/server_handler.pyi +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/py.typed +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/__init__.pyi +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/dataset.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/dataset.pyi +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/ipc.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/ipc.pyi +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/seed.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/seed.pyi +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/serialize.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/serialize.pyi +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/conftest.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_contrib/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_contrib/test_fedavg.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_core/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_core/test_client_trainer.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_core/test_model_selector.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_core/test_partitioned_dataset.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_utils/__init__.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_utils/test_dataset.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_utils/test_seed.py +0 -0
- {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_utils/test_serialize.py +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from copy import deepcopy
|
|
3
2
|
from datetime import datetime
|
|
4
3
|
from pathlib import Path
|
|
5
4
|
|
|
@@ -7,12 +6,13 @@ import hydra
|
|
|
7
6
|
import torch
|
|
8
7
|
import torch.multiprocessing as mp
|
|
9
8
|
from blazefl.contrib import (
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
9
|
+
FedAvgBaseClientTrainer,
|
|
10
|
+
FedAvgBaseServerHandler,
|
|
11
|
+
FedAvgProcessPoolClientTrainer,
|
|
12
|
+
)
|
|
13
|
+
from blazefl.contrib.fedavg import (
|
|
14
|
+
FedAvgThreadPoolClientTrainer,
|
|
13
15
|
)
|
|
14
|
-
from blazefl.contrib.fedavg import FedAvgDownlinkPackage, FedAvgUplinkPackage
|
|
15
|
-
from blazefl.core import ModelSelector, MultiThreadClientTrainer, PartitionedDataset
|
|
16
16
|
from blazefl.utils import seed_everything
|
|
17
17
|
from omegaconf import DictConfig, OmegaConf
|
|
18
18
|
|
|
@@ -20,75 +20,13 @@ from dataset import PartitionedCIFAR10
|
|
|
20
20
|
from models import FedAvgModelSelector
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
class FedAvgMultiThreadClientTrainer(
|
|
24
|
-
MultiThreadClientTrainer[
|
|
25
|
-
FedAvgUplinkPackage,
|
|
26
|
-
FedAvgDownlinkPackage,
|
|
27
|
-
]
|
|
28
|
-
):
|
|
29
|
-
def __init__(
|
|
30
|
-
self,
|
|
31
|
-
model_selector: ModelSelector,
|
|
32
|
-
model_name: str,
|
|
33
|
-
dataset: PartitionedDataset,
|
|
34
|
-
device: str,
|
|
35
|
-
num_clients: int,
|
|
36
|
-
epochs: int,
|
|
37
|
-
batch_size: int,
|
|
38
|
-
lr: float,
|
|
39
|
-
seed: int,
|
|
40
|
-
num_parallels: int,
|
|
41
|
-
) -> None:
|
|
42
|
-
self.num_parallels = num_parallels
|
|
43
|
-
self.device = device
|
|
44
|
-
if self.device == "cuda":
|
|
45
|
-
self.device_count = torch.cuda.device_count()
|
|
46
|
-
self.cache: list[FedAvgUplinkPackage] = []
|
|
47
|
-
|
|
48
|
-
self.model_selector = model_selector
|
|
49
|
-
self.model_name = model_name
|
|
50
|
-
self.dataset = dataset
|
|
51
|
-
self.epochs = epochs
|
|
52
|
-
self.batch_size = batch_size
|
|
53
|
-
self.lr = lr
|
|
54
|
-
self.num_clients = num_clients
|
|
55
|
-
self.seed = seed
|
|
56
|
-
|
|
57
|
-
def process_client(
|
|
58
|
-
self,
|
|
59
|
-
cid: int,
|
|
60
|
-
device: str,
|
|
61
|
-
payload: FedAvgDownlinkPackage,
|
|
62
|
-
) -> FedAvgUplinkPackage:
|
|
63
|
-
model = self.model_selector.select_model(self.model_name)
|
|
64
|
-
train_loader = self.dataset.get_dataloader(
|
|
65
|
-
type_="train",
|
|
66
|
-
cid=cid,
|
|
67
|
-
batch_size=self.batch_size,
|
|
68
|
-
)
|
|
69
|
-
package = FedAvgParallelClientTrainer.train(
|
|
70
|
-
model=model,
|
|
71
|
-
model_parameters=payload.model_parameters,
|
|
72
|
-
train_loader=train_loader,
|
|
73
|
-
device=device,
|
|
74
|
-
epochs=self.epochs,
|
|
75
|
-
lr=self.lr,
|
|
76
|
-
)
|
|
77
|
-
return package
|
|
78
|
-
|
|
79
|
-
def uplink_package(self) -> list[FedAvgUplinkPackage]:
|
|
80
|
-
package = deepcopy(self.cache)
|
|
81
|
-
self.cache = []
|
|
82
|
-
return package
|
|
83
|
-
|
|
84
|
-
|
|
85
23
|
class FedAvgPipeline:
|
|
86
24
|
def __init__(
|
|
87
25
|
self,
|
|
88
|
-
handler:
|
|
89
|
-
trainer:
|
|
90
|
-
|
|
|
91
|
-
|
|
|
26
|
+
handler: FedAvgBaseServerHandler,
|
|
27
|
+
trainer: FedAvgBaseClientTrainer
|
|
28
|
+
| FedAvgProcessPoolClientTrainer
|
|
29
|
+
| FedAvgThreadPoolClientTrainer,
|
|
92
30
|
) -> None:
|
|
93
31
|
self.handler = handler
|
|
94
32
|
self.trainer = trainer
|
|
@@ -145,7 +83,7 @@ def main(cfg: DictConfig):
|
|
|
145
83
|
)
|
|
146
84
|
model_selector = FedAvgModelSelector(num_classes=10)
|
|
147
85
|
|
|
148
|
-
handler =
|
|
86
|
+
handler = FedAvgBaseServerHandler(
|
|
149
87
|
model_selector=model_selector,
|
|
150
88
|
model_name=cfg.model_name,
|
|
151
89
|
dataset=dataset,
|
|
@@ -156,14 +94,14 @@ def main(cfg: DictConfig):
|
|
|
156
94
|
batch_size=cfg.batch_size,
|
|
157
95
|
)
|
|
158
96
|
trainer: (
|
|
159
|
-
|
|
160
|
-
|
|
|
161
|
-
|
|
|
97
|
+
FedAvgBaseClientTrainer
|
|
98
|
+
| FedAvgProcessPoolClientTrainer
|
|
99
|
+
| FedAvgThreadPoolClientTrainer
|
|
162
100
|
| None
|
|
163
101
|
) = None
|
|
164
102
|
match cfg.execution_mode:
|
|
165
|
-
case "
|
|
166
|
-
trainer =
|
|
103
|
+
case "single-thread":
|
|
104
|
+
trainer = FedAvgBaseClientTrainer(
|
|
167
105
|
model_selector=model_selector,
|
|
168
106
|
model_name=cfg.model_name,
|
|
169
107
|
dataset=dataset,
|
|
@@ -174,7 +112,7 @@ def main(cfg: DictConfig):
|
|
|
174
112
|
batch_size=cfg.batch_size,
|
|
175
113
|
)
|
|
176
114
|
case "multi-process":
|
|
177
|
-
trainer =
|
|
115
|
+
trainer = FedAvgProcessPoolClientTrainer(
|
|
178
116
|
model_selector=model_selector,
|
|
179
117
|
model_name=cfg.model_name,
|
|
180
118
|
dataset=dataset,
|
|
@@ -187,9 +125,10 @@ def main(cfg: DictConfig):
|
|
|
187
125
|
lr=cfg.lr,
|
|
188
126
|
batch_size=cfg.batch_size,
|
|
189
127
|
num_parallels=cfg.num_parallels,
|
|
128
|
+
ipc_mode=cfg.ipc_mode,
|
|
190
129
|
)
|
|
191
130
|
case "multi-thread":
|
|
192
|
-
trainer =
|
|
131
|
+
trainer = FedAvgThreadPoolClientTrainer(
|
|
193
132
|
model_selector=model_selector,
|
|
194
133
|
model_name=cfg.model_name,
|
|
195
134
|
dataset=dataset,
|
|
@@ -207,9 +146,7 @@ def main(cfg: DictConfig):
|
|
|
207
146
|
try:
|
|
208
147
|
pipeline.main()
|
|
209
148
|
except KeyboardInterrupt:
|
|
210
|
-
logging.info("KeyboardInterrupt
|
|
211
|
-
except Exception as e:
|
|
212
|
-
logging.exception(f"An error occurred: {e}")
|
|
149
|
+
logging.info("KeyboardInterrupt")
|
|
213
150
|
|
|
214
151
|
|
|
215
152
|
if __name__ == "__main__":
|
|
@@ -10,16 +10,16 @@ sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d
|
|
|
10
10
|
|
|
11
11
|
[[package]]
|
|
12
12
|
name = "blazefl"
|
|
13
|
-
version = "2.0.0.
|
|
13
|
+
version = "2.0.0.dev6"
|
|
14
14
|
source = { registry = "https://pypi.org/simple" }
|
|
15
15
|
dependencies = [
|
|
16
16
|
{ name = "numpy" },
|
|
17
17
|
{ name = "torch" },
|
|
18
18
|
{ name = "tqdm" },
|
|
19
19
|
]
|
|
20
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
|
20
|
+
sdist = { url = "https://files.pythonhosted.org/packages/d7/76/d247942d051447dc1c6dbc25b2d55b4cdf146cc0ebd3594c0c33628f465c/blazefl-2.0.0.dev6.tar.gz", hash = "sha256:d4875f03872917dfd99b3a70ac2c6efc3483ee0e42e484d7328e0810fe5141b9", size = 615380, upload_time = "2025-06-16T16:26:50.929Z" }
|
|
21
21
|
wheels = [
|
|
22
|
-
{ url = "https://files.pythonhosted.org/packages/
|
|
22
|
+
{ url = "https://files.pythonhosted.org/packages/c8/30/eba32caaed89d04c925f676d48fb50bad7bed3ce2c86c6ebcaf2eb92aff1/blazefl-2.0.0.dev6-py3-none-any.whl", hash = "sha256:e38c3739be9ef2bc7bb8ba213da3319e4be8f9ea9f27ecd661cbc70c0cedd655", size = 26215, upload_time = "2025-06-16T16:26:49.512Z" },
|
|
23
23
|
]
|
|
24
24
|
|
|
25
25
|
[[package]]
|
|
@@ -50,7 +50,7 @@ dev = [
|
|
|
50
50
|
|
|
51
51
|
[package.metadata]
|
|
52
52
|
requires-dist = [
|
|
53
|
-
{ name = "blazefl", specifier = ">=2.0.0.
|
|
53
|
+
{ name = "blazefl", specifier = ">=2.0.0.dev6" },
|
|
54
54
|
{ name = "hydra-core", specifier = ">=1.3.2" },
|
|
55
55
|
{ name = "torch", specifier = ">=2.7.1" },
|
|
56
56
|
{ name = "torchvision", specifier = ">=0.22.1" },
|
|
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
import torchvision
|
|
6
|
+
from blazefl.contrib import FedAvgPartitionType
|
|
6
7
|
from blazefl.core import PartitionedDataset
|
|
7
8
|
from blazefl.utils import FilteredDataset
|
|
8
9
|
from torch.utils.data import DataLoader, Dataset
|
|
@@ -15,7 +16,7 @@ from dataset.functional import (
|
|
|
15
16
|
)
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
class PartitionedCIFAR10(PartitionedDataset):
|
|
19
|
+
class PartitionedCIFAR10(PartitionedDataset[FedAvgPartitionType]):
|
|
19
20
|
def __init__(
|
|
20
21
|
self,
|
|
21
22
|
root: Path,
|
|
@@ -107,24 +108,22 @@ class PartitionedCIFAR10(PartitionedDataset):
|
|
|
107
108
|
self.path.joinpath("test.pkl"),
|
|
108
109
|
)
|
|
109
110
|
|
|
110
|
-
def get_dataset(self, type_:
|
|
111
|
+
def get_dataset(self, type_: FedAvgPartitionType, cid: int | None) -> Dataset:
|
|
111
112
|
match type_:
|
|
112
|
-
case
|
|
113
|
+
case FedAvgPartitionType.TRAIN:
|
|
113
114
|
dataset = torch.load(
|
|
114
115
|
self.path.joinpath(type_, f"{cid}.pkl"),
|
|
115
116
|
weights_only=False,
|
|
116
117
|
)
|
|
117
|
-
case
|
|
118
|
+
case FedAvgPartitionType.TEST:
|
|
118
119
|
dataset = torch.load(
|
|
119
120
|
self.path.joinpath(f"{type_}.pkl"), weights_only=False
|
|
120
121
|
)
|
|
121
|
-
case _:
|
|
122
|
-
raise ValueError(f"Invalid dataset type: {type_}")
|
|
123
122
|
assert isinstance(dataset, Dataset)
|
|
124
123
|
return dataset
|
|
125
124
|
|
|
126
125
|
def get_dataloader(
|
|
127
|
-
self, type_:
|
|
126
|
+
self, type_: FedAvgPartitionType, cid: int | None, batch_size: int | None = None
|
|
128
127
|
) -> DataLoader:
|
|
129
128
|
dataset = self.get_dataset(type_, cid)
|
|
130
129
|
assert isinstance(dataset, Sized)
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import random
|
|
2
|
+
import threading
|
|
2
3
|
from collections import defaultdict
|
|
3
4
|
from copy import deepcopy
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
|
|
7
8
|
import torch
|
|
9
|
+
import torch.multiprocessing as mp
|
|
8
10
|
import torch.nn.functional as F
|
|
9
11
|
from blazefl.core import (
|
|
10
12
|
BaseServerHandler,
|
|
@@ -17,7 +19,7 @@ from blazefl.utils import (
|
|
|
17
19
|
)
|
|
18
20
|
from torch.utils.data import DataLoader, Subset
|
|
19
21
|
|
|
20
|
-
from dataset import DSFLPartitionedDataset
|
|
22
|
+
from dataset import DSFLPartitionedDataset, DSFLPartitionType
|
|
21
23
|
from models import DSFLModelSelector
|
|
22
24
|
|
|
23
25
|
|
|
@@ -125,6 +127,7 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
|
|
|
125
127
|
self.kd_epochs,
|
|
126
128
|
self.kd_batch_size,
|
|
127
129
|
self.device,
|
|
130
|
+
stop_event=None,
|
|
128
131
|
)
|
|
129
132
|
|
|
130
133
|
self.global_soft_labels = torch.stack(global_soft_labels)
|
|
@@ -140,10 +143,11 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
|
|
|
140
143
|
kd_epochs: int,
|
|
141
144
|
kd_batch_size: int,
|
|
142
145
|
device: str,
|
|
146
|
+
stop_event: threading.Event | None,
|
|
143
147
|
) -> None:
|
|
144
148
|
model.to(device)
|
|
145
149
|
model.train()
|
|
146
|
-
open_dataset = dataset.get_dataset(type_=
|
|
150
|
+
open_dataset = dataset.get_dataset(type_=DSFLPartitionType.OPEN, cid=None)
|
|
147
151
|
open_loader = DataLoader(
|
|
148
152
|
Subset(open_dataset, global_indices),
|
|
149
153
|
batch_size=kd_batch_size,
|
|
@@ -156,6 +160,8 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
|
|
|
156
160
|
batch_size=kd_batch_size,
|
|
157
161
|
)
|
|
158
162
|
for _ in range(kd_epochs):
|
|
163
|
+
if stop_event is not None and stop_event.is_set():
|
|
164
|
+
break
|
|
159
165
|
for data, soft_label in zip(
|
|
160
166
|
open_loader, global_soft_label_loader, strict=True
|
|
161
167
|
):
|
|
@@ -208,7 +214,7 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
|
|
|
208
214
|
server_loss, server_acc = DSFLBaseServerHandler.evaulate(
|
|
209
215
|
self.model,
|
|
210
216
|
self.dataset.get_dataloader(
|
|
211
|
-
type_=
|
|
217
|
+
type_=DSFLPartitionType.TEST,
|
|
212
218
|
cid=None,
|
|
213
219
|
batch_size=self.kd_batch_size,
|
|
214
220
|
),
|
|
@@ -300,6 +306,8 @@ class DSFLProcessPoolClientTrainer(
|
|
|
300
306
|
self.num_clients = num_clients
|
|
301
307
|
self.seed = seed
|
|
302
308
|
self.ipc_mode = "storage"
|
|
309
|
+
self.manager = mp.Manager()
|
|
310
|
+
self.stop_event = self.manager.Event()
|
|
303
311
|
|
|
304
312
|
if self.device == "cuda":
|
|
305
313
|
self.device_count = torch.cuda.device_count()
|
|
@@ -309,6 +317,7 @@ class DSFLProcessPoolClientTrainer(
|
|
|
309
317
|
config: DSFLClientConfig | Path,
|
|
310
318
|
payload: DSFLDownlinkPackage | Path,
|
|
311
319
|
device: str,
|
|
320
|
+
stop_event: threading.Event,
|
|
312
321
|
) -> Path:
|
|
313
322
|
assert isinstance(config, Path) and isinstance(payload, Path)
|
|
314
323
|
config_path, payload_path = config, payload
|
|
@@ -334,7 +343,7 @@ class DSFLProcessPoolClientTrainer(
|
|
|
334
343
|
seed_everything(c.seed, device=device)
|
|
335
344
|
|
|
336
345
|
# Distill
|
|
337
|
-
open_dataset = c.dataset.get_dataset(type_=
|
|
346
|
+
open_dataset = c.dataset.get_dataset(type_=DSFLPartitionType.OPEN, cid=None)
|
|
338
347
|
if p.indices is not None and p.soft_labels is not None:
|
|
339
348
|
global_soft_labels = list(torch.unbind(p.soft_labels, dim=0))
|
|
340
349
|
global_indices = p.indices.tolist()
|
|
@@ -349,11 +358,12 @@ class DSFLProcessPoolClientTrainer(
|
|
|
349
358
|
kd_epochs=c.kd_epochs,
|
|
350
359
|
kd_batch_size=c.kd_batch_size,
|
|
351
360
|
device=device,
|
|
361
|
+
stop_event=stop_event,
|
|
352
362
|
)
|
|
353
363
|
|
|
354
364
|
# Train
|
|
355
365
|
train_loader = c.dataset.get_dataloader(
|
|
356
|
-
type_=
|
|
366
|
+
type_=DSFLPartitionType.TRAIN,
|
|
357
367
|
cid=c.cid,
|
|
358
368
|
batch_size=c.batch_size,
|
|
359
369
|
)
|
|
@@ -363,6 +373,7 @@ class DSFLProcessPoolClientTrainer(
|
|
|
363
373
|
train_loader=train_loader,
|
|
364
374
|
device=device,
|
|
365
375
|
epochs=c.epochs,
|
|
376
|
+
stop_event=stop_event,
|
|
366
377
|
)
|
|
367
378
|
|
|
368
379
|
# Predict
|
|
@@ -378,7 +389,7 @@ class DSFLProcessPoolClientTrainer(
|
|
|
378
389
|
|
|
379
390
|
# Evaluate
|
|
380
391
|
test_loader = c.dataset.get_dataloader(
|
|
381
|
-
type_=
|
|
392
|
+
type_=DSFLPartitionType.TEST,
|
|
382
393
|
cid=c.cid,
|
|
383
394
|
batch_size=c.batch_size,
|
|
384
395
|
)
|
|
@@ -411,12 +422,15 @@ class DSFLProcessPoolClientTrainer(
|
|
|
411
422
|
train_loader: DataLoader,
|
|
412
423
|
device: str,
|
|
413
424
|
epochs: int,
|
|
425
|
+
stop_event: threading.Event,
|
|
414
426
|
) -> None:
|
|
415
427
|
model.to(device)
|
|
416
428
|
model.train()
|
|
417
429
|
criterion = torch.nn.CrossEntropyLoss()
|
|
418
430
|
|
|
419
431
|
for _ in range(epochs):
|
|
432
|
+
if stop_event.is_set():
|
|
433
|
+
break
|
|
420
434
|
for data, target in train_loader:
|
|
421
435
|
data = data.to(device)
|
|
422
436
|
target = target.to(device)
|
|
@@ -431,7 +445,9 @@ class DSFLProcessPoolClientTrainer(
|
|
|
431
445
|
|
|
432
446
|
@staticmethod
|
|
433
447
|
def predict(
|
|
434
|
-
model: torch.nn.Module,
|
|
448
|
+
model: torch.nn.Module,
|
|
449
|
+
open_loader: DataLoader,
|
|
450
|
+
device: str,
|
|
435
451
|
) -> torch.Tensor:
|
|
436
452
|
model.to(device)
|
|
437
453
|
model.eval()
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from collections.abc import Sized
|
|
2
|
+
from enum import StrEnum
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
@@ -15,7 +16,13 @@ from dataset.functional import (
|
|
|
15
16
|
)
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
class
|
|
19
|
+
class DSFLPartitionType(StrEnum):
|
|
20
|
+
TRAIN = "train"
|
|
21
|
+
OPEN = "open"
|
|
22
|
+
TEST = "test"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DSFLPartitionedDataset(PartitionedDataset[DSFLPartitionType]):
|
|
19
26
|
def __init__(
|
|
20
27
|
self,
|
|
21
28
|
root: Path,
|
|
@@ -68,7 +75,7 @@ class DSFLPartitionedDataset(PartitionedDataset):
|
|
|
68
75
|
train=False,
|
|
69
76
|
download=True,
|
|
70
77
|
)
|
|
71
|
-
for type_ in [
|
|
78
|
+
for type_ in [ds.value for ds in DSFLPartitionType]:
|
|
72
79
|
self.path.joinpath(type_).mkdir(parents=True)
|
|
73
80
|
|
|
74
81
|
match self.partition:
|
|
@@ -141,19 +148,19 @@ class DSFLPartitionedDataset(PartitionedDataset):
|
|
|
141
148
|
self.path.joinpath("test", "default.pkl"),
|
|
142
149
|
)
|
|
143
150
|
|
|
144
|
-
def get_dataset(self, type_:
|
|
151
|
+
def get_dataset(self, type_: DSFLPartitionType, cid: int | None) -> Dataset:
|
|
145
152
|
match type_:
|
|
146
|
-
case
|
|
153
|
+
case DSFLPartitionType.TRAIN:
|
|
147
154
|
dataset = torch.load(
|
|
148
155
|
self.path.joinpath(type_, f"{cid}.pkl"),
|
|
149
156
|
weights_only=False,
|
|
150
157
|
)
|
|
151
|
-
case
|
|
158
|
+
case DSFLPartitionType.OPEN:
|
|
152
159
|
dataset = torch.load(
|
|
153
160
|
self.path.joinpath(f"{type_}.pkl"),
|
|
154
161
|
weights_only=False,
|
|
155
162
|
)
|
|
156
|
-
case
|
|
163
|
+
case DSFLPartitionType.TEST:
|
|
157
164
|
if cid is not None:
|
|
158
165
|
dataset = torch.load(
|
|
159
166
|
self.path.joinpath(type_, f"{cid}.pkl"),
|
|
@@ -163,13 +170,11 @@ class DSFLPartitionedDataset(PartitionedDataset):
|
|
|
163
170
|
dataset = torch.load(
|
|
164
171
|
self.path.joinpath(type_, "default.pkl"), weights_only=False
|
|
165
172
|
)
|
|
166
|
-
case _:
|
|
167
|
-
raise ValueError(f"Invalid dataset type: {type_}")
|
|
168
173
|
assert isinstance(dataset, Dataset)
|
|
169
174
|
return dataset
|
|
170
175
|
|
|
171
176
|
def get_dataloader(
|
|
172
|
-
self, type_:
|
|
177
|
+
self, type_: DSFLPartitionType, cid: int | None, batch_size: int | None = None
|
|
173
178
|
) -> DataLoader:
|
|
174
179
|
dataset = self.get_dataset(type_, cid)
|
|
175
180
|
assert isinstance(dataset, Sized)
|
|
@@ -9,6 +9,8 @@ from blazefl.contrib.fedavg import (
|
|
|
9
9
|
FedAvgBaseClientTrainer,
|
|
10
10
|
FedAvgBaseServerHandler,
|
|
11
11
|
FedAvgDownlinkPackage,
|
|
12
|
+
FedAvgPartitionedDataset,
|
|
13
|
+
FedAvgPartitionType,
|
|
12
14
|
FedAvgProcessPoolClientTrainer,
|
|
13
15
|
FedAvgThreadPoolClientTrainer,
|
|
14
16
|
FedAvgUplinkPackage,
|
|
@@ -21,4 +23,6 @@ __all__ = [
|
|
|
21
23
|
"FedAvgThreadPoolClientTrainer",
|
|
22
24
|
"FedAvgUplinkPackage",
|
|
23
25
|
"FedAvgDownlinkPackage",
|
|
26
|
+
"FedAvgPartitionType",
|
|
27
|
+
"FedAvgPartitionedDataset",
|
|
24
28
|
]
|
|
@@ -2,6 +2,7 @@ import random
|
|
|
2
2
|
import threading
|
|
3
3
|
from copy import deepcopy
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
+
from enum import StrEnum
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import Literal
|
|
7
8
|
|
|
@@ -57,8 +58,16 @@ class FedAvgDownlinkPackage:
|
|
|
57
58
|
model_parameters: torch.Tensor
|
|
58
59
|
|
|
59
60
|
|
|
61
|
+
class FedAvgPartitionType(StrEnum):
|
|
62
|
+
TRAIN = "train"
|
|
63
|
+
TEST = "test"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
FedAvgPartitionedDataset = PartitionedDataset[FedAvgPartitionType]
|
|
67
|
+
|
|
68
|
+
|
|
60
69
|
class FedAvgBaseServerHandler(
|
|
61
|
-
BaseServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPackage]
|
|
70
|
+
BaseServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPackage],
|
|
62
71
|
):
|
|
63
72
|
"""
|
|
64
73
|
Server-side handler for the Federated Averaging (FedAvg) algorithm.
|
|
@@ -83,7 +92,7 @@ class FedAvgBaseServerHandler(
|
|
|
83
92
|
self,
|
|
84
93
|
model_selector: ModelSelector,
|
|
85
94
|
model_name: str,
|
|
86
|
-
dataset:
|
|
95
|
+
dataset: FedAvgPartitionedDataset,
|
|
87
96
|
global_round: int,
|
|
88
97
|
num_clients: int,
|
|
89
98
|
sample_ratio: float,
|
|
@@ -241,7 +250,7 @@ class FedAvgBaseServerHandler(
|
|
|
241
250
|
server_loss, server_acc = FedAvgBaseServerHandler.evaluate(
|
|
242
251
|
self.model,
|
|
243
252
|
self.dataset.get_dataloader(
|
|
244
|
-
type_=
|
|
253
|
+
type_=FedAvgPartitionType.TEST,
|
|
245
254
|
cid=None,
|
|
246
255
|
batch_size=self.batch_size,
|
|
247
256
|
),
|
|
@@ -289,7 +298,7 @@ class FedAvgBaseClientTrainer(
|
|
|
289
298
|
self,
|
|
290
299
|
model_selector: ModelSelector,
|
|
291
300
|
model_name: str,
|
|
292
|
-
dataset:
|
|
301
|
+
dataset: FedAvgPartitionedDataset,
|
|
293
302
|
device: str,
|
|
294
303
|
num_clients: int,
|
|
295
304
|
epochs: int,
|
|
@@ -339,14 +348,9 @@ class FedAvgBaseClientTrainer(
|
|
|
339
348
|
model_parameters = payload.model_parameters
|
|
340
349
|
for cid in tqdm(cid_list, desc="Client", leave=False):
|
|
341
350
|
data_loader = self.dataset.get_dataloader(
|
|
342
|
-
type_=
|
|
351
|
+
type_=FedAvgPartitionType.TRAIN, cid=cid, batch_size=self.batch_size
|
|
343
352
|
)
|
|
344
353
|
pack = self.train(model_parameters, data_loader)
|
|
345
|
-
val_loader = self.dataset.get_dataloader(
|
|
346
|
-
type_="val", cid=cid, batch_size=self.batch_size
|
|
347
|
-
)
|
|
348
|
-
loss, acc = self.evaluate(val_loader)
|
|
349
|
-
pack.metadata = {"loss": loss, "acc": acc}
|
|
350
354
|
self.cache.append(pack)
|
|
351
355
|
|
|
352
356
|
def train(
|
|
@@ -457,7 +461,7 @@ class FedAvgClientConfig:
|
|
|
457
461
|
|
|
458
462
|
model_selector: ModelSelector
|
|
459
463
|
model_name: str
|
|
460
|
-
dataset:
|
|
464
|
+
dataset: FedAvgPartitionedDataset
|
|
461
465
|
epochs: int
|
|
462
466
|
batch_size: int
|
|
463
467
|
lr: float
|
|
@@ -501,7 +505,7 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
501
505
|
model_name: str,
|
|
502
506
|
share_dir: Path,
|
|
503
507
|
state_dir: Path,
|
|
504
|
-
dataset:
|
|
508
|
+
dataset: FedAvgPartitionedDataset,
|
|
505
509
|
device: str,
|
|
506
510
|
num_clients: int,
|
|
507
511
|
epochs: int,
|
|
@@ -614,7 +618,7 @@ class FedAvgProcessPoolClientTrainer(
|
|
|
614
618
|
|
|
615
619
|
model = config.model_selector.select_model(config.model_name)
|
|
616
620
|
train_loader = config.dataset.get_dataloader(
|
|
617
|
-
type_=
|
|
621
|
+
type_=FedAvgPartitionType.TRAIN,
|
|
618
622
|
cid=config.cid,
|
|
619
623
|
batch_size=config.batch_size,
|
|
620
624
|
)
|
|
@@ -739,7 +743,7 @@ class FedAvgThreadPoolClientTrainer(
|
|
|
739
743
|
self,
|
|
740
744
|
model_selector: ModelSelector,
|
|
741
745
|
model_name: str,
|
|
742
|
-
dataset:
|
|
746
|
+
dataset: FedAvgPartitionedDataset,
|
|
743
747
|
device: str,
|
|
744
748
|
num_clients: int,
|
|
745
749
|
epochs: int,
|
|
@@ -773,7 +777,7 @@ class FedAvgThreadPoolClientTrainer(
|
|
|
773
777
|
) -> FedAvgUplinkPackage:
|
|
774
778
|
model = self.model_selector.select_model(self.model_name)
|
|
775
779
|
train_loader = self.dataset.get_dataloader(
|
|
776
|
-
type_=
|
|
780
|
+
type_=FedAvgPartitionType.TRAIN,
|
|
777
781
|
cid=cid,
|
|
778
782
|
batch_size=self.batch_size,
|
|
779
783
|
)
|
|
@@ -1,9 +1,12 @@
|
|
|
1
|
-
from
|
|
1
|
+
from enum import StrEnum
|
|
2
|
+
from typing import Protocol, TypeVar
|
|
2
3
|
|
|
3
4
|
from torch.utils.data import DataLoader, Dataset
|
|
4
5
|
|
|
6
|
+
PartitionType = TypeVar("PartitionType", bound=StrEnum, contravariant=True)
|
|
5
7
|
|
|
6
|
-
|
|
8
|
+
|
|
9
|
+
class PartitionedDataset(Protocol[PartitionType]):
|
|
7
10
|
"""
|
|
8
11
|
Abstract base class for partitioned datasets in federated learning.
|
|
9
12
|
|
|
@@ -14,7 +17,7 @@ class PartitionedDataset(Protocol):
|
|
|
14
17
|
NotImplementedError: If the methods are not implemented in a subclass.
|
|
15
18
|
"""
|
|
16
19
|
|
|
17
|
-
def get_dataset(self, type_:
|
|
20
|
+
def get_dataset(self, type_: PartitionType, cid: int | None) -> Dataset:
|
|
18
21
|
"""
|
|
19
22
|
Retrieve a dataset for a specific type and client ID.
|
|
20
23
|
|
|
@@ -28,7 +31,7 @@ class PartitionedDataset(Protocol):
|
|
|
28
31
|
...
|
|
29
32
|
|
|
30
33
|
def get_dataloader(
|
|
31
|
-
self, type_:
|
|
34
|
+
self, type_: PartitionType, cid: int | None, batch_size: int | None
|
|
32
35
|
) -> DataLoader:
|
|
33
36
|
"""
|
|
34
37
|
Retrieve a DataLoader for a specific type, client ID, and batch size.
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from enum import StrEnum
|
|
2
|
+
from torch.utils.data import DataLoader, Dataset
|
|
3
|
+
from typing import Protocol, TypeVar
|
|
4
|
+
|
|
5
|
+
PartitionType = TypeVar('PartitionType', bound=StrEnum, contravariant=True)
|
|
6
|
+
|
|
7
|
+
class PartitionedDataset(Protocol[PartitionType]):
|
|
8
|
+
def get_dataset(self, type_: PartitionType, cid: int | None) -> Dataset: ...
|
|
9
|
+
def get_dataloader(self, type_: PartitionType, cid: int | None, batch_size: int | None) -> DataLoader: ...
|
|
@@ -82,7 +82,7 @@ wheels = [
|
|
|
82
82
|
|
|
83
83
|
[[package]]
|
|
84
84
|
name = "blazefl"
|
|
85
|
-
version = "2.0.0.
|
|
85
|
+
version = "2.0.0.dev7"
|
|
86
86
|
source = { editable = "." }
|
|
87
87
|
dependencies = [
|
|
88
88
|
{ name = "numpy" },
|
|
@@ -828,6 +828,7 @@ dependencies = [
|
|
|
828
828
|
[package.dev-dependencies]
|
|
829
829
|
dev = [
|
|
830
830
|
{ name = "mypy" },
|
|
831
|
+
{ name = "pre-commit" },
|
|
831
832
|
]
|
|
832
833
|
|
|
833
834
|
[package.metadata]
|
|
@@ -839,7 +840,10 @@ requires-dist = [
|
|
|
839
840
|
]
|
|
840
841
|
|
|
841
842
|
[package.metadata.requires-dev]
|
|
842
|
-
dev = [
|
|
843
|
+
dev = [
|
|
844
|
+
{ name = "mypy", specifier = ">=1.13.0" },
|
|
845
|
+
{ name = "pre-commit", specifier = ">=4.2.0" },
|
|
846
|
+
]
|
|
843
847
|
|
|
844
848
|
[[package]]
|
|
845
849
|
name = "requests"
|
|
@@ -1063,6 +1067,7 @@ dependencies = [
|
|
|
1063
1067
|
[package.dev-dependencies]
|
|
1064
1068
|
dev = [
|
|
1065
1069
|
{ name = "mypy" },
|
|
1070
|
+
{ name = "pre-commit" },
|
|
1066
1071
|
]
|
|
1067
1072
|
|
|
1068
1073
|
[package.metadata]
|
|
@@ -1074,7 +1079,10 @@ requires-dist = [
|
|
|
1074
1079
|
]
|
|
1075
1080
|
|
|
1076
1081
|
[package.metadata.requires-dev]
|
|
1077
|
-
dev = [
|
|
1082
|
+
dev = [
|
|
1083
|
+
{ name = "mypy", specifier = ">=1.13.0" },
|
|
1084
|
+
{ name = "pre-commit", specifier = ">=4.2.0" },
|
|
1085
|
+
]
|
|
1078
1086
|
|
|
1079
1087
|
[[package]]
|
|
1080
1088
|
name = "sympy"
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
from torch.utils.data import DataLoader, Dataset
|
|
2
|
-
from typing import Protocol
|
|
3
|
-
|
|
4
|
-
class PartitionedDataset(Protocol):
|
|
5
|
-
def get_dataset(self, type_: str, cid: int | None) -> Dataset: ...
|
|
6
|
-
def get_dataloader(self, type_: str, cid: int | None, batch_size: int | None) -> DataLoader: ...
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/.python-version
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/__init__.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/dataset.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/functional.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/models/__init__.py
RENAMED
|
File without changes
|
{blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/models/selector.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|