flwr-nightly 1.22.0.dev20250915__py3-none-any.whl → 1.22.0.dev20250917__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +2 -0
- flwr/cli/new/new.py +2 -2
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/pull.py +100 -0
- flwr/cli/utils.py +17 -0
- flwr/common/constant.py +2 -0
- flwr/proto/control_pb2.py +7 -3
- flwr/proto/control_pb2.pyi +24 -0
- flwr/proto/control_pb2_grpc.py +34 -0
- flwr/proto/control_pb2_grpc.pyi +13 -0
- flwr/server/app.py +13 -0
- flwr/serverapp/strategy/__init__.py +8 -0
- flwr/serverapp/strategy/fedavg.py +23 -2
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +71 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/strategy_utils_tests.py +20 -1
- flwr/simulation/app.py +1 -1
- flwr/simulation/run_simulation.py +25 -30
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/superlink/servicer/control/control_grpc.py +3 -0
- flwr/superlink/servicer/control/control_servicer.py +59 -2
- {flwr_nightly-1.22.0.dev20250915.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/METADATA +6 -16
- {flwr_nightly-1.22.0.dev20250915.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/RECORD +37 -30
- {flwr_nightly-1.22.0.dev20250915.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.22.0.dev20250915.dist-info → flwr_nightly-1.22.0.dev20250917.dist-info}/entry_points.txt +0 -0
@@ -1,45 +1,43 @@
|
|
1
1
|
"""$project_name: A Flower Baseline."""
|
2
2
|
|
3
|
-
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
3
|
+
import torch
|
4
|
+
from flwr.app import ArrayRecord, Context
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
6
7
|
|
7
|
-
from $import_name.model import Net
|
8
|
+
from $import_name.model import Net
|
8
9
|
|
10
|
+
# Create ServerApp
|
11
|
+
app = ServerApp()
|
9
12
|
|
10
|
-
# Define metric aggregation function
|
11
|
-
def weighted_average(metrics: list[tuple[int, Metrics]]) -> Metrics:
|
12
|
-
"""Do weighted average of accuracy metric."""
|
13
|
-
# Multiply accuracy of each client by number of examples used
|
14
|
-
accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
|
15
|
-
examples = [num_examples for num_examples, _ in metrics]
|
16
|
-
|
17
|
-
# Aggregate and return custom metric (weighted average)
|
18
|
-
return {"accuracy": sum(accuracies) / sum(examples)}
|
19
13
|
|
14
|
+
@app.main()
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
16
|
+
"""Main entry point for the ServerApp."""
|
20
17
|
|
21
|
-
def server_fn(context: Context):
|
22
|
-
"""Construct components that set the ServerApp behaviour."""
|
23
18
|
# Read from config
|
24
19
|
num_rounds = context.run_config["num-server-rounds"]
|
25
|
-
|
20
|
+
fraction_train = context.run_config["fraction-train"]
|
26
21
|
|
27
|
-
#
|
28
|
-
|
29
|
-
|
22
|
+
# Load global model
|
23
|
+
global_model = Net()
|
24
|
+
arrays = ArrayRecord(global_model.state_dict())
|
30
25
|
|
31
|
-
#
|
26
|
+
# Initialize FedAvg strategy
|
32
27
|
strategy = FedAvg(
|
33
|
-
|
28
|
+
fraction_train=fraction_train,
|
34
29
|
fraction_evaluate=1.0,
|
35
|
-
|
36
|
-
initial_parameters=parameters,
|
37
|
-
evaluate_metrics_aggregation_fn=weighted_average,
|
30
|
+
min_available_nodes=2,
|
38
31
|
)
|
39
|
-
config = ServerConfig(num_rounds=int(num_rounds))
|
40
|
-
|
41
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
42
32
|
|
33
|
+
# Start strategy, run FedAvg for `num_rounds`
|
34
|
+
result = strategy.start(
|
35
|
+
grid=grid,
|
36
|
+
initial_arrays=arrays,
|
37
|
+
num_rounds=num_rounds,
|
38
|
+
)
|
43
39
|
|
44
|
-
#
|
45
|
-
|
40
|
+
# Save final model to disk
|
41
|
+
print("\nSaving final model to disk...")
|
42
|
+
state_dict = result.arrays.to_torch_state_dict()
|
43
|
+
torch.save(state_dict, "final_model.pt")
|
@@ -16,8 +16,8 @@ license = "Apache-2.0"
|
|
16
16
|
dependencies = [
|
17
17
|
"flwr[simulation]>=1.22.0",
|
18
18
|
"flwr-datasets[vision]>=0.5.0",
|
19
|
-
"torch==2.
|
20
|
-
"torchvision==0.
|
19
|
+
"torch==2.8.0",
|
20
|
+
"torchvision==0.23.0",
|
21
21
|
]
|
22
22
|
|
23
23
|
[tool.hatch.metadata]
|
@@ -132,7 +132,7 @@ clientapp = "$import_name.client_app:app"
|
|
132
132
|
# Custom config values accessible via `context.run_config`
|
133
133
|
[tool.flwr.app.config]
|
134
134
|
num-server-rounds = 3
|
135
|
-
fraction-
|
135
|
+
fraction-train = 0.5
|
136
136
|
local-epochs = 1
|
137
137
|
|
138
138
|
# Default federation to use when running the app
|
@@ -61,7 +61,7 @@ train.training-arguments.save-steps = 1000
|
|
61
61
|
train.training-arguments.save-total-limit = 10
|
62
62
|
train.training-arguments.gradient-checkpointing = true
|
63
63
|
train.training-arguments.lr-scheduler-type = "constant"
|
64
|
-
strategy.fraction-
|
64
|
+
strategy.fraction-train = $fraction_train
|
65
65
|
strategy.fraction-evaluate = 0.0
|
66
66
|
num-server-rounds = 200
|
67
67
|
|
flwr/cli/pull.py
ADDED
@@ -0,0 +1,100 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower command line interface `pull` command."""
|
16
|
+
|
17
|
+
|
18
|
+
from pathlib import Path
|
19
|
+
from typing import Annotated, Optional
|
20
|
+
|
21
|
+
import typer
|
22
|
+
|
23
|
+
from flwr.cli.config_utils import (
|
24
|
+
exit_if_no_address,
|
25
|
+
load_and_validate,
|
26
|
+
process_loaded_project_config,
|
27
|
+
validate_federation_in_project_config,
|
28
|
+
)
|
29
|
+
from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
|
30
|
+
from flwr.common.constant import FAB_CONFIG_FILE
|
31
|
+
from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
32
|
+
PullArtifactsRequest,
|
33
|
+
PullArtifactsResponse,
|
34
|
+
)
|
35
|
+
from flwr.proto.control_pb2_grpc import ControlStub
|
36
|
+
|
37
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
|
38
|
+
|
39
|
+
|
40
|
+
def pull( # pylint: disable=R0914
|
41
|
+
run_id: Annotated[
|
42
|
+
int,
|
43
|
+
typer.Option(
|
44
|
+
"--run-id",
|
45
|
+
help="Run ID to pull artifacts from.",
|
46
|
+
),
|
47
|
+
],
|
48
|
+
app: Annotated[
|
49
|
+
Path,
|
50
|
+
typer.Argument(help="Path of the Flower App to run."),
|
51
|
+
] = Path("."),
|
52
|
+
federation: Annotated[
|
53
|
+
Optional[str],
|
54
|
+
typer.Argument(help="Name of the federation."),
|
55
|
+
] = None,
|
56
|
+
federation_config_overrides: Annotated[
|
57
|
+
Optional[list[str]],
|
58
|
+
typer.Option(
|
59
|
+
"--federation-config",
|
60
|
+
help=FEDERATION_CONFIG_HELP_MESSAGE,
|
61
|
+
),
|
62
|
+
] = None,
|
63
|
+
) -> None:
|
64
|
+
"""Pull artifacts from a Flower run."""
|
65
|
+
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
66
|
+
|
67
|
+
pyproject_path = app / FAB_CONFIG_FILE if app else None
|
68
|
+
config, errors, warnings = load_and_validate(path=pyproject_path)
|
69
|
+
config = process_loaded_project_config(config, errors, warnings)
|
70
|
+
federation, federation_config = validate_federation_in_project_config(
|
71
|
+
federation, config, federation_config_overrides
|
72
|
+
)
|
73
|
+
exit_if_no_address(federation_config, "pull")
|
74
|
+
channel = None
|
75
|
+
try:
|
76
|
+
|
77
|
+
auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
|
78
|
+
channel = init_channel(app, federation_config, auth_plugin)
|
79
|
+
stub = ControlStub(channel)
|
80
|
+
with flwr_cli_grpc_exc_handler():
|
81
|
+
res: PullArtifactsResponse = stub.PullArtifacts(
|
82
|
+
PullArtifactsRequest(run_id=run_id)
|
83
|
+
)
|
84
|
+
|
85
|
+
if not res.url:
|
86
|
+
typer.secho(
|
87
|
+
f"❌ A download URL for artifacts from run {run_id} couldn't be "
|
88
|
+
"obtained.",
|
89
|
+
fg=typer.colors.RED,
|
90
|
+
bold=True,
|
91
|
+
)
|
92
|
+
raise typer.Exit(code=1)
|
93
|
+
|
94
|
+
typer.secho(
|
95
|
+
f"✅ Artifacts for run {run_id} can be downloaded from: {res.url}",
|
96
|
+
fg=typer.colors.GREEN,
|
97
|
+
)
|
98
|
+
finally:
|
99
|
+
if channel:
|
100
|
+
channel.close()
|
flwr/cli/utils.py
CHANGED
@@ -32,7 +32,9 @@ from flwr.common.constant import (
|
|
32
32
|
AUTH_TYPE_JSON_KEY,
|
33
33
|
CREDENTIALS_DIR,
|
34
34
|
FLWR_DIR,
|
35
|
+
NO_ARTIFACT_PROVIDER_MESSAGE,
|
35
36
|
NO_USER_AUTH_MESSAGE,
|
37
|
+
PULL_UNFINISHED_RUN_MESSAGE,
|
36
38
|
RUN_ID_NOT_FOUND_MESSAGE,
|
37
39
|
)
|
38
40
|
from flwr.common.grpc import (
|
@@ -319,6 +321,12 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
|
319
321
|
fg=typer.colors.RED,
|
320
322
|
bold=True,
|
321
323
|
)
|
324
|
+
elif e.details() == NO_ARTIFACT_PROVIDER_MESSAGE: # pylint: disable=E1101
|
325
|
+
typer.secho(
|
326
|
+
"❌ The SuperLink does not support `flwr pull` command.",
|
327
|
+
fg=typer.colors.RED,
|
328
|
+
bold=True,
|
329
|
+
)
|
322
330
|
else:
|
323
331
|
typer.secho(
|
324
332
|
"❌ The SuperLink cannot process this request. Please verify that "
|
@@ -356,4 +364,13 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
|
356
364
|
bold=True,
|
357
365
|
)
|
358
366
|
raise typer.Exit(code=1) from None
|
367
|
+
if e.code() == grpc.StatusCode.FAILED_PRECONDITION:
|
368
|
+
if e.details() == PULL_UNFINISHED_RUN_MESSAGE: # pylint: disable=E1101
|
369
|
+
typer.secho(
|
370
|
+
"❌ Run is not finished yet. Artifacts can only be pulled after "
|
371
|
+
"the run is finished. You can check the run status with `flwr ls`.",
|
372
|
+
fg=typer.colors.RED,
|
373
|
+
bold=True,
|
374
|
+
)
|
375
|
+
raise typer.Exit(code=1) from None
|
359
376
|
raise
|
flwr/common/constant.py
CHANGED
@@ -155,6 +155,8 @@ PULL_BACKOFF_CAP = 10 # Maximum backoff time for pulling objects
|
|
155
155
|
# ControlServicer constants
|
156
156
|
RUN_ID_NOT_FOUND_MESSAGE = "Run ID not found"
|
157
157
|
NO_USER_AUTH_MESSAGE = "ControlServicer initialized without user authentication"
|
158
|
+
NO_ARTIFACT_PROVIDER_MESSAGE = "ControlServicer initialized without artifact provider"
|
159
|
+
PULL_UNFINISHED_RUN_MESSAGE = "Cannot pull artifacts for an unfinished run"
|
158
160
|
|
159
161
|
|
160
162
|
class MessageType:
|
flwr/proto/control_pb2.py
CHANGED
@@ -18,7 +18,7 @@ from flwr.proto import recorddict_pb2 as flwr_dot_proto_dot_recorddict__pb2
|
|
18
18
|
from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
19
19
|
|
20
20
|
|
21
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\
|
21
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"&\n\x14PullArtifactsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"1\n\x15PullArtifactsResponse\x12\x10\n\x03url\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x06\n\x04_url2\xc0\x04\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x12V\n\rPullArtifacts\x12 .flwr.proto.PullArtifactsRequest\x1a!.flwr.proto.PullArtifactsResponse\"\x00\x62\x06proto3')
|
22
22
|
|
23
23
|
_globals = globals()
|
24
24
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
@@ -57,6 +57,10 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
57
57
|
_globals['_STOPRUNREQUEST']._serialized_end=1101
|
58
58
|
_globals['_STOPRUNRESPONSE']._serialized_start=1103
|
59
59
|
_globals['_STOPRUNRESPONSE']._serialized_end=1137
|
60
|
-
_globals['
|
61
|
-
_globals['
|
60
|
+
_globals['_PULLARTIFACTSREQUEST']._serialized_start=1139
|
61
|
+
_globals['_PULLARTIFACTSREQUEST']._serialized_end=1177
|
62
|
+
_globals['_PULLARTIFACTSRESPONSE']._serialized_start=1179
|
63
|
+
_globals['_PULLARTIFACTSRESPONSE']._serialized_end=1228
|
64
|
+
_globals['_CONTROL']._serialized_start=1231
|
65
|
+
_globals['_CONTROL']._serialized_end=1807
|
62
66
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/control_pb2.pyi
CHANGED
@@ -210,3 +210,27 @@ class StopRunResponse(google.protobuf.message.Message):
|
|
210
210
|
) -> None: ...
|
211
211
|
def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ...
|
212
212
|
global___StopRunResponse = StopRunResponse
|
213
|
+
|
214
|
+
class PullArtifactsRequest(google.protobuf.message.Message):
|
215
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
216
|
+
RUN_ID_FIELD_NUMBER: builtins.int
|
217
|
+
run_id: builtins.int
|
218
|
+
def __init__(self,
|
219
|
+
*,
|
220
|
+
run_id: builtins.int = ...,
|
221
|
+
) -> None: ...
|
222
|
+
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
|
223
|
+
global___PullArtifactsRequest = PullArtifactsRequest
|
224
|
+
|
225
|
+
class PullArtifactsResponse(google.protobuf.message.Message):
|
226
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
227
|
+
URL_FIELD_NUMBER: builtins.int
|
228
|
+
url: typing.Text
|
229
|
+
def __init__(self,
|
230
|
+
*,
|
231
|
+
url: typing.Optional[typing.Text] = ...,
|
232
|
+
) -> None: ...
|
233
|
+
def HasField(self, field_name: typing_extensions.Literal["_url",b"_url","url",b"url"]) -> builtins.bool: ...
|
234
|
+
def ClearField(self, field_name: typing_extensions.Literal["_url",b"_url","url",b"url"]) -> None: ...
|
235
|
+
def WhichOneof(self, oneof_group: typing_extensions.Literal["_url",b"_url"]) -> typing.Optional[typing_extensions.Literal["url"]]: ...
|
236
|
+
global___PullArtifactsResponse = PullArtifactsResponse
|
flwr/proto/control_pb2_grpc.py
CHANGED
@@ -44,6 +44,11 @@ class ControlStub(object):
|
|
44
44
|
request_serializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensRequest.SerializeToString,
|
45
45
|
response_deserializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.FromString,
|
46
46
|
)
|
47
|
+
self.PullArtifacts = channel.unary_unary(
|
48
|
+
'/flwr.proto.Control/PullArtifacts',
|
49
|
+
request_serializer=flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.SerializeToString,
|
50
|
+
response_deserializer=flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.FromString,
|
51
|
+
)
|
47
52
|
|
48
53
|
|
49
54
|
class ControlServicer(object):
|
@@ -91,6 +96,13 @@ class ControlServicer(object):
|
|
91
96
|
context.set_details('Method not implemented!')
|
92
97
|
raise NotImplementedError('Method not implemented!')
|
93
98
|
|
99
|
+
def PullArtifacts(self, request, context):
|
100
|
+
"""Pull artifacts generated during a run (flwr pull)
|
101
|
+
"""
|
102
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
103
|
+
context.set_details('Method not implemented!')
|
104
|
+
raise NotImplementedError('Method not implemented!')
|
105
|
+
|
94
106
|
|
95
107
|
def add_ControlServicer_to_server(servicer, server):
|
96
108
|
rpc_method_handlers = {
|
@@ -124,6 +136,11 @@ def add_ControlServicer_to_server(servicer, server):
|
|
124
136
|
request_deserializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensRequest.FromString,
|
125
137
|
response_serializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.SerializeToString,
|
126
138
|
),
|
139
|
+
'PullArtifacts': grpc.unary_unary_rpc_method_handler(
|
140
|
+
servicer.PullArtifacts,
|
141
|
+
request_deserializer=flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.FromString,
|
142
|
+
response_serializer=flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.SerializeToString,
|
143
|
+
),
|
127
144
|
}
|
128
145
|
generic_handler = grpc.method_handlers_generic_handler(
|
129
146
|
'flwr.proto.Control', rpc_method_handlers)
|
@@ -235,3 +252,20 @@ class Control(object):
|
|
235
252
|
flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.FromString,
|
236
253
|
options, channel_credentials,
|
237
254
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
255
|
+
|
256
|
+
@staticmethod
|
257
|
+
def PullArtifacts(request,
|
258
|
+
target,
|
259
|
+
options=(),
|
260
|
+
channel_credentials=None,
|
261
|
+
call_credentials=None,
|
262
|
+
insecure=False,
|
263
|
+
compression=None,
|
264
|
+
wait_for_ready=None,
|
265
|
+
timeout=None,
|
266
|
+
metadata=None):
|
267
|
+
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/PullArtifacts',
|
268
|
+
flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.SerializeToString,
|
269
|
+
flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.FromString,
|
270
|
+
options, channel_credentials,
|
271
|
+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
flwr/proto/control_pb2_grpc.pyi
CHANGED
@@ -39,6 +39,11 @@ class ControlStub:
|
|
39
39
|
flwr.proto.control_pb2.GetAuthTokensResponse]
|
40
40
|
"""Get auth tokens upon request"""
|
41
41
|
|
42
|
+
PullArtifacts: grpc.UnaryUnaryMultiCallable[
|
43
|
+
flwr.proto.control_pb2.PullArtifactsRequest,
|
44
|
+
flwr.proto.control_pb2.PullArtifactsResponse]
|
45
|
+
"""Pull artifacts generated during a run (flwr pull)"""
|
46
|
+
|
42
47
|
|
43
48
|
class ControlServicer(metaclass=abc.ABCMeta):
|
44
49
|
@abc.abstractmethod
|
@@ -89,5 +94,13 @@ class ControlServicer(metaclass=abc.ABCMeta):
|
|
89
94
|
"""Get auth tokens upon request"""
|
90
95
|
pass
|
91
96
|
|
97
|
+
@abc.abstractmethod
|
98
|
+
def PullArtifacts(self,
|
99
|
+
request: flwr.proto.control_pb2.PullArtifactsRequest,
|
100
|
+
context: grpc.ServicerContext,
|
101
|
+
) -> flwr.proto.control_pb2.PullArtifactsResponse:
|
102
|
+
"""Pull artifacts generated during a run (flwr pull)"""
|
103
|
+
pass
|
104
|
+
|
92
105
|
|
93
106
|
def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ...
|
flwr/server/app.py
CHANGED
@@ -71,6 +71,7 @@ from flwr.server.fleet_event_log_interceptor import FleetEventLogInterceptor
|
|
71
71
|
from flwr.supercore.ffs import FfsFactory
|
72
72
|
from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_no_tls
|
73
73
|
from flwr.supercore.object_store import ObjectStoreFactory
|
74
|
+
from flwr.superlink.artifact_provider import ArtifactProvider
|
74
75
|
from flwr.superlink.servicer.control import run_control_api_grpc
|
75
76
|
|
76
77
|
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
@@ -91,6 +92,7 @@ try:
|
|
91
92
|
get_control_auth_plugins,
|
92
93
|
get_control_authz_plugins,
|
93
94
|
get_control_event_log_writer_plugins,
|
95
|
+
get_ee_artifact_provider,
|
94
96
|
get_fleet_event_log_writer_plugins,
|
95
97
|
)
|
96
98
|
except ImportError:
|
@@ -113,6 +115,10 @@ except ImportError:
|
|
113
115
|
"No event log writer plugins are currently supported."
|
114
116
|
)
|
115
117
|
|
118
|
+
def get_ee_artifact_provider(config_path: str) -> ArtifactProvider:
|
119
|
+
"""Return the EE artifact provider."""
|
120
|
+
raise NotImplementedError("No artifact provider is currently supported.")
|
121
|
+
|
116
122
|
def get_fleet_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
|
117
123
|
"""Return all Fleet API event log writer plugins."""
|
118
124
|
raise NotImplementedError(
|
@@ -199,6 +205,12 @@ def run_superlink() -> None:
|
|
199
205
|
if args.enable_event_log:
|
200
206
|
event_log_plugin = _try_obtain_control_event_log_writer_plugin()
|
201
207
|
|
208
|
+
# Load artifact provider if the args.artifact_provider_config is provided
|
209
|
+
artifact_provider = None
|
210
|
+
if cfg_path := getattr(args, "artifact_provider_config", None):
|
211
|
+
log(WARN, "The `--artifact-provider-config` flag is highly experimental.")
|
212
|
+
artifact_provider = get_ee_artifact_provider(cfg_path)
|
213
|
+
|
202
214
|
# Initialize StateFactory
|
203
215
|
state_factory = LinkStateFactory(args.database)
|
204
216
|
|
@@ -220,6 +232,7 @@ def run_superlink() -> None:
|
|
220
232
|
auth_plugin=auth_plugin,
|
221
233
|
authz_plugin=authz_plugin,
|
222
234
|
event_log_plugin=event_log_plugin,
|
235
|
+
artifact_provider=artifact_provider,
|
223
236
|
)
|
224
237
|
grpc_servers = [control_server]
|
225
238
|
bckg_threads: list[threading.Thread] = []
|
@@ -22,6 +22,10 @@ from .dp_fixed_clipping import (
|
|
22
22
|
from .fedadagrad import FedAdagrad
|
23
23
|
from .fedadam import FedAdam
|
24
24
|
from .fedavg import FedAvg
|
25
|
+
from .fedavgm import FedAvgM
|
26
|
+
from .fedmedian import FedMedian
|
27
|
+
from .fedprox import FedProx
|
28
|
+
from .fedtrimmedavg import FedTrimmedAvg
|
25
29
|
from .fedxgb_bagging import FedXgbBagging
|
26
30
|
from .fedyogi import FedYogi
|
27
31
|
from .result import Result
|
@@ -33,6 +37,10 @@ __all__ = [
|
|
33
37
|
"FedAdagrad",
|
34
38
|
"FedAdam",
|
35
39
|
"FedAvg",
|
40
|
+
"FedAvgM",
|
41
|
+
"FedMedian",
|
42
|
+
"FedProx",
|
43
|
+
"FedTrimmedAvg",
|
36
44
|
"FedXgbBagging",
|
37
45
|
"FedYogi",
|
38
46
|
"Result",
|
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from collections.abc import Iterable
|
19
|
-
from logging import INFO
|
19
|
+
from logging import INFO, WARNING
|
20
20
|
from typing import Callable, Optional
|
21
21
|
|
22
22
|
from flwr.common import (
|
@@ -67,7 +67,7 @@ class FedAvg(Strategy):
|
|
67
67
|
arrayrecord_key : str (default: "arrays")
|
68
68
|
Key used to store the ArrayRecord when constructing Messages.
|
69
69
|
configrecord_key : str (default: "config")
|
70
|
-
|
70
|
+
Key used to store the ConfigRecord when constructing Messages.
|
71
71
|
train_metrics_aggr_fn : Optional[callable] (default: None)
|
72
72
|
Function with signature (list[RecordDict], str) -> MetricRecord,
|
73
73
|
used to aggregate MetricRecords from training round replies.
|
@@ -111,6 +111,20 @@ class FedAvg(Strategy):
|
|
111
111
|
evaluate_metrics_aggr_fn or aggregate_metricrecords
|
112
112
|
)
|
113
113
|
|
114
|
+
if self.fraction_evaluate == 0.0:
|
115
|
+
self.min_evaluate_nodes = 0
|
116
|
+
log(
|
117
|
+
WARNING,
|
118
|
+
"fraction_evaluate is set to 0.0. "
|
119
|
+
"Federated evaluation will be skipped.",
|
120
|
+
)
|
121
|
+
if self.fraction_train == 0.0:
|
122
|
+
self.min_train_nodes = 0
|
123
|
+
log(
|
124
|
+
WARNING,
|
125
|
+
"fraction_train is set to 0.0. Federated training will be skipped.",
|
126
|
+
)
|
127
|
+
|
114
128
|
def summary(self) -> None:
|
115
129
|
"""Log summary configuration of the strategy."""
|
116
130
|
log(INFO, "\t├──> Sampling:")
|
@@ -150,6 +164,9 @@ class FedAvg(Strategy):
|
|
150
164
|
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
151
165
|
) -> Iterable[Message]:
|
152
166
|
"""Configure the next round of federated training."""
|
167
|
+
# Do not configure federated train if fraction_train is 0.
|
168
|
+
if self.fraction_train == 0.0:
|
169
|
+
return []
|
153
170
|
# Sample nodes
|
154
171
|
num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
|
155
172
|
sample_size = max(num_nodes, self.min_train_nodes)
|
@@ -259,6 +276,10 @@ class FedAvg(Strategy):
|
|
259
276
|
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
260
277
|
) -> Iterable[Message]:
|
261
278
|
"""Configure the next round of federated evaluation."""
|
279
|
+
# Do not configure federated evaluation if fraction_evaluate is 0.
|
280
|
+
if self.fraction_evaluate == 0.0:
|
281
|
+
return []
|
282
|
+
|
262
283
|
# Sample nodes
|
263
284
|
num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_evaluate)
|
264
285
|
sample_size = max(num_nodes, self.min_evaluate_nodes)
|