flwr-nightly 1.22.0.dev20250915__py3-none-any.whl → 1.22.0.dev20250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/new/new.py +2 -2
  3. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  4. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  5. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  6. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  7. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  8. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  9. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  10. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  11. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +3 -3
  12. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  13. flwr/cli/pull.py +100 -0
  14. flwr/cli/utils.py +17 -0
  15. flwr/common/constant.py +2 -0
  16. flwr/proto/control_pb2.py +7 -3
  17. flwr/proto/control_pb2.pyi +24 -0
  18. flwr/proto/control_pb2_grpc.py +34 -0
  19. flwr/proto/control_pb2_grpc.pyi +13 -0
  20. flwr/server/app.py +13 -0
  21. flwr/serverapp/strategy/__init__.py +8 -0
  22. flwr/serverapp/strategy/fedavg.py +23 -2
  23. flwr/serverapp/strategy/fedavgm.py +198 -0
  24. flwr/serverapp/strategy/fedmedian.py +71 -0
  25. flwr/serverapp/strategy/fedprox.py +174 -0
  26. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  27. flwr/serverapp/strategy/strategy_utils_tests.py +20 -1
  28. flwr/simulation/app.py +1 -1
  29. flwr/simulation/run_simulation.py +25 -30
  30. flwr/superlink/artifact_provider/__init__.py +22 -0
  31. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  32. flwr/superlink/servicer/control/control_grpc.py +3 -0
  33. flwr/superlink/servicer/control/control_servicer.py +59 -2
  34. {flwr_nightly-1.22.0.dev20250915.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/METADATA +6 -16
  35. {flwr_nightly-1.22.0.dev20250915.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/RECORD +37 -30
  36. {flwr_nightly-1.22.0.dev20250915.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/WHEEL +0 -0
  37. {flwr_nightly-1.22.0.dev20250915.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/entry_points.txt +0 -0
@@ -1,45 +1,43 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
- from flwr.common import Context, Metrics, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import torch
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
6
7
 
7
- from $import_name.model import Net, get_weights
8
+ from $import_name.model import Net
8
9
 
10
+ # Create ServerApp
11
+ app = ServerApp()
9
12
 
10
- # Define metric aggregation function
11
- def weighted_average(metrics: list[tuple[int, Metrics]]) -> Metrics:
12
- """Do weighted average of accuracy metric."""
13
- # Multiply accuracy of each client by number of examples used
14
- accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
15
- examples = [num_examples for num_examples, _ in metrics]
16
-
17
- # Aggregate and return custom metric (weighted average)
18
- return {"accuracy": sum(accuracies) / sum(examples)}
19
13
 
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
20
17
 
21
- def server_fn(context: Context):
22
- """Construct components that set the ServerApp behaviour."""
23
18
  # Read from config
24
19
  num_rounds = context.run_config["num-server-rounds"]
25
- fraction_fit = context.run_config["fraction-fit"]
20
+ fraction_train = context.run_config["fraction-train"]
26
21
 
27
- # Initialize model parameters
28
- ndarrays = get_weights(Net())
29
- parameters = ndarrays_to_parameters(ndarrays)
22
+ # Load global model
23
+ global_model = Net()
24
+ arrays = ArrayRecord(global_model.state_dict())
30
25
 
31
- # Define strategy
26
+ # Initialize FedAvg strategy
32
27
  strategy = FedAvg(
33
- fraction_fit=float(fraction_fit),
28
+ fraction_train=fraction_train,
34
29
  fraction_evaluate=1.0,
35
- min_available_clients=2,
36
- initial_parameters=parameters,
37
- evaluate_metrics_aggregation_fn=weighted_average,
30
+ min_available_nodes=2,
38
31
  )
39
- config = ServerConfig(num_rounds=int(num_rounds))
40
-
41
- return ServerAppComponents(strategy=strategy, config=config)
42
32
 
33
+ # Start strategy, run FedAvg for `num_rounds`
34
+ result = strategy.start(
35
+ grid=grid,
36
+ initial_arrays=arrays,
37
+ num_rounds=num_rounds,
38
+ )
43
39
 
44
- # Create ServerApp
45
- app = ServerApp(server_fn=server_fn)
40
+ # Save final model to disk
41
+ print("\nSaving final model to disk...")
42
+ state_dict = result.arrays.to_torch_state_dict()
43
+ torch.save(state_dict, "final_model.pt")
@@ -16,8 +16,8 @@ license = "Apache-2.0"
16
16
  dependencies = [
17
17
  "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
- "torch==2.7.1",
20
- "torchvision==0.22.1",
19
+ "torch==2.8.0",
20
+ "torchvision==0.23.0",
21
21
  ]
22
22
 
23
23
  [tool.hatch.metadata]
@@ -132,7 +132,7 @@ clientapp = "$import_name.client_app:app"
132
132
  # Custom config values accessible via `context.run_config`
133
133
  [tool.flwr.app.config]
134
134
  num-server-rounds = 3
135
- fraction-fit = 0.5
135
+ fraction-train = 0.5
136
136
  local-epochs = 1
137
137
 
138
138
  # Default federation to use when running the app
@@ -61,7 +61,7 @@ train.training-arguments.save-steps = 1000
61
61
  train.training-arguments.save-total-limit = 10
62
62
  train.training-arguments.gradient-checkpointing = true
63
63
  train.training-arguments.lr-scheduler-type = "constant"
64
- strategy.fraction-fit = $fraction_fit
64
+ strategy.fraction-train = $fraction_train
65
65
  strategy.fraction-evaluate = 0.0
66
66
  num-server-rounds = 200
67
67
 
flwr/cli/pull.py ADDED
@@ -0,0 +1,100 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower command line interface `pull` command."""
16
+
17
+
18
+ from pathlib import Path
19
+ from typing import Annotated, Optional
20
+
21
+ import typer
22
+
23
+ from flwr.cli.config_utils import (
24
+ exit_if_no_address,
25
+ load_and_validate,
26
+ process_loaded_project_config,
27
+ validate_federation_in_project_config,
28
+ )
29
+ from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
30
+ from flwr.common.constant import FAB_CONFIG_FILE
31
+ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
32
+ PullArtifactsRequest,
33
+ PullArtifactsResponse,
34
+ )
35
+ from flwr.proto.control_pb2_grpc import ControlStub
36
+
37
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
38
+
39
+
40
+ def pull( # pylint: disable=R0914
41
+ run_id: Annotated[
42
+ int,
43
+ typer.Option(
44
+ "--run-id",
45
+ help="Run ID to pull artifacts from.",
46
+ ),
47
+ ],
48
+ app: Annotated[
49
+ Path,
50
+ typer.Argument(help="Path of the Flower App to run."),
51
+ ] = Path("."),
52
+ federation: Annotated[
53
+ Optional[str],
54
+ typer.Argument(help="Name of the federation."),
55
+ ] = None,
56
+ federation_config_overrides: Annotated[
57
+ Optional[list[str]],
58
+ typer.Option(
59
+ "--federation-config",
60
+ help=FEDERATION_CONFIG_HELP_MESSAGE,
61
+ ),
62
+ ] = None,
63
+ ) -> None:
64
+ """Pull artifacts from a Flower run."""
65
+ typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
66
+
67
+ pyproject_path = app / FAB_CONFIG_FILE if app else None
68
+ config, errors, warnings = load_and_validate(path=pyproject_path)
69
+ config = process_loaded_project_config(config, errors, warnings)
70
+ federation, federation_config = validate_federation_in_project_config(
71
+ federation, config, federation_config_overrides
72
+ )
73
+ exit_if_no_address(federation_config, "pull")
74
+ channel = None
75
+ try:
76
+
77
+ auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
78
+ channel = init_channel(app, federation_config, auth_plugin)
79
+ stub = ControlStub(channel)
80
+ with flwr_cli_grpc_exc_handler():
81
+ res: PullArtifactsResponse = stub.PullArtifacts(
82
+ PullArtifactsRequest(run_id=run_id)
83
+ )
84
+
85
+ if not res.url:
86
+ typer.secho(
87
+ f"❌ A download URL for artifacts from run {run_id} couldn't be "
88
+ "obtained.",
89
+ fg=typer.colors.RED,
90
+ bold=True,
91
+ )
92
+ raise typer.Exit(code=1)
93
+
94
+ typer.secho(
95
+ f"✅ Artifacts for run {run_id} can be downloaded from: {res.url}",
96
+ fg=typer.colors.GREEN,
97
+ )
98
+ finally:
99
+ if channel:
100
+ channel.close()
flwr/cli/utils.py CHANGED
@@ -32,7 +32,9 @@ from flwr.common.constant import (
32
32
  AUTH_TYPE_JSON_KEY,
33
33
  CREDENTIALS_DIR,
34
34
  FLWR_DIR,
35
+ NO_ARTIFACT_PROVIDER_MESSAGE,
35
36
  NO_USER_AUTH_MESSAGE,
37
+ PULL_UNFINISHED_RUN_MESSAGE,
36
38
  RUN_ID_NOT_FOUND_MESSAGE,
37
39
  )
38
40
  from flwr.common.grpc import (
@@ -319,6 +321,12 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
319
321
  fg=typer.colors.RED,
320
322
  bold=True,
321
323
  )
324
+ elif e.details() == NO_ARTIFACT_PROVIDER_MESSAGE: # pylint: disable=E1101
325
+ typer.secho(
326
+ "❌ The SuperLink does not support `flwr pull` command.",
327
+ fg=typer.colors.RED,
328
+ bold=True,
329
+ )
322
330
  else:
323
331
  typer.secho(
324
332
  "❌ The SuperLink cannot process this request. Please verify that "
@@ -356,4 +364,13 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
356
364
  bold=True,
357
365
  )
358
366
  raise typer.Exit(code=1) from None
367
+ if e.code() == grpc.StatusCode.FAILED_PRECONDITION:
368
+ if e.details() == PULL_UNFINISHED_RUN_MESSAGE: # pylint: disable=E1101
369
+ typer.secho(
370
+ "❌ Run is not finished yet. Artifacts can only be pulled after "
371
+ "the run is finished. You can check the run status with `flwr ls`.",
372
+ fg=typer.colors.RED,
373
+ bold=True,
374
+ )
375
+ raise typer.Exit(code=1) from None
359
376
  raise
flwr/common/constant.py CHANGED
@@ -155,6 +155,8 @@ PULL_BACKOFF_CAP = 10 # Maximum backoff time for pulling objects
155
155
  # ControlServicer constants
156
156
  RUN_ID_NOT_FOUND_MESSAGE = "Run ID not found"
157
157
  NO_USER_AUTH_MESSAGE = "ControlServicer initialized without user authentication"
158
+ NO_ARTIFACT_PROVIDER_MESSAGE = "ControlServicer initialized without artifact provider"
159
+ PULL_UNFINISHED_RUN_MESSAGE = "Cannot pull artifacts for an unfinished run"
158
160
 
159
161
 
160
162
  class MessageType:
flwr/proto/control_pb2.py CHANGED
@@ -18,7 +18,7 @@ from flwr.proto import recorddict_pb2 as flwr_dot_proto_dot_recorddict__pb2
18
18
  from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
19
19
 
20
20
 
21
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xe8\x03\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x62\x06proto3')
21
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"&\n\x14PullArtifactsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"1\n\x15PullArtifactsResponse\x12\x10\n\x03url\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x06\n\x04_url2\xc0\x04\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x12V\n\rPullArtifacts\x12 .flwr.proto.PullArtifactsRequest\x1a!.flwr.proto.PullArtifactsResponse\"\x00\x62\x06proto3')
22
22
 
23
23
  _globals = globals()
24
24
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -57,6 +57,10 @@ if _descriptor._USE_C_DESCRIPTORS == False:
57
57
  _globals['_STOPRUNREQUEST']._serialized_end=1101
58
58
  _globals['_STOPRUNRESPONSE']._serialized_start=1103
59
59
  _globals['_STOPRUNRESPONSE']._serialized_end=1137
60
- _globals['_CONTROL']._serialized_start=1140
61
- _globals['_CONTROL']._serialized_end=1628
60
+ _globals['_PULLARTIFACTSREQUEST']._serialized_start=1139
61
+ _globals['_PULLARTIFACTSREQUEST']._serialized_end=1177
62
+ _globals['_PULLARTIFACTSRESPONSE']._serialized_start=1179
63
+ _globals['_PULLARTIFACTSRESPONSE']._serialized_end=1228
64
+ _globals['_CONTROL']._serialized_start=1231
65
+ _globals['_CONTROL']._serialized_end=1807
62
66
  # @@protoc_insertion_point(module_scope)
@@ -210,3 +210,27 @@ class StopRunResponse(google.protobuf.message.Message):
210
210
  ) -> None: ...
211
211
  def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ...
212
212
  global___StopRunResponse = StopRunResponse
213
+
214
+ class PullArtifactsRequest(google.protobuf.message.Message):
215
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
216
+ RUN_ID_FIELD_NUMBER: builtins.int
217
+ run_id: builtins.int
218
+ def __init__(self,
219
+ *,
220
+ run_id: builtins.int = ...,
221
+ ) -> None: ...
222
+ def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
223
+ global___PullArtifactsRequest = PullArtifactsRequest
224
+
225
+ class PullArtifactsResponse(google.protobuf.message.Message):
226
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
227
+ URL_FIELD_NUMBER: builtins.int
228
+ url: typing.Text
229
+ def __init__(self,
230
+ *,
231
+ url: typing.Optional[typing.Text] = ...,
232
+ ) -> None: ...
233
+ def HasField(self, field_name: typing_extensions.Literal["_url",b"_url","url",b"url"]) -> builtins.bool: ...
234
+ def ClearField(self, field_name: typing_extensions.Literal["_url",b"_url","url",b"url"]) -> None: ...
235
+ def WhichOneof(self, oneof_group: typing_extensions.Literal["_url",b"_url"]) -> typing.Optional[typing_extensions.Literal["url"]]: ...
236
+ global___PullArtifactsResponse = PullArtifactsResponse
@@ -44,6 +44,11 @@ class ControlStub(object):
44
44
  request_serializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensRequest.SerializeToString,
45
45
  response_deserializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.FromString,
46
46
  )
47
+ self.PullArtifacts = channel.unary_unary(
48
+ '/flwr.proto.Control/PullArtifacts',
49
+ request_serializer=flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.SerializeToString,
50
+ response_deserializer=flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.FromString,
51
+ )
47
52
 
48
53
 
49
54
  class ControlServicer(object):
@@ -91,6 +96,13 @@ class ControlServicer(object):
91
96
  context.set_details('Method not implemented!')
92
97
  raise NotImplementedError('Method not implemented!')
93
98
 
99
+ def PullArtifacts(self, request, context):
100
+ """Pull artifacts generated during a run (flwr pull)
101
+ """
102
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
103
+ context.set_details('Method not implemented!')
104
+ raise NotImplementedError('Method not implemented!')
105
+
94
106
 
95
107
  def add_ControlServicer_to_server(servicer, server):
96
108
  rpc_method_handlers = {
@@ -124,6 +136,11 @@ def add_ControlServicer_to_server(servicer, server):
124
136
  request_deserializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensRequest.FromString,
125
137
  response_serializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.SerializeToString,
126
138
  ),
139
+ 'PullArtifacts': grpc.unary_unary_rpc_method_handler(
140
+ servicer.PullArtifacts,
141
+ request_deserializer=flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.FromString,
142
+ response_serializer=flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.SerializeToString,
143
+ ),
127
144
  }
128
145
  generic_handler = grpc.method_handlers_generic_handler(
129
146
  'flwr.proto.Control', rpc_method_handlers)
@@ -235,3 +252,20 @@ class Control(object):
235
252
  flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.FromString,
236
253
  options, channel_credentials,
237
254
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
255
+
256
+ @staticmethod
257
+ def PullArtifacts(request,
258
+ target,
259
+ options=(),
260
+ channel_credentials=None,
261
+ call_credentials=None,
262
+ insecure=False,
263
+ compression=None,
264
+ wait_for_ready=None,
265
+ timeout=None,
266
+ metadata=None):
267
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/PullArtifacts',
268
+ flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.SerializeToString,
269
+ flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.FromString,
270
+ options, channel_credentials,
271
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -39,6 +39,11 @@ class ControlStub:
39
39
  flwr.proto.control_pb2.GetAuthTokensResponse]
40
40
  """Get auth tokens upon request"""
41
41
 
42
+ PullArtifacts: grpc.UnaryUnaryMultiCallable[
43
+ flwr.proto.control_pb2.PullArtifactsRequest,
44
+ flwr.proto.control_pb2.PullArtifactsResponse]
45
+ """Pull artifacts generated during a run (flwr pull)"""
46
+
42
47
 
43
48
  class ControlServicer(metaclass=abc.ABCMeta):
44
49
  @abc.abstractmethod
@@ -89,5 +94,13 @@ class ControlServicer(metaclass=abc.ABCMeta):
89
94
  """Get auth tokens upon request"""
90
95
  pass
91
96
 
97
+ @abc.abstractmethod
98
+ def PullArtifacts(self,
99
+ request: flwr.proto.control_pb2.PullArtifactsRequest,
100
+ context: grpc.ServicerContext,
101
+ ) -> flwr.proto.control_pb2.PullArtifactsResponse:
102
+ """Pull artifacts generated during a run (flwr pull)"""
103
+ pass
104
+
92
105
 
93
106
  def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ...
flwr/server/app.py CHANGED
@@ -71,6 +71,7 @@ from flwr.server.fleet_event_log_interceptor import FleetEventLogInterceptor
71
71
  from flwr.supercore.ffs import FfsFactory
72
72
  from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_no_tls
73
73
  from flwr.supercore.object_store import ObjectStoreFactory
74
+ from flwr.superlink.artifact_provider import ArtifactProvider
74
75
  from flwr.superlink.servicer.control import run_control_api_grpc
75
76
 
76
77
  from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
@@ -91,6 +92,7 @@ try:
91
92
  get_control_auth_plugins,
92
93
  get_control_authz_plugins,
93
94
  get_control_event_log_writer_plugins,
95
+ get_ee_artifact_provider,
94
96
  get_fleet_event_log_writer_plugins,
95
97
  )
96
98
  except ImportError:
@@ -113,6 +115,10 @@ except ImportError:
113
115
  "No event log writer plugins are currently supported."
114
116
  )
115
117
 
118
+ def get_ee_artifact_provider(config_path: str) -> ArtifactProvider:
119
+ """Return the EE artifact provider."""
120
+ raise NotImplementedError("No artifact provider is currently supported.")
121
+
116
122
  def get_fleet_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
117
123
  """Return all Fleet API event log writer plugins."""
118
124
  raise NotImplementedError(
@@ -199,6 +205,12 @@ def run_superlink() -> None:
199
205
  if args.enable_event_log:
200
206
  event_log_plugin = _try_obtain_control_event_log_writer_plugin()
201
207
 
208
+ # Load artifact provider if the args.artifact_provider_config is provided
209
+ artifact_provider = None
210
+ if cfg_path := getattr(args, "artifact_provider_config", None):
211
+ log(WARN, "The `--artifact-provider-config` flag is highly experimental.")
212
+ artifact_provider = get_ee_artifact_provider(cfg_path)
213
+
202
214
  # Initialize StateFactory
203
215
  state_factory = LinkStateFactory(args.database)
204
216
 
@@ -220,6 +232,7 @@ def run_superlink() -> None:
220
232
  auth_plugin=auth_plugin,
221
233
  authz_plugin=authz_plugin,
222
234
  event_log_plugin=event_log_plugin,
235
+ artifact_provider=artifact_provider,
223
236
  )
224
237
  grpc_servers = [control_server]
225
238
  bckg_threads: list[threading.Thread] = []
@@ -22,6 +22,10 @@ from .dp_fixed_clipping import (
22
22
  from .fedadagrad import FedAdagrad
23
23
  from .fedadam import FedAdam
24
24
  from .fedavg import FedAvg
25
+ from .fedavgm import FedAvgM
26
+ from .fedmedian import FedMedian
27
+ from .fedprox import FedProx
28
+ from .fedtrimmedavg import FedTrimmedAvg
25
29
  from .fedxgb_bagging import FedXgbBagging
26
30
  from .fedyogi import FedYogi
27
31
  from .result import Result
@@ -33,6 +37,10 @@ __all__ = [
33
37
  "FedAdagrad",
34
38
  "FedAdam",
35
39
  "FedAvg",
40
+ "FedAvgM",
41
+ "FedMedian",
42
+ "FedProx",
43
+ "FedTrimmedAvg",
36
44
  "FedXgbBagging",
37
45
  "FedYogi",
38
46
  "Result",
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  from collections.abc import Iterable
19
- from logging import INFO
19
+ from logging import INFO, WARNING
20
20
  from typing import Callable, Optional
21
21
 
22
22
  from flwr.common import (
@@ -67,7 +67,7 @@ class FedAvg(Strategy):
67
67
  arrayrecord_key : str (default: "arrays")
68
68
  Key used to store the ArrayRecord when constructing Messages.
69
69
  configrecord_key : str (default: "config")
70
- Key used to store the ConfigRecord when constructing Messages.
70
+ Key used to store the ConfigRecord when constructing Messages.
71
71
  train_metrics_aggr_fn : Optional[callable] (default: None)
72
72
  Function with signature (list[RecordDict], str) -> MetricRecord,
73
73
  used to aggregate MetricRecords from training round replies.
@@ -111,6 +111,20 @@ class FedAvg(Strategy):
111
111
  evaluate_metrics_aggr_fn or aggregate_metricrecords
112
112
  )
113
113
 
114
+ if self.fraction_evaluate == 0.0:
115
+ self.min_evaluate_nodes = 0
116
+ log(
117
+ WARNING,
118
+ "fraction_evaluate is set to 0.0. "
119
+ "Federated evaluation will be skipped.",
120
+ )
121
+ if self.fraction_train == 0.0:
122
+ self.min_train_nodes = 0
123
+ log(
124
+ WARNING,
125
+ "fraction_train is set to 0.0. Federated training will be skipped.",
126
+ )
127
+
114
128
  def summary(self) -> None:
115
129
  """Log summary configuration of the strategy."""
116
130
  log(INFO, "\t├──> Sampling:")
@@ -150,6 +164,9 @@ class FedAvg(Strategy):
150
164
  self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
151
165
  ) -> Iterable[Message]:
152
166
  """Configure the next round of federated training."""
167
+ # Do not configure federated train if fraction_train is 0.
168
+ if self.fraction_train == 0.0:
169
+ return []
153
170
  # Sample nodes
154
171
  num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
155
172
  sample_size = max(num_nodes, self.min_train_nodes)
@@ -259,6 +276,10 @@ class FedAvg(Strategy):
259
276
  self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
260
277
  ) -> Iterable[Message]:
261
278
  """Configure the next round of federated evaluation."""
279
+ # Do not configure federated evaluation if fraction_evaluate is 0.
280
+ if self.fraction_evaluate == 0.0:
281
+ return []
282
+
262
283
  # Sample nodes
263
284
  num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_evaluate)
264
285
  sample_size = max(num_nodes, self.min_evaluate_nodes)