flwr-nightly 1.22.0.dev20250916__py3-none-any.whl → 1.22.0.dev20250918__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/new/new.py +4 -2
  3. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  4. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  5. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  6. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  7. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  8. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  9. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  10. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  11. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  12. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +3 -3
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  16. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  17. flwr/cli/pull.py +100 -0
  18. flwr/cli/utils.py +17 -0
  19. flwr/common/constant.py +2 -0
  20. flwr/common/exit/exit_code.py +4 -0
  21. flwr/proto/control_pb2.py +7 -3
  22. flwr/proto/control_pb2.pyi +24 -0
  23. flwr/proto/control_pb2_grpc.py +34 -0
  24. flwr/proto/control_pb2_grpc.pyi +13 -0
  25. flwr/server/app.py +13 -0
  26. flwr/serverapp/strategy/__init__.py +4 -0
  27. flwr/serverapp/strategy/fedprox.py +174 -0
  28. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  29. flwr/simulation/app.py +1 -1
  30. flwr/simulation/run_simulation.py +25 -30
  31. flwr/supercore/cli/flower_superexec.py +26 -1
  32. flwr/supercore/constant.py +19 -0
  33. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  34. flwr/supercore/superexec/run_superexec.py +16 -2
  35. flwr/superlink/artifact_provider/__init__.py +22 -0
  36. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  37. flwr/superlink/servicer/control/control_grpc.py +3 -0
  38. flwr/superlink/servicer/control/control_servicer.py +59 -2
  39. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/METADATA +1 -1
  40. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/RECORD +42 -33
  41. flwr/serverapp/strategy/strategy_utils_tests.py +0 -323
  42. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/WHEEL +0 -0
  43. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/entry_points.txt +0 -0
@@ -1,53 +1,48 @@
1
1
  """$project_name: A Flower / FlowerTune app."""
2
2
 
3
- from io import BytesIO
3
+ from collections.abc import Iterable
4
4
  from logging import INFO, WARN
5
- from typing import List, Tuple, Union
5
+ from typing import Optional
6
6
 
7
- from flwr.common import FitIns, FitRes, Parameters, log, parameters_to_ndarrays
8
- from flwr.server.client_manager import ClientManager
9
- from flwr.server.client_proxy import ClientProxy
10
- from flwr.server.strategy import FedAvg
7
+ from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
8
+ from flwr.common import log
9
+ from flwr.serverapp import Grid
10
+ from flwr.serverapp.strategy import FedAvg
11
11
 
12
12
 
13
13
  class FlowerTuneLlm(FedAvg):
14
14
  """Customised FedAvg strategy implementation.
15
-
15
+
16
16
  This class behaves just like FedAvg but also tracks the communication
17
- costs associated with `fit` over FL rounds.
17
+ costs associated with `train` over FL rounds.
18
18
  """
19
19
  def __init__(self, **kwargs):
20
20
  super().__init__(**kwargs)
21
21
  self.comm_tracker = CommunicationTracker()
22
22
 
23
- def configure_fit(
24
- self, server_round: int, parameters: Parameters, client_manager: ClientManager
25
- ):
23
+ def configure_train(
24
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
25
+ ) -> Iterable[Message]:
26
26
  """Configure the next round of training."""
27
- return_clients = super().configure_fit(server_round, parameters, client_manager)
28
-
29
- # Test communication costs
30
- fit_ins_list = [fit_ins for _, fit_ins in return_clients]
31
- self.comm_tracker.track(fit_ins_list)
32
-
33
- return return_clients
34
-
35
- def aggregate_fit(
36
- self,
37
- server_round: int,
38
- results: List[Tuple[ClientProxy, FitRes]],
39
- failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
40
- ):
41
- """Aggregate fit results using weighted average."""
42
- # Test communication costs
43
- fit_res_list = [fit_res for _, fit_res in results]
44
- self.comm_tracker.track(fit_res_list)
45
-
46
- parameters_aggregated, metrics_aggregated = super().aggregate_fit(
47
- server_round, results, failures
48
- )
27
+ messages = super().configure_train(server_round, arrays, config, grid)
28
+
29
+ # Track communication costs
30
+ self.comm_tracker.track(messages)
31
+
32
+ return messages
33
+
34
+ def aggregate_train(
35
+ self,
36
+ server_round: int,
37
+ replies: Iterable[Message],
38
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
39
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
40
+ # Track communication costs
41
+ self.comm_tracker.track(replies)
49
42
 
50
- return parameters_aggregated, metrics_aggregated
43
+ arrays, metrics = super().aggregate_train(server_round, replies)
44
+
45
+ return arrays, metrics
51
46
 
52
47
 
53
48
  class CommunicationTracker:
@@ -55,16 +50,16 @@ class CommunicationTracker:
55
50
  def __init__(self):
56
51
  self.curr_comm_cost = 0.0
57
52
 
58
- @staticmethod
59
- def _compute_bytes(parameters):
60
- return sum([BytesIO(t).getbuffer().nbytes for t in parameters.tensors])
61
-
62
- def track(self, fit_list: List[Union[FitIns, FitRes]]):
63
- size_bytes_list = [
64
- self._compute_bytes(fit_ele.parameters)
65
- for fit_ele in fit_list
66
- ]
67
- comm_cost = sum(size_bytes_list) / 1024**2
53
+ def track(self, messages: Iterable[Message]):
54
+ comm_cost = (
55
+ sum(
56
+ record.count_bytes()
57
+ for msg in messages
58
+ if msg.has_content()
59
+ for record in msg.content.array_records.values()
60
+ )
61
+ / 1024**2
62
+ )
68
63
 
69
64
  self.curr_comm_cost += comm_cost
70
65
  log(
@@ -1,7 +1,5 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
- from collections import OrderedDict
4
-
5
3
  import torch
6
4
  import torch.nn.functional as F
7
5
  from torch import nn
@@ -66,15 +64,3 @@ def test(net, testloader, device):
66
64
  accuracy = correct / len(testloader.dataset)
67
65
  loss = loss / len(testloader)
68
66
  return loss, accuracy
69
-
70
-
71
- def get_weights(net):
72
- """Extract model parameters as numpy arrays from state_dict."""
73
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
74
-
75
-
76
- def set_weights(net, parameters):
77
- """Apply parameters to an existing model."""
78
- params_dict = zip(net.state_dict().keys(), parameters)
79
- state_dict = OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})
80
- net.load_state_dict(state_dict, strict=True)
@@ -1,45 +1,43 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
- from flwr.common import Context, Metrics, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import torch
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
6
7
 
7
- from $import_name.model import Net, get_weights
8
+ from $import_name.model import Net
8
9
 
10
+ # Create ServerApp
11
+ app = ServerApp()
9
12
 
10
- # Define metric aggregation function
11
- def weighted_average(metrics: list[tuple[int, Metrics]]) -> Metrics:
12
- """Do weighted average of accuracy metric."""
13
- # Multiply accuracy of each client by number of examples used
14
- accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
15
- examples = [num_examples for num_examples, _ in metrics]
16
-
17
- # Aggregate and return custom metric (weighted average)
18
- return {"accuracy": sum(accuracies) / sum(examples)}
19
13
 
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
20
17
 
21
- def server_fn(context: Context):
22
- """Construct components that set the ServerApp behaviour."""
23
18
  # Read from config
24
19
  num_rounds = context.run_config["num-server-rounds"]
25
- fraction_fit = context.run_config["fraction-fit"]
20
+ fraction_train = context.run_config["fraction-train"]
26
21
 
27
- # Initialize model parameters
28
- ndarrays = get_weights(Net())
29
- parameters = ndarrays_to_parameters(ndarrays)
22
+ # Load global model
23
+ global_model = Net()
24
+ arrays = ArrayRecord(global_model.state_dict())
30
25
 
31
- # Define strategy
26
+ # Initialize FedAvg strategy
32
27
  strategy = FedAvg(
33
- fraction_fit=float(fraction_fit),
28
+ fraction_train=fraction_train,
34
29
  fraction_evaluate=1.0,
35
- min_available_clients=2,
36
- initial_parameters=parameters,
37
- evaluate_metrics_aggregation_fn=weighted_average,
30
+ min_available_nodes=2,
38
31
  )
39
- config = ServerConfig(num_rounds=int(num_rounds))
40
-
41
- return ServerAppComponents(strategy=strategy, config=config)
42
32
 
33
+ # Start strategy, run FedAvg for `num_rounds`
34
+ result = strategy.start(
35
+ grid=grid,
36
+ initial_arrays=arrays,
37
+ num_rounds=num_rounds,
38
+ )
43
39
 
44
- # Create ServerApp
45
- app = ServerApp(server_fn=server_fn)
40
+ # Save final model to disk
41
+ print("\nSaving final model to disk...")
42
+ state_dict = result.arrays.to_torch_state_dict()
43
+ torch.save(state_dict, "final_model.pt")
@@ -0,0 +1,56 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import numpy as np
4
+ import xgboost as xgb
5
+ from flwr.app import ArrayRecord, Context
6
+ from flwr.common.config import unflatten_dict
7
+ from flwr.serverapp import Grid, ServerApp
8
+ from flwr.serverapp.strategy import FedXgbBagging
9
+
10
+ from $import_name.task import replace_keys
11
+
12
+ # Create ServerApp
13
+ app = ServerApp()
14
+
15
+
16
+ @app.main()
17
+ def main(grid: Grid, context: Context) -> None:
18
+ # Read run config
19
+ num_rounds = context.run_config["num-server-rounds"]
20
+ fraction_train = context.run_config["fraction-train"]
21
+ fraction_evaluate = context.run_config["fraction-evaluate"]
22
+ # Flatted config dict and replace "-" with "_"
23
+ cfg = replace_keys(unflatten_dict(context.run_config))
24
+ params = cfg["params"]
25
+
26
+ # Init global model
27
+ # Init with an empty object; the XGBooster will be created
28
+ # and trained on the client side.
29
+ global_model = b""
30
+ # Note: we store the model as the first item in a list into ArrayRecord,
31
+ # which can be accessed using index ["0"].
32
+ arrays = ArrayRecord([np.frombuffer(global_model, dtype=np.uint8)])
33
+
34
+ # Initialize FedXgbBagging strategy
35
+ strategy = FedXgbBagging(
36
+ fraction_train=fraction_train,
37
+ fraction_evaluate=fraction_evaluate,
38
+ )
39
+
40
+ # Start strategy, run FedXgbBagging for `num_rounds`
41
+ result = strategy.start(
42
+ grid=grid,
43
+ initial_arrays=arrays,
44
+ num_rounds=num_rounds,
45
+ )
46
+
47
+ # Save final model to disk
48
+ bst = xgb.Booster(params=params)
49
+ global_model = bytearray(result.arrays["0"].numpy().tobytes())
50
+
51
+ # Load global model into booster
52
+ bst.load_model(global_model)
53
+
54
+ # Save model
55
+ print("\nSaving final model to disk...")
56
+ bst.save_model("final_model.json")
@@ -0,0 +1,67 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import xgboost as xgb
4
+ from flwr_datasets import FederatedDataset
5
+ from flwr_datasets.partitioner import IidPartitioner
6
+
7
+
8
+ def train_test_split(partition, test_fraction, seed):
9
+ """Split the data into train and validation set given split rate."""
10
+ train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
11
+ partition_train = train_test["train"]
12
+ partition_test = train_test["test"]
13
+
14
+ num_train = len(partition_train)
15
+ num_test = len(partition_test)
16
+
17
+ return partition_train, partition_test, num_train, num_test
18
+
19
+
20
+ def transform_dataset_to_dmatrix(data):
21
+ """Transform dataset to DMatrix format for xgboost."""
22
+ x = data["inputs"]
23
+ y = data["label"]
24
+ new_data = xgb.DMatrix(x, label=y)
25
+ return new_data
26
+
27
+
28
+ fds = None # Cache FederatedDataset
29
+
30
+
31
+ def load_data(partition_id, num_clients):
32
+ """Load partition HIGGS data."""
33
+ # Only initialize `FederatedDataset` once
34
+ global fds
35
+ if fds is None:
36
+ partitioner = IidPartitioner(num_partitions=num_clients)
37
+ fds = FederatedDataset(
38
+ dataset="jxie/higgs",
39
+ partitioners={"train": partitioner},
40
+ )
41
+
42
+ # Load the partition for this `partition_id`
43
+ partition = fds.load_partition(partition_id, split="train")
44
+ partition.set_format("numpy")
45
+
46
+ # Train/test splitting
47
+ train_data, valid_data, num_train, num_val = train_test_split(
48
+ partition, test_fraction=0.2, seed=42
49
+ )
50
+
51
+ # Reformat data to DMatrix for xgboost
52
+ train_dmatrix = transform_dataset_to_dmatrix(train_data)
53
+ valid_dmatrix = transform_dataset_to_dmatrix(valid_data)
54
+
55
+ return train_dmatrix, valid_dmatrix, num_train, num_val
56
+
57
+
58
+ def replace_keys(input_dict, match="-", target="_"):
59
+ """Recursively replace match string with target string in dictionary keys."""
60
+ new_dict = {}
61
+ for key, value in input_dict.items():
62
+ new_key = key.replace(match, target)
63
+ if isinstance(value, dict):
64
+ new_dict[new_key] = replace_keys(value, match, target)
65
+ else:
66
+ new_dict[new_key] = value
67
+ return new_dict
@@ -16,8 +16,8 @@ license = "Apache-2.0"
16
16
  dependencies = [
17
17
  "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
- "torch==2.7.1",
20
- "torchvision==0.22.1",
19
+ "torch==2.8.0",
20
+ "torchvision==0.23.0",
21
21
  ]
22
22
 
23
23
  [tool.hatch.metadata]
@@ -132,7 +132,7 @@ clientapp = "$import_name.client_app:app"
132
132
  # Custom config values accessible via `context.run_config`
133
133
  [tool.flwr.app.config]
134
134
  num-server-rounds = 3
135
- fraction-fit = 0.5
135
+ fraction-train = 0.5
136
136
  local-epochs = 1
137
137
 
138
138
  # Default federation to use when running the app
@@ -61,7 +61,7 @@ train.training-arguments.save-steps = 1000
61
61
  train.training-arguments.save-total-limit = 10
62
62
  train.training-arguments.gradient-checkpointing = true
63
63
  train.training-arguments.lr-scheduler-type = "constant"
64
- strategy.fraction-fit = $fraction_fit
64
+ strategy.fraction-train = $fraction_train
65
65
  strategy.fraction-evaluate = 0.0
66
66
  num-server-rounds = 200
67
67
 
@@ -0,0 +1,61 @@
1
+ # =====================================================================
2
+ # For a full TOML configuration guide, check the Flower docs:
3
+ # https://flower.ai/docs/framework/how-to-configure-pyproject-toml.html
4
+ # =====================================================================
5
+
6
+ [build-system]
7
+ requires = ["hatchling"]
8
+ build-backend = "hatchling.build"
9
+
10
+ [project]
11
+ name = "$package_name"
12
+ version = "1.0.0"
13
+ description = ""
14
+ license = "Apache-2.0"
15
+ # Dependencies for your Flower App
16
+ dependencies = [
17
+ "flwr[simulation]>=1.22.0",
18
+ "flwr-datasets>=0.5.0",
19
+ "xgboost>=2.0.0",
20
+ ]
21
+
22
+ [tool.hatch.build.targets.wheel]
23
+ packages = ["."]
24
+
25
+ [tool.flwr.app]
26
+ publisher = "$username"
27
+
28
+ [tool.flwr.app.components]
29
+ serverapp = "$import_name.server_app:app"
30
+ clientapp = "$import_name.client_app:app"
31
+
32
+ # Custom config values accessible via `context.run_config`
33
+ [tool.flwr.app.config]
34
+ num-server-rounds = 3
35
+ fraction-train = 0.1
36
+ fraction-evaluate = 0.1
37
+ local-epochs = 1
38
+
39
+ # XGBoost parameters
40
+ params.objective = "binary:logistic"
41
+ params.eta = 0.1 # Learning rate
42
+ params.max-depth = 8
43
+ params.eval-metric = "auc"
44
+ params.nthread = 16
45
+ params.num-parallel-tree = 1
46
+ params.subsample = 1
47
+ params.tree-method = "hist"
48
+
49
+ # Default federation to use when running the app
50
+ [tool.flwr.federations]
51
+ default = "local-simulation"
52
+
53
+ # Local simulation federation with 10 virtual SuperNodes
54
+ [tool.flwr.federations.local-simulation]
55
+ options.num-supernodes = 10
56
+
57
+ # Remote federation example for use with SuperLink
58
+ [tool.flwr.federations.remote-federation]
59
+ address = "<SUPERLINK-ADDRESS>:<PORT>"
60
+ insecure = true # Remove this line to enable TLS
61
+ # root-certificates = "<PATH/TO/ca.crt>" # For TLS setup
flwr/cli/pull.py ADDED
@@ -0,0 +1,100 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower command line interface `pull` command."""
16
+
17
+
18
+ from pathlib import Path
19
+ from typing import Annotated, Optional
20
+
21
+ import typer
22
+
23
+ from flwr.cli.config_utils import (
24
+ exit_if_no_address,
25
+ load_and_validate,
26
+ process_loaded_project_config,
27
+ validate_federation_in_project_config,
28
+ )
29
+ from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
30
+ from flwr.common.constant import FAB_CONFIG_FILE
31
+ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
32
+ PullArtifactsRequest,
33
+ PullArtifactsResponse,
34
+ )
35
+ from flwr.proto.control_pb2_grpc import ControlStub
36
+
37
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
38
+
39
+
40
+ def pull( # pylint: disable=R0914
41
+ run_id: Annotated[
42
+ int,
43
+ typer.Option(
44
+ "--run-id",
45
+ help="Run ID to pull artifacts from.",
46
+ ),
47
+ ],
48
+ app: Annotated[
49
+ Path,
50
+ typer.Argument(help="Path of the Flower App to run."),
51
+ ] = Path("."),
52
+ federation: Annotated[
53
+ Optional[str],
54
+ typer.Argument(help="Name of the federation."),
55
+ ] = None,
56
+ federation_config_overrides: Annotated[
57
+ Optional[list[str]],
58
+ typer.Option(
59
+ "--federation-config",
60
+ help=FEDERATION_CONFIG_HELP_MESSAGE,
61
+ ),
62
+ ] = None,
63
+ ) -> None:
64
+ """Pull artifacts from a Flower run."""
65
+ typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
66
+
67
+ pyproject_path = app / FAB_CONFIG_FILE if app else None
68
+ config, errors, warnings = load_and_validate(path=pyproject_path)
69
+ config = process_loaded_project_config(config, errors, warnings)
70
+ federation, federation_config = validate_federation_in_project_config(
71
+ federation, config, federation_config_overrides
72
+ )
73
+ exit_if_no_address(federation_config, "pull")
74
+ channel = None
75
+ try:
76
+
77
+ auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
78
+ channel = init_channel(app, federation_config, auth_plugin)
79
+ stub = ControlStub(channel)
80
+ with flwr_cli_grpc_exc_handler():
81
+ res: PullArtifactsResponse = stub.PullArtifacts(
82
+ PullArtifactsRequest(run_id=run_id)
83
+ )
84
+
85
+ if not res.url:
86
+ typer.secho(
87
+ f"❌ A download URL for artifacts from run {run_id} couldn't be "
88
+ "obtained.",
89
+ fg=typer.colors.RED,
90
+ bold=True,
91
+ )
92
+ raise typer.Exit(code=1)
93
+
94
+ typer.secho(
95
+ f"✅ Artifacts for run {run_id} can be downloaded from: {res.url}",
96
+ fg=typer.colors.GREEN,
97
+ )
98
+ finally:
99
+ if channel:
100
+ channel.close()
flwr/cli/utils.py CHANGED
@@ -32,7 +32,9 @@ from flwr.common.constant import (
32
32
  AUTH_TYPE_JSON_KEY,
33
33
  CREDENTIALS_DIR,
34
34
  FLWR_DIR,
35
+ NO_ARTIFACT_PROVIDER_MESSAGE,
35
36
  NO_USER_AUTH_MESSAGE,
37
+ PULL_UNFINISHED_RUN_MESSAGE,
36
38
  RUN_ID_NOT_FOUND_MESSAGE,
37
39
  )
38
40
  from flwr.common.grpc import (
@@ -319,6 +321,12 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
319
321
  fg=typer.colors.RED,
320
322
  bold=True,
321
323
  )
324
+ elif e.details() == NO_ARTIFACT_PROVIDER_MESSAGE: # pylint: disable=E1101
325
+ typer.secho(
326
+ "❌ The SuperLink does not support `flwr pull` command.",
327
+ fg=typer.colors.RED,
328
+ bold=True,
329
+ )
322
330
  else:
323
331
  typer.secho(
324
332
  "❌ The SuperLink cannot process this request. Please verify that "
@@ -356,4 +364,13 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
356
364
  bold=True,
357
365
  )
358
366
  raise typer.Exit(code=1) from None
367
+ if e.code() == grpc.StatusCode.FAILED_PRECONDITION:
368
+ if e.details() == PULL_UNFINISHED_RUN_MESSAGE: # pylint: disable=E1101
369
+ typer.secho(
370
+ "❌ Run is not finished yet. Artifacts can only be pulled after "
371
+ "the run is finished. You can check the run status with `flwr ls`.",
372
+ fg=typer.colors.RED,
373
+ bold=True,
374
+ )
375
+ raise typer.Exit(code=1) from None
359
376
  raise
flwr/common/constant.py CHANGED
@@ -155,6 +155,8 @@ PULL_BACKOFF_CAP = 10 # Maximum backoff time for pulling objects
155
155
  # ControlServicer constants
156
156
  RUN_ID_NOT_FOUND_MESSAGE = "Run ID not found"
157
157
  NO_USER_AUTH_MESSAGE = "ControlServicer initialized without user authentication"
158
+ NO_ARTIFACT_PROVIDER_MESSAGE = "ControlServicer initialized without artifact provider"
159
+ PULL_UNFINISHED_RUN_MESSAGE = "Cannot pull artifacts for an unfinished run"
158
160
 
159
161
 
160
162
  class MessageType:
@@ -45,6 +45,7 @@ class ExitCode:
45
45
  SUPERNODE_NODE_AUTH_KEYS_INVALID = 302
46
46
 
47
47
  # SuperExec-specific exit codes (400-499)
48
+ SUPEREXEC_INVALID_PLUGIN_CONFIG = 400
48
49
 
49
50
  # Common exit codes (600-699)
50
51
  COMMON_ADDRESS_INVALID = 600
@@ -112,6 +113,9 @@ EXIT_CODE_HELP = {
112
113
  "file and try again."
113
114
  ),
114
115
  # SuperExec-specific exit codes (400-499)
116
+ ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG: (
117
+ "The YAML configuration for the SuperExec plugin is invalid."
118
+ ),
115
119
  # Common exit codes (600-699)
116
120
  ExitCode.COMMON_ADDRESS_INVALID: (
117
121
  "Please provide a valid URL, IPv4 or IPv6 address."
flwr/proto/control_pb2.py CHANGED
@@ -18,7 +18,7 @@ from flwr.proto import recorddict_pb2 as flwr_dot_proto_dot_recorddict__pb2
18
18
  from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
19
19
 
20
20
 
21
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xe8\x03\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x62\x06proto3')
21
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8a\x01\n\x17GetLoginDetailsResponse\x12\x11\n\tauth_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"&\n\x14PullArtifactsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"1\n\x15PullArtifactsResponse\x12\x10\n\x03url\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x06\n\x04_url2\xc0\x04\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x12V\n\rPullArtifacts\x12 .flwr.proto.PullArtifactsRequest\x1a!.flwr.proto.PullArtifactsResponse\"\x00\x62\x06proto3')
22
22
 
23
23
  _globals = globals()
24
24
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -57,6 +57,10 @@ if _descriptor._USE_C_DESCRIPTORS == False:
57
57
  _globals['_STOPRUNREQUEST']._serialized_end=1101
58
58
  _globals['_STOPRUNRESPONSE']._serialized_start=1103
59
59
  _globals['_STOPRUNRESPONSE']._serialized_end=1137
60
- _globals['_CONTROL']._serialized_start=1140
61
- _globals['_CONTROL']._serialized_end=1628
60
+ _globals['_PULLARTIFACTSREQUEST']._serialized_start=1139
61
+ _globals['_PULLARTIFACTSREQUEST']._serialized_end=1177
62
+ _globals['_PULLARTIFACTSRESPONSE']._serialized_start=1179
63
+ _globals['_PULLARTIFACTSRESPONSE']._serialized_end=1228
64
+ _globals['_CONTROL']._serialized_start=1231
65
+ _globals['_CONTROL']._serialized_end=1807
62
66
  # @@protoc_insertion_point(module_scope)
@@ -210,3 +210,27 @@ class StopRunResponse(google.protobuf.message.Message):
210
210
  ) -> None: ...
211
211
  def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ...
212
212
  global___StopRunResponse = StopRunResponse
213
+
214
+ class PullArtifactsRequest(google.protobuf.message.Message):
215
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
216
+ RUN_ID_FIELD_NUMBER: builtins.int
217
+ run_id: builtins.int
218
+ def __init__(self,
219
+ *,
220
+ run_id: builtins.int = ...,
221
+ ) -> None: ...
222
+ def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
223
+ global___PullArtifactsRequest = PullArtifactsRequest
224
+
225
+ class PullArtifactsResponse(google.protobuf.message.Message):
226
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
227
+ URL_FIELD_NUMBER: builtins.int
228
+ url: typing.Text
229
+ def __init__(self,
230
+ *,
231
+ url: typing.Optional[typing.Text] = ...,
232
+ ) -> None: ...
233
+ def HasField(self, field_name: typing_extensions.Literal["_url",b"_url","url",b"url"]) -> builtins.bool: ...
234
+ def ClearField(self, field_name: typing_extensions.Literal["_url",b"_url","url",b"url"]) -> None: ...
235
+ def WhichOneof(self, oneof_group: typing_extensions.Literal["_url",b"_url"]) -> typing.Optional[typing_extensions.Literal["url"]]: ...
236
+ global___PullArtifactsResponse = PullArtifactsResponse