flwr-nightly 1.19.0.dev20250526__py3-none-any.whl → 1.19.0.dev20250528__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +3 -3
- flwr/cli/run/run.py +2 -6
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +5 -4
- flwr/client/grpc_rere_client/connection.py +2 -0
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +18 -0
- flwr/common/constant.py +3 -0
- flwr/common/inflatable.py +33 -2
- flwr/common/message.py +5 -1
- flwr/common/record/array.py +38 -1
- flwr/common/record/arrayrecord.py +34 -0
- flwr/common/serde.py +6 -1
- flwr/compat/client/app.py +9 -151
- flwr/proto/fleet_pb2.py +25 -13
- flwr/proto/fleet_pb2.pyi +60 -3
- flwr/proto/message_pb2.py +22 -19
- flwr/proto/message_pb2.pyi +25 -2
- flwr/proto/serverappio_pb2.py +31 -19
- flwr/proto/serverappio_pb2.pyi +60 -3
- flwr/server/app.py +44 -1
- flwr/server/grid/grpc_grid.py +2 -1
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -2
- flwr/server/superlink/fleet/vce/vce_api.py +3 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -25
- flwr/server/superlink/linkstate/linkstate.py +9 -10
- flwr/server/superlink/linkstate/sqlite_linkstate.py +11 -21
- flwr/server/superlink/linkstate/utils.py +23 -23
- flwr/server/superlink/serverappio/serverappio_servicer.py +6 -10
- flwr/server/utils/validator.py +2 -2
- flwr/supercore/object_store/in_memory_object_store.py +30 -4
- flwr/supercore/object_store/object_store.py +48 -1
- flwr/superexec/exec_servicer.py +1 -2
- {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/RECORD +41 -41
- {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@
|
|
18
18
|
import time
|
19
19
|
from collections.abc import Iterable
|
20
20
|
from typing import Optional, cast
|
21
|
-
from uuid import
|
21
|
+
from uuid import uuid4
|
22
22
|
|
23
23
|
from flwr.common import Message, RecordDict
|
24
24
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
@@ -56,7 +56,7 @@ class InMemoryGrid(Grid):
|
|
56
56
|
def _check_message(self, message: Message) -> None:
|
57
57
|
# Check if the message is valid
|
58
58
|
if not (
|
59
|
-
message.metadata.message_id
|
59
|
+
message.metadata.message_id != ""
|
60
60
|
and message.metadata.reply_to_message_id == ""
|
61
61
|
and message.metadata.ttl > 0
|
62
62
|
and message.metadata.delivered_at == ""
|
@@ -111,6 +111,7 @@ class InMemoryGrid(Grid):
|
|
111
111
|
# Populate metadata
|
112
112
|
msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
|
113
113
|
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
114
|
+
msg.metadata.__dict__["_message_id"] = str(uuid4())
|
114
115
|
# Check message
|
115
116
|
self._check_message(msg)
|
116
117
|
# Store in state
|
@@ -126,12 +127,12 @@ class InMemoryGrid(Grid):
|
|
126
127
|
This method is used to collect messages from the SuperLink that correspond to a
|
127
128
|
set of given message IDs.
|
128
129
|
"""
|
129
|
-
msg_ids =
|
130
|
+
msg_ids = set(message_ids)
|
130
131
|
# Pull Messages
|
131
132
|
message_res_list = self.state.get_message_res(message_ids=msg_ids)
|
132
133
|
# Get IDs of Messages these replies are for
|
133
134
|
message_ins_ids_to_delete = {
|
134
|
-
|
135
|
+
msg_res.metadata.reply_to_message_id for msg_res in message_res_list
|
135
136
|
}
|
136
137
|
# Delete
|
137
138
|
self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
@@ -16,7 +16,6 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from typing import Optional
|
19
|
-
from uuid import UUID
|
20
19
|
|
21
20
|
from flwr.common import Message
|
22
21
|
from flwr.common.constant import Status
|
@@ -122,7 +121,7 @@ def push_messages(
|
|
122
121
|
raise InvalidRunStatusException(abort_msg)
|
123
122
|
|
124
123
|
# Store Message in State
|
125
|
-
message_id: Optional[
|
124
|
+
message_id: Optional[str] = state.store_message_res(message=msg)
|
126
125
|
|
127
126
|
# Build response
|
128
127
|
response = PushMessagesResponse(
|
@@ -25,6 +25,7 @@ from pathlib import Path
|
|
25
25
|
from queue import Empty, Queue
|
26
26
|
from time import sleep
|
27
27
|
from typing import Callable, Optional
|
28
|
+
from uuid import uuid4
|
28
29
|
|
29
30
|
from flwr.app.error import Error
|
30
31
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
@@ -134,6 +135,8 @@ def worker(
|
|
134
135
|
|
135
136
|
finally:
|
136
137
|
if out_mssg:
|
138
|
+
# Assign a message_id
|
139
|
+
out_mssg.metadata.__dict__["_message_id"] = str(uuid4())
|
137
140
|
# Store reply Messages in state
|
138
141
|
messageres_queue.put(out_mssg)
|
139
142
|
|
@@ -21,7 +21,6 @@ from bisect import bisect_right
|
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from logging import ERROR, WARNING
|
23
23
|
from typing import Optional
|
24
|
-
from uuid import UUID, uuid4
|
25
24
|
|
26
25
|
from flwr.common import Context, Message, log, now
|
27
26
|
from flwr.common.constant import (
|
@@ -76,15 +75,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
76
75
|
self.run_ids: dict[int, RunRecord] = {}
|
77
76
|
self.contexts: dict[int, Context] = {}
|
78
77
|
self.federation_options: dict[int, ConfigRecord] = {}
|
79
|
-
self.message_ins_store: dict[
|
80
|
-
self.message_res_store: dict[
|
81
|
-
self.message_ins_id_to_message_res_id: dict[
|
78
|
+
self.message_ins_store: dict[str, Message] = {}
|
79
|
+
self.message_res_store: dict[str, Message] = {}
|
80
|
+
self.message_ins_id_to_message_res_id: dict[str, str] = {}
|
82
81
|
|
83
82
|
self.node_public_keys: set[bytes] = set()
|
84
83
|
|
85
84
|
self.lock = threading.RLock()
|
86
85
|
|
87
|
-
def store_message_ins(self, message: Message) -> Optional[
|
86
|
+
def store_message_ins(self, message: Message) -> Optional[str]:
|
88
87
|
"""Store one Message."""
|
89
88
|
# Validate message
|
90
89
|
errors = validate_message(message, is_reply_message=False)
|
@@ -112,12 +111,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
112
111
|
)
|
113
112
|
return None
|
114
113
|
|
115
|
-
|
116
|
-
message_id = uuid4()
|
117
|
-
|
118
|
-
# Store Message
|
119
|
-
# pylint: disable-next=W0212
|
120
|
-
message.metadata._message_id = str(message_id) # type: ignore
|
114
|
+
message_id = message.metadata.message_id
|
121
115
|
with self.lock:
|
122
116
|
self.message_ins_store[message_id] = message
|
123
117
|
|
@@ -153,7 +147,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
153
147
|
return message_ins_list
|
154
148
|
|
155
149
|
# pylint: disable=R0911
|
156
|
-
def store_message_res(self, message: Message) -> Optional[
|
150
|
+
def store_message_res(self, message: Message) -> Optional[str]:
|
157
151
|
"""Store one Message."""
|
158
152
|
# Validate message
|
159
153
|
errors = validate_message(message, is_reply_message=True)
|
@@ -165,7 +159,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
165
159
|
with self.lock:
|
166
160
|
# Check if the Message it is replying to exists and is valid
|
167
161
|
msg_ins_id = res_metadata.reply_to_message_id
|
168
|
-
msg_ins = self.message_ins_store.get(
|
162
|
+
msg_ins = self.message_ins_store.get(msg_ins_id)
|
169
163
|
|
170
164
|
# Ensure that dst_node_id of original Message matches the src_node_id of
|
171
165
|
# reply Message.
|
@@ -220,22 +214,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
220
214
|
log(ERROR, "`metadata.run_id` is invalid")
|
221
215
|
return None
|
222
216
|
|
223
|
-
|
224
|
-
message_id = uuid4()
|
225
|
-
|
226
|
-
# Store Message
|
227
|
-
# pylint: disable-next=W0212
|
228
|
-
message.metadata._message_id = str(message_id) # type: ignore
|
217
|
+
message_id = message.metadata.message_id
|
229
218
|
with self.lock:
|
230
219
|
self.message_res_store[message_id] = message
|
231
|
-
self.message_ins_id_to_message_res_id[
|
220
|
+
self.message_ins_id_to_message_res_id[msg_ins_id] = message_id
|
232
221
|
|
233
222
|
# Return the new message_id
|
234
223
|
return message_id
|
235
224
|
|
236
|
-
def get_message_res(self, message_ids: set[
|
225
|
+
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
237
226
|
"""Get reply Messages for the given Message IDs."""
|
238
|
-
ret: dict[
|
227
|
+
ret: dict[str, Message] = {}
|
239
228
|
|
240
229
|
with self.lock:
|
241
230
|
current = time.time()
|
@@ -287,7 +276,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
287
276
|
|
288
277
|
return list(ret.values())
|
289
278
|
|
290
|
-
def delete_messages(self, message_ins_ids: set[
|
279
|
+
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
291
280
|
"""Delete a Message and its reply based on provided Message IDs."""
|
292
281
|
if not message_ins_ids:
|
293
282
|
return
|
@@ -304,9 +293,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
304
293
|
)
|
305
294
|
del self.message_res_store[message_res_id]
|
306
295
|
|
307
|
-
def get_message_ids_from_run_id(self, run_id: int) -> set[
|
296
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
308
297
|
"""Get all instruction Message IDs for the given run_id."""
|
309
|
-
message_id_list: set[
|
298
|
+
message_id_list: set[str] = set()
|
310
299
|
with self.lock:
|
311
300
|
for message_id, message in self.message_ins_store.items():
|
312
301
|
if message.metadata.run_id == run_id:
|
@@ -17,7 +17,6 @@
|
|
17
17
|
|
18
18
|
import abc
|
19
19
|
from typing import Optional
|
20
|
-
from uuid import UUID
|
21
20
|
|
22
21
|
from flwr.common import Context, Message
|
23
22
|
from flwr.common.record import ConfigRecord
|
@@ -28,13 +27,13 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
28
27
|
"""Abstract LinkState."""
|
29
28
|
|
30
29
|
@abc.abstractmethod
|
31
|
-
def store_message_ins(self, message: Message) -> Optional[
|
30
|
+
def store_message_ins(self, message: Message) -> Optional[str]:
|
32
31
|
"""Store one Message.
|
33
32
|
|
34
33
|
Usually, the ServerAppIo API calls this to schedule instructions.
|
35
34
|
|
36
35
|
Stores the value of the `message` in the link state and, if successful,
|
37
|
-
returns the `message_id` (
|
36
|
+
returns the `message_id` (str) of the `message`. If, for any reason,
|
38
37
|
storing the `message` fails, `None` is returned.
|
39
38
|
|
40
39
|
Constraints
|
@@ -61,12 +60,12 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
61
60
|
"""
|
62
61
|
|
63
62
|
@abc.abstractmethod
|
64
|
-
def store_message_res(self, message: Message) -> Optional[
|
63
|
+
def store_message_res(self, message: Message) -> Optional[str]:
|
65
64
|
"""Store one Message.
|
66
65
|
|
67
66
|
Usually, the Fleet API calls this for Nodes returning results.
|
68
67
|
|
69
|
-
Stores the Message and, if successful, returns the `message_id` (
|
68
|
+
Stores the Message and, if successful, returns the `message_id` (str) of
|
70
69
|
the `message`. If storing the `message` fails, `None` is returned.
|
71
70
|
|
72
71
|
Constraints
|
@@ -78,7 +77,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
78
77
|
"""
|
79
78
|
|
80
79
|
@abc.abstractmethod
|
81
|
-
def get_message_res(self, message_ids: set[
|
80
|
+
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
82
81
|
"""Get reply Messages for the given Message IDs.
|
83
82
|
|
84
83
|
This method is typically called by the ServerAppIo API to obtain
|
@@ -94,7 +93,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
94
93
|
|
95
94
|
Parameters
|
96
95
|
----------
|
97
|
-
message_ids : set[
|
96
|
+
message_ids : set[str]
|
98
97
|
A set of Message IDs used to retrieve reply Messages responding to them.
|
99
98
|
|
100
99
|
Returns
|
@@ -113,18 +112,18 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
113
112
|
"""Calculate the number of reply Messages in store."""
|
114
113
|
|
115
114
|
@abc.abstractmethod
|
116
|
-
def delete_messages(self, message_ins_ids: set[
|
115
|
+
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
117
116
|
"""Delete a Message and its reply based on provided Message IDs.
|
118
117
|
|
119
118
|
Parameters
|
120
119
|
----------
|
121
|
-
message_ins_ids : set[
|
120
|
+
message_ins_ids : set[str]
|
122
121
|
A set of Message IDs. For each ID in the set, the corresponding
|
123
122
|
Message and its associated reply Message will be deleted.
|
124
123
|
"""
|
125
124
|
|
126
125
|
@abc.abstractmethod
|
127
|
-
def get_message_ids_from_run_id(self, run_id: int) -> set[
|
126
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
128
127
|
"""Get all instruction Message IDs for the given run_id."""
|
129
128
|
|
130
129
|
@abc.abstractmethod
|
@@ -24,7 +24,6 @@ import time
|
|
24
24
|
from collections.abc import Sequence
|
25
25
|
from logging import DEBUG, ERROR, WARNING
|
26
26
|
from typing import Any, Optional, Union, cast
|
27
|
-
from uuid import UUID, uuid4
|
28
27
|
|
29
28
|
from flwr.common import Context, Message, Metadata, log, now
|
30
29
|
from flwr.common.constant import (
|
@@ -251,19 +250,15 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
251
250
|
|
252
251
|
return result
|
253
252
|
|
254
|
-
def store_message_ins(self, message: Message) -> Optional[
|
253
|
+
def store_message_ins(self, message: Message) -> Optional[str]:
|
255
254
|
"""Store one Message."""
|
256
255
|
# Validate message
|
257
256
|
errors = validate_message(message=message, is_reply_message=False)
|
258
257
|
if any(errors):
|
259
258
|
log(ERROR, errors)
|
260
259
|
return None
|
261
|
-
# Create message_id
|
262
|
-
message_id = uuid4()
|
263
260
|
|
264
261
|
# Store Message
|
265
|
-
# pylint: disable-next=W0212
|
266
|
-
message.metadata._message_id = str(message_id) # type: ignore
|
267
262
|
data = (message_to_dict(message),)
|
268
263
|
|
269
264
|
# Convert values from uint64 to sint64 for SQLite
|
@@ -303,7 +298,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
303
298
|
# This may need to be changed in the future version with more integrity checks.
|
304
299
|
self.query(query, data)
|
305
300
|
|
306
|
-
return message_id
|
301
|
+
return message.metadata.message_id
|
307
302
|
|
308
303
|
def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
|
309
304
|
"""Get all Messages that have not been delivered yet."""
|
@@ -366,7 +361,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
366
361
|
|
367
362
|
return result
|
368
363
|
|
369
|
-
def store_message_res(self, message: Message) -> Optional[
|
364
|
+
def store_message_res(self, message: Message) -> Optional[str]:
|
370
365
|
"""Store one Message."""
|
371
366
|
# Validate message
|
372
367
|
errors = validate_message(message=message, is_reply_message=True)
|
@@ -418,12 +413,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
418
413
|
)
|
419
414
|
return None
|
420
415
|
|
421
|
-
# Create message_id
|
422
|
-
message_id = uuid4()
|
423
|
-
|
424
416
|
# Store Message
|
425
|
-
# pylint: disable-next=W0212
|
426
|
-
message.metadata._message_id = str(message_id) # type: ignore
|
427
417
|
data = (message_to_dict(message),)
|
428
418
|
|
429
419
|
# Convert values from uint64 to sint64 for SQLite
|
@@ -442,12 +432,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
442
432
|
log(ERROR, "`run` is invalid")
|
443
433
|
return None
|
444
434
|
|
445
|
-
return message_id
|
435
|
+
return message.metadata.message_id
|
446
436
|
|
447
|
-
def get_message_res(self, message_ids: set[
|
437
|
+
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
448
438
|
"""Get reply Messages for the given Message IDs."""
|
449
439
|
# pylint: disable-msg=too-many-locals
|
450
|
-
ret: dict[
|
440
|
+
ret: dict[str, Message] = {}
|
451
441
|
|
452
442
|
# Verify Message IDs
|
453
443
|
current = time.time()
|
@@ -457,12 +447,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
457
447
|
WHERE message_id IN ({",".join(["?"] * len(message_ids))});
|
458
448
|
"""
|
459
449
|
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
460
|
-
found_message_ins_dict: dict[
|
450
|
+
found_message_ins_dict: dict[str, Message] = {}
|
461
451
|
for row in rows:
|
462
452
|
convert_sint64_values_in_dict_to_uint64(
|
463
453
|
row, ["run_id", "src_node_id", "dst_node_id"]
|
464
454
|
)
|
465
|
-
found_message_ins_dict[
|
455
|
+
found_message_ins_dict[row["message_id"]] = dict_to_message(row)
|
466
456
|
|
467
457
|
ret = verify_message_ids(
|
468
458
|
inquired_message_ids=message_ids,
|
@@ -551,7 +541,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
551
541
|
result: dict[str, int] = rows[0]
|
552
542
|
return result["num"]
|
553
543
|
|
554
|
-
def delete_messages(self, message_ins_ids: set[
|
544
|
+
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
555
545
|
"""Delete a Message and its reply based on provided Message IDs."""
|
556
546
|
if not message_ins_ids:
|
557
547
|
return
|
@@ -577,7 +567,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
577
567
|
self.conn.execute(query_1, data)
|
578
568
|
self.conn.execute(query_2, data)
|
579
569
|
|
580
|
-
def get_message_ids_from_run_id(self, run_id: int) -> set[
|
570
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
581
571
|
"""Get all instruction Message IDs for the given run_id."""
|
582
572
|
if self.conn is None:
|
583
573
|
raise AttributeError("LinkState not initialized")
|
@@ -594,7 +584,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
594
584
|
with self.conn:
|
595
585
|
rows = self.conn.execute(query, data).fetchall()
|
596
586
|
|
597
|
-
return {
|
587
|
+
return {row["message_id"] for row in rows}
|
598
588
|
|
599
589
|
def create_node(self, heartbeat_interval: float) -> int:
|
600
590
|
"""Create, store in the link state, and return `node_id`."""
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from os import urandom
|
19
19
|
from typing import Optional
|
20
|
-
from uuid import
|
20
|
+
from uuid import uuid4
|
21
21
|
|
22
22
|
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
23
23
|
from flwr.common.constant import (
|
@@ -273,7 +273,7 @@ def create_message_error_unavailable_res_message(
|
|
273
273
|
)
|
274
274
|
|
275
275
|
|
276
|
-
def create_message_error_unavailable_ins_message(reply_to_message_id:
|
276
|
+
def create_message_error_unavailable_ins_message(reply_to_message_id: str) -> Message:
|
277
277
|
"""Error to indicate that the enquired Message had expired before reply arrived or
|
278
278
|
that it isn't found."""
|
279
279
|
metadata = Metadata(
|
@@ -281,7 +281,7 @@ def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> M
|
|
281
281
|
message_id=str(uuid4()),
|
282
282
|
src_node_id=SUPERLINK_NODE_ID,
|
283
283
|
dst_node_id=SUPERLINK_NODE_ID,
|
284
|
-
reply_to_message_id=
|
284
|
+
reply_to_message_id=reply_to_message_id,
|
285
285
|
group_id="", # Unknown
|
286
286
|
message_type=MessageType.SYSTEM,
|
287
287
|
created_at=now().timestamp(),
|
@@ -303,18 +303,18 @@ def message_ttl_has_expired(message_metadata: Metadata, current_time: float) ->
|
|
303
303
|
|
304
304
|
|
305
305
|
def verify_message_ids(
|
306
|
-
inquired_message_ids: set[
|
307
|
-
found_message_ins_dict: dict[
|
306
|
+
inquired_message_ids: set[str],
|
307
|
+
found_message_ins_dict: dict[str, Message],
|
308
308
|
current_time: Optional[float] = None,
|
309
309
|
update_set: bool = True,
|
310
|
-
) -> dict[
|
310
|
+
) -> dict[str, Message]:
|
311
311
|
"""Verify found Messages and generate error Messages for invalid ones.
|
312
312
|
|
313
313
|
Parameters
|
314
314
|
----------
|
315
|
-
inquired_message_ids : set[
|
315
|
+
inquired_message_ids : set[str]
|
316
316
|
Set of Message IDs for which to generate error Message if invalid.
|
317
|
-
found_message_ins_dict : dict[
|
317
|
+
found_message_ins_dict : dict[str, Message]
|
318
318
|
Dictionary containing all found Message indexed by their IDs.
|
319
319
|
current_time : Optional[float] (default: None)
|
320
320
|
The current time to check for expiration. If set to `None`, the current time
|
@@ -325,7 +325,7 @@ def verify_message_ids(
|
|
325
325
|
|
326
326
|
Returns
|
327
327
|
-------
|
328
|
-
dict[
|
328
|
+
dict[str, Message]
|
329
329
|
A dictionary of error Message indexed by the corresponding ID of the message
|
330
330
|
they are a reply of.
|
331
331
|
"""
|
@@ -345,19 +345,19 @@ def verify_message_ids(
|
|
345
345
|
|
346
346
|
|
347
347
|
def verify_found_message_replies(
|
348
|
-
inquired_message_ids: set[
|
349
|
-
found_message_ins_dict: dict[
|
348
|
+
inquired_message_ids: set[str],
|
349
|
+
found_message_ins_dict: dict[str, Message],
|
350
350
|
found_message_res_list: list[Message],
|
351
351
|
current_time: Optional[float] = None,
|
352
352
|
update_set: bool = True,
|
353
|
-
) -> dict[
|
353
|
+
) -> dict[str, Message]:
|
354
354
|
"""Verify found Message replies and generate error Message for invalid ones.
|
355
355
|
|
356
356
|
Parameters
|
357
357
|
----------
|
358
|
-
inquired_message_ids : set[
|
358
|
+
inquired_message_ids : set[str]
|
359
359
|
Set of Message IDs for which to generate error Message if invalid.
|
360
|
-
found_message_ins_dict : dict[
|
360
|
+
found_message_ins_dict : dict[str, Message]
|
361
361
|
Dictionary containing all found instruction Messages indexed by their IDs.
|
362
362
|
found_message_res_list : dict[Message, Message]
|
363
363
|
List of found Message to be verified.
|
@@ -370,13 +370,13 @@ def verify_found_message_replies(
|
|
370
370
|
|
371
371
|
Returns
|
372
372
|
-------
|
373
|
-
dict[
|
373
|
+
dict[str, Message]
|
374
374
|
A dictionary of Message indexed by the corresponding Message ID.
|
375
375
|
"""
|
376
|
-
ret_dict: dict[
|
376
|
+
ret_dict: dict[str, Message] = {}
|
377
377
|
current = current_time if current_time else now().timestamp()
|
378
378
|
for message_res in found_message_res_list:
|
379
|
-
message_ins_id =
|
379
|
+
message_ins_id = message_res.metadata.reply_to_message_id
|
380
380
|
if update_set:
|
381
381
|
inquired_message_ids.remove(message_ins_id)
|
382
382
|
# Check if the reply Message has expired
|
@@ -390,21 +390,21 @@ def verify_found_message_replies(
|
|
390
390
|
|
391
391
|
|
392
392
|
def check_node_availability_for_in_message(
|
393
|
-
inquired_in_message_ids: set[
|
394
|
-
found_in_message_dict: dict[
|
393
|
+
inquired_in_message_ids: set[str],
|
394
|
+
found_in_message_dict: dict[str, Message],
|
395
395
|
node_id_to_online_until: dict[int, float],
|
396
396
|
current_time: Optional[float] = None,
|
397
397
|
update_set: bool = True,
|
398
|
-
) -> dict[
|
398
|
+
) -> dict[str, Message]:
|
399
399
|
"""Check node availability for given Message and generate error reply Message if
|
400
400
|
unavailable. A Message error indicating node unavailability will be generated for
|
401
401
|
each given Message whose destination node is offline or non-existent.
|
402
402
|
|
403
403
|
Parameters
|
404
404
|
----------
|
405
|
-
inquired_in_message_ids : set[
|
405
|
+
inquired_in_message_ids : set[str]
|
406
406
|
Set of Message IDs for which to check destination node availability.
|
407
|
-
found_in_message_dict : dict[
|
407
|
+
found_in_message_dict : dict[str, Message]
|
408
408
|
Dictionary containing all found Message indexed by their IDs.
|
409
409
|
node_id_to_online_until : dict[int, float]
|
410
410
|
Dictionary mapping node IDs to their online-until timestamps.
|
@@ -417,7 +417,7 @@ def check_node_availability_for_in_message(
|
|
417
417
|
|
418
418
|
Returns
|
419
419
|
-------
|
420
|
-
dict[
|
420
|
+
dict[str, Message]
|
421
421
|
A dictionary of error Message indexed by the corresponding Message ID.
|
422
422
|
"""
|
423
423
|
ret_dict = {}
|
@@ -18,7 +18,6 @@
|
|
18
18
|
import threading
|
19
19
|
from logging import DEBUG, INFO
|
20
20
|
from typing import Optional
|
21
|
-
from uuid import UUID
|
22
21
|
|
23
22
|
import grpc
|
24
23
|
|
@@ -140,7 +139,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
140
139
|
request_name="PushMessages",
|
141
140
|
detail="`messages_list` must not be empty",
|
142
141
|
)
|
143
|
-
message_ids: list[Optional[
|
142
|
+
message_ids: list[Optional[str]] = []
|
144
143
|
while request.messages_list:
|
145
144
|
message_proto = request.messages_list.pop(0)
|
146
145
|
message = message_from_proto(message_proto=message_proto)
|
@@ -156,7 +155,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
156
155
|
detail="`Message.metadata` has mismatched `run_id`",
|
157
156
|
)
|
158
157
|
# Store
|
159
|
-
message_id: Optional[
|
158
|
+
message_id: Optional[str] = state.store_message_ins(message=message)
|
160
159
|
message_ids.append(message_id)
|
161
160
|
|
162
161
|
return PushInsMessagesResponse(
|
@@ -182,17 +181,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
182
181
|
context,
|
183
182
|
)
|
184
183
|
|
185
|
-
# Convert each message_id str to UUID
|
186
|
-
message_ids: set[UUID] = {
|
187
|
-
UUID(message_id) for message_id in request.message_ids
|
188
|
-
}
|
189
|
-
|
190
184
|
# Read from state
|
191
|
-
messages_res: list[Message] = state.get_message_res(
|
185
|
+
messages_res: list[Message] = state.get_message_res(
|
186
|
+
message_ids=set(request.message_ids)
|
187
|
+
)
|
192
188
|
|
193
189
|
# Delete the instruction Messages and their replies if found
|
194
190
|
message_ins_ids_to_delete = {
|
195
|
-
|
191
|
+
msg_res.metadata.reply_to_message_id for msg_res in messages_res
|
196
192
|
}
|
197
193
|
|
198
194
|
state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
flwr/server/utils/validator.py
CHANGED
@@ -27,8 +27,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
27
27
|
validation_errors = []
|
28
28
|
metadata = message.metadata
|
29
29
|
|
30
|
-
if metadata.message_id
|
31
|
-
validation_errors.append("
|
30
|
+
if metadata.message_id == "":
|
31
|
+
validation_errors.append("empty `metadata.message_id`")
|
32
32
|
|
33
33
|
# Created/delivered/TTL/Pushed
|
34
34
|
if (
|
@@ -28,12 +28,27 @@ class InMemoryObjectStore(ObjectStore):
|
|
28
28
|
def __init__(self, verify: bool = True) -> None:
|
29
29
|
self.verify = verify
|
30
30
|
self.store: dict[str, bytes] = {}
|
31
|
+
# Mapping the Object ID of a message to the list of children object IDs
|
32
|
+
self.msg_children_objects_mapping: dict[str, list[str]] = {}
|
33
|
+
|
34
|
+
def preregister(self, object_ids: list[str]) -> list[str]:
|
35
|
+
"""Identify and preregister missing objects."""
|
36
|
+
new_objects = []
|
37
|
+
for obj_id in object_ids:
|
38
|
+
# Verify object ID format (must be a valid sha256 hash)
|
39
|
+
if not is_valid_sha256_hash(obj_id):
|
40
|
+
raise ValueError(f"Invalid object ID format: {obj_id}")
|
41
|
+
if obj_id not in self.store:
|
42
|
+
self.store[obj_id] = b""
|
43
|
+
new_objects.append(obj_id)
|
44
|
+
|
45
|
+
return new_objects
|
31
46
|
|
32
47
|
def put(self, object_id: str, object_content: bytes) -> None:
|
33
48
|
"""Put an object into the store."""
|
34
|
-
#
|
35
|
-
if not
|
36
|
-
raise
|
49
|
+
# Only allow adding the object if it has been preregistered
|
50
|
+
if object_id not in self.store:
|
51
|
+
raise KeyError(f"Object with id {object_id} was not preregistered.")
|
37
52
|
|
38
53
|
# Verify object_id and object_content match
|
39
54
|
if self.verify:
|
@@ -42,11 +57,22 @@ class InMemoryObjectStore(ObjectStore):
|
|
42
57
|
raise ValueError(f"Object ID {object_id} does not match content hash")
|
43
58
|
|
44
59
|
# Return if object is already present in the store
|
45
|
-
if object_id
|
60
|
+
if self.store[object_id] != b"":
|
46
61
|
return
|
47
62
|
|
48
63
|
self.store[object_id] = object_content
|
49
64
|
|
65
|
+
def set_message_descendant_ids(
|
66
|
+
self, msg_object_id: str, descendant_ids: list[str]
|
67
|
+
) -> None:
|
68
|
+
"""Store the mapping from a ``Message`` object ID to the object IDs of its
|
69
|
+
descendants."""
|
70
|
+
self.msg_children_objects_mapping[msg_object_id] = descendant_ids
|
71
|
+
|
72
|
+
def get_message_descendant_ids(self, msg_object_id: str) -> list[str]:
|
73
|
+
"""Retrieve the object IDs of all descendants of a given Message."""
|
74
|
+
return self.msg_children_objects_mapping[msg_object_id]
|
75
|
+
|
50
76
|
def get(self, object_id: str) -> Optional[bytes]:
|
51
77
|
"""Get an object from the store."""
|
52
78
|
return self.store.get(object_id)
|