flwr-nightly 1.11.0.dev20240822__py3-none-any.whl → 1.11.1.dev20240912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (75) hide show
  1. flwr/cli/app.py +0 -2
  2. flwr/cli/build.py +1 -1
  3. flwr/cli/new/new.py +41 -40
  4. flwr/cli/new/templates/app/LICENSE.tpl +202 -0
  5. flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
  6. flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
  7. flwr/cli/new/templates/app/README.md.tpl +7 -30
  8. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
  9. flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
  10. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
  12. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
  13. flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
  14. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
  15. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
  16. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
  17. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
  18. flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
  19. flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
  20. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
  21. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
  22. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
  23. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
  24. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
  25. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
  26. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
  27. flwr/cli/run/run.py +12 -2
  28. flwr/client/__init__.py +0 -4
  29. flwr/client/app.py +3 -4
  30. flwr/client/client.py +22 -1
  31. flwr/client/client_app.py +2 -2
  32. flwr/client/grpc_rere_client/client_interceptor.py +15 -7
  33. flwr/client/numpy_client.py +22 -1
  34. flwr/client/rest_client/connection.py +1 -1
  35. flwr/client/supernode/app.py +8 -7
  36. flwr/common/address.py +43 -0
  37. flwr/common/config.py +14 -11
  38. flwr/common/constant.py +12 -1
  39. flwr/common/record/recordset.py +1 -1
  40. flwr/common/record/typeddict.py +24 -1
  41. flwr/common/telemetry.py +36 -30
  42. flwr/server/__init__.py +0 -4
  43. flwr/server/app.py +27 -22
  44. flwr/server/compat/app.py +0 -5
  45. flwr/server/driver/grpc_driver.py +3 -6
  46. flwr/server/run_serverapp.py +20 -7
  47. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +15 -2
  48. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -0
  49. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
  50. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +13 -12
  51. flwr/server/superlink/fleet/rest_rere/rest_api.py +71 -122
  52. flwr/server/superlink/fleet/vce/backend/backend.py +1 -2
  53. flwr/server/superlink/fleet/vce/backend/raybackend.py +33 -15
  54. flwr/server/superlink/fleet/vce/vce_api.py +2 -6
  55. flwr/server/superlink/state/in_memory_state.py +15 -15
  56. flwr/server/superlink/state/sqlite_state.py +10 -10
  57. flwr/server/superlink/state/state.py +8 -8
  58. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -0
  59. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -0
  60. flwr/simulation/ray_transport/ray_actor.py +2 -2
  61. flwr/simulation/run_simulation.py +85 -25
  62. flwr/superexec/__init__.py +0 -6
  63. flwr/superexec/app.py +5 -3
  64. flwr/superexec/deployment.py +2 -2
  65. flwr/superexec/simulation.py +20 -1
  66. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/METADATA +3 -3
  67. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/RECORD +70 -62
  68. flwr_nightly-1.11.1.dev20240912.dist-info/entry_points.txt +10 -0
  69. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
  70. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
  71. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
  72. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
  73. flwr_nightly-1.11.0.dev20240822.dist-info/entry_points.txt +0 -10
  74. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/LICENSE +0 -0
  75. {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/WHEEL +0 -0
@@ -0,0 +1,46 @@
1
+ """$project_name: A Flower Baseline."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ from flwr.common import Context, Metrics, ndarrays_to_parameters
6
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
7
+ from flwr.server.strategy import FedAvg
8
+ from $import_name.model import Net, get_weights
9
+
10
+
11
+ # Define metric aggregation function
12
+ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
13
+ """Do weighted average of accuracy metric."""
14
+ # Multiply accuracy of each client by number of examples used
15
+ accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
16
+ examples = [num_examples for num_examples, _ in metrics]
17
+
18
+ # Aggregate and return custom metric (weighted average)
19
+ return {"accuracy": sum(accuracies) / sum(examples)}
20
+
21
+
22
+ def server_fn(context: Context):
23
+ """Construct components that set the ServerApp behaviour."""
24
+ # Read from config
25
+ num_rounds = context.run_config["num-server-rounds"]
26
+ fraction_fit = context.run_config["fraction-fit"]
27
+
28
+ # Initialize model parameters
29
+ ndarrays = get_weights(Net())
30
+ parameters = ndarrays_to_parameters(ndarrays)
31
+
32
+ # Define strategy
33
+ strategy = FedAvg(
34
+ fraction_fit=float(fraction_fit),
35
+ fraction_evaluate=1.0,
36
+ min_available_clients=2,
37
+ initial_parameters=parameters,
38
+ evaluate_metrics_aggregation_fn=weighted_average,
39
+ )
40
+ config = ServerConfig(num_rounds=int(num_rounds))
41
+
42
+ return ServerAppComponents(strategy=strategy, config=config)
43
+
44
+
45
+ # Create ServerApp
46
+ app = ServerApp(server_fn=server_fn)
@@ -1,18 +1,33 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context
4
- from flwr.server.strategy import FedAvg
3
+ from flwr.common import Context, ndarrays_to_parameters
5
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+ from flwr.server.strategy import FedAvg
6
+ from transformers import AutoModelForSequenceClassification
7
+
8
+ from $import_name.task import get_weights
6
9
 
7
10
 
8
11
  def server_fn(context: Context):
9
12
  # Read from config
10
13
  num_rounds = context.run_config["num-server-rounds"]
14
+ fraction_fit = context.run_config["fraction-fit"]
15
+
16
+ # Initialize global model
17
+ model_name = context.run_config["model-name"]
18
+ num_labels = context.run_config["num-labels"]
19
+ net = AutoModelForSequenceClassification.from_pretrained(
20
+ model_name, num_labels=num_labels
21
+ )
22
+
23
+ weights = get_weights(net)
24
+ initial_parameters = ndarrays_to_parameters(weights)
11
25
 
12
26
  # Define strategy
13
27
  strategy = FedAvg(
14
- fraction_fit=1.0,
28
+ fraction_fit=fraction_fit,
15
29
  fraction_evaluate=1.0,
30
+ initial_parameters=initial_parameters,
16
31
  )
17
32
  config = ServerConfig(num_rounds=num_rounds)
18
33
 
@@ -0,0 +1 @@
1
+ """$project_name: A Flower Baseline."""
@@ -4,24 +4,25 @@ import warnings
4
4
  from collections import OrderedDict
5
5
 
6
6
  import torch
7
+ import transformers
8
+ from datasets.utils.logging import disable_progress_bar
7
9
  from evaluate import load as load_metric
10
+ from flwr_datasets import FederatedDataset
11
+ from flwr_datasets.partitioner import IidPartitioner
8
12
  from torch.optim import AdamW
9
13
  from torch.utils.data import DataLoader
10
14
  from transformers import AutoTokenizer, DataCollatorWithPadding
11
15
 
12
- from flwr_datasets import FederatedDataset
13
- from flwr_datasets.partitioner import IidPartitioner
14
-
15
-
16
16
  warnings.filterwarnings("ignore", category=UserWarning)
17
- DEVICE = torch.device("cpu")
18
- CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
17
+ warnings.filterwarnings("ignore", category=FutureWarning)
18
+ disable_progress_bar()
19
+ transformers.logging.set_verbosity_error()
19
20
 
20
21
 
21
22
  fds = None # Cache FederatedDataset
22
23
 
23
24
 
24
- def load_data(partition_id: int, num_partitions: int):
25
+ def load_data(partition_id: int, num_partitions: int, model_name: str):
25
26
  """Load IMDB data (training and eval)"""
26
27
  # Only initialize `FederatedDataset` once
27
28
  global fds
@@ -35,10 +36,12 @@ def load_data(partition_id: int, num_partitions: int):
35
36
  # Divide data: 80% train, 20% test
36
37
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
37
38
 
38
- tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
39
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
40
 
40
41
  def tokenize_function(examples):
41
- return tokenizer(examples["text"], truncation=True)
42
+ return tokenizer(
43
+ examples["text"], truncation=True, add_special_tokens=True, max_length=512
44
+ )
42
45
 
43
46
  partition_train_test = partition_train_test.map(tokenize_function, batched=True)
44
47
  partition_train_test = partition_train_test.remove_columns("text")
@@ -59,12 +62,12 @@ def load_data(partition_id: int, num_partitions: int):
59
62
  return trainloader, testloader
60
63
 
61
64
 
62
- def train(net, trainloader, epochs):
65
+ def train(net, trainloader, epochs, device):
63
66
  optimizer = AdamW(net.parameters(), lr=5e-5)
64
67
  net.train()
65
68
  for _ in range(epochs):
66
69
  for batch in trainloader:
67
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
70
+ batch = {k: v.to(device) for k, v in batch.items()}
68
71
  outputs = net(**batch)
69
72
  loss = outputs.loss
70
73
  loss.backward()
@@ -72,12 +75,12 @@ def train(net, trainloader, epochs):
72
75
  optimizer.zero_grad()
73
76
 
74
77
 
75
- def test(net, testloader):
78
+ def test(net, testloader, device):
76
79
  metric = load_metric("accuracy")
77
80
  loss = 0
78
81
  net.eval()
79
82
  for batch in testloader:
80
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
83
+ batch = {k: v.to(device) for k, v in batch.items()}
81
84
  with torch.no_grad():
82
85
  outputs = net(**batch)
83
86
  logits = outputs.logits
@@ -0,0 +1 @@
1
+ """$project_name: A Flower Baseline."""
@@ -0,0 +1,138 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ license = "Apache-2.0"
10
+ dependencies = [
11
+ "flwr[simulation]>=1.11.0",
12
+ "flwr-datasets[vision]>=0.3.0",
13
+ "torch==2.2.1",
14
+ "torchvision==0.17.1",
15
+ ]
16
+
17
+ [tool.hatch.metadata]
18
+ allow-direct-references = true
19
+
20
+ [project.optional-dependencies]
21
+ dev = [
22
+ "isort==5.13.2",
23
+ "black==24.2.0",
24
+ "docformatter==1.7.5",
25
+ "mypy==1.8.0",
26
+ "pylint==3.2.6",
27
+ "flake8==5.0.4",
28
+ "pytest==6.2.4",
29
+ "pytest-watch==4.2.0",
30
+ "ruff==0.1.9",
31
+ "types-requests==2.31.0.20240125",
32
+ ]
33
+
34
+ [tool.isort]
35
+ profile = "black"
36
+ known_first_party = ["flwr"]
37
+
38
+ [tool.black]
39
+ line-length = 88
40
+ target-version = ["py38", "py39", "py310", "py311"]
41
+
42
+ [tool.pytest.ini_options]
43
+ minversion = "6.2"
44
+ addopts = "-qq"
45
+ testpaths = [
46
+ "flwr_baselines",
47
+ ]
48
+
49
+ [tool.mypy]
50
+ ignore_missing_imports = true
51
+ strict = false
52
+ plugins = "numpy.typing.mypy_plugin"
53
+
54
+ [tool.pylint."MESSAGES CONTROL"]
55
+ disable = "duplicate-code,too-few-public-methods,useless-import-alias"
56
+ good-names = "i,j,k,_,x,y,X,Y,K,N"
57
+ max-args = 10
58
+ max-attributes = 15
59
+ max-locals = 36
60
+ max-branches = 20
61
+ max-statements = 55
62
+
63
+ [tool.pylint.typecheck]
64
+ generated-members = "numpy.*, torch.*, tensorflow.*"
65
+
66
+ [[tool.mypy.overrides]]
67
+ module = [
68
+ "importlib.metadata.*",
69
+ "importlib_metadata.*",
70
+ ]
71
+ follow_imports = "skip"
72
+ follow_imports_for_stubs = true
73
+ disallow_untyped_calls = false
74
+
75
+ [[tool.mypy.overrides]]
76
+ module = "torch.*"
77
+ follow_imports = "skip"
78
+ follow_imports_for_stubs = true
79
+
80
+ [tool.docformatter]
81
+ wrap-summaries = 88
82
+ wrap-descriptions = 88
83
+
84
+ [tool.ruff]
85
+ target-version = "py38"
86
+ line-length = 88
87
+ select = ["D", "E", "F", "W", "B", "ISC", "C4"]
88
+ fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
89
+ ignore = ["B024", "B027"]
90
+ exclude = [
91
+ ".bzr",
92
+ ".direnv",
93
+ ".eggs",
94
+ ".git",
95
+ ".hg",
96
+ ".mypy_cache",
97
+ ".nox",
98
+ ".pants.d",
99
+ ".pytype",
100
+ ".ruff_cache",
101
+ ".svn",
102
+ ".tox",
103
+ ".venv",
104
+ "__pypackages__",
105
+ "_build",
106
+ "buck-out",
107
+ "build",
108
+ "dist",
109
+ "node_modules",
110
+ "venv",
111
+ "proto",
112
+ ]
113
+
114
+ [tool.ruff.pydocstyle]
115
+ convention = "numpy"
116
+
117
+ [tool.hatch.build.targets.wheel]
118
+ packages = ["."]
119
+
120
+ [tool.flwr.app]
121
+ publisher = "$username"
122
+
123
+ [tool.flwr.app.components]
124
+ serverapp = "$import_name.server_app:app"
125
+ clientapp = "$import_name.client_app:app"
126
+
127
+ [tool.flwr.app.config]
128
+ num-server-rounds = 3
129
+ fraction-fit = 0.5
130
+ local-epochs = 1
131
+
132
+ [tool.flwr.federations]
133
+ default = "local-simulation"
134
+
135
+ [tool.flwr.federations.local-simulation]
136
+ options.num-supernodes = 10
137
+ options.backend.client-resources.num-cpus = 2
138
+ options.backend.client-resources.num-gpus = 0.0
@@ -8,15 +8,15 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
12
- "flwr-datasets>=0.1.0,<1.0.0",
13
- "hydra-core==1.3.2",
11
+ "flwr[simulation]>=1.10.0",
12
+ "flwr-datasets>=0.3.0",
14
13
  "trl==0.8.1",
15
14
  "bitsandbytes==0.43.0",
16
15
  "scipy==1.13.0",
17
16
  "peft==0.6.2",
18
17
  "transformers==4.39.3",
19
18
  "sentencepiece==0.2.0",
19
+ "omegaconf==2.3.0",
20
20
  ]
21
21
 
22
22
  [tool.hatch.build.targets.wheel]
@@ -26,14 +26,41 @@ packages = ["."]
26
26
  publisher = "$username"
27
27
 
28
28
  [tool.flwr.app.components]
29
- serverapp = "$import_name.app:server"
30
- clientapp = "$import_name.app:client"
29
+ serverapp = "$import_name.server_app:app"
30
+ clientapp = "$import_name.client_app:app"
31
31
 
32
32
  [tool.flwr.app.config]
33
- num-server-rounds = 3
33
+ model.name = "mistralai/Mistral-7B-v0.3"
34
+ model.quantization = 4
35
+ model.gradient-checkpointing = true
36
+ model.lora.peft-lora-r = 32
37
+ model.lora.peft-lora-alpha = 64
38
+ train.save-every-round = 5
39
+ train.learning-rate-max = 5e-5
40
+ train.learning-rate-min = 1e-6
41
+ train.seq-length = 512
42
+ train.training-arguments.output-dir = ""
43
+ train.training-arguments.learning-rate = ""
44
+ train.training-arguments.per-device-train-batch-size = 16
45
+ train.training-arguments.gradient-accumulation-steps = 1
46
+ train.training-arguments.logging-steps = 10
47
+ train.training-arguments.num-train-epochs = 3
48
+ train.training-arguments.max-steps = 10
49
+ train.training-arguments.save-steps = 1000
50
+ train.training-arguments.save-total-limit = 10
51
+ train.training-arguments.gradient-checkpointing = true
52
+ train.training-arguments.lr-scheduler-type = "constant"
53
+ strategy.fraction-fit = $fraction_fit
54
+ strategy.fraction-evaluate = 0.0
55
+ num-server-rounds = 200
56
+
57
+ [tool.flwr.app.config.static]
58
+ dataset.name = "$dataset_name"
34
59
 
35
60
  [tool.flwr.federations]
36
61
  default = "local-simulation"
37
62
 
38
63
  [tool.flwr.federations.local-simulation]
39
- options.num-supernodes = 10
64
+ options.num-supernodes = $num_clients
65
+ options.backend.client-resources.num-cpus = 6
66
+ options.backend.client-resources.num-gpus = 1.0
@@ -8,7 +8,7 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.10.0",
11
+ "flwr[simulation]>=1.11.0",
12
12
  "flwr-datasets>=0.3.0",
13
13
  "torch==2.2.1",
14
14
  "transformers>=4.30.0,<5.0",
@@ -29,10 +29,18 @@ clientapp = "$import_name.client_app:app"
29
29
 
30
30
  [tool.flwr.app.config]
31
31
  num-server-rounds = 3
32
+ fraction-fit = 0.5
32
33
  local-epochs = 1
34
+ model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
35
+ num-labels = 2
33
36
 
34
37
  [tool.flwr.federations]
35
38
  default = "localhost"
36
39
 
37
40
  [tool.flwr.federations.localhost]
38
41
  options.num-supernodes = 10
42
+
43
+ [tool.flwr.federations.localhost-gpu]
44
+ options.num-supernodes = 10
45
+ options.backend.client-resources.num-cpus = 4 # each ClientApp assumes to use 4CPUs
46
+ options.backend.client-resources.num-gpus = 0.25 # at most 4 ClientApps will run in a given GPU
flwr/cli/run/run.py CHANGED
@@ -15,6 +15,7 @@
15
15
  """Flower command line interface `run` command."""
16
16
 
17
17
  import hashlib
18
+ import json
18
19
  import subprocess
19
20
  import sys
20
21
  from logging import DEBUG
@@ -123,14 +124,14 @@ def run(
123
124
 
124
125
 
125
126
  def _run_with_superexec(
126
- app: Optional[Path],
127
+ app: Path,
127
128
  federation_config: Dict[str, Any],
128
129
  config_overrides: Optional[List[str]],
129
130
  ) -> None:
130
131
 
131
132
  insecure_str = federation_config.get("insecure")
132
133
  if root_certificates := federation_config.get("root-certificates"):
133
- root_certificates_bytes = Path(root_certificates).read_bytes()
134
+ root_certificates_bytes = (app / root_certificates).read_bytes()
134
135
  if insecure := bool(insecure_str):
135
136
  typer.secho(
136
137
  "❌ `root_certificates` were provided but the `insecure` parameter"
@@ -192,6 +193,8 @@ def _run_without_superexec(
192
193
  ) -> None:
193
194
  try:
194
195
  num_supernodes = federation_config["options"]["num-supernodes"]
196
+ verbose: Optional[bool] = federation_config["options"].get("verbose")
197
+ backend_cfg = federation_config["options"].get("backend", {})
195
198
  except KeyError as err:
196
199
  typer.secho(
197
200
  "❌ The project's `pyproject.toml` needs to declare the number of"
@@ -212,6 +215,13 @@ def _run_without_superexec(
212
215
  f"{num_supernodes}",
213
216
  ]
214
217
 
218
+ if backend_cfg:
219
+ # Stringify as JSON
220
+ command.extend(["--backend-config", json.dumps(backend_cfg)])
221
+
222
+ if verbose:
223
+ command.extend(["--verbose"])
224
+
215
225
  if config_overrides:
216
226
  command.extend(["--run-config", f"{' '.join(config_overrides)}"])
217
227
 
flwr/client/__init__.py CHANGED
@@ -20,8 +20,6 @@ from .app import start_numpy_client as start_numpy_client
20
20
  from .client import Client as Client
21
21
  from .client_app import ClientApp as ClientApp
22
22
  from .numpy_client import NumPyClient as NumPyClient
23
- from .supernode import run_client_app as run_client_app
24
- from .supernode import run_supernode as run_supernode
25
23
  from .typing import ClientFn as ClientFn
26
24
  from .typing import ClientFnExt as ClientFnExt
27
25
 
@@ -32,8 +30,6 @@ __all__ = [
32
30
  "ClientFnExt",
33
31
  "NumPyClient",
34
32
  "mod",
35
- "run_client_app",
36
- "run_supernode",
37
33
  "start_client",
38
34
  "start_numpy_client",
39
35
  ]
flwr/client/app.py CHANGED
@@ -35,6 +35,7 @@ from flwr.client.typing import ClientFnExt
35
35
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
36
36
  from flwr.common.address import parse_address
37
37
  from flwr.common.constant import (
38
+ CLIENTAPPIO_API_DEFAULT_ADDRESS,
38
39
  MISSING_EXTRA_REST,
39
40
  RUN_ID_NUM_BYTES,
40
41
  TRANSPORT_TYPE_GRPC_ADAPTER,
@@ -60,8 +61,6 @@ from .message_handler.message_handler import handle_control_message
60
61
  from .node_state import NodeState
61
62
  from .numpy_client import NumPyClient
62
63
 
63
- ADDRESS_CLIENTAPPIO_API_GRPC_RERE = "0.0.0.0:9094"
64
-
65
64
  ISOLATION_MODE_SUBPROCESS = "subprocess"
66
65
  ISOLATION_MODE_PROCESS = "process"
67
66
 
@@ -211,7 +210,7 @@ def start_client_internal(
211
210
  max_wait_time: Optional[float] = None,
212
211
  flwr_path: Optional[Path] = None,
213
212
  isolation: Optional[str] = None,
214
- supernode_address: Optional[str] = ADDRESS_CLIENTAPPIO_API_GRPC_RERE,
213
+ supernode_address: Optional[str] = CLIENTAPPIO_API_DEFAULT_ADDRESS,
215
214
  ) -> None:
216
215
  """Start a Flower client node which connects to a Flower server.
217
216
 
@@ -266,7 +265,7 @@ def start_client_internal(
266
265
  by the SueprNode and communicates using gRPC at the address
267
266
  `supernode_address`. If `process`, the `ClientApp` runs in a separate isolated
268
267
  process and communicates using gRPC at the address `supernode_address`.
269
- supernode_address : Optional[str] (default: `ADDRESS_CLIENTAPPIO_API_GRPC_RERE`)
268
+ supernode_address : Optional[str] (default: `CLIENTAPPIO_API_DEFAULT_ADDRESS`)
270
269
  The SuperNode gRPC server address.
271
270
  """
272
271
  if insecure is None:
flwr/client/client.py CHANGED
@@ -33,12 +33,13 @@ from flwr.common import (
33
33
  Parameters,
34
34
  Status,
35
35
  )
36
+ from flwr.common.logger import warn_deprecated_feature_with_example
36
37
 
37
38
 
38
39
  class Client(ABC):
39
40
  """Abstract base class for Flower clients."""
40
41
 
41
- context: Context
42
+ _context: Context
42
43
 
43
44
  def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
44
45
  """Return set of client's properties.
@@ -141,6 +142,26 @@ class Client(ABC):
141
142
  metrics={},
142
143
  )
143
144
 
145
+ @property
146
+ def context(self) -> Context:
147
+ """Getter for `Context` client attribute."""
148
+ warn_deprecated_feature_with_example(
149
+ "Accessing the context via the client's attribute is deprecated.",
150
+ example_message="Instead, pass it to the client's "
151
+ "constructor in your `client_fn()` which already "
152
+ "receives a context object.",
153
+ code_example="def client_fn(context: Context) -> Client:\n\n"
154
+ "\t\t# Your existing client_fn\n\n"
155
+ "\t\t# Pass `context` to the constructor\n"
156
+ "\t\treturn FlowerClient(context).to_client()",
157
+ )
158
+ return self._context
159
+
160
+ @context.setter
161
+ def context(self, context: Context) -> None:
162
+ """Setter for `Context` client attribute."""
163
+ self._context = context
164
+
144
165
  def get_context(self) -> Context:
145
166
  """Get the run context from this client."""
146
167
  return self.context
flwr/client/client_app.py CHANGED
@@ -41,11 +41,11 @@ def _alert_erroneous_client_fn() -> None:
41
41
 
42
42
  def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
43
43
  client_fn_args = inspect.signature(client_fn).parameters
44
- first_arg = list(client_fn_args.keys())[0]
45
44
 
46
45
  if len(client_fn_args) != 1:
47
46
  _alert_erroneous_client_fn()
48
47
 
48
+ first_arg = list(client_fn_args.keys())[0]
49
49
  first_arg_type = client_fn_args[first_arg].annotation
50
50
 
51
51
  if first_arg_type is str or first_arg == "cid":
@@ -263,7 +263,7 @@ def _registration_error(fn_name: str) -> ValueError:
263
263
  >>> class FlowerClient(NumPyClient):
264
264
  >>> # ...
265
265
  >>>
266
- >>> def client_fn(cid) -> Client:
266
+ >>> def client_fn(context: Context):
267
267
  >>> return FlowerClient().to_client()
268
268
  >>>
269
269
  >>> app = ClientApp(
@@ -17,11 +17,13 @@
17
17
 
18
18
  import base64
19
19
  import collections
20
+ from logging import WARNING
20
21
  from typing import Any, Callable, Optional, Sequence, Tuple, Union
21
22
 
22
23
  import grpc
23
24
  from cryptography.hazmat.primitives.asymmetric import ec
24
25
 
26
+ from flwr.common.logger import log
25
27
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
26
28
  bytes_to_public_key,
27
29
  compute_hmac,
@@ -128,13 +130,12 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
128
130
  if self.shared_secret is None:
129
131
  raise RuntimeError("Failure to compute hmac")
130
132
 
133
+ message_bytes = request.SerializeToString(deterministic=True)
131
134
  metadata.append(
132
135
  (
133
136
  _AUTH_TOKEN_HEADER,
134
137
  base64.urlsafe_b64encode(
135
- compute_hmac(
136
- self.shared_secret, request.SerializeToString(True)
137
- )
138
+ compute_hmac(self.shared_secret, message_bytes)
138
139
  ),
139
140
  )
140
141
  )
@@ -151,8 +152,15 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
151
152
  server_public_key_bytes = base64.urlsafe_b64decode(
152
153
  _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
153
154
  )
154
- self.server_public_key = bytes_to_public_key(server_public_key_bytes)
155
- self.shared_secret = generate_shared_key(
156
- self.private_key, self.server_public_key
157
- )
155
+
156
+ if server_public_key_bytes != b"":
157
+ self.server_public_key = bytes_to_public_key(server_public_key_bytes)
158
+ else:
159
+ log(WARNING, "Can't get server public key, SuperLink may be offline")
160
+
161
+ if self.server_public_key is not None:
162
+ self.shared_secret = generate_shared_key(
163
+ self.private_key, self.server_public_key
164
+ )
165
+
158
166
  return response
@@ -27,6 +27,7 @@ from flwr.common import (
27
27
  ndarrays_to_parameters,
28
28
  parameters_to_ndarrays,
29
29
  )
30
+ from flwr.common.logger import warn_deprecated_feature_with_example
30
31
  from flwr.common.typing import (
31
32
  Code,
32
33
  EvaluateIns,
@@ -70,7 +71,7 @@ Example
70
71
  class NumPyClient(ABC):
71
72
  """Abstract base class for Flower clients using NumPy."""
72
73
 
73
- context: Context
74
+ _context: Context
74
75
 
75
76
  def get_properties(self, config: Config) -> Dict[str, Scalar]:
76
77
  """Return a client's set of properties.
@@ -174,6 +175,26 @@ class NumPyClient(ABC):
174
175
  _ = (self, parameters, config)
175
176
  return 0.0, 0, {}
176
177
 
178
+ @property
179
+ def context(self) -> Context:
180
+ """Getter for `Context` client attribute."""
181
+ warn_deprecated_feature_with_example(
182
+ "Accessing the context via the client's attribute is deprecated.",
183
+ example_message="Instead, pass it to the client's "
184
+ "constructor in your `client_fn()` which already "
185
+ "receives a context object.",
186
+ code_example="def client_fn(context: Context) -> Client:\n\n"
187
+ "\t\t# Your existing client_fn\n\n"
188
+ "\t\t# Pass `context` to the constructor\n"
189
+ "\t\treturn FlowerClient(context).to_client()",
190
+ )
191
+ return self._context
192
+
193
+ @context.setter
194
+ def context(self, context: Context) -> None:
195
+ """Setter for `Context` client attribute."""
196
+ self._context = context
197
+
177
198
  def get_context(self) -> Context:
178
199
  """Get the run context from this client."""
179
200
  return self.context
@@ -275,7 +275,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
275
275
  req = DeleteNodeRequest(node=node)
276
276
 
277
277
  # Send the request
278
- res = _request(req, DeleteNodeResponse, PATH_CREATE_NODE)
278
+ res = _request(req, DeleteNodeResponse, PATH_DELETE_NODE)
279
279
  if res is None:
280
280
  return
281
281