blazefl 2.0.0.dev6__tar.gz → 2.0.0.dev7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/PKG-INFO +1 -1
  2. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/config/config.yaml +1 -0
  3. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/main.py +20 -83
  4. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/pyproject.toml +1 -1
  5. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/uv.lock +4 -4
  6. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/dataset/dataset.py +6 -7
  7. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/pyproject.toml +1 -0
  8. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/algorithm/dsfl.py +23 -7
  9. blazefl-2.0.0.dev7/examples/step-by-step-dsfl/dataset/__init__.py +3 -0
  10. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/dataset/dataset.py +14 -9
  11. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/pyproject.toml +1 -0
  12. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/pyproject.toml +1 -1
  13. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/contrib/__init__.py +4 -0
  14. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/contrib/fedavg.py +19 -15
  15. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/partitioned_dataset.py +7 -4
  16. blazefl-2.0.0.dev7/src/blazefl/core/partitioned_dataset.pyi +9 -0
  17. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/uv.lock +11 -3
  18. blazefl-2.0.0.dev6/examples/step-by-step-dsfl/dataset/__init__.py +0 -3
  19. blazefl-2.0.0.dev6/src/blazefl/core/partitioned_dataset.pyi +0 -6
  20. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/FUNDING.yml +0 -0
  21. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  22. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  23. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/dependabot.yml +0 -0
  24. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/workflows/ci.yaml +0 -0
  25. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/workflows/publish.yaml +0 -0
  26. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.github/workflows/sphinx.yaml +0 -0
  27. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.gitignore +0 -0
  28. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.pre-commit-config.yaml +0 -0
  29. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/.python-version +0 -0
  30. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/CODE_OF_CONDUCT.md +0 -0
  31. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/LICENSE +0 -0
  32. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/Makefile +0 -0
  33. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/README.md +0 -0
  34. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/Makefile +0 -0
  35. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/architecture.png +0 -0
  36. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/benchmark_cnn.png +0 -0
  37. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/benchmark_resnet18.png +0 -0
  38. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/logo.svg +0 -0
  39. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/logo_square.svg +0 -0
  40. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/imgs/ogp.png +0 -0
  41. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/make.bat +0 -0
  42. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/_static/favicon.ico +0 -0
  43. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/_static/logo.png +0 -0
  44. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/_static/logo_square.png +0 -0
  45. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/_templates/autosummary/class.rst +0 -0
  46. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/_templates/autosummary/module.rst +0 -0
  47. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/benchmark.rst +0 -0
  48. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/conf.py +0 -0
  49. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/contribute.rst +0 -0
  50. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/example.rst +0 -0
  51. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/index.rst +0 -0
  52. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/install.rst +0 -0
  53. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/overview.rst +0 -0
  54. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/docs/source/reference.rst +0 -0
  55. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/.gitignore +0 -0
  56. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/.python-version +0 -0
  57. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/Makefile +0 -0
  58. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/README.md +0 -0
  59. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
  60. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
  61. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/dataset/functional.py +0 -0
  62. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/models/__init__.py +0 -0
  63. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/experimental-freethreaded/models/selector.py +0 -0
  64. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/.gitignore +0 -0
  65. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/.python-version +0 -0
  66. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/Makefile +0 -0
  67. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/README.md +0 -0
  68. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/config/config.yaml +0 -0
  69. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
  70. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/dataset/functional.py +0 -0
  71. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/main.py +0 -0
  72. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/models/__init__.py +0 -0
  73. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/quickstart-fedavg/models/selector.py +0 -0
  74. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/.gitignore +0 -0
  75. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/.python-version +0 -0
  76. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/Makefile +0 -0
  77. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/README.md +0 -0
  78. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
  79. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
  80. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/config/config.yaml +0 -0
  81. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
  82. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/main.py +0 -0
  83. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/models/__init__.py +0 -0
  84. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/models/cnn.py +0 -0
  85. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/examples/step-by-step-dsfl/models/selector.py +0 -0
  86. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/__init__.py +0 -0
  87. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/__init__.py +0 -0
  88. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/__init__.pyi +0 -0
  89. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/client_trainer.py +0 -0
  90. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/client_trainer.pyi +0 -0
  91. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/model_selector.py +0 -0
  92. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/model_selector.pyi +0 -0
  93. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/server_handler.py +0 -0
  94. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/core/server_handler.pyi +0 -0
  95. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/py.typed +0 -0
  96. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/__init__.py +0 -0
  97. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/__init__.pyi +0 -0
  98. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/dataset.py +0 -0
  99. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/dataset.pyi +0 -0
  100. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/ipc.py +0 -0
  101. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/ipc.pyi +0 -0
  102. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/seed.py +0 -0
  103. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/seed.pyi +0 -0
  104. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/serialize.py +0 -0
  105. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/src/blazefl/utils/serialize.pyi +0 -0
  106. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/__init__.py +0 -0
  107. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/conftest.py +0 -0
  108. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_contrib/__init__.py +0 -0
  109. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_contrib/test_fedavg.py +0 -0
  110. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_core/__init__.py +0 -0
  111. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_core/test_client_trainer.py +0 -0
  112. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_core/test_model_selector.py +0 -0
  113. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_core/test_partitioned_dataset.py +0 -0
  114. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_utils/__init__.py +0 -0
  115. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_utils/test_dataset.py +0 -0
  116. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_utils/test_seed.py +0 -0
  117. {blazefl-2.0.0.dev6 → blazefl-2.0.0.dev7}/tests/test_utils/test_serialize.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blazefl
3
- Version: 2.0.0.dev6
3
+ Version: 2.0.0.dev7
4
4
  Summary: A blazing-fast and lightweight simulation framework for Federated Learning.
5
5
  Author-email: kitsuyaazuma <kitsuyaazuma@gmail.com>
6
6
  License-File: LICENSE
@@ -15,3 +15,4 @@ dataset_split_dir: /tmp/experimental-freethreaded/split
15
15
  share_dir: /tmp/experimental-freethreaded/share
16
16
  state_dir: /tmp/experimental-freethreaded/state
17
17
  execution_mode: multi-thread
18
+ ipc_mode: storage
@@ -1,5 +1,4 @@
1
1
  import logging
2
- from copy import deepcopy
3
2
  from datetime import datetime
4
3
  from pathlib import Path
5
4
 
@@ -7,12 +6,13 @@ import hydra
7
6
  import torch
8
7
  import torch.multiprocessing as mp
9
8
  from blazefl.contrib import (
10
- FedAvgParallelClientTrainer,
11
- FedAvgSerialClientTrainer,
12
- FedAvgServerHandler,
9
+ FedAvgBaseClientTrainer,
10
+ FedAvgBaseServerHandler,
11
+ FedAvgProcessPoolClientTrainer,
12
+ )
13
+ from blazefl.contrib.fedavg import (
14
+ FedAvgThreadPoolClientTrainer,
13
15
  )
14
- from blazefl.contrib.fedavg import FedAvgDownlinkPackage, FedAvgUplinkPackage
15
- from blazefl.core import ModelSelector, MultiThreadClientTrainer, PartitionedDataset
16
16
  from blazefl.utils import seed_everything
17
17
  from omegaconf import DictConfig, OmegaConf
18
18
 
@@ -20,75 +20,13 @@ from dataset import PartitionedCIFAR10
20
20
  from models import FedAvgModelSelector
21
21
 
22
22
 
23
- class FedAvgMultiThreadClientTrainer(
24
- MultiThreadClientTrainer[
25
- FedAvgUplinkPackage,
26
- FedAvgDownlinkPackage,
27
- ]
28
- ):
29
- def __init__(
30
- self,
31
- model_selector: ModelSelector,
32
- model_name: str,
33
- dataset: PartitionedDataset,
34
- device: str,
35
- num_clients: int,
36
- epochs: int,
37
- batch_size: int,
38
- lr: float,
39
- seed: int,
40
- num_parallels: int,
41
- ) -> None:
42
- self.num_parallels = num_parallels
43
- self.device = device
44
- if self.device == "cuda":
45
- self.device_count = torch.cuda.device_count()
46
- self.cache: list[FedAvgUplinkPackage] = []
47
-
48
- self.model_selector = model_selector
49
- self.model_name = model_name
50
- self.dataset = dataset
51
- self.epochs = epochs
52
- self.batch_size = batch_size
53
- self.lr = lr
54
- self.num_clients = num_clients
55
- self.seed = seed
56
-
57
- def process_client(
58
- self,
59
- cid: int,
60
- device: str,
61
- payload: FedAvgDownlinkPackage,
62
- ) -> FedAvgUplinkPackage:
63
- model = self.model_selector.select_model(self.model_name)
64
- train_loader = self.dataset.get_dataloader(
65
- type_="train",
66
- cid=cid,
67
- batch_size=self.batch_size,
68
- )
69
- package = FedAvgParallelClientTrainer.train(
70
- model=model,
71
- model_parameters=payload.model_parameters,
72
- train_loader=train_loader,
73
- device=device,
74
- epochs=self.epochs,
75
- lr=self.lr,
76
- )
77
- return package
78
-
79
- def uplink_package(self) -> list[FedAvgUplinkPackage]:
80
- package = deepcopy(self.cache)
81
- self.cache = []
82
- return package
83
-
84
-
85
23
  class FedAvgPipeline:
86
24
  def __init__(
87
25
  self,
88
- handler: FedAvgServerHandler,
89
- trainer: FedAvgSerialClientTrainer
90
- | FedAvgParallelClientTrainer
91
- | FedAvgMultiThreadClientTrainer,
26
+ handler: FedAvgBaseServerHandler,
27
+ trainer: FedAvgBaseClientTrainer
28
+ | FedAvgProcessPoolClientTrainer
29
+ | FedAvgThreadPoolClientTrainer,
92
30
  ) -> None:
93
31
  self.handler = handler
94
32
  self.trainer = trainer
@@ -145,7 +83,7 @@ def main(cfg: DictConfig):
145
83
  )
146
84
  model_selector = FedAvgModelSelector(num_classes=10)
147
85
 
148
- handler = FedAvgServerHandler(
86
+ handler = FedAvgBaseServerHandler(
149
87
  model_selector=model_selector,
150
88
  model_name=cfg.model_name,
151
89
  dataset=dataset,
@@ -156,14 +94,14 @@ def main(cfg: DictConfig):
156
94
  batch_size=cfg.batch_size,
157
95
  )
158
96
  trainer: (
159
- FedAvgSerialClientTrainer
160
- | FedAvgParallelClientTrainer
161
- | FedAvgMultiThreadClientTrainer
97
+ FedAvgBaseClientTrainer
98
+ | FedAvgProcessPoolClientTrainer
99
+ | FedAvgThreadPoolClientTrainer
162
100
  | None
163
101
  ) = None
164
102
  match cfg.execution_mode:
165
- case "serial":
166
- trainer = FedAvgSerialClientTrainer(
103
+ case "single-thread":
104
+ trainer = FedAvgBaseClientTrainer(
167
105
  model_selector=model_selector,
168
106
  model_name=cfg.model_name,
169
107
  dataset=dataset,
@@ -174,7 +112,7 @@ def main(cfg: DictConfig):
174
112
  batch_size=cfg.batch_size,
175
113
  )
176
114
  case "multi-process":
177
- trainer = FedAvgParallelClientTrainer(
115
+ trainer = FedAvgProcessPoolClientTrainer(
178
116
  model_selector=model_selector,
179
117
  model_name=cfg.model_name,
180
118
  dataset=dataset,
@@ -187,9 +125,10 @@ def main(cfg: DictConfig):
187
125
  lr=cfg.lr,
188
126
  batch_size=cfg.batch_size,
189
127
  num_parallels=cfg.num_parallels,
128
+ ipc_mode=cfg.ipc_mode,
190
129
  )
191
130
  case "multi-thread":
192
- trainer = FedAvgMultiThreadClientTrainer(
131
+ trainer = FedAvgThreadPoolClientTrainer(
193
132
  model_selector=model_selector,
194
133
  model_name=cfg.model_name,
195
134
  dataset=dataset,
@@ -207,9 +146,7 @@ def main(cfg: DictConfig):
207
146
  try:
208
147
  pipeline.main()
209
148
  except KeyboardInterrupt:
210
- logging.info("KeyboardInterrupt: Stopping the pipeline.")
211
- except Exception as e:
212
- logging.exception(f"An error occurred: {e}")
149
+ logging.info("KeyboardInterrupt")
213
150
 
214
151
 
215
152
  if __name__ == "__main__":
@@ -5,7 +5,7 @@ description = "Add your description here"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.13"
7
7
  dependencies = [
8
- "blazefl>=2.0.0dev2",
8
+ "blazefl>=2.0.0dev6",
9
9
  "hydra-core>=1.3.2",
10
10
  "torch>=2.7.1",
11
11
  "torchvision>=0.22.1",
@@ -10,16 +10,16 @@ sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d
10
10
 
11
11
  [[package]]
12
12
  name = "blazefl"
13
- version = "2.0.0.dev2"
13
+ version = "2.0.0.dev6"
14
14
  source = { registry = "https://pypi.org/simple" }
15
15
  dependencies = [
16
16
  { name = "numpy" },
17
17
  { name = "torch" },
18
18
  { name = "tqdm" },
19
19
  ]
20
- sdist = { url = "https://files.pythonhosted.org/packages/9b/92/3f668d7259a81d13d5d2e6089d8022629a09058ddbcae33f5873d3871656/blazefl-2.0.0.dev2.tar.gz", hash = "sha256:28a3e09d6f6cec8d8e8bb4da90adaa0c795980c74f9d4ce65024ba762726b595", size = 613563, upload_time = "2025-06-10T09:24:08.981Z" }
20
+ sdist = { url = "https://files.pythonhosted.org/packages/d7/76/d247942d051447dc1c6dbc25b2d55b4cdf146cc0ebd3594c0c33628f465c/blazefl-2.0.0.dev6.tar.gz", hash = "sha256:d4875f03872917dfd99b3a70ac2c6efc3483ee0e42e484d7328e0810fe5141b9", size = 615380, upload_time = "2025-06-16T16:26:50.929Z" }
21
21
  wheels = [
22
- { url = "https://files.pythonhosted.org/packages/e3/7c/841c8b9d22caaefc1dcc167076292fd1479309b3e11eb6fc1b24141792a8/blazefl-2.0.0.dev2-py3-none-any.whl", hash = "sha256:88639018166dafa79fde59aa5575bfb7277bf9e0e4497f22a59e9c7ec34efa4b", size = 23999, upload_time = "2025-06-10T09:24:07.553Z" },
22
+ { url = "https://files.pythonhosted.org/packages/c8/30/eba32caaed89d04c925f676d48fb50bad7bed3ce2c86c6ebcaf2eb92aff1/blazefl-2.0.0.dev6-py3-none-any.whl", hash = "sha256:e38c3739be9ef2bc7bb8ba213da3319e4be8f9ea9f27ecd661cbc70c0cedd655", size = 26215, upload_time = "2025-06-16T16:26:49.512Z" },
23
23
  ]
24
24
 
25
25
  [[package]]
@@ -50,7 +50,7 @@ dev = [
50
50
 
51
51
  [package.metadata]
52
52
  requires-dist = [
53
- { name = "blazefl", specifier = ">=2.0.0.dev2" },
53
+ { name = "blazefl", specifier = ">=2.0.0.dev6" },
54
54
  { name = "hydra-core", specifier = ">=1.3.2" },
55
55
  { name = "torch", specifier = ">=2.7.1" },
56
56
  { name = "torchvision", specifier = ">=0.22.1" },
@@ -3,6 +3,7 @@ from pathlib import Path
3
3
 
4
4
  import torch
5
5
  import torchvision
6
+ from blazefl.contrib import FedAvgPartitionType
6
7
  from blazefl.core import PartitionedDataset
7
8
  from blazefl.utils import FilteredDataset
8
9
  from torch.utils.data import DataLoader, Dataset
@@ -15,7 +16,7 @@ from dataset.functional import (
15
16
  )
16
17
 
17
18
 
18
- class PartitionedCIFAR10(PartitionedDataset):
19
+ class PartitionedCIFAR10(PartitionedDataset[FedAvgPartitionType]):
19
20
  def __init__(
20
21
  self,
21
22
  root: Path,
@@ -107,24 +108,22 @@ class PartitionedCIFAR10(PartitionedDataset):
107
108
  self.path.joinpath("test.pkl"),
108
109
  )
109
110
 
110
- def get_dataset(self, type_: str, cid: int | None) -> Dataset:
111
+ def get_dataset(self, type_: FedAvgPartitionType, cid: int | None) -> Dataset:
111
112
  match type_:
112
- case "train":
113
+ case FedAvgPartitionType.TRAIN:
113
114
  dataset = torch.load(
114
115
  self.path.joinpath(type_, f"{cid}.pkl"),
115
116
  weights_only=False,
116
117
  )
117
- case "test":
118
+ case FedAvgPartitionType.TEST:
118
119
  dataset = torch.load(
119
120
  self.path.joinpath(f"{type_}.pkl"), weights_only=False
120
121
  )
121
- case _:
122
- raise ValueError(f"Invalid dataset type: {type_}")
123
122
  assert isinstance(dataset, Dataset)
124
123
  return dataset
125
124
 
126
125
  def get_dataloader(
127
- self, type_: str, cid: int | None, batch_size: int | None = None
126
+ self, type_: FedAvgPartitionType, cid: int | None, batch_size: int | None = None
128
127
  ) -> DataLoader:
129
128
  dataset = self.get_dataset(type_, cid)
130
129
  assert isinstance(dataset, Sized)
@@ -32,4 +32,5 @@ blazefl = { workspace = true }
32
32
  [dependency-groups]
33
33
  dev = [
34
34
  "mypy>=1.13.0",
35
+ "pre-commit>=4.2.0",
35
36
  ]
@@ -1,10 +1,12 @@
1
1
  import random
2
+ import threading
2
3
  from collections import defaultdict
3
4
  from copy import deepcopy
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
7
 
7
8
  import torch
9
+ import torch.multiprocessing as mp
8
10
  import torch.nn.functional as F
9
11
  from blazefl.core import (
10
12
  BaseServerHandler,
@@ -17,7 +19,7 @@ from blazefl.utils import (
17
19
  )
18
20
  from torch.utils.data import DataLoader, Subset
19
21
 
20
- from dataset import DSFLPartitionedDataset
22
+ from dataset import DSFLPartitionedDataset, DSFLPartitionType
21
23
  from models import DSFLModelSelector
22
24
 
23
25
 
@@ -125,6 +127,7 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
125
127
  self.kd_epochs,
126
128
  self.kd_batch_size,
127
129
  self.device,
130
+ stop_event=None,
128
131
  )
129
132
 
130
133
  self.global_soft_labels = torch.stack(global_soft_labels)
@@ -140,10 +143,11 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
140
143
  kd_epochs: int,
141
144
  kd_batch_size: int,
142
145
  device: str,
146
+ stop_event: threading.Event | None,
143
147
  ) -> None:
144
148
  model.to(device)
145
149
  model.train()
146
- open_dataset = dataset.get_dataset(type_="open", cid=None)
150
+ open_dataset = dataset.get_dataset(type_=DSFLPartitionType.OPEN, cid=None)
147
151
  open_loader = DataLoader(
148
152
  Subset(open_dataset, global_indices),
149
153
  batch_size=kd_batch_size,
@@ -156,6 +160,8 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
156
160
  batch_size=kd_batch_size,
157
161
  )
158
162
  for _ in range(kd_epochs):
163
+ if stop_event is not None and stop_event.is_set():
164
+ break
159
165
  for data, soft_label in zip(
160
166
  open_loader, global_soft_label_loader, strict=True
161
167
  ):
@@ -208,7 +214,7 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
208
214
  server_loss, server_acc = DSFLBaseServerHandler.evaulate(
209
215
  self.model,
210
216
  self.dataset.get_dataloader(
211
- type_="test",
217
+ type_=DSFLPartitionType.TEST,
212
218
  cid=None,
213
219
  batch_size=self.kd_batch_size,
214
220
  ),
@@ -300,6 +306,8 @@ class DSFLProcessPoolClientTrainer(
300
306
  self.num_clients = num_clients
301
307
  self.seed = seed
302
308
  self.ipc_mode = "storage"
309
+ self.manager = mp.Manager()
310
+ self.stop_event = self.manager.Event()
303
311
 
304
312
  if self.device == "cuda":
305
313
  self.device_count = torch.cuda.device_count()
@@ -309,6 +317,7 @@ class DSFLProcessPoolClientTrainer(
309
317
  config: DSFLClientConfig | Path,
310
318
  payload: DSFLDownlinkPackage | Path,
311
319
  device: str,
320
+ stop_event: threading.Event,
312
321
  ) -> Path:
313
322
  assert isinstance(config, Path) and isinstance(payload, Path)
314
323
  config_path, payload_path = config, payload
@@ -334,7 +343,7 @@ class DSFLProcessPoolClientTrainer(
334
343
  seed_everything(c.seed, device=device)
335
344
 
336
345
  # Distill
337
- open_dataset = c.dataset.get_dataset(type_="open", cid=None)
346
+ open_dataset = c.dataset.get_dataset(type_=DSFLPartitionType.OPEN, cid=None)
338
347
  if p.indices is not None and p.soft_labels is not None:
339
348
  global_soft_labels = list(torch.unbind(p.soft_labels, dim=0))
340
349
  global_indices = p.indices.tolist()
@@ -349,11 +358,12 @@ class DSFLProcessPoolClientTrainer(
349
358
  kd_epochs=c.kd_epochs,
350
359
  kd_batch_size=c.kd_batch_size,
351
360
  device=device,
361
+ stop_event=stop_event,
352
362
  )
353
363
 
354
364
  # Train
355
365
  train_loader = c.dataset.get_dataloader(
356
- type_="train",
366
+ type_=DSFLPartitionType.TRAIN,
357
367
  cid=c.cid,
358
368
  batch_size=c.batch_size,
359
369
  )
@@ -363,6 +373,7 @@ class DSFLProcessPoolClientTrainer(
363
373
  train_loader=train_loader,
364
374
  device=device,
365
375
  epochs=c.epochs,
376
+ stop_event=stop_event,
366
377
  )
367
378
 
368
379
  # Predict
@@ -378,7 +389,7 @@ class DSFLProcessPoolClientTrainer(
378
389
 
379
390
  # Evaluate
380
391
  test_loader = c.dataset.get_dataloader(
381
- type_="test",
392
+ type_=DSFLPartitionType.TEST,
382
393
  cid=c.cid,
383
394
  batch_size=c.batch_size,
384
395
  )
@@ -411,12 +422,15 @@ class DSFLProcessPoolClientTrainer(
411
422
  train_loader: DataLoader,
412
423
  device: str,
413
424
  epochs: int,
425
+ stop_event: threading.Event,
414
426
  ) -> None:
415
427
  model.to(device)
416
428
  model.train()
417
429
  criterion = torch.nn.CrossEntropyLoss()
418
430
 
419
431
  for _ in range(epochs):
432
+ if stop_event.is_set():
433
+ break
420
434
  for data, target in train_loader:
421
435
  data = data.to(device)
422
436
  target = target.to(device)
@@ -431,7 +445,9 @@ class DSFLProcessPoolClientTrainer(
431
445
 
432
446
  @staticmethod
433
447
  def predict(
434
- model: torch.nn.Module, open_loader: DataLoader, device: str
448
+ model: torch.nn.Module,
449
+ open_loader: DataLoader,
450
+ device: str,
435
451
  ) -> torch.Tensor:
436
452
  model.to(device)
437
453
  model.eval()
@@ -0,0 +1,3 @@
1
+ from dataset.dataset import DSFLPartitionedDataset, DSFLPartitionType
2
+
3
+ __all__ = ["DSFLPartitionedDataset", "DSFLPartitionType"]
@@ -1,4 +1,5 @@
1
1
  from collections.abc import Sized
2
+ from enum import StrEnum
2
3
  from pathlib import Path
3
4
 
4
5
  import numpy as np
@@ -15,7 +16,13 @@ from dataset.functional import (
15
16
  )
16
17
 
17
18
 
18
- class DSFLPartitionedDataset(PartitionedDataset):
19
+ class DSFLPartitionType(StrEnum):
20
+ TRAIN = "train"
21
+ OPEN = "open"
22
+ TEST = "test"
23
+
24
+
25
+ class DSFLPartitionedDataset(PartitionedDataset[DSFLPartitionType]):
19
26
  def __init__(
20
27
  self,
21
28
  root: Path,
@@ -68,7 +75,7 @@ class DSFLPartitionedDataset(PartitionedDataset):
68
75
  train=False,
69
76
  download=True,
70
77
  )
71
- for type_ in ["train", "open", "test"]:
78
+ for type_ in [ds.value for ds in DSFLPartitionType]:
72
79
  self.path.joinpath(type_).mkdir(parents=True)
73
80
 
74
81
  match self.partition:
@@ -141,19 +148,19 @@ class DSFLPartitionedDataset(PartitionedDataset):
141
148
  self.path.joinpath("test", "default.pkl"),
142
149
  )
143
150
 
144
- def get_dataset(self, type_: str, cid: int | None) -> Dataset:
151
+ def get_dataset(self, type_: DSFLPartitionType, cid: int | None) -> Dataset:
145
152
  match type_:
146
- case "train":
153
+ case DSFLPartitionType.TRAIN:
147
154
  dataset = torch.load(
148
155
  self.path.joinpath(type_, f"{cid}.pkl"),
149
156
  weights_only=False,
150
157
  )
151
- case "open":
158
+ case DSFLPartitionType.OPEN:
152
159
  dataset = torch.load(
153
160
  self.path.joinpath(f"{type_}.pkl"),
154
161
  weights_only=False,
155
162
  )
156
- case "test":
163
+ case DSFLPartitionType.TEST:
157
164
  if cid is not None:
158
165
  dataset = torch.load(
159
166
  self.path.joinpath(type_, f"{cid}.pkl"),
@@ -163,13 +170,11 @@ class DSFLPartitionedDataset(PartitionedDataset):
163
170
  dataset = torch.load(
164
171
  self.path.joinpath(type_, "default.pkl"), weights_only=False
165
172
  )
166
- case _:
167
- raise ValueError(f"Invalid dataset type: {type_}")
168
173
  assert isinstance(dataset, Dataset)
169
174
  return dataset
170
175
 
171
176
  def get_dataloader(
172
- self, type_: str, cid: int | None, batch_size: int | None = None
177
+ self, type_: DSFLPartitionType, cid: int | None, batch_size: int | None = None
173
178
  ) -> DataLoader:
174
179
  dataset = self.get_dataset(type_, cid)
175
180
  assert isinstance(dataset, Sized)
@@ -36,4 +36,5 @@ blazefl = { workspace = true }
36
36
  [dependency-groups]
37
37
  dev = [
38
38
  "mypy>=1.13.0",
39
+ "pre-commit>=4.2.0",
39
40
  ]
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blazefl"
3
- version = "2.0.0.dev6"
3
+ version = "2.0.0.dev7"
4
4
  description = "A blazing-fast and lightweight simulation framework for Federated Learning."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -9,6 +9,8 @@ from blazefl.contrib.fedavg import (
9
9
  FedAvgBaseClientTrainer,
10
10
  FedAvgBaseServerHandler,
11
11
  FedAvgDownlinkPackage,
12
+ FedAvgPartitionedDataset,
13
+ FedAvgPartitionType,
12
14
  FedAvgProcessPoolClientTrainer,
13
15
  FedAvgThreadPoolClientTrainer,
14
16
  FedAvgUplinkPackage,
@@ -21,4 +23,6 @@ __all__ = [
21
23
  "FedAvgThreadPoolClientTrainer",
22
24
  "FedAvgUplinkPackage",
23
25
  "FedAvgDownlinkPackage",
26
+ "FedAvgPartitionType",
27
+ "FedAvgPartitionedDataset",
24
28
  ]
@@ -2,6 +2,7 @@ import random
2
2
  import threading
3
3
  from copy import deepcopy
4
4
  from dataclasses import dataclass
5
+ from enum import StrEnum
5
6
  from pathlib import Path
6
7
  from typing import Literal
7
8
 
@@ -57,8 +58,16 @@ class FedAvgDownlinkPackage:
57
58
  model_parameters: torch.Tensor
58
59
 
59
60
 
61
+ class FedAvgPartitionType(StrEnum):
62
+ TRAIN = "train"
63
+ TEST = "test"
64
+
65
+
66
+ FedAvgPartitionedDataset = PartitionedDataset[FedAvgPartitionType]
67
+
68
+
60
69
  class FedAvgBaseServerHandler(
61
- BaseServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPackage]
70
+ BaseServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPackage],
62
71
  ):
63
72
  """
64
73
  Server-side handler for the Federated Averaging (FedAvg) algorithm.
@@ -83,7 +92,7 @@ class FedAvgBaseServerHandler(
83
92
  self,
84
93
  model_selector: ModelSelector,
85
94
  model_name: str,
86
- dataset: PartitionedDataset,
95
+ dataset: FedAvgPartitionedDataset,
87
96
  global_round: int,
88
97
  num_clients: int,
89
98
  sample_ratio: float,
@@ -241,7 +250,7 @@ class FedAvgBaseServerHandler(
241
250
  server_loss, server_acc = FedAvgBaseServerHandler.evaluate(
242
251
  self.model,
243
252
  self.dataset.get_dataloader(
244
- type_="test",
253
+ type_=FedAvgPartitionType.TEST,
245
254
  cid=None,
246
255
  batch_size=self.batch_size,
247
256
  ),
@@ -289,7 +298,7 @@ class FedAvgBaseClientTrainer(
289
298
  self,
290
299
  model_selector: ModelSelector,
291
300
  model_name: str,
292
- dataset: PartitionedDataset,
301
+ dataset: FedAvgPartitionedDataset,
293
302
  device: str,
294
303
  num_clients: int,
295
304
  epochs: int,
@@ -339,14 +348,9 @@ class FedAvgBaseClientTrainer(
339
348
  model_parameters = payload.model_parameters
340
349
  for cid in tqdm(cid_list, desc="Client", leave=False):
341
350
  data_loader = self.dataset.get_dataloader(
342
- type_="train", cid=cid, batch_size=self.batch_size
351
+ type_=FedAvgPartitionType.TRAIN, cid=cid, batch_size=self.batch_size
343
352
  )
344
353
  pack = self.train(model_parameters, data_loader)
345
- val_loader = self.dataset.get_dataloader(
346
- type_="val", cid=cid, batch_size=self.batch_size
347
- )
348
- loss, acc = self.evaluate(val_loader)
349
- pack.metadata = {"loss": loss, "acc": acc}
350
354
  self.cache.append(pack)
351
355
 
352
356
  def train(
@@ -457,7 +461,7 @@ class FedAvgClientConfig:
457
461
 
458
462
  model_selector: ModelSelector
459
463
  model_name: str
460
- dataset: PartitionedDataset
464
+ dataset: FedAvgPartitionedDataset
461
465
  epochs: int
462
466
  batch_size: int
463
467
  lr: float
@@ -501,7 +505,7 @@ class FedAvgProcessPoolClientTrainer(
501
505
  model_name: str,
502
506
  share_dir: Path,
503
507
  state_dir: Path,
504
- dataset: PartitionedDataset,
508
+ dataset: FedAvgPartitionedDataset,
505
509
  device: str,
506
510
  num_clients: int,
507
511
  epochs: int,
@@ -614,7 +618,7 @@ class FedAvgProcessPoolClientTrainer(
614
618
 
615
619
  model = config.model_selector.select_model(config.model_name)
616
620
  train_loader = config.dataset.get_dataloader(
617
- type_="train",
621
+ type_=FedAvgPartitionType.TRAIN,
618
622
  cid=config.cid,
619
623
  batch_size=config.batch_size,
620
624
  )
@@ -739,7 +743,7 @@ class FedAvgThreadPoolClientTrainer(
739
743
  self,
740
744
  model_selector: ModelSelector,
741
745
  model_name: str,
742
- dataset: PartitionedDataset,
746
+ dataset: FedAvgPartitionedDataset,
743
747
  device: str,
744
748
  num_clients: int,
745
749
  epochs: int,
@@ -773,7 +777,7 @@ class FedAvgThreadPoolClientTrainer(
773
777
  ) -> FedAvgUplinkPackage:
774
778
  model = self.model_selector.select_model(self.model_name)
775
779
  train_loader = self.dataset.get_dataloader(
776
- type_="train",
780
+ type_=FedAvgPartitionType.TRAIN,
777
781
  cid=cid,
778
782
  batch_size=self.batch_size,
779
783
  )
@@ -1,9 +1,12 @@
1
- from typing import Protocol
1
+ from enum import StrEnum
2
+ from typing import Protocol, TypeVar
2
3
 
3
4
  from torch.utils.data import DataLoader, Dataset
4
5
 
6
+ PartitionType = TypeVar("PartitionType", bound=StrEnum, contravariant=True)
5
7
 
6
- class PartitionedDataset(Protocol):
8
+
9
+ class PartitionedDataset(Protocol[PartitionType]):
7
10
  """
8
11
  Abstract base class for partitioned datasets in federated learning.
9
12
 
@@ -14,7 +17,7 @@ class PartitionedDataset(Protocol):
14
17
  NotImplementedError: If the methods are not implemented in a subclass.
15
18
  """
16
19
 
17
- def get_dataset(self, type_: str, cid: int | None) -> Dataset:
20
+ def get_dataset(self, type_: PartitionType, cid: int | None) -> Dataset:
18
21
  """
19
22
  Retrieve a dataset for a specific type and client ID.
20
23
 
@@ -28,7 +31,7 @@ class PartitionedDataset(Protocol):
28
31
  ...
29
32
 
30
33
  def get_dataloader(
31
- self, type_: str, cid: int | None, batch_size: int | None
34
+ self, type_: PartitionType, cid: int | None, batch_size: int | None
32
35
  ) -> DataLoader:
33
36
  """
34
37
  Retrieve a DataLoader for a specific type, client ID, and batch size.
@@ -0,0 +1,9 @@
1
+ from enum import StrEnum
2
+ from torch.utils.data import DataLoader, Dataset
3
+ from typing import Protocol, TypeVar
4
+
5
+ PartitionType = TypeVar('PartitionType', bound=StrEnum, contravariant=True)
6
+
7
+ class PartitionedDataset(Protocol[PartitionType]):
8
+ def get_dataset(self, type_: PartitionType, cid: int | None) -> Dataset: ...
9
+ def get_dataloader(self, type_: PartitionType, cid: int | None, batch_size: int | None) -> DataLoader: ...
@@ -82,7 +82,7 @@ wheels = [
82
82
 
83
83
  [[package]]
84
84
  name = "blazefl"
85
- version = "2.0.0.dev6"
85
+ version = "2.0.0.dev7"
86
86
  source = { editable = "." }
87
87
  dependencies = [
88
88
  { name = "numpy" },
@@ -828,6 +828,7 @@ dependencies = [
828
828
  [package.dev-dependencies]
829
829
  dev = [
830
830
  { name = "mypy" },
831
+ { name = "pre-commit" },
831
832
  ]
832
833
 
833
834
  [package.metadata]
@@ -839,7 +840,10 @@ requires-dist = [
839
840
  ]
840
841
 
841
842
  [package.metadata.requires-dev]
842
- dev = [{ name = "mypy", specifier = ">=1.13.0" }]
843
+ dev = [
844
+ { name = "mypy", specifier = ">=1.13.0" },
845
+ { name = "pre-commit", specifier = ">=4.2.0" },
846
+ ]
843
847
 
844
848
  [[package]]
845
849
  name = "requests"
@@ -1063,6 +1067,7 @@ dependencies = [
1063
1067
  [package.dev-dependencies]
1064
1068
  dev = [
1065
1069
  { name = "mypy" },
1070
+ { name = "pre-commit" },
1066
1071
  ]
1067
1072
 
1068
1073
  [package.metadata]
@@ -1074,7 +1079,10 @@ requires-dist = [
1074
1079
  ]
1075
1080
 
1076
1081
  [package.metadata.requires-dev]
1077
- dev = [{ name = "mypy", specifier = ">=1.13.0" }]
1082
+ dev = [
1083
+ { name = "mypy", specifier = ">=1.13.0" },
1084
+ { name = "pre-commit", specifier = ">=4.2.0" },
1085
+ ]
1078
1086
 
1079
1087
  [[package]]
1080
1088
  name = "sympy"
@@ -1,3 +0,0 @@
1
- from dataset.dataset import DSFLPartitionedDataset
2
-
3
- __all__ = ["DSFLPartitionedDataset"]
@@ -1,6 +0,0 @@
1
- from torch.utils.data import DataLoader, Dataset
2
- from typing import Protocol
3
-
4
- class PartitionedDataset(Protocol):
5
- def get_dataset(self, type_: str, cid: int | None) -> Dataset: ...
6
- def get_dataloader(self, type_: str, cid: int | None, batch_size: int | None) -> DataLoader: ...
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes