blazefl 2.0.0.dev5__tar.gz → 2.0.0.dev7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/PKG-INFO +1 -1
  2. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/config/config.yaml +1 -0
  3. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/main.py +20 -83
  4. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/pyproject.toml +1 -1
  5. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/uv.lock +4 -4
  6. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/config/config.yaml +1 -1
  7. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/dataset/dataset.py +6 -7
  8. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/main.py +54 -32
  9. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/pyproject.toml +1 -0
  10. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/algorithm/dsfl.py +23 -7
  11. blazefl-2.0.0.dev7/examples/step-by-step-dsfl/dataset/__init__.py +3 -0
  12. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/dataset/dataset.py +14 -9
  13. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/main.py +1 -3
  14. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/pyproject.toml +1 -0
  15. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/pyproject.toml +1 -1
  16. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/contrib/__init__.py +4 -0
  17. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/contrib/fedavg.py +87 -22
  18. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/core/client_trainer.py +45 -33
  19. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/core/client_trainer.pyi +4 -3
  20. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/core/partitioned_dataset.py +7 -4
  21. blazefl-2.0.0.dev7/src/blazefl/core/partitioned_dataset.pyi +9 -0
  22. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/test_contrib/test_fedavg.py +67 -59
  23. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/test_core/test_client_trainer.py +14 -3
  24. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/uv.lock +11 -3
  25. blazefl-2.0.0.dev5/examples/step-by-step-dsfl/dataset/__init__.py +0 -3
  26. blazefl-2.0.0.dev5/src/blazefl/core/partitioned_dataset.pyi +0 -6
  27. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/.github/FUNDING.yml +0 -0
  28. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  29. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  30. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/.github/dependabot.yml +0 -0
  31. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/.github/workflows/ci.yaml +0 -0
  32. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/.github/workflows/publish.yaml +0 -0
  33. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/.github/workflows/sphinx.yaml +0 -0
  34. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/.gitignore +0 -0
  35. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/.pre-commit-config.yaml +0 -0
  36. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/.python-version +0 -0
  37. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/CODE_OF_CONDUCT.md +0 -0
  38. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/LICENSE +0 -0
  39. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/Makefile +0 -0
  40. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/README.md +0 -0
  41. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/Makefile +0 -0
  42. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/imgs/architecture.png +0 -0
  43. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/imgs/benchmark_cnn.png +0 -0
  44. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/imgs/benchmark_resnet18.png +0 -0
  45. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/imgs/logo.svg +0 -0
  46. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/imgs/logo_square.svg +0 -0
  47. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/imgs/ogp.png +0 -0
  48. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/make.bat +0 -0
  49. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/_static/favicon.ico +0 -0
  50. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/_static/logo.png +0 -0
  51. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/_static/logo_square.png +0 -0
  52. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/_templates/autosummary/class.rst +0 -0
  53. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/_templates/autosummary/module.rst +0 -0
  54. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/benchmark.rst +0 -0
  55. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/conf.py +0 -0
  56. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/contribute.rst +0 -0
  57. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/example.rst +0 -0
  58. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/index.rst +0 -0
  59. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/install.rst +0 -0
  60. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/overview.rst +0 -0
  61. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/docs/source/reference.rst +0 -0
  62. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/.gitignore +0 -0
  63. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/.python-version +0 -0
  64. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/Makefile +0 -0
  65. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/README.md +0 -0
  66. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
  67. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
  68. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/functional.py +0 -0
  69. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/models/__init__.py +0 -0
  70. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/models/selector.py +0 -0
  71. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/.gitignore +0 -0
  72. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/.python-version +0 -0
  73. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/Makefile +0 -0
  74. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/README.md +0 -0
  75. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
  76. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/dataset/functional.py +0 -0
  77. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/models/__init__.py +0 -0
  78. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/models/selector.py +0 -0
  79. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/.gitignore +0 -0
  80. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/.python-version +0 -0
  81. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/Makefile +0 -0
  82. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/README.md +0 -0
  83. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
  84. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
  85. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/config/config.yaml +0 -0
  86. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
  87. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/models/__init__.py +0 -0
  88. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/models/cnn.py +0 -0
  89. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/models/selector.py +0 -0
  90. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/__init__.py +0 -0
  91. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/core/__init__.py +0 -0
  92. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/core/__init__.pyi +0 -0
  93. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/core/model_selector.py +0 -0
  94. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/core/model_selector.pyi +0 -0
  95. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/core/server_handler.py +0 -0
  96. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/core/server_handler.pyi +0 -0
  97. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/py.typed +0 -0
  98. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/utils/__init__.py +0 -0
  99. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/utils/__init__.pyi +0 -0
  100. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/utils/dataset.py +0 -0
  101. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/utils/dataset.pyi +0 -0
  102. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/utils/ipc.py +0 -0
  103. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/utils/ipc.pyi +0 -0
  104. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/utils/seed.py +0 -0
  105. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/utils/seed.pyi +0 -0
  106. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/utils/serialize.py +0 -0
  107. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/src/blazefl/utils/serialize.pyi +0 -0
  108. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/__init__.py +0 -0
  109. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/conftest.py +0 -0
  110. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/test_contrib/__init__.py +0 -0
  111. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/test_core/__init__.py +0 -0
  112. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/test_core/test_model_selector.py +0 -0
  113. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/test_core/test_partitioned_dataset.py +0 -0
  114. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/test_utils/__init__.py +0 -0
  115. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/test_utils/test_dataset.py +0 -0
  116. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/test_utils/test_seed.py +0 -0
  117. {blazefl-2.0.0.dev5 → blazefl-2.0.0.dev7}/tests/test_utils/test_serialize.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blazefl
3
- Version: 2.0.0.dev5
3
+ Version: 2.0.0.dev7
4
4
  Summary: A blazing-fast and lightweight simulation framework for Federated Learning.
5
5
  Author-email: kitsuyaazuma <kitsuyaazuma@gmail.com>
6
6
  License-File: LICENSE
@@ -15,3 +15,4 @@ dataset_split_dir: /tmp/experimental-freethreaded/split
15
15
  share_dir: /tmp/experimental-freethreaded/share
16
16
  state_dir: /tmp/experimental-freethreaded/state
17
17
  execution_mode: multi-thread
18
+ ipc_mode: storage
@@ -1,5 +1,4 @@
1
1
  import logging
2
- from copy import deepcopy
3
2
  from datetime import datetime
4
3
  from pathlib import Path
5
4
 
@@ -7,12 +6,13 @@ import hydra
7
6
  import torch
8
7
  import torch.multiprocessing as mp
9
8
  from blazefl.contrib import (
10
- FedAvgParallelClientTrainer,
11
- FedAvgSerialClientTrainer,
12
- FedAvgServerHandler,
9
+ FedAvgBaseClientTrainer,
10
+ FedAvgBaseServerHandler,
11
+ FedAvgProcessPoolClientTrainer,
12
+ )
13
+ from blazefl.contrib.fedavg import (
14
+ FedAvgThreadPoolClientTrainer,
13
15
  )
14
- from blazefl.contrib.fedavg import FedAvgDownlinkPackage, FedAvgUplinkPackage
15
- from blazefl.core import ModelSelector, MultiThreadClientTrainer, PartitionedDataset
16
16
  from blazefl.utils import seed_everything
17
17
  from omegaconf import DictConfig, OmegaConf
18
18
 
@@ -20,75 +20,13 @@ from dataset import PartitionedCIFAR10
20
20
  from models import FedAvgModelSelector
21
21
 
22
22
 
23
- class FedAvgMultiThreadClientTrainer(
24
- MultiThreadClientTrainer[
25
- FedAvgUplinkPackage,
26
- FedAvgDownlinkPackage,
27
- ]
28
- ):
29
- def __init__(
30
- self,
31
- model_selector: ModelSelector,
32
- model_name: str,
33
- dataset: PartitionedDataset,
34
- device: str,
35
- num_clients: int,
36
- epochs: int,
37
- batch_size: int,
38
- lr: float,
39
- seed: int,
40
- num_parallels: int,
41
- ) -> None:
42
- self.num_parallels = num_parallels
43
- self.device = device
44
- if self.device == "cuda":
45
- self.device_count = torch.cuda.device_count()
46
- self.cache: list[FedAvgUplinkPackage] = []
47
-
48
- self.model_selector = model_selector
49
- self.model_name = model_name
50
- self.dataset = dataset
51
- self.epochs = epochs
52
- self.batch_size = batch_size
53
- self.lr = lr
54
- self.num_clients = num_clients
55
- self.seed = seed
56
-
57
- def process_client(
58
- self,
59
- cid: int,
60
- device: str,
61
- payload: FedAvgDownlinkPackage,
62
- ) -> FedAvgUplinkPackage:
63
- model = self.model_selector.select_model(self.model_name)
64
- train_loader = self.dataset.get_dataloader(
65
- type_="train",
66
- cid=cid,
67
- batch_size=self.batch_size,
68
- )
69
- package = FedAvgParallelClientTrainer.train(
70
- model=model,
71
- model_parameters=payload.model_parameters,
72
- train_loader=train_loader,
73
- device=device,
74
- epochs=self.epochs,
75
- lr=self.lr,
76
- )
77
- return package
78
-
79
- def uplink_package(self) -> list[FedAvgUplinkPackage]:
80
- package = deepcopy(self.cache)
81
- self.cache = []
82
- return package
83
-
84
-
85
23
  class FedAvgPipeline:
86
24
  def __init__(
87
25
  self,
88
- handler: FedAvgServerHandler,
89
- trainer: FedAvgSerialClientTrainer
90
- | FedAvgParallelClientTrainer
91
- | FedAvgMultiThreadClientTrainer,
26
+ handler: FedAvgBaseServerHandler,
27
+ trainer: FedAvgBaseClientTrainer
28
+ | FedAvgProcessPoolClientTrainer
29
+ | FedAvgThreadPoolClientTrainer,
92
30
  ) -> None:
93
31
  self.handler = handler
94
32
  self.trainer = trainer
@@ -145,7 +83,7 @@ def main(cfg: DictConfig):
145
83
  )
146
84
  model_selector = FedAvgModelSelector(num_classes=10)
147
85
 
148
- handler = FedAvgServerHandler(
86
+ handler = FedAvgBaseServerHandler(
149
87
  model_selector=model_selector,
150
88
  model_name=cfg.model_name,
151
89
  dataset=dataset,
@@ -156,14 +94,14 @@ def main(cfg: DictConfig):
156
94
  batch_size=cfg.batch_size,
157
95
  )
158
96
  trainer: (
159
- FedAvgSerialClientTrainer
160
- | FedAvgParallelClientTrainer
161
- | FedAvgMultiThreadClientTrainer
97
+ FedAvgBaseClientTrainer
98
+ | FedAvgProcessPoolClientTrainer
99
+ | FedAvgThreadPoolClientTrainer
162
100
  | None
163
101
  ) = None
164
102
  match cfg.execution_mode:
165
- case "serial":
166
- trainer = FedAvgSerialClientTrainer(
103
+ case "single-thread":
104
+ trainer = FedAvgBaseClientTrainer(
167
105
  model_selector=model_selector,
168
106
  model_name=cfg.model_name,
169
107
  dataset=dataset,
@@ -174,7 +112,7 @@ def main(cfg: DictConfig):
174
112
  batch_size=cfg.batch_size,
175
113
  )
176
114
  case "multi-process":
177
- trainer = FedAvgParallelClientTrainer(
115
+ trainer = FedAvgProcessPoolClientTrainer(
178
116
  model_selector=model_selector,
179
117
  model_name=cfg.model_name,
180
118
  dataset=dataset,
@@ -187,9 +125,10 @@ def main(cfg: DictConfig):
187
125
  lr=cfg.lr,
188
126
  batch_size=cfg.batch_size,
189
127
  num_parallels=cfg.num_parallels,
128
+ ipc_mode=cfg.ipc_mode,
190
129
  )
191
130
  case "multi-thread":
192
- trainer = FedAvgMultiThreadClientTrainer(
131
+ trainer = FedAvgThreadPoolClientTrainer(
193
132
  model_selector=model_selector,
194
133
  model_name=cfg.model_name,
195
134
  dataset=dataset,
@@ -207,9 +146,7 @@ def main(cfg: DictConfig):
207
146
  try:
208
147
  pipeline.main()
209
148
  except KeyboardInterrupt:
210
- logging.info("KeyboardInterrupt: Stopping the pipeline.")
211
- except Exception as e:
212
- logging.exception(f"An error occurred: {e}")
149
+ logging.info("KeyboardInterrupt")
213
150
 
214
151
 
215
152
  if __name__ == "__main__":
@@ -5,7 +5,7 @@ description = "Add your description here"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.13"
7
7
  dependencies = [
8
- "blazefl>=2.0.0dev2",
8
+ "blazefl>=2.0.0dev6",
9
9
  "hydra-core>=1.3.2",
10
10
  "torch>=2.7.1",
11
11
  "torchvision>=0.22.1",
@@ -10,16 +10,16 @@ sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d
10
10
 
11
11
  [[package]]
12
12
  name = "blazefl"
13
- version = "2.0.0.dev2"
13
+ version = "2.0.0.dev6"
14
14
  source = { registry = "https://pypi.org/simple" }
15
15
  dependencies = [
16
16
  { name = "numpy" },
17
17
  { name = "torch" },
18
18
  { name = "tqdm" },
19
19
  ]
20
- sdist = { url = "https://files.pythonhosted.org/packages/9b/92/3f668d7259a81d13d5d2e6089d8022629a09058ddbcae33f5873d3871656/blazefl-2.0.0.dev2.tar.gz", hash = "sha256:28a3e09d6f6cec8d8e8bb4da90adaa0c795980c74f9d4ce65024ba762726b595", size = 613563, upload_time = "2025-06-10T09:24:08.981Z" }
20
+ sdist = { url = "https://files.pythonhosted.org/packages/d7/76/d247942d051447dc1c6dbc25b2d55b4cdf146cc0ebd3594c0c33628f465c/blazefl-2.0.0.dev6.tar.gz", hash = "sha256:d4875f03872917dfd99b3a70ac2c6efc3483ee0e42e484d7328e0810fe5141b9", size = 615380, upload_time = "2025-06-16T16:26:50.929Z" }
21
21
  wheels = [
22
- { url = "https://files.pythonhosted.org/packages/e3/7c/841c8b9d22caaefc1dcc167076292fd1479309b3e11eb6fc1b24141792a8/blazefl-2.0.0.dev2-py3-none-any.whl", hash = "sha256:88639018166dafa79fde59aa5575bfb7277bf9e0e4497f22a59e9c7ec34efa4b", size = 23999, upload_time = "2025-06-10T09:24:07.553Z" },
22
+ { url = "https://files.pythonhosted.org/packages/c8/30/eba32caaed89d04c925f676d48fb50bad7bed3ce2c86c6ebcaf2eb92aff1/blazefl-2.0.0.dev6-py3-none-any.whl", hash = "sha256:e38c3739be9ef2bc7bb8ba213da3319e4be8f9ea9f27ecd661cbc70c0cedd655", size = 26215, upload_time = "2025-06-16T16:26:49.512Z" },
23
23
  ]
24
24
 
25
25
  [[package]]
@@ -50,7 +50,7 @@ dev = [
50
50
 
51
51
  [package.metadata]
52
52
  requires-dist = [
53
- { name = "blazefl", specifier = ">=2.0.0.dev2" },
53
+ { name = "blazefl", specifier = ">=2.0.0.dev6" },
54
54
  { name = "hydra-core", specifier = ">=1.3.2" },
55
55
  { name = "torch", specifier = ">=2.7.1" },
56
56
  { name = "torchvision", specifier = ">=0.22.1" },
@@ -14,5 +14,5 @@ dataset_root_dir: /tmp/quickstart-fedavg/dataset
14
14
  dataset_split_dir: /tmp/quickstart-fedavg/split
15
15
  share_dir: /tmp/quickstart-fedavg/share
16
16
  state_dir: /tmp/quickstart-fedavg/state
17
- parallel: true
17
+ execution_mode: multi-process
18
18
  ipc_mode: storage
@@ -3,6 +3,7 @@ from pathlib import Path
3
3
 
4
4
  import torch
5
5
  import torchvision
6
+ from blazefl.contrib import FedAvgPartitionType
6
7
  from blazefl.core import PartitionedDataset
7
8
  from blazefl.utils import FilteredDataset
8
9
  from torch.utils.data import DataLoader, Dataset
@@ -15,7 +16,7 @@ from dataset.functional import (
15
16
  )
16
17
 
17
18
 
18
- class PartitionedCIFAR10(PartitionedDataset):
19
+ class PartitionedCIFAR10(PartitionedDataset[FedAvgPartitionType]):
19
20
  def __init__(
20
21
  self,
21
22
  root: Path,
@@ -107,24 +108,22 @@ class PartitionedCIFAR10(PartitionedDataset):
107
108
  self.path.joinpath("test.pkl"),
108
109
  )
109
110
 
110
- def get_dataset(self, type_: str, cid: int | None) -> Dataset:
111
+ def get_dataset(self, type_: FedAvgPartitionType, cid: int | None) -> Dataset:
111
112
  match type_:
112
- case "train":
113
+ case FedAvgPartitionType.TRAIN:
113
114
  dataset = torch.load(
114
115
  self.path.joinpath(type_, f"{cid}.pkl"),
115
116
  weights_only=False,
116
117
  )
117
- case "test":
118
+ case FedAvgPartitionType.TEST:
118
119
  dataset = torch.load(
119
120
  self.path.joinpath(f"{type_}.pkl"), weights_only=False
120
121
  )
121
- case _:
122
- raise ValueError(f"Invalid dataset type: {type_}")
123
122
  assert isinstance(dataset, Dataset)
124
123
  return dataset
125
124
 
126
125
  def get_dataloader(
127
- self, type_: str, cid: int | None, batch_size: int | None = None
126
+ self, type_: FedAvgPartitionType, cid: int | None, batch_size: int | None = None
128
127
  ) -> DataLoader:
129
128
  dataset = self.get_dataset(type_, cid)
130
129
  assert isinstance(dataset, Sized)
@@ -9,6 +9,7 @@ from blazefl.contrib import (
9
9
  FedAvgBaseClientTrainer,
10
10
  FedAvgBaseServerHandler,
11
11
  FedAvgProcessPoolClientTrainer,
12
+ FedAvgThreadPoolClientTrainer,
12
13
  )
13
14
  from blazefl.utils import seed_everything
14
15
  from hydra.core import hydra_config
@@ -23,7 +24,9 @@ class FedAvgPipeline:
23
24
  def __init__(
24
25
  self,
25
26
  handler: FedAvgBaseServerHandler,
26
- trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer,
27
+ trainer: FedAvgBaseClientTrainer
28
+ | FedAvgProcessPoolClientTrainer
29
+ | FedAvgThreadPoolClientTrainer,
27
30
  writer: SummaryWriter,
28
31
  ) -> None:
29
32
  self.handler = handler
@@ -97,41 +100,60 @@ def main(cfg: DictConfig):
97
100
  sample_ratio=cfg.sample_ratio,
98
101
  batch_size=cfg.batch_size,
99
102
  )
100
- trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer | None = None
101
- if cfg.parallel:
102
- trainer = FedAvgProcessPoolClientTrainer(
103
- model_selector=model_selector,
104
- model_name=cfg.model_name,
105
- dataset=dataset,
106
- share_dir=share_dir,
107
- state_dir=state_dir,
108
- seed=cfg.seed,
109
- device=device,
110
- num_clients=cfg.num_clients,
111
- epochs=cfg.epochs,
112
- lr=cfg.lr,
113
- batch_size=cfg.batch_size,
114
- num_parallels=cfg.num_parallels,
115
- ipc_mode=cfg.ipc_mode,
116
- )
117
- else:
118
- trainer = FedAvgBaseClientTrainer(
119
- model_selector=model_selector,
120
- model_name=cfg.model_name,
121
- dataset=dataset,
122
- device=device,
123
- num_clients=cfg.num_clients,
124
- epochs=cfg.epochs,
125
- lr=cfg.lr,
126
- batch_size=cfg.batch_size,
127
- )
103
+ trainer: (
104
+ FedAvgBaseClientTrainer
105
+ | FedAvgProcessPoolClientTrainer
106
+ | FedAvgThreadPoolClientTrainer
107
+ | None
108
+ ) = None
109
+ match cfg.execution_mode:
110
+ case "multi-process":
111
+ trainer = FedAvgProcessPoolClientTrainer(
112
+ model_selector=model_selector,
113
+ model_name=cfg.model_name,
114
+ dataset=dataset,
115
+ share_dir=share_dir,
116
+ state_dir=state_dir,
117
+ seed=cfg.seed,
118
+ device=device,
119
+ num_clients=cfg.num_clients,
120
+ epochs=cfg.epochs,
121
+ lr=cfg.lr,
122
+ batch_size=cfg.batch_size,
123
+ num_parallels=cfg.num_parallels,
124
+ ipc_mode=cfg.ipc_mode,
125
+ )
126
+ case "single-thread":
127
+ trainer = FedAvgBaseClientTrainer(
128
+ model_selector=model_selector,
129
+ model_name=cfg.model_name,
130
+ dataset=dataset,
131
+ device=device,
132
+ num_clients=cfg.num_clients,
133
+ epochs=cfg.epochs,
134
+ lr=cfg.lr,
135
+ batch_size=cfg.batch_size,
136
+ )
137
+ case "multi-thread":
138
+ trainer = FedAvgThreadPoolClientTrainer(
139
+ model_selector=model_selector,
140
+ model_name=cfg.model_name,
141
+ dataset=dataset,
142
+ seed=cfg.seed,
143
+ device=device,
144
+ num_clients=cfg.num_clients,
145
+ epochs=cfg.epochs,
146
+ lr=cfg.lr,
147
+ batch_size=cfg.batch_size,
148
+ num_parallels=cfg.num_parallels,
149
+ )
150
+ case _:
151
+ raise ValueError(f"Invalid execution mode: {cfg.execution_mode}")
128
152
  pipeline = FedAvgPipeline(handler=handler, trainer=trainer, writer=writer)
129
153
  try:
130
154
  pipeline.main()
131
155
  except KeyboardInterrupt:
132
- logging.info("KeyboardInterrupt: Stopping the pipeline.")
133
- except Exception as e:
134
- logging.exception(f"An error occurred: {e}")
156
+ logging.info("KeyboardInterrupt")
135
157
 
136
158
 
137
159
  if __name__ == "__main__":
@@ -32,4 +32,5 @@ blazefl = { workspace = true }
32
32
  [dependency-groups]
33
33
  dev = [
34
34
  "mypy>=1.13.0",
35
+ "pre-commit>=4.2.0",
35
36
  ]
@@ -1,10 +1,12 @@
1
1
  import random
2
+ import threading
2
3
  from collections import defaultdict
3
4
  from copy import deepcopy
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
7
 
7
8
  import torch
9
+ import torch.multiprocessing as mp
8
10
  import torch.nn.functional as F
9
11
  from blazefl.core import (
10
12
  BaseServerHandler,
@@ -17,7 +19,7 @@ from blazefl.utils import (
17
19
  )
18
20
  from torch.utils.data import DataLoader, Subset
19
21
 
20
- from dataset import DSFLPartitionedDataset
22
+ from dataset import DSFLPartitionedDataset, DSFLPartitionType
21
23
  from models import DSFLModelSelector
22
24
 
23
25
 
@@ -125,6 +127,7 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
125
127
  self.kd_epochs,
126
128
  self.kd_batch_size,
127
129
  self.device,
130
+ stop_event=None,
128
131
  )
129
132
 
130
133
  self.global_soft_labels = torch.stack(global_soft_labels)
@@ -140,10 +143,11 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
140
143
  kd_epochs: int,
141
144
  kd_batch_size: int,
142
145
  device: str,
146
+ stop_event: threading.Event | None,
143
147
  ) -> None:
144
148
  model.to(device)
145
149
  model.train()
146
- open_dataset = dataset.get_dataset(type_="open", cid=None)
150
+ open_dataset = dataset.get_dataset(type_=DSFLPartitionType.OPEN, cid=None)
147
151
  open_loader = DataLoader(
148
152
  Subset(open_dataset, global_indices),
149
153
  batch_size=kd_batch_size,
@@ -156,6 +160,8 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
156
160
  batch_size=kd_batch_size,
157
161
  )
158
162
  for _ in range(kd_epochs):
163
+ if stop_event is not None and stop_event.is_set():
164
+ break
159
165
  for data, soft_label in zip(
160
166
  open_loader, global_soft_label_loader, strict=True
161
167
  ):
@@ -208,7 +214,7 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
208
214
  server_loss, server_acc = DSFLBaseServerHandler.evaulate(
209
215
  self.model,
210
216
  self.dataset.get_dataloader(
211
- type_="test",
217
+ type_=DSFLPartitionType.TEST,
212
218
  cid=None,
213
219
  batch_size=self.kd_batch_size,
214
220
  ),
@@ -300,6 +306,8 @@ class DSFLProcessPoolClientTrainer(
300
306
  self.num_clients = num_clients
301
307
  self.seed = seed
302
308
  self.ipc_mode = "storage"
309
+ self.manager = mp.Manager()
310
+ self.stop_event = self.manager.Event()
303
311
 
304
312
  if self.device == "cuda":
305
313
  self.device_count = torch.cuda.device_count()
@@ -309,6 +317,7 @@ class DSFLProcessPoolClientTrainer(
309
317
  config: DSFLClientConfig | Path,
310
318
  payload: DSFLDownlinkPackage | Path,
311
319
  device: str,
320
+ stop_event: threading.Event,
312
321
  ) -> Path:
313
322
  assert isinstance(config, Path) and isinstance(payload, Path)
314
323
  config_path, payload_path = config, payload
@@ -334,7 +343,7 @@ class DSFLProcessPoolClientTrainer(
334
343
  seed_everything(c.seed, device=device)
335
344
 
336
345
  # Distill
337
- open_dataset = c.dataset.get_dataset(type_="open", cid=None)
346
+ open_dataset = c.dataset.get_dataset(type_=DSFLPartitionType.OPEN, cid=None)
338
347
  if p.indices is not None and p.soft_labels is not None:
339
348
  global_soft_labels = list(torch.unbind(p.soft_labels, dim=0))
340
349
  global_indices = p.indices.tolist()
@@ -349,11 +358,12 @@ class DSFLProcessPoolClientTrainer(
349
358
  kd_epochs=c.kd_epochs,
350
359
  kd_batch_size=c.kd_batch_size,
351
360
  device=device,
361
+ stop_event=stop_event,
352
362
  )
353
363
 
354
364
  # Train
355
365
  train_loader = c.dataset.get_dataloader(
356
- type_="train",
366
+ type_=DSFLPartitionType.TRAIN,
357
367
  cid=c.cid,
358
368
  batch_size=c.batch_size,
359
369
  )
@@ -363,6 +373,7 @@ class DSFLProcessPoolClientTrainer(
363
373
  train_loader=train_loader,
364
374
  device=device,
365
375
  epochs=c.epochs,
376
+ stop_event=stop_event,
366
377
  )
367
378
 
368
379
  # Predict
@@ -378,7 +389,7 @@ class DSFLProcessPoolClientTrainer(
378
389
 
379
390
  # Evaluate
380
391
  test_loader = c.dataset.get_dataloader(
381
- type_="test",
392
+ type_=DSFLPartitionType.TEST,
382
393
  cid=c.cid,
383
394
  batch_size=c.batch_size,
384
395
  )
@@ -411,12 +422,15 @@ class DSFLProcessPoolClientTrainer(
411
422
  train_loader: DataLoader,
412
423
  device: str,
413
424
  epochs: int,
425
+ stop_event: threading.Event,
414
426
  ) -> None:
415
427
  model.to(device)
416
428
  model.train()
417
429
  criterion = torch.nn.CrossEntropyLoss()
418
430
 
419
431
  for _ in range(epochs):
432
+ if stop_event.is_set():
433
+ break
420
434
  for data, target in train_loader:
421
435
  data = data.to(device)
422
436
  target = target.to(device)
@@ -431,7 +445,9 @@ class DSFLProcessPoolClientTrainer(
431
445
 
432
446
  @staticmethod
433
447
  def predict(
434
- model: torch.nn.Module, open_loader: DataLoader, device: str
448
+ model: torch.nn.Module,
449
+ open_loader: DataLoader,
450
+ device: str,
435
451
  ) -> torch.Tensor:
436
452
  model.to(device)
437
453
  model.eval()
@@ -0,0 +1,3 @@
1
+ from dataset.dataset import DSFLPartitionedDataset, DSFLPartitionType
2
+
3
+ __all__ = ["DSFLPartitionedDataset", "DSFLPartitionType"]
@@ -1,4 +1,5 @@
1
1
  from collections.abc import Sized
2
+ from enum import StrEnum
2
3
  from pathlib import Path
3
4
 
4
5
  import numpy as np
@@ -15,7 +16,13 @@ from dataset.functional import (
15
16
  )
16
17
 
17
18
 
18
- class DSFLPartitionedDataset(PartitionedDataset):
19
+ class DSFLPartitionType(StrEnum):
20
+ TRAIN = "train"
21
+ OPEN = "open"
22
+ TEST = "test"
23
+
24
+
25
+ class DSFLPartitionedDataset(PartitionedDataset[DSFLPartitionType]):
19
26
  def __init__(
20
27
  self,
21
28
  root: Path,
@@ -68,7 +75,7 @@ class DSFLPartitionedDataset(PartitionedDataset):
68
75
  train=False,
69
76
  download=True,
70
77
  )
71
- for type_ in ["train", "open", "test"]:
78
+ for type_ in [ds.value for ds in DSFLPartitionType]:
72
79
  self.path.joinpath(type_).mkdir(parents=True)
73
80
 
74
81
  match self.partition:
@@ -141,19 +148,19 @@ class DSFLPartitionedDataset(PartitionedDataset):
141
148
  self.path.joinpath("test", "default.pkl"),
142
149
  )
143
150
 
144
- def get_dataset(self, type_: str, cid: int | None) -> Dataset:
151
+ def get_dataset(self, type_: DSFLPartitionType, cid: int | None) -> Dataset:
145
152
  match type_:
146
- case "train":
153
+ case DSFLPartitionType.TRAIN:
147
154
  dataset = torch.load(
148
155
  self.path.joinpath(type_, f"{cid}.pkl"),
149
156
  weights_only=False,
150
157
  )
151
- case "open":
158
+ case DSFLPartitionType.OPEN:
152
159
  dataset = torch.load(
153
160
  self.path.joinpath(f"{type_}.pkl"),
154
161
  weights_only=False,
155
162
  )
156
- case "test":
163
+ case DSFLPartitionType.TEST:
157
164
  if cid is not None:
158
165
  dataset = torch.load(
159
166
  self.path.joinpath(type_, f"{cid}.pkl"),
@@ -163,13 +170,11 @@ class DSFLPartitionedDataset(PartitionedDataset):
163
170
  dataset = torch.load(
164
171
  self.path.joinpath(type_, "default.pkl"), weights_only=False
165
172
  )
166
- case _:
167
- raise ValueError(f"Invalid dataset type: {type_}")
168
173
  assert isinstance(dataset, Dataset)
169
174
  return dataset
170
175
 
171
176
  def get_dataloader(
172
- self, type_: str, cid: int | None, batch_size: int | None = None
177
+ self, type_: DSFLPartitionType, cid: int | None, batch_size: int | None = None
173
178
  ) -> DataLoader:
174
179
  dataset = self.get_dataset(type_, cid)
175
180
  assert isinstance(dataset, Sized)
@@ -120,9 +120,7 @@ def main(
120
120
  try:
121
121
  pipeline.main()
122
122
  except KeyboardInterrupt:
123
- logging.info("KeyboardInterrupt: Stopping the pipeline.")
124
- except Exception as e:
125
- logging.exception(f"An error occurred: {e}")
123
+ logging.info("KeyboardInterrupt")
126
124
 
127
125
 
128
126
  if __name__ == "__main__":
@@ -36,4 +36,5 @@ blazefl = { workspace = true }
36
36
  [dependency-groups]
37
37
  dev = [
38
38
  "mypy>=1.13.0",
39
+ "pre-commit>=4.2.0",
39
40
  ]
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blazefl"
3
- version = "2.0.0.dev5"
3
+ version = "2.0.0.dev7"
4
4
  description = "A blazing-fast and lightweight simulation framework for Federated Learning."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -9,6 +9,8 @@ from blazefl.contrib.fedavg import (
9
9
  FedAvgBaseClientTrainer,
10
10
  FedAvgBaseServerHandler,
11
11
  FedAvgDownlinkPackage,
12
+ FedAvgPartitionedDataset,
13
+ FedAvgPartitionType,
12
14
  FedAvgProcessPoolClientTrainer,
13
15
  FedAvgThreadPoolClientTrainer,
14
16
  FedAvgUplinkPackage,
@@ -21,4 +23,6 @@ __all__ = [
21
23
  "FedAvgThreadPoolClientTrainer",
22
24
  "FedAvgUplinkPackage",
23
25
  "FedAvgDownlinkPackage",
26
+ "FedAvgPartitionType",
27
+ "FedAvgPartitionedDataset",
24
28
  ]