flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/flower_toml.py +4 -48
- flwr/cli/new/new.py +6 -3
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
- flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +39 -5
- flwr/client/client_app.py +1 -47
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
- flwr/common/grpc.py +3 -3
- flwr/common/logger.py +78 -15
- flwr/common/object_ref.py +140 -0
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
- flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
- flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
- flwr/server/compat/app.py +2 -1
- flwr/server/driver/grpc_driver.py +4 -4
- flwr/server/history.py +22 -15
- flwr/server/run_serverapp.py +22 -4
- flwr/server/server.py +27 -23
- flwr/server/server_app.py +1 -47
- flwr/server/server_config.py +9 -0
- flwr/server/strategy/fedavg.py +2 -0
- flwr/server/superlink/fleet/vce/vce_api.py +9 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +3 -0
- flwr/server/workflow/constant.py +32 -0
- flwr/server/workflow/default_workflows.py +52 -57
- flwr/server/workflow/secure_aggregation/__init__.py +24 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
RECORD_KEY_STATE = "secaggplus_state"
|
21
21
|
RECORD_KEY_CONFIGS = "secaggplus_configs"
|
22
|
+
RATIO_QUANTIZATION_RANGE = 1073741824 # 1 << 30
|
22
23
|
|
23
24
|
|
24
25
|
class Stage:
|
@@ -26,9 +27,9 @@ class Stage:
|
|
26
27
|
|
27
28
|
SETUP = "setup"
|
28
29
|
SHARE_KEYS = "share_keys"
|
29
|
-
|
30
|
+
COLLECT_MASKED_VECTORS = "collect_masked_vectors"
|
30
31
|
UNMASK = "unmask"
|
31
|
-
_stages = (SETUP, SHARE_KEYS,
|
32
|
+
_stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_VECTORS, UNMASK)
|
32
33
|
|
33
34
|
@classmethod
|
34
35
|
def all(cls) -> tuple[str, str, str, str]:
|
@@ -45,12 +46,12 @@ class Key:
|
|
45
46
|
|
46
47
|
STAGE = "stage"
|
47
48
|
SAMPLE_NUMBER = "sample_num"
|
48
|
-
SECURE_ID = "secure_id"
|
49
49
|
SHARE_NUMBER = "share_num"
|
50
50
|
THRESHOLD = "threshold"
|
51
51
|
CLIPPING_RANGE = "clipping_range"
|
52
52
|
TARGET_RANGE = "target_range"
|
53
53
|
MOD_RANGE = "mod_range"
|
54
|
+
MAX_WEIGHT = "max_weight"
|
54
55
|
PUBLIC_KEY_1 = "pk1"
|
55
56
|
PUBLIC_KEY_2 = "pk2"
|
56
57
|
DESTINATION_LIST = "dsts"
|
@@ -58,9 +59,9 @@ class Key:
|
|
58
59
|
SOURCE_LIST = "srcs"
|
59
60
|
PARAMETERS = "params"
|
60
61
|
MASKED_PARAMETERS = "masked_params"
|
61
|
-
|
62
|
-
|
63
|
-
|
62
|
+
ACTIVE_NODE_ID_LIST = "active_nids"
|
63
|
+
DEAD_NODE_ID_LIST = "dead_nids"
|
64
|
+
NODE_ID_LIST = "nids"
|
64
65
|
SHARE_LIST = "shares"
|
65
66
|
|
66
67
|
def __new__(cls) -> Key:
|
@@ -23,16 +23,16 @@ from flwr.common.typing import NDArrayInt
|
|
23
23
|
|
24
24
|
|
25
25
|
def share_keys_plaintext_concat(
|
26
|
-
|
26
|
+
src_node_id: int, dst_node_id: int, b_share: bytes, sk_share: bytes
|
27
27
|
) -> bytes:
|
28
28
|
"""Combine arguments to bytes.
|
29
29
|
|
30
30
|
Parameters
|
31
31
|
----------
|
32
|
-
|
33
|
-
the
|
34
|
-
|
35
|
-
the
|
32
|
+
src_node_id : int
|
33
|
+
the node ID of the source.
|
34
|
+
dst_node_id : int
|
35
|
+
the node ID of the destination.
|
36
36
|
b_share : bytes
|
37
37
|
the private key share of the source sent to the destination.
|
38
38
|
sk_share : bytes
|
@@ -45,8 +45,8 @@ def share_keys_plaintext_concat(
|
|
45
45
|
"""
|
46
46
|
return b"".join(
|
47
47
|
[
|
48
|
-
int.to_bytes(
|
49
|
-
int.to_bytes(
|
48
|
+
int.to_bytes(src_node_id, 8, "little", signed=True),
|
49
|
+
int.to_bytes(dst_node_id, 8, "little", signed=True),
|
50
50
|
int.to_bytes(len(b_share), 4, "little"),
|
51
51
|
b_share,
|
52
52
|
sk_share,
|
@@ -64,21 +64,21 @@ def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, by
|
|
64
64
|
|
65
65
|
Returns
|
66
66
|
-------
|
67
|
-
|
68
|
-
the
|
69
|
-
|
70
|
-
the
|
67
|
+
src_node_id : int
|
68
|
+
the node ID of the source.
|
69
|
+
dst_node_id : int
|
70
|
+
the node ID of the destination.
|
71
71
|
b_share : bytes
|
72
72
|
the private key share of the source sent to the destination.
|
73
73
|
sk_share : bytes
|
74
74
|
the secret key share of the source sent to the destination.
|
75
75
|
"""
|
76
76
|
src, dst, mark = (
|
77
|
-
int.from_bytes(plaintext[:
|
78
|
-
int.from_bytes(plaintext[
|
79
|
-
int.from_bytes(plaintext[
|
77
|
+
int.from_bytes(plaintext[:8], "little", signed=True),
|
78
|
+
int.from_bytes(plaintext[8:16], "little", signed=True),
|
79
|
+
int.from_bytes(plaintext[16:20], "little"),
|
80
80
|
)
|
81
|
-
ret = (src, dst, plaintext[
|
81
|
+
ret = (src, dst, plaintext[20 : 20 + mark], plaintext[20 + mark :])
|
82
82
|
return ret
|
83
83
|
|
84
84
|
|
flwr/server/compat/app.py
CHANGED
@@ -127,9 +127,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
127
127
|
)
|
128
128
|
log(
|
129
129
|
INFO,
|
130
|
-
"Starting Flower
|
130
|
+
"Starting Flower ServerApp, config: %s",
|
131
131
|
initialized_config,
|
132
132
|
)
|
133
|
+
log(INFO, "")
|
133
134
|
|
134
135
|
# Start the thread updating nodes
|
135
136
|
thread, f_stop = start_update_client_manager_thread(
|
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Flower driver service client."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import
|
18
|
+
from logging import DEBUG, ERROR, WARNING
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
import grpc
|
@@ -70,19 +70,19 @@ class GrpcDriver:
|
|
70
70
|
root_certificates=self.root_certificates,
|
71
71
|
)
|
72
72
|
self.stub = DriverStub(self.channel)
|
73
|
-
log(
|
73
|
+
log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
|
74
74
|
|
75
75
|
def disconnect(self) -> None:
|
76
76
|
"""Disconnect from the Driver API."""
|
77
77
|
event(EventType.DRIVER_DISCONNECT)
|
78
78
|
if self.channel is None or self.stub is None:
|
79
|
-
log(
|
79
|
+
log(DEBUG, "Already disconnected")
|
80
80
|
return
|
81
81
|
channel = self.channel
|
82
82
|
self.channel = None
|
83
83
|
self.stub = None
|
84
84
|
channel.close()
|
85
|
-
log(
|
85
|
+
log(DEBUG, "[Driver] Disconnected")
|
86
86
|
|
87
87
|
def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
|
88
88
|
"""Request for run ID."""
|
flwr/server/history.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
"""Training history."""
|
16
16
|
|
17
17
|
|
18
|
+
import pprint
|
18
19
|
from functools import reduce
|
19
20
|
from typing import Dict, List, Tuple
|
20
21
|
|
@@ -90,29 +91,35 @@ class History:
|
|
90
91
|
"""
|
91
92
|
rep = ""
|
92
93
|
if self.losses_distributed:
|
93
|
-
rep += "History (loss, distributed):\n" +
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
94
|
+
rep += "History (loss, distributed):\n" + pprint.pformat(
|
95
|
+
reduce(
|
96
|
+
lambda a, b: a + b,
|
97
|
+
[
|
98
|
+
f"\tround {server_round}: {loss}\n"
|
99
|
+
for server_round, loss in self.losses_distributed
|
100
|
+
],
|
101
|
+
)
|
99
102
|
)
|
100
103
|
if self.losses_centralized:
|
101
|
-
rep += "History (loss, centralized):\n" +
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
104
|
+
rep += "History (loss, centralized):\n" + pprint.pformat(
|
105
|
+
reduce(
|
106
|
+
lambda a, b: a + b,
|
107
|
+
[
|
108
|
+
f"\tround {server_round}: {loss}\n"
|
109
|
+
for server_round, loss in self.losses_centralized
|
110
|
+
],
|
111
|
+
)
|
107
112
|
)
|
108
113
|
if self.metrics_distributed_fit:
|
109
|
-
rep += "History (metrics, distributed, fit):\n" +
|
114
|
+
rep += "History (metrics, distributed, fit):\n" + pprint.pformat(
|
110
115
|
self.metrics_distributed_fit
|
111
116
|
)
|
112
117
|
if self.metrics_distributed:
|
113
|
-
rep += "History (metrics, distributed, evaluate):\n" +
|
118
|
+
rep += "History (metrics, distributed, evaluate):\n" + pprint.pformat(
|
114
119
|
self.metrics_distributed
|
115
120
|
)
|
116
121
|
if self.metrics_centralized:
|
117
|
-
rep += "History (metrics, centralized):\n" +
|
122
|
+
rep += "History (metrics, centralized):\n" + pprint.pformat(
|
123
|
+
self.metrics_centralized
|
124
|
+
)
|
118
125
|
return rep
|
flwr/server/run_serverapp.py
CHANGED
@@ -17,15 +17,16 @@
|
|
17
17
|
|
18
18
|
import argparse
|
19
19
|
import sys
|
20
|
-
from logging import DEBUG, WARN
|
20
|
+
from logging import DEBUG, INFO, WARN
|
21
21
|
from pathlib import Path
|
22
22
|
from typing import Optional
|
23
23
|
|
24
24
|
from flwr.common import Context, EventType, RecordSet, event
|
25
|
-
from flwr.common.logger import log
|
25
|
+
from flwr.common.logger import log, update_console_handler
|
26
|
+
from flwr.common.object_ref import load_app
|
26
27
|
|
27
28
|
from .driver.driver import Driver
|
28
|
-
from .server_app import
|
29
|
+
from .server_app import LoadServerAppError, ServerApp
|
29
30
|
|
30
31
|
|
31
32
|
def run(
|
@@ -47,7 +48,13 @@ def run(
|
|
47
48
|
# Load ServerApp if needed
|
48
49
|
def _load() -> ServerApp:
|
49
50
|
if server_app_attr:
|
50
|
-
server_app: ServerApp =
|
51
|
+
server_app: ServerApp = load_app(server_app_attr, LoadServerAppError)
|
52
|
+
|
53
|
+
if not isinstance(server_app, ServerApp):
|
54
|
+
raise LoadServerAppError(
|
55
|
+
f"Attribute {server_app_attr} is not of type {ServerApp}",
|
56
|
+
) from None
|
57
|
+
|
51
58
|
if loaded_server_app:
|
52
59
|
server_app = loaded_server_app
|
53
60
|
return server_app
|
@@ -69,6 +76,12 @@ def run_server_app() -> None:
|
|
69
76
|
|
70
77
|
args = _parse_args_run_server_app().parse_args()
|
71
78
|
|
79
|
+
update_console_handler(
|
80
|
+
level=DEBUG if args.verbose else INFO,
|
81
|
+
timestamps=args.verbose,
|
82
|
+
colored=True,
|
83
|
+
)
|
84
|
+
|
72
85
|
# Obtain certificates
|
73
86
|
if args.insecure:
|
74
87
|
if args.root_certificates is not None:
|
@@ -146,6 +159,11 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
146
159
|
help="Run the server app without HTTPS. By default, the app runs with "
|
147
160
|
"HTTPS enabled. Use this flag only if you understand the risks.",
|
148
161
|
)
|
162
|
+
parser.add_argument(
|
163
|
+
"--verbose",
|
164
|
+
action="store_true",
|
165
|
+
help="Set the logging to `DEBUG`.",
|
166
|
+
)
|
149
167
|
parser.add_argument(
|
150
168
|
"--root-certificates",
|
151
169
|
metavar="ROOT_CERT",
|
flwr/server/server.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import concurrent.futures
|
19
|
+
import io
|
19
20
|
import timeit
|
20
21
|
from logging import INFO, WARN
|
21
22
|
from typing import Dict, List, Optional, Tuple, Union
|
@@ -83,14 +84,14 @@ class Server:
|
|
83
84
|
return self._client_manager
|
84
85
|
|
85
86
|
# pylint: disable=too-many-locals
|
86
|
-
def fit(self, num_rounds: int, timeout: Optional[float]) -> History:
|
87
|
+
def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]:
|
87
88
|
"""Run federated averaging for a number of rounds."""
|
88
89
|
history = History()
|
89
90
|
|
90
91
|
# Initialize parameters
|
91
|
-
log(INFO, "
|
92
|
+
log(INFO, "[INIT]")
|
92
93
|
self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout)
|
93
|
-
log(INFO, "Evaluating initial parameters")
|
94
|
+
log(INFO, "Evaluating initial global parameters")
|
94
95
|
res = self.strategy.evaluate(0, parameters=self.parameters)
|
95
96
|
if res is not None:
|
96
97
|
log(
|
@@ -103,10 +104,11 @@ class Server:
|
|
103
104
|
history.add_metrics_centralized(server_round=0, metrics=res[1])
|
104
105
|
|
105
106
|
# Run federated learning for num_rounds
|
106
|
-
log(INFO, "FL starting")
|
107
107
|
start_time = timeit.default_timer()
|
108
108
|
|
109
109
|
for current_round in range(1, num_rounds + 1):
|
110
|
+
log(INFO, "")
|
111
|
+
log(INFO, "[ROUND %s]", current_round)
|
110
112
|
# Train model and replace previous global model
|
111
113
|
res_fit = self.fit_round(
|
112
114
|
server_round=current_round,
|
@@ -152,8 +154,7 @@ class Server:
|
|
152
154
|
# Bookkeeping
|
153
155
|
end_time = timeit.default_timer()
|
154
156
|
elapsed = end_time - start_time
|
155
|
-
|
156
|
-
return history
|
157
|
+
return history, elapsed
|
157
158
|
|
158
159
|
def evaluate_round(
|
159
160
|
self,
|
@@ -170,12 +171,11 @@ class Server:
|
|
170
171
|
client_manager=self._client_manager,
|
171
172
|
)
|
172
173
|
if not client_instructions:
|
173
|
-
log(INFO, "
|
174
|
+
log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
|
174
175
|
return None
|
175
176
|
log(
|
176
177
|
INFO,
|
177
|
-
"
|
178
|
-
server_round,
|
178
|
+
"configure_evaluate: strategy sampled %s clients (out of %s)",
|
179
179
|
len(client_instructions),
|
180
180
|
self._client_manager.num_available(),
|
181
181
|
)
|
@@ -189,8 +189,7 @@ class Server:
|
|
189
189
|
)
|
190
190
|
log(
|
191
191
|
INFO,
|
192
|
-
"
|
193
|
-
server_round,
|
192
|
+
"aggregate_evaluate: received %s results and %s failures",
|
194
193
|
len(results),
|
195
194
|
len(failures),
|
196
195
|
)
|
@@ -220,12 +219,11 @@ class Server:
|
|
220
219
|
)
|
221
220
|
|
222
221
|
if not client_instructions:
|
223
|
-
log(INFO, "
|
222
|
+
log(INFO, "configure_fit: no clients selected, cancel")
|
224
223
|
return None
|
225
224
|
log(
|
226
225
|
INFO,
|
227
|
-
"
|
228
|
-
server_round,
|
226
|
+
"configure_fit: strategy sampled %s clients (out of %s)",
|
229
227
|
len(client_instructions),
|
230
228
|
self._client_manager.num_available(),
|
231
229
|
)
|
@@ -239,8 +237,7 @@ class Server:
|
|
239
237
|
)
|
240
238
|
log(
|
241
239
|
INFO,
|
242
|
-
"
|
243
|
-
server_round,
|
240
|
+
"aggregate_fit: received %s results and %s failures",
|
244
241
|
len(results),
|
245
242
|
len(failures),
|
246
243
|
)
|
@@ -275,7 +272,7 @@ class Server:
|
|
275
272
|
client_manager=self._client_manager
|
276
273
|
)
|
277
274
|
if parameters is not None:
|
278
|
-
log(INFO, "Using initial parameters provided by strategy")
|
275
|
+
log(INFO, "Using initial global parameters provided by strategy")
|
279
276
|
return parameters
|
280
277
|
|
281
278
|
# Get initial parameters from one of the clients
|
@@ -483,12 +480,19 @@ def run_fl(
|
|
483
480
|
config: ServerConfig,
|
484
481
|
) -> History:
|
485
482
|
"""Train a model on the given server and return the History object."""
|
486
|
-
hist = server.fit(
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
log(INFO, "
|
491
|
-
log(INFO, "
|
483
|
+
hist, elapsed_time = server.fit(
|
484
|
+
num_rounds=config.num_rounds, timeout=config.round_timeout
|
485
|
+
)
|
486
|
+
|
487
|
+
log(INFO, "")
|
488
|
+
log(INFO, "[SUMMARY]")
|
489
|
+
log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time)
|
490
|
+
for idx, line in enumerate(io.StringIO(str(hist))):
|
491
|
+
if idx == 0:
|
492
|
+
log(INFO, "%s", line.strip("\n"))
|
493
|
+
else:
|
494
|
+
log(INFO, "\t%s", line.strip("\n"))
|
495
|
+
log(INFO, "")
|
492
496
|
|
493
497
|
# Graceful shutdown
|
494
498
|
server.disconnect_all_clients(timeout=config.round_timeout)
|
flwr/server/server_app.py
CHANGED
@@ -15,8 +15,7 @@
|
|
15
15
|
"""Flower ServerApp."""
|
16
16
|
|
17
17
|
|
18
|
-
import
|
19
|
-
from typing import Callable, Optional, cast
|
18
|
+
from typing import Callable, Optional
|
20
19
|
|
21
20
|
from flwr.common import Context, RecordSet
|
22
21
|
from flwr.server.strategy import Strategy
|
@@ -132,48 +131,3 @@ class ServerApp:
|
|
132
131
|
|
133
132
|
class LoadServerAppError(Exception):
|
134
133
|
"""Error when trying to load `ServerApp`."""
|
135
|
-
|
136
|
-
|
137
|
-
def load_server_app(module_attribute_str: str) -> ServerApp:
|
138
|
-
"""Load the `ServerApp` object specified in a module attribute string.
|
139
|
-
|
140
|
-
The module/attribute string should have the form <module>:<attribute>. Valid
|
141
|
-
examples include `server:app` and `project.package.module:wrapper.app`. It
|
142
|
-
must refer to a module on the PYTHONPATH, the module needs to have the specified
|
143
|
-
attribute, and the attribute must be of type `ServerApp`.
|
144
|
-
"""
|
145
|
-
module_str, _, attributes_str = module_attribute_str.partition(":")
|
146
|
-
if not module_str:
|
147
|
-
raise LoadServerAppError(
|
148
|
-
f"Missing module in {module_attribute_str}",
|
149
|
-
) from None
|
150
|
-
if not attributes_str:
|
151
|
-
raise LoadServerAppError(
|
152
|
-
f"Missing attribute in {module_attribute_str}",
|
153
|
-
) from None
|
154
|
-
|
155
|
-
# Load module
|
156
|
-
try:
|
157
|
-
module = importlib.import_module(module_str)
|
158
|
-
except ModuleNotFoundError:
|
159
|
-
raise LoadServerAppError(
|
160
|
-
f"Unable to load module {module_str}",
|
161
|
-
) from None
|
162
|
-
|
163
|
-
# Recursively load attribute
|
164
|
-
attribute = module
|
165
|
-
try:
|
166
|
-
for attribute_str in attributes_str.split("."):
|
167
|
-
attribute = getattr(attribute, attribute_str)
|
168
|
-
except AttributeError:
|
169
|
-
raise LoadServerAppError(
|
170
|
-
f"Unable to load attribute {attributes_str} from module {module_str}",
|
171
|
-
) from None
|
172
|
-
|
173
|
-
# Check type
|
174
|
-
if not isinstance(attribute, ServerApp):
|
175
|
-
raise LoadServerAppError(
|
176
|
-
f"Attribute {attributes_str} is not of type {ServerApp}",
|
177
|
-
) from None
|
178
|
-
|
179
|
-
return cast(ServerApp, attribute)
|
flwr/server/server_config.py
CHANGED
@@ -29,3 +29,12 @@ class ServerConfig:
|
|
29
29
|
|
30
30
|
num_rounds: int = 1
|
31
31
|
round_timeout: Optional[float] = None
|
32
|
+
|
33
|
+
def __repr__(self) -> str:
|
34
|
+
"""Return the string representation of the ServerConfig."""
|
35
|
+
timeout_string = (
|
36
|
+
"no round_timeout"
|
37
|
+
if self.round_timeout is None
|
38
|
+
else f"round_timeout={self.round_timeout}s"
|
39
|
+
)
|
40
|
+
return f"num_rounds={self.num_rounds}, {timeout_string}"
|
flwr/server/strategy/fedavg.py
CHANGED
@@ -84,6 +84,8 @@ class FedAvg(Strategy):
|
|
84
84
|
Metrics aggregation function, optional.
|
85
85
|
evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn]
|
86
86
|
Metrics aggregation function, optional.
|
87
|
+
inplace : bool (default: True)
|
88
|
+
Enable (True) or disable (False) in-place aggregation of model updates.
|
87
89
|
"""
|
88
90
|
|
89
91
|
# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
|
@@ -21,9 +21,10 @@ import traceback
|
|
21
21
|
from logging import DEBUG, ERROR, INFO, WARN
|
22
22
|
from typing import Callable, Dict, List, Optional
|
23
23
|
|
24
|
-
from flwr.client.client_app import ClientApp, LoadClientAppError
|
24
|
+
from flwr.client.client_app import ClientApp, LoadClientAppError
|
25
25
|
from flwr.client.node_state import NodeState
|
26
26
|
from flwr.common.logger import log
|
27
|
+
from flwr.common.object_ref import load_app
|
27
28
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
28
29
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
29
30
|
from flwr.server.superlink.state import StateFactory
|
@@ -305,7 +306,13 @@ def start_vce(
|
|
305
306
|
def _load() -> ClientApp:
|
306
307
|
|
307
308
|
if client_app_attr:
|
308
|
-
app: ClientApp =
|
309
|
+
app: ClientApp = load_app(client_app_attr, LoadClientAppError)
|
310
|
+
|
311
|
+
if not isinstance(app, ClientApp):
|
312
|
+
raise LoadClientAppError(
|
313
|
+
f"Attribute {client_app_attr} is not of type {ClientApp}",
|
314
|
+
) from None
|
315
|
+
|
309
316
|
if client_app:
|
310
317
|
app = client_app
|
311
318
|
return app
|
@@ -122,7 +122,8 @@ class InMemoryState(State):
|
|
122
122
|
task_res.task_id = str(task_id)
|
123
123
|
task_res.task.created_at = created_at.isoformat()
|
124
124
|
task_res.task.ttl = ttl.isoformat()
|
125
|
-
self.
|
125
|
+
with self.lock:
|
126
|
+
self.task_res_store[task_id] = task_res
|
126
127
|
|
127
128
|
# Return the new task_id
|
128
129
|
return task_id
|
@@ -132,46 +133,47 @@ class InMemoryState(State):
|
|
132
133
|
if limit is not None and limit < 1:
|
133
134
|
raise AssertionError("`limit` must be >= 1")
|
134
135
|
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
136
|
+
with self.lock:
|
137
|
+
# Find TaskRes that were not delivered yet
|
138
|
+
task_res_list: List[TaskRes] = []
|
139
|
+
for _, task_res in self.task_res_store.items():
|
140
|
+
if (
|
141
|
+
UUID(task_res.task.ancestry[0]) in task_ids
|
142
|
+
and task_res.task.delivered_at == ""
|
143
|
+
):
|
144
|
+
task_res_list.append(task_res)
|
145
|
+
if limit and len(task_res_list) == limit:
|
146
|
+
break
|
145
147
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
148
|
+
# Mark all of them as delivered
|
149
|
+
delivered_at = now().isoformat()
|
150
|
+
for task_res in task_res_list:
|
151
|
+
task_res.task.delivered_at = delivered_at
|
150
152
|
|
151
|
-
|
152
|
-
|
153
|
+
# Return TaskRes
|
154
|
+
return task_res_list
|
153
155
|
|
154
156
|
def delete_tasks(self, task_ids: Set[UUID]) -> None:
|
155
157
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
156
158
|
task_ins_to_be_deleted: Set[UUID] = set()
|
157
159
|
task_res_to_be_deleted: Set[UUID] = set()
|
158
160
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
161
|
+
with self.lock:
|
162
|
+
for task_ins_id in task_ids:
|
163
|
+
# Find the task_id of the matching task_res
|
164
|
+
for task_res_id, task_res in self.task_res_store.items():
|
165
|
+
if UUID(task_res.task.ancestry[0]) != task_ins_id:
|
166
|
+
continue
|
167
|
+
if task_res.task.delivered_at == "":
|
168
|
+
continue
|
169
|
+
|
170
|
+
task_ins_to_be_deleted.add(task_ins_id)
|
171
|
+
task_res_to_be_deleted.add(task_res_id)
|
172
|
+
|
173
|
+
for task_id in task_ins_to_be_deleted:
|
172
174
|
del self.task_ins_store[task_id]
|
173
|
-
|
174
|
-
|
175
|
+
for task_id in task_res_to_be_deleted:
|
176
|
+
del self.task_res_store[task_id]
|
175
177
|
|
176
178
|
def num_task_ins(self) -> int:
|
177
179
|
"""Calculate the number of task_ins in store.
|
flwr/server/workflow/__init__.py
CHANGED
@@ -0,0 +1,32 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Constants for default workflows."""
|
16
|
+
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
MAIN_CONFIGS_RECORD = "config"
|
21
|
+
MAIN_PARAMS_RECORD = "parameters"
|
22
|
+
|
23
|
+
|
24
|
+
class Key:
|
25
|
+
"""Constants for default workflows."""
|
26
|
+
|
27
|
+
CURRENT_ROUND = "current_round"
|
28
|
+
START_TIME = "start_time"
|
29
|
+
|
30
|
+
def __new__(cls) -> Key:
|
31
|
+
"""Prevent instantiation."""
|
32
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|