flwr-nightly 1.12.0.dev20241007__py3-none-any.whl → 1.12.0.dev20241010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (49) hide show
  1. flwr/cli/build.py +60 -29
  2. flwr/cli/config_utils.py +10 -0
  3. flwr/cli/install.py +60 -20
  4. flwr/cli/new/new.py +2 -0
  5. flwr/cli/new/templates/app/code/client.jax.py.tpl +11 -17
  6. flwr/cli/new/templates/app/code/client.mlx.py.tpl +16 -36
  7. flwr/cli/new/templates/app/code/client.numpy.py.tpl +4 -5
  8. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +8 -11
  9. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +14 -48
  10. flwr/cli/new/templates/app/code/server.jax.py.tpl +9 -3
  11. flwr/cli/new/templates/app/code/server.mlx.py.tpl +13 -2
  12. flwr/cli/new/templates/app/code/server.numpy.py.tpl +7 -2
  13. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  14. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +13 -1
  15. flwr/cli/new/templates/app/code/task.jax.py.tpl +2 -2
  16. flwr/cli/new/templates/app/code/task.mlx.py.tpl +1 -1
  17. flwr/cli/new/templates/app/code/task.numpy.py.tpl +7 -0
  18. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +3 -3
  19. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +67 -0
  20. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +3 -2
  22. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -0
  23. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +2 -0
  24. flwr/cli/run/run.py +5 -5
  25. flwr/client/app.py +13 -3
  26. flwr/client/clientapp/app.py +5 -2
  27. flwr/client/clientapp/utils.py +11 -5
  28. flwr/client/grpc_rere_client/connection.py +3 -0
  29. flwr/common/config.py +18 -5
  30. flwr/common/constant.py +3 -0
  31. flwr/common/message.py +5 -0
  32. flwr/common/recordset_compat.py +10 -0
  33. flwr/common/retry_invoker.py +15 -0
  34. flwr/server/client_manager.py +2 -0
  35. flwr/server/compat/driver_client_proxy.py +15 -29
  36. flwr/server/driver/inmemory_driver.py +6 -2
  37. flwr/server/run_serverapp.py +11 -13
  38. flwr/server/superlink/driver/driver_servicer.py +1 -1
  39. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
  40. flwr/server/superlink/fleet/vce/vce_api.py +1 -1
  41. flwr/server/superlink/state/in_memory_state.py +26 -8
  42. flwr/server/superlink/state/sqlite_state.py +46 -11
  43. flwr/server/superlink/state/state.py +1 -7
  44. flwr/server/superlink/state/utils.py +0 -10
  45. {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/METADATA +1 -1
  46. {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/RECORD +49 -47
  47. {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/LICENSE +0 -0
  48. {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/WHEEL +0 -0
  49. {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/entry_points.txt +0 -0
@@ -1,16 +1,21 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context
3
+ from flwr.common import Context, ndarrays_to_parameters
4
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
5
  from flwr.server.strategy import FedAvg
6
+ from $import_name.task import get_dummy_model
6
7
 
7
8
 
8
9
  def server_fn(context: Context):
9
10
  # Read from config
10
11
  num_rounds = context.run_config["num-server-rounds"]
11
12
 
13
+ # Initial model
14
+ model = get_dummy_model()
15
+ dummy_parameters = ndarrays_to_parameters([model])
16
+
12
17
  # Define strategy
13
- strategy = FedAvg()
18
+ strategy = FedAvg(initial_parameters=dummy_parameters)
14
19
  config = ServerConfig(num_rounds=num_rounds)
15
20
 
16
21
  return ServerAppComponents(strategy=strategy, config=config)
@@ -3,7 +3,6 @@
3
3
  from flwr.common import Context, ndarrays_to_parameters
4
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
5
  from flwr.server.strategy import FedAvg
6
-
7
6
  from $import_name.task import Net, get_weights
8
7
 
9
8
 
@@ -27,5 +26,6 @@ def server_fn(context: Context):
27
26
 
28
27
  return ServerAppComponents(strategy=strategy, config=config)
29
28
 
29
+
30
30
  # Create ServerApp
31
31
  app = ServerApp(server_fn=server_fn)
@@ -1,19 +1,31 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context
3
+ from flwr.common import Context, ndarrays_to_parameters
4
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
5
  from flwr.server.strategy import FedAvg
6
+ from $import_name.task import get_model, get_model_params, set_initial_params
6
7
 
7
8
 
8
9
  def server_fn(context: Context):
9
10
  # Read from config
10
11
  num_rounds = context.run_config["num-server-rounds"]
11
12
 
13
+ # Create LogisticRegression Model
14
+ penalty = context.run_config["penalty"]
15
+ local_epochs = context.run_config["local-epochs"]
16
+ model = get_model(penalty, local_epochs)
17
+
18
+ # Setting initial parameters, akin to model.compile for keras models
19
+ set_initial_params(model)
20
+
21
+ initial_parameters = ndarrays_to_parameters(get_model_params(model))
22
+
12
23
  # Define strategy
13
24
  strategy = FedAvg(
14
25
  fraction_fit=1.0,
15
26
  fraction_evaluate=1.0,
16
27
  min_available_clients=2,
28
+ initial_parameters=initial_parameters,
17
29
  )
18
30
  config = ServerConfig(num_rounds=num_rounds)
19
31
 
@@ -2,9 +2,9 @@
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
+ import numpy as np
5
6
  from sklearn.datasets import make_regression
6
7
  from sklearn.model_selection import train_test_split
7
- import numpy as np
8
8
 
9
9
  key = jax.random.PRNGKey(0)
10
10
 
@@ -33,7 +33,7 @@ def train(params, grad_fn, X, y):
33
33
  num_examples = X.shape[0]
34
34
  for epochs in range(50):
35
35
  grads = grad_fn(params, X, y)
36
- params = jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
36
+ params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
37
37
  loss = loss_fn(params, X, y)
38
38
  return params, loss, num_examples
39
39
 
@@ -3,10 +3,10 @@
3
3
  import mlx.core as mx
4
4
  import mlx.nn as nn
5
5
  import numpy as np
6
- from datasets.utils.logging import disable_progress_bar
7
6
  from flwr_datasets import FederatedDataset
8
7
  from flwr_datasets.partitioner import IidPartitioner
9
8
 
9
+ from datasets.utils.logging import disable_progress_bar
10
10
 
11
11
  disable_progress_bar()
12
12
 
@@ -0,0 +1,7 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import numpy as np
4
+
5
+
6
+ def get_dummy_model():
7
+ return np.ones((1, 1))
@@ -5,10 +5,10 @@ from collections import OrderedDict
5
5
  import torch
6
6
  import torch.nn as nn
7
7
  import torch.nn.functional as F
8
- from torch.utils.data import DataLoader
9
- from torchvision.transforms import Compose, Normalize, ToTensor
10
8
  from flwr_datasets import FederatedDataset
11
9
  from flwr_datasets.partitioner import IidPartitioner
10
+ from torch.utils.data import DataLoader
11
+ from torchvision.transforms import Compose, Normalize, ToTensor
12
12
 
13
13
 
14
14
  class Net(nn.Module):
@@ -67,7 +67,7 @@ def train(net, trainloader, epochs, device):
67
67
  """Train the model on the training set."""
68
68
  net.to(device) # move model to GPU if available
69
69
  criterion = torch.nn.CrossEntropyLoss().to(device)
70
- optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
70
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
71
71
  net.train()
72
72
  running_loss = 0.0
73
73
  for _ in range(epochs):
@@ -0,0 +1,67 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import numpy as np
4
+ from flwr_datasets import FederatedDataset
5
+ from flwr_datasets.partitioner import IidPartitioner
6
+ from sklearn.linear_model import LogisticRegression
7
+
8
+ fds = None # Cache FederatedDataset
9
+
10
+
11
+ def load_data(partition_id: int, num_partitions: int):
12
+ """Load partition MNIST data."""
13
+ # Only initialize `FederatedDataset` once
14
+ global fds
15
+ if fds is None:
16
+ partitioner = IidPartitioner(num_partitions=num_partitions)
17
+ fds = FederatedDataset(
18
+ dataset="mnist",
19
+ partitioners={"train": partitioner},
20
+ )
21
+
22
+ dataset = fds.load_partition(partition_id, "train").with_format("numpy")
23
+
24
+ X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
25
+
26
+ # Split the on edge data: 80% train, 20% test
27
+ X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
28
+ y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
29
+
30
+ return X_train, X_test, y_train, y_test
31
+
32
+
33
+ def get_model(penalty: str, local_epochs: int):
34
+
35
+ return LogisticRegression(
36
+ penalty=penalty,
37
+ max_iter=local_epochs,
38
+ warm_start=True,
39
+ )
40
+
41
+
42
+ def get_model_params(model):
43
+ if model.fit_intercept:
44
+ params = [
45
+ model.coef_,
46
+ model.intercept_,
47
+ ]
48
+ else:
49
+ params = [model.coef_]
50
+ return params
51
+
52
+
53
+ def set_model_params(model, params):
54
+ model.coef_ = params[0]
55
+ if model.fit_intercept:
56
+ model.intercept_ = params[1]
57
+ return model
58
+
59
+
60
+ def set_initial_params(model):
61
+ n_classes = 10 # MNIST has 10 classes
62
+ n_features = 784 # Number of features in dataset
63
+ model.classes_ = np.array([i for i in range(10)])
64
+
65
+ model.coef_ = np.zeros((n_classes, n_features))
66
+ if model.fit_intercept:
67
+ model.intercept_ = np.zeros((n_classes,))
@@ -14,7 +14,7 @@ dependencies = [
14
14
  "bitsandbytes==0.43.0",
15
15
  "scipy==1.13.0",
16
16
  "peft==0.6.2",
17
- "transformers==4.39.3",
17
+ "transformers==4.43.1",
18
18
  "sentencepiece==0.2.0",
19
19
  "omegaconf==2.3.0",
20
20
  "hf_transfer==0.1.8",
@@ -9,8 +9,8 @@ description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
11
  "flwr[simulation]>=1.10.0",
12
- "jax==0.4.13",
13
- "jaxlib==0.4.13",
12
+ "jax==0.4.30",
13
+ "jaxlib==0.4.30",
14
14
  "scikit-learn==1.3.2",
15
15
  ]
16
16
 
@@ -26,6 +26,7 @@ clientapp = "$import_name.client_app:app"
26
26
 
27
27
  [tool.flwr.app.config]
28
28
  num-server-rounds = 3
29
+ input-dim = 3
29
30
 
30
31
  [tool.flwr.federations]
31
32
  default = "local-simulation"
@@ -28,6 +28,7 @@ clientapp = "$import_name.client_app:app"
28
28
  num-server-rounds = 3
29
29
  local-epochs = 1
30
30
  num-layers = 2
31
+ input-dim = 784 # 28*28
31
32
  hidden-dim = 32
32
33
  batch-size = 256
33
34
  lr = 0.1
@@ -25,6 +25,8 @@ clientapp = "$import_name.client_app:app"
25
25
 
26
26
  [tool.flwr.app.config]
27
27
  num-server-rounds = 3
28
+ penalty = "l2"
29
+ local-epochs = 1
28
30
 
29
31
  [tool.flwr.federations]
30
32
  default = "local-simulation"
flwr/cli/run/run.py CHANGED
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `run` command."""
16
16
 
17
- import hashlib
18
17
  import json
19
18
  import subprocess
20
19
  import sys
@@ -134,6 +133,7 @@ def run(
134
133
  _run_without_superexec(app, federation_config, config_overrides, federation)
135
134
 
136
135
 
136
+ # pylint: disable=too-many-locals
137
137
  def _run_with_superexec(
138
138
  app: Path,
139
139
  federation_config: dict[str, Any],
@@ -179,9 +179,9 @@ def _run_with_superexec(
179
179
  channel.subscribe(on_channel_state_change)
180
180
  stub = ExecStub(channel)
181
181
 
182
- fab_path = Path(build(app))
183
- content = fab_path.read_bytes()
184
- fab = Fab(hashlib.sha256(content).hexdigest(), content)
182
+ fab_path, fab_hash = build(app)
183
+ content = Path(fab_path).read_bytes()
184
+ fab = Fab(fab_hash, content)
185
185
 
186
186
  req = StartRunRequest(
187
187
  fab=fab_to_proto(fab),
@@ -193,7 +193,7 @@ def _run_with_superexec(
193
193
  res = stub.StartRun(req)
194
194
 
195
195
  # Delete FAB file once it has been sent to the SuperExec
196
- fab_path.unlink()
196
+ Path(fab_path).unlink()
197
197
  typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
198
198
 
199
199
  if stream:
flwr/client/app.py CHANGED
@@ -132,6 +132,11 @@ def start_client(
132
132
  - 'grpc-bidi': gRPC, bidirectional streaming
133
133
  - 'grpc-rere': gRPC, request-response (experimental)
134
134
  - 'rest': HTTP (experimental)
135
+ authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
136
+ Tuple containing the elliptic curve private key and public key for
137
+ authentication from the cryptography library.
138
+ Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
139
+ Used to establish an authenticated connection with the server.
135
140
  max_retries: Optional[int] (default: None)
136
141
  The maximum number of times the client will try to connect to the
137
142
  server before giving up in case of a connection error. If set to None,
@@ -197,7 +202,7 @@ def start_client_internal(
197
202
  *,
198
203
  server_address: str,
199
204
  node_config: UserConfig,
200
- load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
205
+ load_client_app_fn: Optional[Callable[[str, str, str], ClientApp]] = None,
201
206
  client_fn: Optional[ClientFnExt] = None,
202
207
  client: Optional[Client] = None,
203
208
  grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
@@ -249,6 +254,11 @@ def start_client_internal(
249
254
  - 'grpc-bidi': gRPC, bidirectional streaming
250
255
  - 'grpc-rere': gRPC, request-response (experimental)
251
256
  - 'rest': HTTP (experimental)
257
+ authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
258
+ Tuple containing the elliptic curve private key and public key for
259
+ authentication from the cryptography library.
260
+ Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
261
+ Used to establish an authenticated connection with the server.
252
262
  max_retries: Optional[int] (default: None)
253
263
  The maximum number of times the client will try to connect to the
254
264
  server before giving up in case of a connection error. If set to None,
@@ -288,7 +298,7 @@ def start_client_internal(
288
298
 
289
299
  client_fn = single_client_factory
290
300
 
291
- def _load_client_app(_1: str, _2: str) -> ClientApp:
301
+ def _load_client_app(_1: str, _2: str, _3: str) -> ClientApp:
292
302
  return ClientApp(client_fn=client_fn)
293
303
 
294
304
  load_client_app_fn = _load_client_app
@@ -519,7 +529,7 @@ def start_client_internal(
519
529
  else:
520
530
  # Load ClientApp instance
521
531
  client_app: ClientApp = load_client_app_fn(
522
- fab_id, fab_version
532
+ fab_id, fab_version, run.fab_hash
523
533
  )
524
534
 
525
535
  # Execute ClientApp
@@ -132,8 +132,11 @@ def run_clientapp( # pylint: disable=R0914
132
132
  )
133
133
 
134
134
  try:
135
- # Load ClientApp
136
- client_app: ClientApp = load_client_app_fn(run.fab_id, run.fab_version)
135
+ if fab:
136
+ # Load ClientApp
137
+ client_app: ClientApp = load_client_app_fn(
138
+ run.fab_id, run.fab_version, fab.hash_str
139
+ )
137
140
 
138
141
  # Execute ClientApp
139
142
  reply_message = client_app(message=message, context=context)
@@ -34,7 +34,7 @@ def get_load_client_app_fn(
34
34
  app_path: Optional[str],
35
35
  multi_app: bool,
36
36
  flwr_dir: Optional[str] = None,
37
- ) -> Callable[[str, str], ClientApp]:
37
+ ) -> Callable[[str, str, str], ClientApp]:
38
38
  """Get the load_client_app_fn function.
39
39
 
40
40
  If `multi_app` is True, this function loads the specified ClientApp
@@ -55,13 +55,14 @@ def get_load_client_app_fn(
55
55
  if not valid and error_msg:
56
56
  raise LoadClientAppError(error_msg) from None
57
57
 
58
- def _load(fab_id: str, fab_version: str) -> ClientApp:
58
+ def _load(fab_id: str, fab_version: str, fab_hash: str) -> ClientApp:
59
59
  runtime_app_dir = Path(app_path if app_path else "").absolute()
60
60
  # If multi-app feature is disabled
61
61
  if not multi_app:
62
62
  # Set app reference
63
63
  client_app_ref = default_app_ref
64
- # If multi-app feature is enabled but app directory is provided
64
+ # If multi-app feature is enabled but app directory is provided.
65
+ # `fab_hash` is not required since the app is loaded from `runtime_app_dir`.
65
66
  elif app_path is not None:
66
67
  config = get_project_config(runtime_app_dir)
67
68
  this_fab_version, this_fab_id = get_metadata_from_config(config)
@@ -81,11 +82,16 @@ def get_load_client_app_fn(
81
82
  else:
82
83
  try:
83
84
  runtime_app_dir = get_project_dir(
84
- fab_id, fab_version, get_flwr_dir(flwr_dir)
85
+ fab_id, fab_version, fab_hash, get_flwr_dir(flwr_dir)
85
86
  )
86
87
  config = get_project_config(runtime_app_dir)
87
88
  except Exception as e:
88
- raise LoadClientAppError("Failed to load ClientApp") from e
89
+ raise LoadClientAppError(
90
+ "Failed to load ClientApp."
91
+ "Possible reasons for error include mismatched "
92
+ "`fab_id`, `fab_version`, or `fab_hash` in "
93
+ f"{str(get_flwr_dir(flwr_dir).resolve())}."
94
+ ) from e
89
95
 
90
96
  # Set app reference
91
97
  client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
@@ -120,6 +120,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
120
120
  authentication from the cryptography library.
121
121
  Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
122
122
  Used to establish an authenticated connection with the server.
123
+ adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] (default: None)
124
+ A GrpcStub Class that can be used to send messages. By default the FleetStub
125
+ will be used.
123
126
 
124
127
  Returns
125
128
  -------
flwr/common/config.py CHANGED
@@ -22,7 +22,12 @@ from typing import Any, Optional, Union, cast, get_args
22
22
  import tomli
23
23
 
24
24
  from flwr.cli.config_utils import get_fab_config, validate_fields
25
- from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
25
+ from flwr.common.constant import (
26
+ APP_DIR,
27
+ FAB_CONFIG_FILE,
28
+ FAB_HASH_TRUNCATION,
29
+ FLWR_HOME,
30
+ )
26
31
  from flwr.common.typing import Run, UserConfig, UserConfigValue
27
32
 
28
33
 
@@ -39,7 +44,10 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
39
44
 
40
45
 
41
46
  def get_project_dir(
42
- fab_id: str, fab_version: str, flwr_dir: Optional[Union[str, Path]] = None
47
+ fab_id: str,
48
+ fab_version: str,
49
+ fab_hash: str,
50
+ flwr_dir: Optional[Union[str, Path]] = None,
43
51
  ) -> Path:
44
52
  """Return the project directory based on the given fab_id and fab_version."""
45
53
  # Check the fab_id
@@ -50,7 +58,11 @@ def get_project_dir(
50
58
  publisher, project_name = fab_id.split("/")
51
59
  if flwr_dir is None:
52
60
  flwr_dir = get_flwr_dir()
53
- return Path(flwr_dir) / APP_DIR / publisher / project_name / fab_version
61
+ return (
62
+ Path(flwr_dir)
63
+ / APP_DIR
64
+ / f"{publisher}.{project_name}.{fab_version}.{fab_hash[:FAB_HASH_TRUNCATION]}"
65
+ )
54
66
 
55
67
 
56
68
  def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]:
@@ -127,7 +139,7 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
127
139
  if not run.fab_id or not run.fab_version:
128
140
  return {}
129
141
 
130
- project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
142
+ project_dir = get_project_dir(run.fab_id, run.fab_version, run.fab_hash, flwr_dir)
131
143
 
132
144
  # Return empty dict if project directory does not exist
133
145
  if not project_dir.is_dir():
@@ -205,8 +217,9 @@ def parse_config_args(
205
217
  matches = pattern.findall(config_line)
206
218
  toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
207
219
  overrides.update(tomli.loads(toml_str))
220
+ flat_overrides = flatten_dict(overrides)
208
221
 
209
- return overrides
222
+ return flat_overrides
210
223
 
211
224
 
212
225
  def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
flwr/common/constant.py CHANGED
@@ -63,7 +63,10 @@ NODE_ID_NUM_BYTES = 8
63
63
 
64
64
  # Constants for FAB
65
65
  APP_DIR = "apps"
66
+ FAB_ALLOWED_EXTENSIONS = {".py", ".toml", ".md"}
66
67
  FAB_CONFIG_FILE = "pyproject.toml"
68
+ FAB_DATE = (2024, 10, 1, 0, 0, 0)
69
+ FAB_HASH_TRUNCATION = 8
67
70
  FLWR_HOME = "FLWR_HOME"
68
71
 
69
72
  # Constants entries in Node config for Simulation
flwr/common/message.py CHANGED
@@ -290,6 +290,11 @@ class Message:
290
290
  follows the equation:
291
291
 
292
292
  ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
293
+
294
+ Returns
295
+ -------
296
+ message : Message
297
+ A Message containing only the relevant error and metadata.
293
298
  """
294
299
  # If no TTL passed, use default for message creation (will update after
295
300
  # message creation)
@@ -59,6 +59,11 @@ def parametersrecord_to_parameters(
59
59
  keep_input : bool
60
60
  A boolean indicating whether entries in the record should be deleted from the
61
61
  input dictionary immediately after adding them to the record.
62
+
63
+ Returns
64
+ -------
65
+ parameters : Parameters
66
+ The parameters in the legacy format Parameters.
62
67
  """
63
68
  parameters = Parameters(tensors=[], tensor_type="")
64
69
 
@@ -94,6 +99,11 @@ def parameters_to_parametersrecord(
94
99
  A boolean indicating whether parameters should be deleted from the input
95
100
  Parameters object (i.e. a list of serialized NumPy arrays) immediately after
96
101
  adding them to the record.
102
+
103
+ Returns
104
+ -------
105
+ ParametersRecord
106
+ The ParametersRecord containing the provided parameters.
97
107
  """
98
108
  tensor_type = parameters.tensor_type
99
109
 
@@ -38,6 +38,11 @@ def exponential(
38
38
  Factor by which the delay is multiplied after each retry.
39
39
  max_delay: Optional[float] (default: None)
40
40
  The maximum delay duration between two consecutive retries.
41
+
42
+ Returns
43
+ -------
44
+ Generator[float, None, None]
45
+ A generator for the delay between 2 retries.
41
46
  """
42
47
  delay = base_delay if max_delay is None else min(base_delay, max_delay)
43
48
  while True:
@@ -56,6 +61,11 @@ def constant(
56
61
  ----------
57
62
  interval: Union[float, Iterable[float]] (default: 1)
58
63
  A constant value to yield or an iterable of such values.
64
+
65
+ Returns
66
+ -------
67
+ Generator[float, None, None]
68
+ A generator for the delay between 2 retries.
59
69
  """
60
70
  if not isinstance(interval, Iterable):
61
71
  interval = itertools.repeat(interval)
@@ -73,6 +83,11 @@ def full_jitter(max_value: float) -> float:
73
83
  ----------
74
84
  max_value : float
75
85
  The upper limit for the randomized value.
86
+
87
+ Returns
88
+ -------
89
+ float
90
+ A random float that is less than max_value.
76
91
  """
77
92
  return random.uniform(0, max_value)
78
93
 
@@ -47,6 +47,7 @@ class ClientManager(ABC):
47
47
  Parameters
48
48
  ----------
49
49
  client : flwr.server.client_proxy.ClientProxy
50
+ The ClientProxy of the Client to register.
50
51
 
51
52
  Returns
52
53
  -------
@@ -64,6 +65,7 @@ class ClientManager(ABC):
64
65
  Parameters
65
66
  ----------
66
67
  client : flwr.server.client_proxy.ClientProxy
68
+ The ClientProxy of the Client to unregister.
67
69
  """
68
70
 
69
71
  @abstractmethod
@@ -15,7 +15,6 @@
15
15
  """Flower ClientProxy implementation for Driver API."""
16
16
 
17
17
 
18
- import time
19
18
  from typing import Optional
20
19
 
21
20
  from flwr import common
@@ -25,8 +24,6 @@ from flwr.server.client_proxy import ClientProxy
25
24
 
26
25
  from ..driver.driver import Driver
27
26
 
28
- SLEEP_TIME = 1
29
-
30
27
 
31
28
  class DriverClientProxy(ClientProxy):
32
29
  """Flower client proxy which delegates work using the Driver API."""
@@ -122,29 +119,18 @@ class DriverClientProxy(ClientProxy):
122
119
  ttl=timeout,
123
120
  )
124
121
 
125
- # Push message
126
- message_ids = list(self.driver.push_messages(messages=[message]))
127
- if len(message_ids) != 1:
128
- raise ValueError("Unexpected number of message_ids")
129
-
130
- message_id = message_ids[0]
131
- if message_id == "":
132
- raise ValueError(f"Failed to send message to node {self.node_id}")
133
-
134
- if timeout:
135
- start_time = time.time()
136
-
137
- while True:
138
- messages = list(self.driver.pull_messages(message_ids))
139
- if len(messages) == 1:
140
- msg: Message = messages[0]
141
- if msg.has_error():
142
- raise ValueError(
143
- f"Message contains an Error (reason: {msg.error.reason}). "
144
- "It originated during client-side execution of a message."
145
- )
146
- return msg.content
147
-
148
- if timeout is not None and time.time() > start_time + timeout:
149
- raise RuntimeError("Timeout reached")
150
- time.sleep(SLEEP_TIME)
122
+ # Send message and wait for reply
123
+ messages = list(self.driver.send_and_receive(messages=[message]))
124
+
125
+ # A single reply is expected
126
+ if len(messages) != 1:
127
+ raise ValueError(f"Expected one Message but got: {len(messages)}")
128
+
129
+ # Only messages without errors can be handled beyond these point
130
+ msg: Message = messages[0]
131
+ if msg.has_error():
132
+ raise ValueError(
133
+ f"Message contains an Error (reason: {msg.error.reason}). "
134
+ "It originated during client-side execution of a message."
135
+ )
136
+ return msg.content