flwr 1.24.0__py3-none-any.whl → 1.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. flwr/__init__.py +1 -1
  2. flwr/app/__init__.py +4 -1
  3. flwr/app/message_type.py +29 -0
  4. flwr/app/metadata.py +5 -2
  5. flwr/app/user_config.py +19 -0
  6. flwr/cli/app.py +37 -19
  7. flwr/cli/app_cmd/publish.py +25 -75
  8. flwr/cli/app_cmd/review.py +25 -66
  9. flwr/cli/auth_plugin/auth_plugin.py +5 -10
  10. flwr/cli/auth_plugin/noop_auth_plugin.py +1 -2
  11. flwr/cli/auth_plugin/oidc_cli_plugin.py +38 -38
  12. flwr/cli/build.py +15 -28
  13. flwr/cli/config/__init__.py +21 -0
  14. flwr/cli/config/ls.py +71 -0
  15. flwr/cli/config_migration.py +297 -0
  16. flwr/cli/config_utils.py +63 -156
  17. flwr/cli/constant.py +71 -0
  18. flwr/cli/federation/__init__.py +0 -2
  19. flwr/cli/federation/ls.py +256 -64
  20. flwr/cli/flower_config.py +429 -0
  21. flwr/cli/install.py +23 -62
  22. flwr/cli/log.py +23 -37
  23. flwr/cli/login/login.py +29 -63
  24. flwr/cli/ls.py +72 -61
  25. flwr/cli/new/new.py +98 -309
  26. flwr/cli/pull.py +19 -37
  27. flwr/cli/run/run.py +87 -100
  28. flwr/cli/run_utils.py +23 -5
  29. flwr/cli/stop.py +33 -74
  30. flwr/cli/supernode/ls.py +35 -62
  31. flwr/cli/supernode/register.py +31 -80
  32. flwr/cli/supernode/unregister.py +24 -70
  33. flwr/cli/typing.py +200 -0
  34. flwr/cli/utils.py +160 -412
  35. flwr/client/grpc_adapter_client/connection.py +2 -2
  36. flwr/client/grpc_rere_client/connection.py +9 -6
  37. flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
  38. flwr/client/message_handler/message_handler.py +2 -1
  39. flwr/client/mod/centraldp_mods.py +1 -1
  40. flwr/client/mod/localdp_mod.py +1 -1
  41. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  42. flwr/client/rest_client/connection.py +6 -4
  43. flwr/client/run_info_store.py +2 -1
  44. flwr/clientapp/client_app.py +2 -1
  45. flwr/common/__init__.py +3 -2
  46. flwr/common/args.py +5 -5
  47. flwr/common/config.py +12 -17
  48. flwr/common/constant.py +3 -16
  49. flwr/common/context.py +2 -1
  50. flwr/common/exit/exit.py +4 -4
  51. flwr/common/exit/exit_code.py +6 -0
  52. flwr/common/grpc.py +2 -1
  53. flwr/common/logger.py +1 -1
  54. flwr/common/message.py +1 -1
  55. flwr/common/retry_invoker.py +13 -5
  56. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -2
  57. flwr/common/serde.py +13 -5
  58. flwr/common/telemetry.py +1 -1
  59. flwr/common/typing.py +10 -3
  60. flwr/compat/client/app.py +6 -9
  61. flwr/compat/client/grpc_client/connection.py +2 -1
  62. flwr/compat/common/constant.py +29 -0
  63. flwr/compat/server/app.py +1 -1
  64. flwr/proto/clientappio_pb2.py +2 -2
  65. flwr/proto/clientappio_pb2_grpc.py +104 -88
  66. flwr/proto/clientappio_pb2_grpc.pyi +140 -80
  67. flwr/proto/federation_pb2.py +5 -3
  68. flwr/proto/federation_pb2.pyi +32 -2
  69. flwr/proto/fleet_pb2.py +10 -10
  70. flwr/proto/fleet_pb2.pyi +5 -1
  71. flwr/proto/run_pb2.py +18 -26
  72. flwr/proto/run_pb2.pyi +10 -58
  73. flwr/proto/serverappio_pb2.py +2 -2
  74. flwr/proto/serverappio_pb2_grpc.py +138 -207
  75. flwr/proto/serverappio_pb2_grpc.pyi +189 -155
  76. flwr/proto/simulationio_pb2.py +2 -2
  77. flwr/proto/simulationio_pb2_grpc.py +62 -90
  78. flwr/proto/simulationio_pb2_grpc.pyi +95 -55
  79. flwr/server/app.py +7 -13
  80. flwr/server/compat/grid_client_proxy.py +2 -1
  81. flwr/server/grid/grpc_grid.py +5 -5
  82. flwr/server/serverapp/app.py +11 -4
  83. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
  84. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +13 -12
  85. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  86. flwr/server/superlink/linkstate/__init__.py +2 -2
  87. flwr/server/superlink/linkstate/in_memory_linkstate.py +36 -10
  88. flwr/server/superlink/linkstate/linkstate.py +34 -21
  89. flwr/server/superlink/linkstate/linkstate_factory.py +16 -8
  90. flwr/server/superlink/linkstate/{sqlite_linkstate.py → sql_linkstate.py} +471 -516
  91. flwr/server/superlink/linkstate/utils.py +49 -2
  92. flwr/server/superlink/serverappio/serverappio_servicer.py +1 -33
  93. flwr/server/superlink/simulation/simulationio_servicer.py +0 -19
  94. flwr/server/utils/validator.py +1 -1
  95. flwr/server/workflow/default_workflows.py +2 -1
  96. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
  97. flwr/serverapp/strategy/bulyan.py +7 -1
  98. flwr/serverapp/strategy/dp_fixed_clipping.py +9 -1
  99. flwr/serverapp/strategy/fedavg.py +1 -1
  100. flwr/serverapp/strategy/fedxgb_cyclic.py +1 -1
  101. flwr/simulation/ray_transport/ray_client_proxy.py +2 -6
  102. flwr/simulation/run_simulation.py +3 -12
  103. flwr/simulation/simulationio_connection.py +3 -3
  104. flwr/{common → supercore}/address.py +7 -33
  105. flwr/supercore/app_utils.py +2 -1
  106. flwr/supercore/constant.py +27 -2
  107. flwr/supercore/corestate/{sqlite_corestate.py → sql_corestate.py} +19 -23
  108. flwr/supercore/credential_store/__init__.py +33 -0
  109. flwr/supercore/credential_store/credential_store.py +34 -0
  110. flwr/supercore/credential_store/file_credential_store.py +76 -0
  111. flwr/{common → supercore}/date.py +0 -11
  112. flwr/supercore/ffs/disk_ffs.py +1 -1
  113. flwr/supercore/object_store/object_store_factory.py +14 -6
  114. flwr/supercore/object_store/{sqlite_object_store.py → sql_object_store.py} +115 -117
  115. flwr/supercore/sql_mixin.py +315 -0
  116. flwr/{cli/new/templates → supercore/state}/__init__.py +2 -2
  117. flwr/{cli/new/templates/app/code/flwr_tune → supercore/state/alembic}/__init__.py +2 -2
  118. flwr/supercore/state/alembic/env.py +103 -0
  119. flwr/supercore/state/alembic/script.py.mako +43 -0
  120. flwr/supercore/state/alembic/utils.py +239 -0
  121. flwr/{cli/new/templates/app → supercore/state/alembic/versions}/__init__.py +2 -2
  122. flwr/supercore/state/alembic/versions/rev_2026_01_28_initialize_migration_of_state_tables.py +200 -0
  123. flwr/supercore/state/schema/README.md +121 -0
  124. flwr/{cli/new/templates/app/code → supercore/state/schema}/__init__.py +2 -2
  125. flwr/supercore/state/schema/corestate_tables.py +36 -0
  126. flwr/supercore/state/schema/linkstate_tables.py +152 -0
  127. flwr/supercore/state/schema/objectstore_tables.py +90 -0
  128. flwr/supercore/superexec/run_superexec.py +2 -2
  129. flwr/supercore/utils.py +225 -0
  130. flwr/superlink/federation/federation_manager.py +2 -2
  131. flwr/superlink/federation/noop_federation_manager.py +8 -6
  132. flwr/superlink/servicer/control/control_grpc.py +2 -0
  133. flwr/superlink/servicer/control/control_servicer.py +106 -21
  134. flwr/supernode/cli/flower_supernode.py +2 -1
  135. flwr/supernode/nodestate/in_memory_nodestate.py +62 -1
  136. flwr/supernode/nodestate/nodestate.py +45 -0
  137. flwr/supernode/runtime/run_clientapp.py +14 -14
  138. flwr/supernode/servicer/clientappio/clientappio_servicer.py +13 -5
  139. flwr/supernode/start_client_internal.py +17 -10
  140. {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/METADATA +8 -8
  141. {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/RECORD +144 -184
  142. flwr/cli/federation/show.py +0 -317
  143. flwr/cli/new/templates/app/.gitignore.tpl +0 -163
  144. flwr/cli/new/templates/app/LICENSE.tpl +0 -202
  145. flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
  146. flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
  147. flwr/cli/new/templates/app/README.md.tpl +0 -37
  148. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
  149. flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
  150. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
  151. flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
  152. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
  153. flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
  154. flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
  155. flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
  156. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
  157. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
  158. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
  159. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
  160. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
  161. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
  162. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
  163. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
  164. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
  165. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
  166. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
  167. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
  168. flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
  169. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
  170. flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
  171. flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
  172. flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
  173. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
  174. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
  175. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
  176. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
  177. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
  178. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
  179. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
  180. flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
  181. flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
  182. flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
  183. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -99
  184. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
  185. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
  186. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
  187. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
  188. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
  189. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
  190. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
  191. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
  192. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
  193. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
  194. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
  195. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
  196. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
  197. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
  198. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
  199. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
  200. flwr/common/pyproject.py +0 -42
  201. flwr/supercore/sqlite_mixin.py +0 -159
  202. /flwr/{common → supercore}/version.py +0 -0
  203. {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/WHEEL +0 -0
  204. {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/entry_points.txt +0 -0
@@ -1,56 +0,0 @@
1
- """$project_name: A Flower / FlowerTune app."""
2
-
3
- import math
4
-
5
- import torch
6
- from omegaconf import DictConfig
7
- from peft import LoraConfig, get_peft_model
8
- from peft.utils import prepare_model_for_kbit_training
9
- from transformers import AutoModelForCausalLM, BitsAndBytesConfig
10
-
11
-
12
- def cosine_annealing(
13
- current_round: int,
14
- total_round: int,
15
- lrate_max: float = 0.001,
16
- lrate_min: float = 0.0,
17
- ) -> float:
18
- """Implement cosine annealing learning rate schedule."""
19
- cos_inner = math.pi * current_round / total_round
20
- return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
21
-
22
-
23
- def get_model(model_cfg: DictConfig):
24
- """Load model with appropriate quantization config and other optimizations.
25
- """
26
- if model_cfg.quantization == 4:
27
- quantization_config = BitsAndBytesConfig(load_in_4bit=True)
28
- elif model_cfg.quantization == 8:
29
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
30
- else:
31
- raise ValueError(
32
- f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
33
- )
34
-
35
- model = AutoModelForCausalLM.from_pretrained(
36
- model_cfg.name,
37
- quantization_config=quantization_config,
38
- torch_dtype=torch.bfloat16,
39
- low_cpu_mem_usage=True,
40
- )
41
-
42
- model = prepare_model_for_kbit_training(
43
- model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
44
- )
45
-
46
- peft_config = LoraConfig(
47
- r=model_cfg.lora.peft_lora_r,
48
- lora_alpha=model_cfg.lora.peft_lora_alpha,
49
- lora_dropout=0.075,
50
- task_type="CAUSAL_LM",
51
- )
52
-
53
- if model_cfg.gradient_checkpointing:
54
- model.config.use_cache = False
55
-
56
- return get_peft_model(model, peft_config)
@@ -1,73 +0,0 @@
1
- """$project_name: A Flower / FlowerTune app."""
2
-
3
- import os
4
- from datetime import datetime
5
-
6
- from flwr.app import ArrayRecord, ConfigRecord, Context, MetricRecord
7
- from flwr.common.config import unflatten_dict
8
- from flwr.serverapp import Grid, ServerApp
9
- from omegaconf import DictConfig
10
- from peft import get_peft_model_state_dict, set_peft_model_state_dict
11
-
12
- from $import_name.dataset import replace_keys
13
- from $import_name.models import get_model
14
- from $import_name.strategy import FlowerTuneLlm
15
-
16
- # Create ServerApp
17
- app = ServerApp()
18
-
19
-
20
- @app.main()
21
- def main(grid: Grid, context: Context) -> None:
22
- """Main entry point for the ServerApp."""
23
- # Create output directory given current timestamp
24
- current_time = datetime.now()
25
- folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
26
- save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
27
- os.makedirs(save_path, exist_ok=True)
28
-
29
- # Read from config
30
- num_rounds = context.run_config["num-server-rounds"]
31
- cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
32
-
33
- # Get initial model weights
34
- init_model = get_model(cfg.model)
35
- arrays = ArrayRecord(get_peft_model_state_dict(init_model))
36
-
37
- # Define strategy
38
- strategy = FlowerTuneLlm(
39
- fraction_train=cfg.strategy.fraction_train,
40
- fraction_evaluate=cfg.strategy.fraction_evaluate,
41
- )
42
-
43
- # Start strategy, run FedAvg for `num_rounds`
44
- strategy.start(
45
- grid=grid,
46
- initial_arrays=arrays,
47
- train_config=ConfigRecord({"save_path": save_path}),
48
- num_rounds=num_rounds,
49
- evaluate_fn=get_evaluate_fn(
50
- cfg.model, cfg.train.save_every_round, num_rounds, save_path
51
- ),
52
- )
53
-
54
-
55
- # Get function that will be executed by the strategy
56
- # Here we use it to save global model checkpoints
57
- def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
58
- """Return an evaluation function for saving global model."""
59
-
60
- def evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
61
- # Save model
62
- if server_round != 0 and (
63
- server_round == total_round or server_round % save_every_round == 0
64
- ):
65
- # Init model
66
- model = get_model(model_cfg)
67
- set_peft_model_state_dict(model, arrays.to_torch_state_dict())
68
-
69
- model.save_pretrained(f"{save_path}/peft_{server_round}")
70
-
71
- return MetricRecord()
72
-
73
- return evaluate
@@ -1,78 +0,0 @@
1
- """$project_name: A Flower / FlowerTune app."""
2
-
3
- from collections.abc import Iterable
4
- from logging import INFO, WARN
5
- from typing import Optional
6
-
7
- from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
8
- from flwr.common import log
9
- from flwr.serverapp import Grid
10
- from flwr.serverapp.strategy import FedAvg
11
-
12
-
13
- class FlowerTuneLlm(FedAvg):
14
- """Customised FedAvg strategy implementation.
15
-
16
- This class behaves just like FedAvg but also tracks the communication
17
- costs associated with `train` over FL rounds.
18
- """
19
- def __init__(self, **kwargs):
20
- super().__init__(**kwargs)
21
- self.comm_tracker = CommunicationTracker()
22
-
23
- def configure_train(
24
- self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
25
- ) -> Iterable[Message]:
26
- """Configure the next round of training."""
27
- messages = super().configure_train(server_round, arrays, config, grid)
28
-
29
- # Track communication costs
30
- self.comm_tracker.track(messages)
31
-
32
- return messages
33
-
34
- def aggregate_train(
35
- self,
36
- server_round: int,
37
- replies: Iterable[Message],
38
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
39
- """Aggregate ArrayRecords and MetricRecords in the received Messages."""
40
- # Track communication costs
41
- self.comm_tracker.track(replies)
42
-
43
- arrays, metrics = super().aggregate_train(server_round, replies)
44
-
45
- return arrays, metrics
46
-
47
-
48
- class CommunicationTracker:
49
- """Communication costs tracker over FL rounds."""
50
- def __init__(self):
51
- self.curr_comm_cost = 0.0
52
-
53
- def track(self, messages: Iterable[Message]):
54
- comm_cost = (
55
- sum(
56
- record.count_bytes()
57
- for msg in messages
58
- if msg.has_content()
59
- for record in msg.content.array_records.values()
60
- )
61
- / 1024**2
62
- )
63
-
64
- self.curr_comm_cost += comm_cost
65
- log(
66
- INFO,
67
- "Communication budget: used %.2f MB (+%.2f MB this round) / 200,000 MB",
68
- self.curr_comm_cost,
69
- comm_cost,
70
- )
71
-
72
- if self.curr_comm_cost > 2e5:
73
- log(
74
- WARN,
75
- "The accumulated communication cost has exceeded 200,000 MB. "
76
- "Please consider reducing it if you plan to participate "
77
- "FlowerTune LLM Leaderboard.",
78
- )
@@ -1,66 +0,0 @@
1
- """$project_name: A Flower Baseline."""
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn
6
-
7
-
8
- class Net(nn.Module):
9
- """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')."""
10
-
11
- def __init__(self):
12
- super().__init__()
13
- self.conv1 = nn.Conv2d(3, 6, 5)
14
- self.pool = nn.MaxPool2d(2, 2)
15
- self.conv2 = nn.Conv2d(6, 16, 5)
16
- self.fc1 = nn.Linear(16 * 5 * 5, 120)
17
- self.fc2 = nn.Linear(120, 84)
18
- self.fc3 = nn.Linear(84, 10)
19
-
20
- def forward(self, x):
21
- """Do forward."""
22
- x = self.pool(F.relu(self.conv1(x)))
23
- x = self.pool(F.relu(self.conv2(x)))
24
- x = x.view(-1, 16 * 5 * 5)
25
- x = F.relu(self.fc1(x))
26
- x = F.relu(self.fc2(x))
27
- return self.fc3(x)
28
-
29
-
30
- def train(net, trainloader, epochs, device):
31
- """Train the model on the training set."""
32
- net.to(device) # move model to GPU if available
33
- criterion = torch.nn.CrossEntropyLoss()
34
- criterion.to(device)
35
- optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
36
- net.train()
37
- running_loss = 0.0
38
- for _ in range(epochs):
39
- for batch in trainloader:
40
- images = batch["img"]
41
- labels = batch["label"]
42
- optimizer.zero_grad()
43
- loss = criterion(net(images.to(device)), labels.to(device))
44
- loss.backward()
45
- optimizer.step()
46
- running_loss += loss.item()
47
-
48
- avg_trainloss = running_loss / len(trainloader)
49
- return avg_trainloss
50
-
51
-
52
- def test(net, testloader, device):
53
- """Validate the model on the test set."""
54
- net.to(device)
55
- criterion = torch.nn.CrossEntropyLoss()
56
- correct, loss = 0, 0.0
57
- with torch.no_grad():
58
- for batch in testloader:
59
- images = batch["img"].to(device)
60
- labels = batch["label"].to(device)
61
- outputs = net(images)
62
- loss += criterion(outputs, labels).item()
63
- correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
64
- accuracy = correct / len(testloader.dataset)
65
- loss = loss / len(testloader)
66
- return loss, accuracy
@@ -1,43 +0,0 @@
1
- """$project_name: A Flower Baseline."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
-
8
- from $import_name.model import Net
9
-
10
- # Create ServerApp
11
- app = ServerApp()
12
-
13
-
14
- @app.main()
15
- def main(grid: Grid, context: Context) -> None:
16
- """Main entry point for the ServerApp."""
17
-
18
- # Read from config
19
- num_rounds = context.run_config["num-server-rounds"]
20
- fraction_train = context.run_config["fraction-train"]
21
-
22
- # Load global model
23
- global_model = Net()
24
- arrays = ArrayRecord(global_model.state_dict())
25
-
26
- # Initialize FedAvg strategy
27
- strategy = FedAvg(
28
- fraction_train=fraction_train,
29
- fraction_evaluate=1.0,
30
- min_available_nodes=2,
31
- )
32
-
33
- # Start strategy, run FedAvg for `num_rounds`
34
- result = strategy.start(
35
- grid=grid,
36
- initial_arrays=arrays,
37
- num_rounds=num_rounds,
38
- )
39
-
40
- # Save final model to disk
41
- print("\nSaving final model to disk...")
42
- state_dict = result.arrays.to_torch_state_dict()
43
- torch.save(state_dict, "final_model.pt")
@@ -1,42 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
- from transformers import AutoModelForSequenceClassification
8
-
9
- # Create ServerApp
10
- app = ServerApp()
11
-
12
-
13
- @app.main()
14
- def main(grid: Grid, context: Context) -> None:
15
- """Main entry point for the ServerApp."""
16
-
17
- # Read from config
18
- num_rounds = context.run_config["num-server-rounds"]
19
- fraction_train = context.run_config["fraction-train"]
20
-
21
- # Initialize global model
22
- model_name = context.run_config["model-name"]
23
- num_labels = context.run_config["num-labels"]
24
- net = AutoModelForSequenceClassification.from_pretrained(
25
- model_name, num_labels=num_labels
26
- )
27
- arrays = ArrayRecord(net.state_dict())
28
-
29
- # Initialize FedAvg strategy
30
- strategy = FedAvg(fraction_train=fraction_train)
31
-
32
- # Start strategy, run FedAvg for `num_rounds`
33
- result = strategy.start(
34
- grid=grid,
35
- initial_arrays=arrays,
36
- num_rounds=num_rounds,
37
- )
38
-
39
- # Save final model to disk
40
- print("\nSaving final model to disk...")
41
- state_dict = result.arrays.to_torch_state_dict()
42
- torch.save(state_dict, "final_model.pt")
@@ -1,39 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import numpy as np
4
- from flwr.app import ArrayRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
-
8
- from $import_name.task import get_params, load_model
9
-
10
- # Create ServerApp
11
- app = ServerApp()
12
-
13
-
14
- @app.main()
15
- def main(grid: Grid, context: Context) -> None:
16
- """Main entry point for the ServerApp."""
17
-
18
- # Read from config
19
- num_rounds = context.run_config["num-server-rounds"]
20
- input_dim = context.run_config["input-dim"]
21
-
22
- # Load global model
23
- model = load_model((input_dim,))
24
- arrays = ArrayRecord(get_params(model))
25
-
26
- # Initialize FedAvg strategy
27
- strategy = FedAvg()
28
-
29
- # Start strategy, run FedAvg for `num_rounds`
30
- result = strategy.start(
31
- grid=grid,
32
- initial_arrays=arrays,
33
- num_rounds=num_rounds,
34
- )
35
-
36
- # Save final model to disk
37
- print("\nSaving final model to disk...")
38
- ndarrays = result.arrays.to_numpy_ndarrays()
39
- np.savez("final_model.npz", *ndarrays)
@@ -1,41 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- from flwr.app import ArrayRecord, Context
4
- from flwr.serverapp import Grid, ServerApp
5
- from flwr.serverapp.strategy import FedAvg
6
-
7
- from $import_name.task import MLP, get_params, set_params
8
-
9
- # Create ServerApp
10
- app = ServerApp()
11
-
12
-
13
- @app.main()
14
- def main(grid: Grid, context: Context) -> None:
15
- """Main entry point for the ServerApp."""
16
- # Read from config
17
- num_rounds = context.run_config["num-server-rounds"]
18
- num_layers = context.run_config["num-layers"]
19
- input_dim = context.run_config["input-dim"]
20
- hidden_dim = context.run_config["hidden-dim"]
21
-
22
- # Initialize global model
23
- model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
24
- params = get_params(model)
25
- arrays = ArrayRecord(params)
26
-
27
- # Initialize FedAvg strategy
28
- strategy = FedAvg()
29
-
30
- # Start strategy, run FedAvg for `num_rounds`
31
- result = strategy.start(
32
- grid=grid,
33
- initial_arrays=arrays,
34
- num_rounds=num_rounds,
35
- )
36
-
37
- # Save final model to disk
38
- print("\nSaving final model to disk...")
39
- ndarrays = result.arrays.to_numpy_ndarrays()
40
- set_params(model, ndarrays)
41
- model.save_weights("final_model.npz")
@@ -1,38 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import numpy as np
4
- from flwr.app import ArrayRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
-
8
- from $import_name.task import get_dummy_model
9
-
10
- # Create ServerApp
11
- app = ServerApp()
12
-
13
-
14
- @app.main()
15
- def main(grid: Grid, context: Context) -> None:
16
- """Main entry point for the ServerApp."""
17
-
18
- # Read run config
19
- num_rounds: int = context.run_config["num-server-rounds"]
20
-
21
- # Load global model
22
- model = get_dummy_model()
23
- arrays = ArrayRecord(model)
24
-
25
- # Initialize FedAvg strategy
26
- strategy = FedAvg()
27
-
28
- # Start strategy, run FedAvg for `num_rounds`
29
- result = strategy.start(
30
- grid=grid,
31
- initial_arrays=arrays,
32
- num_rounds=num_rounds,
33
- )
34
-
35
- # Save final model to disk
36
- print("\nSaving final model to disk...")
37
- ndarrays = result.arrays.to_numpy_ndarrays()
38
- np.savez("final_model", *ndarrays)
@@ -1,41 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, ConfigRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
-
8
- from $import_name.task import Net
9
-
10
- # Create ServerApp
11
- app = ServerApp()
12
-
13
-
14
- @app.main()
15
- def main(grid: Grid, context: Context) -> None:
16
- """Main entry point for the ServerApp."""
17
-
18
- # Read run config
19
- fraction_train: float = context.run_config["fraction-train"]
20
- num_rounds: int = context.run_config["num-server-rounds"]
21
- lr: float = context.run_config["lr"]
22
-
23
- # Load global model
24
- global_model = Net()
25
- arrays = ArrayRecord(global_model.state_dict())
26
-
27
- # Initialize FedAvg strategy
28
- strategy = FedAvg(fraction_train=fraction_train)
29
-
30
- # Start strategy, run FedAvg for `num_rounds`
31
- result = strategy.start(
32
- grid=grid,
33
- initial_arrays=arrays,
34
- train_config=ConfigRecord({"lr": lr}),
35
- num_rounds=num_rounds,
36
- )
37
-
38
- # Save final model to disk
39
- print("\nSaving final model to disk...")
40
- state_dict = result.arrays.to_torch_state_dict()
41
- torch.save(state_dict, "final_model.pt")
@@ -1,31 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import Net, get_weights
7
-
8
-
9
- def server_fn(context: Context):
10
- # Read from config
11
- num_rounds = context.run_config["num-server-rounds"]
12
- fraction_fit = context.run_config["fraction-fit"]
13
-
14
- # Initialize model parameters
15
- ndarrays = get_weights(Net())
16
- parameters = ndarrays_to_parameters(ndarrays)
17
-
18
- # Define strategy
19
- strategy = FedAvg(
20
- fraction_fit=fraction_fit,
21
- fraction_evaluate=1.0,
22
- min_available_clients=2,
23
- initial_parameters=parameters,
24
- )
25
- config = ServerConfig(num_rounds=num_rounds)
26
-
27
- return ServerAppComponents(strategy=strategy, config=config)
28
-
29
-
30
- # Create ServerApp
31
- app = ServerApp(server_fn=server_fn)
@@ -1,44 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import joblib
4
- from flwr.app import ArrayRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
-
8
- from $import_name.task import get_model, get_model_params, set_initial_params, set_model_params
9
-
10
- # Create ServerApp
11
- app = ServerApp()
12
-
13
-
14
- @app.main()
15
- def main(grid: Grid, context: Context) -> None:
16
- """Main entry point for the ServerApp."""
17
-
18
- # Read run config
19
- num_rounds: int = context.run_config["num-server-rounds"]
20
-
21
- # Create LogisticRegression Model
22
- penalty = context.run_config["penalty"]
23
- local_epochs = context.run_config["local-epochs"]
24
- model = get_model(penalty, local_epochs)
25
- # Setting initial parameters, akin to model.compile for keras models
26
- set_initial_params(model)
27
- # Construct ArrayRecord representation
28
- arrays = ArrayRecord(get_model_params(model))
29
-
30
- # Initialize FedAvg strategy
31
- strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
32
-
33
- # Start strategy, run FedAvg for `num_rounds`
34
- result = strategy.start(
35
- grid=grid,
36
- initial_arrays=arrays,
37
- num_rounds=num_rounds,
38
- )
39
-
40
- # Save final model parameters
41
- print("\nSaving final model to disk...")
42
- ndarrays = result.arrays.to_numpy_ndarrays()
43
- set_model_params(model, ndarrays)
44
- joblib.dump(model, "logreg_model.pkl")
@@ -1,38 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- from flwr.app import ArrayRecord, Context
4
- from flwr.serverapp import Grid, ServerApp
5
- from flwr.serverapp.strategy import FedAvg
6
-
7
- from $import_name.task import load_model
8
-
9
- # Create ServerApp
10
- app = ServerApp()
11
-
12
-
13
- @app.main()
14
- def main(grid: Grid, context: Context) -> None:
15
- """Main entry point for the ServerApp."""
16
-
17
- # Read run config
18
- num_rounds: int = context.run_config["num-server-rounds"]
19
-
20
- # Load global model
21
- model = load_model()
22
- arrays = ArrayRecord(model.get_weights())
23
-
24
- # Initialize FedAvg strategy
25
- strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
26
-
27
- # Start strategy, run FedAvg for `num_rounds`
28
- result = strategy.start(
29
- grid=grid,
30
- initial_arrays=arrays,
31
- num_rounds=num_rounds,
32
- )
33
-
34
- # Save final model to disk
35
- print("\nSaving final model to disk...")
36
- ndarrays = result.arrays.to_numpy_ndarrays()
37
- model.set_weights(ndarrays)
38
- model.save("final_model.keras")