blazefl 2.0.0.dev2__tar.gz → 2.0.0.dev3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/PKG-INFO +1 -1
  2. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/main.py +1 -2
  3. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/pyproject.toml +1 -1
  4. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/uv.lock +4 -4
  5. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/main.py +9 -9
  6. blazefl-2.0.0.dev3/examples/step-by-step-dsfl/algorithm/__init__.py +3 -0
  7. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/algorithm/dsfl.py +11 -11
  8. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/main.py +5 -5
  9. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/pyproject.toml +1 -1
  10. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/contrib/__init__.py +6 -6
  11. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/contrib/fedavg.py +15 -13
  12. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/__init__.py +8 -8
  13. blazefl-2.0.0.dev3/src/blazefl/core/__init__.pyi +6 -0
  14. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/client_trainer.py +7 -13
  15. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/client_trainer.pyi +3 -4
  16. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/server_handler.py +1 -1
  17. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/server_handler.pyi +1 -1
  18. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_contrib/test_fedavg.py +12 -12
  19. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_core/test_client_trainer.py +5 -5
  20. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/uv.lock +1 -1
  21. blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/__init__.py +0 -5
  22. blazefl-2.0.0.dev2/examples/experimental-freethreaded/core/client_trainer.py +0 -45
  23. blazefl-2.0.0.dev2/examples/step-by-step-dsfl/algorithm/__init__.py +0 -3
  24. blazefl-2.0.0.dev2/src/blazefl/core/__init__.pyi +0 -6
  25. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/FUNDING.yml +0 -0
  26. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  27. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  28. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/dependabot.yml +0 -0
  29. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/workflows/ci.yaml +0 -0
  30. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/workflows/publish.yaml +0 -0
  31. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.github/workflows/sphinx.yaml +0 -0
  32. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.gitignore +0 -0
  33. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.pre-commit-config.yaml +0 -0
  34. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/.python-version +0 -0
  35. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/CODE_OF_CONDUCT.md +0 -0
  36. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/LICENSE +0 -0
  37. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/Makefile +0 -0
  38. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/README.md +0 -0
  39. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/Makefile +0 -0
  40. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/architecture.png +0 -0
  41. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/benchmark_cnn.png +0 -0
  42. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/benchmark_resnet18.png +0 -0
  43. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/logo.svg +0 -0
  44. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/logo_square.svg +0 -0
  45. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/imgs/ogp.png +0 -0
  46. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/make.bat +0 -0
  47. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/_static/favicon.ico +0 -0
  48. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/_static/logo.png +0 -0
  49. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/_static/logo_square.png +0 -0
  50. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/_templates/autosummary/class.rst +0 -0
  51. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/_templates/autosummary/module.rst +0 -0
  52. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/benchmark.rst +0 -0
  53. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/conf.py +0 -0
  54. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/contribute.rst +0 -0
  55. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/example.rst +0 -0
  56. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/index.rst +0 -0
  57. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/install.rst +0 -0
  58. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/overview.rst +0 -0
  59. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/docs/source/reference.rst +0 -0
  60. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/.gitignore +0 -0
  61. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/.python-version +0 -0
  62. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/Makefile +0 -0
  63. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/README.md +0 -0
  64. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/config/config.yaml +0 -0
  65. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
  66. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
  67. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/dataset/functional.py +0 -0
  68. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/models/__init__.py +0 -0
  69. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/experimental-freethreaded/models/selector.py +0 -0
  70. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/.gitignore +0 -0
  71. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/.python-version +0 -0
  72. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/Makefile +0 -0
  73. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/README.md +0 -0
  74. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/config/config.yaml +0 -0
  75. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
  76. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/dataset/dataset.py +0 -0
  77. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/dataset/functional.py +0 -0
  78. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/models/__init__.py +0 -0
  79. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/models/selector.py +0 -0
  80. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/quickstart-fedavg/pyproject.toml +0 -0
  81. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/.gitignore +0 -0
  82. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/.python-version +0 -0
  83. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/Makefile +0 -0
  84. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/README.md +0 -0
  85. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
  86. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/config/config.yaml +0 -0
  87. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
  88. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
  89. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
  90. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/models/__init__.py +0 -0
  91. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/models/cnn.py +0 -0
  92. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/models/selector.py +0 -0
  93. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/examples/step-by-step-dsfl/pyproject.toml +0 -0
  94. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/__init__.py +0 -0
  95. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/model_selector.py +0 -0
  96. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/model_selector.pyi +0 -0
  97. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/partitioned_dataset.py +0 -0
  98. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/core/partitioned_dataset.pyi +0 -0
  99. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/py.typed +0 -0
  100. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/__init__.py +0 -0
  101. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/__init__.pyi +0 -0
  102. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/dataset.py +0 -0
  103. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/dataset.pyi +0 -0
  104. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/seed.py +0 -0
  105. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/seed.pyi +0 -0
  106. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/serialize.py +0 -0
  107. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/src/blazefl/utils/serialize.pyi +0 -0
  108. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/__init__.py +0 -0
  109. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/conftest.py +0 -0
  110. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_contrib/__init__.py +0 -0
  111. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_core/__init__.py +0 -0
  112. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_core/test_model_selector.py +0 -0
  113. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_core/test_partitioned_dataset.py +0 -0
  114. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_utils/__init__.py +0 -0
  115. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_utils/test_dataset.py +0 -0
  116. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_utils/test_seed.py +0 -0
  117. {blazefl-2.0.0.dev2 → blazefl-2.0.0.dev3}/tests/test_utils/test_serialize.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blazefl
3
- Version: 2.0.0.dev2
3
+ Version: 2.0.0.dev3
4
4
  Summary: A blazing-fast and lightweight simulation framework for Federated Learning.
5
5
  Author-email: kitsuyaazuma <kitsuyaazuma@gmail.com>
6
6
  License-File: LICENSE
@@ -12,11 +12,10 @@ from blazefl.contrib import (
12
12
  FedAvgServerHandler,
13
13
  )
14
14
  from blazefl.contrib.fedavg import FedAvgDownlinkPackage, FedAvgUplinkPackage
15
- from blazefl.core import ModelSelector, PartitionedDataset
15
+ from blazefl.core import ModelSelector, MultiThreadClientTrainer, PartitionedDataset
16
16
  from blazefl.utils import seed_everything
17
17
  from omegaconf import DictConfig, OmegaConf
18
18
 
19
- from core.client_trainer import MultiThreadClientTrainer
20
19
  from dataset import PartitionedCIFAR10
21
20
  from models import FedAvgModelSelector
22
21
 
@@ -5,7 +5,7 @@ description = "Add your description here"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.13"
7
7
  dependencies = [
8
- "blazefl>=2.0.0dev1",
8
+ "blazefl>=2.0.0dev2",
9
9
  "hydra-core>=1.3.2",
10
10
  "torch>=2.7.1",
11
11
  "torchvision>=0.22.1",
@@ -10,16 +10,16 @@ sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d
10
10
 
11
11
  [[package]]
12
12
  name = "blazefl"
13
- version = "2.0.0.dev1"
13
+ version = "2.0.0.dev2"
14
14
  source = { registry = "https://pypi.org/simple" }
15
15
  dependencies = [
16
16
  { name = "numpy" },
17
17
  { name = "torch" },
18
18
  { name = "tqdm" },
19
19
  ]
20
- sdist = { url = "https://files.pythonhosted.org/packages/9d/4a/b82e557a19b3dc957e43fc29c73bb3b9ddbb78575de20e2a98a7db954c51/blazefl-2.0.0.dev1.tar.gz", hash = "sha256:aeef9d1a6835ae240b9772144a1ca8ee2691a9252a980173b1c81d1c7c453596", size = 591232, upload_time = "2025-06-10T08:28:25.786Z" }
20
+ sdist = { url = "https://files.pythonhosted.org/packages/9b/92/3f668d7259a81d13d5d2e6089d8022629a09058ddbcae33f5873d3871656/blazefl-2.0.0.dev2.tar.gz", hash = "sha256:28a3e09d6f6cec8d8e8bb4da90adaa0c795980c74f9d4ce65024ba762726b595", size = 613563, upload_time = "2025-06-10T09:24:08.981Z" }
21
21
  wheels = [
22
- { url = "https://files.pythonhosted.org/packages/99/c0/3a4e12feee7a3a9e52a3eba30816a7ce7d4a11e3b11989b72c01419e1116/blazefl-2.0.0.dev1-py3-none-any.whl", hash = "sha256:1a158f673ac8f644b0a273d6cdb6d32b0ff17a1f479dffd930a06873f023042a", size = 23714, upload_time = "2025-06-10T08:28:24.297Z" },
22
+ { url = "https://files.pythonhosted.org/packages/e3/7c/841c8b9d22caaefc1dcc167076292fd1479309b3e11eb6fc1b24141792a8/blazefl-2.0.0.dev2-py3-none-any.whl", hash = "sha256:88639018166dafa79fde59aa5575bfb7277bf9e0e4497f22a59e9c7ec34efa4b", size = 23999, upload_time = "2025-06-10T09:24:07.553Z" },
23
23
  ]
24
24
 
25
25
  [[package]]
@@ -50,7 +50,7 @@ dev = [
50
50
 
51
51
  [package.metadata]
52
52
  requires-dist = [
53
- { name = "blazefl", specifier = ">=2.0.0.dev1" },
53
+ { name = "blazefl", specifier = ">=2.0.0.dev2" },
54
54
  { name = "hydra-core", specifier = ">=1.3.2" },
55
55
  { name = "torch", specifier = ">=2.7.1" },
56
56
  { name = "torchvision", specifier = ">=0.22.1" },
@@ -6,9 +6,9 @@ import hydra
6
6
  import torch
7
7
  import torch.multiprocessing as mp
8
8
  from blazefl.contrib import (
9
- FedAvgParallelClientTrainer,
10
- FedAvgSerialClientTrainer,
11
- FedAvgServerHandler,
9
+ FedAvgBaseClientTrainer,
10
+ FedAvgBaseServerHandler,
11
+ FedAvgProcessPoolClientTrainer,
12
12
  )
13
13
  from blazefl.utils import seed_everything
14
14
  from hydra.core import hydra_config
@@ -22,8 +22,8 @@ from models import FedAvgModelSelector
22
22
  class FedAvgPipeline:
23
23
  def __init__(
24
24
  self,
25
- handler: FedAvgServerHandler,
26
- trainer: FedAvgSerialClientTrainer | FedAvgParallelClientTrainer,
25
+ handler: FedAvgBaseServerHandler,
26
+ trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer,
27
27
  writer: SummaryWriter,
28
28
  ) -> None:
29
29
  self.handler = handler
@@ -87,7 +87,7 @@ def main(cfg: DictConfig):
87
87
  )
88
88
  model_selector = FedAvgModelSelector(num_classes=10)
89
89
 
90
- handler = FedAvgServerHandler(
90
+ handler = FedAvgBaseServerHandler(
91
91
  model_selector=model_selector,
92
92
  model_name=cfg.model_name,
93
93
  dataset=dataset,
@@ -97,9 +97,9 @@ def main(cfg: DictConfig):
97
97
  sample_ratio=cfg.sample_ratio,
98
98
  batch_size=cfg.batch_size,
99
99
  )
100
- trainer: FedAvgSerialClientTrainer | FedAvgParallelClientTrainer | None = None
100
+ trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer | None = None
101
101
  if cfg.serial:
102
- trainer = FedAvgSerialClientTrainer(
102
+ trainer = FedAvgBaseClientTrainer(
103
103
  model_selector=model_selector,
104
104
  model_name=cfg.model_name,
105
105
  dataset=dataset,
@@ -110,7 +110,7 @@ def main(cfg: DictConfig):
110
110
  batch_size=cfg.batch_size,
111
111
  )
112
112
  else:
113
- trainer = FedAvgParallelClientTrainer(
113
+ trainer = FedAvgProcessPoolClientTrainer(
114
114
  model_selector=model_selector,
115
115
  model_name=cfg.model_name,
116
116
  dataset=dataset,
@@ -0,0 +1,3 @@
1
+ from algorithm.dsfl import DSFLBaseServerHandler, DSFLProcessPoolClientTrainer
2
+
3
+ __all__ = ["DSFLBaseServerHandler", "DSFLProcessPoolClientTrainer"]
@@ -7,8 +7,8 @@ from pathlib import Path
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
  from blazefl.core import (
10
- ParallelClientTrainer,
11
- ServerHandler,
10
+ BaseServerHandler,
11
+ ProcessPoolClientTrainer,
12
12
  )
13
13
  from blazefl.utils import (
14
14
  FilteredDataset,
@@ -35,7 +35,7 @@ class DSFLDownlinkPackage:
35
35
  next_indices: torch.Tensor
36
36
 
37
37
 
38
- class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
38
+ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
39
39
  def __init__(
40
40
  self,
41
41
  model_selector: DSFLModelSelector,
@@ -116,7 +116,7 @@ class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
116
116
  era_soft_labels = F.softmax(mean_soft_labels / self.era_temperature, dim=0)
117
117
  global_soft_labels.append(era_soft_labels)
118
118
 
119
- DSFLServerHandler.distill(
119
+ DSFLBaseServerHandler.distill(
120
120
  self.model,
121
121
  self.kd_optimizer,
122
122
  self.dataset,
@@ -205,7 +205,7 @@ class DSFLServerHandler(ServerHandler[DSFLUplinkPackage, DSFLDownlinkPackage]):
205
205
  return avg_loss, avg_acc
206
206
 
207
207
  def get_summary(self) -> dict[str, float]:
208
- server_loss, server_acc = DSFLServerHandler.evaulate(
208
+ server_loss, server_acc = DSFLBaseServerHandler.evaulate(
209
209
  self.model,
210
210
  self.dataset.get_dataloader(
211
211
  type_="test",
@@ -257,8 +257,8 @@ class DSFLClientState:
257
257
  kd_optimizer: dict[str, torch.Tensor] | None
258
258
 
259
259
 
260
- class DSFLParallelClientTrainer(
261
- ParallelClientTrainer[DSFLUplinkPackage, DSFLDownlinkPackage, DSFLDiskSharedData]
260
+ class DSFLProcessPoolClientTrainer(
261
+ ProcessPoolClientTrainer[DSFLUplinkPackage, DSFLDownlinkPackage, DSFLDiskSharedData]
262
262
  ):
263
263
  def __init__(
264
264
  self,
@@ -333,7 +333,7 @@ class DSFLParallelClientTrainer(
333
333
  global_indices = data.payload.indices.tolist()
334
334
  if kd_optimizer is None:
335
335
  kd_optimizer = torch.optim.SGD(model.parameters(), lr=data.kd_lr)
336
- DSFLServerHandler.distill(
336
+ DSFLBaseServerHandler.distill(
337
337
  model=model,
338
338
  optimizer=kd_optimizer,
339
339
  dataset=data.dataset,
@@ -350,7 +350,7 @@ class DSFLParallelClientTrainer(
350
350
  cid=data.cid,
351
351
  batch_size=data.batch_size,
352
352
  )
353
- DSFLParallelClientTrainer.train(
353
+ DSFLProcessPoolClientTrainer.train(
354
354
  model=model,
355
355
  optimizer=optimizer,
356
356
  train_loader=train_loader,
@@ -363,7 +363,7 @@ class DSFLParallelClientTrainer(
363
363
  Subset(open_dataset, data.payload.next_indices.tolist()),
364
364
  batch_size=data.batch_size,
365
365
  )
366
- soft_labels = DSFLParallelClientTrainer.predict(
366
+ soft_labels = DSFLProcessPoolClientTrainer.predict(
367
367
  model=model,
368
368
  open_loader=open_loader,
369
369
  device=device,
@@ -375,7 +375,7 @@ class DSFLParallelClientTrainer(
375
375
  cid=data.cid,
376
376
  batch_size=data.batch_size,
377
377
  )
378
- loss, acc = DSFLServerHandler.evaulate(
378
+ loss, acc = DSFLBaseServerHandler.evaulate(
379
379
  model=model,
380
380
  test_loader=test_loader,
381
381
  device=device,
@@ -10,7 +10,7 @@ from hydra.core import hydra_config
10
10
  from omegaconf import DictConfig, OmegaConf
11
11
  from torch.utils.tensorboard.writer import SummaryWriter
12
12
 
13
- from algorithm import DSFLParallelClientTrainer, DSFLServerHandler
13
+ from algorithm import DSFLBaseServerHandler, DSFLProcessPoolClientTrainer
14
14
  from dataset import DSFLPartitionedDataset
15
15
  from models import DSFLModelSelector
16
16
 
@@ -18,8 +18,8 @@ from models import DSFLModelSelector
18
18
  class DSFLPipeline:
19
19
  def __init__(
20
20
  self,
21
- handler: DSFLServerHandler,
22
- trainer: DSFLParallelClientTrainer,
21
+ handler: DSFLBaseServerHandler,
22
+ trainer: DSFLProcessPoolClientTrainer,
23
23
  writer: SummaryWriter,
24
24
  ) -> None:
25
25
  self.handler = handler
@@ -82,7 +82,7 @@ def main(
82
82
 
83
83
  match cfg.algorithm.name:
84
84
  case "dsfl":
85
- handler = DSFLServerHandler(
85
+ handler = DSFLBaseServerHandler(
86
86
  model_selector=model_selector,
87
87
  model_name=cfg.model_name,
88
88
  dataset=dataset,
@@ -96,7 +96,7 @@ def main(
96
96
  device=device,
97
97
  sample_ratio=cfg.sample_ratio,
98
98
  )
99
- trainer = DSFLParallelClientTrainer(
99
+ trainer = DSFLProcessPoolClientTrainer(
100
100
  model_selector=model_selector,
101
101
  model_name=cfg.model_name,
102
102
  dataset=dataset,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blazefl"
3
- version = "2.0.0.dev2"
3
+ version = "2.0.0.dev3"
4
4
  description = "A blazing-fast and lightweight simulation framework for Federated Learning."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -6,13 +6,13 @@ extending the core functionalities of BlazeFL.
6
6
  """
7
7
 
8
8
  from blazefl.contrib.fedavg import (
9
- FedAvgParallelClientTrainer,
10
- FedAvgSerialClientTrainer,
11
- FedAvgServerHandler,
9
+ FedAvgBaseClientTrainer,
10
+ FedAvgBaseServerHandler,
11
+ FedAvgProcessPoolClientTrainer,
12
12
  )
13
13
 
14
14
  __all__ = [
15
- "FedAvgServerHandler",
16
- "FedAvgParallelClientTrainer",
17
- "FedAvgSerialClientTrainer",
15
+ "FedAvgBaseServerHandler",
16
+ "FedAvgProcessPoolClientTrainer",
17
+ "FedAvgBaseClientTrainer",
18
18
  ]
@@ -8,11 +8,11 @@ from torch.utils.data import DataLoader
8
8
  from tqdm import tqdm
9
9
 
10
10
  from blazefl.core import (
11
+ BaseClientTrainer,
12
+ BaseServerHandler,
11
13
  ModelSelector,
12
- ParallelClientTrainer,
13
14
  PartitionedDataset,
14
- SerialClientTrainer,
15
- ServerHandler,
15
+ ProcessPoolClientTrainer,
16
16
  )
17
17
  from blazefl.utils import (
18
18
  RandomState,
@@ -53,7 +53,9 @@ class FedAvgDownlinkPackage:
53
53
  model_parameters: torch.Tensor
54
54
 
55
55
 
56
- class FedAvgServerHandler(ServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPackage]):
56
+ class FedAvgBaseServerHandler(
57
+ BaseServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPackage]
58
+ ):
57
59
  """
58
60
  Server-side handler for the Federated Averaging (FedAvg) algorithm.
59
61
 
@@ -85,7 +87,7 @@ class FedAvgServerHandler(ServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPacka
85
87
  batch_size: int,
86
88
  ) -> None:
87
89
  """
88
- Initialize the FedAvgServerHandler.
90
+ Initialize the FedAvgBaseServerHandler.
89
91
 
90
92
  Args:
91
93
  model_selector (ModelSelector): Selector for initializing the model.
@@ -232,7 +234,7 @@ class FedAvgServerHandler(ServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPacka
232
234
  return avg_loss, avg_acc
233
235
 
234
236
  def get_summary(self) -> dict[str, float]:
235
- server_loss, server_acc = FedAvgServerHandler.evaluate(
237
+ server_loss, server_acc = FedAvgBaseServerHandler.evaluate(
236
238
  self.model,
237
239
  self.dataset.get_dataloader(
238
240
  type_="test",
@@ -258,11 +260,11 @@ class FedAvgServerHandler(ServerHandler[FedAvgUplinkPackage, FedAvgDownlinkPacka
258
260
  return FedAvgDownlinkPackage(model_parameters)
259
261
 
260
262
 
261
- class FedAvgSerialClientTrainer(
262
- SerialClientTrainer[FedAvgUplinkPackage, FedAvgDownlinkPackage]
263
+ class FedAvgBaseClientTrainer(
264
+ BaseClientTrainer[FedAvgUplinkPackage, FedAvgDownlinkPackage]
263
265
  ):
264
266
  """
265
- Serial client trainer for the Federated Averaging (FedAvg) algorithm.
267
+ Base client trainer for the Federated Averaging (FedAvg) algorithm.
266
268
 
267
269
  This trainer processes clients sequentially, training and evaluating a local model
268
270
  for each client based on the server-provided model parameters.
@@ -291,7 +293,7 @@ class FedAvgSerialClientTrainer(
291
293
  lr: float,
292
294
  ) -> None:
293
295
  """
294
- Initialize the FedAvgSerialClientTrainer.
296
+ Initialize the FedAvgBaseClientTrainer.
295
297
 
296
298
  Args:
297
299
  model_selector (ModelSelector): Selector for initializing the local model.
@@ -462,8 +464,8 @@ class FedAvgDiskSharedData:
462
464
  state_path: Path
463
465
 
464
466
 
465
- class FedAvgParallelClientTrainer(
466
- ParallelClientTrainer[
467
+ class FedAvgProcessPoolClientTrainer(
468
+ ProcessPoolClientTrainer[
467
469
  FedAvgUplinkPackage, FedAvgDownlinkPackage, FedAvgDiskSharedData
468
470
  ]
469
471
  ):
@@ -573,7 +575,7 @@ class FedAvgParallelClientTrainer(
573
575
  cid=data.cid,
574
576
  batch_size=data.batch_size,
575
577
  )
576
- package = FedAvgParallelClientTrainer.train(
578
+ package = FedAvgProcessPoolClientTrainer.train(
577
579
  model=model,
578
580
  model_parameters=data.payload.model_parameters,
579
581
  train_loader=train_loader,
@@ -6,19 +6,19 @@ including client trainers, model selectors, partitioned datasets, and server han
6
6
  """
7
7
 
8
8
  from blazefl.core.client_trainer import (
9
- MultiThreadClientTrainer,
10
- ParallelClientTrainer,
11
- SerialClientTrainer,
9
+ BaseClientTrainer,
10
+ ProcessPoolClientTrainer,
11
+ ThreadPoolClientTrainer,
12
12
  )
13
13
  from blazefl.core.model_selector import ModelSelector
14
14
  from blazefl.core.partitioned_dataset import PartitionedDataset
15
- from blazefl.core.server_handler import ServerHandler
15
+ from blazefl.core.server_handler import BaseServerHandler
16
16
 
17
17
  __all__ = [
18
- "SerialClientTrainer",
19
- "ParallelClientTrainer",
20
- "MultiThreadClientTrainer",
18
+ "BaseClientTrainer",
19
+ "ProcessPoolClientTrainer",
20
+ "ThreadPoolClientTrainer",
21
21
  "ModelSelector",
22
22
  "PartitionedDataset",
23
- "ServerHandler",
23
+ "BaseServerHandler",
24
24
  ]
@@ -0,0 +1,6 @@
1
+ from blazefl.core.client_trainer import BaseClientTrainer as BaseClientTrainer, ProcessPoolClientTrainer as ProcessPoolClientTrainer, ThreadPoolClientTrainer as ThreadPoolClientTrainer
2
+ from blazefl.core.model_selector import ModelSelector as ModelSelector
3
+ from blazefl.core.partitioned_dataset import PartitionedDataset as PartitionedDataset
4
+ from blazefl.core.server_handler import BaseServerHandler as BaseServerHandler
5
+
6
+ __all__ = ['BaseClientTrainer', 'ProcessPoolClientTrainer', 'ThreadPoolClientTrainer', 'ModelSelector', 'PartitionedDataset', 'BaseServerHandler']
@@ -12,7 +12,7 @@ UplinkPackage = TypeVar("UplinkPackage")
12
12
  DownlinkPackage = TypeVar("DownlinkPackage", contravariant=True)
13
13
 
14
14
 
15
- class SerialClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
15
+ class BaseClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
16
16
  """
17
17
  Abstract base class for serial client training in federated learning.
18
18
 
@@ -50,7 +50,8 @@ class SerialClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
50
50
  DiskSharedData = TypeVar("DiskSharedData", covariant=True)
51
51
 
52
52
 
53
- class ParallelClientTrainer(
53
+ class ProcessPoolClientTrainer(
54
+ BaseClientTrainer[UplinkPackage, DownlinkPackage],
54
55
  Protocol[UplinkPackage, DownlinkPackage, DiskSharedData],
55
56
  ):
56
57
  """
@@ -74,16 +75,6 @@ class ParallelClientTrainer(
74
75
  device_count: int
75
76
  cache: list[UplinkPackage]
76
77
 
77
- def uplink_package(self) -> list[UplinkPackage]:
78
- """
79
- Prepare the data package to be sent from the client to the server.
80
-
81
- Returns:
82
- list[UplinkPackage]: A list of data packages prepared for uplink
83
- transmission.
84
- """
85
- ...
86
-
87
78
  def get_shared_data(self, cid: int, payload: DownlinkPackage) -> DiskSharedData:
88
79
  """
89
80
  Retrieve shared data for a given client ID and payload.
@@ -159,7 +150,10 @@ class ParallelClientTrainer(
159
150
  self.cache.append(package)
160
151
 
161
152
 
162
- class MultiThreadClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
153
+ class ThreadPoolClientTrainer(
154
+ BaseClientTrainer[UplinkPackage, DownlinkPackage],
155
+ Protocol[UplinkPackage, DownlinkPackage],
156
+ ):
163
157
  num_parallels: int
164
158
  device: str
165
159
  device_count: int
@@ -5,25 +5,24 @@ from typing import Protocol, TypeVar
5
5
  UplinkPackage = TypeVar('UplinkPackage')
6
6
  DownlinkPackage = TypeVar('DownlinkPackage', contravariant=True)
7
7
 
8
- class SerialClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
8
+ class BaseClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
9
9
  def uplink_package(self) -> list[UplinkPackage]: ...
10
10
  def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...
11
11
  DiskSharedData = TypeVar('DiskSharedData', covariant=True)
12
12
 
13
- class ParallelClientTrainer(Protocol[UplinkPackage, DownlinkPackage, DiskSharedData]):
13
+ class ProcessPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage], Protocol[UplinkPackage, DownlinkPackage, DiskSharedData]):
14
14
  num_parallels: int
15
15
  share_dir: Path
16
16
  device: str
17
17
  device_count: int
18
18
  cache: list[UplinkPackage]
19
- def uplink_package(self) -> list[UplinkPackage]: ...
20
19
  def get_shared_data(self, cid: int, payload: DownlinkPackage) -> DiskSharedData: ...
21
20
  def get_client_device(self, cid: int) -> str: ...
22
21
  @staticmethod
23
22
  def process_client(path: Path, device: str) -> Path: ...
24
23
  def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None: ...
25
24
 
26
- class MultiThreadClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
25
+ class ThreadPoolClientTrainer(BaseClientTrainer[UplinkPackage, DownlinkPackage], Protocol[UplinkPackage, DownlinkPackage]):
27
26
  num_parallels: int
28
27
  device: str
29
28
  device_count: int
@@ -4,7 +4,7 @@ UplinkPackage = TypeVar("UplinkPackage")
4
4
  DownlinkPackage = TypeVar("DownlinkPackage", covariant=True)
5
5
 
6
6
 
7
- class ServerHandler(Protocol[UplinkPackage, DownlinkPackage]):
7
+ class BaseServerHandler(Protocol[UplinkPackage, DownlinkPackage]):
8
8
  """
9
9
  Abstract base class for server-side operations in federated learning.
10
10
 
@@ -3,7 +3,7 @@ from typing import Protocol, TypeVar
3
3
  UplinkPackage = TypeVar('UplinkPackage')
4
4
  DownlinkPackage = TypeVar('DownlinkPackage', covariant=True)
5
5
 
6
- class ServerHandler(Protocol[UplinkPackage, DownlinkPackage]):
6
+ class BaseServerHandler(Protocol[UplinkPackage, DownlinkPackage]):
7
7
  def downlink_package(self) -> DownlinkPackage: ...
8
8
  def sample_clients(self) -> list[int]: ...
9
9
  def if_stop(self) -> bool: ...
@@ -10,9 +10,9 @@ import torch
10
10
  from torch.utils.data import DataLoader, Dataset
11
11
 
12
12
  from src.blazefl.contrib.fedavg import (
13
- FedAvgParallelClientTrainer,
14
- FedAvgSerialClientTrainer,
15
- FedAvgServerHandler,
13
+ FedAvgBaseClientTrainer,
14
+ FedAvgBaseServerHandler,
15
+ FedAvgProcessPoolClientTrainer,
16
16
  )
17
17
  from src.blazefl.core import ModelSelector, PartitionedDataset
18
18
 
@@ -86,7 +86,7 @@ def tmp_state_dir(tmp_path):
86
86
  return state_dir
87
87
 
88
88
 
89
- def test_server_and_serial_integration(model_selector, partitioned_dataset, device):
89
+ def test_server_and_base_integration(model_selector, partitioned_dataset, device):
90
90
  model_name = "dummy"
91
91
  global_round = 1
92
92
  num_clients = 3
@@ -95,7 +95,7 @@ def test_server_and_serial_integration(model_selector, partitioned_dataset, devi
95
95
  batch_size = 2
96
96
  lr = 0.01
97
97
 
98
- server = FedAvgServerHandler(
98
+ server = FedAvgBaseServerHandler(
99
99
  model_selector=model_selector,
100
100
  model_name=model_name,
101
101
  dataset=partitioned_dataset,
@@ -106,7 +106,7 @@ def test_server_and_serial_integration(model_selector, partitioned_dataset, devi
106
106
  batch_size=batch_size,
107
107
  )
108
108
 
109
- trainer = FedAvgSerialClientTrainer(
109
+ trainer = FedAvgBaseClientTrainer(
110
110
  model_selector=model_selector,
111
111
  model_name=model_name,
112
112
  dataset=partitioned_dataset,
@@ -133,7 +133,7 @@ def test_server_and_serial_integration(model_selector, partitioned_dataset, devi
133
133
  assert server.if_stop() is True
134
134
 
135
135
 
136
- def test_server_and_parallel_integration(
136
+ def test_server_and_process_pool_integration(
137
137
  model_selector, partitioned_dataset, device, tmp_share_dir, tmp_state_dir
138
138
  ):
139
139
  model_name = "dummy"
@@ -146,7 +146,7 @@ def test_server_and_parallel_integration(
146
146
  seed = 42
147
147
  num_parallels = 2
148
148
 
149
- server = FedAvgServerHandler(
149
+ server = FedAvgBaseServerHandler(
150
150
  model_selector=model_selector,
151
151
  model_name=model_name,
152
152
  dataset=partitioned_dataset,
@@ -157,7 +157,7 @@ def test_server_and_parallel_integration(
157
157
  batch_size=batch_size,
158
158
  )
159
159
 
160
- trainer = FedAvgParallelClientTrainer(
160
+ trainer = FedAvgProcessPoolClientTrainer(
161
161
  model_selector=model_selector,
162
162
  model_name=model_name,
163
163
  share_dir=tmp_share_dir,
@@ -193,7 +193,7 @@ def run_local_process(trainer, downlink, cids):
193
193
  trainer.local_process(downlink, cids)
194
194
 
195
195
 
196
- def test_server_and_parallel_integration_keyboard_interrupt(
196
+ def test_server_and_process_pool_integration_keyboard_interrupt(
197
197
  model_selector, partitioned_dataset, device, tmp_share_dir, tmp_state_dir
198
198
  ):
199
199
  model_name = "dummy"
@@ -206,7 +206,7 @@ def test_server_and_parallel_integration_keyboard_interrupt(
206
206
  seed = 42
207
207
  num_parallels = 10
208
208
 
209
- server = FedAvgServerHandler(
209
+ server = FedAvgBaseServerHandler(
210
210
  model_selector=model_selector,
211
211
  model_name=model_name,
212
212
  dataset=partitioned_dataset,
@@ -217,7 +217,7 @@ def test_server_and_parallel_integration_keyboard_interrupt(
217
217
  batch_size=batch_size,
218
218
  )
219
219
 
220
- trainer = FedAvgParallelClientTrainer(
220
+ trainer = FedAvgProcessPoolClientTrainer(
221
221
  model_selector=model_selector,
222
222
  model_name=model_name,
223
223
  share_dir=tmp_share_dir,
@@ -4,7 +4,7 @@ from pathlib import Path
4
4
  import pytest
5
5
  import torch
6
6
 
7
- from src.blazefl.core import ParallelClientTrainer
7
+ from src.blazefl.core import ProcessPoolClientTrainer
8
8
 
9
9
 
10
10
  @dataclass
@@ -24,8 +24,8 @@ class DiskSharedData:
24
24
  payload: DownlinkPackage
25
25
 
26
26
 
27
- class DummyParallelClientTrainer(
28
- ParallelClientTrainer[UplinkPackage, DownlinkPackage, DiskSharedData]
27
+ class DummyProcessPoolClientTrainer(
28
+ ProcessPoolClientTrainer[UplinkPackage, DownlinkPackage, DiskSharedData]
29
29
  ):
30
30
  def __init__(self, num_parallels: int, share_dir: Path, device: str):
31
31
  self.num_parallels = num_parallels
@@ -58,10 +58,10 @@ class DummyParallelClientTrainer(
58
58
 
59
59
  @pytest.mark.parametrize("num_parallels", [1, 2, 4])
60
60
  @pytest.mark.parametrize("cid_list", [[], [42], [0, 1, 2]])
61
- def test_parallel_client_trainer(
61
+ def test_process_pool_client_trainer(
62
62
  tmp_path: Path, num_parallels: int, cid_list: list[int]
63
63
  ) -> None:
64
- trainer = DummyParallelClientTrainer(
64
+ trainer = DummyProcessPoolClientTrainer(
65
65
  num_parallels=num_parallels, share_dir=tmp_path, device="cpu"
66
66
  )
67
67
 
@@ -82,7 +82,7 @@ wheels = [
82
82
 
83
83
  [[package]]
84
84
  name = "blazefl"
85
- version = "2.0.0.dev2"
85
+ version = "2.0.0.dev3"
86
86
  source = { editable = "." }
87
87
  dependencies = [
88
88
  { name = "numpy" },
@@ -1,5 +0,0 @@
1
- from core.client_trainer import MultiThreadClientTrainer
2
-
3
- __all__ = [
4
- "MultiThreadClientTrainer",
5
- ]
@@ -1,45 +0,0 @@
1
- from concurrent.futures import ThreadPoolExecutor, as_completed
2
- from typing import Protocol, TypeVar
3
-
4
- from tqdm import tqdm
5
-
6
- UplinkPackage = TypeVar("UplinkPackage")
7
- DownlinkPackage = TypeVar("DownlinkPackage", contravariant=True)
8
-
9
-
10
- class MultiThreadClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
11
- num_parallels: int
12
- device: str
13
- device_count: int
14
- cache: list[UplinkPackage]
15
-
16
- def process_client(
17
- self,
18
- cid: int,
19
- device: str,
20
- payload: DownlinkPackage,
21
- ) -> UplinkPackage: ...
22
-
23
- def get_client_device(self, cid: int) -> str:
24
- if self.device == "cuda":
25
- return f"cuda:{cid % self.device_count}"
26
- return self.device
27
-
28
- def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
29
- with ThreadPoolExecutor(max_workers=self.num_parallels) as executor:
30
- futures = []
31
- for cid in cid_list:
32
- device = self.get_client_device(cid)
33
- future = executor.submit(
34
- self.process_client,
35
- cid,
36
- device,
37
- payload,
38
- )
39
- futures.append(future)
40
-
41
- for future in tqdm(
42
- as_completed(futures), total=len(futures), desc="Client", leave=False
43
- ):
44
- result = future.result()
45
- self.cache.append(result)
@@ -1,3 +0,0 @@
1
- from algorithm.dsfl import DSFLParallelClientTrainer, DSFLServerHandler
2
-
3
- __all__ = ["DSFLServerHandler", "DSFLParallelClientTrainer"]
@@ -1,6 +0,0 @@
1
- from blazefl.core.client_trainer import MultiThreadClientTrainer as MultiThreadClientTrainer, ParallelClientTrainer as ParallelClientTrainer, SerialClientTrainer as SerialClientTrainer
2
- from blazefl.core.model_selector import ModelSelector as ModelSelector
3
- from blazefl.core.partitioned_dataset import PartitionedDataset as PartitionedDataset
4
- from blazefl.core.server_handler import ServerHandler as ServerHandler
5
-
6
- __all__ = ['SerialClientTrainer', 'ParallelClientTrainer', 'MultiThreadClientTrainer', 'ModelSelector', 'PartitionedDataset', 'ServerHandler']
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes