flwr-nightly 1.9.0.dev20240507__py3-none-any.whl → 1.9.0.dev20240520__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/new/new.py +4 -0
- flwr/cli/new/templates/app/code/client.hf.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +55 -0
- flwr/cli/new/templates/app/code/server.hf.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +12 -0
- flwr/cli/new/templates/app/code/task.hf.py.tpl +87 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +31 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +28 -0
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +1 -2
- flwr/server/__init__.py +0 -2
- flwr/server/app.py +7 -1
- flwr/server/compat/app.py +6 -57
- flwr/server/driver/__init__.py +3 -2
- flwr/server/driver/inmemory_driver.py +181 -0
- flwr/server/history.py +20 -20
- flwr/server/server.py +11 -7
- flwr/server/strategy/dp_adaptive_clipping.py +2 -4
- flwr/server/strategy/dp_fixed_clipping.py +2 -4
- flwr/server/superlink/driver/driver_servicer.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +11 -3
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/workflow/default_workflows.py +67 -22
- flwr/simulation/run_simulation.py +7 -34
- {flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/METADATA +2 -1
- {flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/RECORD +30 -21
- {flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower in-memory Driver."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import time
|
|
19
|
+
import warnings
|
|
20
|
+
from typing import Iterable, List, Optional
|
|
21
|
+
from uuid import UUID
|
|
22
|
+
|
|
23
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
24
|
+
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
25
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
26
|
+
from flwr.server.superlink.state import StateFactory
|
|
27
|
+
|
|
28
|
+
from .driver import Driver
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class InMemoryDriver(Driver):
|
|
32
|
+
"""`InMemoryDriver` class provides an interface to the Driver API.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
state_factory : StateFactory
|
|
37
|
+
A StateFactory embedding a state that this driver can interface with.
|
|
38
|
+
fab_id : str (default: None)
|
|
39
|
+
The identifier of the FAB used in the run.
|
|
40
|
+
fab_version : str (default: None)
|
|
41
|
+
The version of the FAB used in the run.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
state_factory: StateFactory,
|
|
47
|
+
fab_id: Optional[str] = None,
|
|
48
|
+
fab_version: Optional[str] = None,
|
|
49
|
+
) -> None:
|
|
50
|
+
self.run_id: Optional[int] = None
|
|
51
|
+
self.fab_id = fab_id if fab_id is not None else ""
|
|
52
|
+
self.fab_version = fab_version if fab_version is not None else ""
|
|
53
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
54
|
+
self.state = state_factory.state()
|
|
55
|
+
|
|
56
|
+
def _check_message(self, message: Message) -> None:
|
|
57
|
+
# Check if the message is valid
|
|
58
|
+
if not (
|
|
59
|
+
message.metadata.run_id == self.run_id
|
|
60
|
+
and message.metadata.src_node_id == self.node.node_id
|
|
61
|
+
and message.metadata.message_id == ""
|
|
62
|
+
and message.metadata.reply_to_message == ""
|
|
63
|
+
and message.metadata.ttl > 0
|
|
64
|
+
):
|
|
65
|
+
raise ValueError(f"Invalid message: {message}")
|
|
66
|
+
|
|
67
|
+
def _get_run_id(self) -> int:
|
|
68
|
+
"""Return run_id.
|
|
69
|
+
|
|
70
|
+
If unset, create a new run.
|
|
71
|
+
"""
|
|
72
|
+
if self.run_id is None:
|
|
73
|
+
self.run_id = self.state.create_run(
|
|
74
|
+
fab_id=self.fab_id, fab_version=self.fab_version
|
|
75
|
+
)
|
|
76
|
+
return self.run_id
|
|
77
|
+
|
|
78
|
+
def create_message( # pylint: disable=too-many-arguments
|
|
79
|
+
self,
|
|
80
|
+
content: RecordSet,
|
|
81
|
+
message_type: str,
|
|
82
|
+
dst_node_id: int,
|
|
83
|
+
group_id: str,
|
|
84
|
+
ttl: Optional[float] = None,
|
|
85
|
+
) -> Message:
|
|
86
|
+
"""Create a new message with specified parameters.
|
|
87
|
+
|
|
88
|
+
This method constructs a new `Message` with given content and metadata.
|
|
89
|
+
The `run_id` and `src_node_id` will be set automatically.
|
|
90
|
+
"""
|
|
91
|
+
run_id = self._get_run_id()
|
|
92
|
+
if ttl:
|
|
93
|
+
warnings.warn(
|
|
94
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
95
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
96
|
+
"version of Flower.",
|
|
97
|
+
stacklevel=2,
|
|
98
|
+
)
|
|
99
|
+
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
100
|
+
|
|
101
|
+
metadata = Metadata(
|
|
102
|
+
run_id=run_id,
|
|
103
|
+
message_id="", # Will be set by the server
|
|
104
|
+
src_node_id=self.node.node_id,
|
|
105
|
+
dst_node_id=dst_node_id,
|
|
106
|
+
reply_to_message="",
|
|
107
|
+
group_id=group_id,
|
|
108
|
+
ttl=ttl_,
|
|
109
|
+
message_type=message_type,
|
|
110
|
+
)
|
|
111
|
+
return Message(metadata=metadata, content=content)
|
|
112
|
+
|
|
113
|
+
def get_node_ids(self) -> List[int]:
|
|
114
|
+
"""Get node IDs."""
|
|
115
|
+
run_id = self._get_run_id()
|
|
116
|
+
return list(self.state.get_nodes(run_id))
|
|
117
|
+
|
|
118
|
+
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
119
|
+
"""Push messages to specified node IDs.
|
|
120
|
+
|
|
121
|
+
This method takes an iterable of messages and sends each message
|
|
122
|
+
to the node specified in `dst_node_id`.
|
|
123
|
+
"""
|
|
124
|
+
task_ids: List[str] = []
|
|
125
|
+
for msg in messages:
|
|
126
|
+
# Check message
|
|
127
|
+
self._check_message(msg)
|
|
128
|
+
# Convert Message to TaskIns
|
|
129
|
+
taskins = message_to_taskins(msg)
|
|
130
|
+
# Store in state
|
|
131
|
+
taskins.task.pushed_at = time.time()
|
|
132
|
+
task_id = self.state.store_task_ins(taskins)
|
|
133
|
+
if task_id:
|
|
134
|
+
task_ids.append(str(task_id))
|
|
135
|
+
|
|
136
|
+
return task_ids
|
|
137
|
+
|
|
138
|
+
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
139
|
+
"""Pull messages based on message IDs.
|
|
140
|
+
|
|
141
|
+
This method is used to collect messages from the SuperLink that correspond to a
|
|
142
|
+
set of given message IDs.
|
|
143
|
+
"""
|
|
144
|
+
msg_ids = {UUID(msg_id) for msg_id in message_ids}
|
|
145
|
+
# Pull TaskRes
|
|
146
|
+
task_res_list = self.state.get_task_res(task_ids=msg_ids, limit=len(msg_ids))
|
|
147
|
+
# Delete tasks in state
|
|
148
|
+
self.state.delete_tasks(msg_ids)
|
|
149
|
+
# Convert TaskRes to Message
|
|
150
|
+
msgs = [message_from_taskres(taskres) for taskres in task_res_list]
|
|
151
|
+
return msgs
|
|
152
|
+
|
|
153
|
+
def send_and_receive(
|
|
154
|
+
self,
|
|
155
|
+
messages: Iterable[Message],
|
|
156
|
+
*,
|
|
157
|
+
timeout: Optional[float] = None,
|
|
158
|
+
) -> Iterable[Message]:
|
|
159
|
+
"""Push messages to specified node IDs and pull the reply messages.
|
|
160
|
+
|
|
161
|
+
This method sends a list of messages to their destination node IDs and then
|
|
162
|
+
waits for the replies. It continues to pull replies until either all replies are
|
|
163
|
+
received or the specified timeout duration is exceeded.
|
|
164
|
+
"""
|
|
165
|
+
# Push messages
|
|
166
|
+
msg_ids = set(self.push_messages(messages))
|
|
167
|
+
|
|
168
|
+
# Pull messages
|
|
169
|
+
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
170
|
+
ret: List[Message] = []
|
|
171
|
+
while timeout is None or time.time() < end_time:
|
|
172
|
+
res_msgs = self.pull_messages(msg_ids)
|
|
173
|
+
ret.extend(res_msgs)
|
|
174
|
+
msg_ids.difference_update(
|
|
175
|
+
{msg.metadata.reply_to_message for msg in res_msgs}
|
|
176
|
+
)
|
|
177
|
+
if len(msg_ids) == 0:
|
|
178
|
+
break
|
|
179
|
+
# Sleep
|
|
180
|
+
time.sleep(3)
|
|
181
|
+
return ret
|
flwr/server/history.py
CHANGED
|
@@ -91,32 +91,32 @@ class History:
|
|
|
91
91
|
"""
|
|
92
92
|
rep = ""
|
|
93
93
|
if self.losses_distributed:
|
|
94
|
-
rep += "History (loss, distributed):\n" +
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
],
|
|
101
|
-
)
|
|
94
|
+
rep += "History (loss, distributed):\n" + reduce(
|
|
95
|
+
lambda a, b: a + b,
|
|
96
|
+
[
|
|
97
|
+
f"\tround {server_round}: {loss}\n"
|
|
98
|
+
for server_round, loss in self.losses_distributed
|
|
99
|
+
],
|
|
102
100
|
)
|
|
103
101
|
if self.losses_centralized:
|
|
104
|
-
rep += "History (loss, centralized):\n" +
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
],
|
|
111
|
-
)
|
|
102
|
+
rep += "History (loss, centralized):\n" + reduce(
|
|
103
|
+
lambda a, b: a + b,
|
|
104
|
+
[
|
|
105
|
+
f"\tround {server_round}: {loss}\n"
|
|
106
|
+
for server_round, loss in self.losses_centralized
|
|
107
|
+
],
|
|
112
108
|
)
|
|
113
109
|
if self.metrics_distributed_fit:
|
|
114
|
-
rep +=
|
|
115
|
-
|
|
110
|
+
rep += (
|
|
111
|
+
"History (metrics, distributed, fit):\n"
|
|
112
|
+
+ pprint.pformat(self.metrics_distributed_fit)
|
|
113
|
+
+ "\n"
|
|
116
114
|
)
|
|
117
115
|
if self.metrics_distributed:
|
|
118
|
-
rep +=
|
|
119
|
-
|
|
116
|
+
rep += (
|
|
117
|
+
"History (metrics, distributed, evaluate):\n"
|
|
118
|
+
+ pprint.pformat(self.metrics_distributed)
|
|
119
|
+
+ "\n"
|
|
120
120
|
)
|
|
121
121
|
if self.metrics_centralized:
|
|
122
122
|
rep += "History (metrics, centralized):\n" + pprint.pformat(
|
flwr/server/server.py
CHANGED
|
@@ -282,7 +282,14 @@ class Server:
|
|
|
282
282
|
get_parameters_res = random_client.get_parameters(
|
|
283
283
|
ins=ins, timeout=timeout, group_id=server_round
|
|
284
284
|
)
|
|
285
|
-
|
|
285
|
+
if get_parameters_res.status.code == Code.OK:
|
|
286
|
+
log(INFO, "Received initial parameters from one random client")
|
|
287
|
+
else:
|
|
288
|
+
log(
|
|
289
|
+
WARN,
|
|
290
|
+
"Failed to receive initial parameters from the client."
|
|
291
|
+
" Empty initial parameters will be used.",
|
|
292
|
+
)
|
|
286
293
|
return get_parameters_res.parameters
|
|
287
294
|
|
|
288
295
|
|
|
@@ -486,12 +493,9 @@ def run_fl(
|
|
|
486
493
|
|
|
487
494
|
log(INFO, "")
|
|
488
495
|
log(INFO, "[SUMMARY]")
|
|
489
|
-
log(INFO, "Run finished %s
|
|
490
|
-
for
|
|
491
|
-
|
|
492
|
-
log(INFO, "%s", line.strip("\n"))
|
|
493
|
-
else:
|
|
494
|
-
log(INFO, "\t%s", line.strip("\n"))
|
|
496
|
+
log(INFO, "Run finished %s round(s) in %.2fs", config.num_rounds, elapsed_time)
|
|
497
|
+
for line in io.StringIO(str(hist)):
|
|
498
|
+
log(INFO, "\t%s", line.strip("\n"))
|
|
495
499
|
log(INFO, "")
|
|
496
500
|
|
|
497
501
|
# Graceful shutdown
|
|
@@ -234,8 +234,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
234
234
|
)
|
|
235
235
|
log(
|
|
236
236
|
INFO,
|
|
237
|
-
"aggregate_fit: central DP noise with "
|
|
238
|
-
"standard deviation: %.4f added to parameters.",
|
|
237
|
+
"aggregate_fit: central DP noise with %.4f stdev added",
|
|
239
238
|
compute_stdv(
|
|
240
239
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
241
240
|
),
|
|
@@ -425,8 +424,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
425
424
|
)
|
|
426
425
|
log(
|
|
427
426
|
INFO,
|
|
428
|
-
"aggregate_fit: central DP noise with "
|
|
429
|
-
"standard deviation: %.4f added to parameters.",
|
|
427
|
+
"aggregate_fit: central DP noise with %.4f stdev added",
|
|
430
428
|
compute_stdv(
|
|
431
429
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
432
430
|
),
|
|
@@ -180,8 +180,7 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
180
180
|
|
|
181
181
|
log(
|
|
182
182
|
INFO,
|
|
183
|
-
"aggregate_fit: central DP noise with "
|
|
184
|
-
"standard deviation: %.4f added to parameters.",
|
|
183
|
+
"aggregate_fit: central DP noise with %.4f stdev added",
|
|
185
184
|
compute_stdv(
|
|
186
185
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
187
186
|
),
|
|
@@ -338,8 +337,7 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
|
338
337
|
)
|
|
339
338
|
log(
|
|
340
339
|
INFO,
|
|
341
|
-
"aggregate_fit: central DP noise with "
|
|
342
|
-
"standard deviation: %.4f added to parameters.",
|
|
340
|
+
"aggregate_fit: central DP noise with %.4f stdev added",
|
|
343
341
|
compute_stdv(
|
|
344
342
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
345
343
|
),
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
from logging import DEBUG
|
|
19
|
+
from logging import DEBUG
|
|
20
20
|
from typing import List, Optional, Set
|
|
21
21
|
from uuid import UUID
|
|
22
22
|
|
|
@@ -62,7 +62,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
62
62
|
self, request: CreateRunRequest, context: grpc.ServicerContext
|
|
63
63
|
) -> CreateRunResponse:
|
|
64
64
|
"""Create run ID."""
|
|
65
|
-
log(
|
|
65
|
+
log(DEBUG, "DriverServicer.CreateRun")
|
|
66
66
|
state: State = self.state_factory.state()
|
|
67
67
|
run_id = state.create_run(request.fab_id, request.fab_version)
|
|
68
68
|
return CreateRunResponse(run_id=run_id)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Ray backend for the Fleet API using the Simulation Engine."""
|
|
16
16
|
|
|
17
17
|
import pathlib
|
|
18
|
-
from logging import DEBUG, ERROR,
|
|
18
|
+
from logging import DEBUG, ERROR, WARNING
|
|
19
19
|
from typing import Callable, Dict, List, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import ray
|
|
@@ -45,7 +45,7 @@ class RayBackend(Backend):
|
|
|
45
45
|
work_dir: str,
|
|
46
46
|
) -> None:
|
|
47
47
|
"""Prepare RayBackend by initialising Ray and creating the ActorPool."""
|
|
48
|
-
log(
|
|
48
|
+
log(DEBUG, "Initialising: %s", self.__class__.__name__)
|
|
49
49
|
log(DEBUG, "Backend config: %s", backend_config)
|
|
50
50
|
|
|
51
51
|
if not pathlib.Path(work_dir).exists():
|
|
@@ -55,7 +55,15 @@ class RayBackend(Backend):
|
|
|
55
55
|
runtime_env = (
|
|
56
56
|
self._configure_runtime_env(work_dir=work_dir) if work_dir else None
|
|
57
57
|
)
|
|
58
|
-
|
|
58
|
+
|
|
59
|
+
if backend_config.get("mute_logging", False):
|
|
60
|
+
init_ray(
|
|
61
|
+
logging_level=WARNING, log_to_driver=False, runtime_env=runtime_env
|
|
62
|
+
)
|
|
63
|
+
elif backend_config.get("silent", False):
|
|
64
|
+
init_ray(logging_level=WARNING, log_to_driver=True, runtime_env=runtime_env)
|
|
65
|
+
else:
|
|
66
|
+
init_ray(runtime_env=runtime_env)
|
|
59
67
|
|
|
60
68
|
# Validate client resources
|
|
61
69
|
self.client_resources_key = "client_resources"
|
|
@@ -46,7 +46,7 @@ def _register_nodes(
|
|
|
46
46
|
for i in range(num_nodes):
|
|
47
47
|
node_id = state.create_node(ping_interval=PING_MAX_INTERVAL)
|
|
48
48
|
nodes_mapping[node_id] = i
|
|
49
|
-
log(
|
|
49
|
+
log(DEBUG, "Registered %i nodes", len(nodes_mapping))
|
|
50
50
|
return nodes_mapping
|
|
51
51
|
|
|
52
52
|
|
|
@@ -17,13 +17,23 @@
|
|
|
17
17
|
|
|
18
18
|
import io
|
|
19
19
|
import timeit
|
|
20
|
-
from logging import INFO
|
|
21
|
-
from typing import Optional, cast
|
|
20
|
+
from logging import INFO, WARN
|
|
21
|
+
from typing import List, Optional, Tuple, Union, cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
|
24
|
-
from flwr.common import
|
|
24
|
+
from flwr.common import (
|
|
25
|
+
Code,
|
|
26
|
+
ConfigsRecord,
|
|
27
|
+
Context,
|
|
28
|
+
EvaluateRes,
|
|
29
|
+
FitRes,
|
|
30
|
+
GetParametersIns,
|
|
31
|
+
ParametersRecord,
|
|
32
|
+
log,
|
|
33
|
+
)
|
|
25
34
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
26
35
|
|
|
36
|
+
from ..client_proxy import ClientProxy
|
|
27
37
|
from ..compat.app_utils import start_update_client_manager_thread
|
|
28
38
|
from ..compat.legacy_context import LegacyContext
|
|
29
39
|
from ..driver import Driver
|
|
@@ -88,7 +98,12 @@ class DefaultWorkflow:
|
|
|
88
98
|
hist = context.history
|
|
89
99
|
log(INFO, "")
|
|
90
100
|
log(INFO, "[SUMMARY]")
|
|
91
|
-
log(
|
|
101
|
+
log(
|
|
102
|
+
INFO,
|
|
103
|
+
"Run finished %s round(s) in %.2fs",
|
|
104
|
+
context.config.num_rounds,
|
|
105
|
+
elapsed,
|
|
106
|
+
)
|
|
92
107
|
for idx, line in enumerate(io.StringIO(str(hist))):
|
|
93
108
|
if idx == 0:
|
|
94
109
|
log(INFO, "%s", line.strip("\n"))
|
|
@@ -130,9 +145,24 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
130
145
|
)
|
|
131
146
|
]
|
|
132
147
|
)
|
|
133
|
-
log(INFO, "Received initial parameters from one random client")
|
|
134
148
|
msg = list(messages)[0]
|
|
135
|
-
|
|
149
|
+
|
|
150
|
+
if (
|
|
151
|
+
msg.has_content()
|
|
152
|
+
and compat._extract_status_from_recordset( # pylint: disable=W0212
|
|
153
|
+
"getparametersres", msg.content
|
|
154
|
+
).code
|
|
155
|
+
== Code.OK
|
|
156
|
+
):
|
|
157
|
+
log(INFO, "Received initial parameters from one random client")
|
|
158
|
+
paramsrecord = next(iter(msg.content.parameters_records.values()))
|
|
159
|
+
else:
|
|
160
|
+
log(
|
|
161
|
+
WARN,
|
|
162
|
+
"Failed to receive initial parameters from the client."
|
|
163
|
+
" Empty initial parameters will be used.",
|
|
164
|
+
)
|
|
165
|
+
paramsrecord = ParametersRecord()
|
|
136
166
|
|
|
137
167
|
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
|
138
168
|
|
|
@@ -244,14 +274,20 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
244
274
|
)
|
|
245
275
|
|
|
246
276
|
# Aggregate training results
|
|
247
|
-
results = [
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
277
|
+
results: List[Tuple[ClientProxy, FitRes]] = []
|
|
278
|
+
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = []
|
|
279
|
+
for msg in messages:
|
|
280
|
+
if msg.has_content():
|
|
281
|
+
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
282
|
+
fitres = compat.recordset_to_fitres(msg.content, False)
|
|
283
|
+
if fitres.status.code == Code.OK:
|
|
284
|
+
results.append((proxy, fitres))
|
|
285
|
+
else:
|
|
286
|
+
failures.append((proxy, fitres))
|
|
287
|
+
else:
|
|
288
|
+
failures.append(Exception(msg.error))
|
|
289
|
+
|
|
290
|
+
aggregated_result = context.strategy.aggregate_fit(current_round, results, failures)
|
|
255
291
|
parameters_aggregated, metrics_aggregated = aggregated_result
|
|
256
292
|
|
|
257
293
|
# Update the parameters and write history
|
|
@@ -265,6 +301,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
265
301
|
)
|
|
266
302
|
|
|
267
303
|
|
|
304
|
+
# pylint: disable-next=R0914
|
|
268
305
|
def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
269
306
|
"""Execute the default workflow for a single evaluate round."""
|
|
270
307
|
if not isinstance(context, LegacyContext):
|
|
@@ -323,14 +360,22 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
323
360
|
)
|
|
324
361
|
|
|
325
362
|
# Aggregate the evaluation results
|
|
326
|
-
results = [
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
363
|
+
results: List[Tuple[ClientProxy, EvaluateRes]] = []
|
|
364
|
+
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = []
|
|
365
|
+
for msg in messages:
|
|
366
|
+
if msg.has_content():
|
|
367
|
+
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
368
|
+
evalres = compat.recordset_to_evaluateres(msg.content)
|
|
369
|
+
if evalres.status.code == Code.OK:
|
|
370
|
+
results.append((proxy, evalres))
|
|
371
|
+
else:
|
|
372
|
+
failures.append((proxy, evalres))
|
|
373
|
+
else:
|
|
374
|
+
failures.append(Exception(msg.error))
|
|
375
|
+
|
|
376
|
+
aggregated_result = context.strategy.aggregate_evaluate(
|
|
377
|
+
current_round, results, failures
|
|
378
|
+
)
|
|
334
379
|
|
|
335
380
|
loss_aggregated, metrics_aggregated = aggregated_result
|
|
336
381
|
|
|
@@ -24,16 +24,13 @@ from logging import DEBUG, ERROR, INFO, WARNING
|
|
|
24
24
|
from time import sleep
|
|
25
25
|
from typing import Dict, Optional
|
|
26
26
|
|
|
27
|
-
import grpc
|
|
28
|
-
|
|
29
27
|
from flwr.client import ClientApp
|
|
30
28
|
from flwr.common import EventType, event, log
|
|
31
29
|
from flwr.common.logger import set_logger_propagation, update_console_handler
|
|
32
30
|
from flwr.common.typing import ConfigsRecordValues
|
|
33
|
-
from flwr.server.driver import Driver,
|
|
31
|
+
from flwr.server.driver import Driver, InMemoryDriver
|
|
34
32
|
from flwr.server.run_serverapp import run
|
|
35
33
|
from flwr.server.server_app import ServerApp
|
|
36
|
-
from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
|
|
37
34
|
from flwr.server.superlink.fleet import vce
|
|
38
35
|
from flwr.server.superlink.state import StateFactory
|
|
39
36
|
from flwr.simulation.ray_transport.utils import (
|
|
@@ -56,7 +53,6 @@ def run_simulation_from_cli() -> None:
|
|
|
56
53
|
backend_name=args.backend,
|
|
57
54
|
backend_config=backend_config_dict,
|
|
58
55
|
app_dir=args.app_dir,
|
|
59
|
-
driver_api_address=args.driver_api_address,
|
|
60
56
|
enable_tf_gpu_growth=args.enable_tf_gpu_growth,
|
|
61
57
|
verbose_logging=args.verbose,
|
|
62
58
|
)
|
|
@@ -177,7 +173,6 @@ def _main_loop(
|
|
|
177
173
|
num_supernodes: int,
|
|
178
174
|
backend_name: str,
|
|
179
175
|
backend_config_stream: str,
|
|
180
|
-
driver_api_address: str,
|
|
181
176
|
app_dir: str,
|
|
182
177
|
enable_tf_gpu_growth: bool,
|
|
183
178
|
client_app: Optional[ClientApp] = None,
|
|
@@ -194,21 +189,11 @@ def _main_loop(
|
|
|
194
189
|
# Initialize StateFactory
|
|
195
190
|
state_factory = StateFactory(":flwr-in-memory-state:")
|
|
196
191
|
|
|
197
|
-
# Start Driver API
|
|
198
|
-
driver_server: grpc.Server = run_driver_api_grpc(
|
|
199
|
-
address=driver_api_address,
|
|
200
|
-
state_factory=state_factory,
|
|
201
|
-
certificates=None,
|
|
202
|
-
)
|
|
203
|
-
|
|
204
192
|
f_stop = asyncio.Event()
|
|
205
193
|
serverapp_th = None
|
|
206
194
|
try:
|
|
207
195
|
# Initialize Driver
|
|
208
|
-
driver =
|
|
209
|
-
driver_service_address=driver_api_address,
|
|
210
|
-
root_certificates=None,
|
|
211
|
-
)
|
|
196
|
+
driver = InMemoryDriver(state_factory)
|
|
212
197
|
|
|
213
198
|
# Get and run ServerApp thread
|
|
214
199
|
serverapp_th = run_serverapp_th(
|
|
@@ -239,9 +224,6 @@ def _main_loop(
|
|
|
239
224
|
raise RuntimeError("An error was encountered. Ending simulation.") from ex
|
|
240
225
|
|
|
241
226
|
finally:
|
|
242
|
-
# Stop Driver
|
|
243
|
-
driver_server.stop(grace=0)
|
|
244
|
-
driver.close()
|
|
245
227
|
# Trigger stop event
|
|
246
228
|
f_stop.set()
|
|
247
229
|
|
|
@@ -262,7 +244,6 @@ def _run_simulation(
|
|
|
262
244
|
client_app_attr: Optional[str] = None,
|
|
263
245
|
server_app_attr: Optional[str] = None,
|
|
264
246
|
app_dir: str = "",
|
|
265
|
-
driver_api_address: str = "0.0.0.0:9091",
|
|
266
247
|
enable_tf_gpu_growth: bool = False,
|
|
267
248
|
verbose_logging: bool = False,
|
|
268
249
|
) -> None:
|
|
@@ -302,9 +283,6 @@ def _run_simulation(
|
|
|
302
283
|
Add specified directory to the PYTHONPATH and load `ClientApp` from there.
|
|
303
284
|
(Default: current working directory.)
|
|
304
285
|
|
|
305
|
-
driver_api_address : str (default: "0.0.0.0:9091")
|
|
306
|
-
Driver API (gRPC) server address (IPv4, IPv6, or a domain name)
|
|
307
|
-
|
|
308
286
|
enable_tf_gpu_growth : bool (default: False)
|
|
309
287
|
A boolean to indicate whether to enable GPU growth on the main thread. This is
|
|
310
288
|
desirable if you make use of a TensorFlow model on your `ServerApp` while
|
|
@@ -317,13 +295,15 @@ def _run_simulation(
|
|
|
317
295
|
When diabled, only INFO, WARNING and ERROR log messages will be shown. If
|
|
318
296
|
enabled, DEBUG-level logs will be displayed.
|
|
319
297
|
"""
|
|
298
|
+
if backend_config is None:
|
|
299
|
+
backend_config = {}
|
|
300
|
+
|
|
320
301
|
# Set logging level
|
|
321
302
|
logger = logging.getLogger("flwr")
|
|
322
303
|
if verbose_logging:
|
|
323
304
|
update_console_handler(level=DEBUG, timestamps=True, colored=True)
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
backend_config = {}
|
|
305
|
+
else:
|
|
306
|
+
backend_config["silent"] = True
|
|
327
307
|
|
|
328
308
|
if enable_tf_gpu_growth:
|
|
329
309
|
# Check that Backend config has also enabled using GPU growth
|
|
@@ -340,7 +320,6 @@ def _run_simulation(
|
|
|
340
320
|
num_supernodes,
|
|
341
321
|
backend_name,
|
|
342
322
|
backend_config_stream,
|
|
343
|
-
driver_api_address,
|
|
344
323
|
app_dir,
|
|
345
324
|
enable_tf_gpu_growth,
|
|
346
325
|
client_app,
|
|
@@ -397,12 +376,6 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
|
397
376
|
required=True,
|
|
398
377
|
help="Number of simulated SuperNodes.",
|
|
399
378
|
)
|
|
400
|
-
parser.add_argument(
|
|
401
|
-
"--driver-api-address",
|
|
402
|
-
default="0.0.0.0:9091",
|
|
403
|
-
type=str,
|
|
404
|
-
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
|
405
|
-
)
|
|
406
379
|
parser.add_argument(
|
|
407
380
|
"--backend",
|
|
408
381
|
default="ray",
|
{flwr_nightly-1.9.0.dev20240507.dist-info → flwr_nightly-1.9.0.dev20240520.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: flwr-nightly
|
|
3
|
-
Version: 1.9.0.
|
|
3
|
+
Version: 1.9.0.dev20240520
|
|
4
4
|
Summary: Flower: A Friendly Federated Learning Framework
|
|
5
5
|
Home-page: https://flower.ai
|
|
6
6
|
License: Apache-2.0
|
|
@@ -202,6 +202,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
|
|
|
202
202
|
- [Comprehensive Flower+XGBoost](https://github.com/adap/flower/tree/main/examples/xgboost-comprehensive)
|
|
203
203
|
- [Flower through Docker Compose and with Grafana dashboard](https://github.com/adap/flower/tree/main/examples/flower-via-docker-compose)
|
|
204
204
|
- [Flower with KaplanMeierFitter from the lifelines library](https://github.com/adap/flower/tree/main/examples/federated-kaplna-meier-fitter)
|
|
205
|
+
- [Sample Level Privacy with Opacus](https://github.com/adap/flower/tree/main/examples/opacus)
|
|
205
206
|
|
|
206
207
|
## Community
|
|
207
208
|
|