flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +260 -86
- flwr/client/clientapp/app.py +6 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/run_info_store.py +2 -2
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/__init__.py +12 -4
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/config.py +4 -4
- flwr/common/constant.py +16 -0
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +2 -2
- flwr/common/message.py +338 -102
- flwr/common/object_ref.py +0 -10
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +9 -18
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +67 -190
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +44 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +74 -3
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +15 -12
- flwr/server/compat/app_utils.py +26 -18
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
- flwr/server/run_serverapp.py +6 -17
- flwr/server/server_app.py +126 -33
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +33 -38
- flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
- flwr/server/superlink/linkstate/linkstate.py +51 -64
- flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
- flwr/server/superlink/linkstate/utils.py +171 -133
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +52 -58
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/app.py +0 -14
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +6 -6
- flwr/superexec/exec_user_auth_interceptor.py +22 -4
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/common/record/parametersrecord.py +0 -204
- flwr/common/record/recordset.py +0 -202
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,24 +12,23 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower gRPC
|
|
15
|
+
"""Flower gRPC Grid."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
import warnings
|
|
20
19
|
from collections.abc import Iterable
|
|
21
|
-
from logging import DEBUG, WARNING
|
|
20
|
+
from logging import DEBUG, ERROR, WARNING
|
|
22
21
|
from typing import Optional, cast
|
|
23
22
|
|
|
24
23
|
import grpc
|
|
25
24
|
|
|
26
|
-
from flwr.common import
|
|
25
|
+
from flwr.common import Message, RecordDict
|
|
27
26
|
from flwr.common.constant import (
|
|
28
27
|
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
29
28
|
SUPERLINK_NODE_ID,
|
|
30
29
|
)
|
|
31
30
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
32
|
-
from flwr.common.logger import log
|
|
31
|
+
from flwr.common.logger import log, warn_deprecated_feature
|
|
33
32
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
34
33
|
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
|
35
34
|
from flwr.common.typing import Run
|
|
@@ -46,18 +45,39 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
|
46
45
|
)
|
|
47
46
|
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
|
48
47
|
|
|
49
|
-
from .
|
|
48
|
+
from .grid import Grid
|
|
50
49
|
|
|
51
|
-
|
|
52
|
-
[flwr-serverapp] Error: Not connected.
|
|
50
|
+
ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED = """
|
|
53
51
|
|
|
54
|
-
|
|
55
|
-
|
|
52
|
+
[Grid.push_messages] gRPC error occurred:
|
|
53
|
+
|
|
54
|
+
The 2GB gRPC limit has been reached. Consider reducing the number of messages pushed
|
|
55
|
+
at once, or push messages individually, for example:
|
|
56
|
+
|
|
57
|
+
> msgs = [msg1, msg2, msg3]
|
|
58
|
+
> msg_ids = []
|
|
59
|
+
> for msg in msgs:
|
|
60
|
+
> msg_id = grid.push_messages([msg])
|
|
61
|
+
> msg_ids.extend(msg_id)
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED = """
|
|
65
|
+
|
|
66
|
+
[Grid.pull_messages] gRPC error occurred:
|
|
67
|
+
|
|
68
|
+
The 2GB gRPC limit has been reached. Consider reducing the number of messages pulled
|
|
69
|
+
at once, or pull messages individually, for example:
|
|
70
|
+
|
|
71
|
+
> msgs_ids = [msg_id1, msg_id2, msg_id3]
|
|
72
|
+
> msgs = []
|
|
73
|
+
> for msg_id in msg_ids:
|
|
74
|
+
> msg = grid.pull_messages([msg_id])
|
|
75
|
+
> msgs.extend(msg)
|
|
56
76
|
"""
|
|
57
77
|
|
|
58
78
|
|
|
59
|
-
class
|
|
60
|
-
"""`
|
|
79
|
+
class GrpcGrid(Grid):
|
|
80
|
+
"""`GrpcGrid` provides an interface to the ServerAppIo API.
|
|
61
81
|
|
|
62
82
|
Parameters
|
|
63
83
|
----------
|
|
@@ -69,6 +89,8 @@ class GrpcDriver(Driver):
|
|
|
69
89
|
established to an SSL-enabled Flower server.
|
|
70
90
|
"""
|
|
71
91
|
|
|
92
|
+
_deprecation_warning_logged = False
|
|
93
|
+
|
|
72
94
|
def __init__( # pylint: disable=too-many-arguments
|
|
73
95
|
self,
|
|
74
96
|
serverappio_service_address: str = SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
@@ -81,6 +103,7 @@ class GrpcDriver(Driver):
|
|
|
81
103
|
self._channel: Optional[grpc.Channel] = None
|
|
82
104
|
self.node = Node(node_id=SUPERLINK_NODE_ID)
|
|
83
105
|
self._retry_invoker = _make_simple_grpc_retry_invoker()
|
|
106
|
+
super().__init__()
|
|
84
107
|
|
|
85
108
|
@property
|
|
86
109
|
def _is_connected(self) -> bool:
|
|
@@ -140,18 +163,15 @@ class GrpcDriver(Driver):
|
|
|
140
163
|
def _check_message(self, message: Message) -> None:
|
|
141
164
|
# Check if the message is valid
|
|
142
165
|
if not (
|
|
143
|
-
|
|
144
|
-
message.metadata.
|
|
145
|
-
and message.metadata.src_node_id == self.node.node_id
|
|
146
|
-
and message.metadata.message_id == ""
|
|
147
|
-
and message.metadata.reply_to_message == ""
|
|
166
|
+
message.metadata.message_id == ""
|
|
167
|
+
and message.metadata.reply_to_message_id == ""
|
|
148
168
|
and message.metadata.ttl > 0
|
|
149
169
|
):
|
|
150
170
|
raise ValueError(f"Invalid message: {message}")
|
|
151
171
|
|
|
152
172
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
153
173
|
self,
|
|
154
|
-
content:
|
|
174
|
+
content: RecordDict,
|
|
155
175
|
message_type: str,
|
|
156
176
|
dst_node_id: int,
|
|
157
177
|
group_id: str,
|
|
@@ -162,30 +182,17 @@ class GrpcDriver(Driver):
|
|
|
162
182
|
This method constructs a new `Message` with given content and metadata.
|
|
163
183
|
The `run_id` and `src_node_id` will be set automatically.
|
|
164
184
|
"""
|
|
165
|
-
if
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
"
|
|
169
|
-
"
|
|
170
|
-
stacklevel=2,
|
|
185
|
+
if not GrpcGrid._deprecation_warning_logged:
|
|
186
|
+
GrpcGrid._deprecation_warning_logged = True
|
|
187
|
+
warn_deprecated_feature(
|
|
188
|
+
"`Driver.create_message` / `Grid.create_message` is deprecated."
|
|
189
|
+
"Use `Message` constructor instead."
|
|
171
190
|
)
|
|
191
|
+
return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
|
|
172
192
|
|
|
173
|
-
|
|
174
|
-
metadata = Metadata(
|
|
175
|
-
run_id=cast(Run, self._run).run_id,
|
|
176
|
-
message_id="", # Will be set by the server
|
|
177
|
-
src_node_id=self.node.node_id,
|
|
178
|
-
dst_node_id=dst_node_id,
|
|
179
|
-
reply_to_message="",
|
|
180
|
-
group_id=group_id,
|
|
181
|
-
ttl=ttl_,
|
|
182
|
-
message_type=message_type,
|
|
183
|
-
)
|
|
184
|
-
return Message(metadata=metadata, content=content)
|
|
185
|
-
|
|
186
|
-
def get_node_ids(self) -> list[int]:
|
|
193
|
+
def get_node_ids(self) -> Iterable[int]:
|
|
187
194
|
"""Get node IDs."""
|
|
188
|
-
# Call
|
|
195
|
+
# Call GrpcServerAppIoStub method
|
|
189
196
|
res: GetNodesResponse = self._stub.GetNodes(
|
|
190
197
|
GetNodesRequest(run_id=cast(Run, self._run).run_id)
|
|
191
198
|
)
|
|
@@ -198,21 +205,40 @@ class GrpcDriver(Driver):
|
|
|
198
205
|
to the node specified in `dst_node_id`.
|
|
199
206
|
"""
|
|
200
207
|
# Construct Messages
|
|
208
|
+
run_id = cast(Run, self._run).run_id
|
|
201
209
|
message_proto_list: list[ProtoMessage] = []
|
|
202
210
|
for msg in messages:
|
|
211
|
+
# Populate metadata
|
|
212
|
+
msg.metadata.__dict__["_run_id"] = run_id
|
|
213
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
203
214
|
# Check message
|
|
204
215
|
self._check_message(msg)
|
|
205
216
|
# Convert to proto
|
|
206
217
|
msg_proto = message_to_proto(msg)
|
|
207
218
|
# Add to list
|
|
208
219
|
message_proto_list.append(msg_proto)
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
# Call GrpcServerAppIoStub method
|
|
223
|
+
res: PushInsMessagesResponse = self._stub.PushMessages(
|
|
224
|
+
PushInsMessagesRequest(messages_list=message_proto_list, run_id=run_id)
|
|
213
225
|
)
|
|
214
|
-
|
|
215
|
-
|
|
226
|
+
if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
|
|
227
|
+
message_proto_list
|
|
228
|
+
):
|
|
229
|
+
log(
|
|
230
|
+
WARNING,
|
|
231
|
+
"Not all messages could be pushed to the SuperLink. The returned "
|
|
232
|
+
"list has `None` for those messages (the order is preserved as "
|
|
233
|
+
"passed to `push_messages`). This could be due to a malformed "
|
|
234
|
+
"message.",
|
|
235
|
+
)
|
|
236
|
+
return list(res.message_ids)
|
|
237
|
+
except grpc.RpcError as e:
|
|
238
|
+
if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
|
|
239
|
+
log(ERROR, ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED)
|
|
240
|
+
return []
|
|
241
|
+
raise
|
|
216
242
|
|
|
217
243
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
218
244
|
"""Pull messages based on message IDs.
|
|
@@ -220,16 +246,22 @@ class GrpcDriver(Driver):
|
|
|
220
246
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
221
247
|
set of given message IDs.
|
|
222
248
|
"""
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
249
|
+
try:
|
|
250
|
+
# Pull Messages
|
|
251
|
+
res: PullResMessagesResponse = self._stub.PullMessages(
|
|
252
|
+
PullResMessagesRequest(
|
|
253
|
+
message_ids=message_ids,
|
|
254
|
+
run_id=cast(Run, self._run).run_id,
|
|
255
|
+
)
|
|
228
256
|
)
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
257
|
+
# Convert Message from Protobuf representation
|
|
258
|
+
msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
|
|
259
|
+
return msgs
|
|
260
|
+
except grpc.RpcError as e:
|
|
261
|
+
if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
|
|
262
|
+
log(ERROR, ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED)
|
|
263
|
+
return []
|
|
264
|
+
raise
|
|
233
265
|
|
|
234
266
|
def send_and_receive(
|
|
235
267
|
self,
|
|
@@ -253,7 +285,7 @@ class GrpcDriver(Driver):
|
|
|
253
285
|
res_msgs = self.pull_messages(msg_ids)
|
|
254
286
|
ret.extend(res_msgs)
|
|
255
287
|
msg_ids.difference_update(
|
|
256
|
-
{msg.metadata.
|
|
288
|
+
{msg.metadata.reply_to_message_id for msg in res_msgs}
|
|
257
289
|
)
|
|
258
290
|
if len(msg_ids) == 0:
|
|
259
291
|
break
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,36 +12,37 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Flower in-memory
|
|
15
|
+
"""Flower in-memory Grid."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
import warnings
|
|
20
19
|
from collections.abc import Iterable
|
|
21
20
|
from typing import Optional, cast
|
|
22
21
|
from uuid import UUID
|
|
23
22
|
|
|
24
|
-
from flwr.common import
|
|
23
|
+
from flwr.common import Message, RecordDict
|
|
25
24
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
26
|
-
from flwr.common.
|
|
25
|
+
from flwr.common.logger import warn_deprecated_feature
|
|
27
26
|
from flwr.common.typing import Run
|
|
28
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
29
28
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
30
29
|
|
|
31
|
-
from .
|
|
30
|
+
from .grid import Grid
|
|
32
31
|
|
|
33
32
|
|
|
34
|
-
class
|
|
35
|
-
"""`
|
|
33
|
+
class InMemoryGrid(Grid):
|
|
34
|
+
"""`InMemoryGrid` class provides an interface to the ServerAppIo API.
|
|
36
35
|
|
|
37
36
|
Parameters
|
|
38
37
|
----------
|
|
39
38
|
state_factory : StateFactory
|
|
40
|
-
A StateFactory embedding a state that this
|
|
39
|
+
A StateFactory embedding a state that this grid can interface with.
|
|
41
40
|
pull_interval : float (default=0.1)
|
|
42
41
|
Sleep duration between calls to `pull_messages`.
|
|
43
42
|
"""
|
|
44
43
|
|
|
44
|
+
_deprecation_warning_logged = False
|
|
45
|
+
|
|
45
46
|
def __init__(
|
|
46
47
|
self,
|
|
47
48
|
state_factory: LinkStateFactory,
|
|
@@ -55,11 +56,10 @@ class InMemoryDriver(Driver):
|
|
|
55
56
|
def _check_message(self, message: Message) -> None:
|
|
56
57
|
# Check if the message is valid
|
|
57
58
|
if not (
|
|
58
|
-
message.metadata.
|
|
59
|
-
and message.metadata.
|
|
60
|
-
and message.metadata.message_id == ""
|
|
61
|
-
and message.metadata.reply_to_message == ""
|
|
59
|
+
message.metadata.message_id == ""
|
|
60
|
+
and message.metadata.reply_to_message_id == ""
|
|
62
61
|
and message.metadata.ttl > 0
|
|
62
|
+
and message.metadata.delivered_at == ""
|
|
63
63
|
):
|
|
64
64
|
raise ValueError(f"Invalid message: {message}")
|
|
65
65
|
|
|
@@ -77,7 +77,7 @@ class InMemoryDriver(Driver):
|
|
|
77
77
|
|
|
78
78
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
|
79
79
|
self,
|
|
80
|
-
content:
|
|
80
|
+
content: RecordDict,
|
|
81
81
|
message_type: str,
|
|
82
82
|
dst_node_id: int,
|
|
83
83
|
group_id: str,
|
|
@@ -88,30 +88,17 @@ class InMemoryDriver(Driver):
|
|
|
88
88
|
This method constructs a new `Message` with given content and metadata.
|
|
89
89
|
The `run_id` and `src_node_id` will be set automatically.
|
|
90
90
|
"""
|
|
91
|
-
if
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
"
|
|
95
|
-
"
|
|
96
|
-
stacklevel=2,
|
|
91
|
+
if not InMemoryGrid._deprecation_warning_logged:
|
|
92
|
+
InMemoryGrid._deprecation_warning_logged = True
|
|
93
|
+
warn_deprecated_feature(
|
|
94
|
+
"`Driver.create_message` / `Grid.create_message` is deprecated."
|
|
95
|
+
"Use `Message` constructor instead."
|
|
97
96
|
)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
run_id=cast(Run, self._run).run_id,
|
|
102
|
-
message_id="", # Will be set by the server
|
|
103
|
-
src_node_id=self.node.node_id,
|
|
104
|
-
dst_node_id=dst_node_id,
|
|
105
|
-
reply_to_message="",
|
|
106
|
-
group_id=group_id,
|
|
107
|
-
ttl=ttl_,
|
|
108
|
-
message_type=message_type,
|
|
109
|
-
)
|
|
110
|
-
return Message(metadata=metadata, content=content)
|
|
111
|
-
|
|
112
|
-
def get_node_ids(self) -> list[int]:
|
|
97
|
+
return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
|
|
98
|
+
|
|
99
|
+
def get_node_ids(self) -> Iterable[int]:
|
|
113
100
|
"""Get node IDs."""
|
|
114
|
-
return
|
|
101
|
+
return self.state.get_nodes(cast(Run, self._run).run_id)
|
|
115
102
|
|
|
116
103
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
117
104
|
"""Push messages to specified node IDs.
|
|
@@ -119,18 +106,19 @@ class InMemoryDriver(Driver):
|
|
|
119
106
|
This method takes an iterable of messages and sends each message
|
|
120
107
|
to the node specified in `dst_node_id`.
|
|
121
108
|
"""
|
|
122
|
-
|
|
109
|
+
msg_ids: list[str] = []
|
|
123
110
|
for msg in messages:
|
|
111
|
+
# Populate metadata
|
|
112
|
+
msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
|
|
113
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
124
114
|
# Check message
|
|
125
115
|
self._check_message(msg)
|
|
126
|
-
# Convert Message to TaskIns
|
|
127
|
-
taskins = message_to_taskins(msg)
|
|
128
116
|
# Store in state
|
|
129
|
-
|
|
130
|
-
if
|
|
131
|
-
|
|
117
|
+
msg_id = self.state.store_message_ins(msg)
|
|
118
|
+
if msg_id:
|
|
119
|
+
msg_ids.append(str(msg_id))
|
|
132
120
|
|
|
133
|
-
return
|
|
121
|
+
return msg_ids
|
|
134
122
|
|
|
135
123
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
136
124
|
"""Pull messages based on message IDs.
|
|
@@ -139,17 +127,16 @@ class InMemoryDriver(Driver):
|
|
|
139
127
|
set of given message IDs.
|
|
140
128
|
"""
|
|
141
129
|
msg_ids = {UUID(msg_id) for msg_id in message_ids}
|
|
142
|
-
# Pull
|
|
143
|
-
|
|
144
|
-
#
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
130
|
+
# Pull Messages
|
|
131
|
+
message_res_list = self.state.get_message_res(message_ids=msg_ids)
|
|
132
|
+
# Get IDs of Messages these replies are for
|
|
133
|
+
message_ins_ids_to_delete = {
|
|
134
|
+
UUID(msg_res.metadata.reply_to_message_id) for msg_res in message_res_list
|
|
148
135
|
}
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
return
|
|
136
|
+
# Delete
|
|
137
|
+
self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
|
138
|
+
|
|
139
|
+
return message_res_list
|
|
153
140
|
|
|
154
141
|
def send_and_receive(
|
|
155
142
|
self,
|
|
@@ -173,7 +160,7 @@ class InMemoryDriver(Driver):
|
|
|
173
160
|
res_msgs = self.pull_messages(msg_ids)
|
|
174
161
|
ret.extend(res_msgs)
|
|
175
162
|
msg_ids.difference_update(
|
|
176
|
-
{msg.metadata.
|
|
163
|
+
{msg.metadata.reply_to_message_id for msg in res_msgs}
|
|
177
164
|
)
|
|
178
165
|
if len(msg_ids) == 0:
|
|
179
166
|
break
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -15,26 +15,25 @@
|
|
|
15
15
|
"""Run ServerApp."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from logging import DEBUG
|
|
18
|
+
from logging import DEBUG
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
from flwr.common import Context
|
|
22
|
-
from flwr.common.exit_handlers import register_exit_handlers
|
|
21
|
+
from flwr.common import Context
|
|
23
22
|
from flwr.common.logger import log
|
|
24
23
|
from flwr.common.object_ref import load_app
|
|
25
24
|
|
|
26
|
-
from .
|
|
25
|
+
from .grid import Grid
|
|
27
26
|
from .server_app import LoadServerAppError, ServerApp
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
def run(
|
|
31
|
-
|
|
30
|
+
grid: Grid,
|
|
32
31
|
context: Context,
|
|
33
32
|
server_app_dir: str,
|
|
34
33
|
server_app_attr: Optional[str] = None,
|
|
35
34
|
loaded_server_app: Optional[ServerApp] = None,
|
|
36
35
|
) -> Context:
|
|
37
|
-
"""Run ServerApp with a given
|
|
36
|
+
"""Run ServerApp with a given Grid."""
|
|
38
37
|
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
|
39
38
|
raise ValueError(
|
|
40
39
|
"Either `server_app_attr` or `loaded_server_app` should be set "
|
|
@@ -60,17 +59,7 @@ def run(
|
|
|
60
59
|
server_app = _load()
|
|
61
60
|
|
|
62
61
|
# Call ServerApp
|
|
63
|
-
server_app(
|
|
62
|
+
server_app(grid=grid, context=context)
|
|
64
63
|
|
|
65
64
|
log(DEBUG, "ServerApp finished running.")
|
|
66
65
|
return context
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def run_server_app() -> None:
|
|
70
|
-
"""Run Flower server app."""
|
|
71
|
-
event(EventType.RUN_SERVER_APP_ENTER)
|
|
72
|
-
log(
|
|
73
|
-
ERROR,
|
|
74
|
-
"The command `flower-server-app` has been replaced by `flwr run`.",
|
|
75
|
-
)
|
|
76
|
-
register_exit_handlers(event_type=EventType.RUN_SERVER_APP_LEAVE)
|