flwr-nightly 1.17.0.dev20250318__py3-none-any.whl → 1.17.0.dev20250320__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/app.py +6 -4
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +23 -20
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +5 -5
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +2 -0
- flwr/common/constant.py +2 -0
- flwr/common/context.py +4 -4
- flwr/common/logger.py +2 -2
- flwr/common/message.py +269 -101
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/configsrecord.py +2 -2
- flwr/common/record/metricsrecord.py +1 -1
- flwr/common/record/parametersrecord.py +1 -1
- flwr/common/record/{recordset.py → recorddict.py} +57 -17
- flwr/common/{recordset_compat.py → recorddict_compat.py} +105 -105
- flwr/common/serde.py +33 -37
- flwr/proto/exec_pb2.py +32 -32
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +2 -2
- flwr/proto/run_pb2.py +32 -32
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +2 -0
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/grid_client_proxy.py +38 -38
- flwr/server/grid/__init__.py +7 -6
- flwr/server/grid/grid.py +46 -17
- flwr/server/grid/grpc_grid.py +26 -33
- flwr/server/grid/inmemory_grid.py +19 -25
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +37 -11
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +29 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +54 -20
- flwr/server/superlink/linkstate/utils.py +77 -17
- flwr/server/superlink/serverappio/serverappio_servicer.py +1 -1
- flwr/server/typing.py +3 -3
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +24 -26
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +23 -23
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +13 -13
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/simulation.py +2 -2
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/RECORD +60 -60
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/entry_points.txt +0 -0
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""Flower in-memory
|
15
|
+
"""Flower in-memory Grid."""
|
16
16
|
|
17
17
|
|
18
18
|
import time
|
@@ -20,22 +20,23 @@ from collections.abc import Iterable
|
|
20
20
|
from typing import Optional, cast
|
21
21
|
from uuid import UUID
|
22
22
|
|
23
|
-
from flwr.common import
|
23
|
+
from flwr.common import Message, RecordDict
|
24
24
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
25
|
+
from flwr.common.logger import warn_deprecated_feature
|
25
26
|
from flwr.common.typing import Run
|
26
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
27
28
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
28
29
|
|
29
|
-
from .grid import
|
30
|
+
from .grid import Grid
|
30
31
|
|
31
32
|
|
32
|
-
class
|
33
|
-
"""`
|
33
|
+
class InMemoryGrid(Grid):
|
34
|
+
"""`InMemoryGrid` class provides an interface to the ServerAppIo API.
|
34
35
|
|
35
36
|
Parameters
|
36
37
|
----------
|
37
38
|
state_factory : StateFactory
|
38
|
-
A StateFactory embedding a state that this
|
39
|
+
A StateFactory embedding a state that this grid can interface with.
|
39
40
|
pull_interval : float (default=0.1)
|
40
41
|
Sleep duration between calls to `pull_messages`.
|
41
42
|
"""
|
@@ -53,10 +54,8 @@ class InMemoryDriver(Driver):
|
|
53
54
|
def _check_message(self, message: Message) -> None:
|
54
55
|
# Check if the message is valid
|
55
56
|
if not (
|
56
|
-
message.metadata.
|
57
|
-
and message.metadata.
|
58
|
-
and message.metadata.message_id == ""
|
59
|
-
and message.metadata.reply_to_message == ""
|
57
|
+
message.metadata.message_id == ""
|
58
|
+
and message.metadata.reply_to_message_id == ""
|
60
59
|
and message.metadata.ttl > 0
|
61
60
|
and message.metadata.delivered_at == ""
|
62
61
|
):
|
@@ -76,7 +75,7 @@ class InMemoryDriver(Driver):
|
|
76
75
|
|
77
76
|
def create_message( # pylint: disable=too-many-arguments,R0917
|
78
77
|
self,
|
79
|
-
content:
|
78
|
+
content: RecordDict,
|
80
79
|
message_type: str,
|
81
80
|
dst_node_id: int,
|
82
81
|
group_id: str,
|
@@ -87,19 +86,11 @@ class InMemoryDriver(Driver):
|
|
87
86
|
This method constructs a new `Message` with given content and metadata.
|
88
87
|
The `run_id` and `src_node_id` will be set automatically.
|
89
88
|
"""
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
run_id=cast(Run, self._run).run_id,
|
94
|
-
message_id="", # Will be set by the server
|
95
|
-
src_node_id=self.node.node_id,
|
96
|
-
dst_node_id=dst_node_id,
|
97
|
-
reply_to_message="",
|
98
|
-
group_id=group_id,
|
99
|
-
ttl=ttl_,
|
100
|
-
message_type=message_type,
|
89
|
+
warn_deprecated_feature(
|
90
|
+
"`Driver.create_message` / `Grid.create_message` is deprecated."
|
91
|
+
"Use `Message` constructor instead."
|
101
92
|
)
|
102
|
-
return Message(
|
93
|
+
return Message(content, dst_node_id, message_type, ttl=ttl, group_id=group_id)
|
103
94
|
|
104
95
|
def get_node_ids(self) -> Iterable[int]:
|
105
96
|
"""Get node IDs."""
|
@@ -113,6 +104,9 @@ class InMemoryDriver(Driver):
|
|
113
104
|
"""
|
114
105
|
msg_ids: list[str] = []
|
115
106
|
for msg in messages:
|
107
|
+
# Populate metadata
|
108
|
+
msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
|
109
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
116
110
|
# Check message
|
117
111
|
self._check_message(msg)
|
118
112
|
# Store in state
|
@@ -133,7 +127,7 @@ class InMemoryDriver(Driver):
|
|
133
127
|
message_res_list = self.state.get_message_res(message_ids=msg_ids)
|
134
128
|
# Get IDs of Messages these replies are for
|
135
129
|
message_ins_ids_to_delete = {
|
136
|
-
UUID(msg_res.metadata.
|
130
|
+
UUID(msg_res.metadata.reply_to_message_id) for msg_res in message_res_list
|
137
131
|
}
|
138
132
|
# Delete
|
139
133
|
self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
@@ -162,7 +156,7 @@ class InMemoryDriver(Driver):
|
|
162
156
|
res_msgs = self.pull_messages(msg_ids)
|
163
157
|
ret.extend(res_msgs)
|
164
158
|
msg_ids.difference_update(
|
165
|
-
{msg.metadata.
|
159
|
+
{msg.metadata.reply_to_message_id for msg in res_msgs}
|
166
160
|
)
|
167
161
|
if len(msg_ids) == 0:
|
168
162
|
break
|
flwr/server/run_serverapp.py
CHANGED
@@ -22,18 +22,18 @@ from flwr.common import Context
|
|
22
22
|
from flwr.common.logger import log
|
23
23
|
from flwr.common.object_ref import load_app
|
24
24
|
|
25
|
-
from .grid import
|
25
|
+
from .grid import Grid
|
26
26
|
from .server_app import LoadServerAppError, ServerApp
|
27
27
|
|
28
28
|
|
29
29
|
def run(
|
30
|
-
|
30
|
+
grid: Grid,
|
31
31
|
context: Context,
|
32
32
|
server_app_dir: str,
|
33
33
|
server_app_attr: Optional[str] = None,
|
34
34
|
loaded_server_app: Optional[ServerApp] = None,
|
35
35
|
) -> Context:
|
36
|
-
"""Run ServerApp with a given
|
36
|
+
"""Run ServerApp with a given Grid."""
|
37
37
|
if not (server_app_attr is None) ^ (loaded_server_app is None):
|
38
38
|
raise ValueError(
|
39
39
|
"Either `server_app_attr` or `loaded_server_app` should be set "
|
@@ -59,7 +59,7 @@ def run(
|
|
59
59
|
server_app = _load()
|
60
60
|
|
61
61
|
# Call ServerApp
|
62
|
-
server_app(
|
62
|
+
server_app(grid=grid, context=context)
|
63
63
|
|
64
64
|
log(DEBUG, "ServerApp finished running.")
|
65
65
|
return context
|
flwr/server/server_app.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
"""Flower ServerApp."""
|
16
16
|
|
17
17
|
|
18
|
+
import inspect
|
18
19
|
from collections.abc import Iterator
|
19
20
|
from contextlib import contextmanager
|
20
21
|
from typing import Callable, Optional
|
@@ -24,8 +25,8 @@ from flwr.common.logger import warn_deprecated_feature_with_example
|
|
24
25
|
from flwr.server.strategy import Strategy
|
25
26
|
|
26
27
|
from .client_manager import ClientManager
|
27
|
-
from .compat import
|
28
|
-
from .grid import Driver
|
28
|
+
from .compat import start_grid
|
29
|
+
from .grid import Driver, Grid
|
29
30
|
from .server import Server
|
30
31
|
from .server_config import ServerConfig
|
31
32
|
from .typing import ServerAppCallable, ServerFn
|
@@ -43,6 +44,21 @@ SERVER_FN_USAGE_EXAMPLE = """
|
|
43
44
|
app = ServerApp(server_fn=server_fn)
|
44
45
|
"""
|
45
46
|
|
47
|
+
GRID_USAGE_EXAMPLE = """
|
48
|
+
app = ServerApp()
|
49
|
+
|
50
|
+
@app.main()
|
51
|
+
def main(grid: Grid, context: Context) -> None:
|
52
|
+
# Your existing ServerApp code ...
|
53
|
+
"""
|
54
|
+
|
55
|
+
DRIVER_DEPRECATION_MSG = """
|
56
|
+
The `Driver` class is deprecated, it will be removed in a future release.
|
57
|
+
"""
|
58
|
+
DRIVER_EXAMPLE_MSG = """
|
59
|
+
Instead, use `Grid` in the signature of your `ServerApp`. For example:
|
60
|
+
"""
|
61
|
+
|
46
62
|
|
47
63
|
@contextmanager
|
48
64
|
def _empty_lifespan(_: Context) -> Iterator[None]:
|
@@ -54,7 +70,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
54
70
|
|
55
71
|
Examples
|
56
72
|
--------
|
57
|
-
Use the
|
73
|
+
Use the ``ServerApp`` with an existing ``Strategy``:
|
58
74
|
|
59
75
|
>>> def server_fn(context: Context):
|
60
76
|
>>> server_config = ServerConfig(num_rounds=3)
|
@@ -66,12 +82,12 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
66
82
|
>>>
|
67
83
|
>>> app = ServerApp(server_fn=server_fn)
|
68
84
|
|
69
|
-
Use the
|
85
|
+
Use the ``ServerApp`` with a custom main function:
|
70
86
|
|
71
87
|
>>> app = ServerApp()
|
72
88
|
>>>
|
73
89
|
>>> @app.main()
|
74
|
-
>>> def main(
|
90
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
75
91
|
>>> print("ServerApp running")
|
76
92
|
"""
|
77
93
|
|
@@ -111,7 +127,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
111
127
|
self._main: Optional[ServerAppCallable] = None
|
112
128
|
self._lifespan = _empty_lifespan
|
113
129
|
|
114
|
-
def __call__(self,
|
130
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
115
131
|
"""Execute `ServerApp`."""
|
116
132
|
with self._lifespan(context):
|
117
133
|
# Compatibility mode
|
@@ -123,17 +139,17 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
123
139
|
self._config = components.config
|
124
140
|
self._strategy = components.strategy
|
125
141
|
self._client_manager = components.client_manager
|
126
|
-
|
142
|
+
start_grid(
|
127
143
|
server=self._server,
|
128
144
|
config=self._config,
|
129
145
|
strategy=self._strategy,
|
130
146
|
client_manager=self._client_manager,
|
131
|
-
|
147
|
+
grid=grid,
|
132
148
|
)
|
133
149
|
return
|
134
150
|
|
135
151
|
# New execution mode
|
136
|
-
self._main(
|
152
|
+
self._main(grid, context)
|
137
153
|
|
138
154
|
def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
|
139
155
|
"""Return a decorator that registers the main fn with the server app.
|
@@ -143,7 +159,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
143
159
|
>>> app = ServerApp()
|
144
160
|
>>>
|
145
161
|
>>> @app.main()
|
146
|
-
>>> def main(
|
162
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
147
163
|
>>> print("ServerApp running")
|
148
164
|
"""
|
149
165
|
|
@@ -168,11 +184,21 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
168
184
|
>>> app = ServerApp()
|
169
185
|
>>>
|
170
186
|
>>> @app.main()
|
171
|
-
>>> def main(
|
187
|
+
>>> def main(grid: Grid, context: Context) -> None:
|
172
188
|
>>> print("ServerApp running")
|
173
189
|
""",
|
174
190
|
)
|
175
191
|
|
192
|
+
sig = inspect.signature(main_fn)
|
193
|
+
param = list(sig.parameters.values())[0]
|
194
|
+
# Check if parameter name or the annotation should be updated
|
195
|
+
if param.name == "driver" or param.annotation is Driver:
|
196
|
+
warn_deprecated_feature_with_example(
|
197
|
+
deprecation_message=DRIVER_DEPRECATION_MSG,
|
198
|
+
example_message=DRIVER_EXAMPLE_MSG,
|
199
|
+
code_example=GRID_USAGE_EXAMPLE,
|
200
|
+
)
|
201
|
+
|
176
202
|
# Register provided function with the ServerApp object
|
177
203
|
self._main = main_fn
|
178
204
|
|
flwr/server/serverapp/app.py
CHANGED
@@ -60,7 +60,7 @@ from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
60
60
|
PullServerAppInputsResponse,
|
61
61
|
PushServerAppOutputsRequest,
|
62
62
|
)
|
63
|
-
from flwr.server.grid.grpc_grid import
|
63
|
+
from flwr.server.grid.grpc_grid import GrpcGrid
|
64
64
|
from flwr.server.run_serverapp import run as run_
|
65
65
|
|
66
66
|
|
@@ -106,7 +106,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
106
106
|
certificates: Optional[bytes] = None,
|
107
107
|
) -> None:
|
108
108
|
"""Run Flower ServerApp process."""
|
109
|
-
|
109
|
+
grid = GrpcGrid(
|
110
110
|
serverappio_service_address=serverappio_api_address,
|
111
111
|
root_certificates=certificates,
|
112
112
|
)
|
@@ -123,7 +123,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
123
123
|
# Pull ServerAppInputs from LinkState
|
124
124
|
req = PullServerAppInputsRequest()
|
125
125
|
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
126
|
-
res: PullServerAppInputsResponse =
|
126
|
+
res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
|
127
127
|
if not res.HasField("run"):
|
128
128
|
sleep(3)
|
129
129
|
run_status = None
|
@@ -135,14 +135,14 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
135
135
|
|
136
136
|
hash_run_id = get_sha256_hash(run.run_id)
|
137
137
|
|
138
|
-
|
138
|
+
grid.set_run(run.run_id)
|
139
139
|
|
140
140
|
# Start log uploader for this run
|
141
141
|
log_uploader = start_log_uploader(
|
142
142
|
log_queue=log_queue,
|
143
143
|
node_id=0,
|
144
144
|
run_id=run.run_id,
|
145
|
-
stub=
|
145
|
+
stub=grid._stub,
|
146
146
|
)
|
147
147
|
|
148
148
|
log(DEBUG, "[flwr-serverapp] Start FAB installation.")
|
@@ -173,7 +173,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
173
173
|
|
174
174
|
# Change status to Running
|
175
175
|
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
176
|
-
|
176
|
+
grid._stub.UpdateRunStatus(
|
177
177
|
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
178
178
|
)
|
179
179
|
|
@@ -182,9 +182,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
182
182
|
event_details={"run-id-hash": hash_run_id},
|
183
183
|
)
|
184
184
|
|
185
|
-
# Load and run the ServerApp with the
|
185
|
+
# Load and run the ServerApp with the Grid
|
186
186
|
updated_context = run_(
|
187
|
-
|
187
|
+
grid=grid,
|
188
188
|
server_app_dir=app_path,
|
189
189
|
server_app_attr=server_app_attr,
|
190
190
|
context=context,
|
@@ -196,7 +196,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
196
196
|
out_req = PushServerAppOutputsRequest(
|
197
197
|
run_id=run.run_id, context=context_proto
|
198
198
|
)
|
199
|
-
_ =
|
199
|
+
_ = grid._stub.PushServerAppOutputs(out_req)
|
200
200
|
|
201
201
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
202
202
|
except RunNotRunningException:
|
@@ -221,7 +221,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
221
221
|
# Update run status
|
222
222
|
if run_status:
|
223
223
|
run_status_proto = run_status_to_proto(run_status)
|
224
|
-
|
224
|
+
grid._stub.UpdateRunStatus(
|
225
225
|
UpdateRunStatusRequest(
|
226
226
|
run_id=run.run_id, run_status=run_status_proto
|
227
227
|
)
|
@@ -130,9 +130,7 @@ def worker(
|
|
130
130
|
e_code = ErrorCode.UNKNOWN
|
131
131
|
|
132
132
|
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
133
|
-
out_mssg = message
|
134
|
-
error=Error(code=e_code, reason=reason)
|
135
|
-
)
|
133
|
+
out_mssg = Message(Error(code=e_code, reason=reason), reply_to=message)
|
136
134
|
|
137
135
|
finally:
|
138
136
|
if out_mssg:
|
@@ -27,6 +27,7 @@ from flwr.common import Context, Message, log, now
|
|
27
27
|
from flwr.common.constant import (
|
28
28
|
MESSAGE_TTL_TOLERANCE,
|
29
29
|
NODE_ID_NUM_BYTES,
|
30
|
+
PING_PATIENCE,
|
30
31
|
RUN_ID_NUM_BYTES,
|
31
32
|
SUPERLINK_NODE_ID,
|
32
33
|
Status,
|
@@ -37,6 +38,7 @@ from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
37
38
|
from flwr.server.utils import validate_message
|
38
39
|
|
39
40
|
from .utils import (
|
41
|
+
check_node_availability_for_in_message,
|
40
42
|
generate_rand_int_from_bytes,
|
41
43
|
has_valid_sub_status,
|
42
44
|
is_valid_transition,
|
@@ -156,7 +158,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
156
158
|
res_metadata = message.metadata
|
157
159
|
with self.lock:
|
158
160
|
# Check if the Message it is replying to exists and is valid
|
159
|
-
msg_ins_id = res_metadata.
|
161
|
+
msg_ins_id = res_metadata.reply_to_message_id
|
160
162
|
msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
|
161
163
|
|
162
164
|
# Ensure that dst_node_id of original Message matches the src_node_id of
|
@@ -232,13 +234,28 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
232
234
|
with self.lock:
|
233
235
|
current = time.time()
|
234
236
|
|
235
|
-
# Verify
|
237
|
+
# Verify Message IDs
|
236
238
|
ret = verify_message_ids(
|
237
239
|
inquired_message_ids=message_ids,
|
238
240
|
found_message_ins_dict=self.message_ins_store,
|
239
241
|
current_time=current,
|
240
242
|
)
|
241
243
|
|
244
|
+
# Check node availability
|
245
|
+
dst_node_ids = {
|
246
|
+
self.message_ins_store[message_id].metadata.dst_node_id
|
247
|
+
for message_id in message_ids
|
248
|
+
}
|
249
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
250
|
+
inquired_in_message_ids=message_ids,
|
251
|
+
found_in_message_dict=self.message_ins_store,
|
252
|
+
node_id_to_online_until={
|
253
|
+
node_id: self.node_ids[node_id][0] for node_id in dst_node_ids
|
254
|
+
},
|
255
|
+
current_time=current,
|
256
|
+
)
|
257
|
+
ret.update(tmp_ret_dict)
|
258
|
+
|
242
259
|
# Find all reply Messages
|
243
260
|
message_res_found: list[Message] = []
|
244
261
|
for message_id in message_ids:
|
@@ -317,6 +334,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
317
334
|
log(ERROR, "Unexpected node registration failure.")
|
318
335
|
return 0
|
319
336
|
|
337
|
+
# Mark the node online util time.time() + ping_interval
|
320
338
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
321
339
|
return node_id
|
322
340
|
|
@@ -519,10 +537,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
519
537
|
return self.federation_options[run_id]
|
520
538
|
|
521
539
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
522
|
-
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
540
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
541
|
+
|
542
|
+
It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
|
543
|
+
marking the node as offline, where PING_PATIENCE = 2 in default.
|
544
|
+
"""
|
523
545
|
with self.lock:
|
524
546
|
if node_id in self.node_ids:
|
525
|
-
self.node_ids[node_id] = (
|
547
|
+
self.node_ids[node_id] = (
|
548
|
+
time.time() + PING_PATIENCE * ping_interval,
|
549
|
+
ping_interval,
|
550
|
+
)
|
526
551
|
return True
|
527
552
|
return False
|
528
553
|
|
@@ -30,28 +30,31 @@ from flwr.common import Context, Message, Metadata, log, now
|
|
30
30
|
from flwr.common.constant import (
|
31
31
|
MESSAGE_TTL_TOLERANCE,
|
32
32
|
NODE_ID_NUM_BYTES,
|
33
|
+
PING_PATIENCE,
|
33
34
|
RUN_ID_NUM_BYTES,
|
34
35
|
SUPERLINK_NODE_ID,
|
35
36
|
Status,
|
36
37
|
)
|
38
|
+
from flwr.common.message import make_message
|
37
39
|
from flwr.common.record import ConfigsRecord
|
38
40
|
from flwr.common.serde import (
|
39
41
|
error_from_proto,
|
40
42
|
error_to_proto,
|
41
|
-
|
42
|
-
|
43
|
+
recorddict_from_proto,
|
44
|
+
recorddict_to_proto,
|
43
45
|
)
|
44
46
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
45
47
|
|
46
48
|
# pylint: disable=E0611
|
47
49
|
from flwr.proto.error_pb2 import Error as ProtoError
|
48
|
-
from flwr.proto.
|
50
|
+
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
49
51
|
|
50
52
|
# pylint: enable=E0611
|
51
53
|
from flwr.server.utils.validator import validate_message
|
52
54
|
|
53
55
|
from .linkstate import LinkState
|
54
56
|
from .utils import (
|
57
|
+
check_node_availability_for_in_message,
|
55
58
|
configsrecord_from_bytes,
|
56
59
|
configsrecord_to_bytes,
|
57
60
|
context_from_bytes,
|
@@ -129,7 +132,7 @@ CREATE TABLE IF NOT EXISTS message_ins(
|
|
129
132
|
run_id INTEGER,
|
130
133
|
src_node_id INTEGER,
|
131
134
|
dst_node_id INTEGER,
|
132
|
-
|
135
|
+
reply_to_message_id TEXT,
|
133
136
|
created_at REAL,
|
134
137
|
delivered_at TEXT,
|
135
138
|
ttl REAL,
|
@@ -148,7 +151,7 @@ CREATE TABLE IF NOT EXISTS message_res(
|
|
148
151
|
run_id INTEGER,
|
149
152
|
src_node_id INTEGER,
|
150
153
|
dst_node_id INTEGER,
|
151
|
-
|
154
|
+
reply_to_message_id TEXT,
|
152
155
|
created_at REAL,
|
153
156
|
delivered_at TEXT,
|
154
157
|
ttl REAL,
|
@@ -371,7 +374,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
371
374
|
return None
|
372
375
|
|
373
376
|
res_metadata = message.metadata
|
374
|
-
msg_ins_id = res_metadata.
|
377
|
+
msg_ins_id = res_metadata.reply_to_message_id
|
375
378
|
msg_ins = self.get_valid_message_ins(msg_ins_id)
|
376
379
|
if msg_ins is None:
|
377
380
|
log(
|
@@ -442,6 +445,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
442
445
|
|
443
446
|
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
444
447
|
"""Get reply Messages for the given Message IDs."""
|
448
|
+
# pylint: disable-msg=too-many-locals
|
445
449
|
ret: dict[UUID, Message] = {}
|
446
450
|
|
447
451
|
# Verify Message IDs
|
@@ -465,11 +469,34 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
465
469
|
current_time=current,
|
466
470
|
)
|
467
471
|
|
472
|
+
# Check node availability
|
473
|
+
dst_node_ids: set[int] = set()
|
474
|
+
for message_id in message_ids:
|
475
|
+
in_message = found_message_ins_dict[message_id]
|
476
|
+
sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
|
477
|
+
dst_node_ids.add(sint_node_id)
|
478
|
+
query = f"""
|
479
|
+
SELECT node_id, online_until
|
480
|
+
FROM node
|
481
|
+
WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
|
482
|
+
"""
|
483
|
+
rows = self.query(query, tuple(dst_node_ids))
|
484
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
485
|
+
inquired_in_message_ids=message_ids,
|
486
|
+
found_in_message_dict=found_message_ins_dict,
|
487
|
+
node_id_to_online_until={
|
488
|
+
convert_sint64_to_uint64(row["node_id"]): row["online_until"]
|
489
|
+
for row in rows
|
490
|
+
},
|
491
|
+
current_time=current,
|
492
|
+
)
|
493
|
+
ret.update(tmp_ret_dict)
|
494
|
+
|
468
495
|
# Find all reply Messages
|
469
496
|
query = f"""
|
470
497
|
SELECT *
|
471
498
|
FROM message_res
|
472
|
-
WHERE
|
499
|
+
WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
|
473
500
|
AND delivered_at = "";
|
474
501
|
"""
|
475
502
|
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
@@ -542,7 +569,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
542
569
|
# Delete reply Message
|
543
570
|
query_2 = f"""
|
544
571
|
DELETE FROM message_res
|
545
|
-
WHERE
|
572
|
+
WHERE reply_to_message_id IN ({placeholders});
|
546
573
|
"""
|
547
574
|
|
548
575
|
with self.conn:
|
@@ -584,6 +611,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
584
611
|
"VALUES (?, ?, ?, ?)"
|
585
612
|
)
|
586
613
|
|
614
|
+
# Mark the node online util time.time() + ping_interval
|
587
615
|
try:
|
588
616
|
self.query(
|
589
617
|
query,
|
@@ -899,7 +927,11 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
899
927
|
return configsrecord_from_bytes(row["federation_options"])
|
900
928
|
|
901
929
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
902
|
-
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
930
|
+
"""Acknowledge a ping received from a node, serving as a heartbeat.
|
931
|
+
|
932
|
+
It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
|
933
|
+
marking the node as offline, where PING_PATIENCE = 2 in default.
|
934
|
+
"""
|
903
935
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
904
936
|
|
905
937
|
# Check if the node exists in the `node` table
|
@@ -909,7 +941,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
909
941
|
|
910
942
|
# Update `online_until` and `ping_interval` for the given `node_id`
|
911
943
|
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
|
912
|
-
self.query(
|
944
|
+
self.query(
|
945
|
+
query,
|
946
|
+
(
|
947
|
+
time.time() + PING_PATIENCE * ping_interval,
|
948
|
+
ping_interval,
|
949
|
+
sint64_node_id,
|
950
|
+
),
|
951
|
+
)
|
913
952
|
return True
|
914
953
|
|
915
954
|
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
@@ -1026,7 +1065,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1026
1065
|
"run_id": message.metadata.run_id,
|
1027
1066
|
"src_node_id": message.metadata.src_node_id,
|
1028
1067
|
"dst_node_id": message.metadata.dst_node_id,
|
1029
|
-
"
|
1068
|
+
"reply_to_message_id": message.metadata.reply_to_message_id,
|
1030
1069
|
"created_at": message.metadata.created_at,
|
1031
1070
|
"delivered_at": message.metadata.delivered_at,
|
1032
1071
|
"ttl": message.metadata.ttl,
|
@@ -1036,7 +1075,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1036
1075
|
}
|
1037
1076
|
|
1038
1077
|
if message.has_content():
|
1039
|
-
result["content"] =
|
1078
|
+
result["content"] = recorddict_to_proto(message.content).SerializeToString()
|
1040
1079
|
else:
|
1041
1080
|
result["error"] = error_to_proto(message.error).SerializeToString()
|
1042
1081
|
|
@@ -1047,20 +1086,15 @@ def dict_to_message(message_dict: dict[str, Any]) -> Message:
|
|
1047
1086
|
"""Transform dict to Message."""
|
1048
1087
|
content, error = None, None
|
1049
1088
|
if (b_content := message_dict.pop("content")) is not None:
|
1050
|
-
content =
|
1089
|
+
content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
|
1051
1090
|
if (b_error := message_dict.pop("error")) is not None:
|
1052
1091
|
error = error_from_proto(ProtoError.FromString(b_error))
|
1053
1092
|
|
1054
1093
|
# Metadata constructor doesn't allow passing created_at. We set it later
|
1055
1094
|
metadata = Metadata(
|
1056
|
-
**{
|
1057
|
-
k: v
|
1058
|
-
for k, v in message_dict.items()
|
1059
|
-
if k not in ["created_at", "delivered_at"]
|
1060
|
-
}
|
1095
|
+
**{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
|
1061
1096
|
)
|
1062
|
-
msg =
|
1063
|
-
msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
|
1097
|
+
msg = make_message(metadata=metadata, content=content, error=error)
|
1064
1098
|
msg.metadata.delivered_at = message_dict["delivered_at"]
|
1065
1099
|
return msg
|
1066
1100
|
|