flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +260 -86
- flwr/client/clientapp/app.py +6 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/run_info_store.py +2 -2
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/__init__.py +12 -4
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/config.py +4 -4
- flwr/common/constant.py +16 -0
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +2 -2
- flwr/common/message.py +338 -102
- flwr/common/object_ref.py +0 -10
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +9 -18
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +67 -190
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +44 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +74 -3
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +15 -12
- flwr/server/compat/app_utils.py +26 -18
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
- flwr/server/run_serverapp.py +6 -17
- flwr/server/server_app.py +126 -33
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +33 -38
- flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
- flwr/server/superlink/linkstate/linkstate.py +51 -64
- flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
- flwr/server/superlink/linkstate/utils.py +171 -133
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +52 -58
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/app.py +0 -14
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +6 -6
- flwr/superexec/exec_user_auth_interceptor.py +22 -4
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/common/record/parametersrecord.py +0 -204
- flwr/common/record/recordset.py +0 -202
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
flwr/client/client_app.py
CHANGED
|
@@ -16,6 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import inspect
|
|
19
|
+
from collections.abc import Iterator
|
|
20
|
+
from contextlib import contextmanager
|
|
19
21
|
from typing import Callable, Optional
|
|
20
22
|
|
|
21
23
|
from flwr.client.client import Client
|
|
@@ -25,10 +27,13 @@ from flwr.client.message_handler.message_handler import (
|
|
|
25
27
|
from flwr.client.mod.utils import make_ffn
|
|
26
28
|
from flwr.client.typing import ClientFnExt, Mod
|
|
27
29
|
from flwr.common import Context, Message, MessageType
|
|
28
|
-
from flwr.common.logger import warn_deprecated_feature
|
|
30
|
+
from flwr.common.logger import warn_deprecated_feature
|
|
31
|
+
from flwr.common.message import validate_message_type
|
|
29
32
|
|
|
30
33
|
from .typing import ClientAppCallable
|
|
31
34
|
|
|
35
|
+
DEFAULT_ACTION = "default"
|
|
36
|
+
|
|
32
37
|
|
|
33
38
|
def _alert_erroneous_client_fn() -> None:
|
|
34
39
|
raise ValueError(
|
|
@@ -71,6 +76,11 @@ def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFn
|
|
|
71
76
|
return client_fn
|
|
72
77
|
|
|
73
78
|
|
|
79
|
+
@contextmanager
|
|
80
|
+
def _empty_lifespan(_: Context) -> Iterator[None]:
|
|
81
|
+
yield
|
|
82
|
+
|
|
83
|
+
|
|
74
84
|
class ClientAppException(Exception):
|
|
75
85
|
"""Exception raised when an exception is raised while executing a ClientApp."""
|
|
76
86
|
|
|
@@ -95,15 +105,6 @@ class ClientApp:
|
|
|
95
105
|
>>> return FlowerClient().to_client()
|
|
96
106
|
>>>
|
|
97
107
|
>>> app = ClientApp(client_fn)
|
|
98
|
-
|
|
99
|
-
If the above code is in a Python module called `client`, it can be started as
|
|
100
|
-
follows:
|
|
101
|
-
|
|
102
|
-
>>> flower-client-app client:app --insecure
|
|
103
|
-
|
|
104
|
-
In this `client:app` example, `client` refers to the Python module `client.py` in
|
|
105
|
-
which the previous code lives in and `app` refers to the global attribute `app` that
|
|
106
|
-
points to an object of type `ClientApp`.
|
|
107
108
|
"""
|
|
108
109
|
|
|
109
110
|
def __init__(
|
|
@@ -112,6 +113,7 @@ class ClientApp:
|
|
|
112
113
|
mods: Optional[list[Mod]] = None,
|
|
113
114
|
) -> None:
|
|
114
115
|
self._mods: list[Mod] = mods if mods is not None else []
|
|
116
|
+
self._registered_funcs: dict[str, ClientAppCallable] = {}
|
|
115
117
|
|
|
116
118
|
# Create wrapper function for `handle`
|
|
117
119
|
self._call: Optional[ClientAppCallable] = None
|
|
@@ -131,129 +133,303 @@ class ClientApp:
|
|
|
131
133
|
# Wrap mods around the wrapped handle function
|
|
132
134
|
self._call = make_ffn(ffn, mods if mods is not None else [])
|
|
133
135
|
|
|
134
|
-
#
|
|
135
|
-
self.
|
|
136
|
-
self._evaluate: Optional[ClientAppCallable] = None
|
|
137
|
-
self._query: Optional[ClientAppCallable] = None
|
|
136
|
+
# Lifespan function
|
|
137
|
+
self._lifespan = _empty_lifespan
|
|
138
138
|
|
|
139
139
|
def __call__(self, message: Message, context: Context) -> Message:
|
|
140
140
|
"""Execute `ClientApp`."""
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
141
|
+
with self._lifespan(context):
|
|
142
|
+
# Execute message using `client_fn`
|
|
143
|
+
if self._call:
|
|
144
|
+
return self._call(message, context)
|
|
145
|
+
|
|
146
|
+
# Get the category and the action
|
|
147
|
+
# A valid message type is of the form "<category>" or "<category>.<action>",
|
|
148
|
+
# where <category> must be "train"/"evaluate"/"query", and <action> is a
|
|
149
|
+
# valid Python identifier
|
|
150
|
+
if not validate_message_type(message.metadata.message_type):
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"Invalid message type: {message.metadata.message_type}"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
category, action = message.metadata.message_type, DEFAULT_ACTION
|
|
156
|
+
if "." in category:
|
|
157
|
+
category, action = category.split(".")
|
|
158
|
+
|
|
159
|
+
# Check if the function is registered
|
|
160
|
+
if (full_name := f"{category}.{action}") in self._registered_funcs:
|
|
161
|
+
return self._registered_funcs[full_name](message, context)
|
|
162
|
+
|
|
163
|
+
raise ValueError(f"No {category} function registered with name '{action}'")
|
|
164
|
+
|
|
165
|
+
def train(
|
|
166
|
+
self, action: str = DEFAULT_ACTION, *, mods: Optional[list[Mod]] = None
|
|
167
|
+
) -> Callable[[ClientAppCallable], ClientAppCallable]:
|
|
168
|
+
"""Register a train function with the ``ClientApp``.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
action : str (default: "default")
|
|
173
|
+
The action name used to route messages. Defaults to "default".
|
|
174
|
+
mods : Optional[list[Mod]] (default: None)
|
|
175
|
+
A list of function-specific modifiers.
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
Callable[[ClientAppCallable], ClientAppCallable]
|
|
180
|
+
A decorator that registers a train function with the ``ClientApp``.
|
|
164
181
|
|
|
165
182
|
Examples
|
|
166
183
|
--------
|
|
184
|
+
Registering a train function:
|
|
185
|
+
|
|
167
186
|
>>> app = ClientApp()
|
|
168
187
|
>>>
|
|
169
188
|
>>> @app.train()
|
|
170
189
|
>>> def train(message: Message, context: Context) -> Message:
|
|
171
|
-
>>>
|
|
172
|
-
>>>
|
|
173
|
-
>>>
|
|
174
|
-
"""
|
|
190
|
+
>>> print("Executing default train function")
|
|
191
|
+
>>> # Create and return an echo reply message
|
|
192
|
+
>>> return Message(message.content, reply_to=message)
|
|
175
193
|
|
|
176
|
-
|
|
177
|
-
"""Register the train fn with the ServerApp object."""
|
|
178
|
-
if self._call:
|
|
179
|
-
raise _registration_error(MessageType.TRAIN)
|
|
194
|
+
Registering a train function with a custom action name:
|
|
180
195
|
|
|
181
|
-
|
|
196
|
+
>>> app = ClientApp()
|
|
197
|
+
>>>
|
|
198
|
+
>>> # Messages with `message_type="train.custom_action"` will be
|
|
199
|
+
>>> # routed to this function.
|
|
200
|
+
>>> @app.train("custom_action")
|
|
201
|
+
>>> def custom_action(message: Message, context: Context) -> Message:
|
|
202
|
+
>>> print("Executing train function for custom action")
|
|
203
|
+
>>> return Message(message.content, reply_to=message)
|
|
182
204
|
|
|
183
|
-
|
|
184
|
-
# Wrap mods around the wrapped step function
|
|
185
|
-
self._train = make_ffn(train_fn, self._mods)
|
|
205
|
+
Registering a train function with a function-specific Flower Mod:
|
|
186
206
|
|
|
187
|
-
|
|
188
|
-
|
|
207
|
+
>>> from flwr.client.mod import message_size_mod
|
|
208
|
+
>>>
|
|
209
|
+
>>> app = ClientApp()
|
|
210
|
+
>>>
|
|
211
|
+
>>> # Using the `mods` argument to apply a function-specific mod.
|
|
212
|
+
>>> @app.train(mods=[message_size_mod])
|
|
213
|
+
>>> def train(message: Message, context: Context) -> Message:
|
|
214
|
+
>>> print("Executing train function with message size mod")
|
|
215
|
+
>>> # Create and return an echo reply message
|
|
216
|
+
>>> return Message(message.content, reply_to=message)
|
|
217
|
+
"""
|
|
218
|
+
return _get_decorator(self, MessageType.TRAIN, action, mods)
|
|
219
|
+
|
|
220
|
+
def evaluate(
|
|
221
|
+
self, action: str = DEFAULT_ACTION, *, mods: Optional[list[Mod]] = None
|
|
222
|
+
) -> Callable[[ClientAppCallable], ClientAppCallable]:
|
|
223
|
+
"""Register an evaluate function with the ``ClientApp``.
|
|
189
224
|
|
|
190
|
-
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
action : str (default: "default")
|
|
228
|
+
The action name used to route messages. Defaults to "default".
|
|
229
|
+
mods : Optional[list[Mod]] (default: None)
|
|
230
|
+
A list of function-specific modifiers.
|
|
191
231
|
|
|
192
|
-
|
|
193
|
-
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
Callable[[ClientAppCallable], ClientAppCallable]
|
|
235
|
+
A decorator that registers an evaluate function with the ``ClientApp``.
|
|
194
236
|
|
|
195
237
|
Examples
|
|
196
238
|
--------
|
|
239
|
+
Registering an evaluate function:
|
|
240
|
+
|
|
197
241
|
>>> app = ClientApp()
|
|
198
242
|
>>>
|
|
199
243
|
>>> @app.evaluate()
|
|
200
244
|
>>> def evaluate(message: Message, context: Context) -> Message:
|
|
201
|
-
>>>
|
|
202
|
-
>>>
|
|
203
|
-
>>>
|
|
204
|
-
"""
|
|
245
|
+
>>> print("Executing default evaluate function")
|
|
246
|
+
>>> # Create and return an echo reply message
|
|
247
|
+
>>> return Message(message.content, reply_to=message)
|
|
205
248
|
|
|
206
|
-
|
|
207
|
-
"""Register the evaluate fn with the ServerApp object."""
|
|
208
|
-
if self._call:
|
|
209
|
-
raise _registration_error(MessageType.EVALUATE)
|
|
249
|
+
Registering an evaluate function with a custom action name:
|
|
210
250
|
|
|
211
|
-
|
|
251
|
+
>>> app = ClientApp()
|
|
252
|
+
>>>
|
|
253
|
+
>>> # Messages with `message_type="evaluate.custom_action"` will be
|
|
254
|
+
>>> # routed to this function.
|
|
255
|
+
>>> @app.evaluate("custom_action")
|
|
256
|
+
>>> def custom_action(message: Message, context: Context) -> Message:
|
|
257
|
+
>>> print("Executing evaluate function for custom action")
|
|
258
|
+
>>> return Message(message.content, reply_to=message)
|
|
212
259
|
|
|
213
|
-
|
|
214
|
-
# Wrap mods around the wrapped step function
|
|
215
|
-
self._evaluate = make_ffn(evaluate_fn, self._mods)
|
|
260
|
+
Registering an evaluate function with a function-specific Flower Mod:
|
|
216
261
|
|
|
217
|
-
|
|
218
|
-
|
|
262
|
+
>>> from flwr.client.mod import message_size_mod
|
|
263
|
+
>>>
|
|
264
|
+
>>> app = ClientApp()
|
|
265
|
+
>>>
|
|
266
|
+
>>> # Using the `mods` argument to apply a function-specific mod.
|
|
267
|
+
>>> @app.evaluate(mods=[message_size_mod])
|
|
268
|
+
>>> def evaluate(message: Message, context: Context) -> Message:
|
|
269
|
+
>>> print("Executing evaluate function with message size mod")
|
|
270
|
+
>>> # Create and return an echo reply message
|
|
271
|
+
>>> return Message(message.content, reply_to=message)
|
|
272
|
+
"""
|
|
273
|
+
return _get_decorator(self, MessageType.EVALUATE, action, mods)
|
|
274
|
+
|
|
275
|
+
def query(
|
|
276
|
+
self, action: str = DEFAULT_ACTION, *, mods: Optional[list[Mod]] = None
|
|
277
|
+
) -> Callable[[ClientAppCallable], ClientAppCallable]:
|
|
278
|
+
"""Register a query function with the ``ClientApp``.
|
|
219
279
|
|
|
220
|
-
|
|
280
|
+
Parameters
|
|
281
|
+
----------
|
|
282
|
+
action : str (default: "default")
|
|
283
|
+
The action name used to route messages. Defaults to "default".
|
|
284
|
+
mods : Optional[list[Mod]] (default: None)
|
|
285
|
+
A list of function-specific modifiers.
|
|
221
286
|
|
|
222
|
-
|
|
223
|
-
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
Callable[[ClientAppCallable], ClientAppCallable]
|
|
290
|
+
A decorator that registers a query function with the ``ClientApp``.
|
|
224
291
|
|
|
225
292
|
Examples
|
|
226
293
|
--------
|
|
294
|
+
Registering a query function:
|
|
295
|
+
|
|
227
296
|
>>> app = ClientApp()
|
|
228
297
|
>>>
|
|
229
298
|
>>> @app.query()
|
|
230
299
|
>>> def query(message: Message, context: Context) -> Message:
|
|
231
|
-
>>>
|
|
232
|
-
>>>
|
|
233
|
-
>>>
|
|
300
|
+
>>> print("Executing default query function")
|
|
301
|
+
>>> # Create and return an echo reply message
|
|
302
|
+
>>> return Message(message.content, reply_to=message)
|
|
303
|
+
|
|
304
|
+
Registering a query function with a custom action name:
|
|
305
|
+
|
|
306
|
+
>>> app = ClientApp()
|
|
307
|
+
>>>
|
|
308
|
+
>>> # Messages with `message_type="query.custom_action"` will be
|
|
309
|
+
>>> # routed to this function.
|
|
310
|
+
>>> @app.query("custom_action")
|
|
311
|
+
>>> def custom_action(message: Message, context: Context) -> Message:
|
|
312
|
+
>>> print("Executing query function for custom action")
|
|
313
|
+
>>> return Message(message.content, reply_to=message)
|
|
314
|
+
|
|
315
|
+
Registering a query function with a function-specific Flower Mod:
|
|
316
|
+
|
|
317
|
+
>>> from flwr.client.mod import message_size_mod
|
|
318
|
+
>>>
|
|
319
|
+
>>> app = ClientApp()
|
|
320
|
+
>>>
|
|
321
|
+
>>> # Using the `mods` argument to apply a function-specific mod.
|
|
322
|
+
>>> @app.query(mods=[message_size_mod])
|
|
323
|
+
>>> def query(message: Message, context: Context) -> Message:
|
|
324
|
+
>>> print("Executing query function with message size mod")
|
|
325
|
+
>>> # Create and return an echo reply message
|
|
326
|
+
>>> return Message(message.content, reply_to=message)
|
|
234
327
|
"""
|
|
328
|
+
return _get_decorator(self, MessageType.QUERY, action, mods)
|
|
235
329
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
330
|
+
def lifespan(
|
|
331
|
+
self,
|
|
332
|
+
) -> Callable[
|
|
333
|
+
[Callable[[Context], Iterator[None]]], Callable[[Context], Iterator[None]]
|
|
334
|
+
]:
|
|
335
|
+
"""Return a decorator that registers the lifespan fn with the client app.
|
|
336
|
+
|
|
337
|
+
The decorated function should accept a `Context` object and use `yield`
|
|
338
|
+
to define enter and exit behavior.
|
|
339
|
+
|
|
340
|
+
Examples
|
|
341
|
+
--------
|
|
342
|
+
>>> app = ClientApp()
|
|
343
|
+
>>>
|
|
344
|
+
>>> @app.lifespan()
|
|
345
|
+
>>> def lifespan(context: Context) -> None:
|
|
346
|
+
>>> # Perform initialization tasks before the app starts
|
|
347
|
+
>>> print("Initializing ClientApp")
|
|
348
|
+
>>>
|
|
349
|
+
>>> yield # ClientApp is running
|
|
350
|
+
>>>
|
|
351
|
+
>>> # Perform cleanup tasks after the app stops
|
|
352
|
+
>>> print("Cleaning up ClientApp")
|
|
353
|
+
"""
|
|
240
354
|
|
|
241
|
-
|
|
355
|
+
def lifespan_decorator(
|
|
356
|
+
lifespan_fn: Callable[[Context], Iterator[None]]
|
|
357
|
+
) -> Callable[[Context], Iterator[None]]:
|
|
358
|
+
"""Register the lifespan fn with the ServerApp object."""
|
|
359
|
+
|
|
360
|
+
@contextmanager
|
|
361
|
+
def decorated_lifespan(context: Context) -> Iterator[None]:
|
|
362
|
+
# Execute the code before `yield` in lifespan_fn
|
|
363
|
+
try:
|
|
364
|
+
if not isinstance(it := lifespan_fn(context), Iterator):
|
|
365
|
+
raise StopIteration
|
|
366
|
+
next(it)
|
|
367
|
+
except StopIteration:
|
|
368
|
+
raise RuntimeError(
|
|
369
|
+
"lifespan function should yield at least once."
|
|
370
|
+
) from None
|
|
371
|
+
|
|
372
|
+
try:
|
|
373
|
+
# Enter the context
|
|
374
|
+
yield
|
|
375
|
+
finally:
|
|
376
|
+
try:
|
|
377
|
+
# Execute the code after `yield` in lifespan_fn
|
|
378
|
+
next(it)
|
|
379
|
+
except StopIteration:
|
|
380
|
+
pass
|
|
381
|
+
else:
|
|
382
|
+
raise RuntimeError("lifespan function should only yield once.")
|
|
242
383
|
|
|
243
384
|
# Register provided function with the ClientApp object
|
|
244
|
-
#
|
|
245
|
-
self.
|
|
385
|
+
# Ignore mypy error because of different argument names (`_` vs `context`)
|
|
386
|
+
self._lifespan = decorated_lifespan # type: ignore
|
|
246
387
|
|
|
247
388
|
# Return provided function unmodified
|
|
248
|
-
return
|
|
389
|
+
return lifespan_fn
|
|
249
390
|
|
|
250
|
-
return
|
|
391
|
+
return lifespan_decorator
|
|
251
392
|
|
|
252
393
|
|
|
253
394
|
class LoadClientAppError(Exception):
|
|
254
395
|
"""Error when trying to load `ClientApp`."""
|
|
255
396
|
|
|
256
397
|
|
|
398
|
+
def _get_decorator(
|
|
399
|
+
app: ClientApp, category: str, action: str, mods: Optional[list[Mod]]
|
|
400
|
+
) -> Callable[[ClientAppCallable], ClientAppCallable]:
|
|
401
|
+
"""Get the decorator for the given category and action."""
|
|
402
|
+
# pylint: disable=protected-access
|
|
403
|
+
if app._call:
|
|
404
|
+
raise _registration_error(category)
|
|
405
|
+
|
|
406
|
+
def decorator(fn: ClientAppCallable) -> ClientAppCallable:
|
|
407
|
+
|
|
408
|
+
# Check if the name is a valid Python identifier
|
|
409
|
+
if not action.isidentifier():
|
|
410
|
+
raise ValueError(
|
|
411
|
+
f"Cannot register {category} function with name '{action}'. "
|
|
412
|
+
"The name must follow Python's function naming rules."
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Check if the name is already registered
|
|
416
|
+
full_name = f"{category}.{action}" # Full name of the message type
|
|
417
|
+
if full_name in app._registered_funcs:
|
|
418
|
+
raise ValueError(
|
|
419
|
+
f"Cannot register {category} function with name '{action}'. "
|
|
420
|
+
f"A {category} function with the name '{action}' is already registered."
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Register provided function with the ClientApp object
|
|
424
|
+
app._registered_funcs[full_name] = make_ffn(fn, app._mods + (mods or []))
|
|
425
|
+
|
|
426
|
+
# Return provided function unmodified
|
|
427
|
+
return fn
|
|
428
|
+
|
|
429
|
+
# pylint: enable=protected-access
|
|
430
|
+
return decorator
|
|
431
|
+
|
|
432
|
+
|
|
257
433
|
def _registration_error(fn_name: str) -> ValueError:
|
|
258
434
|
return ValueError(
|
|
259
435
|
f"""Use either `@app.{fn_name}()` or `client_fn`, but not both.
|
|
@@ -278,8 +454,6 @@ def _registration_error(fn_name: str) -> ValueError:
|
|
|
278
454
|
>>> def {fn_name}(message: Message, context: Context) -> Message:
|
|
279
455
|
>>> print("ClientApp {fn_name} running")
|
|
280
456
|
>>> # Create and return an echo reply message
|
|
281
|
-
>>> return message.
|
|
282
|
-
>>> content=message.content()
|
|
283
|
-
>>> )
|
|
457
|
+
>>> return Message(message.content, reply_to=message)
|
|
284
458
|
""",
|
|
285
459
|
)
|
flwr/client/clientapp/app.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
|
+
import gc
|
|
19
20
|
import time
|
|
20
21
|
from logging import DEBUG, ERROR, INFO
|
|
21
22
|
from typing import Optional
|
|
@@ -151,8 +152,8 @@ def run_clientapp( # pylint: disable=R0914
|
|
|
151
152
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
152
153
|
|
|
153
154
|
# Create error message
|
|
154
|
-
reply_message =
|
|
155
|
-
|
|
155
|
+
reply_message = Message(
|
|
156
|
+
Error(code=e_code, reason=reason), reply_to=message
|
|
156
157
|
)
|
|
157
158
|
|
|
158
159
|
# Push Message and Context to SuperNode
|
|
@@ -160,6 +161,9 @@ def run_clientapp( # pylint: disable=R0914
|
|
|
160
161
|
stub=stub, token=token, message=reply_message, context=context
|
|
161
162
|
)
|
|
162
163
|
|
|
164
|
+
del client_app, message, context, run, fab, reply_message
|
|
165
|
+
gc.collect()
|
|
166
|
+
|
|
163
167
|
# Reset token to `None` to prevent flwr-clientapp from trying to pull the
|
|
164
168
|
# same inputs again
|
|
165
169
|
token = None
|
|
@@ -28,16 +28,18 @@ from cryptography.hazmat.primitives.asymmetric import ec
|
|
|
28
28
|
from flwr.common import (
|
|
29
29
|
DEFAULT_TTL,
|
|
30
30
|
GRPC_MAX_MESSAGE_LENGTH,
|
|
31
|
-
|
|
31
|
+
ConfigRecord,
|
|
32
32
|
Message,
|
|
33
33
|
Metadata,
|
|
34
|
-
|
|
34
|
+
RecordDict,
|
|
35
|
+
now,
|
|
35
36
|
)
|
|
36
|
-
from flwr.common import
|
|
37
|
+
from flwr.common import recorddict_compat as compat
|
|
37
38
|
from flwr.common import serde
|
|
38
39
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
39
40
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
40
41
|
from flwr.common.logger import log
|
|
42
|
+
from flwr.common.message import make_message
|
|
41
43
|
from flwr.common.retry_invoker import RetryInvoker
|
|
42
44
|
from flwr.common.typing import Fab, Run
|
|
43
45
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
@@ -139,32 +141,32 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
|
|
|
139
141
|
# Receive ServerMessage proto
|
|
140
142
|
proto = next(server_message_iterator)
|
|
141
143
|
|
|
142
|
-
# ServerMessage proto --> *Ins -->
|
|
144
|
+
# ServerMessage proto --> *Ins --> RecordDict
|
|
143
145
|
field = proto.WhichOneof("msg")
|
|
144
146
|
message_type = ""
|
|
145
147
|
if field == "get_properties_ins":
|
|
146
|
-
|
|
148
|
+
recorddict = compat.getpropertiesins_to_recorddict(
|
|
147
149
|
serde.get_properties_ins_from_proto(proto.get_properties_ins)
|
|
148
150
|
)
|
|
149
151
|
message_type = MessageTypeLegacy.GET_PROPERTIES
|
|
150
152
|
elif field == "get_parameters_ins":
|
|
151
|
-
|
|
153
|
+
recorddict = compat.getparametersins_to_recorddict(
|
|
152
154
|
serde.get_parameters_ins_from_proto(proto.get_parameters_ins)
|
|
153
155
|
)
|
|
154
156
|
message_type = MessageTypeLegacy.GET_PARAMETERS
|
|
155
157
|
elif field == "fit_ins":
|
|
156
|
-
|
|
158
|
+
recorddict = compat.fitins_to_recorddict(
|
|
157
159
|
serde.fit_ins_from_proto(proto.fit_ins), False
|
|
158
160
|
)
|
|
159
161
|
message_type = MessageType.TRAIN
|
|
160
162
|
elif field == "evaluate_ins":
|
|
161
|
-
|
|
163
|
+
recorddict = compat.evaluateins_to_recorddict(
|
|
162
164
|
serde.evaluate_ins_from_proto(proto.evaluate_ins), False
|
|
163
165
|
)
|
|
164
166
|
message_type = MessageType.EVALUATE
|
|
165
167
|
elif field == "reconnect_ins":
|
|
166
|
-
|
|
167
|
-
|
|
168
|
+
recorddict = RecordDict()
|
|
169
|
+
recorddict.config_records["config"] = ConfigRecord(
|
|
168
170
|
{"seconds": proto.reconnect_ins.seconds}
|
|
169
171
|
)
|
|
170
172
|
message_type = "reconnect"
|
|
@@ -175,45 +177,46 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
|
|
|
175
177
|
)
|
|
176
178
|
|
|
177
179
|
# Construct Message
|
|
178
|
-
return
|
|
180
|
+
return make_message(
|
|
179
181
|
metadata=Metadata(
|
|
180
182
|
run_id=0,
|
|
181
183
|
message_id=str(uuid.uuid4()),
|
|
182
184
|
src_node_id=0,
|
|
183
185
|
dst_node_id=0,
|
|
184
|
-
|
|
186
|
+
reply_to_message_id="",
|
|
185
187
|
group_id="",
|
|
188
|
+
created_at=now().timestamp(),
|
|
186
189
|
ttl=DEFAULT_TTL,
|
|
187
190
|
message_type=message_type,
|
|
188
191
|
),
|
|
189
|
-
content=
|
|
192
|
+
content=recorddict,
|
|
190
193
|
)
|
|
191
194
|
|
|
192
195
|
def send(message: Message) -> None:
|
|
193
|
-
# Retrieve
|
|
194
|
-
|
|
196
|
+
# Retrieve RecordDict and message_type
|
|
197
|
+
recorddict = message.content
|
|
195
198
|
message_type = message.metadata.message_type
|
|
196
199
|
|
|
197
|
-
#
|
|
200
|
+
# RecordDict --> *Res --> *Res proto -> ClientMessage proto
|
|
198
201
|
if message_type == MessageTypeLegacy.GET_PROPERTIES:
|
|
199
|
-
getpropres = compat.
|
|
202
|
+
getpropres = compat.recorddict_to_getpropertiesres(recorddict)
|
|
200
203
|
msg_proto = ClientMessage(
|
|
201
204
|
get_properties_res=serde.get_properties_res_to_proto(getpropres)
|
|
202
205
|
)
|
|
203
206
|
elif message_type == MessageTypeLegacy.GET_PARAMETERS:
|
|
204
|
-
getparamres = compat.
|
|
207
|
+
getparamres = compat.recorddict_to_getparametersres(recorddict, False)
|
|
205
208
|
msg_proto = ClientMessage(
|
|
206
209
|
get_parameters_res=serde.get_parameters_res_to_proto(getparamres)
|
|
207
210
|
)
|
|
208
211
|
elif message_type == MessageType.TRAIN:
|
|
209
|
-
fitres = compat.
|
|
212
|
+
fitres = compat.recorddict_to_fitres(recorddict, False)
|
|
210
213
|
msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres))
|
|
211
214
|
elif message_type == MessageType.EVALUATE:
|
|
212
|
-
evalres = compat.
|
|
215
|
+
evalres = compat.recorddict_to_evaluateres(recorddict)
|
|
213
216
|
msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
|
|
214
217
|
elif message_type == "reconnect":
|
|
215
218
|
reason = cast(
|
|
216
|
-
Reason.ValueType,
|
|
219
|
+
Reason.ValueType, recorddict.config_records["config"]["reason"]
|
|
217
220
|
)
|
|
218
221
|
msg_proto = ClientMessage(
|
|
219
222
|
disconnect_res=ClientMessage.DisconnectRes(reason=reason)
|