flwr-nightly 1.22.0.dev20250916__py3-none-any.whl → 1.22.0.dev20250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/new/new.py +2 -2
  3. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  4. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  5. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  6. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  7. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  8. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  9. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  10. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  11. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +3 -3
  12. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  13. flwr/cli/pull.py +100 -0
  14. flwr/cli/utils.py +17 -0
  15. flwr/common/constant.py +2 -0
  16. flwr/proto/control_pb2.py +7 -3
  17. flwr/proto/control_pb2.pyi +24 -0
  18. flwr/proto/control_pb2_grpc.py +34 -0
  19. flwr/proto/control_pb2_grpc.pyi +13 -0
  20. flwr/server/app.py +13 -0
  21. flwr/serverapp/strategy/__init__.py +2 -0
  22. flwr/serverapp/strategy/fedprox.py +174 -0
  23. flwr/simulation/app.py +1 -1
  24. flwr/simulation/run_simulation.py +25 -30
  25. flwr/superlink/artifact_provider/__init__.py +22 -0
  26. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  27. flwr/superlink/servicer/control/control_grpc.py +3 -0
  28. flwr/superlink/servicer/control/control_servicer.py +59 -2
  29. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/METADATA +1 -1
  30. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/RECORD +32 -28
  31. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/WHEEL +0 -0
  32. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/entry_points.txt +0 -0
@@ -1,45 +1,43 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
- from flwr.common import Context, Metrics, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import torch
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
6
7
 
7
- from $import_name.model import Net, get_weights
8
+ from $import_name.model import Net
8
9
 
10
+ # Create ServerApp
11
+ app = ServerApp()
9
12
 
10
- # Define metric aggregation function
11
- def weighted_average(metrics: list[tuple[int, Metrics]]) -> Metrics:
12
- """Do weighted average of accuracy metric."""
13
- # Multiply accuracy of each client by number of examples used
14
- accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
15
- examples = [num_examples for num_examples, _ in metrics]
16
-
17
- # Aggregate and return custom metric (weighted average)
18
- return {"accuracy": sum(accuracies) / sum(examples)}
19
13
 
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
20
17
 
21
- def server_fn(context: Context):
22
- """Construct components that set the ServerApp behaviour."""
23
18
  # Read from config
24
19
  num_rounds = context.run_config["num-server-rounds"]
25
- fraction_fit = context.run_config["fraction-fit"]
20
+ fraction_train = context.run_config["fraction-train"]
26
21
 
27
- # Initialize model parameters
28
- ndarrays = get_weights(Net())
29
- parameters = ndarrays_to_parameters(ndarrays)
22
+ # Load global model
23
+ global_model = Net()
24
+ arrays = ArrayRecord(global_model.state_dict())
30
25
 
31
- # Define strategy
26
+ # Initialize FedAvg strategy
32
27
  strategy = FedAvg(
33
- fraction_fit=float(fraction_fit),
28
+ fraction_train=fraction_train,
34
29
  fraction_evaluate=1.0,
35
- min_available_clients=2,
36
- initial_parameters=parameters,
37
- evaluate_metrics_aggregation_fn=weighted_average,
30
+ min_available_nodes=2,
38
31
  )
39
- config = ServerConfig(num_rounds=int(num_rounds))
40
-
41
- return ServerAppComponents(strategy=strategy, config=config)
42
32
 
33
+ # Start strategy, run FedAvg for `num_rounds`
34
+ result = strategy.start(
35
+ grid=grid,
36
+ initial_arrays=arrays,
37
+ num_rounds=num_rounds,
38
+ )
43
39
 
44
- # Create ServerApp
45
- app = ServerApp(server_fn=server_fn)
40
+ # Save final model to disk
41
+ print("\nSaving final model to disk...")
42
+ state_dict = result.arrays.to_torch_state_dict()
43
+ torch.save(state_dict, "final_model.pt")
@@ -16,8 +16,8 @@ license = "Apache-2.0"
16
16
  dependencies = [
17
17
  "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
- "torch==2.7.1",
20
- "torchvision==0.22.1",
19
+ "torch==2.8.0",
20
+ "torchvision==0.23.0",
21
21
  ]
22
22
 
23
23
  [tool.hatch.metadata]
@@ -132,7 +132,7 @@ clientapp = "$import_name.client_app:app"
132
132
  # Custom config values accessible via `context.run_config`
133
133
  [tool.flwr.app.config]
134
134
  num-server-rounds = 3
135
- fraction-fit = 0.5
135
+ fraction-train = 0.5
136
136
  local-epochs = 1
137
137
 
138
138
  # Default federation to use when running the app
@@ -61,7 +61,7 @@ train.training-arguments.save-steps = 1000
61
61
  train.training-arguments.save-total-limit = 10
62
62
  train.training-arguments.gradient-checkpointing = true
63
63
  train.training-arguments.lr-scheduler-type = "constant"
64
- strategy.fraction-fit = $fraction_fit
64
+ strategy.fraction-train = $fraction_train
65
65
  strategy.fraction-evaluate = 0.0
66
66
  num-server-rounds = 200
67
67
 
flwr/cli/pull.py ADDED
@@ -0,0 +1,100 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower command line interface `pull` command."""
16
+
17
+
18
+ from pathlib import Path
19
+ from typing import Annotated, Optional
20
+
21
+ import typer
22
+
23
+ from flwr.cli.config_utils import (
24
+ exit_if_no_address,
25
+ load_and_validate,
26
+ process_loaded_project_config,
27
+ validate_federation_in_project_config,
28
+ )
29
+ from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
30
+ from flwr.common.constant import FAB_CONFIG_FILE
31
+ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
32
+ PullArtifactsRequest,
33
+ PullArtifactsResponse,
34
+ )
35
+ from flwr.proto.control_pb2_grpc import ControlStub
36
+
37
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
38
+
39
+
40
+ def pull( # pylint: disable=R0914
41
+ run_id: Annotated[
42
+ int,
43
+ typer.Option(
44
+ "--run-id",
45
+ help="Run ID to pull artifacts from.",
46
+ ),
47
+ ],
48
+ app: Annotated[
49
+ Path,
50
+ typer.Argument(help="Path of the Flower App to run."),
51
+ ] = Path("."),
52
+ federation: Annotated[
53
+ Optional[str],
54
+ typer.Argument(help="Name of the federation."),
55
+ ] = None,
56
+ federation_config_overrides: Annotated[
57
+ Optional[list[str]],
58
+ typer.Option(
59
+ "--federation-config",
60
+ help=FEDERATION_CONFIG_HELP_MESSAGE,
61
+ ),
62
+ ] = None,
63
+ ) -> None:
64
+ """Pull artifacts from a Flower run."""
65
+ typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
66
+
67
+ pyproject_path = app / FAB_CONFIG_FILE if app else None
68
+ config, errors, warnings = load_and_validate(path=pyproject_path)
69
+ config = process_loaded_project_config(config, errors, warnings)
70
+ federation, federation_config = validate_federation_in_project_config(
71
+ federation, config, federation_config_overrides
72
+ )
73
+ exit_if_no_address(federation_config, "pull")
74
+ channel = None
75
+ try:
76
+
77
+ auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
78
+ channel = init_channel(app, federation_config, auth_plugin)
79
+ stub = ControlStub(channel)
80
+ with flwr_cli_grpc_exc_handler():
81
+ res: PullArtifactsResponse = stub.PullArtifacts(
82
+ PullArtifactsRequest(run_id=run_id)
83
+ )
84
+
85
+ if not res.url:
86
+ typer.secho(
87
+ f"❌ A download URL for artifacts from run {run_id} couldn't be "
88
+ "obtained.",
89
+ fg=typer.colors.RED,
90
+ bold=True,
91
+ )
92
+ raise typer.Exit(code=1)
93
+
94
+ typer.secho(
95
+ f"✅ Artifacts for run {run_id} can be downloaded from: {res.url}",
96
+ fg=typer.colors.GREEN,
97
+ )
98
+ finally:
99
+ if channel:
100
+ channel.close()
flwr/cli/utils.py CHANGED
@@ -32,7 +32,9 @@ from flwr.common.constant import (
32
32
  AUTH_TYPE_JSON_KEY,
33
33
  CREDENTIALS_DIR,
34
34
  FLWR_DIR,
35
+ NO_ARTIFACT_PROVIDER_MESSAGE,
35
36
  NO_USER_AUTH_MESSAGE,
37
+ PULL_UNFINISHED_RUN_MESSAGE,
36
38
  RUN_ID_NOT_FOUND_MESSAGE,
37
39
  )
38
40
  from flwr.common.grpc import (
@@ -319,6 +321,12 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
319
321
  fg=typer.colors.RED,
320
322
  bold=True,
321
323
  )
324
+ elif e.details() == NO_ARTIFACT_PROVIDER_MESSAGE: # pylint: disable=E1101
325
+ typer.secho(
326
+ "❌ The SuperLink does not support `flwr pull` command.",
327
+ fg=typer.colors.RED,
328
+ bold=True,
329
+ )
322
330
  else:
323
331
  typer.secho(
324
332
  "❌ The SuperLink cannot process this request. Please verify that "
@@ -356,4 +364,13 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
356
364
  bold=True,
357
365
  )
358
366
  raise typer.Exit(code=1) from None
367
+ if e.code() == grpc.StatusCode.FAILED_PRECONDITION:
368
+ if e.details() == PULL_UNFINISHED_RUN_MESSAGE: # pylint: disable=E1101
369
+ typer.secho(
370
+ "❌ Run is not finished yet. Artifacts can only be pulled after "
371
+ "the run is finished. You can check the run status with `flwr ls`.",
372
+ fg=typer.colors.RED,
373
+ bold=True,
374
+ )
375
+ raise typer.Exit(code=1) from None
359
376
  raise
flwr/common/constant.py CHANGED
@@ -155,6 +155,8 @@ PULL_BACKOFF_CAP = 10 # Maximum backoff time for pulling objects
155
155
  # ControlServicer constants
156
156
  RUN_ID_NOT_FOUND_MESSAGE = "Run ID not found"
157
157
  NO_USER_AUTH_MESSAGE = "ControlServicer initialized without user authentication"
158
+ NO_ARTIFACT_PROVIDER_MESSAGE = "ControlServicer initialized without artifact provider"
159
+ PULL_UNFINISHED_RUN_MESSAGE = "Cannot pull artifacts for an unfinished run"
158
160
 
159
161
 
160
162
  class MessageType:
flwr/proto/control_pb2.py CHANGED
@@ -18,7 +18,7 @@ from flwr.proto import recorddict_pb2 as flwr_dot_proto_dot_recorddict__pb2
18
18
  from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
19
19
 
20
20
 
21
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xe8\x03\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x62\x06proto3')
21
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"&\n\x14PullArtifactsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"1\n\x15PullArtifactsResponse\x12\x10\n\x03url\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x06\n\x04_url2\xc0\x04\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x12V\n\rPullArtifacts\x12 .flwr.proto.PullArtifactsRequest\x1a!.flwr.proto.PullArtifactsResponse\"\x00\x62\x06proto3')
22
22
 
23
23
  _globals = globals()
24
24
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -57,6 +57,10 @@ if _descriptor._USE_C_DESCRIPTORS == False:
57
57
  _globals['_STOPRUNREQUEST']._serialized_end=1101
58
58
  _globals['_STOPRUNRESPONSE']._serialized_start=1103
59
59
  _globals['_STOPRUNRESPONSE']._serialized_end=1137
60
- _globals['_CONTROL']._serialized_start=1140
61
- _globals['_CONTROL']._serialized_end=1628
60
+ _globals['_PULLARTIFACTSREQUEST']._serialized_start=1139
61
+ _globals['_PULLARTIFACTSREQUEST']._serialized_end=1177
62
+ _globals['_PULLARTIFACTSRESPONSE']._serialized_start=1179
63
+ _globals['_PULLARTIFACTSRESPONSE']._serialized_end=1228
64
+ _globals['_CONTROL']._serialized_start=1231
65
+ _globals['_CONTROL']._serialized_end=1807
62
66
  # @@protoc_insertion_point(module_scope)
@@ -210,3 +210,27 @@ class StopRunResponse(google.protobuf.message.Message):
210
210
  ) -> None: ...
211
211
  def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ...
212
212
  global___StopRunResponse = StopRunResponse
213
+
214
+ class PullArtifactsRequest(google.protobuf.message.Message):
215
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
216
+ RUN_ID_FIELD_NUMBER: builtins.int
217
+ run_id: builtins.int
218
+ def __init__(self,
219
+ *,
220
+ run_id: builtins.int = ...,
221
+ ) -> None: ...
222
+ def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
223
+ global___PullArtifactsRequest = PullArtifactsRequest
224
+
225
+ class PullArtifactsResponse(google.protobuf.message.Message):
226
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
227
+ URL_FIELD_NUMBER: builtins.int
228
+ url: typing.Text
229
+ def __init__(self,
230
+ *,
231
+ url: typing.Optional[typing.Text] = ...,
232
+ ) -> None: ...
233
+ def HasField(self, field_name: typing_extensions.Literal["_url",b"_url","url",b"url"]) -> builtins.bool: ...
234
+ def ClearField(self, field_name: typing_extensions.Literal["_url",b"_url","url",b"url"]) -> None: ...
235
+ def WhichOneof(self, oneof_group: typing_extensions.Literal["_url",b"_url"]) -> typing.Optional[typing_extensions.Literal["url"]]: ...
236
+ global___PullArtifactsResponse = PullArtifactsResponse
@@ -44,6 +44,11 @@ class ControlStub(object):
44
44
  request_serializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensRequest.SerializeToString,
45
45
  response_deserializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.FromString,
46
46
  )
47
+ self.PullArtifacts = channel.unary_unary(
48
+ '/flwr.proto.Control/PullArtifacts',
49
+ request_serializer=flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.SerializeToString,
50
+ response_deserializer=flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.FromString,
51
+ )
47
52
 
48
53
 
49
54
  class ControlServicer(object):
@@ -91,6 +96,13 @@ class ControlServicer(object):
91
96
  context.set_details('Method not implemented!')
92
97
  raise NotImplementedError('Method not implemented!')
93
98
 
99
+ def PullArtifacts(self, request, context):
100
+ """Pull artifacts generated during a run (flwr pull)
101
+ """
102
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
103
+ context.set_details('Method not implemented!')
104
+ raise NotImplementedError('Method not implemented!')
105
+
94
106
 
95
107
  def add_ControlServicer_to_server(servicer, server):
96
108
  rpc_method_handlers = {
@@ -124,6 +136,11 @@ def add_ControlServicer_to_server(servicer, server):
124
136
  request_deserializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensRequest.FromString,
125
137
  response_serializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.SerializeToString,
126
138
  ),
139
+ 'PullArtifacts': grpc.unary_unary_rpc_method_handler(
140
+ servicer.PullArtifacts,
141
+ request_deserializer=flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.FromString,
142
+ response_serializer=flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.SerializeToString,
143
+ ),
127
144
  }
128
145
  generic_handler = grpc.method_handlers_generic_handler(
129
146
  'flwr.proto.Control', rpc_method_handlers)
@@ -235,3 +252,20 @@ class Control(object):
235
252
  flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.FromString,
236
253
  options, channel_credentials,
237
254
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
255
+
256
+ @staticmethod
257
+ def PullArtifacts(request,
258
+ target,
259
+ options=(),
260
+ channel_credentials=None,
261
+ call_credentials=None,
262
+ insecure=False,
263
+ compression=None,
264
+ wait_for_ready=None,
265
+ timeout=None,
266
+ metadata=None):
267
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/PullArtifacts',
268
+ flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.SerializeToString,
269
+ flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.FromString,
270
+ options, channel_credentials,
271
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -39,6 +39,11 @@ class ControlStub:
39
39
  flwr.proto.control_pb2.GetAuthTokensResponse]
40
40
  """Get auth tokens upon request"""
41
41
 
42
+ PullArtifacts: grpc.UnaryUnaryMultiCallable[
43
+ flwr.proto.control_pb2.PullArtifactsRequest,
44
+ flwr.proto.control_pb2.PullArtifactsResponse]
45
+ """Pull artifacts generated during a run (flwr pull)"""
46
+
42
47
 
43
48
  class ControlServicer(metaclass=abc.ABCMeta):
44
49
  @abc.abstractmethod
@@ -89,5 +94,13 @@ class ControlServicer(metaclass=abc.ABCMeta):
89
94
  """Get auth tokens upon request"""
90
95
  pass
91
96
 
97
+ @abc.abstractmethod
98
+ def PullArtifacts(self,
99
+ request: flwr.proto.control_pb2.PullArtifactsRequest,
100
+ context: grpc.ServicerContext,
101
+ ) -> flwr.proto.control_pb2.PullArtifactsResponse:
102
+ """Pull artifacts generated during a run (flwr pull)"""
103
+ pass
104
+
92
105
 
93
106
  def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ...
flwr/server/app.py CHANGED
@@ -71,6 +71,7 @@ from flwr.server.fleet_event_log_interceptor import FleetEventLogInterceptor
71
71
  from flwr.supercore.ffs import FfsFactory
72
72
  from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_no_tls
73
73
  from flwr.supercore.object_store import ObjectStoreFactory
74
+ from flwr.superlink.artifact_provider import ArtifactProvider
74
75
  from flwr.superlink.servicer.control import run_control_api_grpc
75
76
 
76
77
  from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
@@ -91,6 +92,7 @@ try:
91
92
  get_control_auth_plugins,
92
93
  get_control_authz_plugins,
93
94
  get_control_event_log_writer_plugins,
95
+ get_ee_artifact_provider,
94
96
  get_fleet_event_log_writer_plugins,
95
97
  )
96
98
  except ImportError:
@@ -113,6 +115,10 @@ except ImportError:
113
115
  "No event log writer plugins are currently supported."
114
116
  )
115
117
 
118
+ def get_ee_artifact_provider(config_path: str) -> ArtifactProvider:
119
+ """Return the EE artifact provider."""
120
+ raise NotImplementedError("No artifact provider is currently supported.")
121
+
116
122
  def get_fleet_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
117
123
  """Return all Fleet API event log writer plugins."""
118
124
  raise NotImplementedError(
@@ -199,6 +205,12 @@ def run_superlink() -> None:
199
205
  if args.enable_event_log:
200
206
  event_log_plugin = _try_obtain_control_event_log_writer_plugin()
201
207
 
208
+ # Load artifact provider if the args.artifact_provider_config is provided
209
+ artifact_provider = None
210
+ if cfg_path := getattr(args, "artifact_provider_config", None):
211
+ log(WARN, "The `--artifact-provider-config` flag is highly experimental.")
212
+ artifact_provider = get_ee_artifact_provider(cfg_path)
213
+
202
214
  # Initialize StateFactory
203
215
  state_factory = LinkStateFactory(args.database)
204
216
 
@@ -220,6 +232,7 @@ def run_superlink() -> None:
220
232
  auth_plugin=auth_plugin,
221
233
  authz_plugin=authz_plugin,
222
234
  event_log_plugin=event_log_plugin,
235
+ artifact_provider=artifact_provider,
223
236
  )
224
237
  grpc_servers = [control_server]
225
238
  bckg_threads: list[threading.Thread] = []
@@ -24,6 +24,7 @@ from .fedadam import FedAdam
24
24
  from .fedavg import FedAvg
25
25
  from .fedavgm import FedAvgM
26
26
  from .fedmedian import FedMedian
27
+ from .fedprox import FedProx
27
28
  from .fedtrimmedavg import FedTrimmedAvg
28
29
  from .fedxgb_bagging import FedXgbBagging
29
30
  from .fedyogi import FedYogi
@@ -38,6 +39,7 @@ __all__ = [
38
39
  "FedAvg",
39
40
  "FedAvgM",
40
41
  "FedMedian",
42
+ "FedProx",
41
43
  "FedTrimmedAvg",
42
44
  "FedXgbBagging",
43
45
  "FedYogi",
@@ -0,0 +1,174 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Federated Optimization (FedProx) [Li et al., 2018] strategy.
16
+
17
+ Paper: arxiv.org/abs/1812.06127
18
+ """
19
+
20
+
21
+ from collections.abc import Iterable
22
+ from logging import INFO, WARN
23
+ from typing import Callable, Optional
24
+
25
+ from flwr.common import (
26
+ ArrayRecord,
27
+ ConfigRecord,
28
+ Message,
29
+ MetricRecord,
30
+ RecordDict,
31
+ log,
32
+ )
33
+ from flwr.server import Grid
34
+
35
+ from .fedavg import FedAvg
36
+
37
+
38
+ class FedProx(FedAvg):
39
+ r"""Federated Optimization strategy.
40
+
41
+ Implementation based on https://arxiv.org/abs/1812.06127
42
+
43
+ FedProx extends FedAvg by introducing a proximal term into the client-side
44
+ optimization objective. The strategy itself behaves identically to FedAvg
45
+ on the server side, but each client **MUST** add a proximal regularization
46
+ term to its local loss function during training:
47
+
48
+ .. math::
49
+ \frac{\mu}{2} || w - w^t ||^2
50
+
51
+ Where $w^t$ denotes the global parameters and $w$ denotes the local weights
52
+ being optimized.
53
+
54
+ This strategy sends the proximal term inside the ``ConfigRecord`` as part of the
55
+ ``configure_train`` method under key ``"proximal-mu"``. The client can then use this
56
+ value to add the proximal term to the loss function.
57
+
58
+ In PyTorch, for example, the loss would go from:
59
+
60
+ .. code:: python
61
+ loss = criterion(net(inputs), labels)
62
+
63
+ To:
64
+
65
+ .. code:: python
66
+ # Get proximal term weight from message
67
+ mu = msg.content["config"]["proximal-mu"]
68
+
69
+ # Compute proximal term
70
+ proximal_term = 0.0
71
+ for local_weights, global_weights in zip(net.parameters(), global_params):
72
+ proximal_term += (local_weights - global_weights).norm(2)
73
+
74
+ # Update loss
75
+ loss = criterion(net(inputs), labels) + (mu / 2) * proximal_term
76
+
77
+ With ``global_params`` being a copy of the model parameters, created **after**
78
+ applying the received global weights but **before** local training begins.
79
+
80
+ .. code:: python
81
+ global_params = copy.deepcopy(net).parameters()
82
+
83
+ Parameters
84
+ ----------
85
+ fraction_train : float (default: 1.0)
86
+ Fraction of nodes used during training. In case `min_train_nodes`
87
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
88
+ will still be sampled.
89
+ fraction_evaluate : float (default: 1.0)
90
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
91
+ is larger than `fraction_evaluate * total_connected_nodes`,
92
+ `min_evaluate_nodes` will still be sampled.
93
+ min_train_nodes : int (default: 2)
94
+ Minimum number of nodes used during training.
95
+ min_evaluate_nodes : int (default: 2)
96
+ Minimum number of nodes used during validation.
97
+ min_available_nodes : int (default: 2)
98
+ Minimum number of total nodes in the system.
99
+ weighted_by_key : str (default: "num-examples")
100
+ The key within each MetricRecord whose value is used as the weight when
101
+ computing weighted averages for both ArrayRecords and MetricRecords.
102
+ arrayrecord_key : str (default: "arrays")
103
+ Key used to store the ArrayRecord when constructing Messages.
104
+ configrecord_key : str (default: "config")
105
+ Key used to store the ConfigRecord when constructing Messages.
106
+ train_metrics_aggr_fn : Optional[callable] (default: None)
107
+ Function with signature (list[RecordDict], str) -> MetricRecord,
108
+ used to aggregate MetricRecords from training round replies.
109
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
110
+ average using the provided weight factor key.
111
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
112
+ Function with signature (list[RecordDict], str) -> MetricRecord,
113
+ used to aggregate MetricRecords from training round replies.
114
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
115
+ average using the provided weight factor key.
116
+ proximal_mu : float (default: 0.0)
117
+ The weight of the proximal term used in the optimization. 0.0 makes
118
+ this strategy equivalent to FedAvg, and the higher the coefficient, the more
119
+ regularization will be used (that is, the client parameters will need to be
120
+ closer to the server parameters during training).
121
+ """
122
+
123
+ def __init__( # pylint: disable=R0913, R0917
124
+ self,
125
+ fraction_train: float = 1.0,
126
+ fraction_evaluate: float = 1.0,
127
+ min_train_nodes: int = 2,
128
+ min_evaluate_nodes: int = 2,
129
+ min_available_nodes: int = 2,
130
+ weighted_by_key: str = "num-examples",
131
+ arrayrecord_key: str = "arrays",
132
+ configrecord_key: str = "config",
133
+ train_metrics_aggr_fn: Optional[
134
+ Callable[[list[RecordDict], str], MetricRecord]
135
+ ] = None,
136
+ evaluate_metrics_aggr_fn: Optional[
137
+ Callable[[list[RecordDict], str], MetricRecord]
138
+ ] = None,
139
+ proximal_mu: float = 0.0,
140
+ ) -> None:
141
+ super().__init__(
142
+ fraction_train=fraction_train,
143
+ fraction_evaluate=fraction_evaluate,
144
+ min_train_nodes=min_train_nodes,
145
+ min_evaluate_nodes=min_evaluate_nodes,
146
+ min_available_nodes=min_available_nodes,
147
+ weighted_by_key=weighted_by_key,
148
+ arrayrecord_key=arrayrecord_key,
149
+ configrecord_key=configrecord_key,
150
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
151
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
152
+ )
153
+ self.proximal_mu = proximal_mu
154
+
155
+ if self.proximal_mu == 0.0:
156
+ log(
157
+ WARN,
158
+ "FedProx initialized with `proximal_mu=0.0`. "
159
+ "This makes the strategy equivalent to FedAvg.",
160
+ )
161
+
162
+ def summary(self) -> None:
163
+ """Log summary configuration of the strategy."""
164
+ log(INFO, "\t├──> FedProx settings:")
165
+ log(INFO, "\t|\t└── Proximal mu: %s", self.proximal_mu)
166
+ super().summary()
167
+
168
+ def configure_train(
169
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
170
+ ) -> Iterable[Message]:
171
+ """Configure the next round of federated training."""
172
+ # Inject proximal term weight into config
173
+ config["proximal-mu"] = self.proximal_mu
174
+ return super().configure_train(server_round, arrays, config, grid)
flwr/simulation/app.py CHANGED
@@ -245,7 +245,7 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
245
245
  run=run,
246
246
  enable_tf_gpu_growth=enable_tf_gpu_growth,
247
247
  verbose_logging=verbose,
248
- server_app_run_config=fused_config,
248
+ server_app_context=context,
249
249
  is_app=True,
250
250
  exit_event=EventType.FLWR_SIMULATION_RUN_LEAVE,
251
251
  )