flwr-nightly 1.12.0.dev20241007__py3-none-any.whl → 1.12.0.dev20241010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +60 -29
- flwr/cli/config_utils.py +10 -0
- flwr/cli/install.py +60 -20
- flwr/cli/new/new.py +2 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +11 -17
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +16 -36
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +4 -5
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +8 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +14 -48
- flwr/cli/new/templates/app/code/server.jax.py.tpl +9 -3
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +13 -2
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +7 -2
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +13 -1
- flwr/cli/new/templates/app/code/task.jax.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +7 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +3 -3
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +2 -0
- flwr/cli/run/run.py +5 -5
- flwr/client/app.py +13 -3
- flwr/client/clientapp/app.py +5 -2
- flwr/client/clientapp/utils.py +11 -5
- flwr/client/grpc_rere_client/connection.py +3 -0
- flwr/common/config.py +18 -5
- flwr/common/constant.py +3 -0
- flwr/common/message.py +5 -0
- flwr/common/recordset_compat.py +10 -0
- flwr/common/retry_invoker.py +15 -0
- flwr/server/client_manager.py +2 -0
- flwr/server/compat/driver_client_proxy.py +15 -29
- flwr/server/driver/inmemory_driver.py +6 -2
- flwr/server/run_serverapp.py +11 -13
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +26 -8
- flwr/server/superlink/state/sqlite_state.py +46 -11
- flwr/server/superlink/state/state.py +1 -7
- flwr/server/superlink/state/utils.py +0 -10
- {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/METADATA +1 -1
- {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/RECORD +49 -47
- {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/entry_points.txt +0 -0
|
@@ -1,16 +1,21 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.common import Context
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
5
|
from flwr.server.strategy import FedAvg
|
|
6
|
+
from $import_name.task import get_dummy_model
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def server_fn(context: Context):
|
|
9
10
|
# Read from config
|
|
10
11
|
num_rounds = context.run_config["num-server-rounds"]
|
|
11
12
|
|
|
13
|
+
# Initial model
|
|
14
|
+
model = get_dummy_model()
|
|
15
|
+
dummy_parameters = ndarrays_to_parameters([model])
|
|
16
|
+
|
|
12
17
|
# Define strategy
|
|
13
|
-
strategy = FedAvg()
|
|
18
|
+
strategy = FedAvg(initial_parameters=dummy_parameters)
|
|
14
19
|
config = ServerConfig(num_rounds=num_rounds)
|
|
15
20
|
|
|
16
21
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
from flwr.common import Context, ndarrays_to_parameters
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
5
|
from flwr.server.strategy import FedAvg
|
|
6
|
-
|
|
7
6
|
from $import_name.task import Net, get_weights
|
|
8
7
|
|
|
9
8
|
|
|
@@ -27,5 +26,6 @@ def server_fn(context: Context):
|
|
|
27
26
|
|
|
28
27
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
29
28
|
|
|
29
|
+
|
|
30
30
|
# Create ServerApp
|
|
31
31
|
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,19 +1,31 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.common import Context
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
5
|
from flwr.server.strategy import FedAvg
|
|
6
|
+
from $import_name.task import get_model, get_model_params, set_initial_params
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def server_fn(context: Context):
|
|
9
10
|
# Read from config
|
|
10
11
|
num_rounds = context.run_config["num-server-rounds"]
|
|
11
12
|
|
|
13
|
+
# Create LogisticRegression Model
|
|
14
|
+
penalty = context.run_config["penalty"]
|
|
15
|
+
local_epochs = context.run_config["local-epochs"]
|
|
16
|
+
model = get_model(penalty, local_epochs)
|
|
17
|
+
|
|
18
|
+
# Setting initial parameters, akin to model.compile for keras models
|
|
19
|
+
set_initial_params(model)
|
|
20
|
+
|
|
21
|
+
initial_parameters = ndarrays_to_parameters(get_model_params(model))
|
|
22
|
+
|
|
12
23
|
# Define strategy
|
|
13
24
|
strategy = FedAvg(
|
|
14
25
|
fraction_fit=1.0,
|
|
15
26
|
fraction_evaluate=1.0,
|
|
16
27
|
min_available_clients=2,
|
|
28
|
+
initial_parameters=initial_parameters,
|
|
17
29
|
)
|
|
18
30
|
config = ServerConfig(num_rounds=num_rounds)
|
|
19
31
|
|
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
5
|
+
import numpy as np
|
|
5
6
|
from sklearn.datasets import make_regression
|
|
6
7
|
from sklearn.model_selection import train_test_split
|
|
7
|
-
import numpy as np
|
|
8
8
|
|
|
9
9
|
key = jax.random.PRNGKey(0)
|
|
10
10
|
|
|
@@ -33,7 +33,7 @@ def train(params, grad_fn, X, y):
|
|
|
33
33
|
num_examples = X.shape[0]
|
|
34
34
|
for epochs in range(50):
|
|
35
35
|
grads = grad_fn(params, X, y)
|
|
36
|
-
params = jax.
|
|
36
|
+
params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
|
|
37
37
|
loss = loss_fn(params, X, y)
|
|
38
38
|
return params, loss, num_examples
|
|
39
39
|
|
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
import mlx.core as mx
|
|
4
4
|
import mlx.nn as nn
|
|
5
5
|
import numpy as np
|
|
6
|
-
from datasets.utils.logging import disable_progress_bar
|
|
7
6
|
from flwr_datasets import FederatedDataset
|
|
8
7
|
from flwr_datasets.partitioner import IidPartitioner
|
|
9
8
|
|
|
9
|
+
from datasets.utils.logging import disable_progress_bar
|
|
10
10
|
|
|
11
11
|
disable_progress_bar()
|
|
12
12
|
|
|
@@ -5,10 +5,10 @@ from collections import OrderedDict
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
import torch.nn.functional as F
|
|
8
|
-
from torch.utils.data import DataLoader
|
|
9
|
-
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
10
8
|
from flwr_datasets import FederatedDataset
|
|
11
9
|
from flwr_datasets.partitioner import IidPartitioner
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class Net(nn.Module):
|
|
@@ -67,7 +67,7 @@ def train(net, trainloader, epochs, device):
|
|
|
67
67
|
"""Train the model on the training set."""
|
|
68
68
|
net.to(device) # move model to GPU if available
|
|
69
69
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
70
|
-
optimizer = torch.optim.
|
|
70
|
+
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
|
|
71
71
|
net.train()
|
|
72
72
|
running_loss = 0.0
|
|
73
73
|
for _ in range(epochs):
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from flwr_datasets import FederatedDataset
|
|
5
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
6
|
+
from sklearn.linear_model import LogisticRegression
|
|
7
|
+
|
|
8
|
+
fds = None # Cache FederatedDataset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_data(partition_id: int, num_partitions: int):
|
|
12
|
+
"""Load partition MNIST data."""
|
|
13
|
+
# Only initialize `FederatedDataset` once
|
|
14
|
+
global fds
|
|
15
|
+
if fds is None:
|
|
16
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
17
|
+
fds = FederatedDataset(
|
|
18
|
+
dataset="mnist",
|
|
19
|
+
partitioners={"train": partitioner},
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
|
|
23
|
+
|
|
24
|
+
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
|
|
25
|
+
|
|
26
|
+
# Split the on edge data: 80% train, 20% test
|
|
27
|
+
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
|
|
28
|
+
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
|
|
29
|
+
|
|
30
|
+
return X_train, X_test, y_train, y_test
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_model(penalty: str, local_epochs: int):
|
|
34
|
+
|
|
35
|
+
return LogisticRegression(
|
|
36
|
+
penalty=penalty,
|
|
37
|
+
max_iter=local_epochs,
|
|
38
|
+
warm_start=True,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_model_params(model):
|
|
43
|
+
if model.fit_intercept:
|
|
44
|
+
params = [
|
|
45
|
+
model.coef_,
|
|
46
|
+
model.intercept_,
|
|
47
|
+
]
|
|
48
|
+
else:
|
|
49
|
+
params = [model.coef_]
|
|
50
|
+
return params
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def set_model_params(model, params):
|
|
54
|
+
model.coef_ = params[0]
|
|
55
|
+
if model.fit_intercept:
|
|
56
|
+
model.intercept_ = params[1]
|
|
57
|
+
return model
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def set_initial_params(model):
|
|
61
|
+
n_classes = 10 # MNIST has 10 classes
|
|
62
|
+
n_features = 784 # Number of features in dataset
|
|
63
|
+
model.classes_ = np.array([i for i in range(10)])
|
|
64
|
+
|
|
65
|
+
model.coef_ = np.zeros((n_classes, n_features))
|
|
66
|
+
if model.fit_intercept:
|
|
67
|
+
model.intercept_ = np.zeros((n_classes,))
|
|
@@ -9,8 +9,8 @@ description = ""
|
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
11
|
"flwr[simulation]>=1.10.0",
|
|
12
|
-
"jax==0.4.
|
|
13
|
-
"jaxlib==0.4.
|
|
12
|
+
"jax==0.4.30",
|
|
13
|
+
"jaxlib==0.4.30",
|
|
14
14
|
"scikit-learn==1.3.2",
|
|
15
15
|
]
|
|
16
16
|
|
|
@@ -26,6 +26,7 @@ clientapp = "$import_name.client_app:app"
|
|
|
26
26
|
|
|
27
27
|
[tool.flwr.app.config]
|
|
28
28
|
num-server-rounds = 3
|
|
29
|
+
input-dim = 3
|
|
29
30
|
|
|
30
31
|
[tool.flwr.federations]
|
|
31
32
|
default = "local-simulation"
|
flwr/cli/run/run.py
CHANGED
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface `run` command."""
|
|
16
16
|
|
|
17
|
-
import hashlib
|
|
18
17
|
import json
|
|
19
18
|
import subprocess
|
|
20
19
|
import sys
|
|
@@ -134,6 +133,7 @@ def run(
|
|
|
134
133
|
_run_without_superexec(app, federation_config, config_overrides, federation)
|
|
135
134
|
|
|
136
135
|
|
|
136
|
+
# pylint: disable=too-many-locals
|
|
137
137
|
def _run_with_superexec(
|
|
138
138
|
app: Path,
|
|
139
139
|
federation_config: dict[str, Any],
|
|
@@ -179,9 +179,9 @@ def _run_with_superexec(
|
|
|
179
179
|
channel.subscribe(on_channel_state_change)
|
|
180
180
|
stub = ExecStub(channel)
|
|
181
181
|
|
|
182
|
-
fab_path =
|
|
183
|
-
content = fab_path.read_bytes()
|
|
184
|
-
fab = Fab(
|
|
182
|
+
fab_path, fab_hash = build(app)
|
|
183
|
+
content = Path(fab_path).read_bytes()
|
|
184
|
+
fab = Fab(fab_hash, content)
|
|
185
185
|
|
|
186
186
|
req = StartRunRequest(
|
|
187
187
|
fab=fab_to_proto(fab),
|
|
@@ -193,7 +193,7 @@ def _run_with_superexec(
|
|
|
193
193
|
res = stub.StartRun(req)
|
|
194
194
|
|
|
195
195
|
# Delete FAB file once it has been sent to the SuperExec
|
|
196
|
-
fab_path.unlink()
|
|
196
|
+
Path(fab_path).unlink()
|
|
197
197
|
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
|
|
198
198
|
|
|
199
199
|
if stream:
|
flwr/client/app.py
CHANGED
|
@@ -132,6 +132,11 @@ def start_client(
|
|
|
132
132
|
- 'grpc-bidi': gRPC, bidirectional streaming
|
|
133
133
|
- 'grpc-rere': gRPC, request-response (experimental)
|
|
134
134
|
- 'rest': HTTP (experimental)
|
|
135
|
+
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
136
|
+
Tuple containing the elliptic curve private key and public key for
|
|
137
|
+
authentication from the cryptography library.
|
|
138
|
+
Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
|
|
139
|
+
Used to establish an authenticated connection with the server.
|
|
135
140
|
max_retries: Optional[int] (default: None)
|
|
136
141
|
The maximum number of times the client will try to connect to the
|
|
137
142
|
server before giving up in case of a connection error. If set to None,
|
|
@@ -197,7 +202,7 @@ def start_client_internal(
|
|
|
197
202
|
*,
|
|
198
203
|
server_address: str,
|
|
199
204
|
node_config: UserConfig,
|
|
200
|
-
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
|
|
205
|
+
load_client_app_fn: Optional[Callable[[str, str, str], ClientApp]] = None,
|
|
201
206
|
client_fn: Optional[ClientFnExt] = None,
|
|
202
207
|
client: Optional[Client] = None,
|
|
203
208
|
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
@@ -249,6 +254,11 @@ def start_client_internal(
|
|
|
249
254
|
- 'grpc-bidi': gRPC, bidirectional streaming
|
|
250
255
|
- 'grpc-rere': gRPC, request-response (experimental)
|
|
251
256
|
- 'rest': HTTP (experimental)
|
|
257
|
+
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
258
|
+
Tuple containing the elliptic curve private key and public key for
|
|
259
|
+
authentication from the cryptography library.
|
|
260
|
+
Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
|
|
261
|
+
Used to establish an authenticated connection with the server.
|
|
252
262
|
max_retries: Optional[int] (default: None)
|
|
253
263
|
The maximum number of times the client will try to connect to the
|
|
254
264
|
server before giving up in case of a connection error. If set to None,
|
|
@@ -288,7 +298,7 @@ def start_client_internal(
|
|
|
288
298
|
|
|
289
299
|
client_fn = single_client_factory
|
|
290
300
|
|
|
291
|
-
def _load_client_app(_1: str, _2: str) -> ClientApp:
|
|
301
|
+
def _load_client_app(_1: str, _2: str, _3: str) -> ClientApp:
|
|
292
302
|
return ClientApp(client_fn=client_fn)
|
|
293
303
|
|
|
294
304
|
load_client_app_fn = _load_client_app
|
|
@@ -519,7 +529,7 @@ def start_client_internal(
|
|
|
519
529
|
else:
|
|
520
530
|
# Load ClientApp instance
|
|
521
531
|
client_app: ClientApp = load_client_app_fn(
|
|
522
|
-
fab_id, fab_version
|
|
532
|
+
fab_id, fab_version, run.fab_hash
|
|
523
533
|
)
|
|
524
534
|
|
|
525
535
|
# Execute ClientApp
|
flwr/client/clientapp/app.py
CHANGED
|
@@ -132,8 +132,11 @@ def run_clientapp( # pylint: disable=R0914
|
|
|
132
132
|
)
|
|
133
133
|
|
|
134
134
|
try:
|
|
135
|
-
|
|
136
|
-
|
|
135
|
+
if fab:
|
|
136
|
+
# Load ClientApp
|
|
137
|
+
client_app: ClientApp = load_client_app_fn(
|
|
138
|
+
run.fab_id, run.fab_version, fab.hash_str
|
|
139
|
+
)
|
|
137
140
|
|
|
138
141
|
# Execute ClientApp
|
|
139
142
|
reply_message = client_app(message=message, context=context)
|
flwr/client/clientapp/utils.py
CHANGED
|
@@ -34,7 +34,7 @@ def get_load_client_app_fn(
|
|
|
34
34
|
app_path: Optional[str],
|
|
35
35
|
multi_app: bool,
|
|
36
36
|
flwr_dir: Optional[str] = None,
|
|
37
|
-
) -> Callable[[str, str], ClientApp]:
|
|
37
|
+
) -> Callable[[str, str, str], ClientApp]:
|
|
38
38
|
"""Get the load_client_app_fn function.
|
|
39
39
|
|
|
40
40
|
If `multi_app` is True, this function loads the specified ClientApp
|
|
@@ -55,13 +55,14 @@ def get_load_client_app_fn(
|
|
|
55
55
|
if not valid and error_msg:
|
|
56
56
|
raise LoadClientAppError(error_msg) from None
|
|
57
57
|
|
|
58
|
-
def _load(fab_id: str, fab_version: str) -> ClientApp:
|
|
58
|
+
def _load(fab_id: str, fab_version: str, fab_hash: str) -> ClientApp:
|
|
59
59
|
runtime_app_dir = Path(app_path if app_path else "").absolute()
|
|
60
60
|
# If multi-app feature is disabled
|
|
61
61
|
if not multi_app:
|
|
62
62
|
# Set app reference
|
|
63
63
|
client_app_ref = default_app_ref
|
|
64
|
-
# If multi-app feature is enabled but app directory is provided
|
|
64
|
+
# If multi-app feature is enabled but app directory is provided.
|
|
65
|
+
# `fab_hash` is not required since the app is loaded from `runtime_app_dir`.
|
|
65
66
|
elif app_path is not None:
|
|
66
67
|
config = get_project_config(runtime_app_dir)
|
|
67
68
|
this_fab_version, this_fab_id = get_metadata_from_config(config)
|
|
@@ -81,11 +82,16 @@ def get_load_client_app_fn(
|
|
|
81
82
|
else:
|
|
82
83
|
try:
|
|
83
84
|
runtime_app_dir = get_project_dir(
|
|
84
|
-
fab_id, fab_version, get_flwr_dir(flwr_dir)
|
|
85
|
+
fab_id, fab_version, fab_hash, get_flwr_dir(flwr_dir)
|
|
85
86
|
)
|
|
86
87
|
config = get_project_config(runtime_app_dir)
|
|
87
88
|
except Exception as e:
|
|
88
|
-
raise LoadClientAppError(
|
|
89
|
+
raise LoadClientAppError(
|
|
90
|
+
"Failed to load ClientApp."
|
|
91
|
+
"Possible reasons for error include mismatched "
|
|
92
|
+
"`fab_id`, `fab_version`, or `fab_hash` in "
|
|
93
|
+
f"{str(get_flwr_dir(flwr_dir).resolve())}."
|
|
94
|
+
) from e
|
|
89
95
|
|
|
90
96
|
# Set app reference
|
|
91
97
|
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
@@ -120,6 +120,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
120
120
|
authentication from the cryptography library.
|
|
121
121
|
Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
|
|
122
122
|
Used to establish an authenticated connection with the server.
|
|
123
|
+
adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] (default: None)
|
|
124
|
+
A GrpcStub Class that can be used to send messages. By default the FleetStub
|
|
125
|
+
will be used.
|
|
123
126
|
|
|
124
127
|
Returns
|
|
125
128
|
-------
|
flwr/common/config.py
CHANGED
|
@@ -22,7 +22,12 @@ from typing import Any, Optional, Union, cast, get_args
|
|
|
22
22
|
import tomli
|
|
23
23
|
|
|
24
24
|
from flwr.cli.config_utils import get_fab_config, validate_fields
|
|
25
|
-
from flwr.common.constant import
|
|
25
|
+
from flwr.common.constant import (
|
|
26
|
+
APP_DIR,
|
|
27
|
+
FAB_CONFIG_FILE,
|
|
28
|
+
FAB_HASH_TRUNCATION,
|
|
29
|
+
FLWR_HOME,
|
|
30
|
+
)
|
|
26
31
|
from flwr.common.typing import Run, UserConfig, UserConfigValue
|
|
27
32
|
|
|
28
33
|
|
|
@@ -39,7 +44,10 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
|
|
|
39
44
|
|
|
40
45
|
|
|
41
46
|
def get_project_dir(
|
|
42
|
-
fab_id: str,
|
|
47
|
+
fab_id: str,
|
|
48
|
+
fab_version: str,
|
|
49
|
+
fab_hash: str,
|
|
50
|
+
flwr_dir: Optional[Union[str, Path]] = None,
|
|
43
51
|
) -> Path:
|
|
44
52
|
"""Return the project directory based on the given fab_id and fab_version."""
|
|
45
53
|
# Check the fab_id
|
|
@@ -50,7 +58,11 @@ def get_project_dir(
|
|
|
50
58
|
publisher, project_name = fab_id.split("/")
|
|
51
59
|
if flwr_dir is None:
|
|
52
60
|
flwr_dir = get_flwr_dir()
|
|
53
|
-
return
|
|
61
|
+
return (
|
|
62
|
+
Path(flwr_dir)
|
|
63
|
+
/ APP_DIR
|
|
64
|
+
/ f"{publisher}.{project_name}.{fab_version}.{fab_hash[:FAB_HASH_TRUNCATION]}"
|
|
65
|
+
)
|
|
54
66
|
|
|
55
67
|
|
|
56
68
|
def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]:
|
|
@@ -127,7 +139,7 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
|
|
|
127
139
|
if not run.fab_id or not run.fab_version:
|
|
128
140
|
return {}
|
|
129
141
|
|
|
130
|
-
project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
|
|
142
|
+
project_dir = get_project_dir(run.fab_id, run.fab_version, run.fab_hash, flwr_dir)
|
|
131
143
|
|
|
132
144
|
# Return empty dict if project directory does not exist
|
|
133
145
|
if not project_dir.is_dir():
|
|
@@ -205,8 +217,9 @@ def parse_config_args(
|
|
|
205
217
|
matches = pattern.findall(config_line)
|
|
206
218
|
toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
|
|
207
219
|
overrides.update(tomli.loads(toml_str))
|
|
220
|
+
flat_overrides = flatten_dict(overrides)
|
|
208
221
|
|
|
209
|
-
return
|
|
222
|
+
return flat_overrides
|
|
210
223
|
|
|
211
224
|
|
|
212
225
|
def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
|
flwr/common/constant.py
CHANGED
|
@@ -63,7 +63,10 @@ NODE_ID_NUM_BYTES = 8
|
|
|
63
63
|
|
|
64
64
|
# Constants for FAB
|
|
65
65
|
APP_DIR = "apps"
|
|
66
|
+
FAB_ALLOWED_EXTENSIONS = {".py", ".toml", ".md"}
|
|
66
67
|
FAB_CONFIG_FILE = "pyproject.toml"
|
|
68
|
+
FAB_DATE = (2024, 10, 1, 0, 0, 0)
|
|
69
|
+
FAB_HASH_TRUNCATION = 8
|
|
67
70
|
FLWR_HOME = "FLWR_HOME"
|
|
68
71
|
|
|
69
72
|
# Constants entries in Node config for Simulation
|
flwr/common/message.py
CHANGED
|
@@ -290,6 +290,11 @@ class Message:
|
|
|
290
290
|
follows the equation:
|
|
291
291
|
|
|
292
292
|
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
|
|
293
|
+
|
|
294
|
+
Returns
|
|
295
|
+
-------
|
|
296
|
+
message : Message
|
|
297
|
+
A Message containing only the relevant error and metadata.
|
|
293
298
|
"""
|
|
294
299
|
# If no TTL passed, use default for message creation (will update after
|
|
295
300
|
# message creation)
|
flwr/common/recordset_compat.py
CHANGED
|
@@ -59,6 +59,11 @@ def parametersrecord_to_parameters(
|
|
|
59
59
|
keep_input : bool
|
|
60
60
|
A boolean indicating whether entries in the record should be deleted from the
|
|
61
61
|
input dictionary immediately after adding them to the record.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
parameters : Parameters
|
|
66
|
+
The parameters in the legacy format Parameters.
|
|
62
67
|
"""
|
|
63
68
|
parameters = Parameters(tensors=[], tensor_type="")
|
|
64
69
|
|
|
@@ -94,6 +99,11 @@ def parameters_to_parametersrecord(
|
|
|
94
99
|
A boolean indicating whether parameters should be deleted from the input
|
|
95
100
|
Parameters object (i.e. a list of serialized NumPy arrays) immediately after
|
|
96
101
|
adding them to the record.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
ParametersRecord
|
|
106
|
+
The ParametersRecord containing the provided parameters.
|
|
97
107
|
"""
|
|
98
108
|
tensor_type = parameters.tensor_type
|
|
99
109
|
|
flwr/common/retry_invoker.py
CHANGED
|
@@ -38,6 +38,11 @@ def exponential(
|
|
|
38
38
|
Factor by which the delay is multiplied after each retry.
|
|
39
39
|
max_delay: Optional[float] (default: None)
|
|
40
40
|
The maximum delay duration between two consecutive retries.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Generator[float, None, None]
|
|
45
|
+
A generator for the delay between 2 retries.
|
|
41
46
|
"""
|
|
42
47
|
delay = base_delay if max_delay is None else min(base_delay, max_delay)
|
|
43
48
|
while True:
|
|
@@ -56,6 +61,11 @@ def constant(
|
|
|
56
61
|
----------
|
|
57
62
|
interval: Union[float, Iterable[float]] (default: 1)
|
|
58
63
|
A constant value to yield or an iterable of such values.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
Generator[float, None, None]
|
|
68
|
+
A generator for the delay between 2 retries.
|
|
59
69
|
"""
|
|
60
70
|
if not isinstance(interval, Iterable):
|
|
61
71
|
interval = itertools.repeat(interval)
|
|
@@ -73,6 +83,11 @@ def full_jitter(max_value: float) -> float:
|
|
|
73
83
|
----------
|
|
74
84
|
max_value : float
|
|
75
85
|
The upper limit for the randomized value.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
float
|
|
90
|
+
A random float that is less than max_value.
|
|
76
91
|
"""
|
|
77
92
|
return random.uniform(0, max_value)
|
|
78
93
|
|
flwr/server/client_manager.py
CHANGED
|
@@ -47,6 +47,7 @@ class ClientManager(ABC):
|
|
|
47
47
|
Parameters
|
|
48
48
|
----------
|
|
49
49
|
client : flwr.server.client_proxy.ClientProxy
|
|
50
|
+
The ClientProxy of the Client to register.
|
|
50
51
|
|
|
51
52
|
Returns
|
|
52
53
|
-------
|
|
@@ -64,6 +65,7 @@ class ClientManager(ABC):
|
|
|
64
65
|
Parameters
|
|
65
66
|
----------
|
|
66
67
|
client : flwr.server.client_proxy.ClientProxy
|
|
68
|
+
The ClientProxy of the Client to unregister.
|
|
67
69
|
"""
|
|
68
70
|
|
|
69
71
|
@abstractmethod
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
"""Flower ClientProxy implementation for Driver API."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import time
|
|
19
18
|
from typing import Optional
|
|
20
19
|
|
|
21
20
|
from flwr import common
|
|
@@ -25,8 +24,6 @@ from flwr.server.client_proxy import ClientProxy
|
|
|
25
24
|
|
|
26
25
|
from ..driver.driver import Driver
|
|
27
26
|
|
|
28
|
-
SLEEP_TIME = 1
|
|
29
|
-
|
|
30
27
|
|
|
31
28
|
class DriverClientProxy(ClientProxy):
|
|
32
29
|
"""Flower client proxy which delegates work using the Driver API."""
|
|
@@ -122,29 +119,18 @@ class DriverClientProxy(ClientProxy):
|
|
|
122
119
|
ttl=timeout,
|
|
123
120
|
)
|
|
124
121
|
|
|
125
|
-
#
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
if
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
msg: Message = messages[0]
|
|
141
|
-
if msg.has_error():
|
|
142
|
-
raise ValueError(
|
|
143
|
-
f"Message contains an Error (reason: {msg.error.reason}). "
|
|
144
|
-
"It originated during client-side execution of a message."
|
|
145
|
-
)
|
|
146
|
-
return msg.content
|
|
147
|
-
|
|
148
|
-
if timeout is not None and time.time() > start_time + timeout:
|
|
149
|
-
raise RuntimeError("Timeout reached")
|
|
150
|
-
time.sleep(SLEEP_TIME)
|
|
122
|
+
# Send message and wait for reply
|
|
123
|
+
messages = list(self.driver.send_and_receive(messages=[message]))
|
|
124
|
+
|
|
125
|
+
# A single reply is expected
|
|
126
|
+
if len(messages) != 1:
|
|
127
|
+
raise ValueError(f"Expected one Message but got: {len(messages)}")
|
|
128
|
+
|
|
129
|
+
# Only messages without errors can be handled beyond these point
|
|
130
|
+
msg: Message = messages[0]
|
|
131
|
+
if msg.has_error():
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"Message contains an Error (reason: {msg.error.reason}). "
|
|
134
|
+
"It originated during client-side execution of a message."
|
|
135
|
+
)
|
|
136
|
+
return msg.content
|