flwr 1.17.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +1 -1
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/__init__.py +1 -1
- flwr/cli/app.py +21 -2
- flwr/cli/build.py +83 -58
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +53 -17
- flwr/cli/example.py +1 -1
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +4 -4
- flwr/cli/login/__init__.py +1 -1
- flwr/cli/login/login.py +15 -8
- flwr/cli/ls.py +16 -37
- flwr/cli/new/__init__.py +1 -1
- flwr/cli/new/new.py +4 -4
- flwr/cli/new/templates/__init__.py +1 -1
- flwr/cli/new/templates/app/__init__.py +1 -1
- flwr/cli/new/templates/app/code/__init__.py +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +4 -4
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -1
- flwr/cli/run/run.py +11 -19
- flwr/cli/stop.py +3 -3
- flwr/cli/utils.py +42 -17
- flwr/client/__init__.py +3 -3
- flwr/client/client.py +1 -1
- flwr/client/client_app.py +140 -138
- flwr/client/clientapp/__init__.py +1 -8
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +1 -1
- flwr/client/grpc_adapter_client/connection.py +5 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +131 -61
- flwr/client/grpc_rere_client/grpc_adapter.py +35 -7
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/__init__.py +1 -1
- flwr/client/mod/centraldp_mods.py +1 -1
- flwr/client/mod/comms_mods.py +39 -20
- flwr/client/mod/localdp_mod.py +6 -6
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +174 -68
- flwr/client/run_info_store.py +1 -1
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +3 -3
- flwr/common/address.py +1 -1
- flwr/common/args.py +1 -1
- flwr/common/auth_plugin/__init__.py +3 -1
- flwr/common/auth_plugin/auth_plugin.py +30 -4
- flwr/common/config.py +1 -1
- flwr/common/constant.py +37 -8
- flwr/common/context.py +1 -1
- flwr/common/date.py +1 -1
- flwr/common/differential_privacy.py +1 -1
- flwr/common/differential_privacy_constants.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit.py +6 -6
- flwr/common/exit_handlers.py +31 -1
- flwr/common/grpc.py +1 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/logger.py +1 -1
- flwr/common/message.py +137 -252
- flwr/common/object_ref.py +1 -1
- flwr/common/parameter.py +1 -1
- flwr/common/pyproject.py +1 -1
- flwr/common/record/__init__.py +3 -2
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +121 -243
- flwr/common/record/configrecord.py +71 -16
- flwr/common/record/conversion_utils.py +2 -2
- flwr/common/record/metricrecord.py +71 -20
- flwr/common/record/recorddict.py +207 -90
- flwr/common/record/typeddict.py +1 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +15 -11
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +52 -30
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +60 -184
- flwr/common/serde_utils.py +175 -0
- flwr/common/telemetry.py +2 -2
- flwr/common/typing.py +6 -4
- flwr/common/version.py +1 -1
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +71 -211
- flwr/{client → compat/client}/grpc_client/__init__.py +1 -1
- flwr/{client → compat/client}/grpc_client/connection.py +13 -13
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/__init__.py +1 -1
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +2 -2
- flwr/server/app.py +69 -187
- flwr/server/client_manager.py +1 -1
- flwr/server/client_proxy.py +1 -1
- flwr/server/compat/__init__.py +1 -1
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +51 -29
- flwr/server/compat/legacy_context.py +1 -1
- flwr/server/criterion.py +1 -1
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grid.py +3 -3
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/history.py +1 -1
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +1 -1
- flwr/server/server_app.py +65 -58
- flwr/server/server_config.py +1 -1
- flwr/server/serverapp/__init__.py +1 -1
- flwr/server/serverapp/app.py +19 -1
- flwr/server/serverapp_components.py +1 -1
- flwr/server/strategy/__init__.py +1 -1
- flwr/server/strategy/aggregate.py +1 -1
- flwr/server/strategy/bulyan.py +2 -2
- flwr/server/strategy/dp_adaptive_clipping.py +17 -17
- flwr/server/strategy/dp_fixed_clipping.py +17 -17
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedtrimmedavg.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +3 -2
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/strategy/strategy.py +1 -1
- flwr/server/superlink/__init__.py +1 -1
- flwr/server/superlink/ffs/__init__.py +3 -1
- flwr/server/superlink/ffs/disk_ffs.py +1 -1
- flwr/server/superlink/ffs/ffs.py +1 -1
- flwr/server/superlink/ffs/ffs_factory.py +1 -1
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +14 -4
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +136 -19
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -12
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +7 -4
- flwr/server/superlink/linkstate/__init__.py +1 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +139 -44
- flwr/server/superlink/linkstate/linkstate.py +54 -21
- flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +150 -56
- flwr/server/superlink/linkstate/utils.py +34 -30
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/__init__.py +1 -1
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +45 -3
- flwr/server/typing.py +1 -1
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +3 -3
- flwr/server/workflow/__init__.py +1 -1
- flwr/server/workflow/constant.py +1 -1
- flwr/server/workflow/default_workflows.py +1 -1
- flwr/server/workflow/secure_aggregation/__init__.py +1 -1
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +18 -1
- flwr/simulation/legacy_app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/ray_transport/utils.py +1 -1
- flwr/simulation/run_simulation.py +2 -2
- flwr/simulation/simulationio_connection.py +1 -1
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/__init__.py +1 -1
- flwr/superexec/app.py +1 -1
- flwr/superexec/deployment.py +7 -3
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +8 -4
- flwr/superexec/exec_servicer.py +126 -24
- flwr/superexec/exec_user_auth_interceptor.py +38 -9
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -2
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +1 -8
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +8 -15
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +4 -13
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/{client → supernode}/nodestate/__init__.py +1 -1
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/{client → supernode}/nodestate/nodestate_factory.py +1 -1
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +26 -57
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/METADATA +6 -5
- flwr-1.19.0.dist-info/RECORD +365 -0
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.17.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- flwr-1.17.0.dist-info/LICENSE +0 -202
- flwr-1.17.0.dist-info/RECORD +0 -333
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -24,25 +24,23 @@ import time
|
|
|
24
24
|
from collections.abc import Sequence
|
|
25
25
|
from logging import DEBUG, ERROR, WARNING
|
|
26
26
|
from typing import Any, Optional, Union, cast
|
|
27
|
-
from uuid import UUID, uuid4
|
|
28
27
|
|
|
29
28
|
from flwr.common import Context, Message, Metadata, log, now
|
|
30
29
|
from flwr.common.constant import (
|
|
30
|
+
HEARTBEAT_MAX_INTERVAL,
|
|
31
|
+
HEARTBEAT_PATIENCE,
|
|
31
32
|
MESSAGE_TTL_TOLERANCE,
|
|
32
33
|
NODE_ID_NUM_BYTES,
|
|
33
|
-
|
|
34
|
+
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
34
35
|
RUN_ID_NUM_BYTES,
|
|
35
36
|
SUPERLINK_NODE_ID,
|
|
36
37
|
Status,
|
|
38
|
+
SubStatus,
|
|
37
39
|
)
|
|
38
40
|
from flwr.common.message import make_message
|
|
39
41
|
from flwr.common.record import ConfigRecord
|
|
40
|
-
from flwr.common.serde import
|
|
41
|
-
|
|
42
|
-
error_to_proto,
|
|
43
|
-
recorddict_from_proto,
|
|
44
|
-
recorddict_to_proto,
|
|
45
|
-
)
|
|
42
|
+
from flwr.common.serde import recorddict_from_proto, recorddict_to_proto
|
|
43
|
+
from flwr.common.serde_utils import error_from_proto, error_to_proto
|
|
46
44
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
47
45
|
|
|
48
46
|
# pylint: disable=E0611
|
|
@@ -74,7 +72,7 @@ SQL_CREATE_TABLE_NODE = """
|
|
|
74
72
|
CREATE TABLE IF NOT EXISTS node(
|
|
75
73
|
node_id INTEGER UNIQUE,
|
|
76
74
|
online_until REAL,
|
|
77
|
-
|
|
75
|
+
heartbeat_interval REAL,
|
|
78
76
|
public_key BLOB
|
|
79
77
|
);
|
|
80
78
|
"""
|
|
@@ -92,6 +90,8 @@ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
|
92
90
|
SQL_CREATE_TABLE_RUN = """
|
|
93
91
|
CREATE TABLE IF NOT EXISTS run(
|
|
94
92
|
run_id INTEGER UNIQUE,
|
|
93
|
+
active_until REAL,
|
|
94
|
+
heartbeat_interval REAL,
|
|
95
95
|
fab_id TEXT,
|
|
96
96
|
fab_version TEXT,
|
|
97
97
|
fab_hash TEXT,
|
|
@@ -102,7 +102,8 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
102
102
|
finished_at TEXT,
|
|
103
103
|
sub_status TEXT,
|
|
104
104
|
details TEXT,
|
|
105
|
-
federation_options BLOB
|
|
105
|
+
federation_options BLOB,
|
|
106
|
+
flwr_aid TEXT
|
|
106
107
|
);
|
|
107
108
|
"""
|
|
108
109
|
|
|
@@ -250,19 +251,15 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
250
251
|
|
|
251
252
|
return result
|
|
252
253
|
|
|
253
|
-
def store_message_ins(self, message: Message) -> Optional[
|
|
254
|
+
def store_message_ins(self, message: Message) -> Optional[str]:
|
|
254
255
|
"""Store one Message."""
|
|
255
256
|
# Validate message
|
|
256
257
|
errors = validate_message(message=message, is_reply_message=False)
|
|
257
258
|
if any(errors):
|
|
258
259
|
log(ERROR, errors)
|
|
259
260
|
return None
|
|
260
|
-
# Create message_id
|
|
261
|
-
message_id = uuid4()
|
|
262
261
|
|
|
263
262
|
# Store Message
|
|
264
|
-
# pylint: disable-next=W0212
|
|
265
|
-
message.metadata._message_id = str(message_id) # type: ignore
|
|
266
263
|
data = (message_to_dict(message),)
|
|
267
264
|
|
|
268
265
|
# Convert values from uint64 to sint64 for SQLite
|
|
@@ -302,7 +299,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
302
299
|
# This may need to be changed in the future version with more integrity checks.
|
|
303
300
|
self.query(query, data)
|
|
304
301
|
|
|
305
|
-
return message_id
|
|
302
|
+
return message.metadata.message_id
|
|
306
303
|
|
|
307
304
|
def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
|
|
308
305
|
"""Get all Messages that have not been delivered yet."""
|
|
@@ -365,7 +362,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
365
362
|
|
|
366
363
|
return result
|
|
367
364
|
|
|
368
|
-
def store_message_res(self, message: Message) -> Optional[
|
|
365
|
+
def store_message_res(self, message: Message) -> Optional[str]:
|
|
369
366
|
"""Store one Message."""
|
|
370
367
|
# Validate message
|
|
371
368
|
errors = validate_message(message=message, is_reply_message=True)
|
|
@@ -417,12 +414,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
417
414
|
)
|
|
418
415
|
return None
|
|
419
416
|
|
|
420
|
-
# Create message_id
|
|
421
|
-
message_id = uuid4()
|
|
422
|
-
|
|
423
417
|
# Store Message
|
|
424
|
-
# pylint: disable-next=W0212
|
|
425
|
-
message.metadata._message_id = str(message_id) # type: ignore
|
|
426
418
|
data = (message_to_dict(message),)
|
|
427
419
|
|
|
428
420
|
# Convert values from uint64 to sint64 for SQLite
|
|
@@ -441,12 +433,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
441
433
|
log(ERROR, "`run` is invalid")
|
|
442
434
|
return None
|
|
443
435
|
|
|
444
|
-
return message_id
|
|
436
|
+
return message.metadata.message_id
|
|
445
437
|
|
|
446
|
-
def get_message_res(self, message_ids: set[
|
|
438
|
+
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
|
447
439
|
"""Get reply Messages for the given Message IDs."""
|
|
448
440
|
# pylint: disable-msg=too-many-locals
|
|
449
|
-
ret: dict[
|
|
441
|
+
ret: dict[str, Message] = {}
|
|
450
442
|
|
|
451
443
|
# Verify Message IDs
|
|
452
444
|
current = time.time()
|
|
@@ -456,12 +448,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
456
448
|
WHERE message_id IN ({",".join(["?"] * len(message_ids))});
|
|
457
449
|
"""
|
|
458
450
|
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
|
459
|
-
found_message_ins_dict: dict[
|
|
451
|
+
found_message_ins_dict: dict[str, Message] = {}
|
|
460
452
|
for row in rows:
|
|
461
453
|
convert_sint64_values_in_dict_to_uint64(
|
|
462
454
|
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
463
455
|
)
|
|
464
|
-
found_message_ins_dict[
|
|
456
|
+
found_message_ins_dict[row["message_id"]] = dict_to_message(row)
|
|
465
457
|
|
|
466
458
|
ret = verify_message_ids(
|
|
467
459
|
inquired_message_ids=message_ids,
|
|
@@ -550,7 +542,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
550
542
|
result: dict[str, int] = rows[0]
|
|
551
543
|
return result["num"]
|
|
552
544
|
|
|
553
|
-
def delete_messages(self, message_ins_ids: set[
|
|
545
|
+
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
|
554
546
|
"""Delete a Message and its reply based on provided Message IDs."""
|
|
555
547
|
if not message_ins_ids:
|
|
556
548
|
return
|
|
@@ -576,7 +568,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
576
568
|
self.conn.execute(query_1, data)
|
|
577
569
|
self.conn.execute(query_2, data)
|
|
578
570
|
|
|
579
|
-
def get_message_ids_from_run_id(self, run_id: int) -> set[
|
|
571
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
|
580
572
|
"""Get all instruction Message IDs for the given run_id."""
|
|
581
573
|
if self.conn is None:
|
|
582
574
|
raise AttributeError("LinkState not initialized")
|
|
@@ -593,9 +585,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
593
585
|
with self.conn:
|
|
594
586
|
rows = self.conn.execute(query, data).fetchall()
|
|
595
587
|
|
|
596
|
-
return {
|
|
588
|
+
return {row["message_id"] for row in rows}
|
|
597
589
|
|
|
598
|
-
def create_node(self,
|
|
590
|
+
def create_node(self, heartbeat_interval: float) -> int:
|
|
599
591
|
"""Create, store in the link state, and return `node_id`."""
|
|
600
592
|
# Sample a random uint64 as node_id
|
|
601
593
|
uint64_node_id = generate_rand_int_from_bytes(
|
|
@@ -607,18 +599,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
607
599
|
|
|
608
600
|
query = (
|
|
609
601
|
"INSERT INTO node "
|
|
610
|
-
"(node_id, online_until,
|
|
602
|
+
"(node_id, online_until, heartbeat_interval, public_key) "
|
|
611
603
|
"VALUES (?, ?, ?, ?)"
|
|
612
604
|
)
|
|
613
605
|
|
|
614
|
-
# Mark the node online util time.time() +
|
|
606
|
+
# Mark the node online util time.time() + heartbeat_interval
|
|
615
607
|
try:
|
|
616
608
|
self.query(
|
|
617
609
|
query,
|
|
618
610
|
(
|
|
619
611
|
sint64_node_id,
|
|
620
|
-
time.time() +
|
|
621
|
-
|
|
612
|
+
time.time() + heartbeat_interval,
|
|
613
|
+
heartbeat_interval,
|
|
622
614
|
b"", # Initialize with an empty public key
|
|
623
615
|
),
|
|
624
616
|
)
|
|
@@ -728,6 +720,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
728
720
|
fab_hash: Optional[str],
|
|
729
721
|
override_config: UserConfig,
|
|
730
722
|
federation_options: ConfigRecord,
|
|
723
|
+
flwr_aid: Optional[str],
|
|
731
724
|
) -> int:
|
|
732
725
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
733
726
|
# Sample a random int64 as run_id
|
|
@@ -742,26 +735,28 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
742
735
|
if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
|
|
743
736
|
query = (
|
|
744
737
|
"INSERT INTO run "
|
|
745
|
-
"(run_id,
|
|
746
|
-
"
|
|
747
|
-
"
|
|
738
|
+
"(run_id, active_until, heartbeat_interval, fab_id, fab_version, "
|
|
739
|
+
"fab_hash, override_config, federation_options, pending_at, "
|
|
740
|
+
"starting_at, running_at, finished_at, sub_status, details, flwr_aid) "
|
|
741
|
+
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
|
748
742
|
)
|
|
749
743
|
override_config_json = json.dumps(override_config)
|
|
750
744
|
data = [
|
|
751
745
|
sint64_run_id,
|
|
746
|
+
0, # The `active_until` is not used until the run is started
|
|
747
|
+
0, # This `heartbeat_interval` is not used until the run is started
|
|
752
748
|
fab_id,
|
|
753
749
|
fab_version,
|
|
754
750
|
fab_hash,
|
|
755
751
|
override_config_json,
|
|
756
752
|
configrecord_to_bytes(federation_options),
|
|
757
|
-
]
|
|
758
|
-
data += [
|
|
759
753
|
now().isoformat(),
|
|
760
754
|
"",
|
|
761
755
|
"",
|
|
762
756
|
"",
|
|
763
757
|
"",
|
|
764
758
|
"",
|
|
759
|
+
flwr_aid or "",
|
|
765
760
|
]
|
|
766
761
|
self.query(query, tuple(data))
|
|
767
762
|
return uint64_run_id
|
|
@@ -790,14 +785,47 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
790
785
|
result: set[bytes] = {row["public_key"] for row in rows}
|
|
791
786
|
return result
|
|
792
787
|
|
|
793
|
-
def get_run_ids(self) -> set[int]:
|
|
794
|
-
"""Retrieve all run IDs.
|
|
795
|
-
|
|
796
|
-
|
|
788
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
|
789
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
|
790
|
+
|
|
791
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
|
792
|
+
"""
|
|
793
|
+
if flwr_aid:
|
|
794
|
+
rows = self.query(
|
|
795
|
+
"SELECT run_id FROM run WHERE flwr_aid = ?;",
|
|
796
|
+
(flwr_aid,),
|
|
797
|
+
)
|
|
798
|
+
else:
|
|
799
|
+
rows = self.query("SELECT run_id FROM run;", ())
|
|
797
800
|
return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
|
|
798
801
|
|
|
802
|
+
def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
|
|
803
|
+
"""Check if any runs are no longer active.
|
|
804
|
+
|
|
805
|
+
Marks runs with status 'starting' or 'running' as failed
|
|
806
|
+
if they have not sent a heartbeat before `active_until`.
|
|
807
|
+
"""
|
|
808
|
+
sint_run_ids = [convert_uint64_to_sint64(run_id) for run_id in run_ids]
|
|
809
|
+
query = "UPDATE run SET finished_at = ?, sub_status = ?, details = ? "
|
|
810
|
+
query += "WHERE starting_at != '' AND finished_at = '' AND active_until < ?"
|
|
811
|
+
query += f" AND run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
812
|
+
current = now()
|
|
813
|
+
self.query(
|
|
814
|
+
query,
|
|
815
|
+
(
|
|
816
|
+
current.isoformat(),
|
|
817
|
+
SubStatus.FAILED,
|
|
818
|
+
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
819
|
+
current.timestamp(),
|
|
820
|
+
*sint_run_ids,
|
|
821
|
+
),
|
|
822
|
+
)
|
|
823
|
+
|
|
799
824
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
800
825
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
826
|
+
# Check if runs are still active
|
|
827
|
+
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
828
|
+
|
|
801
829
|
# Convert the uint64 value to sint64 for SQLite
|
|
802
830
|
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
803
831
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
@@ -819,12 +847,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
819
847
|
sub_status=row["sub_status"],
|
|
820
848
|
details=row["details"],
|
|
821
849
|
),
|
|
850
|
+
flwr_aid=row["flwr_aid"],
|
|
822
851
|
)
|
|
823
852
|
log(ERROR, "`run_id` does not exist.")
|
|
824
853
|
return None
|
|
825
854
|
|
|
826
855
|
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
827
856
|
"""Retrieve the statuses for the specified runs."""
|
|
857
|
+
# Check if runs are still active
|
|
858
|
+
self._check_and_tag_inactive_run(run_ids=run_ids)
|
|
859
|
+
|
|
828
860
|
# Convert the uint64 value to sint64 for SQLite
|
|
829
861
|
sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
|
|
830
862
|
query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
@@ -842,6 +874,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
842
874
|
|
|
843
875
|
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
844
876
|
"""Update the status of the run with the specified `run_id`."""
|
|
877
|
+
# Check if runs are still active
|
|
878
|
+
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
879
|
+
|
|
845
880
|
# Convert the uint64 value to sint64 for SQLite
|
|
846
881
|
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
847
882
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
@@ -879,9 +914,22 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
879
914
|
return False
|
|
880
915
|
|
|
881
916
|
# Update the status
|
|
882
|
-
query = "UPDATE run SET %s= ?, sub_status = ?, details =
|
|
917
|
+
query = "UPDATE run SET %s= ?, sub_status = ?, details = ?, "
|
|
918
|
+
query += "active_until = ?, heartbeat_interval = ? "
|
|
883
919
|
query += "WHERE run_id = ?;"
|
|
884
920
|
|
|
921
|
+
# Prepare data for query
|
|
922
|
+
# Initialize heartbeat_interval and active_until
|
|
923
|
+
# when switching to starting or running
|
|
924
|
+
current = now()
|
|
925
|
+
if new_status.status in (Status.STARTING, Status.RUNNING):
|
|
926
|
+
heartbeat_interval = HEARTBEAT_MAX_INTERVAL
|
|
927
|
+
active_until = current.timestamp() + heartbeat_interval
|
|
928
|
+
else:
|
|
929
|
+
heartbeat_interval = 0
|
|
930
|
+
active_until = 0
|
|
931
|
+
|
|
932
|
+
# Determine the timestamp field based on the new status
|
|
885
933
|
timestamp_fld = ""
|
|
886
934
|
if new_status.status == Status.STARTING:
|
|
887
935
|
timestamp_fld = "starting_at"
|
|
@@ -891,10 +939,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
891
939
|
timestamp_fld = "finished_at"
|
|
892
940
|
|
|
893
941
|
data = (
|
|
894
|
-
|
|
942
|
+
current.isoformat(),
|
|
895
943
|
new_status.sub_status,
|
|
896
944
|
new_status.details,
|
|
897
|
-
|
|
945
|
+
active_until,
|
|
946
|
+
heartbeat_interval,
|
|
947
|
+
convert_uint64_to_sint64(run_id),
|
|
898
948
|
)
|
|
899
949
|
self.query(query % timestamp_fld, data)
|
|
900
950
|
return True
|
|
@@ -926,11 +976,15 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
926
976
|
row = rows[0]
|
|
927
977
|
return configrecord_from_bytes(row["federation_options"])
|
|
928
978
|
|
|
929
|
-
def
|
|
930
|
-
|
|
979
|
+
def acknowledge_node_heartbeat(
|
|
980
|
+
self, node_id: int, heartbeat_interval: float
|
|
981
|
+
) -> bool:
|
|
982
|
+
"""Acknowledge a heartbeat received from a node, serving as a heartbeat.
|
|
931
983
|
|
|
932
|
-
|
|
933
|
-
|
|
984
|
+
A node is considered online as long as it sends heartbeats within
|
|
985
|
+
the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
|
|
986
|
+
HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
|
|
987
|
+
the node is marked as offline.
|
|
934
988
|
"""
|
|
935
989
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
936
990
|
|
|
@@ -939,18 +993,58 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
939
993
|
if not self.query(query, (sint64_node_id,)):
|
|
940
994
|
return False
|
|
941
995
|
|
|
942
|
-
# Update `online_until` and `
|
|
943
|
-
query =
|
|
996
|
+
# Update `online_until` and `heartbeat_interval` for the given `node_id`
|
|
997
|
+
query = (
|
|
998
|
+
"UPDATE node SET online_until = ?, heartbeat_interval = ? WHERE node_id = ?"
|
|
999
|
+
)
|
|
944
1000
|
self.query(
|
|
945
1001
|
query,
|
|
946
1002
|
(
|
|
947
|
-
time.time() +
|
|
948
|
-
|
|
1003
|
+
time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
|
|
1004
|
+
heartbeat_interval,
|
|
949
1005
|
sint64_node_id,
|
|
950
1006
|
),
|
|
951
1007
|
)
|
|
952
1008
|
return True
|
|
953
1009
|
|
|
1010
|
+
def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
|
|
1011
|
+
"""Acknowledge a heartbeat received from a ServerApp for a given run.
|
|
1012
|
+
|
|
1013
|
+
A run with status `"running"` is considered alive as long as it sends heartbeats
|
|
1014
|
+
within the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
|
|
1015
|
+
HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before the run is
|
|
1016
|
+
marked as `"completed:failed"`.
|
|
1017
|
+
"""
|
|
1018
|
+
# Check if runs are still active
|
|
1019
|
+
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
1020
|
+
|
|
1021
|
+
# Search for the run
|
|
1022
|
+
sint_run_id = convert_uint64_to_sint64(run_id)
|
|
1023
|
+
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
1024
|
+
rows = self.query(query, (sint_run_id,))
|
|
1025
|
+
|
|
1026
|
+
if not rows:
|
|
1027
|
+
log(ERROR, "`run_id` is invalid")
|
|
1028
|
+
return False
|
|
1029
|
+
|
|
1030
|
+
# Check if the run is of status "running"/"starting"
|
|
1031
|
+
row = rows[0]
|
|
1032
|
+
status = determine_run_status(row)
|
|
1033
|
+
if status not in (Status.RUNNING, Status.STARTING):
|
|
1034
|
+
log(
|
|
1035
|
+
ERROR,
|
|
1036
|
+
'Cannot acknowledge heartbeat for run with status "%s"',
|
|
1037
|
+
status,
|
|
1038
|
+
)
|
|
1039
|
+
return False
|
|
1040
|
+
|
|
1041
|
+
# Update the `active_until` and `heartbeat_interval` for the given run
|
|
1042
|
+
active_until = now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
|
|
1043
|
+
query = "UPDATE run SET active_until = ?, heartbeat_interval = ? "
|
|
1044
|
+
query += "WHERE run_id = ?"
|
|
1045
|
+
self.query(query, (active_until, heartbeat_interval, sint_run_id))
|
|
1046
|
+
return True
|
|
1047
|
+
|
|
954
1048
|
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
955
1049
|
"""Get the context for the specified `run_id`."""
|
|
956
1050
|
# Retrieve context if any
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,10 +17,10 @@
|
|
|
17
17
|
|
|
18
18
|
from os import urandom
|
|
19
19
|
from typing import Optional
|
|
20
|
-
from uuid import UUID, uuid4
|
|
21
20
|
|
|
22
21
|
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
|
23
22
|
from flwr.common.constant import (
|
|
23
|
+
HEARTBEAT_PATIENCE,
|
|
24
24
|
SUPERLINK_NODE_ID,
|
|
25
25
|
ErrorCode,
|
|
26
26
|
MessageType,
|
|
@@ -56,8 +56,8 @@ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
|
56
56
|
"Error: Reply Message Unavailable - The reply message has expired."
|
|
57
57
|
)
|
|
58
58
|
NODE_UNAVAILABLE_ERROR_REASON = (
|
|
59
|
-
"Error: Node Unavailable
|
|
60
|
-
"
|
|
59
|
+
"Error: Node Unavailable — The destination node failed to report a heartbeat "
|
|
60
|
+
f"within {HEARTBEAT_PATIENCE} × its expected interval."
|
|
61
61
|
)
|
|
62
62
|
|
|
63
63
|
|
|
@@ -245,7 +245,7 @@ def create_message_error_unavailable_res_message(
|
|
|
245
245
|
ttl = max(ins_metadata.ttl - (current_time - ins_metadata.created_at), 0)
|
|
246
246
|
metadata = Metadata(
|
|
247
247
|
run_id=ins_metadata.run_id,
|
|
248
|
-
message_id=
|
|
248
|
+
message_id="",
|
|
249
249
|
src_node_id=SUPERLINK_NODE_ID,
|
|
250
250
|
dst_node_id=SUPERLINK_NODE_ID,
|
|
251
251
|
reply_to_message_id=ins_metadata.message_id,
|
|
@@ -255,7 +255,7 @@ def create_message_error_unavailable_res_message(
|
|
|
255
255
|
ttl=ttl,
|
|
256
256
|
)
|
|
257
257
|
|
|
258
|
-
|
|
258
|
+
msg = make_message(
|
|
259
259
|
metadata=metadata,
|
|
260
260
|
error=Error(
|
|
261
261
|
code=(
|
|
@@ -270,30 +270,34 @@ def create_message_error_unavailable_res_message(
|
|
|
270
270
|
),
|
|
271
271
|
),
|
|
272
272
|
)
|
|
273
|
+
msg.metadata.__dict__["_message_id"] = msg.object_id
|
|
274
|
+
return msg
|
|
273
275
|
|
|
274
276
|
|
|
275
|
-
def create_message_error_unavailable_ins_message(reply_to_message_id:
|
|
277
|
+
def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Message:
|
|
276
278
|
"""Error to indicate that the enquired Message had expired before reply arrived or
|
|
277
279
|
that it isn't found."""
|
|
278
280
|
metadata = Metadata(
|
|
279
281
|
run_id=0, # Unknown
|
|
280
|
-
message_id=
|
|
282
|
+
message_id="",
|
|
281
283
|
src_node_id=SUPERLINK_NODE_ID,
|
|
282
284
|
dst_node_id=SUPERLINK_NODE_ID,
|
|
283
|
-
reply_to_message_id=
|
|
285
|
+
reply_to_message_id=reply_to_message_id,
|
|
284
286
|
group_id="", # Unknown
|
|
285
287
|
message_type=MessageType.SYSTEM,
|
|
286
288
|
created_at=now().timestamp(),
|
|
287
289
|
ttl=0,
|
|
288
290
|
)
|
|
289
291
|
|
|
290
|
-
|
|
292
|
+
msg = make_message(
|
|
291
293
|
metadata=metadata,
|
|
292
294
|
error=Error(
|
|
293
295
|
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
|
294
296
|
reason=MESSAGE_UNAVAILABLE_ERROR_REASON,
|
|
295
297
|
),
|
|
296
298
|
)
|
|
299
|
+
msg.metadata.__dict__["_message_id"] = msg.object_id
|
|
300
|
+
return msg
|
|
297
301
|
|
|
298
302
|
|
|
299
303
|
def message_ttl_has_expired(message_metadata: Metadata, current_time: float) -> bool:
|
|
@@ -302,18 +306,18 @@ def message_ttl_has_expired(message_metadata: Metadata, current_time: float) ->
|
|
|
302
306
|
|
|
303
307
|
|
|
304
308
|
def verify_message_ids(
|
|
305
|
-
inquired_message_ids: set[
|
|
306
|
-
found_message_ins_dict: dict[
|
|
309
|
+
inquired_message_ids: set[str],
|
|
310
|
+
found_message_ins_dict: dict[str, Message],
|
|
307
311
|
current_time: Optional[float] = None,
|
|
308
312
|
update_set: bool = True,
|
|
309
|
-
) -> dict[
|
|
313
|
+
) -> dict[str, Message]:
|
|
310
314
|
"""Verify found Messages and generate error Messages for invalid ones.
|
|
311
315
|
|
|
312
316
|
Parameters
|
|
313
317
|
----------
|
|
314
|
-
inquired_message_ids : set[
|
|
318
|
+
inquired_message_ids : set[str]
|
|
315
319
|
Set of Message IDs for which to generate error Message if invalid.
|
|
316
|
-
found_message_ins_dict : dict[
|
|
320
|
+
found_message_ins_dict : dict[str, Message]
|
|
317
321
|
Dictionary containing all found Message indexed by their IDs.
|
|
318
322
|
current_time : Optional[float] (default: None)
|
|
319
323
|
The current time to check for expiration. If set to `None`, the current time
|
|
@@ -324,7 +328,7 @@ def verify_message_ids(
|
|
|
324
328
|
|
|
325
329
|
Returns
|
|
326
330
|
-------
|
|
327
|
-
dict[
|
|
331
|
+
dict[str, Message]
|
|
328
332
|
A dictionary of error Message indexed by the corresponding ID of the message
|
|
329
333
|
they are a reply of.
|
|
330
334
|
"""
|
|
@@ -344,19 +348,19 @@ def verify_message_ids(
|
|
|
344
348
|
|
|
345
349
|
|
|
346
350
|
def verify_found_message_replies(
|
|
347
|
-
inquired_message_ids: set[
|
|
348
|
-
found_message_ins_dict: dict[
|
|
351
|
+
inquired_message_ids: set[str],
|
|
352
|
+
found_message_ins_dict: dict[str, Message],
|
|
349
353
|
found_message_res_list: list[Message],
|
|
350
354
|
current_time: Optional[float] = None,
|
|
351
355
|
update_set: bool = True,
|
|
352
|
-
) -> dict[
|
|
356
|
+
) -> dict[str, Message]:
|
|
353
357
|
"""Verify found Message replies and generate error Message for invalid ones.
|
|
354
358
|
|
|
355
359
|
Parameters
|
|
356
360
|
----------
|
|
357
|
-
inquired_message_ids : set[
|
|
361
|
+
inquired_message_ids : set[str]
|
|
358
362
|
Set of Message IDs for which to generate error Message if invalid.
|
|
359
|
-
found_message_ins_dict : dict[
|
|
363
|
+
found_message_ins_dict : dict[str, Message]
|
|
360
364
|
Dictionary containing all found instruction Messages indexed by their IDs.
|
|
361
365
|
found_message_res_list : dict[Message, Message]
|
|
362
366
|
List of found Message to be verified.
|
|
@@ -369,13 +373,13 @@ def verify_found_message_replies(
|
|
|
369
373
|
|
|
370
374
|
Returns
|
|
371
375
|
-------
|
|
372
|
-
dict[
|
|
376
|
+
dict[str, Message]
|
|
373
377
|
A dictionary of Message indexed by the corresponding Message ID.
|
|
374
378
|
"""
|
|
375
|
-
ret_dict: dict[
|
|
379
|
+
ret_dict: dict[str, Message] = {}
|
|
376
380
|
current = current_time if current_time else now().timestamp()
|
|
377
381
|
for message_res in found_message_res_list:
|
|
378
|
-
message_ins_id =
|
|
382
|
+
message_ins_id = message_res.metadata.reply_to_message_id
|
|
379
383
|
if update_set:
|
|
380
384
|
inquired_message_ids.remove(message_ins_id)
|
|
381
385
|
# Check if the reply Message has expired
|
|
@@ -389,21 +393,21 @@ def verify_found_message_replies(
|
|
|
389
393
|
|
|
390
394
|
|
|
391
395
|
def check_node_availability_for_in_message(
|
|
392
|
-
inquired_in_message_ids: set[
|
|
393
|
-
found_in_message_dict: dict[
|
|
396
|
+
inquired_in_message_ids: set[str],
|
|
397
|
+
found_in_message_dict: dict[str, Message],
|
|
394
398
|
node_id_to_online_until: dict[int, float],
|
|
395
399
|
current_time: Optional[float] = None,
|
|
396
400
|
update_set: bool = True,
|
|
397
|
-
) -> dict[
|
|
401
|
+
) -> dict[str, Message]:
|
|
398
402
|
"""Check node availability for given Message and generate error reply Message if
|
|
399
403
|
unavailable. A Message error indicating node unavailability will be generated for
|
|
400
404
|
each given Message whose destination node is offline or non-existent.
|
|
401
405
|
|
|
402
406
|
Parameters
|
|
403
407
|
----------
|
|
404
|
-
inquired_in_message_ids : set[
|
|
408
|
+
inquired_in_message_ids : set[str]
|
|
405
409
|
Set of Message IDs for which to check destination node availability.
|
|
406
|
-
found_in_message_dict : dict[
|
|
410
|
+
found_in_message_dict : dict[str, Message]
|
|
407
411
|
Dictionary containing all found Message indexed by their IDs.
|
|
408
412
|
node_id_to_online_until : dict[int, float]
|
|
409
413
|
Dictionary mapping node IDs to their online-until timestamps.
|
|
@@ -416,7 +420,7 @@ def check_node_availability_for_in_message(
|
|
|
416
420
|
|
|
417
421
|
Returns
|
|
418
422
|
-------
|
|
419
|
-
dict[
|
|
423
|
+
dict[str, Message]
|
|
420
424
|
A dictionary of error Message indexed by the corresponding Message ID.
|
|
421
425
|
"""
|
|
422
426
|
ret_dict = {}
|
|
@@ -28,6 +28,7 @@ from flwr.proto.serverappio_pb2_grpc import ( # pylint: disable=E0611
|
|
|
28
28
|
)
|
|
29
29
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
30
30
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
31
|
+
from flwr.supercore.object_store import ObjectStoreFactory
|
|
31
32
|
|
|
32
33
|
from .serverappio_servicer import ServerAppIoServicer
|
|
33
34
|
|
|
@@ -36,6 +37,7 @@ def run_serverappio_api_grpc(
|
|
|
36
37
|
address: str,
|
|
37
38
|
state_factory: LinkStateFactory,
|
|
38
39
|
ffs_factory: FfsFactory,
|
|
40
|
+
objectstore_factory: ObjectStoreFactory,
|
|
39
41
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
40
42
|
) -> grpc.Server:
|
|
41
43
|
"""Run ServerAppIo API (gRPC, request-response)."""
|
|
@@ -43,6 +45,7 @@ def run_serverappio_api_grpc(
|
|
|
43
45
|
serverappio_servicer: grpc.Server = ServerAppIoServicer(
|
|
44
46
|
state_factory=state_factory,
|
|
45
47
|
ffs_factory=ffs_factory,
|
|
48
|
+
objectstore_factory=objectstore_factory,
|
|
46
49
|
)
|
|
47
50
|
serverappio_add_servicer_to_server_fn = add_ServerAppIoServicer_to_server
|
|
48
51
|
serverappio_grpc_server = generic_create_grpc_server(
|