blazefl 2.0.0.dev4__tar.gz → 2.0.0.dev6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. blazefl-2.0.0.dev6/.python-version +1 -0
  2. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/PKG-INFO +1 -1
  3. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/config/config.yaml +1 -1
  4. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/main.py +54 -32
  5. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/main.py +1 -3
  6. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/pyproject.toml +1 -1
  7. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/contrib/__init__.py +6 -0
  8. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/contrib/fedavg.py +132 -2
  9. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/client_trainer.py +34 -8
  10. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/client_trainer.pyi +5 -2
  11. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_contrib/test_fedavg.py +161 -33
  12. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_core/test_client_trainer.py +14 -3
  13. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/uv.lock +1 -1
  14. blazefl-2.0.0.dev4/examples/step-by-step-dsfl/.python-version +0 -1
  15. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/FUNDING.yml +0 -0
  16. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  17. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  18. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/dependabot.yml +0 -0
  19. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/workflows/ci.yaml +0 -0
  20. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/workflows/publish.yaml +0 -0
  21. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.github/workflows/sphinx.yaml +0 -0
  22. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.gitignore +0 -0
  23. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/.pre-commit-config.yaml +0 -0
  24. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/CODE_OF_CONDUCT.md +0 -0
  25. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/LICENSE +0 -0
  26. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/Makefile +0 -0
  27. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/README.md +0 -0
  28. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/Makefile +0 -0
  29. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/architecture.png +0 -0
  30. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/benchmark_cnn.png +0 -0
  31. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/benchmark_resnet18.png +0 -0
  32. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/logo.svg +0 -0
  33. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/logo_square.svg +0 -0
  34. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/imgs/ogp.png +0 -0
  35. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/make.bat +0 -0
  36. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/_static/favicon.ico +0 -0
  37. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/_static/logo.png +0 -0
  38. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/_static/logo_square.png +0 -0
  39. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/_templates/autosummary/class.rst +0 -0
  40. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/_templates/autosummary/module.rst +0 -0
  41. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/benchmark.rst +0 -0
  42. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/conf.py +0 -0
  43. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/contribute.rst +0 -0
  44. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/example.rst +0 -0
  45. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/index.rst +0 -0
  46. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/install.rst +0 -0
  47. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/overview.rst +0 -0
  48. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/docs/source/reference.rst +0 -0
  49. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/.gitignore +0 -0
  50. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/.python-version +0 -0
  51. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/Makefile +0 -0
  52. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/README.md +0 -0
  53. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/config/config.yaml +0 -0
  54. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
  55. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
  56. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/dataset/functional.py +0 -0
  57. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/main.py +0 -0
  58. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/models/__init__.py +0 -0
  59. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/models/selector.py +0 -0
  60. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/pyproject.toml +0 -0
  61. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/experimental-freethreaded/uv.lock +0 -0
  62. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/.gitignore +0 -0
  63. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6/examples/quickstart-fedavg}/.python-version +0 -0
  64. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/Makefile +0 -0
  65. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/README.md +0 -0
  66. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
  67. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/dataset/dataset.py +0 -0
  68. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/dataset/functional.py +0 -0
  69. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/models/__init__.py +0 -0
  70. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/models/selector.py +0 -0
  71. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/quickstart-fedavg/pyproject.toml +0 -0
  72. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/.gitignore +0 -0
  73. {blazefl-2.0.0.dev4/examples/quickstart-fedavg → blazefl-2.0.0.dev6/examples/step-by-step-dsfl}/.python-version +0 -0
  74. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/Makefile +0 -0
  75. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/README.md +0 -0
  76. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
  77. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/algorithm/dsfl.py +0 -0
  78. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
  79. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/config/config.yaml +0 -0
  80. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
  81. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
  82. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
  83. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/models/__init__.py +0 -0
  84. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/models/cnn.py +0 -0
  85. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/models/selector.py +0 -0
  86. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/examples/step-by-step-dsfl/pyproject.toml +0 -0
  87. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/__init__.py +0 -0
  88. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/__init__.py +0 -0
  89. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/__init__.pyi +0 -0
  90. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/model_selector.py +0 -0
  91. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/model_selector.pyi +0 -0
  92. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/partitioned_dataset.py +0 -0
  93. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/partitioned_dataset.pyi +0 -0
  94. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/server_handler.py +0 -0
  95. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/core/server_handler.pyi +0 -0
  96. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/py.typed +0 -0
  97. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/__init__.py +0 -0
  98. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/__init__.pyi +0 -0
  99. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/dataset.py +0 -0
  100. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/dataset.pyi +0 -0
  101. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/ipc.py +0 -0
  102. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/ipc.pyi +0 -0
  103. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/seed.py +0 -0
  104. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/seed.pyi +0 -0
  105. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/serialize.py +0 -0
  106. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/src/blazefl/utils/serialize.pyi +0 -0
  107. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/__init__.py +0 -0
  108. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/conftest.py +0 -0
  109. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_contrib/__init__.py +0 -0
  110. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_core/__init__.py +0 -0
  111. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_core/test_model_selector.py +0 -0
  112. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_core/test_partitioned_dataset.py +0 -0
  113. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_utils/__init__.py +0 -0
  114. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_utils/test_dataset.py +0 -0
  115. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_utils/test_seed.py +0 -0
  116. {blazefl-2.0.0.dev4 → blazefl-2.0.0.dev6}/tests/test_utils/test_serialize.py +0 -0
@@ -0,0 +1 @@
1
+ 3.13
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blazefl
3
- Version: 2.0.0.dev4
3
+ Version: 2.0.0.dev6
4
4
  Summary: A blazing-fast and lightweight simulation framework for Federated Learning.
5
5
  Author-email: kitsuyaazuma <kitsuyaazuma@gmail.com>
6
6
  License-File: LICENSE
@@ -14,5 +14,5 @@ dataset_root_dir: /tmp/quickstart-fedavg/dataset
14
14
  dataset_split_dir: /tmp/quickstart-fedavg/split
15
15
  share_dir: /tmp/quickstart-fedavg/share
16
16
  state_dir: /tmp/quickstart-fedavg/state
17
- parallel: true
17
+ execution_mode: multi-process
18
18
  ipc_mode: storage
@@ -9,6 +9,7 @@ from blazefl.contrib import (
9
9
  FedAvgBaseClientTrainer,
10
10
  FedAvgBaseServerHandler,
11
11
  FedAvgProcessPoolClientTrainer,
12
+ FedAvgThreadPoolClientTrainer,
12
13
  )
13
14
  from blazefl.utils import seed_everything
14
15
  from hydra.core import hydra_config
@@ -23,7 +24,9 @@ class FedAvgPipeline:
23
24
  def __init__(
24
25
  self,
25
26
  handler: FedAvgBaseServerHandler,
26
- trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer,
27
+ trainer: FedAvgBaseClientTrainer
28
+ | FedAvgProcessPoolClientTrainer
29
+ | FedAvgThreadPoolClientTrainer,
27
30
  writer: SummaryWriter,
28
31
  ) -> None:
29
32
  self.handler = handler
@@ -97,41 +100,60 @@ def main(cfg: DictConfig):
97
100
  sample_ratio=cfg.sample_ratio,
98
101
  batch_size=cfg.batch_size,
99
102
  )
100
- trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer | None = None
101
- if cfg.parallel:
102
- trainer = FedAvgProcessPoolClientTrainer(
103
- model_selector=model_selector,
104
- model_name=cfg.model_name,
105
- dataset=dataset,
106
- share_dir=share_dir,
107
- state_dir=state_dir,
108
- seed=cfg.seed,
109
- device=device,
110
- num_clients=cfg.num_clients,
111
- epochs=cfg.epochs,
112
- lr=cfg.lr,
113
- batch_size=cfg.batch_size,
114
- num_parallels=cfg.num_parallels,
115
- ipc_mode=cfg.ipc_mode,
116
- )
117
- else:
118
- trainer = FedAvgBaseClientTrainer(
119
- model_selector=model_selector,
120
- model_name=cfg.model_name,
121
- dataset=dataset,
122
- device=device,
123
- num_clients=cfg.num_clients,
124
- epochs=cfg.epochs,
125
- lr=cfg.lr,
126
- batch_size=cfg.batch_size,
127
- )
103
+ trainer: (
104
+ FedAvgBaseClientTrainer
105
+ | FedAvgProcessPoolClientTrainer
106
+ | FedAvgThreadPoolClientTrainer
107
+ | None
108
+ ) = None
109
+ match cfg.execution_mode:
110
+ case "multi-process":
111
+ trainer = FedAvgProcessPoolClientTrainer(
112
+ model_selector=model_selector,
113
+ model_name=cfg.model_name,
114
+ dataset=dataset,
115
+ share_dir=share_dir,
116
+ state_dir=state_dir,
117
+ seed=cfg.seed,
118
+ device=device,
119
+ num_clients=cfg.num_clients,
120
+ epochs=cfg.epochs,
121
+ lr=cfg.lr,
122
+ batch_size=cfg.batch_size,
123
+ num_parallels=cfg.num_parallels,
124
+ ipc_mode=cfg.ipc_mode,
125
+ )
126
+ case "single-thread":
127
+ trainer = FedAvgBaseClientTrainer(
128
+ model_selector=model_selector,
129
+ model_name=cfg.model_name,
130
+ dataset=dataset,
131
+ device=device,
132
+ num_clients=cfg.num_clients,
133
+ epochs=cfg.epochs,
134
+ lr=cfg.lr,
135
+ batch_size=cfg.batch_size,
136
+ )
137
+ case "multi-thread":
138
+ trainer = FedAvgThreadPoolClientTrainer(
139
+ model_selector=model_selector,
140
+ model_name=cfg.model_name,
141
+ dataset=dataset,
142
+ seed=cfg.seed,
143
+ device=device,
144
+ num_clients=cfg.num_clients,
145
+ epochs=cfg.epochs,
146
+ lr=cfg.lr,
147
+ batch_size=cfg.batch_size,
148
+ num_parallels=cfg.num_parallels,
149
+ )
150
+ case _:
151
+ raise ValueError(f"Invalid execution mode: {cfg.execution_mode}")
128
152
  pipeline = FedAvgPipeline(handler=handler, trainer=trainer, writer=writer)
129
153
  try:
130
154
  pipeline.main()
131
155
  except KeyboardInterrupt:
132
- logging.info("KeyboardInterrupt: Stopping the pipeline.")
133
- except Exception as e:
134
- logging.exception(f"An error occurred: {e}")
156
+ logging.info("KeyboardInterrupt")
135
157
 
136
158
 
137
159
  if __name__ == "__main__":
@@ -120,9 +120,7 @@ def main(
120
120
  try:
121
121
  pipeline.main()
122
122
  except KeyboardInterrupt:
123
- logging.info("KeyboardInterrupt: Stopping the pipeline.")
124
- except Exception as e:
125
- logging.exception(f"An error occurred: {e}")
123
+ logging.info("KeyboardInterrupt")
126
124
 
127
125
 
128
126
  if __name__ == "__main__":
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blazefl"
3
- version = "2.0.0.dev4"
3
+ version = "2.0.0.dev6"
4
4
  description = "A blazing-fast and lightweight simulation framework for Federated Learning."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -8,11 +8,17 @@ extending the core functionalities of BlazeFL.
8
8
  from blazefl.contrib.fedavg import (
9
9
  FedAvgBaseClientTrainer,
10
10
  FedAvgBaseServerHandler,
11
+ FedAvgDownlinkPackage,
11
12
  FedAvgProcessPoolClientTrainer,
13
+ FedAvgThreadPoolClientTrainer,
14
+ FedAvgUplinkPackage,
12
15
  )
13
16
 
14
17
  __all__ = [
15
18
  "FedAvgBaseServerHandler",
16
19
  "FedAvgProcessPoolClientTrainer",
17
20
  "FedAvgBaseClientTrainer",
21
+ "FedAvgThreadPoolClientTrainer",
22
+ "FedAvgUplinkPackage",
23
+ "FedAvgDownlinkPackage",
18
24
  ]
@@ -1,10 +1,12 @@
1
1
  import random
2
+ import threading
2
3
  from copy import deepcopy
3
4
  from dataclasses import dataclass
4
5
  from pathlib import Path
5
6
  from typing import Literal
6
7
 
7
8
  import torch
9
+ import torch.multiprocessing as mp
8
10
  from torch.utils.data import DataLoader
9
11
  from tqdm import tqdm
10
12
 
@@ -14,6 +16,7 @@ from blazefl.core import (
14
16
  ModelSelector,
15
17
  PartitionedDataset,
16
18
  ProcessPoolClientTrainer,
19
+ ThreadPoolClientTrainer,
17
20
  )
18
21
  from blazefl.utils import (
19
22
  RandomState,
@@ -545,12 +548,15 @@ class FedAvgProcessPoolClientTrainer(
545
548
  self.num_clients = num_clients
546
549
  self.seed = seed
547
550
  self.ipc_mode = ipc_mode
551
+ self.manager = mp.Manager()
552
+ self.stop_event = self.manager.Event()
548
553
 
549
554
  @staticmethod
550
555
  def worker(
551
556
  config: FedAvgClientConfig | Path,
552
557
  payload: FedAvgDownlinkPackage | Path,
553
558
  device: str,
559
+ stop_event: threading.Event,
554
560
  ) -> FedAvgUplinkPackage | Path:
555
561
  """
556
562
  Process a single client's local training and evaluation.
@@ -578,6 +584,7 @@ class FedAvgProcessPoolClientTrainer(
578
584
  config_path: Path,
579
585
  payload_path: Path,
580
586
  device: str,
587
+ stop_event: threading.Event,
581
588
  ) -> Path:
582
589
  config = torch.load(config_path, weights_only=False)
583
590
  assert isinstance(config, FedAvgClientConfig)
@@ -587,6 +594,7 @@ class FedAvgProcessPoolClientTrainer(
587
594
  config=config,
588
595
  payload=payload,
589
596
  device=device,
597
+ stop_event=stop_event,
590
598
  )
591
599
  torch.save(package, config_path)
592
600
  return config_path
@@ -595,6 +603,7 @@ class FedAvgProcessPoolClientTrainer(
595
603
  config: FedAvgClientConfig,
596
604
  payload: FedAvgDownlinkPackage,
597
605
  device: str,
606
+ stop_event: threading.Event,
598
607
  ) -> FedAvgUplinkPackage:
599
608
  if config.state_path.exists():
600
609
  state = torch.load(config.state_path, weights_only=False)
@@ -616,16 +625,17 @@ class FedAvgProcessPoolClientTrainer(
616
625
  device=device,
617
626
  epochs=config.epochs,
618
627
  lr=config.lr,
628
+ stop_event=stop_event,
619
629
  )
620
630
  torch.save(RandomState.get_random_state(device=device), config.state_path)
621
631
  return package
622
632
 
623
633
  if isinstance(config, Path) and isinstance(payload, Path):
624
- return _storage_worker(config, payload, device)
634
+ return _storage_worker(config, payload, device, stop_event)
625
635
  elif isinstance(config, FedAvgClientConfig) and isinstance(
626
636
  payload, FedAvgDownlinkPackage
627
637
  ):
628
- return _shared_memory_worker(config, payload, device)
638
+ return _shared_memory_worker(config, payload, device, stop_event)
629
639
  else:
630
640
  raise TypeError(
631
641
  "Invalid types for config and payload."
@@ -640,6 +650,7 @@ class FedAvgProcessPoolClientTrainer(
640
650
  device: str,
641
651
  epochs: int,
642
652
  lr: float,
653
+ stop_event: threading.Event,
643
654
  ) -> FedAvgUplinkPackage:
644
655
  """
645
656
  Train the model with the given training data loader.
@@ -664,6 +675,8 @@ class FedAvgProcessPoolClientTrainer(
664
675
 
665
676
  data_size = 0
666
677
  for _ in range(epochs):
678
+ if stop_event.is_set():
679
+ break
667
680
  for data, target in train_loader:
668
681
  data = data.to(device)
669
682
  target = target.to(device)
@@ -714,3 +727,120 @@ class FedAvgProcessPoolClientTrainer(
714
727
  package = deepcopy(self.cache)
715
728
  self.cache = []
716
729
  return package
730
+
731
+
732
+ class FedAvgThreadPoolClientTrainer(
733
+ ThreadPoolClientTrainer[
734
+ FedAvgUplinkPackage,
735
+ FedAvgDownlinkPackage,
736
+ ]
737
+ ):
738
+ def __init__(
739
+ self,
740
+ model_selector: ModelSelector,
741
+ model_name: str,
742
+ dataset: PartitionedDataset,
743
+ device: str,
744
+ num_clients: int,
745
+ epochs: int,
746
+ batch_size: int,
747
+ lr: float,
748
+ seed: int,
749
+ num_parallels: int,
750
+ ) -> None:
751
+ self.num_parallels = num_parallels
752
+ self.device = device
753
+ if self.device == "cuda":
754
+ self.device_count = torch.cuda.device_count()
755
+ self.cache: list[FedAvgUplinkPackage] = []
756
+
757
+ self.model_selector = model_selector
758
+ self.model_name = model_name
759
+ self.dataset = dataset
760
+ self.epochs = epochs
761
+ self.batch_size = batch_size
762
+ self.lr = lr
763
+ self.num_clients = num_clients
764
+ self.seed = seed
765
+ self.stop_event = threading.Event()
766
+
767
+ def worker(
768
+ self,
769
+ cid: int,
770
+ device: str,
771
+ payload: FedAvgDownlinkPackage,
772
+ stop_event: threading.Event,
773
+ ) -> FedAvgUplinkPackage:
774
+ model = self.model_selector.select_model(self.model_name)
775
+ train_loader = self.dataset.get_dataloader(
776
+ type_="train",
777
+ cid=cid,
778
+ batch_size=self.batch_size,
779
+ )
780
+ package = self.train(
781
+ model=model,
782
+ model_parameters=payload.model_parameters,
783
+ train_loader=train_loader,
784
+ device=device,
785
+ epochs=self.epochs,
786
+ lr=self.lr,
787
+ stop_event=stop_event,
788
+ )
789
+ return package
790
+
791
+ def train(
792
+ self,
793
+ model: torch.nn.Module,
794
+ model_parameters: torch.Tensor,
795
+ train_loader: DataLoader,
796
+ device: str,
797
+ epochs: int,
798
+ lr: float,
799
+ stop_event: threading.Event,
800
+ ) -> FedAvgUplinkPackage:
801
+ """
802
+ Train the model with the given training data loader.
803
+
804
+ Args:
805
+ model (torch.nn.Module): The model to train.
806
+ model_parameters (torch.Tensor): Initial global model parameters.
807
+ train_loader (DataLoader): DataLoader for the training data.
808
+ device (str): Device to run the training on.
809
+ epochs (int): Number of local training epochs.
810
+ lr (float): Learning rate for the optimizer.
811
+
812
+ Returns:
813
+ FedAvgUplinkPackage: Uplink package containing updated model parameters
814
+ and data size.
815
+ """
816
+ model.to(device)
817
+ deserialize_model(model, model_parameters)
818
+ model.train()
819
+ optimizer = torch.optim.SGD(model.parameters(), lr=lr)
820
+ criterion = torch.nn.CrossEntropyLoss()
821
+
822
+ data_size = 0
823
+ for _ in range(epochs):
824
+ if stop_event.is_set():
825
+ break
826
+ for data, target in train_loader:
827
+ data = data.to(device)
828
+ target = target.to(device)
829
+
830
+ output = model(data)
831
+ loss = criterion(output, target)
832
+
833
+ data_size += len(target)
834
+
835
+ optimizer.zero_grad()
836
+ loss.backward()
837
+ optimizer.step()
838
+
839
+ model_parameters = serialize_model(model)
840
+
841
+ return FedAvgUplinkPackage(model_parameters, data_size)
842
+
843
+ def uplink_package(self) -> list[FedAvgUplinkPackage]:
844
+ package = deepcopy(self.cache)
845
+ self.cache = []
846
+ return package
@@ -1,11 +1,12 @@
1
- import multiprocessing as mp
2
1
  import signal
2
+ import threading
3
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
4
  from multiprocessing.pool import ApplyResult
5
5
  from pathlib import Path
6
6
  from typing import Literal, Protocol, TypeVar
7
7
 
8
8
  import torch
9
+ import torch.multiprocessing as mp
9
10
  from tqdm import tqdm
10
11
 
11
12
  from blazefl.utils import move_tensor_to_shared_memory
@@ -82,6 +83,7 @@ class ProcessPoolClientTrainer(
82
83
  device_count: int
83
84
  cache: list[UplinkPackage]
84
85
  ipc_mode: Literal["storage", "shared_memory"] = "storage"
86
+ stop_event: threading.Event
85
87
 
86
88
  def get_client_config(self, cid: int) -> ClientConfig:
87
89
  """
@@ -111,7 +113,10 @@ class ProcessPoolClientTrainer(
111
113
 
112
114
  @staticmethod
113
115
  def worker(
114
- config: ClientConfig | Path, payload: DownlinkPackage | Path, device: str
116
+ config: ClientConfig | Path,
117
+ payload: DownlinkPackage | Path,
118
+ device: str,
119
+ stop_event: threading.Event,
115
120
  ) -> UplinkPackage | Path:
116
121
  """
117
122
  Process a single client's training task.
@@ -157,11 +162,13 @@ class ProcessPoolClientTrainer(
157
162
  else: # shared_memory
158
163
  move_tensor_to_shared_memory(payload)
159
164
 
160
- with mp.Pool(
165
+ self.stop_event.clear()
166
+ pool = mp.Pool(
161
167
  processes=self.num_parallels,
162
168
  initializer=signal.signal,
163
169
  initargs=(signal.SIGINT, signal.SIG_IGN),
164
- ) as pool:
170
+ )
171
+ try:
165
172
  jobs: list[ApplyResult] = []
166
173
  for cid in cid_list:
167
174
  config = self.get_client_config(cid)
@@ -171,12 +178,15 @@ class ProcessPoolClientTrainer(
171
178
  torch.save(config, config_path)
172
179
  jobs.append(
173
180
  pool.apply_async(
174
- self.worker, (config_path, payload_path, device)
181
+ self.worker,
182
+ (config_path, payload_path, device, self.stop_event),
175
183
  )
176
184
  )
177
185
  else: # shared_memory
178
186
  jobs.append(
179
- pool.apply_async(self.worker, (config, payload, device))
187
+ pool.apply_async(
188
+ self.worker, (config, payload, device, self.stop_event)
189
+ )
180
190
  )
181
191
 
182
192
  for job in tqdm(jobs, desc="Client", leave=False):
@@ -187,6 +197,10 @@ class ProcessPoolClientTrainer(
187
197
  else: # shared_memory
188
198
  package = result
189
199
  self.cache.append(package)
200
+ finally:
201
+ self.stop_event.set()
202
+ pool.close()
203
+ pool.join()
190
204
 
191
205
 
192
206
  class ThreadPoolClientTrainer(
@@ -197,12 +211,14 @@ class ThreadPoolClientTrainer(
197
211
  device: str
198
212
  device_count: int
199
213
  cache: list[UplinkPackage]
214
+ stop_event: threading.Event
200
215
 
201
216
  def worker(
202
217
  self,
203
218
  cid: int,
204
219
  device: str,
205
220
  payload: DownlinkPackage,
221
+ stop_event: threading.Event,
206
222
  ) -> UplinkPackage:
207
223
  """
208
224
  Process a single client's training task in a thread.
@@ -211,6 +227,7 @@ class ThreadPoolClientTrainer(
211
227
  cid (int): The client ID.
212
228
  device (str): The device to use for processing this client.
213
229
  payload (DownlinkPackage): The data package received from the server.
230
+ stop_event (threading.Event): Event to signal stopping the worker.
214
231
 
215
232
  Returns:
216
233
  UplinkPackage: The uplink package containing the client's results.
@@ -223,7 +240,9 @@ class ThreadPoolClientTrainer(
223
240
  return self.device
224
241
 
225
242
  def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
226
- with ThreadPoolExecutor(max_workers=self.num_parallels) as executor:
243
+ self.stop_event.clear()
244
+ executor = ThreadPoolExecutor(max_workers=self.num_parallels)
245
+ try:
227
246
  futures = []
228
247
  for cid in cid_list:
229
248
  device = self.get_client_device(cid)
@@ -232,11 +251,18 @@ class ThreadPoolClientTrainer(
232
251
  cid,
233
252
  device,
234
253
  payload,
254
+ self.stop_event,
235
255
  )
236
256
  futures.append(future)
237
257
 
238
258
  for future in tqdm(
239
- as_completed(futures), total=len(futures), desc="Client", leave=False
259
+ as_completed(futures),
260
+ total=len(futures),
261
+ desc="Client",
262
+ leave=False,
240
263
  ):
241
264
  result = future.result()
242
265
  self.cache.append(result)
266
+ finally:
267
+ self.stop_event.set()
268
+ executor.shutdown(wait=True, cancel_futures=True)
@@ -1,3 +1,4 @@
1
+ import threading
1
2
  from blazefl.utils import move_tensor_to_shared_memory as move_tensor_to_shared_memory
2
3
  from multiprocessing.pool import ApplyResult as ApplyResult
3
4
  from pathlib import Path
@@ -18,10 +19,11 @@ class ProcessPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage]
18
19
  device_count: int
19
20
  cache: list[UplinkPackage]
20
21
  ipc_mode: Literal['storage', 'shared_memory']
22
+ stop_event: threading.Event
21
23
  def get_client_config(self, cid: int) -> ClientConfig: ...
22
24
  def get_client_device(self, cid: int) -> str: ...
23
25
  @staticmethod
24
- def worker(config: ClientConfig | Path, payload: DownlinkPackage | Path, device: str) -> UplinkPackage | Path: ...
26
+ def worker(config: ClientConfig | Path, payload: DownlinkPackage | Path, device: str, stop_event: threading.Event) -> UplinkPackage | Path: ...
25
27
  def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...
26
28
 
27
29
  class ThreadPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage], Protocol[UplinkPackage, DownlinkPackage]):
@@ -29,6 +31,7 @@ class ThreadPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage],
29
31
  device: str
30
32
  device_count: int
31
33
  cache: list[UplinkPackage]
32
- def worker(self, cid: int, device: str, payload: DownlinkPackage) -> UplinkPackage: ...
34
+ stop_event: threading.Event
35
+ def worker(self, cid: int, device: str, payload: DownlinkPackage, stop_event: threading.Event) -> UplinkPackage: ...
33
36
  def get_client_device(self, cid: int) -> str: ...
34
37
  def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...