flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +260 -86
- flwr/client/clientapp/app.py +6 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/run_info_store.py +2 -2
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/__init__.py +12 -4
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/config.py +4 -4
- flwr/common/constant.py +16 -0
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +2 -2
- flwr/common/message.py +338 -102
- flwr/common/object_ref.py +0 -10
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +9 -18
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +67 -190
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +44 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +74 -3
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +15 -12
- flwr/server/compat/app_utils.py +26 -18
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
- flwr/server/run_serverapp.py +6 -17
- flwr/server/server_app.py +126 -33
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +33 -38
- flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
- flwr/server/superlink/linkstate/linkstate.py +51 -64
- flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
- flwr/server/superlink/linkstate/utils.py +171 -133
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +52 -58
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/app.py +0 -14
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +6 -6
- flwr/superexec/exec_user_auth_interceptor.py +22 -4
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/common/record/parametersrecord.py +0 -204
- flwr/common/record/recordset.py +0 -202
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
|
@@ -15,29 +15,26 @@
|
|
|
15
15
|
"""Utility functions for State."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from logging import ERROR
|
|
19
18
|
from os import urandom
|
|
20
|
-
from typing import Optional
|
|
19
|
+
from typing import Optional
|
|
21
20
|
from uuid import UUID, uuid4
|
|
22
21
|
|
|
23
|
-
from flwr.common import
|
|
24
|
-
from flwr.common.constant import
|
|
22
|
+
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
|
23
|
+
from flwr.common.constant import (
|
|
24
|
+
SUPERLINK_NODE_ID,
|
|
25
|
+
ErrorCode,
|
|
26
|
+
MessageType,
|
|
27
|
+
Status,
|
|
28
|
+
SubStatus,
|
|
29
|
+
)
|
|
30
|
+
from flwr.common.message import make_message
|
|
25
31
|
from flwr.common.typing import RunStatus
|
|
26
32
|
|
|
27
33
|
# pylint: disable=E0611
|
|
28
|
-
from flwr.proto.error_pb2 import Error
|
|
29
34
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
30
|
-
from flwr.proto.
|
|
31
|
-
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
|
|
32
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
35
|
+
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
|
33
36
|
|
|
34
37
|
# pylint: enable=E0611
|
|
35
|
-
|
|
36
|
-
NODE_UNAVAILABLE_ERROR_REASON = (
|
|
37
|
-
"Error: Node Unavailable - The destination node is currently unavailable. "
|
|
38
|
-
"It exceeds the time limit specified in its last ping."
|
|
39
|
-
)
|
|
40
|
-
|
|
41
38
|
VALID_RUN_STATUS_TRANSITIONS = {
|
|
42
39
|
(Status.PENDING, Status.STARTING),
|
|
43
40
|
(Status.STARTING, Status.RUNNING),
|
|
@@ -58,6 +55,10 @@ MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
|
58
55
|
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
59
56
|
"Error: Reply Message Unavailable - The reply message has expired."
|
|
60
57
|
)
|
|
58
|
+
NODE_UNAVAILABLE_ERROR_REASON = (
|
|
59
|
+
"Error: Node Unavailable - The destination node is currently unavailable. "
|
|
60
|
+
"It exceeds twice the time limit specified in its last ping."
|
|
61
|
+
)
|
|
61
62
|
|
|
62
63
|
|
|
63
64
|
def generate_rand_int_from_bytes(
|
|
@@ -171,15 +172,15 @@ def context_from_bytes(context_bytes: bytes) -> Context:
|
|
|
171
172
|
return serde.context_from_proto(ProtoContext.FromString(context_bytes))
|
|
172
173
|
|
|
173
174
|
|
|
174
|
-
def
|
|
175
|
-
"""Serialize a `
|
|
176
|
-
return serde.
|
|
175
|
+
def configrecord_to_bytes(config_record: ConfigRecord) -> bytes:
|
|
176
|
+
"""Serialize a `ConfigRecord` to bytes."""
|
|
177
|
+
return serde.config_record_to_proto(config_record).SerializeToString()
|
|
177
178
|
|
|
178
179
|
|
|
179
|
-
def
|
|
180
|
-
"""Deserialize `
|
|
181
|
-
return serde.
|
|
182
|
-
|
|
180
|
+
def configrecord_from_bytes(configrecord_bytes: bytes) -> ConfigRecord:
|
|
181
|
+
"""Deserialize `ConfigRecord` from bytes."""
|
|
182
|
+
return serde.config_record_from_proto(
|
|
183
|
+
ProtoConfigRecord.FromString(configrecord_bytes)
|
|
183
184
|
)
|
|
184
185
|
|
|
185
186
|
|
|
@@ -235,165 +236,202 @@ def has_valid_sub_status(status: RunStatus) -> bool:
|
|
|
235
236
|
return status.sub_status == ""
|
|
236
237
|
|
|
237
238
|
|
|
238
|
-
def
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
taskins_id : Union[str, UUID]
|
|
244
|
-
The ID of the unavailable TaskIns.
|
|
245
|
-
|
|
246
|
-
Returns
|
|
247
|
-
-------
|
|
248
|
-
TaskRes
|
|
249
|
-
A TaskRes with an error code MESSAGE_UNAVAILABLE to indicate that the
|
|
250
|
-
inquired TaskIns ID cannot be found (due to non-existence or expiration).
|
|
251
|
-
"""
|
|
239
|
+
def create_message_error_unavailable_res_message(
|
|
240
|
+
ins_metadata: Metadata, error_type: str
|
|
241
|
+
) -> Message:
|
|
242
|
+
"""Generate an error Message that the SuperLink returns carrying the specified
|
|
243
|
+
error."""
|
|
252
244
|
current_time = now().timestamp()
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
245
|
+
ttl = max(ins_metadata.ttl - (current_time - ins_metadata.created_at), 0)
|
|
246
|
+
metadata = Metadata(
|
|
247
|
+
run_id=ins_metadata.run_id,
|
|
248
|
+
message_id=str(uuid4()),
|
|
249
|
+
src_node_id=SUPERLINK_NODE_ID,
|
|
250
|
+
dst_node_id=SUPERLINK_NODE_ID,
|
|
251
|
+
reply_to_message_id=ins_metadata.message_id,
|
|
252
|
+
group_id=ins_metadata.group_id,
|
|
253
|
+
message_type=ins_metadata.message_type,
|
|
254
|
+
created_at=current_time,
|
|
255
|
+
ttl=ttl,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return make_message(
|
|
259
|
+
metadata=metadata,
|
|
260
|
+
error=Error(
|
|
261
|
+
code=(
|
|
262
|
+
ErrorCode.REPLY_MESSAGE_UNAVAILABLE
|
|
263
|
+
if error_type == "msg_unavail"
|
|
264
|
+
else ErrorCode.NODE_UNAVAILABLE
|
|
265
|
+
),
|
|
266
|
+
reason=(
|
|
267
|
+
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON
|
|
268
|
+
if error_type == "msg_unavail"
|
|
269
|
+
else NODE_UNAVAILABLE_ERROR_REASON
|
|
268
270
|
),
|
|
269
271
|
),
|
|
270
272
|
)
|
|
271
273
|
|
|
272
274
|
|
|
273
|
-
def
|
|
274
|
-
"""
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
275
|
+
def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> Message:
|
|
276
|
+
"""Error to indicate that the enquired Message had expired before reply arrived or
|
|
277
|
+
that it isn't found."""
|
|
278
|
+
metadata = Metadata(
|
|
279
|
+
run_id=0, # Unknown
|
|
280
|
+
message_id=str(uuid4()),
|
|
281
|
+
src_node_id=SUPERLINK_NODE_ID,
|
|
282
|
+
dst_node_id=SUPERLINK_NODE_ID,
|
|
283
|
+
reply_to_message_id=str(reply_to_message_id),
|
|
284
|
+
group_id="", # Unknown
|
|
285
|
+
message_type=MessageType.SYSTEM,
|
|
286
|
+
created_at=now().timestamp(),
|
|
287
|
+
ttl=0,
|
|
288
|
+
)
|
|
280
289
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
"""
|
|
287
|
-
current_time = now().timestamp()
|
|
288
|
-
ttl = ref_taskins.task.ttl - (current_time - ref_taskins.task.created_at)
|
|
289
|
-
if ttl < 0:
|
|
290
|
-
log(ERROR, "Creating TaskRes for TaskIns that exceeds its TTL.")
|
|
291
|
-
ttl = 0
|
|
292
|
-
return TaskRes(
|
|
293
|
-
task_id=str(uuid4()),
|
|
294
|
-
group_id=ref_taskins.group_id,
|
|
295
|
-
run_id=ref_taskins.run_id,
|
|
296
|
-
task=Task(
|
|
297
|
-
# This function is only called by SuperLink, and thus it's the producer.
|
|
298
|
-
producer=Node(node_id=SUPERLINK_NODE_ID),
|
|
299
|
-
consumer=Node(node_id=SUPERLINK_NODE_ID),
|
|
300
|
-
created_at=current_time,
|
|
301
|
-
ttl=ttl,
|
|
302
|
-
ancestry=[ref_taskins.task_id],
|
|
303
|
-
task_type=ref_taskins.task.task_type,
|
|
304
|
-
error=Error(
|
|
305
|
-
code=ErrorCode.REPLY_MESSAGE_UNAVAILABLE,
|
|
306
|
-
reason=REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON,
|
|
307
|
-
),
|
|
290
|
+
return make_message(
|
|
291
|
+
metadata=metadata,
|
|
292
|
+
error=Error(
|
|
293
|
+
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
|
294
|
+
reason=MESSAGE_UNAVAILABLE_ERROR_REASON,
|
|
308
295
|
),
|
|
309
296
|
)
|
|
310
297
|
|
|
311
298
|
|
|
312
|
-
def
|
|
313
|
-
"""Check if the
|
|
314
|
-
return
|
|
299
|
+
def message_ttl_has_expired(message_metadata: Metadata, current_time: float) -> bool:
|
|
300
|
+
"""Check if the Message has expired."""
|
|
301
|
+
return message_metadata.ttl + message_metadata.created_at < current_time
|
|
315
302
|
|
|
316
303
|
|
|
317
|
-
def
|
|
318
|
-
|
|
319
|
-
|
|
304
|
+
def verify_message_ids(
|
|
305
|
+
inquired_message_ids: set[UUID],
|
|
306
|
+
found_message_ins_dict: dict[UUID, Message],
|
|
320
307
|
current_time: Optional[float] = None,
|
|
321
308
|
update_set: bool = True,
|
|
322
|
-
) -> dict[UUID,
|
|
323
|
-
"""Verify found
|
|
309
|
+
) -> dict[UUID, Message]:
|
|
310
|
+
"""Verify found Messages and generate error Messages for invalid ones.
|
|
324
311
|
|
|
325
312
|
Parameters
|
|
326
313
|
----------
|
|
327
|
-
|
|
328
|
-
Set of
|
|
329
|
-
|
|
330
|
-
Dictionary containing all found
|
|
314
|
+
inquired_message_ids : set[UUID]
|
|
315
|
+
Set of Message IDs for which to generate error Message if invalid.
|
|
316
|
+
found_message_ins_dict : dict[UUID, Message]
|
|
317
|
+
Dictionary containing all found Message indexed by their IDs.
|
|
331
318
|
current_time : Optional[float] (default: None)
|
|
332
319
|
The current time to check for expiration. If set to `None`, the current time
|
|
333
320
|
will automatically be set to the current timestamp using `now().timestamp()`.
|
|
334
321
|
update_set : bool (default: True)
|
|
335
|
-
If True, the `
|
|
322
|
+
If True, the `inquired_message_ids` will be updated to remove invalid ones,
|
|
336
323
|
by default True.
|
|
337
324
|
|
|
338
325
|
Returns
|
|
339
326
|
-------
|
|
340
|
-
dict[UUID,
|
|
341
|
-
A dictionary of error
|
|
327
|
+
dict[UUID, Message]
|
|
328
|
+
A dictionary of error Message indexed by the corresponding ID of the message
|
|
329
|
+
they are a reply of.
|
|
342
330
|
"""
|
|
343
331
|
ret_dict = {}
|
|
344
332
|
current = current_time if current_time else now().timestamp()
|
|
345
|
-
for
|
|
346
|
-
# Generate error
|
|
347
|
-
|
|
348
|
-
if
|
|
333
|
+
for message_id in list(inquired_message_ids):
|
|
334
|
+
# Generate error message if the inquired message doesn't exist or has expired
|
|
335
|
+
message_ins = found_message_ins_dict.get(message_id)
|
|
336
|
+
if message_ins is None or message_ttl_has_expired(
|
|
337
|
+
message_ins.metadata, current
|
|
338
|
+
):
|
|
349
339
|
if update_set:
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
ret_dict[
|
|
340
|
+
inquired_message_ids.remove(message_id)
|
|
341
|
+
message_res = create_message_error_unavailable_ins_message(message_id)
|
|
342
|
+
ret_dict[message_id] = message_res
|
|
353
343
|
return ret_dict
|
|
354
344
|
|
|
355
345
|
|
|
356
|
-
def
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
346
|
+
def verify_found_message_replies(
|
|
347
|
+
inquired_message_ids: set[UUID],
|
|
348
|
+
found_message_ins_dict: dict[UUID, Message],
|
|
349
|
+
found_message_res_list: list[Message],
|
|
360
350
|
current_time: Optional[float] = None,
|
|
361
351
|
update_set: bool = True,
|
|
362
|
-
) -> dict[UUID,
|
|
363
|
-
"""Verify found
|
|
352
|
+
) -> dict[UUID, Message]:
|
|
353
|
+
"""Verify found Message replies and generate error Message for invalid ones.
|
|
364
354
|
|
|
365
355
|
Parameters
|
|
366
356
|
----------
|
|
367
|
-
|
|
368
|
-
Set of
|
|
369
|
-
|
|
370
|
-
Dictionary containing all found
|
|
371
|
-
|
|
372
|
-
List of found
|
|
357
|
+
inquired_message_ids : set[UUID]
|
|
358
|
+
Set of Message IDs for which to generate error Message if invalid.
|
|
359
|
+
found_message_ins_dict : dict[UUID, Message]
|
|
360
|
+
Dictionary containing all found instruction Messages indexed by their IDs.
|
|
361
|
+
found_message_res_list : dict[Message, Message]
|
|
362
|
+
List of found Message to be verified.
|
|
373
363
|
current_time : Optional[float] (default: None)
|
|
374
364
|
The current time to check for expiration. If set to `None`, the current time
|
|
375
365
|
will automatically be set to the current timestamp using `now().timestamp()`.
|
|
376
366
|
update_set : bool (default: True)
|
|
377
|
-
If True, the `
|
|
378
|
-
that have a
|
|
367
|
+
If True, the `inquired_message_ids` will be updated to remove ones
|
|
368
|
+
that have a reply Message, by default True.
|
|
379
369
|
|
|
380
370
|
Returns
|
|
381
371
|
-------
|
|
382
|
-
dict[UUID,
|
|
383
|
-
A dictionary of
|
|
372
|
+
dict[UUID, Message]
|
|
373
|
+
A dictionary of Message indexed by the corresponding Message ID.
|
|
384
374
|
"""
|
|
385
|
-
ret_dict: dict[UUID,
|
|
375
|
+
ret_dict: dict[UUID, Message] = {}
|
|
386
376
|
current = current_time if current_time else now().timestamp()
|
|
387
|
-
for
|
|
388
|
-
|
|
377
|
+
for message_res in found_message_res_list:
|
|
378
|
+
message_ins_id = UUID(message_res.metadata.reply_to_message_id)
|
|
389
379
|
if update_set:
|
|
390
|
-
|
|
391
|
-
# Check if the
|
|
392
|
-
if
|
|
393
|
-
# No need to insert the error
|
|
394
|
-
|
|
395
|
-
|
|
380
|
+
inquired_message_ids.remove(message_ins_id)
|
|
381
|
+
# Check if the reply Message has expired
|
|
382
|
+
if message_ttl_has_expired(message_res.metadata, current):
|
|
383
|
+
# No need to insert the error Message
|
|
384
|
+
message_res = create_message_error_unavailable_res_message(
|
|
385
|
+
found_message_ins_dict[message_ins_id].metadata, "msg_unavail"
|
|
386
|
+
)
|
|
387
|
+
ret_dict[message_ins_id] = message_res
|
|
388
|
+
return ret_dict
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def check_node_availability_for_in_message(
|
|
392
|
+
inquired_in_message_ids: set[UUID],
|
|
393
|
+
found_in_message_dict: dict[UUID, Message],
|
|
394
|
+
node_id_to_online_until: dict[int, float],
|
|
395
|
+
current_time: Optional[float] = None,
|
|
396
|
+
update_set: bool = True,
|
|
397
|
+
) -> dict[UUID, Message]:
|
|
398
|
+
"""Check node availability for given Message and generate error reply Message if
|
|
399
|
+
unavailable. A Message error indicating node unavailability will be generated for
|
|
400
|
+
each given Message whose destination node is offline or non-existent.
|
|
401
|
+
|
|
402
|
+
Parameters
|
|
403
|
+
----------
|
|
404
|
+
inquired_in_message_ids : set[UUID]
|
|
405
|
+
Set of Message IDs for which to check destination node availability.
|
|
406
|
+
found_in_message_dict : dict[UUID, Message]
|
|
407
|
+
Dictionary containing all found Message indexed by their IDs.
|
|
408
|
+
node_id_to_online_until : dict[int, float]
|
|
409
|
+
Dictionary mapping node IDs to their online-until timestamps.
|
|
410
|
+
current_time : Optional[float] (default: None)
|
|
411
|
+
The current time to check for expiration. If set to `None`, the current time
|
|
412
|
+
will automatically be set to the current timestamp using `now().timestamp()`.
|
|
413
|
+
update_set : bool (default: True)
|
|
414
|
+
If True, the `inquired_in_message_ids` will be updated to remove invalid ones,
|
|
415
|
+
by default True.
|
|
416
|
+
|
|
417
|
+
Returns
|
|
418
|
+
-------
|
|
419
|
+
dict[UUID, Message]
|
|
420
|
+
A dictionary of error Message indexed by the corresponding Message ID.
|
|
421
|
+
"""
|
|
422
|
+
ret_dict = {}
|
|
423
|
+
current = current_time if current_time else now().timestamp()
|
|
424
|
+
for in_message_id in list(inquired_in_message_ids):
|
|
425
|
+
in_message = found_in_message_dict[in_message_id]
|
|
426
|
+
node_id = in_message.metadata.dst_node_id
|
|
427
|
+
online_until = node_id_to_online_until.get(node_id)
|
|
428
|
+
# Generate a reply message containing an error reply
|
|
429
|
+
# if the node is offline or doesn't exist.
|
|
430
|
+
if online_until is None or online_until < current:
|
|
431
|
+
if update_set:
|
|
432
|
+
inquired_in_message_ids.remove(in_message_id)
|
|
433
|
+
reply_message = create_message_error_unavailable_res_message(
|
|
434
|
+
in_message.metadata, "node_unavail"
|
|
396
435
|
)
|
|
397
|
-
|
|
398
|
-
ret_dict[taskins_id] = taskres
|
|
436
|
+
ret_dict[in_message_id] = reply_message
|
|
399
437
|
return ret_dict
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -22,8 +22,8 @@ from uuid import UUID
|
|
|
22
22
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
|
-
from flwr.common import
|
|
26
|
-
from flwr.common.constant import Status
|
|
25
|
+
from flwr.common import ConfigRecord, Message
|
|
26
|
+
from flwr.common.constant import SUPERLINK_NODE_ID, Status
|
|
27
27
|
from flwr.common.logger import log
|
|
28
28
|
from flwr.common.serde import (
|
|
29
29
|
context_from_proto,
|
|
@@ -31,9 +31,7 @@ from flwr.common.serde import (
|
|
|
31
31
|
fab_from_proto,
|
|
32
32
|
fab_to_proto,
|
|
33
33
|
message_from_proto,
|
|
34
|
-
message_from_taskres,
|
|
35
34
|
message_to_proto,
|
|
36
|
-
message_to_taskins,
|
|
37
35
|
run_status_from_proto,
|
|
38
36
|
run_status_to_proto,
|
|
39
37
|
run_to_proto,
|
|
@@ -69,12 +67,11 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
|
69
67
|
PushServerAppOutputsRequest,
|
|
70
68
|
PushServerAppOutputsResponse,
|
|
71
69
|
)
|
|
72
|
-
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
73
70
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
74
71
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
75
72
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
76
73
|
from flwr.server.superlink.utils import abort_if
|
|
77
|
-
from flwr.server.utils.validator import
|
|
74
|
+
from flwr.server.utils.validator import validate_message
|
|
78
75
|
|
|
79
76
|
|
|
80
77
|
class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
@@ -130,7 +127,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
130
127
|
request.fab_version,
|
|
131
128
|
fab_hash,
|
|
132
129
|
user_config_from_proto(request.override_config),
|
|
133
|
-
|
|
130
|
+
ConfigRecord(),
|
|
134
131
|
)
|
|
135
132
|
return CreateRunResponse(run_id=run_id)
|
|
136
133
|
|
|
@@ -161,20 +158,19 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
161
158
|
while request.messages_list:
|
|
162
159
|
message_proto = request.messages_list.pop(0)
|
|
163
160
|
message = message_from_proto(message_proto=message_proto)
|
|
164
|
-
|
|
165
|
-
validation_errors = validate_task_ins_or_res(task_ins)
|
|
161
|
+
validation_errors = validate_message(message, is_reply_message=False)
|
|
166
162
|
_raise_if(
|
|
167
163
|
validation_error=bool(validation_errors),
|
|
168
164
|
request_name="PushMessages",
|
|
169
165
|
detail=", ".join(validation_errors),
|
|
170
166
|
)
|
|
171
167
|
_raise_if(
|
|
172
|
-
validation_error=request.run_id !=
|
|
168
|
+
validation_error=request.run_id != message.metadata.run_id,
|
|
173
169
|
request_name="PushMessages",
|
|
174
|
-
detail="`
|
|
170
|
+
detail="`Message.metadata` has mismatched `run_id`",
|
|
175
171
|
)
|
|
176
172
|
# Store
|
|
177
|
-
message_id: Optional[UUID] = state.
|
|
173
|
+
message_id: Optional[UUID] = state.store_message_ins(message=message)
|
|
178
174
|
message_ids.append(message_id)
|
|
179
175
|
|
|
180
176
|
return PushInsMessagesResponse(
|
|
@@ -200,32 +196,34 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
200
196
|
context,
|
|
201
197
|
)
|
|
202
198
|
|
|
203
|
-
# Convert each
|
|
199
|
+
# Convert each message_id str to UUID
|
|
204
200
|
message_ids: set[UUID] = {
|
|
205
201
|
UUID(message_id) for message_id in request.message_ids
|
|
206
202
|
}
|
|
207
203
|
|
|
208
204
|
# Read from state
|
|
209
|
-
|
|
205
|
+
messages_res: list[Message] = state.get_message_res(message_ids=message_ids)
|
|
210
206
|
|
|
211
|
-
# Delete the
|
|
212
|
-
|
|
213
|
-
UUID(
|
|
207
|
+
# Delete the instruction Messages and their replies if found
|
|
208
|
+
message_ins_ids_to_delete = {
|
|
209
|
+
UUID(msg_res.metadata.reply_to_message_id) for msg_res in messages_res
|
|
214
210
|
}
|
|
215
211
|
|
|
216
|
-
state.
|
|
212
|
+
state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
|
217
213
|
|
|
218
|
-
# Convert to
|
|
214
|
+
# Convert Messages to proto
|
|
219
215
|
messages_list = []
|
|
220
|
-
while
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
216
|
+
while messages_res:
|
|
217
|
+
msg = messages_res.pop(0)
|
|
218
|
+
|
|
219
|
+
# Skip `run_id` check for SuperLink generated replies
|
|
220
|
+
if msg.metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
221
|
+
_raise_if(
|
|
222
|
+
validation_error=request.run_id != msg.metadata.run_id,
|
|
223
|
+
request_name="PullMessages",
|
|
224
|
+
detail="`message.metadata` has mismatched `run_id`",
|
|
225
|
+
)
|
|
226
|
+
messages_list.append(message_to_proto(msg))
|
|
229
227
|
|
|
230
228
|
return PullResMessagesResponse(messages_list=messages_list)
|
|
231
229
|
|
|
@@ -24,7 +24,7 @@ from grpc import ServicerContext
|
|
|
24
24
|
from flwr.common.constant import Status
|
|
25
25
|
from flwr.common.logger import log
|
|
26
26
|
from flwr.common.serde import (
|
|
27
|
-
|
|
27
|
+
config_record_to_proto,
|
|
28
28
|
context_from_proto,
|
|
29
29
|
context_to_proto,
|
|
30
30
|
fab_to_proto,
|
|
@@ -182,5 +182,5 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
182
182
|
)
|
|
183
183
|
return GetFederationOptionsResponse()
|
|
184
184
|
return GetFederationOptionsResponse(
|
|
185
|
-
federation_options=
|
|
185
|
+
federation_options=config_record_to_proto(federation_options)
|
|
186
186
|
)
|
flwr/server/typing.py
CHANGED
|
@@ -19,9 +19,9 @@ from typing import Callable
|
|
|
19
19
|
|
|
20
20
|
from flwr.common import Context
|
|
21
21
|
|
|
22
|
-
from .
|
|
22
|
+
from .grid import Grid
|
|
23
23
|
from .serverapp_components import ServerAppComponents
|
|
24
24
|
|
|
25
|
-
ServerAppCallable = Callable[[
|
|
26
|
-
Workflow = Callable[[
|
|
25
|
+
ServerAppCallable = Callable[[Grid, Context], None]
|
|
26
|
+
Workflow = Callable[[Grid, Context], None]
|
|
27
27
|
ServerFn = Callable[[Context], ServerAppComponents]
|
flwr/server/utils/__init__.py
CHANGED
|
@@ -16,9 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from .tensorboard import tensorboard as tensorboard
|
|
19
|
-
from .validator import
|
|
19
|
+
from .validator import validate_message as validate_message
|
|
20
20
|
|
|
21
21
|
__all__ = [
|
|
22
22
|
"tensorboard",
|
|
23
|
-
"
|
|
23
|
+
"validate_message",
|
|
24
24
|
]
|