flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/cli/flower_toml.py +4 -48
- flwr/cli/new/new.py +6 -3
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
- flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +39 -5
- flwr/client/client_app.py +1 -47
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
- flwr/common/grpc.py +3 -3
- flwr/common/logger.py +78 -15
- flwr/common/object_ref.py +140 -0
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
- flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
- flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
- flwr/server/compat/app.py +2 -1
- flwr/server/driver/grpc_driver.py +4 -4
- flwr/server/history.py +22 -15
- flwr/server/run_serverapp.py +22 -4
- flwr/server/server.py +27 -23
- flwr/server/server_app.py +1 -47
- flwr/server/server_config.py +9 -0
- flwr/server/strategy/fedavg.py +2 -0
- flwr/server/superlink/fleet/vce/vce_api.py +9 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +3 -0
- flwr/server/workflow/constant.py +32 -0
- flwr/server/workflow/default_workflows.py +52 -57
- flwr/server/workflow/secure_aggregation/__init__.py +24 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
RECORD_KEY_STATE = "secaggplus_state"
|
21
21
|
RECORD_KEY_CONFIGS = "secaggplus_configs"
|
22
|
+
RATIO_QUANTIZATION_RANGE = 1073741824 # 1 << 30
|
22
23
|
|
23
24
|
|
24
25
|
class Stage:
|
@@ -26,9 +27,9 @@ class Stage:
|
|
26
27
|
|
27
28
|
SETUP = "setup"
|
28
29
|
SHARE_KEYS = "share_keys"
|
29
|
-
|
30
|
+
COLLECT_MASKED_VECTORS = "collect_masked_vectors"
|
30
31
|
UNMASK = "unmask"
|
31
|
-
_stages = (SETUP, SHARE_KEYS,
|
32
|
+
_stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_VECTORS, UNMASK)
|
32
33
|
|
33
34
|
@classmethod
|
34
35
|
def all(cls) -> tuple[str, str, str, str]:
|
@@ -45,12 +46,12 @@ class Key:
|
|
45
46
|
|
46
47
|
STAGE = "stage"
|
47
48
|
SAMPLE_NUMBER = "sample_num"
|
48
|
-
SECURE_ID = "secure_id"
|
49
49
|
SHARE_NUMBER = "share_num"
|
50
50
|
THRESHOLD = "threshold"
|
51
51
|
CLIPPING_RANGE = "clipping_range"
|
52
52
|
TARGET_RANGE = "target_range"
|
53
53
|
MOD_RANGE = "mod_range"
|
54
|
+
MAX_WEIGHT = "max_weight"
|
54
55
|
PUBLIC_KEY_1 = "pk1"
|
55
56
|
PUBLIC_KEY_2 = "pk2"
|
56
57
|
DESTINATION_LIST = "dsts"
|
@@ -58,9 +59,9 @@ class Key:
|
|
58
59
|
SOURCE_LIST = "srcs"
|
59
60
|
PARAMETERS = "params"
|
60
61
|
MASKED_PARAMETERS = "masked_params"
|
61
|
-
|
62
|
-
|
63
|
-
|
62
|
+
ACTIVE_NODE_ID_LIST = "active_nids"
|
63
|
+
DEAD_NODE_ID_LIST = "dead_nids"
|
64
|
+
NODE_ID_LIST = "nids"
|
64
65
|
SHARE_LIST = "shares"
|
65
66
|
|
66
67
|
def __new__(cls) -> Key:
|
@@ -23,16 +23,16 @@ from flwr.common.typing import NDArrayInt
|
|
23
23
|
|
24
24
|
|
25
25
|
def share_keys_plaintext_concat(
|
26
|
-
|
26
|
+
src_node_id: int, dst_node_id: int, b_share: bytes, sk_share: bytes
|
27
27
|
) -> bytes:
|
28
28
|
"""Combine arguments to bytes.
|
29
29
|
|
30
30
|
Parameters
|
31
31
|
----------
|
32
|
-
|
33
|
-
the
|
34
|
-
|
35
|
-
the
|
32
|
+
src_node_id : int
|
33
|
+
the node ID of the source.
|
34
|
+
dst_node_id : int
|
35
|
+
the node ID of the destination.
|
36
36
|
b_share : bytes
|
37
37
|
the private key share of the source sent to the destination.
|
38
38
|
sk_share : bytes
|
@@ -45,8 +45,8 @@ def share_keys_plaintext_concat(
|
|
45
45
|
"""
|
46
46
|
return b"".join(
|
47
47
|
[
|
48
|
-
int.to_bytes(
|
49
|
-
int.to_bytes(
|
48
|
+
int.to_bytes(src_node_id, 8, "little", signed=True),
|
49
|
+
int.to_bytes(dst_node_id, 8, "little", signed=True),
|
50
50
|
int.to_bytes(len(b_share), 4, "little"),
|
51
51
|
b_share,
|
52
52
|
sk_share,
|
@@ -64,21 +64,21 @@ def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, by
|
|
64
64
|
|
65
65
|
Returns
|
66
66
|
-------
|
67
|
-
|
68
|
-
the
|
69
|
-
|
70
|
-
the
|
67
|
+
src_node_id : int
|
68
|
+
the node ID of the source.
|
69
|
+
dst_node_id : int
|
70
|
+
the node ID of the destination.
|
71
71
|
b_share : bytes
|
72
72
|
the private key share of the source sent to the destination.
|
73
73
|
sk_share : bytes
|
74
74
|
the secret key share of the source sent to the destination.
|
75
75
|
"""
|
76
76
|
src, dst, mark = (
|
77
|
-
int.from_bytes(plaintext[:
|
78
|
-
int.from_bytes(plaintext[
|
79
|
-
int.from_bytes(plaintext[
|
77
|
+
int.from_bytes(plaintext[:8], "little", signed=True),
|
78
|
+
int.from_bytes(plaintext[8:16], "little", signed=True),
|
79
|
+
int.from_bytes(plaintext[16:20], "little"),
|
80
80
|
)
|
81
|
-
ret = (src, dst, plaintext[
|
81
|
+
ret = (src, dst, plaintext[20 : 20 + mark], plaintext[20 + mark :])
|
82
82
|
return ret
|
83
83
|
|
84
84
|
|
flwr/server/compat/app.py
CHANGED
@@ -127,9 +127,10 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
127
127
|
)
|
128
128
|
log(
|
129
129
|
INFO,
|
130
|
-
"Starting Flower
|
130
|
+
"Starting Flower ServerApp, config: %s",
|
131
131
|
initialized_config,
|
132
132
|
)
|
133
|
+
log(INFO, "")
|
133
134
|
|
134
135
|
# Start the thread updating nodes
|
135
136
|
thread, f_stop = start_update_client_manager_thread(
|
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Flower driver service client."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import
|
18
|
+
from logging import DEBUG, ERROR, WARNING
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
import grpc
|
@@ -70,19 +70,19 @@ class GrpcDriver:
|
|
70
70
|
root_certificates=self.root_certificates,
|
71
71
|
)
|
72
72
|
self.stub = DriverStub(self.channel)
|
73
|
-
log(
|
73
|
+
log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
|
74
74
|
|
75
75
|
def disconnect(self) -> None:
|
76
76
|
"""Disconnect from the Driver API."""
|
77
77
|
event(EventType.DRIVER_DISCONNECT)
|
78
78
|
if self.channel is None or self.stub is None:
|
79
|
-
log(
|
79
|
+
log(DEBUG, "Already disconnected")
|
80
80
|
return
|
81
81
|
channel = self.channel
|
82
82
|
self.channel = None
|
83
83
|
self.stub = None
|
84
84
|
channel.close()
|
85
|
-
log(
|
85
|
+
log(DEBUG, "[Driver] Disconnected")
|
86
86
|
|
87
87
|
def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
|
88
88
|
"""Request for run ID."""
|
flwr/server/history.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
"""Training history."""
|
16
16
|
|
17
17
|
|
18
|
+
import pprint
|
18
19
|
from functools import reduce
|
19
20
|
from typing import Dict, List, Tuple
|
20
21
|
|
@@ -90,29 +91,35 @@ class History:
|
|
90
91
|
"""
|
91
92
|
rep = ""
|
92
93
|
if self.losses_distributed:
|
93
|
-
rep += "History (loss, distributed):\n" +
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
94
|
+
rep += "History (loss, distributed):\n" + pprint.pformat(
|
95
|
+
reduce(
|
96
|
+
lambda a, b: a + b,
|
97
|
+
[
|
98
|
+
f"\tround {server_round}: {loss}\n"
|
99
|
+
for server_round, loss in self.losses_distributed
|
100
|
+
],
|
101
|
+
)
|
99
102
|
)
|
100
103
|
if self.losses_centralized:
|
101
|
-
rep += "History (loss, centralized):\n" +
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
104
|
+
rep += "History (loss, centralized):\n" + pprint.pformat(
|
105
|
+
reduce(
|
106
|
+
lambda a, b: a + b,
|
107
|
+
[
|
108
|
+
f"\tround {server_round}: {loss}\n"
|
109
|
+
for server_round, loss in self.losses_centralized
|
110
|
+
],
|
111
|
+
)
|
107
112
|
)
|
108
113
|
if self.metrics_distributed_fit:
|
109
|
-
rep += "History (metrics, distributed, fit):\n" +
|
114
|
+
rep += "History (metrics, distributed, fit):\n" + pprint.pformat(
|
110
115
|
self.metrics_distributed_fit
|
111
116
|
)
|
112
117
|
if self.metrics_distributed:
|
113
|
-
rep += "History (metrics, distributed, evaluate):\n" +
|
118
|
+
rep += "History (metrics, distributed, evaluate):\n" + pprint.pformat(
|
114
119
|
self.metrics_distributed
|
115
120
|
)
|
116
121
|
if self.metrics_centralized:
|
117
|
-
rep += "History (metrics, centralized):\n" +
|
122
|
+
rep += "History (metrics, centralized):\n" + pprint.pformat(
|
123
|
+
self.metrics_centralized
|
124
|
+
)
|
118
125
|
return rep
|
flwr/server/run_serverapp.py
CHANGED
@@ -17,15 +17,16 @@
|
|
17
17
|
|
18
18
|
import argparse
|
19
19
|
import sys
|
20
|
-
from logging import DEBUG, WARN
|
20
|
+
from logging import DEBUG, INFO, WARN
|
21
21
|
from pathlib import Path
|
22
22
|
from typing import Optional
|
23
23
|
|
24
24
|
from flwr.common import Context, EventType, RecordSet, event
|
25
|
-
from flwr.common.logger import log
|
25
|
+
from flwr.common.logger import log, update_console_handler
|
26
|
+
from flwr.common.object_ref import load_app
|
26
27
|
|
27
28
|
from .driver.driver import Driver
|
28
|
-
from .server_app import
|
29
|
+
from .server_app import LoadServerAppError, ServerApp
|
29
30
|
|
30
31
|
|
31
32
|
def run(
|
@@ -47,7 +48,13 @@ def run(
|
|
47
48
|
# Load ServerApp if needed
|
48
49
|
def _load() -> ServerApp:
|
49
50
|
if server_app_attr:
|
50
|
-
server_app: ServerApp =
|
51
|
+
server_app: ServerApp = load_app(server_app_attr, LoadServerAppError)
|
52
|
+
|
53
|
+
if not isinstance(server_app, ServerApp):
|
54
|
+
raise LoadServerAppError(
|
55
|
+
f"Attribute {server_app_attr} is not of type {ServerApp}",
|
56
|
+
) from None
|
57
|
+
|
51
58
|
if loaded_server_app:
|
52
59
|
server_app = loaded_server_app
|
53
60
|
return server_app
|
@@ -69,6 +76,12 @@ def run_server_app() -> None:
|
|
69
76
|
|
70
77
|
args = _parse_args_run_server_app().parse_args()
|
71
78
|
|
79
|
+
update_console_handler(
|
80
|
+
level=DEBUG if args.verbose else INFO,
|
81
|
+
timestamps=args.verbose,
|
82
|
+
colored=True,
|
83
|
+
)
|
84
|
+
|
72
85
|
# Obtain certificates
|
73
86
|
if args.insecure:
|
74
87
|
if args.root_certificates is not None:
|
@@ -146,6 +159,11 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
146
159
|
help="Run the server app without HTTPS. By default, the app runs with "
|
147
160
|
"HTTPS enabled. Use this flag only if you understand the risks.",
|
148
161
|
)
|
162
|
+
parser.add_argument(
|
163
|
+
"--verbose",
|
164
|
+
action="store_true",
|
165
|
+
help="Set the logging to `DEBUG`.",
|
166
|
+
)
|
149
167
|
parser.add_argument(
|
150
168
|
"--root-certificates",
|
151
169
|
metavar="ROOT_CERT",
|
flwr/server/server.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import concurrent.futures
|
19
|
+
import io
|
19
20
|
import timeit
|
20
21
|
from logging import INFO, WARN
|
21
22
|
from typing import Dict, List, Optional, Tuple, Union
|
@@ -83,14 +84,14 @@ class Server:
|
|
83
84
|
return self._client_manager
|
84
85
|
|
85
86
|
# pylint: disable=too-many-locals
|
86
|
-
def fit(self, num_rounds: int, timeout: Optional[float]) -> History:
|
87
|
+
def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]:
|
87
88
|
"""Run federated averaging for a number of rounds."""
|
88
89
|
history = History()
|
89
90
|
|
90
91
|
# Initialize parameters
|
91
|
-
log(INFO, "
|
92
|
+
log(INFO, "[INIT]")
|
92
93
|
self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout)
|
93
|
-
log(INFO, "Evaluating initial parameters")
|
94
|
+
log(INFO, "Evaluating initial global parameters")
|
94
95
|
res = self.strategy.evaluate(0, parameters=self.parameters)
|
95
96
|
if res is not None:
|
96
97
|
log(
|
@@ -103,10 +104,11 @@ class Server:
|
|
103
104
|
history.add_metrics_centralized(server_round=0, metrics=res[1])
|
104
105
|
|
105
106
|
# Run federated learning for num_rounds
|
106
|
-
log(INFO, "FL starting")
|
107
107
|
start_time = timeit.default_timer()
|
108
108
|
|
109
109
|
for current_round in range(1, num_rounds + 1):
|
110
|
+
log(INFO, "")
|
111
|
+
log(INFO, "[ROUND %s]", current_round)
|
110
112
|
# Train model and replace previous global model
|
111
113
|
res_fit = self.fit_round(
|
112
114
|
server_round=current_round,
|
@@ -152,8 +154,7 @@ class Server:
|
|
152
154
|
# Bookkeeping
|
153
155
|
end_time = timeit.default_timer()
|
154
156
|
elapsed = end_time - start_time
|
155
|
-
|
156
|
-
return history
|
157
|
+
return history, elapsed
|
157
158
|
|
158
159
|
def evaluate_round(
|
159
160
|
self,
|
@@ -170,12 +171,11 @@ class Server:
|
|
170
171
|
client_manager=self._client_manager,
|
171
172
|
)
|
172
173
|
if not client_instructions:
|
173
|
-
log(INFO, "
|
174
|
+
log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
|
174
175
|
return None
|
175
176
|
log(
|
176
177
|
INFO,
|
177
|
-
"
|
178
|
-
server_round,
|
178
|
+
"configure_evaluate: strategy sampled %s clients (out of %s)",
|
179
179
|
len(client_instructions),
|
180
180
|
self._client_manager.num_available(),
|
181
181
|
)
|
@@ -189,8 +189,7 @@ class Server:
|
|
189
189
|
)
|
190
190
|
log(
|
191
191
|
INFO,
|
192
|
-
"
|
193
|
-
server_round,
|
192
|
+
"aggregate_evaluate: received %s results and %s failures",
|
194
193
|
len(results),
|
195
194
|
len(failures),
|
196
195
|
)
|
@@ -220,12 +219,11 @@ class Server:
|
|
220
219
|
)
|
221
220
|
|
222
221
|
if not client_instructions:
|
223
|
-
log(INFO, "
|
222
|
+
log(INFO, "configure_fit: no clients selected, cancel")
|
224
223
|
return None
|
225
224
|
log(
|
226
225
|
INFO,
|
227
|
-
"
|
228
|
-
server_round,
|
226
|
+
"configure_fit: strategy sampled %s clients (out of %s)",
|
229
227
|
len(client_instructions),
|
230
228
|
self._client_manager.num_available(),
|
231
229
|
)
|
@@ -239,8 +237,7 @@ class Server:
|
|
239
237
|
)
|
240
238
|
log(
|
241
239
|
INFO,
|
242
|
-
"
|
243
|
-
server_round,
|
240
|
+
"aggregate_fit: received %s results and %s failures",
|
244
241
|
len(results),
|
245
242
|
len(failures),
|
246
243
|
)
|
@@ -275,7 +272,7 @@ class Server:
|
|
275
272
|
client_manager=self._client_manager
|
276
273
|
)
|
277
274
|
if parameters is not None:
|
278
|
-
log(INFO, "Using initial parameters provided by strategy")
|
275
|
+
log(INFO, "Using initial global parameters provided by strategy")
|
279
276
|
return parameters
|
280
277
|
|
281
278
|
# Get initial parameters from one of the clients
|
@@ -483,12 +480,19 @@ def run_fl(
|
|
483
480
|
config: ServerConfig,
|
484
481
|
) -> History:
|
485
482
|
"""Train a model on the given server and return the History object."""
|
486
|
-
hist = server.fit(
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
log(INFO, "
|
491
|
-
log(INFO, "
|
483
|
+
hist, elapsed_time = server.fit(
|
484
|
+
num_rounds=config.num_rounds, timeout=config.round_timeout
|
485
|
+
)
|
486
|
+
|
487
|
+
log(INFO, "")
|
488
|
+
log(INFO, "[SUMMARY]")
|
489
|
+
log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time)
|
490
|
+
for idx, line in enumerate(io.StringIO(str(hist))):
|
491
|
+
if idx == 0:
|
492
|
+
log(INFO, "%s", line.strip("\n"))
|
493
|
+
else:
|
494
|
+
log(INFO, "\t%s", line.strip("\n"))
|
495
|
+
log(INFO, "")
|
492
496
|
|
493
497
|
# Graceful shutdown
|
494
498
|
server.disconnect_all_clients(timeout=config.round_timeout)
|
flwr/server/server_app.py
CHANGED
@@ -15,8 +15,7 @@
|
|
15
15
|
"""Flower ServerApp."""
|
16
16
|
|
17
17
|
|
18
|
-
import
|
19
|
-
from typing import Callable, Optional, cast
|
18
|
+
from typing import Callable, Optional
|
20
19
|
|
21
20
|
from flwr.common import Context, RecordSet
|
22
21
|
from flwr.server.strategy import Strategy
|
@@ -132,48 +131,3 @@ class ServerApp:
|
|
132
131
|
|
133
132
|
class LoadServerAppError(Exception):
|
134
133
|
"""Error when trying to load `ServerApp`."""
|
135
|
-
|
136
|
-
|
137
|
-
def load_server_app(module_attribute_str: str) -> ServerApp:
|
138
|
-
"""Load the `ServerApp` object specified in a module attribute string.
|
139
|
-
|
140
|
-
The module/attribute string should have the form <module>:<attribute>. Valid
|
141
|
-
examples include `server:app` and `project.package.module:wrapper.app`. It
|
142
|
-
must refer to a module on the PYTHONPATH, the module needs to have the specified
|
143
|
-
attribute, and the attribute must be of type `ServerApp`.
|
144
|
-
"""
|
145
|
-
module_str, _, attributes_str = module_attribute_str.partition(":")
|
146
|
-
if not module_str:
|
147
|
-
raise LoadServerAppError(
|
148
|
-
f"Missing module in {module_attribute_str}",
|
149
|
-
) from None
|
150
|
-
if not attributes_str:
|
151
|
-
raise LoadServerAppError(
|
152
|
-
f"Missing attribute in {module_attribute_str}",
|
153
|
-
) from None
|
154
|
-
|
155
|
-
# Load module
|
156
|
-
try:
|
157
|
-
module = importlib.import_module(module_str)
|
158
|
-
except ModuleNotFoundError:
|
159
|
-
raise LoadServerAppError(
|
160
|
-
f"Unable to load module {module_str}",
|
161
|
-
) from None
|
162
|
-
|
163
|
-
# Recursively load attribute
|
164
|
-
attribute = module
|
165
|
-
try:
|
166
|
-
for attribute_str in attributes_str.split("."):
|
167
|
-
attribute = getattr(attribute, attribute_str)
|
168
|
-
except AttributeError:
|
169
|
-
raise LoadServerAppError(
|
170
|
-
f"Unable to load attribute {attributes_str} from module {module_str}",
|
171
|
-
) from None
|
172
|
-
|
173
|
-
# Check type
|
174
|
-
if not isinstance(attribute, ServerApp):
|
175
|
-
raise LoadServerAppError(
|
176
|
-
f"Attribute {attributes_str} is not of type {ServerApp}",
|
177
|
-
) from None
|
178
|
-
|
179
|
-
return cast(ServerApp, attribute)
|
flwr/server/server_config.py
CHANGED
@@ -29,3 +29,12 @@ class ServerConfig:
|
|
29
29
|
|
30
30
|
num_rounds: int = 1
|
31
31
|
round_timeout: Optional[float] = None
|
32
|
+
|
33
|
+
def __repr__(self) -> str:
|
34
|
+
"""Return the string representation of the ServerConfig."""
|
35
|
+
timeout_string = (
|
36
|
+
"no round_timeout"
|
37
|
+
if self.round_timeout is None
|
38
|
+
else f"round_timeout={self.round_timeout}s"
|
39
|
+
)
|
40
|
+
return f"num_rounds={self.num_rounds}, {timeout_string}"
|
flwr/server/strategy/fedavg.py
CHANGED
@@ -84,6 +84,8 @@ class FedAvg(Strategy):
|
|
84
84
|
Metrics aggregation function, optional.
|
85
85
|
evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn]
|
86
86
|
Metrics aggregation function, optional.
|
87
|
+
inplace : bool (default: True)
|
88
|
+
Enable (True) or disable (False) in-place aggregation of model updates.
|
87
89
|
"""
|
88
90
|
|
89
91
|
# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
|
@@ -21,9 +21,10 @@ import traceback
|
|
21
21
|
from logging import DEBUG, ERROR, INFO, WARN
|
22
22
|
from typing import Callable, Dict, List, Optional
|
23
23
|
|
24
|
-
from flwr.client.client_app import ClientApp, LoadClientAppError
|
24
|
+
from flwr.client.client_app import ClientApp, LoadClientAppError
|
25
25
|
from flwr.client.node_state import NodeState
|
26
26
|
from flwr.common.logger import log
|
27
|
+
from flwr.common.object_ref import load_app
|
27
28
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
28
29
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
29
30
|
from flwr.server.superlink.state import StateFactory
|
@@ -305,7 +306,13 @@ def start_vce(
|
|
305
306
|
def _load() -> ClientApp:
|
306
307
|
|
307
308
|
if client_app_attr:
|
308
|
-
app: ClientApp =
|
309
|
+
app: ClientApp = load_app(client_app_attr, LoadClientAppError)
|
310
|
+
|
311
|
+
if not isinstance(app, ClientApp):
|
312
|
+
raise LoadClientAppError(
|
313
|
+
f"Attribute {client_app_attr} is not of type {ClientApp}",
|
314
|
+
) from None
|
315
|
+
|
309
316
|
if client_app:
|
310
317
|
app = client_app
|
311
318
|
return app
|
@@ -122,7 +122,8 @@ class InMemoryState(State):
|
|
122
122
|
task_res.task_id = str(task_id)
|
123
123
|
task_res.task.created_at = created_at.isoformat()
|
124
124
|
task_res.task.ttl = ttl.isoformat()
|
125
|
-
self.
|
125
|
+
with self.lock:
|
126
|
+
self.task_res_store[task_id] = task_res
|
126
127
|
|
127
128
|
# Return the new task_id
|
128
129
|
return task_id
|
@@ -132,46 +133,47 @@ class InMemoryState(State):
|
|
132
133
|
if limit is not None and limit < 1:
|
133
134
|
raise AssertionError("`limit` must be >= 1")
|
134
135
|
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
136
|
+
with self.lock:
|
137
|
+
# Find TaskRes that were not delivered yet
|
138
|
+
task_res_list: List[TaskRes] = []
|
139
|
+
for _, task_res in self.task_res_store.items():
|
140
|
+
if (
|
141
|
+
UUID(task_res.task.ancestry[0]) in task_ids
|
142
|
+
and task_res.task.delivered_at == ""
|
143
|
+
):
|
144
|
+
task_res_list.append(task_res)
|
145
|
+
if limit and len(task_res_list) == limit:
|
146
|
+
break
|
145
147
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
148
|
+
# Mark all of them as delivered
|
149
|
+
delivered_at = now().isoformat()
|
150
|
+
for task_res in task_res_list:
|
151
|
+
task_res.task.delivered_at = delivered_at
|
150
152
|
|
151
|
-
|
152
|
-
|
153
|
+
# Return TaskRes
|
154
|
+
return task_res_list
|
153
155
|
|
154
156
|
def delete_tasks(self, task_ids: Set[UUID]) -> None:
|
155
157
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
156
158
|
task_ins_to_be_deleted: Set[UUID] = set()
|
157
159
|
task_res_to_be_deleted: Set[UUID] = set()
|
158
160
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
161
|
+
with self.lock:
|
162
|
+
for task_ins_id in task_ids:
|
163
|
+
# Find the task_id of the matching task_res
|
164
|
+
for task_res_id, task_res in self.task_res_store.items():
|
165
|
+
if UUID(task_res.task.ancestry[0]) != task_ins_id:
|
166
|
+
continue
|
167
|
+
if task_res.task.delivered_at == "":
|
168
|
+
continue
|
169
|
+
|
170
|
+
task_ins_to_be_deleted.add(task_ins_id)
|
171
|
+
task_res_to_be_deleted.add(task_res_id)
|
172
|
+
|
173
|
+
for task_id in task_ins_to_be_deleted:
|
172
174
|
del self.task_ins_store[task_id]
|
173
|
-
|
174
|
-
|
175
|
+
for task_id in task_res_to_be_deleted:
|
176
|
+
del self.task_res_store[task_id]
|
175
177
|
|
176
178
|
def num_task_ins(self) -> int:
|
177
179
|
"""Calculate the number of task_ins in store.
|
flwr/server/workflow/__init__.py
CHANGED
@@ -0,0 +1,32 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Constants for default workflows."""
|
16
|
+
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
MAIN_CONFIGS_RECORD = "config"
|
21
|
+
MAIN_PARAMS_RECORD = "parameters"
|
22
|
+
|
23
|
+
|
24
|
+
class Key:
|
25
|
+
"""Constants for default workflows."""
|
26
|
+
|
27
|
+
CURRENT_ROUND = "current_round"
|
28
|
+
START_TIME = "start_time"
|
29
|
+
|
30
|
+
def __new__(cls) -> Key:
|
31
|
+
"""Prevent instantiation."""
|
32
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|