flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +260 -86
- flwr/client/clientapp/app.py +6 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/run_info_store.py +2 -2
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/__init__.py +12 -4
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/config.py +4 -4
- flwr/common/constant.py +16 -0
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +2 -2
- flwr/common/message.py +338 -102
- flwr/common/object_ref.py +0 -10
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +9 -18
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +67 -190
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +44 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +74 -3
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +15 -12
- flwr/server/compat/app_utils.py +26 -18
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
- flwr/server/run_serverapp.py +6 -17
- flwr/server/server_app.py +126 -33
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +33 -38
- flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
- flwr/server/superlink/linkstate/linkstate.py +51 -64
- flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
- flwr/server/superlink/linkstate/utils.py +171 -133
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +52 -58
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/app.py +0 -14
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +6 -6
- flwr/superexec/exec_user_auth_interceptor.py +22 -4
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/common/record/parametersrecord.py +0 -204
- flwr/common/record/recordset.py +0 -202
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
flwr/server/server_app.py
CHANGED
|
@@ -15,18 +15,18 @@
|
|
|
15
15
|
"""Flower ServerApp."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import inspect
|
|
19
|
+
from collections.abc import Iterator
|
|
20
|
+
from contextlib import contextmanager
|
|
18
21
|
from typing import Callable, Optional
|
|
19
22
|
|
|
20
23
|
from flwr.common import Context
|
|
21
|
-
from flwr.common.logger import
|
|
22
|
-
warn_deprecated_feature_with_example,
|
|
23
|
-
warn_preview_feature,
|
|
24
|
-
)
|
|
24
|
+
from flwr.common.logger import warn_deprecated_feature_with_example
|
|
25
25
|
from flwr.server.strategy import Strategy
|
|
26
26
|
|
|
27
27
|
from .client_manager import ClientManager
|
|
28
|
-
from .compat import
|
|
29
|
-
from .
|
|
28
|
+
from .compat import start_grid
|
|
29
|
+
from .grid import Driver, Grid
|
|
30
30
|
from .server import Server
|
|
31
31
|
from .server_config import ServerConfig
|
|
32
32
|
from .typing import ServerAppCallable, ServerFn
|
|
@@ -44,13 +44,33 @@ SERVER_FN_USAGE_EXAMPLE = """
|
|
|
44
44
|
app = ServerApp(server_fn=server_fn)
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
|
+
GRID_USAGE_EXAMPLE = """
|
|
48
|
+
app = ServerApp()
|
|
47
49
|
|
|
48
|
-
|
|
50
|
+
@app.main()
|
|
51
|
+
def main(grid: Grid, context: Context) -> None:
|
|
52
|
+
# Your existing ServerApp code ...
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
DRIVER_DEPRECATION_MSG = """
|
|
56
|
+
The `Driver` class is deprecated, it will be removed in a future release.
|
|
57
|
+
"""
|
|
58
|
+
DRIVER_EXAMPLE_MSG = """
|
|
59
|
+
Instead, use `Grid` in the signature of your `ServerApp`. For example:
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@contextmanager
|
|
64
|
+
def _empty_lifespan(_: Context) -> Iterator[None]:
|
|
65
|
+
yield
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
49
69
|
"""Flower ServerApp.
|
|
50
70
|
|
|
51
71
|
Examples
|
|
52
72
|
--------
|
|
53
|
-
Use the
|
|
73
|
+
Use the ``ServerApp`` with an existing ``Strategy``:
|
|
54
74
|
|
|
55
75
|
>>> def server_fn(context: Context):
|
|
56
76
|
>>> server_config = ServerConfig(num_rounds=3)
|
|
@@ -62,12 +82,12 @@ class ServerApp:
|
|
|
62
82
|
>>>
|
|
63
83
|
>>> app = ServerApp(server_fn=server_fn)
|
|
64
84
|
|
|
65
|
-
Use the
|
|
85
|
+
Use the ``ServerApp`` with a custom main function:
|
|
66
86
|
|
|
67
87
|
>>> app = ServerApp()
|
|
68
88
|
>>>
|
|
69
89
|
>>> @app.main()
|
|
70
|
-
>>> def main(
|
|
90
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
|
71
91
|
>>> print("ServerApp running")
|
|
72
92
|
"""
|
|
73
93
|
|
|
@@ -105,29 +125,31 @@ class ServerApp:
|
|
|
105
125
|
self._client_manager = client_manager
|
|
106
126
|
self._server_fn = server_fn
|
|
107
127
|
self._main: Optional[ServerAppCallable] = None
|
|
128
|
+
self._lifespan = _empty_lifespan
|
|
108
129
|
|
|
109
|
-
def __call__(self,
|
|
130
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
|
110
131
|
"""Execute `ServerApp`."""
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
if self.
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
132
|
+
with self._lifespan(context):
|
|
133
|
+
# Compatibility mode
|
|
134
|
+
if not self._main:
|
|
135
|
+
if self._server_fn:
|
|
136
|
+
# Execute server_fn()
|
|
137
|
+
components = self._server_fn(context)
|
|
138
|
+
self._server = components.server
|
|
139
|
+
self._config = components.config
|
|
140
|
+
self._strategy = components.strategy
|
|
141
|
+
self._client_manager = components.client_manager
|
|
142
|
+
start_grid(
|
|
143
|
+
server=self._server,
|
|
144
|
+
config=self._config,
|
|
145
|
+
strategy=self._strategy,
|
|
146
|
+
client_manager=self._client_manager,
|
|
147
|
+
grid=grid,
|
|
148
|
+
)
|
|
149
|
+
return
|
|
128
150
|
|
|
129
|
-
|
|
130
|
-
|
|
151
|
+
# New execution mode
|
|
152
|
+
self._main(grid, context)
|
|
131
153
|
|
|
132
154
|
def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
|
|
133
155
|
"""Return a decorator that registers the main fn with the server app.
|
|
@@ -137,7 +159,7 @@ class ServerApp:
|
|
|
137
159
|
>>> app = ServerApp()
|
|
138
160
|
>>>
|
|
139
161
|
>>> @app.main()
|
|
140
|
-
>>> def main(
|
|
162
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
|
141
163
|
>>> print("ServerApp running")
|
|
142
164
|
"""
|
|
143
165
|
|
|
@@ -162,12 +184,20 @@ class ServerApp:
|
|
|
162
184
|
>>> app = ServerApp()
|
|
163
185
|
>>>
|
|
164
186
|
>>> @app.main()
|
|
165
|
-
>>> def main(
|
|
187
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
|
166
188
|
>>> print("ServerApp running")
|
|
167
189
|
""",
|
|
168
190
|
)
|
|
169
191
|
|
|
170
|
-
|
|
192
|
+
sig = inspect.signature(main_fn)
|
|
193
|
+
param = list(sig.parameters.values())[0]
|
|
194
|
+
# Check if parameter name or the annotation should be updated
|
|
195
|
+
if param.name == "driver" or param.annotation is Driver:
|
|
196
|
+
warn_deprecated_feature_with_example(
|
|
197
|
+
deprecation_message=DRIVER_DEPRECATION_MSG,
|
|
198
|
+
example_message=DRIVER_EXAMPLE_MSG,
|
|
199
|
+
code_example=GRID_USAGE_EXAMPLE,
|
|
200
|
+
)
|
|
171
201
|
|
|
172
202
|
# Register provided function with the ServerApp object
|
|
173
203
|
self._main = main_fn
|
|
@@ -177,6 +207,69 @@ class ServerApp:
|
|
|
177
207
|
|
|
178
208
|
return main_decorator
|
|
179
209
|
|
|
210
|
+
def lifespan(
|
|
211
|
+
self,
|
|
212
|
+
) -> Callable[
|
|
213
|
+
[Callable[[Context], Iterator[None]]], Callable[[Context], Iterator[None]]
|
|
214
|
+
]:
|
|
215
|
+
"""Return a decorator that registers the lifespan fn with the server app.
|
|
216
|
+
|
|
217
|
+
The decorated function should accept a `Context` object and use `yield`
|
|
218
|
+
to define enter and exit behavior.
|
|
219
|
+
|
|
220
|
+
Examples
|
|
221
|
+
--------
|
|
222
|
+
>>> app = ServerApp()
|
|
223
|
+
>>>
|
|
224
|
+
>>> @app.lifespan()
|
|
225
|
+
>>> def lifespan(context: Context) -> None:
|
|
226
|
+
>>> # Perform initialization tasks before the app starts
|
|
227
|
+
>>> print("Initializing ServerApp")
|
|
228
|
+
>>>
|
|
229
|
+
>>> yield # ServerApp is running
|
|
230
|
+
>>>
|
|
231
|
+
>>> # Perform cleanup tasks after the app stops
|
|
232
|
+
>>> print("Cleaning up ServerApp")
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
def lifespan_decorator(
|
|
236
|
+
lifespan_fn: Callable[[Context], Iterator[None]],
|
|
237
|
+
) -> Callable[[Context], Iterator[None]]:
|
|
238
|
+
"""Register the lifespan fn with the ServerApp object."""
|
|
239
|
+
|
|
240
|
+
@contextmanager
|
|
241
|
+
def decorated_lifespan(context: Context) -> Iterator[None]:
|
|
242
|
+
# Execute the code before `yield` in lifespan_fn
|
|
243
|
+
try:
|
|
244
|
+
if not isinstance(it := lifespan_fn(context), Iterator):
|
|
245
|
+
raise StopIteration
|
|
246
|
+
next(it)
|
|
247
|
+
except StopIteration:
|
|
248
|
+
raise RuntimeError(
|
|
249
|
+
"lifespan function should yield at least once."
|
|
250
|
+
) from None
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
# Enter the context
|
|
254
|
+
yield
|
|
255
|
+
finally:
|
|
256
|
+
try:
|
|
257
|
+
# Execute the code after `yield` in lifespan_fn
|
|
258
|
+
next(it)
|
|
259
|
+
except StopIteration:
|
|
260
|
+
pass
|
|
261
|
+
else:
|
|
262
|
+
raise RuntimeError("lifespan function should only yield once.")
|
|
263
|
+
|
|
264
|
+
# Register provided function with the ServerApp object
|
|
265
|
+
# Ignore mypy error because of different argument names (`_` vs `context`)
|
|
266
|
+
self._lifespan = decorated_lifespan # type: ignore
|
|
267
|
+
|
|
268
|
+
# Return provided function unmodified
|
|
269
|
+
return lifespan_fn
|
|
270
|
+
|
|
271
|
+
return lifespan_decorator
|
|
272
|
+
|
|
180
273
|
|
|
181
274
|
class LoadServerAppError(Exception):
|
|
182
275
|
"""Error when trying to load `ServerApp`."""
|
flwr/server/serverapp/app.py
CHANGED
|
@@ -60,7 +60,7 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
|
60
60
|
PullServerAppInputsResponse,
|
|
61
61
|
PushServerAppOutputsRequest,
|
|
62
62
|
)
|
|
63
|
-
from flwr.server.
|
|
63
|
+
from flwr.server.grid.grpc_grid import GrpcGrid
|
|
64
64
|
from flwr.server.run_serverapp import run as run_
|
|
65
65
|
|
|
66
66
|
|
|
@@ -106,7 +106,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
106
106
|
certificates: Optional[bytes] = None,
|
|
107
107
|
) -> None:
|
|
108
108
|
"""Run Flower ServerApp process."""
|
|
109
|
-
|
|
109
|
+
grid = GrpcGrid(
|
|
110
110
|
serverappio_service_address=serverappio_api_address,
|
|
111
111
|
root_certificates=certificates,
|
|
112
112
|
)
|
|
@@ -123,7 +123,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
123
123
|
# Pull ServerAppInputs from LinkState
|
|
124
124
|
req = PullServerAppInputsRequest()
|
|
125
125
|
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
|
126
|
-
res: PullServerAppInputsResponse =
|
|
126
|
+
res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
|
|
127
127
|
if not res.HasField("run"):
|
|
128
128
|
sleep(3)
|
|
129
129
|
run_status = None
|
|
@@ -135,14 +135,14 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
135
135
|
|
|
136
136
|
hash_run_id = get_sha256_hash(run.run_id)
|
|
137
137
|
|
|
138
|
-
|
|
138
|
+
grid.set_run(run.run_id)
|
|
139
139
|
|
|
140
140
|
# Start log uploader for this run
|
|
141
141
|
log_uploader = start_log_uploader(
|
|
142
142
|
log_queue=log_queue,
|
|
143
143
|
node_id=0,
|
|
144
144
|
run_id=run.run_id,
|
|
145
|
-
stub=
|
|
145
|
+
stub=grid._stub,
|
|
146
146
|
)
|
|
147
147
|
|
|
148
148
|
log(DEBUG, "[flwr-serverapp] Start FAB installation.")
|
|
@@ -173,7 +173,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
173
173
|
|
|
174
174
|
# Change status to Running
|
|
175
175
|
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
|
176
|
-
|
|
176
|
+
grid._stub.UpdateRunStatus(
|
|
177
177
|
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
178
178
|
)
|
|
179
179
|
|
|
@@ -182,9 +182,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
182
182
|
event_details={"run-id-hash": hash_run_id},
|
|
183
183
|
)
|
|
184
184
|
|
|
185
|
-
# Load and run the ServerApp with the
|
|
185
|
+
# Load and run the ServerApp with the Grid
|
|
186
186
|
updated_context = run_(
|
|
187
|
-
|
|
187
|
+
grid=grid,
|
|
188
188
|
server_app_dir=app_path,
|
|
189
189
|
server_app_attr=server_app_attr,
|
|
190
190
|
context=context,
|
|
@@ -196,7 +196,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
196
196
|
out_req = PushServerAppOutputsRequest(
|
|
197
197
|
run_id=run.run_id, context=context_proto
|
|
198
198
|
)
|
|
199
|
-
_ =
|
|
199
|
+
_ = grid._stub.PushServerAppOutputs(out_req)
|
|
200
200
|
|
|
201
201
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
202
202
|
except RunNotRunningException:
|
|
@@ -221,7 +221,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
221
221
|
# Update run status
|
|
222
222
|
if run_status:
|
|
223
223
|
run_status_proto = run_status_to_proto(run_status)
|
|
224
|
-
|
|
224
|
+
grid._stub.UpdateRunStatus(
|
|
225
225
|
UpdateRunStatusRequest(
|
|
226
226
|
run_id=run.run_id, run_status=run_status_proto
|
|
227
227
|
)
|
|
@@ -103,11 +103,11 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
103
103
|
if request.messages_list:
|
|
104
104
|
log(
|
|
105
105
|
INFO,
|
|
106
|
-
"[Fleet.PushMessages] Push
|
|
106
|
+
"[Fleet.PushMessages] Push replies from node_id=%s",
|
|
107
107
|
request.messages_list[0].metadata.src_node_id,
|
|
108
108
|
)
|
|
109
109
|
else:
|
|
110
|
-
log(INFO, "[Fleet.PushMessages] No
|
|
110
|
+
log(INFO, "[Fleet.PushMessages] No replies to push")
|
|
111
111
|
|
|
112
112
|
try:
|
|
113
113
|
res = message_handler.push_messages(
|
|
@@ -18,13 +18,12 @@
|
|
|
18
18
|
from typing import Optional
|
|
19
19
|
from uuid import UUID
|
|
20
20
|
|
|
21
|
+
from flwr.common import Message
|
|
21
22
|
from flwr.common.constant import Status
|
|
22
23
|
from flwr.common.serde import (
|
|
23
24
|
fab_to_proto,
|
|
24
25
|
message_from_proto,
|
|
25
|
-
message_from_taskins,
|
|
26
26
|
message_to_proto,
|
|
27
|
-
message_to_taskres,
|
|
28
27
|
user_config_to_proto,
|
|
29
28
|
)
|
|
30
29
|
from flwr.common.typing import Fab, InvalidRunStatusException
|
|
@@ -48,7 +47,6 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
48
47
|
GetRunResponse,
|
|
49
48
|
Run,
|
|
50
49
|
)
|
|
51
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
52
50
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
53
51
|
from flwr.server.superlink.linkstate import LinkState
|
|
54
52
|
from flwr.server.superlink.utils import check_abort
|
|
@@ -92,13 +90,12 @@ def pull_messages(
|
|
|
92
90
|
node = request.node # pylint: disable=no-member
|
|
93
91
|
node_id: int = node.node_id
|
|
94
92
|
|
|
95
|
-
# Retrieve
|
|
96
|
-
|
|
93
|
+
# Retrieve Message from State
|
|
94
|
+
message_list: list[Message] = state.get_message_ins(node_id=node_id, limit=1)
|
|
97
95
|
|
|
98
96
|
# Convert to Messages
|
|
99
97
|
msg_proto = []
|
|
100
|
-
for
|
|
101
|
-
msg = message_from_taskins(task_ins)
|
|
98
|
+
for msg in message_list:
|
|
102
99
|
msg_proto.append(message_to_proto(msg))
|
|
103
100
|
|
|
104
101
|
return PullMessagesResponse(messages_list=msg_proto)
|
|
@@ -108,21 +105,20 @@ def push_messages(
|
|
|
108
105
|
request: PushMessagesRequest, state: LinkState
|
|
109
106
|
) -> PushMessagesResponse:
|
|
110
107
|
"""Push Messages handler."""
|
|
111
|
-
# Convert Message
|
|
108
|
+
# Convert Message from proto
|
|
112
109
|
msg = message_from_proto(message_proto=request.messages_list[0])
|
|
113
|
-
task_res = message_to_taskres(msg)
|
|
114
110
|
|
|
115
111
|
# Abort if the run is not running
|
|
116
112
|
abort_msg = check_abort(
|
|
117
|
-
|
|
113
|
+
msg.metadata.run_id,
|
|
118
114
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
119
115
|
state,
|
|
120
116
|
)
|
|
121
117
|
if abort_msg:
|
|
122
118
|
raise InvalidRunStatusException(abort_msg)
|
|
123
119
|
|
|
124
|
-
# Store
|
|
125
|
-
message_id: Optional[UUID] = state.
|
|
120
|
+
# Store Message in State
|
|
121
|
+
message_id: Optional[UUID] = state.store_message_res(message=msg)
|
|
126
122
|
|
|
127
123
|
# Build response
|
|
128
124
|
response = PushMessagesResponse(
|
|
@@ -21,9 +21,9 @@ from typing import Callable
|
|
|
21
21
|
from flwr.client.client_app import ClientApp
|
|
22
22
|
from flwr.common.context import Context
|
|
23
23
|
from flwr.common.message import Message
|
|
24
|
-
from flwr.common.typing import
|
|
24
|
+
from flwr.common.typing import ConfigRecordValues
|
|
25
25
|
|
|
26
|
-
BackendConfig = dict[str, dict[str,
|
|
26
|
+
BackendConfig = dict[str, dict[str, ConfigRecordValues]]
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class Backend(ABC):
|
|
@@ -45,7 +45,7 @@ class Backend(ABC):
|
|
|
45
45
|
def num_workers(self) -> int:
|
|
46
46
|
"""Return number of workers in the backend.
|
|
47
47
|
|
|
48
|
-
This is the number of
|
|
48
|
+
This is the number of Messages that can be processed concurrently.
|
|
49
49
|
"""
|
|
50
50
|
return 0
|
|
51
51
|
|
|
@@ -26,7 +26,7 @@ from flwr.common.constant import PARTITION_ID_KEY
|
|
|
26
26
|
from flwr.common.context import Context
|
|
27
27
|
from flwr.common.logger import log
|
|
28
28
|
from flwr.common.message import Message
|
|
29
|
-
from flwr.common.typing import
|
|
29
|
+
from flwr.common.typing import ConfigRecordValues
|
|
30
30
|
from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
|
|
31
31
|
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
|
32
32
|
|
|
@@ -104,7 +104,7 @@ class RayBackend(Backend):
|
|
|
104
104
|
if not ray.is_initialized():
|
|
105
105
|
ray_init_args: dict[
|
|
106
106
|
str,
|
|
107
|
-
|
|
107
|
+
ConfigRecordValues,
|
|
108
108
|
] = {}
|
|
109
109
|
|
|
110
110
|
if backend_config.get(self.init_args_key):
|
|
@@ -29,6 +29,7 @@ from typing import Callable, Optional
|
|
|
29
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
30
30
|
from flwr.client.clientapp.utils import get_load_client_app_fn
|
|
31
31
|
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
32
|
+
from flwr.common import Message
|
|
32
33
|
from flwr.common.constant import (
|
|
33
34
|
NUM_PARTITIONS_KEY,
|
|
34
35
|
PARTITION_ID_KEY,
|
|
@@ -37,9 +38,7 @@ from flwr.common.constant import (
|
|
|
37
38
|
)
|
|
38
39
|
from flwr.common.logger import log
|
|
39
40
|
from flwr.common.message import Error
|
|
40
|
-
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
41
41
|
from flwr.common.typing import Run
|
|
42
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
43
42
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
44
43
|
|
|
45
44
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
@@ -87,33 +86,33 @@ def _register_node_info_stores(
|
|
|
87
86
|
|
|
88
87
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
89
88
|
def worker(
|
|
90
|
-
|
|
91
|
-
|
|
89
|
+
messageins_queue: Queue[Message],
|
|
90
|
+
messageres_queue: Queue[Message],
|
|
92
91
|
node_info_store: dict[int, DeprecatedRunInfoStore],
|
|
93
92
|
backend: Backend,
|
|
94
93
|
f_stop: threading.Event,
|
|
95
94
|
) -> None:
|
|
96
|
-
"""
|
|
95
|
+
"""Process messages from the queue, execute them, update context, and enqueue
|
|
96
|
+
replies."""
|
|
97
97
|
while not f_stop.is_set():
|
|
98
98
|
out_mssg = None
|
|
99
99
|
try:
|
|
100
100
|
# Fetch from queue with timeout. We use a timeout so
|
|
101
101
|
# the stopping event can be evaluated even when the queue is empty.
|
|
102
|
-
|
|
103
|
-
node_id =
|
|
102
|
+
message: Message = messageins_queue.get(timeout=1.0)
|
|
103
|
+
node_id = message.metadata.dst_node_id
|
|
104
104
|
|
|
105
105
|
# Retrieve context
|
|
106
|
-
context = node_info_store[node_id].retrieve_context(
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
message = message_from_taskins(task_ins)
|
|
106
|
+
context = node_info_store[node_id].retrieve_context(
|
|
107
|
+
run_id=message.metadata.run_id
|
|
108
|
+
)
|
|
110
109
|
|
|
111
110
|
# Let backend process message
|
|
112
111
|
out_mssg, updated_context = backend.process_message(message, context)
|
|
113
112
|
|
|
114
113
|
# Update Context
|
|
115
114
|
node_info_store[node_id].update_context(
|
|
116
|
-
|
|
115
|
+
message.metadata.run_id, context=updated_context
|
|
117
116
|
)
|
|
118
117
|
except Empty:
|
|
119
118
|
# An exception raised if queue.get times out
|
|
@@ -131,41 +130,37 @@ def worker(
|
|
|
131
130
|
e_code = ErrorCode.UNKNOWN
|
|
132
131
|
|
|
133
132
|
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
134
|
-
out_mssg = message
|
|
135
|
-
error=Error(code=e_code, reason=reason)
|
|
136
|
-
)
|
|
133
|
+
out_mssg = Message(Error(code=e_code, reason=reason), reply_to=message)
|
|
137
134
|
|
|
138
135
|
finally:
|
|
139
136
|
if out_mssg:
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
# Store TaskRes in state
|
|
143
|
-
taskres_queue.put(task_res)
|
|
137
|
+
# Store reply Messages in state
|
|
138
|
+
messageres_queue.put(out_mssg)
|
|
144
139
|
|
|
145
140
|
|
|
146
|
-
def
|
|
141
|
+
def add_messages_to_queue(
|
|
147
142
|
state: LinkState,
|
|
148
|
-
queue:
|
|
143
|
+
queue: Queue[Message],
|
|
149
144
|
nodes_mapping: NodeToPartitionMapping,
|
|
150
145
|
f_stop: threading.Event,
|
|
151
146
|
) -> None:
|
|
152
|
-
"""Put
|
|
147
|
+
"""Put Messages in the queue from the LinkState."""
|
|
153
148
|
while not f_stop.is_set():
|
|
154
149
|
for node_id in nodes_mapping.keys():
|
|
155
|
-
|
|
156
|
-
for
|
|
157
|
-
queue.put(
|
|
150
|
+
message_ins_list = state.get_message_ins(node_id=node_id, limit=1)
|
|
151
|
+
for msg in message_ins_list:
|
|
152
|
+
queue.put(msg)
|
|
158
153
|
sleep(0.1)
|
|
159
154
|
|
|
160
155
|
|
|
161
|
-
def
|
|
162
|
-
state: LinkState, queue:
|
|
156
|
+
def put_message_into_state(
|
|
157
|
+
state: LinkState, queue: Queue[Message], f_stop: threading.Event
|
|
163
158
|
) -> None:
|
|
164
|
-
"""
|
|
159
|
+
"""Store reply Messages into the LinkState from the queue."""
|
|
165
160
|
while not f_stop.is_set():
|
|
166
161
|
try:
|
|
167
|
-
|
|
168
|
-
state.
|
|
162
|
+
message_reply = queue.get(timeout=1.0)
|
|
163
|
+
state.store_message_res(message_reply)
|
|
169
164
|
except Empty:
|
|
170
165
|
# queue is empty when timeout was triggered
|
|
171
166
|
pass
|
|
@@ -181,8 +176,8 @@ def run_api(
|
|
|
181
176
|
f_stop: threading.Event,
|
|
182
177
|
) -> None:
|
|
183
178
|
"""Run the VCE."""
|
|
184
|
-
|
|
185
|
-
|
|
179
|
+
messageins_queue: Queue[Message] = Queue()
|
|
180
|
+
messageres_queue: Queue[Message] = Queue()
|
|
186
181
|
|
|
187
182
|
try:
|
|
188
183
|
|
|
@@ -196,10 +191,10 @@ def run_api(
|
|
|
196
191
|
state = state_factory.state()
|
|
197
192
|
|
|
198
193
|
extractor_th = threading.Thread(
|
|
199
|
-
target=
|
|
194
|
+
target=add_messages_to_queue,
|
|
200
195
|
args=(
|
|
201
196
|
state,
|
|
202
|
-
|
|
197
|
+
messageins_queue,
|
|
203
198
|
nodes_mapping,
|
|
204
199
|
f_stop,
|
|
205
200
|
),
|
|
@@ -207,10 +202,10 @@ def run_api(
|
|
|
207
202
|
extractor_th.start()
|
|
208
203
|
|
|
209
204
|
injector_th = threading.Thread(
|
|
210
|
-
target=
|
|
205
|
+
target=put_message_into_state,
|
|
211
206
|
args=(
|
|
212
207
|
state,
|
|
213
|
-
|
|
208
|
+
messageres_queue,
|
|
214
209
|
f_stop,
|
|
215
210
|
),
|
|
216
211
|
)
|
|
@@ -220,8 +215,8 @@ def run_api(
|
|
|
220
215
|
_ = [
|
|
221
216
|
executor.submit(
|
|
222
217
|
worker,
|
|
223
|
-
|
|
224
|
-
|
|
218
|
+
messageins_queue,
|
|
219
|
+
messageres_queue,
|
|
225
220
|
node_info_stores,
|
|
226
221
|
backend,
|
|
227
222
|
f_stop,
|