flwr-nightly 1.22.0.dev20250916__py3-none-any.whl → 1.22.0.dev20250918__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +2 -0
- flwr/cli/new/new.py +4 -2
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/utils.py +17 -0
- flwr/common/constant.py +2 -0
- flwr/common/exit/exit_code.py +4 -0
- flwr/proto/control_pb2.py +7 -3
- flwr/proto/control_pb2.pyi +24 -0
- flwr/proto/control_pb2_grpc.py +34 -0
- flwr/proto/control_pb2_grpc.pyi +13 -0
- flwr/server/app.py +13 -0
- flwr/serverapp/strategy/__init__.py +4 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/run_simulation.py +25 -30
- flwr/supercore/cli/flower_superexec.py +26 -1
- flwr/supercore/constant.py +19 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
- flwr/supercore/superexec/run_superexec.py +16 -2
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/superlink/servicer/control/control_grpc.py +3 -0
- flwr/superlink/servicer/control/control_servicer.py +59 -2
- {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/METADATA +1 -1
- {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/RECORD +42 -33
- flwr/serverapp/strategy/strategy_utils_tests.py +0 -323
- {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/entry_points.txt +0 -0
@@ -1,53 +1,48 @@
|
|
1
1
|
"""$project_name: A Flower / FlowerTune app."""
|
2
2
|
|
3
|
-
from
|
3
|
+
from collections.abc import Iterable
|
4
4
|
from logging import INFO, WARN
|
5
|
-
from typing import
|
5
|
+
from typing import Optional
|
6
6
|
|
7
|
-
from flwr.
|
8
|
-
from flwr.
|
9
|
-
from flwr.
|
10
|
-
from flwr.
|
7
|
+
from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
|
8
|
+
from flwr.common import log
|
9
|
+
from flwr.serverapp import Grid
|
10
|
+
from flwr.serverapp.strategy import FedAvg
|
11
11
|
|
12
12
|
|
13
13
|
class FlowerTuneLlm(FedAvg):
|
14
14
|
"""Customised FedAvg strategy implementation.
|
15
|
-
|
15
|
+
|
16
16
|
This class behaves just like FedAvg but also tracks the communication
|
17
|
-
costs associated with `
|
17
|
+
costs associated with `train` over FL rounds.
|
18
18
|
"""
|
19
19
|
def __init__(self, **kwargs):
|
20
20
|
super().__init__(**kwargs)
|
21
21
|
self.comm_tracker = CommunicationTracker()
|
22
22
|
|
23
|
-
def
|
24
|
-
|
25
|
-
):
|
23
|
+
def configure_train(
|
24
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
25
|
+
) -> Iterable[Message]:
|
26
26
|
"""Configure the next round of training."""
|
27
|
-
|
28
|
-
|
29
|
-
#
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
# Test communication costs
|
43
|
-
fit_res_list = [fit_res for _, fit_res in results]
|
44
|
-
self.comm_tracker.track(fit_res_list)
|
45
|
-
|
46
|
-
parameters_aggregated, metrics_aggregated = super().aggregate_fit(
|
47
|
-
server_round, results, failures
|
48
|
-
)
|
27
|
+
messages = super().configure_train(server_round, arrays, config, grid)
|
28
|
+
|
29
|
+
# Track communication costs
|
30
|
+
self.comm_tracker.track(messages)
|
31
|
+
|
32
|
+
return messages
|
33
|
+
|
34
|
+
def aggregate_train(
|
35
|
+
self,
|
36
|
+
server_round: int,
|
37
|
+
replies: Iterable[Message],
|
38
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
39
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
40
|
+
# Track communication costs
|
41
|
+
self.comm_tracker.track(replies)
|
49
42
|
|
50
|
-
|
43
|
+
arrays, metrics = super().aggregate_train(server_round, replies)
|
44
|
+
|
45
|
+
return arrays, metrics
|
51
46
|
|
52
47
|
|
53
48
|
class CommunicationTracker:
|
@@ -55,16 +50,16 @@ class CommunicationTracker:
|
|
55
50
|
def __init__(self):
|
56
51
|
self.curr_comm_cost = 0.0
|
57
52
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
53
|
+
def track(self, messages: Iterable[Message]):
|
54
|
+
comm_cost = (
|
55
|
+
sum(
|
56
|
+
record.count_bytes()
|
57
|
+
for msg in messages
|
58
|
+
if msg.has_content()
|
59
|
+
for record in msg.content.array_records.values()
|
60
|
+
)
|
61
|
+
/ 1024**2
|
62
|
+
)
|
68
63
|
|
69
64
|
self.curr_comm_cost += comm_cost
|
70
65
|
log(
|
@@ -1,7 +1,5 @@
|
|
1
1
|
"""$project_name: A Flower Baseline."""
|
2
2
|
|
3
|
-
from collections import OrderedDict
|
4
|
-
|
5
3
|
import torch
|
6
4
|
import torch.nn.functional as F
|
7
5
|
from torch import nn
|
@@ -66,15 +64,3 @@ def test(net, testloader, device):
|
|
66
64
|
accuracy = correct / len(testloader.dataset)
|
67
65
|
loss = loss / len(testloader)
|
68
66
|
return loss, accuracy
|
69
|
-
|
70
|
-
|
71
|
-
def get_weights(net):
|
72
|
-
"""Extract model parameters as numpy arrays from state_dict."""
|
73
|
-
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
74
|
-
|
75
|
-
|
76
|
-
def set_weights(net, parameters):
|
77
|
-
"""Apply parameters to an existing model."""
|
78
|
-
params_dict = zip(net.state_dict().keys(), parameters)
|
79
|
-
state_dict = OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})
|
80
|
-
net.load_state_dict(state_dict, strict=True)
|
@@ -1,45 +1,43 @@
|
|
1
1
|
"""$project_name: A Flower Baseline."""
|
2
2
|
|
3
|
-
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
3
|
+
import torch
|
4
|
+
from flwr.app import ArrayRecord, Context
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
6
7
|
|
7
|
-
from $import_name.model import Net
|
8
|
+
from $import_name.model import Net
|
8
9
|
|
10
|
+
# Create ServerApp
|
11
|
+
app = ServerApp()
|
9
12
|
|
10
|
-
# Define metric aggregation function
|
11
|
-
def weighted_average(metrics: list[tuple[int, Metrics]]) -> Metrics:
|
12
|
-
"""Do weighted average of accuracy metric."""
|
13
|
-
# Multiply accuracy of each client by number of examples used
|
14
|
-
accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
|
15
|
-
examples = [num_examples for num_examples, _ in metrics]
|
16
|
-
|
17
|
-
# Aggregate and return custom metric (weighted average)
|
18
|
-
return {"accuracy": sum(accuracies) / sum(examples)}
|
19
13
|
|
14
|
+
@app.main()
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
16
|
+
"""Main entry point for the ServerApp."""
|
20
17
|
|
21
|
-
def server_fn(context: Context):
|
22
|
-
"""Construct components that set the ServerApp behaviour."""
|
23
18
|
# Read from config
|
24
19
|
num_rounds = context.run_config["num-server-rounds"]
|
25
|
-
|
20
|
+
fraction_train = context.run_config["fraction-train"]
|
26
21
|
|
27
|
-
#
|
28
|
-
|
29
|
-
|
22
|
+
# Load global model
|
23
|
+
global_model = Net()
|
24
|
+
arrays = ArrayRecord(global_model.state_dict())
|
30
25
|
|
31
|
-
#
|
26
|
+
# Initialize FedAvg strategy
|
32
27
|
strategy = FedAvg(
|
33
|
-
|
28
|
+
fraction_train=fraction_train,
|
34
29
|
fraction_evaluate=1.0,
|
35
|
-
|
36
|
-
initial_parameters=parameters,
|
37
|
-
evaluate_metrics_aggregation_fn=weighted_average,
|
30
|
+
min_available_nodes=2,
|
38
31
|
)
|
39
|
-
config = ServerConfig(num_rounds=int(num_rounds))
|
40
|
-
|
41
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
42
32
|
|
33
|
+
# Start strategy, run FedAvg for `num_rounds`
|
34
|
+
result = strategy.start(
|
35
|
+
grid=grid,
|
36
|
+
initial_arrays=arrays,
|
37
|
+
num_rounds=num_rounds,
|
38
|
+
)
|
43
39
|
|
44
|
-
#
|
45
|
-
|
40
|
+
# Save final model to disk
|
41
|
+
print("\nSaving final model to disk...")
|
42
|
+
state_dict = result.arrays.to_torch_state_dict()
|
43
|
+
torch.save(state_dict, "final_model.pt")
|
@@ -0,0 +1,56 @@
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import xgboost as xgb
|
5
|
+
from flwr.app import ArrayRecord, Context
|
6
|
+
from flwr.common.config import unflatten_dict
|
7
|
+
from flwr.serverapp import Grid, ServerApp
|
8
|
+
from flwr.serverapp.strategy import FedXgbBagging
|
9
|
+
|
10
|
+
from $import_name.task import replace_keys
|
11
|
+
|
12
|
+
# Create ServerApp
|
13
|
+
app = ServerApp()
|
14
|
+
|
15
|
+
|
16
|
+
@app.main()
|
17
|
+
def main(grid: Grid, context: Context) -> None:
|
18
|
+
# Read run config
|
19
|
+
num_rounds = context.run_config["num-server-rounds"]
|
20
|
+
fraction_train = context.run_config["fraction-train"]
|
21
|
+
fraction_evaluate = context.run_config["fraction-evaluate"]
|
22
|
+
# Flatted config dict and replace "-" with "_"
|
23
|
+
cfg = replace_keys(unflatten_dict(context.run_config))
|
24
|
+
params = cfg["params"]
|
25
|
+
|
26
|
+
# Init global model
|
27
|
+
# Init with an empty object; the XGBooster will be created
|
28
|
+
# and trained on the client side.
|
29
|
+
global_model = b""
|
30
|
+
# Note: we store the model as the first item in a list into ArrayRecord,
|
31
|
+
# which can be accessed using index ["0"].
|
32
|
+
arrays = ArrayRecord([np.frombuffer(global_model, dtype=np.uint8)])
|
33
|
+
|
34
|
+
# Initialize FedXgbBagging strategy
|
35
|
+
strategy = FedXgbBagging(
|
36
|
+
fraction_train=fraction_train,
|
37
|
+
fraction_evaluate=fraction_evaluate,
|
38
|
+
)
|
39
|
+
|
40
|
+
# Start strategy, run FedXgbBagging for `num_rounds`
|
41
|
+
result = strategy.start(
|
42
|
+
grid=grid,
|
43
|
+
initial_arrays=arrays,
|
44
|
+
num_rounds=num_rounds,
|
45
|
+
)
|
46
|
+
|
47
|
+
# Save final model to disk
|
48
|
+
bst = xgb.Booster(params=params)
|
49
|
+
global_model = bytearray(result.arrays["0"].numpy().tobytes())
|
50
|
+
|
51
|
+
# Load global model into booster
|
52
|
+
bst.load_model(global_model)
|
53
|
+
|
54
|
+
# Save model
|
55
|
+
print("\nSaving final model to disk...")
|
56
|
+
bst.save_model("final_model.json")
|
@@ -0,0 +1,67 @@
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
2
|
+
|
3
|
+
import xgboost as xgb
|
4
|
+
from flwr_datasets import FederatedDataset
|
5
|
+
from flwr_datasets.partitioner import IidPartitioner
|
6
|
+
|
7
|
+
|
8
|
+
def train_test_split(partition, test_fraction, seed):
|
9
|
+
"""Split the data into train and validation set given split rate."""
|
10
|
+
train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
|
11
|
+
partition_train = train_test["train"]
|
12
|
+
partition_test = train_test["test"]
|
13
|
+
|
14
|
+
num_train = len(partition_train)
|
15
|
+
num_test = len(partition_test)
|
16
|
+
|
17
|
+
return partition_train, partition_test, num_train, num_test
|
18
|
+
|
19
|
+
|
20
|
+
def transform_dataset_to_dmatrix(data):
|
21
|
+
"""Transform dataset to DMatrix format for xgboost."""
|
22
|
+
x = data["inputs"]
|
23
|
+
y = data["label"]
|
24
|
+
new_data = xgb.DMatrix(x, label=y)
|
25
|
+
return new_data
|
26
|
+
|
27
|
+
|
28
|
+
fds = None # Cache FederatedDataset
|
29
|
+
|
30
|
+
|
31
|
+
def load_data(partition_id, num_clients):
|
32
|
+
"""Load partition HIGGS data."""
|
33
|
+
# Only initialize `FederatedDataset` once
|
34
|
+
global fds
|
35
|
+
if fds is None:
|
36
|
+
partitioner = IidPartitioner(num_partitions=num_clients)
|
37
|
+
fds = FederatedDataset(
|
38
|
+
dataset="jxie/higgs",
|
39
|
+
partitioners={"train": partitioner},
|
40
|
+
)
|
41
|
+
|
42
|
+
# Load the partition for this `partition_id`
|
43
|
+
partition = fds.load_partition(partition_id, split="train")
|
44
|
+
partition.set_format("numpy")
|
45
|
+
|
46
|
+
# Train/test splitting
|
47
|
+
train_data, valid_data, num_train, num_val = train_test_split(
|
48
|
+
partition, test_fraction=0.2, seed=42
|
49
|
+
)
|
50
|
+
|
51
|
+
# Reformat data to DMatrix for xgboost
|
52
|
+
train_dmatrix = transform_dataset_to_dmatrix(train_data)
|
53
|
+
valid_dmatrix = transform_dataset_to_dmatrix(valid_data)
|
54
|
+
|
55
|
+
return train_dmatrix, valid_dmatrix, num_train, num_val
|
56
|
+
|
57
|
+
|
58
|
+
def replace_keys(input_dict, match="-", target="_"):
|
59
|
+
"""Recursively replace match string with target string in dictionary keys."""
|
60
|
+
new_dict = {}
|
61
|
+
for key, value in input_dict.items():
|
62
|
+
new_key = key.replace(match, target)
|
63
|
+
if isinstance(value, dict):
|
64
|
+
new_dict[new_key] = replace_keys(value, match, target)
|
65
|
+
else:
|
66
|
+
new_dict[new_key] = value
|
67
|
+
return new_dict
|
@@ -16,8 +16,8 @@ license = "Apache-2.0"
|
|
16
16
|
dependencies = [
|
17
17
|
"flwr[simulation]>=1.22.0",
|
18
18
|
"flwr-datasets[vision]>=0.5.0",
|
19
|
-
"torch==2.
|
20
|
-
"torchvision==0.
|
19
|
+
"torch==2.8.0",
|
20
|
+
"torchvision==0.23.0",
|
21
21
|
]
|
22
22
|
|
23
23
|
[tool.hatch.metadata]
|
@@ -132,7 +132,7 @@ clientapp = "$import_name.client_app:app"
|
|
132
132
|
# Custom config values accessible via `context.run_config`
|
133
133
|
[tool.flwr.app.config]
|
134
134
|
num-server-rounds = 3
|
135
|
-
fraction-
|
135
|
+
fraction-train = 0.5
|
136
136
|
local-epochs = 1
|
137
137
|
|
138
138
|
# Default federation to use when running the app
|
@@ -61,7 +61,7 @@ train.training-arguments.save-steps = 1000
|
|
61
61
|
train.training-arguments.save-total-limit = 10
|
62
62
|
train.training-arguments.gradient-checkpointing = true
|
63
63
|
train.training-arguments.lr-scheduler-type = "constant"
|
64
|
-
strategy.fraction-
|
64
|
+
strategy.fraction-train = $fraction_train
|
65
65
|
strategy.fraction-evaluate = 0.0
|
66
66
|
num-server-rounds = 200
|
67
67
|
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# =====================================================================
|
2
|
+
# For a full TOML configuration guide, check the Flower docs:
|
3
|
+
# https://flower.ai/docs/framework/how-to-configure-pyproject-toml.html
|
4
|
+
# =====================================================================
|
5
|
+
|
6
|
+
[build-system]
|
7
|
+
requires = ["hatchling"]
|
8
|
+
build-backend = "hatchling.build"
|
9
|
+
|
10
|
+
[project]
|
11
|
+
name = "$package_name"
|
12
|
+
version = "1.0.0"
|
13
|
+
description = ""
|
14
|
+
license = "Apache-2.0"
|
15
|
+
# Dependencies for your Flower App
|
16
|
+
dependencies = [
|
17
|
+
"flwr[simulation]>=1.22.0",
|
18
|
+
"flwr-datasets>=0.5.0",
|
19
|
+
"xgboost>=2.0.0",
|
20
|
+
]
|
21
|
+
|
22
|
+
[tool.hatch.build.targets.wheel]
|
23
|
+
packages = ["."]
|
24
|
+
|
25
|
+
[tool.flwr.app]
|
26
|
+
publisher = "$username"
|
27
|
+
|
28
|
+
[tool.flwr.app.components]
|
29
|
+
serverapp = "$import_name.server_app:app"
|
30
|
+
clientapp = "$import_name.client_app:app"
|
31
|
+
|
32
|
+
# Custom config values accessible via `context.run_config`
|
33
|
+
[tool.flwr.app.config]
|
34
|
+
num-server-rounds = 3
|
35
|
+
fraction-train = 0.1
|
36
|
+
fraction-evaluate = 0.1
|
37
|
+
local-epochs = 1
|
38
|
+
|
39
|
+
# XGBoost parameters
|
40
|
+
params.objective = "binary:logistic"
|
41
|
+
params.eta = 0.1 # Learning rate
|
42
|
+
params.max-depth = 8
|
43
|
+
params.eval-metric = "auc"
|
44
|
+
params.nthread = 16
|
45
|
+
params.num-parallel-tree = 1
|
46
|
+
params.subsample = 1
|
47
|
+
params.tree-method = "hist"
|
48
|
+
|
49
|
+
# Default federation to use when running the app
|
50
|
+
[tool.flwr.federations]
|
51
|
+
default = "local-simulation"
|
52
|
+
|
53
|
+
# Local simulation federation with 10 virtual SuperNodes
|
54
|
+
[tool.flwr.federations.local-simulation]
|
55
|
+
options.num-supernodes = 10
|
56
|
+
|
57
|
+
# Remote federation example for use with SuperLink
|
58
|
+
[tool.flwr.federations.remote-federation]
|
59
|
+
address = "<SUPERLINK-ADDRESS>:<PORT>"
|
60
|
+
insecure = true # Remove this line to enable TLS
|
61
|
+
# root-certificates = "<PATH/TO/ca.crt>" # For TLS setup
|
flwr/cli/pull.py
ADDED
@@ -0,0 +1,100 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower command line interface `pull` command."""
|
16
|
+
|
17
|
+
|
18
|
+
from pathlib import Path
|
19
|
+
from typing import Annotated, Optional
|
20
|
+
|
21
|
+
import typer
|
22
|
+
|
23
|
+
from flwr.cli.config_utils import (
|
24
|
+
exit_if_no_address,
|
25
|
+
load_and_validate,
|
26
|
+
process_loaded_project_config,
|
27
|
+
validate_federation_in_project_config,
|
28
|
+
)
|
29
|
+
from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
|
30
|
+
from flwr.common.constant import FAB_CONFIG_FILE
|
31
|
+
from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
32
|
+
PullArtifactsRequest,
|
33
|
+
PullArtifactsResponse,
|
34
|
+
)
|
35
|
+
from flwr.proto.control_pb2_grpc import ControlStub
|
36
|
+
|
37
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
|
38
|
+
|
39
|
+
|
40
|
+
def pull( # pylint: disable=R0914
|
41
|
+
run_id: Annotated[
|
42
|
+
int,
|
43
|
+
typer.Option(
|
44
|
+
"--run-id",
|
45
|
+
help="Run ID to pull artifacts from.",
|
46
|
+
),
|
47
|
+
],
|
48
|
+
app: Annotated[
|
49
|
+
Path,
|
50
|
+
typer.Argument(help="Path of the Flower App to run."),
|
51
|
+
] = Path("."),
|
52
|
+
federation: Annotated[
|
53
|
+
Optional[str],
|
54
|
+
typer.Argument(help="Name of the federation."),
|
55
|
+
] = None,
|
56
|
+
federation_config_overrides: Annotated[
|
57
|
+
Optional[list[str]],
|
58
|
+
typer.Option(
|
59
|
+
"--federation-config",
|
60
|
+
help=FEDERATION_CONFIG_HELP_MESSAGE,
|
61
|
+
),
|
62
|
+
] = None,
|
63
|
+
) -> None:
|
64
|
+
"""Pull artifacts from a Flower run."""
|
65
|
+
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
66
|
+
|
67
|
+
pyproject_path = app / FAB_CONFIG_FILE if app else None
|
68
|
+
config, errors, warnings = load_and_validate(path=pyproject_path)
|
69
|
+
config = process_loaded_project_config(config, errors, warnings)
|
70
|
+
federation, federation_config = validate_federation_in_project_config(
|
71
|
+
federation, config, federation_config_overrides
|
72
|
+
)
|
73
|
+
exit_if_no_address(federation_config, "pull")
|
74
|
+
channel = None
|
75
|
+
try:
|
76
|
+
|
77
|
+
auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
|
78
|
+
channel = init_channel(app, federation_config, auth_plugin)
|
79
|
+
stub = ControlStub(channel)
|
80
|
+
with flwr_cli_grpc_exc_handler():
|
81
|
+
res: PullArtifactsResponse = stub.PullArtifacts(
|
82
|
+
PullArtifactsRequest(run_id=run_id)
|
83
|
+
)
|
84
|
+
|
85
|
+
if not res.url:
|
86
|
+
typer.secho(
|
87
|
+
f"❌ A download URL for artifacts from run {run_id} couldn't be "
|
88
|
+
"obtained.",
|
89
|
+
fg=typer.colors.RED,
|
90
|
+
bold=True,
|
91
|
+
)
|
92
|
+
raise typer.Exit(code=1)
|
93
|
+
|
94
|
+
typer.secho(
|
95
|
+
f"✅ Artifacts for run {run_id} can be downloaded from: {res.url}",
|
96
|
+
fg=typer.colors.GREEN,
|
97
|
+
)
|
98
|
+
finally:
|
99
|
+
if channel:
|
100
|
+
channel.close()
|
flwr/cli/utils.py
CHANGED
@@ -32,7 +32,9 @@ from flwr.common.constant import (
|
|
32
32
|
AUTH_TYPE_JSON_KEY,
|
33
33
|
CREDENTIALS_DIR,
|
34
34
|
FLWR_DIR,
|
35
|
+
NO_ARTIFACT_PROVIDER_MESSAGE,
|
35
36
|
NO_USER_AUTH_MESSAGE,
|
37
|
+
PULL_UNFINISHED_RUN_MESSAGE,
|
36
38
|
RUN_ID_NOT_FOUND_MESSAGE,
|
37
39
|
)
|
38
40
|
from flwr.common.grpc import (
|
@@ -319,6 +321,12 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
|
319
321
|
fg=typer.colors.RED,
|
320
322
|
bold=True,
|
321
323
|
)
|
324
|
+
elif e.details() == NO_ARTIFACT_PROVIDER_MESSAGE: # pylint: disable=E1101
|
325
|
+
typer.secho(
|
326
|
+
"❌ The SuperLink does not support `flwr pull` command.",
|
327
|
+
fg=typer.colors.RED,
|
328
|
+
bold=True,
|
329
|
+
)
|
322
330
|
else:
|
323
331
|
typer.secho(
|
324
332
|
"❌ The SuperLink cannot process this request. Please verify that "
|
@@ -356,4 +364,13 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
|
356
364
|
bold=True,
|
357
365
|
)
|
358
366
|
raise typer.Exit(code=1) from None
|
367
|
+
if e.code() == grpc.StatusCode.FAILED_PRECONDITION:
|
368
|
+
if e.details() == PULL_UNFINISHED_RUN_MESSAGE: # pylint: disable=E1101
|
369
|
+
typer.secho(
|
370
|
+
"❌ Run is not finished yet. Artifacts can only be pulled after "
|
371
|
+
"the run is finished. You can check the run status with `flwr ls`.",
|
372
|
+
fg=typer.colors.RED,
|
373
|
+
bold=True,
|
374
|
+
)
|
375
|
+
raise typer.Exit(code=1) from None
|
359
376
|
raise
|
flwr/common/constant.py
CHANGED
@@ -155,6 +155,8 @@ PULL_BACKOFF_CAP = 10 # Maximum backoff time for pulling objects
|
|
155
155
|
# ControlServicer constants
|
156
156
|
RUN_ID_NOT_FOUND_MESSAGE = "Run ID not found"
|
157
157
|
NO_USER_AUTH_MESSAGE = "ControlServicer initialized without user authentication"
|
158
|
+
NO_ARTIFACT_PROVIDER_MESSAGE = "ControlServicer initialized without artifact provider"
|
159
|
+
PULL_UNFINISHED_RUN_MESSAGE = "Cannot pull artifacts for an unfinished run"
|
158
160
|
|
159
161
|
|
160
162
|
class MessageType:
|
flwr/common/exit/exit_code.py
CHANGED
@@ -45,6 +45,7 @@ class ExitCode:
|
|
45
45
|
SUPERNODE_NODE_AUTH_KEYS_INVALID = 302
|
46
46
|
|
47
47
|
# SuperExec-specific exit codes (400-499)
|
48
|
+
SUPEREXEC_INVALID_PLUGIN_CONFIG = 400
|
48
49
|
|
49
50
|
# Common exit codes (600-699)
|
50
51
|
COMMON_ADDRESS_INVALID = 600
|
@@ -112,6 +113,9 @@ EXIT_CODE_HELP = {
|
|
112
113
|
"file and try again."
|
113
114
|
),
|
114
115
|
# SuperExec-specific exit codes (400-499)
|
116
|
+
ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG: (
|
117
|
+
"The YAML configuration for the SuperExec plugin is invalid."
|
118
|
+
),
|
115
119
|
# Common exit codes (600-699)
|
116
120
|
ExitCode.COMMON_ADDRESS_INVALID: (
|
117
121
|
"Please provide a valid URL, IPv4 or IPv6 address."
|
flwr/proto/control_pb2.py
CHANGED
@@ -18,7 +18,7 @@ from flwr.proto import recorddict_pb2 as flwr_dot_proto_dot_recorddict__pb2
|
|
18
18
|
from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
19
19
|
|
20
20
|
|
21
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\
|
21
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"&\n\x14PullArtifactsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"1\n\x15PullArtifactsResponse\x12\x10\n\x03url\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x06\n\x04_url2\xc0\x04\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x12V\n\rPullArtifacts\x12 .flwr.proto.PullArtifactsRequest\x1a!.flwr.proto.PullArtifactsResponse\"\x00\x62\x06proto3')
|
22
22
|
|
23
23
|
_globals = globals()
|
24
24
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
@@ -57,6 +57,10 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
57
57
|
_globals['_STOPRUNREQUEST']._serialized_end=1101
|
58
58
|
_globals['_STOPRUNRESPONSE']._serialized_start=1103
|
59
59
|
_globals['_STOPRUNRESPONSE']._serialized_end=1137
|
60
|
-
_globals['
|
61
|
-
_globals['
|
60
|
+
_globals['_PULLARTIFACTSREQUEST']._serialized_start=1139
|
61
|
+
_globals['_PULLARTIFACTSREQUEST']._serialized_end=1177
|
62
|
+
_globals['_PULLARTIFACTSRESPONSE']._serialized_start=1179
|
63
|
+
_globals['_PULLARTIFACTSRESPONSE']._serialized_end=1228
|
64
|
+
_globals['_CONTROL']._serialized_start=1231
|
65
|
+
_globals['_CONTROL']._serialized_end=1807
|
62
66
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/control_pb2.pyi
CHANGED
@@ -210,3 +210,27 @@ class StopRunResponse(google.protobuf.message.Message):
|
|
210
210
|
) -> None: ...
|
211
211
|
def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ...
|
212
212
|
global___StopRunResponse = StopRunResponse
|
213
|
+
|
214
|
+
class PullArtifactsRequest(google.protobuf.message.Message):
|
215
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
216
|
+
RUN_ID_FIELD_NUMBER: builtins.int
|
217
|
+
run_id: builtins.int
|
218
|
+
def __init__(self,
|
219
|
+
*,
|
220
|
+
run_id: builtins.int = ...,
|
221
|
+
) -> None: ...
|
222
|
+
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
|
223
|
+
global___PullArtifactsRequest = PullArtifactsRequest
|
224
|
+
|
225
|
+
class PullArtifactsResponse(google.protobuf.message.Message):
|
226
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
227
|
+
URL_FIELD_NUMBER: builtins.int
|
228
|
+
url: typing.Text
|
229
|
+
def __init__(self,
|
230
|
+
*,
|
231
|
+
url: typing.Optional[typing.Text] = ...,
|
232
|
+
) -> None: ...
|
233
|
+
def HasField(self, field_name: typing_extensions.Literal["_url",b"_url","url",b"url"]) -> builtins.bool: ...
|
234
|
+
def ClearField(self, field_name: typing_extensions.Literal["_url",b"_url","url",b"url"]) -> None: ...
|
235
|
+
def WhichOneof(self, oneof_group: typing_extensions.Literal["_url",b"_url"]) -> typing.Optional[typing_extensions.Literal["url"]]: ...
|
236
|
+
global___PullArtifactsResponse = PullArtifactsResponse
|