blazefl 2.0.0.dev3__tar.gz → 2.0.0.dev4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/PKG-INFO +1 -1
  2. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/config/config.yaml +2 -1
  3. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/main.py +8 -7
  4. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/algorithm/dsfl.py +43 -39
  5. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/pyproject.toml +1 -1
  6. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/contrib/fedavg.py +87 -47
  7. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/client_trainer.py +73 -22
  8. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/client_trainer.pyi +8 -6
  9. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/__init__.py +2 -0
  10. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/__init__.pyi +2 -1
  11. blazefl-2.0.0.dev4/src/blazefl/utils/ipc.py +33 -0
  12. blazefl-2.0.0.dev4/src/blazefl/utils/ipc.pyi +3 -0
  13. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_contrib/test_fedavg.py +4 -1
  14. blazefl-2.0.0.dev4/tests/test_core/test_client_trainer.py +126 -0
  15. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/uv.lock +1 -1
  16. blazefl-2.0.0.dev3/tests/test_core/test_client_trainer.py +0 -84
  17. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/FUNDING.yml +0 -0
  18. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  19. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  20. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/dependabot.yml +0 -0
  21. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/workflows/ci.yaml +0 -0
  22. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/workflows/publish.yaml +0 -0
  23. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.github/workflows/sphinx.yaml +0 -0
  24. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.gitignore +0 -0
  25. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.pre-commit-config.yaml +0 -0
  26. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/.python-version +0 -0
  27. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/CODE_OF_CONDUCT.md +0 -0
  28. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/LICENSE +0 -0
  29. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/Makefile +0 -0
  30. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/README.md +0 -0
  31. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/Makefile +0 -0
  32. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/architecture.png +0 -0
  33. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/benchmark_cnn.png +0 -0
  34. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/benchmark_resnet18.png +0 -0
  35. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/logo.svg +0 -0
  36. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/logo_square.svg +0 -0
  37. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/imgs/ogp.png +0 -0
  38. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/make.bat +0 -0
  39. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/_static/favicon.ico +0 -0
  40. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/_static/logo.png +0 -0
  41. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/_static/logo_square.png +0 -0
  42. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/_templates/autosummary/class.rst +0 -0
  43. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/_templates/autosummary/module.rst +0 -0
  44. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/benchmark.rst +0 -0
  45. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/conf.py +0 -0
  46. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/contribute.rst +0 -0
  47. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/example.rst +0 -0
  48. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/index.rst +0 -0
  49. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/install.rst +0 -0
  50. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/overview.rst +0 -0
  51. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/docs/source/reference.rst +0 -0
  52. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/.gitignore +0 -0
  53. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/.python-version +0 -0
  54. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/Makefile +0 -0
  55. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/README.md +0 -0
  56. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/config/config.yaml +0 -0
  57. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/__init__.py +0 -0
  58. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/dataset.py +0 -0
  59. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/dataset/functional.py +0 -0
  60. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/main.py +0 -0
  61. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/models/__init__.py +0 -0
  62. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/models/selector.py +0 -0
  63. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/pyproject.toml +0 -0
  64. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/experimental-freethreaded/uv.lock +0 -0
  65. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/.gitignore +0 -0
  66. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/.python-version +0 -0
  67. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/Makefile +0 -0
  68. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/README.md +0 -0
  69. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/__init__.py +0 -0
  70. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/dataset.py +0 -0
  71. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/dataset/functional.py +0 -0
  72. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/models/__init__.py +0 -0
  73. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/models/selector.py +0 -0
  74. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/quickstart-fedavg/pyproject.toml +0 -0
  75. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/.gitignore +0 -0
  76. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/.python-version +0 -0
  77. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/Makefile +0 -0
  78. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/README.md +0 -0
  79. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/algorithm/__init__.py +0 -0
  80. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/config/algorithm/dsfl.yaml +0 -0
  81. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/config/config.yaml +0 -0
  82. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/__init__.py +0 -0
  83. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/dataset.py +0 -0
  84. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/dataset/functional.py +0 -0
  85. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/main.py +0 -0
  86. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/__init__.py +0 -0
  87. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/cnn.py +0 -0
  88. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/models/selector.py +0 -0
  89. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/examples/step-by-step-dsfl/pyproject.toml +0 -0
  90. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/__init__.py +0 -0
  91. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/contrib/__init__.py +0 -0
  92. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/__init__.py +0 -0
  93. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/__init__.pyi +0 -0
  94. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/model_selector.py +0 -0
  95. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/model_selector.pyi +0 -0
  96. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/partitioned_dataset.py +0 -0
  97. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/partitioned_dataset.pyi +0 -0
  98. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/server_handler.py +0 -0
  99. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/core/server_handler.pyi +0 -0
  100. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/py.typed +0 -0
  101. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/dataset.py +0 -0
  102. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/dataset.pyi +0 -0
  103. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/seed.py +0 -0
  104. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/seed.pyi +0 -0
  105. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/serialize.py +0 -0
  106. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/src/blazefl/utils/serialize.pyi +0 -0
  107. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/__init__.py +0 -0
  108. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/conftest.py +0 -0
  109. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_contrib/__init__.py +0 -0
  110. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_core/__init__.py +0 -0
  111. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_core/test_model_selector.py +0 -0
  112. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_core/test_partitioned_dataset.py +0 -0
  113. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_utils/__init__.py +0 -0
  114. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_utils/test_dataset.py +0 -0
  115. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_utils/test_seed.py +0 -0
  116. {blazefl-2.0.0.dev3 → blazefl-2.0.0.dev4}/tests/test_utils/test_serialize.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blazefl
3
- Version: 2.0.0.dev3
3
+ Version: 2.0.0.dev4
4
4
  Summary: A blazing-fast and lightweight simulation framework for Federated Learning.
5
5
  Author-email: kitsuyaazuma <kitsuyaazuma@gmail.com>
6
6
  License-File: LICENSE
@@ -14,4 +14,5 @@ dataset_root_dir: /tmp/quickstart-fedavg/dataset
14
14
  dataset_split_dir: /tmp/quickstart-fedavg/split
15
15
  share_dir: /tmp/quickstart-fedavg/share
16
16
  state_dir: /tmp/quickstart-fedavg/state
17
- serial: false
17
+ parallel: true
18
+ ipc_mode: storage
@@ -98,31 +98,32 @@ def main(cfg: DictConfig):
98
98
  batch_size=cfg.batch_size,
99
99
  )
100
100
  trainer: FedAvgBaseClientTrainer | FedAvgProcessPoolClientTrainer | None = None
101
- if cfg.serial:
102
- trainer = FedAvgBaseClientTrainer(
101
+ if cfg.parallel:
102
+ trainer = FedAvgProcessPoolClientTrainer(
103
103
  model_selector=model_selector,
104
104
  model_name=cfg.model_name,
105
105
  dataset=dataset,
106
+ share_dir=share_dir,
107
+ state_dir=state_dir,
108
+ seed=cfg.seed,
106
109
  device=device,
107
110
  num_clients=cfg.num_clients,
108
111
  epochs=cfg.epochs,
109
112
  lr=cfg.lr,
110
113
  batch_size=cfg.batch_size,
114
+ num_parallels=cfg.num_parallels,
115
+ ipc_mode=cfg.ipc_mode,
111
116
  )
112
117
  else:
113
- trainer = FedAvgProcessPoolClientTrainer(
118
+ trainer = FedAvgBaseClientTrainer(
114
119
  model_selector=model_selector,
115
120
  model_name=cfg.model_name,
116
121
  dataset=dataset,
117
- share_dir=share_dir,
118
- state_dir=state_dir,
119
- seed=cfg.seed,
120
122
  device=device,
121
123
  num_clients=cfg.num_clients,
122
124
  epochs=cfg.epochs,
123
125
  lr=cfg.lr,
124
126
  batch_size=cfg.batch_size,
125
- num_parallels=cfg.num_parallels,
126
127
  )
127
128
  pipeline = FedAvgPipeline(handler=handler, trainer=trainer, writer=writer)
128
129
  try:
@@ -233,7 +233,7 @@ class DSFLBaseServerHandler(BaseServerHandler[DSFLUplinkPackage, DSFLDownlinkPac
233
233
 
234
234
 
235
235
  @dataclass
236
- class DSFLDiskSharedData:
236
+ class DSFLClientConfig:
237
237
  model_selector: DSFLModelSelector
238
238
  model_name: str
239
239
  dataset: DSFLPartitionedDataset
@@ -245,7 +245,6 @@ class DSFLDiskSharedData:
245
245
  kd_lr: float
246
246
  cid: int
247
247
  seed: int
248
- payload: DSFLDownlinkPackage
249
248
  state_path: Path
250
249
 
251
250
 
@@ -258,7 +257,7 @@ class DSFLClientState:
258
257
 
259
258
 
260
259
  class DSFLProcessPoolClientTrainer(
261
- ProcessPoolClientTrainer[DSFLUplinkPackage, DSFLDownlinkPackage, DSFLDiskSharedData]
260
+ ProcessPoolClientTrainer[DSFLUplinkPackage, DSFLDownlinkPackage, DSFLClientConfig]
262
261
  ):
263
262
  def __init__(
264
263
  self,
@@ -300,68 +299,76 @@ class DSFLProcessPoolClientTrainer(
300
299
  self.device = device
301
300
  self.num_clients = num_clients
302
301
  self.seed = seed
302
+ self.ipc_mode = "storage"
303
303
 
304
304
  if self.device == "cuda":
305
305
  self.device_count = torch.cuda.device_count()
306
306
 
307
307
  @staticmethod
308
- def process_client(path: Path, device: str) -> Path:
309
- data = torch.load(path, weights_only=False)
310
- assert isinstance(data, DSFLDiskSharedData)
311
-
312
- model = data.model_selector.select_model(data.model_name)
313
- optimizer = torch.optim.SGD(model.parameters(), lr=data.lr)
308
+ def worker(
309
+ config: DSFLClientConfig | Path,
310
+ payload: DSFLDownlinkPackage | Path,
311
+ device: str,
312
+ ) -> Path:
313
+ assert isinstance(config, Path) and isinstance(payload, Path)
314
+ config_path, payload_path = config, payload
315
+ c = torch.load(config_path, weights_only=False)
316
+ p = torch.load(payload_path, weights_only=False)
317
+ assert isinstance(c, DSFLClientConfig) and isinstance(p, DSFLDownlinkPackage)
318
+
319
+ model = c.model_selector.select_model(c.model_name)
320
+ optimizer = torch.optim.SGD(model.parameters(), lr=c.lr)
314
321
  kd_optimizer: torch.optim.SGD | None = None
315
322
 
316
323
  state: DSFLClientState | None = None
317
- if data.state_path.exists():
318
- state = torch.load(data.state_path, weights_only=False)
324
+ if c.state_path.exists():
325
+ state = torch.load(c.state_path, weights_only=False)
319
326
  assert isinstance(state, DSFLClientState)
320
327
  RandomState.set_random_state(state.random)
321
328
  model.load_state_dict(state.model)
322
329
  optimizer.load_state_dict(state.optimizer)
323
330
  if state.kd_optimizer is not None:
324
- kd_optimizer = torch.optim.SGD(model.parameters(), lr=data.kd_lr)
331
+ kd_optimizer = torch.optim.SGD(model.parameters(), lr=c.kd_lr)
325
332
  kd_optimizer.load_state_dict(state.kd_optimizer)
326
333
  else:
327
- seed_everything(data.seed, device=device)
334
+ seed_everything(c.seed, device=device)
328
335
 
329
336
  # Distill
330
- open_dataset = data.dataset.get_dataset(type_="open", cid=None)
331
- if data.payload.indices is not None and data.payload.soft_labels is not None:
332
- global_soft_labels = list(torch.unbind(data.payload.soft_labels, dim=0))
333
- global_indices = data.payload.indices.tolist()
337
+ open_dataset = c.dataset.get_dataset(type_="open", cid=None)
338
+ if p.indices is not None and p.soft_labels is not None:
339
+ global_soft_labels = list(torch.unbind(p.soft_labels, dim=0))
340
+ global_indices = p.indices.tolist()
334
341
  if kd_optimizer is None:
335
- kd_optimizer = torch.optim.SGD(model.parameters(), lr=data.kd_lr)
342
+ kd_optimizer = torch.optim.SGD(model.parameters(), lr=c.kd_lr)
336
343
  DSFLBaseServerHandler.distill(
337
344
  model=model,
338
345
  optimizer=kd_optimizer,
339
- dataset=data.dataset,
346
+ dataset=c.dataset,
340
347
  global_soft_labels=global_soft_labels,
341
348
  global_indices=global_indices,
342
- kd_epochs=data.kd_epochs,
343
- kd_batch_size=data.kd_batch_size,
349
+ kd_epochs=c.kd_epochs,
350
+ kd_batch_size=c.kd_batch_size,
344
351
  device=device,
345
352
  )
346
353
 
347
354
  # Train
348
- train_loader = data.dataset.get_dataloader(
355
+ train_loader = c.dataset.get_dataloader(
349
356
  type_="train",
350
- cid=data.cid,
351
- batch_size=data.batch_size,
357
+ cid=c.cid,
358
+ batch_size=c.batch_size,
352
359
  )
353
360
  DSFLProcessPoolClientTrainer.train(
354
361
  model=model,
355
362
  optimizer=optimizer,
356
363
  train_loader=train_loader,
357
364
  device=device,
358
- epochs=data.epochs,
365
+ epochs=c.epochs,
359
366
  )
360
367
 
361
368
  # Predict
362
369
  open_loader = DataLoader(
363
- Subset(open_dataset, data.payload.next_indices.tolist()),
364
- batch_size=data.batch_size,
370
+ Subset(open_dataset, p.next_indices.tolist()),
371
+ batch_size=c.batch_size,
365
372
  )
366
373
  soft_labels = DSFLProcessPoolClientTrainer.predict(
367
374
  model=model,
@@ -370,10 +377,10 @@ class DSFLProcessPoolClientTrainer(
370
377
  )
371
378
 
372
379
  # Evaluate
373
- test_loader = data.dataset.get_dataloader(
380
+ test_loader = c.dataset.get_dataloader(
374
381
  type_="test",
375
- cid=data.cid,
376
- batch_size=data.batch_size,
382
+ cid=c.cid,
383
+ batch_size=c.batch_size,
377
384
  )
378
385
  loss, acc = DSFLBaseServerHandler.evaulate(
379
386
  model=model,
@@ -383,19 +390,19 @@ class DSFLProcessPoolClientTrainer(
383
390
 
384
391
  package = DSFLUplinkPackage(
385
392
  soft_labels=soft_labels,
386
- indices=data.payload.next_indices,
393
+ indices=p.next_indices,
387
394
  metadata={"loss": loss, "acc": acc},
388
395
  )
389
396
 
390
- torch.save(package, path)
397
+ torch.save(package, config_path)
391
398
  state = DSFLClientState(
392
399
  random=RandomState.get_random_state(device=device),
393
400
  model=model.state_dict(),
394
401
  optimizer=optimizer.state_dict(),
395
402
  kd_optimizer=kd_optimizer.state_dict() if kd_optimizer else None,
396
403
  )
397
- torch.save(state, data.state_path)
398
- return path
404
+ torch.save(state, c.state_path)
405
+ return config_path
399
406
 
400
407
  @staticmethod
401
408
  def train(
@@ -442,10 +449,8 @@ class DSFLProcessPoolClientTrainer(
442
449
  soft_labels = torch.cat(soft_labels_list, dim=0)
443
450
  return soft_labels.cpu()
444
451
 
445
- def get_shared_data(
446
- self, cid: int, payload: DSFLDownlinkPackage
447
- ) -> DSFLDiskSharedData:
448
- data = DSFLDiskSharedData(
452
+ def get_client_config(self, cid: int) -> DSFLClientConfig:
453
+ data = DSFLClientConfig(
449
454
  model_selector=self.model_selector,
450
455
  model_name=self.model_name,
451
456
  dataset=self.dataset,
@@ -457,7 +462,6 @@ class DSFLProcessPoolClientTrainer(
457
462
  kd_lr=self.kd_lr,
458
463
  cid=cid,
459
464
  seed=self.seed,
460
- payload=payload,
461
465
  state_path=self.state_dir.joinpath(f"{cid}.pt"),
462
466
  )
463
467
  return data
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blazefl"
3
- version = "2.0.0.dev3"
3
+ version = "2.0.0.dev4"
4
4
  description = "A blazing-fast and lightweight simulation framework for Federated Learning."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -2,6 +2,7 @@ import random
2
2
  from copy import deepcopy
3
3
  from dataclasses import dataclass
4
4
  from pathlib import Path
5
+ from typing import Literal
5
6
 
6
7
  import torch
7
8
  from torch.utils.data import DataLoader
@@ -431,7 +432,7 @@ class FedAvgBaseClientTrainer(
431
432
 
432
433
 
433
434
  @dataclass
434
- class FedAvgDiskSharedData:
435
+ class FedAvgClientConfig:
435
436
  """
436
437
  Data structure representing shared data for parallel client training
437
438
  in the Federated Averaging (FedAvg) algorithm.
@@ -448,7 +449,6 @@ class FedAvgDiskSharedData:
448
449
  lr (float): Learning rate for the optimizer.
449
450
  cid (int): Client ID.
450
451
  seed (int): Seed for reproducibility.
451
- payload (FedAvgDownlinkPackage): Downlink package with global model parameters.
452
452
  state_path (Path): Path to save the client's random state.
453
453
  """
454
454
 
@@ -460,13 +460,12 @@ class FedAvgDiskSharedData:
460
460
  lr: float
461
461
  cid: int
462
462
  seed: int
463
- payload: FedAvgDownlinkPackage
464
463
  state_path: Path
465
464
 
466
465
 
467
466
  class FedAvgProcessPoolClientTrainer(
468
467
  ProcessPoolClientTrainer[
469
- FedAvgUplinkPackage, FedAvgDownlinkPackage, FedAvgDiskSharedData
468
+ FedAvgUplinkPackage, FedAvgDownlinkPackage, FedAvgClientConfig
470
469
  ]
471
470
  ):
472
471
  """
@@ -488,6 +487,8 @@ class FedAvgProcessPoolClientTrainer(
488
487
  lr (float): Learning rate for the optimizer.
489
488
  seed (int): Seed for reproducibility.
490
489
  num_parallels (int): Number of parallel processes for training.
490
+ ipc_mode (Literal["storage", "shared_memory"]):
491
+ Inter-process communication mode.
491
492
  device_count (int | None): Number of CUDA devices available (if using GPU).
492
493
  """
493
494
 
@@ -505,6 +506,7 @@ class FedAvgProcessPoolClientTrainer(
505
506
  lr: float,
506
507
  seed: int,
507
508
  num_parallels: int,
509
+ ipc_mode: Literal["storage", "shared_memory"],
508
510
  ) -> None:
509
511
  """
510
512
  Initialize the FedAvgParalleClientTrainer.
@@ -542,50 +544,93 @@ class FedAvgProcessPoolClientTrainer(
542
544
  self.device = device
543
545
  self.num_clients = num_clients
544
546
  self.seed = seed
547
+ self.ipc_mode = ipc_mode
545
548
 
546
549
  @staticmethod
547
- def process_client(path: Path, device: str) -> Path:
550
+ def worker(
551
+ config: FedAvgClientConfig | Path,
552
+ payload: FedAvgDownlinkPackage | Path,
553
+ device: str,
554
+ ) -> FedAvgUplinkPackage | Path:
548
555
  """
549
556
  Process a single client's local training and evaluation.
550
557
 
551
- This method is executed by a parallel process and handles data loading,
552
- training, evaluation, and saving results to a shared file.
558
+ This method is executed by a worker process and handles loading client
559
+ configuration and payload, performing the client-specific training,
560
+ and returning the result.
553
561
 
554
562
  Args:
555
- path (Path): Path to the shared data file containing client-specific
556
- information.
557
- device (str): Device to use for processing.
563
+ config (FedAvgClientConfig | Path):
564
+ The client's configuration data, or a path to a file containing
565
+ the configuration if `ipc_mode` is "storage".
566
+ payload (FedAvgDownlinkPackage | Path):
567
+ The downlink payload from the server, or a path to a file
568
+ containing the payload if `ipc_mode` is "storage".
569
+ device (str): Device to use for processing (e.g., "cpu", "cuda:0").
558
570
 
559
571
  Returns:
560
- Path: Path to the file with the processed results.
561
- """
562
- data = torch.load(path, weights_only=False)
563
- assert isinstance(data, FedAvgDiskSharedData)
564
-
565
- if data.state_path.exists():
566
- state = torch.load(data.state_path, weights_only=False)
567
- assert isinstance(state, RandomState)
568
- RandomState.set_random_state(state)
572
+ FedAvgUplinkPackage | Path:
573
+ The uplink package containing the client's results, or a path to
574
+ a file containing the package if `ipc_mode` is "storage".
575
+ """
576
+
577
+ def _storage_worker(
578
+ config_path: Path,
579
+ payload_path: Path,
580
+ device: str,
581
+ ) -> Path:
582
+ config = torch.load(config_path, weights_only=False)
583
+ assert isinstance(config, FedAvgClientConfig)
584
+ payload = torch.load(payload_path, weights_only=False)
585
+ assert isinstance(payload, FedAvgDownlinkPackage)
586
+ package = _shared_memory_worker(
587
+ config=config,
588
+ payload=payload,
589
+ device=device,
590
+ )
591
+ torch.save(package, config_path)
592
+ return config_path
593
+
594
+ def _shared_memory_worker(
595
+ config: FedAvgClientConfig,
596
+ payload: FedAvgDownlinkPackage,
597
+ device: str,
598
+ ) -> FedAvgUplinkPackage:
599
+ if config.state_path.exists():
600
+ state = torch.load(config.state_path, weights_only=False)
601
+ assert isinstance(state, RandomState)
602
+ RandomState.set_random_state(state)
603
+ else:
604
+ seed_everything(config.seed, device=device)
605
+
606
+ model = config.model_selector.select_model(config.model_name)
607
+ train_loader = config.dataset.get_dataloader(
608
+ type_="train",
609
+ cid=config.cid,
610
+ batch_size=config.batch_size,
611
+ )
612
+ package = FedAvgProcessPoolClientTrainer.train(
613
+ model=model,
614
+ model_parameters=payload.model_parameters,
615
+ train_loader=train_loader,
616
+ device=device,
617
+ epochs=config.epochs,
618
+ lr=config.lr,
619
+ )
620
+ torch.save(RandomState.get_random_state(device=device), config.state_path)
621
+ return package
622
+
623
+ if isinstance(config, Path) and isinstance(payload, Path):
624
+ return _storage_worker(config, payload, device)
625
+ elif isinstance(config, FedAvgClientConfig) and isinstance(
626
+ payload, FedAvgDownlinkPackage
627
+ ):
628
+ return _shared_memory_worker(config, payload, device)
569
629
  else:
570
- seed_everything(data.seed, device=device)
571
-
572
- model = data.model_selector.select_model(data.model_name)
573
- train_loader = data.dataset.get_dataloader(
574
- type_="train",
575
- cid=data.cid,
576
- batch_size=data.batch_size,
577
- )
578
- package = FedAvgProcessPoolClientTrainer.train(
579
- model=model,
580
- model_parameters=data.payload.model_parameters,
581
- train_loader=train_loader,
582
- device=device,
583
- epochs=data.epochs,
584
- lr=data.lr,
585
- )
586
- torch.save(package, path)
587
- torch.save(RandomState.get_random_state(device=device), data.state_path)
588
- return path
630
+ raise TypeError(
631
+ "Invalid types for config and payload."
632
+ " Expected FedAvgClientConfig and FedAvgDownlinkPackage or Path."
633
+ )
589
634
 
590
635
  @staticmethod
591
636
  def train(
@@ -636,21 +681,17 @@ class FedAvgProcessPoolClientTrainer(
636
681
 
637
682
  return FedAvgUplinkPackage(model_parameters, data_size)
638
683
 
639
- def get_shared_data(
640
- self, cid: int, payload: FedAvgDownlinkPackage
641
- ) -> FedAvgDiskSharedData:
684
+ def get_client_config(self, cid: int) -> FedAvgClientConfig:
642
685
  """
643
- Generate the shared data for a specific client.
686
+ Generate the client configuration for a specific client.
644
687
 
645
688
  Args:
646
689
  cid (int): Client ID.
647
- payload (FedAvgDownlinkPackage): Downlink package with global model
648
- parameters.
649
690
 
650
691
  Returns:
651
- FedAvgDiskSharedData: Shared data structure for the client.
692
+ FedAvgClientConfig: Client configuration data structure.
652
693
  """
653
- data = FedAvgDiskSharedData(
694
+ data = FedAvgClientConfig(
654
695
  model_selector=self.model_selector,
655
696
  model_name=self.model_name,
656
697
  dataset=self.dataset,
@@ -659,7 +700,6 @@ class FedAvgProcessPoolClientTrainer(
659
700
  lr=self.lr,
660
701
  cid=cid,
661
702
  seed=self.seed,
662
- payload=payload,
663
703
  state_path=self.state_dir.joinpath(f"{cid}.pt"),
664
704
  )
665
705
  return data
@@ -3,11 +3,13 @@ import signal
3
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
4
  from multiprocessing.pool import ApplyResult
5
5
  from pathlib import Path
6
- from typing import Protocol, TypeVar
6
+ from typing import Literal, Protocol, TypeVar
7
7
 
8
8
  import torch
9
9
  from tqdm import tqdm
10
10
 
11
+ from blazefl.utils import move_tensor_to_shared_memory
12
+
11
13
  UplinkPackage = TypeVar("UplinkPackage")
12
14
  DownlinkPackage = TypeVar("DownlinkPackage", contravariant=True)
13
15
 
@@ -47,12 +49,12 @@ class BaseClientTrainer(Protocol[UplinkPackage, DownlinkPackage]):
47
49
  ...
48
50
 
49
51
 
50
- DiskSharedData = TypeVar("DiskSharedData", covariant=True)
52
+ ClientConfig = TypeVar("ClientConfig")
51
53
 
52
54
 
53
55
  class ProcessPoolClientTrainer(
54
56
  BaseClientTrainer[UplinkPackage, DownlinkPackage],
55
- Protocol[UplinkPackage, DownlinkPackage, DiskSharedData],
57
+ Protocol[UplinkPackage, DownlinkPackage, ClientConfig],
56
58
  ):
57
59
  """
58
60
  Abstract base class for parallel client training in federated learning.
@@ -63,7 +65,12 @@ class ProcessPoolClientTrainer(
63
65
  Attributes:
64
66
  num_parallels (int): Number of parallel processes to use for client training.
65
67
  share_dir (Path): Directory path for sharing data between processes.
68
+ device (str): The primary device to use for computation (e.g., "cpu", "cuda").
69
+ device_count (int): The number of available CUDA devices, if `device` is "cuda".
66
70
  cache (list[UplinkPackage]): Cache to store uplink packages from clients.
71
+ ipc_mode (Literal["storage", "shared_memory"]): Inter-process communication
72
+ mode. "storage" uses disk for data exchange, "shared_memory" uses
73
+ shared memory for tensor data. Defaults to "storage".
67
74
 
68
75
  Raises:
69
76
  NotImplementedError: If the abstract methods are not implemented in a subclass.
@@ -74,17 +81,17 @@ class ProcessPoolClientTrainer(
74
81
  device: str
75
82
  device_count: int
76
83
  cache: list[UplinkPackage]
84
+ ipc_mode: Literal["storage", "shared_memory"] = "storage"
77
85
 
78
- def get_shared_data(self, cid: int, payload: DownlinkPackage) -> DiskSharedData:
86
+ def get_client_config(self, cid: int) -> ClientConfig:
79
87
  """
80
- Retrieve shared data for a given client ID and payload.
88
+ Retrieve the configuration for a given client ID.
81
89
 
82
90
  Args:
83
91
  cid (int): Client ID.
84
- payload (DownlinkPackage): The data package received from the server.
85
92
 
86
93
  Returns:
87
- DiskSharedData: The shared data associated with the client ID and payload.
94
+ ClientConfig: The configuration for the specified client.
88
95
  """
89
96
  ...
90
97
 
@@ -103,16 +110,29 @@ class ProcessPoolClientTrainer(
103
110
  return self.device
104
111
 
105
112
  @staticmethod
106
- def process_client(path: Path, device: str) -> Path:
113
+ def worker(
114
+ config: ClientConfig | Path, payload: DownlinkPackage | Path, device: str
115
+ ) -> UplinkPackage | Path:
107
116
  """
108
- Process a single client based on the provided path.
117
+ Process a single client's training task.
118
+
119
+ This method is executed by each worker process in the pool.
120
+ It handles loading client configuration and payload, performing
121
+ the client-specific operations, and returning the result.
109
122
 
110
123
  Args:
111
- path (Path): Path to the client's data file.
112
- device (str): Device to use for processing.
124
+ config (ClientConfig | Path):
125
+ The client's configuration data, or a path to a file containing
126
+ the configuration if `ipc_mode` is "storage".
127
+ payload (DownlinkPackage | Path):
128
+ The downlink payload from the server, or a path to a file
129
+ containing the payload if `ipc_mode` is "storage".
130
+ device (str): Device to use for processing (e.g., "cpu", "cuda:0").
113
131
 
114
132
  Returns:
115
- Path: Path to the processed client's data file.
133
+ UplinkPackage | Path:
134
+ The uplink package containing the client's results, or a path
135
+ to a file containing the package if `ipc_mode` is "storage".
116
136
  """
117
137
  ...
118
138
 
@@ -130,6 +150,13 @@ class ProcessPoolClientTrainer(
130
150
  Returns:
131
151
  None
132
152
  """
153
+ payload_path = Path()
154
+ if self.ipc_mode == "storage":
155
+ payload_path = self.share_dir.joinpath("payload.pkl")
156
+ torch.save(payload, payload_path)
157
+ else: # shared_memory
158
+ move_tensor_to_shared_memory(payload)
159
+
133
160
  with mp.Pool(
134
161
  processes=self.num_parallels,
135
162
  initializer=signal.signal,
@@ -137,16 +164,28 @@ class ProcessPoolClientTrainer(
137
164
  ) as pool:
138
165
  jobs: list[ApplyResult] = []
139
166
  for cid in cid_list:
140
- path = self.share_dir.joinpath(f"{cid}.pkl")
141
- data = self.get_shared_data(cid, payload)
167
+ config = self.get_client_config(cid)
142
168
  device = self.get_client_device(cid)
143
- torch.save(data, path)
144
- jobs.append(pool.apply_async(self.process_client, (path, device)))
169
+ if self.ipc_mode == "storage":
170
+ config_path = self.share_dir.joinpath(f"{cid}.pkl")
171
+ torch.save(config, config_path)
172
+ jobs.append(
173
+ pool.apply_async(
174
+ self.worker, (config_path, payload_path, device)
175
+ )
176
+ )
177
+ else: # shared_memory
178
+ jobs.append(
179
+ pool.apply_async(self.worker, (config, payload, device))
180
+ )
145
181
 
146
182
  for job in tqdm(jobs, desc="Client", leave=False):
147
- path = job.get()
148
- assert isinstance(path, Path)
149
- package = torch.load(path, weights_only=False)
183
+ result = job.get()
184
+ if self.ipc_mode == "storage":
185
+ assert isinstance(result, Path)
186
+ package = torch.load(result, weights_only=False)
187
+ else: # shared_memory
188
+ package = result
150
189
  self.cache.append(package)
151
190
 
152
191
 
@@ -159,12 +198,24 @@ class ThreadPoolClientTrainer(
159
198
  device_count: int
160
199
  cache: list[UplinkPackage]
161
200
 
162
- def process_client(
201
+ def worker(
163
202
  self,
164
203
  cid: int,
165
204
  device: str,
166
205
  payload: DownlinkPackage,
167
- ) -> UplinkPackage: ...
206
+ ) -> UplinkPackage:
207
+ """
208
+ Process a single client's training task in a thread.
209
+
210
+ Args:
211
+ cid (int): The client ID.
212
+ device (str): The device to use for processing this client.
213
+ payload (DownlinkPackage): The data package received from the server.
214
+
215
+ Returns:
216
+ UplinkPackage: The uplink package containing the client's results.
217
+ """
218
+ ...
168
219
 
169
220
  def get_client_device(self, cid: int) -> str:
170
221
  if self.device == "cuda":
@@ -177,7 +228,7 @@ class ThreadPoolClientTrainer(
177
228
  for cid in cid_list:
178
229
  device = self.get_client_device(cid)
179
230
  future = executor.submit(
180
- self.process_client,
231
+ self.worker,
181
232
  cid,
182
233
  device,
183
234
  payload,