flwr 1.15.1__py3-none-any.whl → 1.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/new.py +1 -1
- flwr/cli/new/templates/app/README.baseline.md.tpl +4 -4
- flwr/cli/new/templates/app/README.md.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/client/client_app.py +147 -36
- flwr/client/clientapp/app.py +4 -0
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/constant.py +16 -0
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/message.py +18 -7
- flwr/common/object_ref.py +0 -10
- flwr/common/record/conversion_utils.py +8 -17
- flwr/common/record/parametersrecord.py +151 -16
- flwr/common/record/recordset.py +95 -88
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/serde.py +8 -126
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +36 -0
- flwr/server/app.py +18 -2
- flwr/server/compat/app.py +4 -1
- flwr/server/compat/app_utils.py +10 -2
- flwr/server/compat/driver_client_proxy.py +2 -2
- flwr/server/driver/driver.py +1 -1
- flwr/server/driver/grpc_driver.py +10 -1
- flwr/server/driver/inmemory_driver.py +17 -21
- flwr/server/run_serverapp.py +2 -13
- flwr/server/server_app.py +93 -20
- flwr/server/superlink/driver/serverappio_servicer.py +27 -33
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -16
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +32 -36
- flwr/server/superlink/linkstate/in_memory_linkstate.py +140 -126
- flwr/server/superlink/linkstate/linkstate.py +47 -60
- flwr/server/superlink/linkstate/sqlite_linkstate.py +210 -282
- flwr/server/superlink/linkstate/utils.py +91 -119
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -71
- flwr/server/workflow/default_workflows.py +4 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
- flwr/superexec/app.py +0 -14
- flwr/superexec/exec_servicer.py +4 -4
- flwr/superexec/exec_user_auth_interceptor.py +5 -3
- {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/METADATA +5 -5
- {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/RECORD +66 -69
- {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -103
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.1.dist-info → flwr-1.16.0.dist-info}/WHEEL +0 -0
|
@@ -23,7 +23,6 @@ from uuid import UUID
|
|
|
23
23
|
|
|
24
24
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
25
25
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
26
|
-
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
27
26
|
from flwr.common.typing import Run
|
|
28
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
29
28
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
@@ -60,6 +59,7 @@ class InMemoryDriver(Driver):
|
|
|
60
59
|
and message.metadata.message_id == ""
|
|
61
60
|
and message.metadata.reply_to_message == ""
|
|
62
61
|
and message.metadata.ttl > 0
|
|
62
|
+
and message.metadata.delivered_at == ""
|
|
63
63
|
):
|
|
64
64
|
raise ValueError(f"Invalid message: {message}")
|
|
65
65
|
|
|
@@ -109,9 +109,9 @@ class InMemoryDriver(Driver):
|
|
|
109
109
|
)
|
|
110
110
|
return Message(metadata=metadata, content=content)
|
|
111
111
|
|
|
112
|
-
def get_node_ids(self) ->
|
|
112
|
+
def get_node_ids(self) -> Iterable[int]:
|
|
113
113
|
"""Get node IDs."""
|
|
114
|
-
return
|
|
114
|
+
return self.state.get_nodes(cast(Run, self._run).run_id)
|
|
115
115
|
|
|
116
116
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
117
117
|
"""Push messages to specified node IDs.
|
|
@@ -119,19 +119,16 @@ class InMemoryDriver(Driver):
|
|
|
119
119
|
This method takes an iterable of messages and sends each message
|
|
120
120
|
to the node specified in `dst_node_id`.
|
|
121
121
|
"""
|
|
122
|
-
|
|
122
|
+
msg_ids: list[str] = []
|
|
123
123
|
for msg in messages:
|
|
124
124
|
# Check message
|
|
125
125
|
self._check_message(msg)
|
|
126
|
-
# Convert Message to TaskIns
|
|
127
|
-
taskins = message_to_taskins(msg)
|
|
128
126
|
# Store in state
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
task_ids.append(str(task_id))
|
|
127
|
+
msg_id = self.state.store_message_ins(msg)
|
|
128
|
+
if msg_id:
|
|
129
|
+
msg_ids.append(str(msg_id))
|
|
133
130
|
|
|
134
|
-
return
|
|
131
|
+
return msg_ids
|
|
135
132
|
|
|
136
133
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
137
134
|
"""Pull messages based on message IDs.
|
|
@@ -140,17 +137,16 @@ class InMemoryDriver(Driver):
|
|
|
140
137
|
set of given message IDs.
|
|
141
138
|
"""
|
|
142
139
|
msg_ids = {UUID(msg_id) for msg_id in message_ids}
|
|
143
|
-
# Pull
|
|
144
|
-
|
|
145
|
-
#
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
140
|
+
# Pull Messages
|
|
141
|
+
message_res_list = self.state.get_message_res(message_ids=msg_ids)
|
|
142
|
+
# Get IDs of Messages these replies are for
|
|
143
|
+
message_ins_ids_to_delete = {
|
|
144
|
+
UUID(msg_res.metadata.reply_to_message) for msg_res in message_res_list
|
|
149
145
|
}
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
return
|
|
146
|
+
# Delete
|
|
147
|
+
self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
|
148
|
+
|
|
149
|
+
return message_res_list
|
|
154
150
|
|
|
155
151
|
def send_and_receive(
|
|
156
152
|
self,
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -15,11 +15,10 @@
|
|
|
15
15
|
"""Run ServerApp."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from logging import DEBUG
|
|
18
|
+
from logging import DEBUG
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
from flwr.common import Context
|
|
22
|
-
from flwr.common.exit_handlers import register_exit_handlers
|
|
21
|
+
from flwr.common import Context
|
|
23
22
|
from flwr.common.logger import log
|
|
24
23
|
from flwr.common.object_ref import load_app
|
|
25
24
|
|
|
@@ -64,13 +63,3 @@ def run(
|
|
|
64
63
|
|
|
65
64
|
log(DEBUG, "ServerApp finished running.")
|
|
66
65
|
return context
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def run_server_app() -> None:
|
|
70
|
-
"""Run Flower server app."""
|
|
71
|
-
event(EventType.RUN_SERVER_APP_ENTER)
|
|
72
|
-
log(
|
|
73
|
-
ERROR,
|
|
74
|
-
"The command `flower-server-app` has been replaced by `flwr run`.",
|
|
75
|
-
)
|
|
76
|
-
register_exit_handlers(event_type=EventType.RUN_SERVER_APP_LEAVE)
|
flwr/server/server_app.py
CHANGED
|
@@ -15,6 +15,8 @@
|
|
|
15
15
|
"""Flower ServerApp."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from collections.abc import Iterator
|
|
19
|
+
from contextlib import contextmanager
|
|
18
20
|
from typing import Callable, Optional
|
|
19
21
|
|
|
20
22
|
from flwr.common import Context
|
|
@@ -45,7 +47,12 @@ SERVER_FN_USAGE_EXAMPLE = """
|
|
|
45
47
|
"""
|
|
46
48
|
|
|
47
49
|
|
|
48
|
-
|
|
50
|
+
@contextmanager
|
|
51
|
+
def _empty_lifespan(_: Context) -> Iterator[None]:
|
|
52
|
+
yield
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
49
56
|
"""Flower ServerApp.
|
|
50
57
|
|
|
51
58
|
Examples
|
|
@@ -105,29 +112,31 @@ class ServerApp:
|
|
|
105
112
|
self._client_manager = client_manager
|
|
106
113
|
self._server_fn = server_fn
|
|
107
114
|
self._main: Optional[ServerAppCallable] = None
|
|
115
|
+
self._lifespan = _empty_lifespan
|
|
108
116
|
|
|
109
117
|
def __call__(self, driver: Driver, context: Context) -> None:
|
|
110
118
|
"""Execute `ServerApp`."""
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
if self.
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
119
|
+
with self._lifespan(context):
|
|
120
|
+
# Compatibility mode
|
|
121
|
+
if not self._main:
|
|
122
|
+
if self._server_fn:
|
|
123
|
+
# Execute server_fn()
|
|
124
|
+
components = self._server_fn(context)
|
|
125
|
+
self._server = components.server
|
|
126
|
+
self._config = components.config
|
|
127
|
+
self._strategy = components.strategy
|
|
128
|
+
self._client_manager = components.client_manager
|
|
129
|
+
start_driver(
|
|
130
|
+
server=self._server,
|
|
131
|
+
config=self._config,
|
|
132
|
+
strategy=self._strategy,
|
|
133
|
+
client_manager=self._client_manager,
|
|
134
|
+
driver=driver,
|
|
135
|
+
)
|
|
136
|
+
return
|
|
128
137
|
|
|
129
|
-
|
|
130
|
-
|
|
138
|
+
# New execution mode
|
|
139
|
+
self._main(driver, context)
|
|
131
140
|
|
|
132
141
|
def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
|
|
133
142
|
"""Return a decorator that registers the main fn with the server app.
|
|
@@ -177,6 +186,70 @@ class ServerApp:
|
|
|
177
186
|
|
|
178
187
|
return main_decorator
|
|
179
188
|
|
|
189
|
+
def lifespan(
|
|
190
|
+
self,
|
|
191
|
+
) -> Callable[
|
|
192
|
+
[Callable[[Context], Iterator[None]]], Callable[[Context], Iterator[None]]
|
|
193
|
+
]:
|
|
194
|
+
"""Return a decorator that registers the lifespan fn with the server app.
|
|
195
|
+
|
|
196
|
+
The decorated function should accept a `Context` object and use `yield`
|
|
197
|
+
to define enter and exit behavior.
|
|
198
|
+
|
|
199
|
+
Examples
|
|
200
|
+
--------
|
|
201
|
+
>>> app = ServerApp()
|
|
202
|
+
>>>
|
|
203
|
+
>>> @app.lifespan()
|
|
204
|
+
>>> def lifespan(context: Context) -> None:
|
|
205
|
+
>>> # Perform initialization tasks before the app starts
|
|
206
|
+
>>> print("Initializing ServerApp")
|
|
207
|
+
>>>
|
|
208
|
+
>>> yield # ServerApp is running
|
|
209
|
+
>>>
|
|
210
|
+
>>> # Perform cleanup tasks after the app stops
|
|
211
|
+
>>> print("Cleaning up ServerApp")
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def lifespan_decorator(
|
|
215
|
+
lifespan_fn: Callable[[Context], Iterator[None]]
|
|
216
|
+
) -> Callable[[Context], Iterator[None]]:
|
|
217
|
+
"""Register the lifespan fn with the ServerApp object."""
|
|
218
|
+
warn_preview_feature("ServerApp-register-lifespan-function")
|
|
219
|
+
|
|
220
|
+
@contextmanager
|
|
221
|
+
def decorated_lifespan(context: Context) -> Iterator[None]:
|
|
222
|
+
# Execute the code before `yield` in lifespan_fn
|
|
223
|
+
try:
|
|
224
|
+
if not isinstance(it := lifespan_fn(context), Iterator):
|
|
225
|
+
raise StopIteration
|
|
226
|
+
next(it)
|
|
227
|
+
except StopIteration:
|
|
228
|
+
raise RuntimeError(
|
|
229
|
+
"lifespan function should yield at least once."
|
|
230
|
+
) from None
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
# Enter the context
|
|
234
|
+
yield
|
|
235
|
+
finally:
|
|
236
|
+
try:
|
|
237
|
+
# Execute the code after `yield` in lifespan_fn
|
|
238
|
+
next(it)
|
|
239
|
+
except StopIteration:
|
|
240
|
+
pass
|
|
241
|
+
else:
|
|
242
|
+
raise RuntimeError("lifespan function should only yield once.")
|
|
243
|
+
|
|
244
|
+
# Register provided function with the ServerApp object
|
|
245
|
+
# Ignore mypy error because of different argument names (`_` vs `context`)
|
|
246
|
+
self._lifespan = decorated_lifespan # type: ignore
|
|
247
|
+
|
|
248
|
+
# Return provided function unmodified
|
|
249
|
+
return lifespan_fn
|
|
250
|
+
|
|
251
|
+
return lifespan_decorator
|
|
252
|
+
|
|
180
253
|
|
|
181
254
|
class LoadServerAppError(Exception):
|
|
182
255
|
"""Error when trying to load `ServerApp`."""
|
|
@@ -22,8 +22,8 @@ from uuid import UUID
|
|
|
22
22
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
|
-
from flwr.common import ConfigsRecord,
|
|
26
|
-
from flwr.common.constant import Status
|
|
25
|
+
from flwr.common import ConfigsRecord, Message
|
|
26
|
+
from flwr.common.constant import SUPERLINK_NODE_ID, Status
|
|
27
27
|
from flwr.common.logger import log
|
|
28
28
|
from flwr.common.serde import (
|
|
29
29
|
context_from_proto,
|
|
@@ -31,9 +31,7 @@ from flwr.common.serde import (
|
|
|
31
31
|
fab_from_proto,
|
|
32
32
|
fab_to_proto,
|
|
33
33
|
message_from_proto,
|
|
34
|
-
message_from_taskres,
|
|
35
34
|
message_to_proto,
|
|
36
|
-
message_to_taskins,
|
|
37
35
|
run_status_from_proto,
|
|
38
36
|
run_status_to_proto,
|
|
39
37
|
run_to_proto,
|
|
@@ -69,12 +67,11 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
|
69
67
|
PushServerAppOutputsRequest,
|
|
70
68
|
PushServerAppOutputsResponse,
|
|
71
69
|
)
|
|
72
|
-
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
73
70
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
74
71
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
75
72
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
76
73
|
from flwr.server.superlink.utils import abort_if
|
|
77
|
-
from flwr.server.utils.validator import
|
|
74
|
+
from flwr.server.utils.validator import validate_message
|
|
78
75
|
|
|
79
76
|
|
|
80
77
|
class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
@@ -151,9 +148,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
151
148
|
context,
|
|
152
149
|
)
|
|
153
150
|
|
|
154
|
-
# Set pushed_at (timestamp in seconds)
|
|
155
|
-
pushed_at = now().timestamp()
|
|
156
|
-
|
|
157
151
|
# Validate request and insert in State
|
|
158
152
|
_raise_if(
|
|
159
153
|
validation_error=len(request.messages_list) == 0,
|
|
@@ -164,21 +158,19 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
164
158
|
while request.messages_list:
|
|
165
159
|
message_proto = request.messages_list.pop(0)
|
|
166
160
|
message = message_from_proto(message_proto=message_proto)
|
|
167
|
-
|
|
168
|
-
task_ins.task.pushed_at = pushed_at
|
|
169
|
-
validation_errors = validate_task_ins_or_res(task_ins)
|
|
161
|
+
validation_errors = validate_message(message, is_reply_message=False)
|
|
170
162
|
_raise_if(
|
|
171
163
|
validation_error=bool(validation_errors),
|
|
172
164
|
request_name="PushMessages",
|
|
173
165
|
detail=", ".join(validation_errors),
|
|
174
166
|
)
|
|
175
167
|
_raise_if(
|
|
176
|
-
validation_error=request.run_id !=
|
|
168
|
+
validation_error=request.run_id != message.metadata.run_id,
|
|
177
169
|
request_name="PushMessages",
|
|
178
|
-
detail="`
|
|
170
|
+
detail="`Message.metadata` has mismatched `run_id`",
|
|
179
171
|
)
|
|
180
172
|
# Store
|
|
181
|
-
message_id: Optional[UUID] = state.
|
|
173
|
+
message_id: Optional[UUID] = state.store_message_ins(message=message)
|
|
182
174
|
message_ids.append(message_id)
|
|
183
175
|
|
|
184
176
|
return PushInsMessagesResponse(
|
|
@@ -204,32 +196,34 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
204
196
|
context,
|
|
205
197
|
)
|
|
206
198
|
|
|
207
|
-
# Convert each
|
|
199
|
+
# Convert each message_id str to UUID
|
|
208
200
|
message_ids: set[UUID] = {
|
|
209
201
|
UUID(message_id) for message_id in request.message_ids
|
|
210
202
|
}
|
|
211
203
|
|
|
212
204
|
# Read from state
|
|
213
|
-
|
|
205
|
+
messages_res: list[Message] = state.get_message_res(message_ids=message_ids)
|
|
214
206
|
|
|
215
|
-
#
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
task_res = task_res_list.pop(0)
|
|
219
|
-
_raise_if(
|
|
220
|
-
validation_error=request.run_id != task_res.run_id,
|
|
221
|
-
request_name="PullMessages",
|
|
222
|
-
detail="`task_res` has mismatched `run_id`",
|
|
223
|
-
)
|
|
224
|
-
message = message_from_taskres(taskres=task_res)
|
|
225
|
-
messages_list.append(message_to_proto(message))
|
|
226
|
-
|
|
227
|
-
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
228
|
-
task_ins_ids_to_delete = {
|
|
229
|
-
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
207
|
+
# Delete the instruction Messages and their replies if found
|
|
208
|
+
message_ins_ids_to_delete = {
|
|
209
|
+
UUID(msg_res.metadata.reply_to_message) for msg_res in messages_res
|
|
230
210
|
}
|
|
231
211
|
|
|
232
|
-
state.
|
|
212
|
+
state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
|
213
|
+
|
|
214
|
+
# Convert Messages to proto
|
|
215
|
+
messages_list = []
|
|
216
|
+
while messages_res:
|
|
217
|
+
msg = messages_res.pop(0)
|
|
218
|
+
|
|
219
|
+
# Skip `run_id` check for SuperLink generated replies
|
|
220
|
+
if msg.metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
221
|
+
_raise_if(
|
|
222
|
+
validation_error=request.run_id != msg.metadata.run_id,
|
|
223
|
+
request_name="PullMessages",
|
|
224
|
+
detail="`message.metadata` has mismatched `run_id`",
|
|
225
|
+
)
|
|
226
|
+
messages_list.append(message_to_proto(msg))
|
|
233
227
|
|
|
234
228
|
return PullResMessagesResponse(messages_list=messages_list)
|
|
235
229
|
|
|
@@ -103,11 +103,11 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
103
103
|
if request.messages_list:
|
|
104
104
|
log(
|
|
105
105
|
INFO,
|
|
106
|
-
"[Fleet.PushMessages] Push
|
|
106
|
+
"[Fleet.PushMessages] Push replies from node_id=%s",
|
|
107
107
|
request.messages_list[0].metadata.src_node_id,
|
|
108
108
|
)
|
|
109
109
|
else:
|
|
110
|
-
log(INFO, "[Fleet.PushMessages] No
|
|
110
|
+
log(INFO, "[Fleet.PushMessages] No replies to push")
|
|
111
111
|
|
|
112
112
|
try:
|
|
113
113
|
res = message_handler.push_messages(
|
|
@@ -15,17 +15,15 @@
|
|
|
15
15
|
"""Fleet API message handlers."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import time
|
|
19
18
|
from typing import Optional
|
|
20
19
|
from uuid import UUID
|
|
21
20
|
|
|
21
|
+
from flwr.common import Message
|
|
22
22
|
from flwr.common.constant import Status
|
|
23
23
|
from flwr.common.serde import (
|
|
24
24
|
fab_to_proto,
|
|
25
25
|
message_from_proto,
|
|
26
|
-
message_from_taskins,
|
|
27
26
|
message_to_proto,
|
|
28
|
-
message_to_taskres,
|
|
29
27
|
user_config_to_proto,
|
|
30
28
|
)
|
|
31
29
|
from flwr.common.typing import Fab, InvalidRunStatusException
|
|
@@ -49,7 +47,6 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
49
47
|
GetRunResponse,
|
|
50
48
|
Run,
|
|
51
49
|
)
|
|
52
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
53
50
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
54
51
|
from flwr.server.superlink.linkstate import LinkState
|
|
55
52
|
from flwr.server.superlink.utils import check_abort
|
|
@@ -93,13 +90,12 @@ def pull_messages(
|
|
|
93
90
|
node = request.node # pylint: disable=no-member
|
|
94
91
|
node_id: int = node.node_id
|
|
95
92
|
|
|
96
|
-
# Retrieve
|
|
97
|
-
|
|
93
|
+
# Retrieve Message from State
|
|
94
|
+
message_list: list[Message] = state.get_message_ins(node_id=node_id, limit=1)
|
|
98
95
|
|
|
99
96
|
# Convert to Messages
|
|
100
97
|
msg_proto = []
|
|
101
|
-
for
|
|
102
|
-
msg = message_from_taskins(task_ins)
|
|
98
|
+
for msg in message_list:
|
|
103
99
|
msg_proto.append(message_to_proto(msg))
|
|
104
100
|
|
|
105
101
|
return PullMessagesResponse(messages_list=msg_proto)
|
|
@@ -109,24 +105,20 @@ def push_messages(
|
|
|
109
105
|
request: PushMessagesRequest, state: LinkState
|
|
110
106
|
) -> PushMessagesResponse:
|
|
111
107
|
"""Push Messages handler."""
|
|
112
|
-
# Convert Message
|
|
108
|
+
# Convert Message from proto
|
|
113
109
|
msg = message_from_proto(message_proto=request.messages_list[0])
|
|
114
|
-
task_res = message_to_taskres(msg)
|
|
115
110
|
|
|
116
111
|
# Abort if the run is not running
|
|
117
112
|
abort_msg = check_abort(
|
|
118
|
-
|
|
113
|
+
msg.metadata.run_id,
|
|
119
114
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
120
115
|
state,
|
|
121
116
|
)
|
|
122
117
|
if abort_msg:
|
|
123
118
|
raise InvalidRunStatusException(abort_msg)
|
|
124
119
|
|
|
125
|
-
#
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
# Store TaskRes in State
|
|
129
|
-
message_id: Optional[UUID] = state.store_task_res(task_res=task_res)
|
|
120
|
+
# Store Message in State
|
|
121
|
+
message_id: Optional[UUID] = state.store_message_res(message=msg)
|
|
130
122
|
|
|
131
123
|
# Build response
|
|
132
124
|
response = PushMessagesResponse(
|
|
@@ -45,7 +45,7 @@ class Backend(ABC):
|
|
|
45
45
|
def num_workers(self) -> int:
|
|
46
46
|
"""Return number of workers in the backend.
|
|
47
47
|
|
|
48
|
-
This is the number of
|
|
48
|
+
This is the number of Messages that can be processed concurrently.
|
|
49
49
|
"""
|
|
50
50
|
return 0
|
|
51
51
|
|
|
@@ -29,6 +29,7 @@ from typing import Callable, Optional
|
|
|
29
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
30
30
|
from flwr.client.clientapp.utils import get_load_client_app_fn
|
|
31
31
|
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
32
|
+
from flwr.common import Message
|
|
32
33
|
from flwr.common.constant import (
|
|
33
34
|
NUM_PARTITIONS_KEY,
|
|
34
35
|
PARTITION_ID_KEY,
|
|
@@ -37,9 +38,7 @@ from flwr.common.constant import (
|
|
|
37
38
|
)
|
|
38
39
|
from flwr.common.logger import log
|
|
39
40
|
from flwr.common.message import Error
|
|
40
|
-
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
41
41
|
from flwr.common.typing import Run
|
|
42
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
43
42
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
44
43
|
|
|
45
44
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
@@ -87,33 +86,33 @@ def _register_node_info_stores(
|
|
|
87
86
|
|
|
88
87
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
89
88
|
def worker(
|
|
90
|
-
|
|
91
|
-
|
|
89
|
+
messageins_queue: Queue[Message],
|
|
90
|
+
messageres_queue: Queue[Message],
|
|
92
91
|
node_info_store: dict[int, DeprecatedRunInfoStore],
|
|
93
92
|
backend: Backend,
|
|
94
93
|
f_stop: threading.Event,
|
|
95
94
|
) -> None:
|
|
96
|
-
"""
|
|
95
|
+
"""Process messages from the queue, execute them, update context, and enqueue
|
|
96
|
+
replies."""
|
|
97
97
|
while not f_stop.is_set():
|
|
98
98
|
out_mssg = None
|
|
99
99
|
try:
|
|
100
100
|
# Fetch from queue with timeout. We use a timeout so
|
|
101
101
|
# the stopping event can be evaluated even when the queue is empty.
|
|
102
|
-
|
|
103
|
-
node_id =
|
|
102
|
+
message: Message = messageins_queue.get(timeout=1.0)
|
|
103
|
+
node_id = message.metadata.dst_node_id
|
|
104
104
|
|
|
105
105
|
# Retrieve context
|
|
106
|
-
context = node_info_store[node_id].retrieve_context(
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
message = message_from_taskins(task_ins)
|
|
106
|
+
context = node_info_store[node_id].retrieve_context(
|
|
107
|
+
run_id=message.metadata.run_id
|
|
108
|
+
)
|
|
110
109
|
|
|
111
110
|
# Let backend process message
|
|
112
111
|
out_mssg, updated_context = backend.process_message(message, context)
|
|
113
112
|
|
|
114
113
|
# Update Context
|
|
115
114
|
node_info_store[node_id].update_context(
|
|
116
|
-
|
|
115
|
+
message.metadata.run_id, context=updated_context
|
|
117
116
|
)
|
|
118
117
|
except Empty:
|
|
119
118
|
# An exception raised if queue.get times out
|
|
@@ -137,36 +136,33 @@ def worker(
|
|
|
137
136
|
|
|
138
137
|
finally:
|
|
139
138
|
if out_mssg:
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
# Store TaskRes in state
|
|
143
|
-
task_res.task.pushed_at = time.time()
|
|
144
|
-
taskres_queue.put(task_res)
|
|
139
|
+
# Store reply Messages in state
|
|
140
|
+
messageres_queue.put(out_mssg)
|
|
145
141
|
|
|
146
142
|
|
|
147
|
-
def
|
|
143
|
+
def add_messages_to_queue(
|
|
148
144
|
state: LinkState,
|
|
149
|
-
queue:
|
|
145
|
+
queue: Queue[Message],
|
|
150
146
|
nodes_mapping: NodeToPartitionMapping,
|
|
151
147
|
f_stop: threading.Event,
|
|
152
148
|
) -> None:
|
|
153
|
-
"""Put
|
|
149
|
+
"""Put Messages in the queue from the LinkState."""
|
|
154
150
|
while not f_stop.is_set():
|
|
155
151
|
for node_id in nodes_mapping.keys():
|
|
156
|
-
|
|
157
|
-
for
|
|
158
|
-
queue.put(
|
|
152
|
+
message_ins_list = state.get_message_ins(node_id=node_id, limit=1)
|
|
153
|
+
for msg in message_ins_list:
|
|
154
|
+
queue.put(msg)
|
|
159
155
|
sleep(0.1)
|
|
160
156
|
|
|
161
157
|
|
|
162
|
-
def
|
|
163
|
-
state: LinkState, queue:
|
|
158
|
+
def put_message_into_state(
|
|
159
|
+
state: LinkState, queue: Queue[Message], f_stop: threading.Event
|
|
164
160
|
) -> None:
|
|
165
|
-
"""
|
|
161
|
+
"""Store reply Messages into the LinkState from the queue."""
|
|
166
162
|
while not f_stop.is_set():
|
|
167
163
|
try:
|
|
168
|
-
|
|
169
|
-
state.
|
|
164
|
+
message_reply = queue.get(timeout=1.0)
|
|
165
|
+
state.store_message_res(message_reply)
|
|
170
166
|
except Empty:
|
|
171
167
|
# queue is empty when timeout was triggered
|
|
172
168
|
pass
|
|
@@ -182,8 +178,8 @@ def run_api(
|
|
|
182
178
|
f_stop: threading.Event,
|
|
183
179
|
) -> None:
|
|
184
180
|
"""Run the VCE."""
|
|
185
|
-
|
|
186
|
-
|
|
181
|
+
messageins_queue: Queue[Message] = Queue()
|
|
182
|
+
messageres_queue: Queue[Message] = Queue()
|
|
187
183
|
|
|
188
184
|
try:
|
|
189
185
|
|
|
@@ -197,10 +193,10 @@ def run_api(
|
|
|
197
193
|
state = state_factory.state()
|
|
198
194
|
|
|
199
195
|
extractor_th = threading.Thread(
|
|
200
|
-
target=
|
|
196
|
+
target=add_messages_to_queue,
|
|
201
197
|
args=(
|
|
202
198
|
state,
|
|
203
|
-
|
|
199
|
+
messageins_queue,
|
|
204
200
|
nodes_mapping,
|
|
205
201
|
f_stop,
|
|
206
202
|
),
|
|
@@ -208,10 +204,10 @@ def run_api(
|
|
|
208
204
|
extractor_th.start()
|
|
209
205
|
|
|
210
206
|
injector_th = threading.Thread(
|
|
211
|
-
target=
|
|
207
|
+
target=put_message_into_state,
|
|
212
208
|
args=(
|
|
213
209
|
state,
|
|
214
|
-
|
|
210
|
+
messageres_queue,
|
|
215
211
|
f_stop,
|
|
216
212
|
),
|
|
217
213
|
)
|
|
@@ -221,8 +217,8 @@ def run_api(
|
|
|
221
217
|
_ = [
|
|
222
218
|
executor.submit(
|
|
223
219
|
worker,
|
|
224
|
-
|
|
225
|
-
|
|
220
|
+
messageins_queue,
|
|
221
|
+
messageres_queue,
|
|
226
222
|
node_info_stores,
|
|
227
223
|
backend,
|
|
228
224
|
f_stop,
|