blazefl 2.0.0.dev2__tar.gz → 2.0.0.dev4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/PKG-INFO +1 -1
  2. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/main.py +1 -2
  3. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/pyproject.toml +1 -1
  4. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/uv.lock +4 -4
  5. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/config/config.yaml +2 -1
  6. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/main.py +15 -14
  7. blazefl-2.0.0.dev4/examples/step-by-step-dsfl/algorithm/__init__.py +3 -0
  8. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/algorithm/dsfl.py +53 -49
  9. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/main.py +5 -5
  10. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/pyproject.toml +1 -1
  11. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/contrib/__init__.py +6 -6
  12. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/contrib/fedavg.py +101 -59
  13. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/__init__.py +8 -8
  14. blazefl-2.0.0.dev4/src/blazefl/core/__init__.pyi +6 -0
  15. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/client_trainer.py +80 -35
  16. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/client_trainer.pyi +10 -9
  17. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/server_handler.py +1 -1
  18. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/server_handler.pyi +1 -1
  19. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/__init__.py +2 -0
  20. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/__init__.pyi +2 -1
  21. blazefl-2.0.0.dev4/src/blazefl/utils/ipc.py +33 -0
  22. blazefl-2.0.0.dev4/src/blazefl/utils/ipc.pyi +3 -0
  23. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_contrib/test_fedavg.py +16 -13
  24. blazefl-2.0.0.dev4/tests/test_core/test_client_trainer.py +126 -0
  25. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/uv.lock +1 -1
  26. blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/__init__.py +0 -5
  27. blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/client_trainer.py +0 -45
  28. blazefl-2.0.0.dev2/examples/step-by-step-dsfl/algorithm/__init__.py +0 -3
  29. blazefl-2.0.0.dev2/src/blazefl/core/__init__.pyi +0 -6
  30. blazefl-2.0.0.dev2/tests/test_core/test_client_trainer.py +0 -84
  31. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/FUNDING.yml +0 -0
  32. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  33. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  34. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/dependabot.yml +0 -0
  35. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/workflows/ci.yaml +0 -0
  36. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/workflows/publish.yaml +0 -0
  37. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.github/workflows/sphinx.yaml +0 -0
  38. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.gitignore +0 -0
  39. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.pre-commit-config.yaml +0 -0
  40. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/.python-version +0 -0
  41. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/CODE_OF_CONDUCT.md +0 -0
  42. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/LICENSE +0 -0
  43. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/Makefile +0 -0
  44. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/README.md +0 -0
  45. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/Makefile +0 -0
  46. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/architecture.png +0 -0
  47. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/benchmark_cnn.png +0 -0
  48. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/benchmark_resnet18.png +0 -0
  49. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/logo.svg +0 -0
  50. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/logo_square.svg +0 -0
  51. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/imgs/ogp.png +0 -0
  52. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/make.bat +0 -0
  53. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/_static/favicon.ico +0 -0
  54. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/_static/logo.png +0 -0
  55. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/_static/logo_square.png +0 -0
  56. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/_templates/autosummary/class.rst +0 -0
  57. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/_templates/autosummary/module.rst +0 -0
  58. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/benchmark.rst +0 -0
  59. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/conf.py +0 -0
  60. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/contribute.rst +0 -0
  61. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/example.rst +0 -0
  62. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/index.rst +0 -0
  63. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/install.rst +0 -0
  64. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/overview.rst +0 -0
  65. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/docs/source/reference.rst +0 -0
  66. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/.gitignore +0 -0
  67. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/.python-version +0 -0
  68. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/Makefile +0 -0
  69. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/README.md +0 -0
  70. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/config/config.yaml +0 -0
  71. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
  72. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
  73. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/functional.py +0 -0
  74. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/models/__init__.py +0 -0
  75. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/models/selector.py +0 -0
  76. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/.gitignore +0 -0
  77. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/.python-version +0 -0
  78. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/Makefile +0 -0
  79. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/README.md +0 -0
  80. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
  81. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/dataset.py +0 -0
  82. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/functional.py +0 -0
  83. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/models/__init__.py +0 -0
  84. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/models/selector.py +0 -0
  85. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/pyproject.toml +0 -0
  86. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/.gitignore +0 -0
  87. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/.python-version +0 -0
  88. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/Makefile +0 -0
  89. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/README.md +0 -0
  90. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
  91. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/config/config.yaml +0 -0
  92. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
  93. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
  94. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
  95. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/__init__.py +0 -0
  96. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/cnn.py +0 -0
  97. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/selector.py +0 -0
  98. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/pyproject.toml +0 -0
  99. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/__init__.py +0 -0
  100. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/model_selector.py +0 -0
  101. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/model_selector.pyi +0 -0
  102. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/partitioned_dataset.py +0 -0
  103. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/core/partitioned_dataset.pyi +0 -0
  104. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/py.typed +0 -0
  105. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/dataset.py +0 -0
  106. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/dataset.pyi +0 -0
  107. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/seed.py +0 -0
  108. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/seed.pyi +0 -0
  109. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/serialize.py +0 -0
  110. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/src/blazefl/utils/serialize.pyi +0 -0
  111. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/__init__.py +0 -0
  112. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/conftest.py +0 -0
  113. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_contrib/__init__.py +0 -0
  114. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_core/__init__.py +0 -0
  115. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_core/test_model_selector.py +0 -0
  116. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_core/test_partitioned_dataset.py +0 -0
  117. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_utils/__init__.py +0 -0
  118. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_utils/test_dataset.py +0 -0
  119. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_utils/test_seed.py +0 -0
  120. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev4}/tests/test_utils/test_serialize.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blazefl
3
- Version: 2.0.0.dev2
3
+ Version: 2.0.0.dev4
4
4
  Summary: A blazing-fast and lightweight simulation framework for Federated Learning.
5
5
  Author-email: kitsuyaazuma <kitsuyaazuma@gmail.com>
6
6
  License-File: LICENSE
@@ -12,11 +12,10 @@ from blazefl.contrib import (
12
12
  FedAvgServerHandler,
13
13
  )
14
14
  from blazefl.contrib.fedavg import FedAvgDownlinkPackage, FedAvgUplinkPackage
15
- from blazefl.core import ModelSelector, PartitionedDataset
15
+ from blazefl.core import ModelSelector, MultiThreadClientTrainer, PartitionedDataset
16
16
  from blazefl.utils import seed_everything
17
17
  from omegaconf import DictConfig, OmegaConf
18
18
 
19
- from core.client_trainer import MultiThreadClientTrainer
20
19
  from dataset import PartitionedCIFAR10
21
20
  from models import FedAvgModelSelector
22
21
 
@@ -5,7 +5,7 @@ description = "Add your description here"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.13"
7
7
  dependencies = [
8
- "blazefl>=2.0.0dev1",
8
+ "blazefl>=2.0.0dev2",
9
9
  "hydra-core>=1.3.2",
10
10
  "torch>=2.7.1",
11
11
  "torchvision>=0.22.1",
@@ -10,16 +10,16 @@ sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d
10
10
 
11
11
  [[package]]
12
12
  name = "blazefl"
13
- version = "2.0.0.dev1"
13
+ version = "2.0.0.dev2"
14
14
  source = { registry = "https://pypi.org/simple" }
15
15
  dependencies = [
16
16
  { name = "numpy" },
17
17
  { name = "torch" },
18
18
  { name = "tqdm" },
19
19
  ]
20
- sdist = { url = "https://files.pythonhosted.org/packages/9d/4a/b82e557a19b3dc957e43fc29c73bb3b9ddbb78575de20e2a98a7db954c51/blazefl-2.0.0.dev1.tar.gz", hash = "sha256:aeef9d1a6835ae240b9772144a1ca8ee2691a9252a980173b1c81d1c7c453596", size = 591232, upload_time = "2025-06-10T08:28:25.786Z" }
20
+ sdist = { url = "https://files.pythonhosted.org/packages/9b/92/3f668d7259a81d13d5d2e6089d8022629a09058ddbcae33f5873d3871656/blazefl-2.0.0.dev2.tar.gz", hash = "sha256:28a3e09d6f6cec8d8e8bb4da90adaa0c795980c74f9d4ce65024ba762726b595", size = 613563, upload_time = "2025-06-10T09:24:08.981Z" }
21
21
  wheels = [
22
- { url = "https://files.pythonhosted.org/packages/99/c0/3a4e12feee7a3a9e52a3eba30816a7ce7d4a11e3b11989b72c01419e1116/blazefl-2.0.0.dev1-py3-none-any.whl", hash = "sha256:1a158f673ac8f644b0a273d6cdb6d32b0ff17a1f479dffd930a06873f023042a", size = 23714, upload_time = "2025-06-10T08:28:24.297Z" },
22
+ { url = "https://files.pythonhosted.org/packages/e3/7c/841c8b9d22caaefc1dcc167076292fd1479309b3e11eb6fc1b24141792a8/blazefl-2.0.0.dev2-py3-none-any.whl", hash = "sha256:88639018166dafa79fde59aa5575bfb7277bf9e0e4497f22a59e9c7ec34efa4b", size = 23999, upload_time = "2025-06-10T09:24:07.553Z" },
23
23
  ]
24
24
 
25
25
  [[package]]
@@ -50,7 +50,7 @@ dev = [
50
50
 
51
51
  [package.metadata]
52
52
  requires-dist = [
53
- { name = "blazefl", specifier = ">=2.0.0.dev1" },
53
+ { name = "blazefl", specifier = ">=2.0.0.dev2" },
54
54
  { name = "hydra-core", specifier = ">=1.3.2" },
55
55
  { name = "torch", specifier = ">=2.7.1" },
56
56
  { name = "torchvision", specifier = ">=0.22.1" },
@@ -14,4 +14,5 @@ dataset_root_dir: /tmp/quickstart-fedavg/dataset
14
14
  dataset_split_dir: /tmp/quickstart-fedavg/split
15
15
  share_dir: /tmp/quickstart-fedavg/share
16
16
  state_dir: /tmp/quickstart-fedavg/state
17
- serial: false
17
+ parallel: true
18
+ ipc_mode: storage
@@ -6,9 +6,9 @@ import hydra
6
6
  import torch
7
7
  import torch.multiprocessing as mp
8
8
  from blazefl.contrib import (
9
- FedAvgParallelClientTrainer,
10
- FedAvgSerialClientTrainer,
11
- FedAvgServerHandler,
9
+ FedAvgBaseClientTrainer,
10
+ FedAvgBaseServerHandler,
11
+ FedAvgProcessPoolClientTrainer,
12
12
  )
13
13
  from blazefl.utils import seed_everything
14
14
  from hydra.core import hydra_config
@@ -22,8 +22,8 @@ from models import FedAvgModelSelector
22
22
  class FedAvgPipeline:
23
23
  def __init__(
24
24
  self,
25
- handler: FedAvgServerHandler,
26
- trainer: FedAvgSerialClientTrainer | FedAvgParallelClientTrainer,
25
+ handler: FedAvgBaseServerHandler,
26
+ trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer,
27
27
  writer: SummaryWriter,
28
28
  ) -> None:
29
29
  self.handler = handler
@@ -87,7 +87,7 @@ def main(cfg: DictConfig):
87
87
  )
88
88
  model_selector = FedAvgModelSelector(num_classes=10)
89
89
 
90
- handler = FedAvgServerHandler(
90
+ handler = FedAvgBaseServerHandler(
91
91
  model_selector=model_selector,
92
92
  model_name=cfg.model_name,
93
93
  dataset=dataset,
@@ -97,32 +97,33 @@ def main(cfg: DictConfig):
97
97
  sample_ratio=cfg.sample_ratio,
98
98
  batch_size=cfg.batch_size,
99
99
  )
100
- trainer: FedAvgSerialClientTrainer | FedAvgParallelClientTrainer | None = None
101
- if cfg.serial:
102
- trainer = FedAvgSerialClientTrainer(
100
+ trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer | None = None
101
+ if cfg.parallel:
102
+ trainer = FedAvgProcessPoolClientTrainer(
103
103
  model_selector=model_selector,
104
104
  model_name=cfg.model_name,
105
105
  dataset=dataset,
106
+ share_dir=share_dir,
107
+ state_dir=state_dir,
108
+ seed=cfg.seed,
106
109
  device=device,
107
110
  num_clients=cfg.num_clients,
108
111
  epochs=cfg.epochs,
109
112
  lr=cfg.lr,
110
113
  batch_size=cfg.batch_size,
114
+ num_parallels=cfg.num_parallels,
115
+ ipc_mode=cfg.ipc_mode,
111
116
  )
112
117
  else:
113
- trainer = FedAvgParallelClientTrainer(
118
+ trainer = FedAvgBaseClientTrainer(
114
119
  model_selector=model_selector,
115
120
  model_name=cfg.model_name,
116
121
  dataset=dataset,
117
- share_dir=share_dir,
118
- state_dir=state_dir,
119
- seed=cfg.seed,
120
122
  device=device,
121
123
  num_clients=cfg.num_clients,
122
124
  epochs=cfg.epochs,
123
125
  lr=cfg.lr,
124
126
  batch_size=cfg.batch_size,
125
- num_parallels=cfg.num_parallels,
126
127
  )
127
128
  pipeline = FedAvgPipeline(handler=handler, trainer=trainer, writer=writer)
128
129
  try:
@@ -0,0 +1,3 @@
1
+ from algorithm.dsfl import DSFLBaseServerHandler, DSFLProcessPoolClientTrainer
2
+
3
+ __all__ = ["DSFLBaseServerHandler", "DSFLProcessPoolClientTrainer"]
@@ -7,8 +7,8 @@ from pathlib import Path
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
  from blazefl.core import (
10
- ParallelClientTrainer,
11
- ServerHandler,
10
+ BaseServerHandler,
11
+ ProcessPoolClientTrainer,
12
12
  )
13
13
  from blazefl.utils import (
14
14
  FilteredDataset,
@@ -35,7 +35,7 @@ class DSFLDownlinkPackage:
35
35
  next_indices: torch.Tensor
36
36
 
37
37
 
38
- class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
38
+ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
39
39
  def __init__(
40
40
  self,
41
41
  model_selector: DSFLModelSelector,
@@ -116,7 +116,7 @@ class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
116
116
  era_soft_labels = F.softmax(mean_soft_labels / self.era_temperature, dim=0)
117
117
  global_soft_labels.append(era_soft_labels)
118
118
 
119
- DSFLServerHandler.distill(
119
+ DSFLBaseServerHandler.distill(
120
120
  self.model,
121
121
  self.kd_optimizer,
122
122
  self.dataset,
@@ -205,7 +205,7 @@ class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
205
205
  return avg_loss, avg_acc
206
206
 
207
207
  def get_summary(self) -> dict[str, float]:
208
- server_loss, server_acc = DSFLServerHandler.evaulate(
208
+ server_loss, server_acc = DSFLBaseServerHandler.evaulate(
209
209
  self.model,
210
210
  self.dataset.get_dataloader(
211
211
  type_="test",
@@ -233,7 +233,7 @@ class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
233
233
 
234
234
 
235
235
  @dataclass
236
- class DSFLDiskSharedData:
236
+ class DSFLClientConfig:
237
237
  model_selector: DSFLModelSelector
238
238
  model_name: str
239
239
  dataset: DSFLPartitionedDataset
@@ -245,7 +245,6 @@ class DSFLDiskSharedData:
245
245
  kd_lr: float
246
246
  cid: int
247
247
  seed: int
248
- payload: DSFLDownlinkPackage
249
248
  state_path: Path
250
249
 
251
250
 
@@ -257,8 +256,8 @@ class DSFLClientState:
257
256
  kd_optimizer: dict[str, torch.Tensor] | None
258
257
 
259
258
 
260
- class DSFLParallelClientTrainer(
261
- ParallelClientTrainer[DSFLUplinkPackage, DSFLDownlinkPackage, DSFLDiskSharedData]
259
+ class DSFLProcessPoolClientTrainer(
260
+ ProcessPoolClientTrainer[DSFLUplinkPackage, DSFLDownlinkPackage, DSFLClientConfig]
262
261
  ):
263
262
  def __init__(
264
263
  self,
@@ -300,82 +299,90 @@ class DSFLParallelClientTrainer(
300
299
  self.device = device
301
300
  self.num_clients = num_clients
302
301
  self.seed = seed
302
+ self.ipc_mode = "storage"
303
303
 
304
304
  if self.device == "cuda":
305
305
  self.device_count = torch.cuda.device_count()
306
306
 
307
307
  @staticmethod
308
- def process_client(path: Path, device: str) -> Path:
309
- data = torch.load(path, weights_only=False)
310
- assert isinstance(data, DSFLDiskSharedData)
311
-
312
- model = data.model_selector.select_model(data.model_name)
313
- optimizer = torch.optim.SGD(model.parameters(), lr=data.lr)
308
+ def worker(
309
+ config: DSFLClientConfig | Path,
310
+ payload: DSFLDownlinkPackage | Path,
311
+ device: str,
312
+ ) -> Path:
313
+ assert isinstance(config, Path) and isinstance(payload, Path)
314
+ config_path, payload_path = config, payload
315
+ c = torch.load(config_path, weights_only=False)
316
+ p = torch.load(payload_path, weights_only=False)
317
+ assert isinstance(c, DSFLClientConfig) and isinstance(p, DSFLDownlinkPackage)
318
+
319
+ model = c.model_selector.select_model(c.model_name)
320
+ optimizer = torch.optim.SGD(model.parameters(), lr=c.lr)
314
321
  kd_optimizer: torch.optim.SGD | None = None
315
322
 
316
323
  state: DSFLClientState | None = None
317
- if data.state_path.exists():
318
- state = torch.load(data.state_path, weights_only=False)
324
+ if c.state_path.exists():
325
+ state = torch.load(c.state_path, weights_only=False)
319
326
  assert isinstance(state, DSFLClientState)
320
327
  RandomState.set_random_state(state.random)
321
328
  model.load_state_dict(state.model)
322
329
  optimizer.load_state_dict(state.optimizer)
323
330
  if state.kd_optimizer is not None:
324
- kd_optimizer = torch.optim.SGD(model.parameters(), lr=data.kd_lr)
331
+ kd_optimizer = torch.optim.SGD(model.parameters(), lr=c.kd_lr)
325
332
  kd_optimizer.load_state_dict(state.kd_optimizer)
326
333
  else:
327
- seed_everything(data.seed, device=device)
334
+ seed_everything(c.seed, device=device)
328
335
 
329
336
  # Distill
330
- open_dataset = data.dataset.get_dataset(type_="open", cid=None)
331
- if data.payload.indices is not None and data.payload.soft_labels is not None:
332
- global_soft_labels = list(torch.unbind(data.payload.soft_labels, dim=0))
333
- global_indices = data.payload.indices.tolist()
337
+ open_dataset = c.dataset.get_dataset(type_="open", cid=None)
338
+ if p.indices is not None and p.soft_labels is not None:
339
+ global_soft_labels = list(torch.unbind(p.soft_labels, dim=0))
340
+ global_indices = p.indices.tolist()
334
341
  if kd_optimizer is None:
335
- kd_optimizer = torch.optim.SGD(model.parameters(), lr=data.kd_lr)
336
- DSFLServerHandler.distill(
342
+ kd_optimizer = torch.optim.SGD(model.parameters(), lr=c.kd_lr)
343
+ DSFLBaseServerHandler.distill(
337
344
  model=model,
338
345
  optimizer=kd_optimizer,
339
- dataset=data.dataset,
346
+ dataset=c.dataset,
340
347
  global_soft_labels=global_soft_labels,
341
348
  global_indices=global_indices,
342
- kd_epochs=data.kd_epochs,
343
- kd_batch_size=data.kd_batch_size,
349
+ kd_epochs=c.kd_epochs,
350
+ kd_batch_size=c.kd_batch_size,
344
351
  device=device,
345
352
  )
346
353
 
347
354
  # Train
348
- train_loader = data.dataset.get_dataloader(
355
+ train_loader = c.dataset.get_dataloader(
349
356
  type_="train",
350
- cid=data.cid,
351
- batch_size=data.batch_size,
357
+ cid=c.cid,
358
+ batch_size=c.batch_size,
352
359
  )
353
- DSFLParallelClientTrainer.train(
360
+ DSFLProcessPoolClientTrainer.train(
354
361
  model=model,
355
362
  optimizer=optimizer,
356
363
  train_loader=train_loader,
357
364
  device=device,
358
- epochs=data.epochs,
365
+ epochs=c.epochs,
359
366
  )
360
367
 
361
368
  # Predict
362
369
  open_loader = DataLoader(
363
- Subset(open_dataset, data.payload.next_indices.tolist()),
364
- batch_size=data.batch_size,
370
+ Subset(open_dataset, p.next_indices.tolist()),
371
+ batch_size=c.batch_size,
365
372
  )
366
- soft_labels = DSFLParallelClientTrainer.predict(
373
+ soft_labels = DSFLProcessPoolClientTrainer.predict(
367
374
  model=model,
368
375
  open_loader=open_loader,
369
376
  device=device,
370
377
  )
371
378
 
372
379
  # Evaluate
373
- test_loader = data.dataset.get_dataloader(
380
+ test_loader = c.dataset.get_dataloader(
374
381
  type_="test",
375
- cid=data.cid,
376
- batch_size=data.batch_size,
382
+ cid=c.cid,
383
+ batch_size=c.batch_size,
377
384
  )
378
- loss, acc = DSFLServerHandler.evaulate(
385
+ loss, acc = DSFLBaseServerHandler.evaulate(
379
386
  model=model,
380
387
  test_loader=test_loader,
381
388
  device=device,
@@ -383,19 +390,19 @@ class DSFLParallelClientTrainer(
383
390
 
384
391
  package = DSFLUplinkPackage(
385
392
  soft_labels=soft_labels,
386
- indices=data.payload.next_indices,
393
+ indices=p.next_indices,
387
394
  metadata={"loss": loss, "acc": acc},
388
395
  )
389
396
 
390
- torch.save(package, path)
397
+ torch.save(package, config_path)
391
398
  state = DSFLClientState(
392
399
  random=RandomState.get_random_state(device=device),
393
400
  model=model.state_dict(),
394
401
  optimizer=optimizer.state_dict(),
395
402
  kd_optimizer=kd_optimizer.state_dict() if kd_optimizer else None,
396
403
  )
397
- torch.save(state, data.state_path)
398
- return path
404
+ torch.save(state, c.state_path)
405
+ return config_path
399
406
 
400
407
  @staticmethod
401
408
  def train(
@@ -442,10 +449,8 @@ class DSFLParallelClientTrainer(
442
449
  soft_labels = torch.cat(soft_labels_list, dim=0)
443
450
  return soft_labels.cpu()
444
451
 
445
- def get_shared_data(
446
- self, cid: int, payload: DSFLDownlinkPackage
447
- ) -> DSFLDiskSharedData:
448
- data = DSFLDiskSharedData(
452
+ def get_client_config(self, cid: int) -> DSFLClientConfig:
453
+ data = DSFLClientConfig(
449
454
  model_selector=self.model_selector,
450
455
  model_name=self.model_name,
451
456
  dataset=self.dataset,
@@ -457,7 +462,6 @@ class DSFLParallelClientTrainer(
457
462
  kd_lr=self.kd_lr,
458
463
  cid=cid,
459
464
  seed=self.seed,
460
- payload=payload,
461
465
  state_path=self.state_dir.joinpath(f"{cid}.pt"),
462
466
  )
463
467
  return data
@@ -10,7 +10,7 @@ from hydra.core import hydra_config
10
10
  from omegaconf import DictConfig, OmegaConf
11
11
  from torch.utils.tensorboard.writer import SummaryWriter
12
12
 
13
- from algorithm import DSFLParallelClientTrainer, DSFLServerHandler
13
+ from algorithm import DSFLBaseServerHandler, DSFLProcessPoolClientTrainer
14
14
  from dataset import DSFLPartitionedDataset
15
15
  from models import DSFLModelSelector
16
16
 
@@ -18,8 +18,8 @@ from models import DSFLModelSelector
18
18
  class DSFLPipeline:
19
19
  def __init__(
20
20
  self,
21
- handler: DSFLServerHandler,
22
- trainer: DSFLParallelClientTrainer,
21
+ handler: DSFLBaseServerHandler,
22
+ trainer: DSFLProcessPoolClientTrainer,
23
23
  writer: SummaryWriter,
24
24
  ) -> None:
25
25
  self.handler = handler
@@ -82,7 +82,7 @@ def main(
82
82
 
83
83
  match cfg.algorithm.name:
84
84
  case "dsfl":
85
- handler = DSFLServerHandler(
85
+ handler = DSFLBaseServerHandler(
86
86
  model_selector=model_selector,
87
87
  model_name=cfg.model_name,
88
88
  dataset=dataset,
@@ -96,7 +96,7 @@ def main(
96
96
  device=device,
97
97
  sample_ratio=cfg.sample_ratio,
98
98
  )
99
- trainer = DSFLParallelClientTrainer(
99
+ trainer = DSFLProcessPoolClientTrainer(
100
100
  model_selector=model_selector,
101
101
  model_name=cfg.model_name,
102
102
  dataset=dataset,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blazefl"
3
- version = "2.0.0.dev2"
3
+ version = "2.0.0.dev4"
4
4
  description = "A blazing-fast and lightweight simulation framework for Federated Learning."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -6,13 +6,13 @@ extending the core functionalities of BlazeFL.
6
6
  """
7
7
 
8
8
  from blazefl.contrib.fedavg import (
9
- FedAvgParallelClientTrainer,
10
- FedAvgSerialClientTrainer,
11
- FedAvgServerHandler,
9
+ FedAvgBaseClientTrainer,
10
+ FedAvgBaseServerHandler,
11
+ FedAvgProcessPoolClientTrainer,
12
12
  )
13
13
 
14
14
  __all__ = [
15
- "FedAvgServerHandler",
16
- "FedAvgParallelClientTrainer",
17
- "FedAvgSerialClientTrainer",
15
+ "FedAvgBaseServerHandler",
16
+ "FedAvgProcessPoolClientTrainer",
17
+ "FedAvgBaseClientTrainer",
18
18
  ]