blazefl 2.0.0.dev1__tar.gz → 2.0.0.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/Makefile +1 -1
  2. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/PKG-INFO +1 -1
  3. blazefl-2.0.0.dev2/examples/experimental-freethreaded/.python-version +1 -0
  4. blazefl-2.0.0.dev2/examples/experimental-freethreaded/Makefile +6 -0
  5. blazefl-2.0.0.dev2/examples/experimental-freethreaded/config/config.yaml +17 -0
  6. blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/__init__.py +5 -0
  7. blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/client_trainer.py +45 -0
  8. blazefl-2.0.0.dev2/examples/experimental-freethreaded/main.py +220 -0
  9. blazefl-2.0.0.dev2/examples/experimental-freethreaded/pyproject.toml +37 -0
  10. blazefl-2.0.0.dev2/examples/experimental-freethreaded/uv.lock +569 -0
  11. blazefl-2.0.0.dev2/examples/quickstart-fedavg/dataset/__init__.py +3 -0
  12. blazefl-2.0.0.dev2/examples/quickstart-fedavg/dataset/dataset.py +133 -0
  13. blazefl-2.0.0.dev2/examples/quickstart-fedavg/dataset/functional.py +136 -0
  14. blazefl-2.0.0.dev2/examples/quickstart-fedavg/models/__init__.py +3 -0
  15. blazefl-2.0.0.dev2/examples/quickstart-fedavg/models/selector.py +45 -0
  16. blazefl-2.0.0.dev2/examples/step-by-step-dsfl/.gitignore +3 -0
  17. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/pyproject.toml +1 -1
  18. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/__init__.py +6 -1
  19. blazefl-2.0.0.dev2/src/blazefl/core/__init__.pyi +6 -0
  20. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/client_trainer.py +39 -0
  21. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/client_trainer.pyi +9 -0
  22. blazefl-2.0.0.dev2/tests/test_utils/__init__.py +0 -0
  23. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/uv.lock +1 -1
  24. blazefl-2.0.0.dev1/src/blazefl/core/__init__.pyi +0 -6
  25. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/FUNDING.yml +0 -0
  26. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  27. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  28. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/dependabot.yml +0 -0
  29. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/workflows/ci.yaml +0 -0
  30. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/workflows/publish.yaml +0 -0
  31. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.github/workflows/sphinx.yaml +0 -0
  32. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.gitignore +0 -0
  33. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.pre-commit-config.yaml +0 -0
  34. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/.python-version +0 -0
  35. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/CODE_OF_CONDUCT.md +0 -0
  36. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/LICENSE +0 -0
  37. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/README.md +0 -0
  38. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/Makefile +0 -0
  39. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/architecture.png +0 -0
  40. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/benchmark_cnn.png +0 -0
  41. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/benchmark_resnet18.png +0 -0
  42. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/logo.svg +0 -0
  43. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/logo_square.svg +0 -0
  44. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/imgs/ogp.png +0 -0
  45. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/make.bat +0 -0
  46. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/_static/favicon.ico +0 -0
  47. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/_static/logo.png +0 -0
  48. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/_static/logo_square.png +0 -0
  49. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/_templates/autosummary/class.rst +0 -0
  50. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/_templates/autosummary/module.rst +0 -0
  51. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/benchmark.rst +0 -0
  52. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/conf.py +0 -0
  53. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/contribute.rst +0 -0
  54. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/example.rst +0 -0
  55. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/index.rst +0 -0
  56. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/install.rst +0 -0
  57. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/overview.rst +0 -0
  58. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/docs/source/reference.rst +0 -0
  59. {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/.gitignore +0 -0
  60. /blazefl-2.0.0.dev1/src/blazefl/__init__.py → /blazefl-2.0.0.dev2/examples/experimental-freethreaded/README.md +0 -0
  61. {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/dataset/__init__.py +0 -0
  62. {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/dataset/dataset.py +0 -0
  63. {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/dataset/functional.py +0 -0
  64. {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/models/__init__.py +0 -0
  65. {blazefl-2.0.0.dev1/examples/quickstart-fedavg → blazefl-2.0.0.dev2/examples/experimental-freethreaded}/models/selector.py +0 -0
  66. {blazefl-2.0.0.dev1/examples/step-by-step-dsfl → blazefl-2.0.0.dev2/examples/quickstart-fedavg}/.gitignore +0 -0
  67. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/.python-version +0 -0
  68. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/Makefile +0 -0
  69. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/README.md +0 -0
  70. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/config/config.yaml +0 -0
  71. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/main.py +0 -0
  72. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/quickstart-fedavg/pyproject.toml +0 -0
  73. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/.python-version +0 -0
  74. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/Makefile +0 -0
  75. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/README.md +0 -0
  76. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
  77. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/algorithm/dsfl.py +0 -0
  78. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
  79. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/config/config.yaml +0 -0
  80. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
  81. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
  82. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
  83. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/main.py +0 -0
  84. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/models/__init__.py +0 -0
  85. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/models/cnn.py +0 -0
  86. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/models/selector.py +0 -0
  87. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/examples/step-by-step-dsfl/pyproject.toml +0 -0
  88. {blazefl-2.0.0.dev1/tests → blazefl-2.0.0.dev2/src/blazefl}/__init__.py +0 -0
  89. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/contrib/__init__.py +0 -0
  90. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/contrib/fedavg.py +0 -0
  91. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/model_selector.py +0 -0
  92. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/model_selector.pyi +0 -0
  93. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/partitioned_dataset.py +0 -0
  94. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/partitioned_dataset.pyi +0 -0
  95. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/server_handler.py +0 -0
  96. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/core/server_handler.pyi +0 -0
  97. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/py.typed +0 -0
  98. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/__init__.py +0 -0
  99. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/__init__.pyi +0 -0
  100. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/dataset.py +0 -0
  101. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/dataset.pyi +0 -0
  102. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/seed.py +0 -0
  103. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/seed.pyi +0 -0
  104. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/serialize.py +0 -0
  105. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/src/blazefl/utils/serialize.pyi +0 -0
  106. {blazefl-2.0.0.dev1/tests/test_contrib → blazefl-2.0.0.dev2/tests}/__init__.py +0 -0
  107. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/conftest.py +0 -0
  108. {blazefl-2.0.0.dev1/tests/test_core → blazefl-2.0.0.dev2/tests/test_contrib}/__init__.py +0 -0
  109. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_contrib/test_fedavg.py +0 -0
  110. {blazefl-2.0.0.dev1/tests/test_utils → blazefl-2.0.0.dev2/tests/test_core}/__init__.py +0 -0
  111. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_core/test_client_trainer.py +0 -0
  112. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_core/test_model_selector.py +0 -0
  113. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_core/test_partitioned_dataset.py +0 -0
  114. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_utils/test_dataset.py +0 -0
  115. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_utils/test_seed.py +0 -0
  116. {blazefl-2.0.0.dev1 → blazefl-2.0.0.dev2}/tests/test_utils/test_serialize.py +0 -0
@@ -9,4 +9,4 @@ test:
9
9
  pytest -v tests
10
10
 
11
11
  stubgen:
12
- stubgen -m blazefl.core -m blazefl.utils --no-analysis -o src
12
+ stubgen -p blazefl.core -p blazefl.utils --no-analysis -o src
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blazefl
3
- Version: 2.0.0.dev1
3
+ Version: 2.0.0.dev2
4
4
  Summary: A blazing-fast and lightweight simulation framework for Federated Learning.
5
5
  Author-email: kitsuyaazuma <kitsuyaazuma@gmail.com>
6
6
  License-File: LICENSE
@@ -0,0 +1,6 @@
1
+ format:
2
+ ruff format .
3
+
4
+ lint:
5
+ ruff check . --fix
6
+ mypy .
@@ -0,0 +1,17 @@
1
+ model_name: cnn
2
+ num_clients: 100
3
+ global_round: 5
4
+ sample_ratio: 1.0
5
+ partition: shards
6
+ num_shards: 200
7
+ dir_alpha: 1.0
8
+ seed: 42
9
+ epochs: 5
10
+ lr: 0.1
11
+ batch_size: 50
12
+ num_parallels: 10
13
+ dataset_root_dir: /tmp/experimental-freethreaded/dataset
14
+ dataset_split_dir: /tmp/experimental-freethreaded/split
15
+ share_dir: /tmp/experimental-freethreaded/share
16
+ state_dir: /tmp/experimental-freethreaded/state
17
+ execution_mode: multi-thread
@@ -0,0 +1,5 @@
1
+ from core.client_trainer import MultiThreadClientTrainer
2
+
3
+ __all__ = [
4
+ "MultiThreadClientTrainer",
5
+ ]
@@ -0,0 +1,45 @@
1
+ from concurrent.futures import ThreadPoolExecutor, as_completed
2
+ from typing import Protocol, TypeVar
3
+
4
+ from tqdm import tqdm
5
+
6
+ UplinkPackage = TypeVar("UplinkPackage")
7
+ DownlinkPackage = TypeVar("DownlinkPackage", contravariant=True)
8
+
9
+
10
+ class MultiThreadClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
11
+ num_parallels: int
12
+ device: str
13
+ device_count: int
14
+ cache: list[UplinkPackage]
15
+
16
+ def process_client(
17
+ self,
18
+ cid: int,
19
+ device: str,
20
+ payload: DownlinkPackage,
21
+ ) -> UplinkPackage: ...
22
+
23
+ def get_client_device(self, cid: int) -> str:
24
+ if self.device == "cuda":
25
+ return f"cuda:{cid % self.device_count}"
26
+ return self.device
27
+
28
+ def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
29
+ with ThreadPoolExecutor(max_workers=self.num_parallels) as executor:
30
+ futures = []
31
+ for cid in cid_list:
32
+ device = self.get_client_device(cid)
33
+ future = executor.submit(
34
+ self.process_client,
35
+ cid,
36
+ device,
37
+ payload,
38
+ )
39
+ futures.append(future)
40
+
41
+ for future in tqdm(
42
+ as_completed(futures), total=len(futures), desc="Client", leave=False
43
+ ):
44
+ result = future.result()
45
+ self.cache.append(result)
@@ -0,0 +1,220 @@
1
+ import logging
2
+ from copy import deepcopy
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+
6
+ import hydra
7
+ import torch
8
+ import torch.multiprocessing as mp
9
+ from blazefl.contrib import (
10
+ FedAvgParallelClientTrainer,
11
+ FedAvgSerialClientTrainer,
12
+ FedAvgServerHandler,
13
+ )
14
+ from blazefl.contrib.fedavg import FedAvgDownlinkPackage, FedAvgUplinkPackage
15
+ from blazefl.core import ModelSelector, PartitionedDataset
16
+ from blazefl.utils import seed_everything
17
+ from omegaconf import DictConfig, OmegaConf
18
+
19
+ from core.client_trainer import MultiThreadClientTrainer
20
+ from dataset import PartitionedCIFAR10
21
+ from models import FedAvgModelSelector
22
+
23
+
24
+ class FedAvgMultiThreadClientTrainer(
25
+ MultiThreadClientTrainer[
26
+ FedAvgUplinkPackage,
27
+ FedAvgDownlinkPackage,
28
+ ]
29
+ ):
30
+ def __init__(
31
+ self,
32
+ model_selector: ModelSelector,
33
+ model_name: str,
34
+ dataset: PartitionedDataset,
35
+ device: str,
36
+ num_clients: int,
37
+ epochs: int,
38
+ batch_size: int,
39
+ lr: float,
40
+ seed: int,
41
+ num_parallels: int,
42
+ ) -> None:
43
+ self.num_parallels = num_parallels
44
+ self.device = device
45
+ if self.device == "cuda":
46
+ self.device_count = torch.cuda.device_count()
47
+ self.cache: list[FedAvgUplinkPackage] = []
48
+
49
+ self.model_selector = model_selector
50
+ self.model_name = model_name
51
+ self.dataset = dataset
52
+ self.epochs = epochs
53
+ self.batch_size = batch_size
54
+ self.lr = lr
55
+ self.num_clients = num_clients
56
+ self.seed = seed
57
+
58
+ def process_client(
59
+ self,
60
+ cid: int,
61
+ device: str,
62
+ payload: FedAvgDownlinkPackage,
63
+ ) -> FedAvgUplinkPackage:
64
+ model = self.model_selector.select_model(self.model_name)
65
+ train_loader = self.dataset.get_dataloader(
66
+ type_="train",
67
+ cid=cid,
68
+ batch_size=self.batch_size,
69
+ )
70
+ package = FedAvgParallelClientTrainer.train(
71
+ model=model,
72
+ model_parameters=payload.model_parameters,
73
+ train_loader=train_loader,
74
+ device=device,
75
+ epochs=self.epochs,
76
+ lr=self.lr,
77
+ )
78
+ return package
79
+
80
+ def uplink_package(self) -> list[FedAvgUplinkPackage]:
81
+ package = deepcopy(self.cache)
82
+ self.cache = []
83
+ return package
84
+
85
+
86
+ class FedAvgPipeline:
87
+ def __init__(
88
+ self,
89
+ handler: FedAvgServerHandler,
90
+ trainer: FedAvgSerialClientTrainer
91
+ | FedAvgParallelClientTrainer
92
+ | FedAvgMultiThreadClientTrainer,
93
+ ) -> None:
94
+ self.handler = handler
95
+ self.trainer = trainer
96
+
97
+ def main(self):
98
+ while not self.handler.if_stop():
99
+ round_ = self.handler.round
100
+ # server side
101
+ sampled_clients = self.handler.sample_clients()
102
+ broadcast = self.handler.downlink_package()
103
+
104
+ # client side
105
+ self.trainer.local_process(broadcast, sampled_clients)
106
+ uploads = self.trainer.uplink_package()
107
+
108
+ # server side
109
+ for pack in uploads:
110
+ self.handler.load(pack)
111
+
112
+ summary = self.handler.get_summary()
113
+ formatted_summary = ", ".join(f"{k}: {v:.3f}" for k, v in summary.items())
114
+ logging.info(f"round: {round_}, {formatted_summary}")
115
+
116
+ logging.info("done!")
117
+
118
+
119
+ @hydra.main(version_base=None, config_path="config", config_name="config")
120
+ def main(cfg: DictConfig):
121
+ print(OmegaConf.to_yaml(cfg))
122
+
123
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
124
+ dataset_root_dir = Path(cfg.dataset_root_dir)
125
+ dataset_split_dir = dataset_root_dir.joinpath(timestamp)
126
+ share_dir = Path(cfg.share_dir).joinpath(timestamp)
127
+ state_dir = Path(cfg.state_dir).joinpath(timestamp)
128
+
129
+ device = "cpu"
130
+ if torch.cuda.is_available():
131
+ device = "cuda"
132
+ elif torch.backends.mps.is_available():
133
+ device = "mps"
134
+ logging.info(f"device: {device}")
135
+
136
+ seed_everything(cfg.seed, device=device)
137
+
138
+ dataset = PartitionedCIFAR10(
139
+ root=dataset_root_dir,
140
+ path=dataset_split_dir,
141
+ num_clients=cfg.num_clients,
142
+ seed=cfg.seed,
143
+ partition=cfg.partition,
144
+ num_shards=cfg.num_shards,
145
+ dir_alpha=cfg.dir_alpha,
146
+ )
147
+ model_selector = FedAvgModelSelector(num_classes=10)
148
+
149
+ handler = FedAvgServerHandler(
150
+ model_selector=model_selector,
151
+ model_name=cfg.model_name,
152
+ dataset=dataset,
153
+ global_round=cfg.global_round,
154
+ num_clients=cfg.num_clients,
155
+ device=device,
156
+ sample_ratio=cfg.sample_ratio,
157
+ batch_size=cfg.batch_size,
158
+ )
159
+ trainer: (
160
+ FedAvgSerialClientTrainer
161
+ | FedAvgParallelClientTrainer
162
+ | FedAvgMultiThreadClientTrainer
163
+ | None
164
+ ) = None
165
+ match cfg.execution_mode:
166
+ case "serial":
167
+ trainer = FedAvgSerialClientTrainer(
168
+ model_selector=model_selector,
169
+ model_name=cfg.model_name,
170
+ dataset=dataset,
171
+ device=device,
172
+ num_clients=cfg.num_clients,
173
+ epochs=cfg.epochs,
174
+ lr=cfg.lr,
175
+ batch_size=cfg.batch_size,
176
+ )
177
+ case "multi-process":
178
+ trainer = FedAvgParallelClientTrainer(
179
+ model_selector=model_selector,
180
+ model_name=cfg.model_name,
181
+ dataset=dataset,
182
+ share_dir=share_dir,
183
+ state_dir=state_dir,
184
+ seed=cfg.seed,
185
+ device=device,
186
+ num_clients=cfg.num_clients,
187
+ epochs=cfg.epochs,
188
+ lr=cfg.lr,
189
+ batch_size=cfg.batch_size,
190
+ num_parallels=cfg.num_parallels,
191
+ )
192
+ case "multi-thread":
193
+ trainer = FedAvgMultiThreadClientTrainer(
194
+ model_selector=model_selector,
195
+ model_name=cfg.model_name,
196
+ dataset=dataset,
197
+ device=device,
198
+ num_clients=cfg.num_clients,
199
+ epochs=cfg.epochs,
200
+ lr=cfg.lr,
201
+ batch_size=cfg.batch_size,
202
+ num_parallels=cfg.num_parallels,
203
+ seed=cfg.seed,
204
+ )
205
+ case _:
206
+ raise ValueError(f"Invalid execution mode: {cfg.execution_mode}")
207
+ pipeline = FedAvgPipeline(handler=handler, trainer=trainer)
208
+ try:
209
+ pipeline.main()
210
+ except KeyboardInterrupt:
211
+ logging.info("KeyboardInterrupt: Stopping the pipeline.")
212
+ except Exception as e:
213
+ logging.exception(f"An error occurred: {e}")
214
+
215
+
216
+ if __name__ == "__main__":
217
+ # NOTE: To use CUDA with multiprocessing, you must use the 'spawn' start method
218
+ mp.set_start_method("spawn")
219
+
220
+ main()
@@ -0,0 +1,37 @@
1
+ [project]
2
+ name = "experimental-freethreaded"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "blazefl>=2.0.0dev1",
9
+ "hydra-core>=1.3.2",
10
+ "torch>=2.7.1",
11
+ "torchvision>=0.22.1",
12
+ ]
13
+
14
+ [tool.basedpyright]
15
+ typeCheckingMode = "standard"
16
+
17
+ [[tool.mypy.overrides]]
18
+ module = ["torchvision.*"]
19
+ ignore_missing_imports = true
20
+
21
+ [tool.ruff.lint]
22
+ select = [
23
+ "E", # pycodestyle
24
+ "F", # Pyflakes
25
+ "UP", # pyupgrade
26
+ "B", # flake8-bugbear
27
+ "SIM", # flake8-simplify
28
+ "I", # isort
29
+ ]
30
+ ignore = []
31
+ fixable = ["ALL"]
32
+
33
+ [dependency-groups]
34
+ dev = [
35
+ "mypy>=1.16.0",
36
+ "types-tqdm>=4.67.0.20250516",
37
+ ]