blazefl 2.0.0.dev4__tar.gz → 2.0.0.dev5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. blazefl-2.0.0.dev5/.python-version +1 -0
  2. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/PKG-INFO +1 -1
  3. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/pyproject.toml +1 -1
  4. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/contrib/__init__.py +6 -0
  5. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/contrib/fedavg.py +69 -0
  6. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/client_trainer.py +31 -17
  7. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/client_trainer.pyi +2 -0
  8. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_contrib/test_fedavg.py +124 -4
  9. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/uv.lock +1 -1
  10. blazefl-2.0.0.dev4/examples/step-by-step-dsfl/.python-version +0 -1
  11. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/FUNDING.yml +0 -0
  12. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  13. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  14. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/dependabot.yml +0 -0
  15. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/workflows/ci.yaml +0 -0
  16. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/workflows/publish.yaml +0 -0
  17. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.github/workflows/sphinx.yaml +0 -0
  18. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.gitignore +0 -0
  19. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/.pre-commit-config.yaml +0 -0
  20. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/CODE_OF_CONDUCT.md +0 -0
  21. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/LICENSE +0 -0
  22. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/Makefile +0 -0
  23. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/README.md +0 -0
  24. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/Makefile +0 -0
  25. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/architecture.png +0 -0
  26. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/benchmark_cnn.png +0 -0
  27. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/benchmark_resnet18.png +0 -0
  28. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/logo.svg +0 -0
  29. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/logo_square.svg +0 -0
  30. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/imgs/ogp.png +0 -0
  31. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/make.bat +0 -0
  32. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/_static/favicon.ico +0 -0
  33. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/_static/logo.png +0 -0
  34. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/_static/logo_square.png +0 -0
  35. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/_templates/autosummary/class.rst +0 -0
  36. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/_templates/autosummary/module.rst +0 -0
  37. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/benchmark.rst +0 -0
  38. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/conf.py +0 -0
  39. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/contribute.rst +0 -0
  40. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/example.rst +0 -0
  41. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/index.rst +0 -0
  42. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/install.rst +0 -0
  43. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/overview.rst +0 -0
  44. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/docs/source/reference.rst +0 -0
  45. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/.gitignore +0 -0
  46. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/.python-version +0 -0
  47. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/Makefile +0 -0
  48. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/README.md +0 -0
  49. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/config/config.yaml +0 -0
  50. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
  51. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
  52. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/dataset/functional.py +0 -0
  53. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/main.py +0 -0
  54. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/models/__init__.py +0 -0
  55. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/models/selector.py +0 -0
  56. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/pyproject.toml +0 -0
  57. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/experimental-freethreaded/uv.lock +0 -0
  58. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/.gitignore +0 -0
  59. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5/examples/quickstart-fedavg}/.python-version +0 -0
  60. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/Makefile +0 -0
  61. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/README.md +0 -0
  62. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/config/config.yaml +0 -0
  63. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
  64. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/dataset/dataset.py +0 -0
  65. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/dataset/functional.py +0 -0
  66. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/main.py +0 -0
  67. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/models/__init__.py +0 -0
  68. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/models/selector.py +0 -0
  69. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/quickstart-fedavg/pyproject.toml +0 -0
  70. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/.gitignore +0 -0
  71. {blazefl-2.0.0.dev4/examples/quickstart-fedavg → blazefl-2.0.0.dev5/examples/step-by-step-dsfl}/.python-version +0 -0
  72. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/Makefile +0 -0
  73. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/README.md +0 -0
  74. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
  75. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/algorithm/dsfl.py +0 -0
  76. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
  77. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/config/config.yaml +0 -0
  78. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
  79. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
  80. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
  81. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/main.py +0 -0
  82. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/models/__init__.py +0 -0
  83. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/models/cnn.py +0 -0
  84. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/models/selector.py +0 -0
  85. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/examples/step-by-step-dsfl/pyproject.toml +0 -0
  86. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/__init__.py +0 -0
  87. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/__init__.py +0 -0
  88. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/__init__.pyi +0 -0
  89. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/model_selector.py +0 -0
  90. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/model_selector.pyi +0 -0
  91. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/partitioned_dataset.py +0 -0
  92. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/partitioned_dataset.pyi +0 -0
  93. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/server_handler.py +0 -0
  94. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/core/server_handler.pyi +0 -0
  95. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/py.typed +0 -0
  96. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/__init__.py +0 -0
  97. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/__init__.pyi +0 -0
  98. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/dataset.py +0 -0
  99. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/dataset.pyi +0 -0
  100. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/ipc.py +0 -0
  101. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/ipc.pyi +0 -0
  102. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/seed.py +0 -0
  103. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/seed.pyi +0 -0
  104. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/serialize.py +0 -0
  105. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/src/blazefl/utils/serialize.pyi +0 -0
  106. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/__init__.py +0 -0
  107. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/conftest.py +0 -0
  108. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_contrib/__init__.py +0 -0
  109. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_core/__init__.py +0 -0
  110. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_core/test_client_trainer.py +0 -0
  111. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_core/test_model_selector.py +0 -0
  112. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_core/test_partitioned_dataset.py +0 -0
  113. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_utils/__init__.py +0 -0
  114. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_utils/test_dataset.py +0 -0
  115. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_utils/test_seed.py +0 -0
  116. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev5}/tests/test_utils/test_serialize.py +0 -0
@@ -0,0 +1 @@
1
+ 3.13
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blazefl
3
- Version: 2.0.0.dev4
3
+ Version: 2.0.0.dev5
4
4
  Summary: A blazing-fast and lightweight simulation framework for Federated Learning.
5
5
  Author-email: kitsuyaazuma <kitsuyaazuma@gmail.com>
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blazefl"
3
- version = "2.0.0.dev4"
3
+ version = "2.0.0.dev5"
4
4
  description = "A blazing-fast and lightweight simulation framework for Federated Learning."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -8,11 +8,17 @@ extending the core functionalities of BlazeFL.
8
8
  from blazefl.contrib.fedavg import (
9
9
  FedAvgBaseClientTrainer,
10
10
  FedAvgBaseServerHandler,
11
+ FedAvgDownlinkPackage,
11
12
  FedAvgProcessPoolClientTrainer,
13
+ FedAvgThreadPoolClientTrainer,
14
+ FedAvgUplinkPackage,
12
15
  )
13
16
 
14
17
  __all__ = [
15
18
  "FedAvgBaseServerHandler",
16
19
  "FedAvgProcessPoolClientTrainer",
17
20
  "FedAvgBaseClientTrainer",
21
+ "FedAvgThreadPoolClientTrainer",
22
+ "FedAvgUplinkPackage",
23
+ "FedAvgDownlinkPackage",
18
24
  ]
@@ -1,4 +1,5 @@
1
1
  import random
2
+ import threading
2
3
  from copy import deepcopy
3
4
  from dataclasses import dataclass
4
5
  from pathlib import Path
@@ -14,6 +15,7 @@ from blazefl.core import (
14
15
  ModelSelector,
15
16
  PartitionedDataset,
16
17
  ProcessPoolClientTrainer,
18
+ ThreadPoolClientTrainer,
17
19
  )
18
20
  from blazefl.utils import (
19
21
  RandomState,
@@ -640,6 +642,7 @@ class FedAvgProcessPoolClientTrainer(
640
642
  device: str,
641
643
  epochs: int,
642
644
  lr: float,
645
+ stop_event: threading.Event | None = None,
643
646
  ) -> FedAvgUplinkPackage:
644
647
  """
645
648
  Train the model with the given training data loader.
@@ -664,6 +667,8 @@ class FedAvgProcessPoolClientTrainer(
664
667
 
665
668
  data_size = 0
666
669
  for _ in range(epochs):
670
+ if stop_event is not None and stop_event.is_set():
671
+ break
667
672
  for data, target in train_loader:
668
673
  data = data.to(device)
669
674
  target = target.to(device)
@@ -714,3 +719,67 @@ class FedAvgProcessPoolClientTrainer(
714
719
  package = deepcopy(self.cache)
715
720
  self.cache = []
716
721
  return package
722
+
723
+
724
+ class FedAvgThreadPoolClientTrainer(
725
+ ThreadPoolClientTrainer[
726
+ FedAvgUplinkPackage,
727
+ FedAvgDownlinkPackage,
728
+ ]
729
+ ):
730
+ def __init__(
731
+ self,
732
+ model_selector: ModelSelector,
733
+ model_name: str,
734
+ dataset: PartitionedDataset,
735
+ device: str,
736
+ num_clients: int,
737
+ epochs: int,
738
+ batch_size: int,
739
+ lr: float,
740
+ seed: int,
741
+ num_parallels: int,
742
+ ) -> None:
743
+ self.num_parallels = num_parallels
744
+ self.device = device
745
+ if self.device == "cuda":
746
+ self.device_count = torch.cuda.device_count()
747
+ self.cache: list[FedAvgUplinkPackage] = []
748
+
749
+ self.model_selector = model_selector
750
+ self.model_name = model_name
751
+ self.dataset = dataset
752
+ self.epochs = epochs
753
+ self.batch_size = batch_size
754
+ self.lr = lr
755
+ self.num_clients = num_clients
756
+ self.seed = seed
757
+ self.stop_event = None
758
+
759
+ def worker(
760
+ self,
761
+ cid: int,
762
+ device: str,
763
+ payload: FedAvgDownlinkPackage,
764
+ ) -> FedAvgUplinkPackage:
765
+ model = self.model_selector.select_model(self.model_name)
766
+ train_loader = self.dataset.get_dataloader(
767
+ type_="train",
768
+ cid=cid,
769
+ batch_size=self.batch_size,
770
+ )
771
+ package = FedAvgProcessPoolClientTrainer.train(
772
+ model=model,
773
+ model_parameters=payload.model_parameters,
774
+ train_loader=train_loader,
775
+ device=device,
776
+ epochs=self.epochs,
777
+ lr=self.lr,
778
+ stop_event=self.stop_event,
779
+ )
780
+ return package
781
+
782
+ def uplink_package(self) -> list[FedAvgUplinkPackage]:
783
+ package = deepcopy(self.cache)
784
+ self.cache = []
785
+ return package
@@ -1,5 +1,7 @@
1
+ import logging
1
2
  import multiprocessing as mp
2
3
  import signal
4
+ import threading
3
5
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
6
  from multiprocessing.pool import ApplyResult
5
7
  from pathlib import Path
@@ -197,6 +199,7 @@ class ThreadPoolClientTrainer(
197
199
  device: str
198
200
  device_count: int
199
201
  cache: list[UplinkPackage]
202
+ stop_event: threading.Event | None
200
203
 
201
204
  def worker(
202
205
  self,
@@ -223,20 +226,31 @@ class ThreadPoolClientTrainer(
223
226
  return self.device
224
227
 
225
228
  def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
226
- with ThreadPoolExecutor(max_workers=self.num_parallels) as executor:
227
- futures = []
228
- for cid in cid_list:
229
- device = self.get_client_device(cid)
230
- future = executor.submit(
231
- self.worker,
232
- cid,
233
- device,
234
- payload,
235
- )
236
- futures.append(future)
237
-
238
- for future in tqdm(
239
- as_completed(futures), total=len(futures), desc="Client", leave=False
240
- ):
241
- result = future.result()
242
- self.cache.append(result)
229
+ if self.stop_event is None:
230
+ self.stop_event = threading.Event()
231
+ self.stop_event.clear()
232
+ try:
233
+ with ThreadPoolExecutor(max_workers=self.num_parallels) as executor:
234
+ futures = []
235
+ for cid in cid_list:
236
+ device = self.get_client_device(cid)
237
+ future = executor.submit(
238
+ self.worker,
239
+ cid,
240
+ device,
241
+ payload,
242
+ )
243
+ futures.append(future)
244
+
245
+ for future in tqdm(
246
+ as_completed(futures),
247
+ total=len(futures),
248
+ desc="Client",
249
+ leave=False,
250
+ ):
251
+ result = future.result()
252
+ self.cache.append(result)
253
+ except KeyboardInterrupt:
254
+ logging.warning("Training interrupted by user.")
255
+ self.stop_event.set()
256
+ return
@@ -1,3 +1,4 @@
1
+ import threading
1
2
  from blazefl.utils import move_tensor_to_shared_memory as move_tensor_to_shared_memory
2
3
  from multiprocessing.pool import ApplyResult as ApplyResult
3
4
  from pathlib import Path
@@ -29,6 +30,7 @@ class ThreadPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage],
29
30
  device: str
30
31
  device_count: int
31
32
  cache: list[UplinkPackage]
33
+ stop_event: threading.Event | None
32
34
  def worker(self, cid: int, device: str, payload: DownlinkPackage) -> UplinkPackage: ...
33
35
  def get_client_device(self, cid: int) -> str: ...
34
36
  def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...
@@ -13,6 +13,7 @@ from src.blazefl.contrib.fedavg import (
13
13
  FedAvgBaseClientTrainer,
14
14
  FedAvgBaseServerHandler,
15
15
  FedAvgProcessPoolClientTrainer,
16
+ FedAvgThreadPoolClientTrainer,
16
17
  )
17
18
  from src.blazefl.core import ModelSelector, PartitionedDataset
18
19
 
@@ -86,7 +87,9 @@ def tmp_state_dir(tmp_path):
86
87
  return state_dir
87
88
 
88
89
 
89
- def test_server_and_base_integration(model_selector, partitioned_dataset, device):
90
+ def test_base_server_and_base_trainer_integration(
91
+ model_selector, partitioned_dataset, device
92
+ ):
90
93
  model_name = "dummy"
91
94
  global_round = 1
92
95
  num_clients = 3
@@ -134,7 +137,7 @@ def test_server_and_base_integration(model_selector, partitioned_dataset, device
134
137
 
135
138
 
136
139
  @pytest.mark.parametrize("ipc_mode", ["storage", "shared_memory"])
137
- def test_server_and_process_pool_integration(
140
+ def test_base_handler_and_process_pool_trainer_integration(
138
141
  model_selector, partitioned_dataset, device, tmp_share_dir, tmp_state_dir, ipc_mode
139
142
  ):
140
143
  model_name = "dummy"
@@ -195,7 +198,7 @@ def run_local_process(trainer, downlink, cids):
195
198
  trainer.local_process(downlink, cids)
196
199
 
197
200
 
198
- def test_server_and_process_pool_integration_keyboard_interrupt(
201
+ def test_base_handler_and_process_pool_trainer_integration_keyboard_interrupt(
199
202
  model_selector, partitioned_dataset, device, tmp_share_dir, tmp_state_dir
200
203
  ):
201
204
  model_name = "dummy"
@@ -252,7 +255,9 @@ def test_server_and_process_pool_integration_keyboard_interrupt(
252
255
  if len(spawned_pids) == num_parallels:
253
256
  break
254
257
  if time.time() - start_time > timeout:
255
- raise AssertionError("Timeout reached while waiting for spawned processes.")
258
+ pytest.fail(
259
+ f"Process did not spawn {len(spawned_pids)} processes within {timeout}s"
260
+ )
256
261
  assert proc.is_alive()
257
262
 
258
263
  os.kill(proc.pid, signal.SIGINT)
@@ -265,3 +270,118 @@ def test_server_and_process_pool_integration_keyboard_interrupt(
265
270
  if psutil.pid_exists(pid):
266
271
  orphan_pids.append(pid)
267
272
  assert len(orphan_pids) == 0
273
+
274
+
275
+ def test_base_handler_and_thread_pool_trainer_integration(
276
+ model_selector, partitioned_dataset, device
277
+ ):
278
+ model_name = "dummy"
279
+ global_round = 2
280
+ num_clients = 3
281
+ sample_ratio = 1.0
282
+ epochs = 1
283
+ batch_size = 2
284
+ lr = 0.01
285
+ seed = 42
286
+ num_parallels = 2
287
+
288
+ server = FedAvgBaseServerHandler(
289
+ model_selector=model_selector,
290
+ model_name=model_name,
291
+ dataset=partitioned_dataset,
292
+ global_round=global_round,
293
+ num_clients=num_clients,
294
+ sample_ratio=sample_ratio,
295
+ device=device,
296
+ batch_size=batch_size,
297
+ )
298
+
299
+ trainer = FedAvgThreadPoolClientTrainer(
300
+ model_selector=model_selector,
301
+ model_name=model_name,
302
+ dataset=partitioned_dataset,
303
+ device=device,
304
+ num_clients=num_clients,
305
+ epochs=epochs,
306
+ batch_size=batch_size,
307
+ lr=lr,
308
+ seed=seed,
309
+ num_parallels=num_parallels,
310
+ )
311
+
312
+ for round_ in range(1, global_round + 1):
313
+ cids = server.sample_clients()
314
+ downlink = server.downlink_package()
315
+ trainer.local_process(downlink, cids)
316
+ uplinks = trainer.uplink_package()
317
+ assert len(uplinks) == num_clients
318
+
319
+ done = False
320
+ for pkg in uplinks:
321
+ done = server.load(pkg)
322
+ assert done is True
323
+ assert server.round == round_
324
+
325
+ assert server.if_stop() is True
326
+
327
+
328
+ def test_base_handler_and_thread_pool_trainer_integration_keyboard_interrupt(
329
+ model_selector,
330
+ partitioned_dataset,
331
+ device,
332
+ ):
333
+ model_name = "dummy"
334
+ global_round = 1
335
+ num_clients = 10
336
+ sample_ratio = 1.0
337
+ epochs = 10**5
338
+ batch_size = 2
339
+ lr = 0.01
340
+ seed = 42
341
+ num_parallels = 10
342
+
343
+ server = FedAvgBaseServerHandler(
344
+ model_selector=model_selector,
345
+ model_name=model_name,
346
+ dataset=partitioned_dataset,
347
+ global_round=global_round,
348
+ num_clients=num_clients,
349
+ sample_ratio=sample_ratio,
350
+ device=device,
351
+ batch_size=batch_size,
352
+ )
353
+
354
+ trainer = FedAvgThreadPoolClientTrainer(
355
+ model_selector=model_selector,
356
+ model_name=model_name,
357
+ dataset=partitioned_dataset,
358
+ device=device,
359
+ num_clients=num_clients,
360
+ epochs=epochs,
361
+ batch_size=batch_size,
362
+ lr=lr,
363
+ seed=seed,
364
+ num_parallels=num_parallels,
365
+ )
366
+
367
+ cids = server.sample_clients()
368
+ downlink = server.downlink_package()
369
+
370
+ proc = Process(target=run_local_process, args=(trainer, downlink, cids))
371
+ proc.start()
372
+ assert proc.pid is not None
373
+
374
+ p = psutil.Process(proc.pid)
375
+ timeout = 5
376
+ while p.num_threads() < num_parallels + 1:
377
+ if not p.is_running() or time.time() - p.create_time() > timeout:
378
+ pytest.fail(
379
+ f"Process did not spawn {num_parallels} threads within {timeout}s"
380
+ )
381
+ time.sleep(0.1)
382
+ assert proc.is_alive()
383
+
384
+ os.kill(proc.pid, signal.SIGINT)
385
+
386
+ proc.join(timeout=5)
387
+ assert not proc.is_alive()
@@ -82,7 +82,7 @@ wheels = [
82
82
 
83
83
  [[package]]
84
84
  name = "blazefl"
85
- version = "2.0.0.dev4"
85
+ version = "2.0.0.dev5"
86
86
  source = { editable = "." }
87
87
  dependencies = [
88
88
  { name = "numpy" },
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes