flwr-nightly 1.19.0.dev20250611__py3-none-any.whl → 1.19.0.dev20250613__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/ls.py +12 -33
- flwr/cli/utils.py +18 -1
- flwr/client/grpc_rere_client/connection.py +47 -29
- flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
- flwr/client/rest_client/connection.py +70 -51
- flwr/common/constant.py +4 -0
- flwr/common/inflatable.py +24 -0
- flwr/common/serde.py +2 -0
- flwr/common/typing.py +2 -0
- flwr/proto/fleet_pb2.py +12 -16
- flwr/proto/fleet_pb2.pyi +4 -19
- flwr/proto/fleet_pb2_grpc.py +34 -0
- flwr/proto/fleet_pb2_grpc.pyi +13 -0
- flwr/proto/message_pb2.py +15 -9
- flwr/proto/message_pb2.pyi +41 -0
- flwr/proto/run_pb2.py +24 -24
- flwr/proto/run_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +22 -26
- flwr/proto/serverappio_pb2.pyi +4 -19
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +1 -0
- flwr/server/grid/grpc_grid.py +20 -9
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +26 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +20 -3
- flwr/server/superlink/linkstate/linkstate.py +6 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
- flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
- flwr/server/superlink/utils.py +23 -10
- flwr/supercore/object_store/in_memory_object_store.py +160 -33
- flwr/supercore/object_store/object_store.py +54 -7
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_grpc.py +3 -0
- flwr/superexec/exec_servicer.py +125 -22
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/RECORD +43 -43
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/entry_points.txt +0 -0
@@ -44,6 +44,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
44
44
|
SendNodeHeartbeatResponse,
|
45
45
|
)
|
46
46
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
47
|
+
ConfirmMessageReceivedRequest,
|
48
|
+
ConfirmMessageReceivedResponse,
|
47
49
|
ObjectIDs,
|
48
50
|
PullObjectRequest,
|
49
51
|
PullObjectResponse,
|
@@ -146,6 +148,7 @@ def push_messages(
|
|
146
148
|
msg.metadata.run_id,
|
147
149
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
148
150
|
state,
|
151
|
+
store,
|
149
152
|
)
|
150
153
|
if abort_msg:
|
151
154
|
raise InvalidRunStatusException(abort_msg)
|
@@ -165,7 +168,9 @@ def push_messages(
|
|
165
168
|
return response
|
166
169
|
|
167
170
|
|
168
|
-
def get_run(
|
171
|
+
def get_run(
|
172
|
+
request: GetRunRequest, state: LinkState, store: ObjectStore
|
173
|
+
) -> GetRunResponse:
|
169
174
|
"""Get run information."""
|
170
175
|
run = state.get_run(request.run_id)
|
171
176
|
|
@@ -177,6 +182,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
177
182
|
request.run_id,
|
178
183
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
179
184
|
state,
|
185
|
+
store,
|
180
186
|
)
|
181
187
|
if abort_msg:
|
182
188
|
raise InvalidRunStatusException(abort_msg)
|
@@ -193,7 +199,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
193
199
|
|
194
200
|
|
195
201
|
def get_fab(
|
196
|
-
request: GetFabRequest, ffs: Ffs, state: LinkState
|
202
|
+
request: GetFabRequest, ffs: Ffs, state: LinkState, store: ObjectStore
|
197
203
|
) -> GetFabResponse:
|
198
204
|
"""Get FAB."""
|
199
205
|
# Abort if the run is not running
|
@@ -201,6 +207,7 @@ def get_fab(
|
|
201
207
|
request.run_id,
|
202
208
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
203
209
|
state,
|
210
|
+
store,
|
204
211
|
)
|
205
212
|
if abort_msg:
|
206
213
|
raise InvalidRunStatusException(abort_msg)
|
@@ -220,6 +227,7 @@ def push_object(
|
|
220
227
|
request.run_id,
|
221
228
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
222
229
|
state,
|
230
|
+
store,
|
223
231
|
)
|
224
232
|
if abort_msg:
|
225
233
|
raise InvalidRunStatusException(abort_msg)
|
@@ -245,6 +253,7 @@ def pull_object(
|
|
245
253
|
request.run_id,
|
246
254
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
247
255
|
state,
|
256
|
+
store,
|
248
257
|
)
|
249
258
|
if abort_msg:
|
250
259
|
raise InvalidRunStatusException(abort_msg)
|
@@ -259,3 +268,25 @@ def pull_object(
|
|
259
268
|
object_content=content,
|
260
269
|
)
|
261
270
|
return PullObjectResponse(object_found=False, object_available=False)
|
271
|
+
|
272
|
+
|
273
|
+
def confirm_message_received(
|
274
|
+
request: ConfirmMessageReceivedRequest,
|
275
|
+
state: LinkState,
|
276
|
+
store: ObjectStore,
|
277
|
+
) -> ConfirmMessageReceivedResponse:
|
278
|
+
"""Confirm message received handler."""
|
279
|
+
abort_msg = check_abort(
|
280
|
+
request.run_id,
|
281
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
282
|
+
state,
|
283
|
+
store,
|
284
|
+
)
|
285
|
+
if abort_msg:
|
286
|
+
raise InvalidRunStatusException(abort_msg)
|
287
|
+
|
288
|
+
# Delete the message object
|
289
|
+
store.delete(request.message_object_id)
|
290
|
+
store.delete_message_descendant_ids(request.message_object_id)
|
291
|
+
|
292
|
+
return ConfirmMessageReceivedResponse()
|
@@ -39,6 +39,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
39
39
|
SendNodeHeartbeatResponse,
|
40
40
|
)
|
41
41
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
42
|
+
ConfirmMessageReceivedRequest,
|
43
|
+
ConfirmMessageReceivedResponse,
|
42
44
|
PullObjectRequest,
|
43
45
|
PullObjectResponse,
|
44
46
|
PushObjectRequest,
|
@@ -176,9 +178,10 @@ async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
176
178
|
"""GetRun."""
|
177
179
|
# Get state from app
|
178
180
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
181
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
179
182
|
|
180
183
|
# Handle message
|
181
|
-
return message_handler.get_run(request=request, state=state)
|
184
|
+
return message_handler.get_run(request=request, state=state, store=store)
|
182
185
|
|
183
186
|
|
184
187
|
@rest_request_response(GetFabRequest)
|
@@ -189,9 +192,25 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
189
192
|
|
190
193
|
# Get state from app
|
191
194
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
195
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
196
|
+
|
197
|
+
# Handle message
|
198
|
+
return message_handler.get_fab(request=request, ffs=ffs, state=state, store=store)
|
199
|
+
|
200
|
+
|
201
|
+
@rest_request_response(ConfirmMessageReceivedRequest)
|
202
|
+
async def confirm_message_received(
|
203
|
+
request: ConfirmMessageReceivedRequest,
|
204
|
+
) -> ConfirmMessageReceivedResponse:
|
205
|
+
"""Confirm message received."""
|
206
|
+
# Get state from app
|
207
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
208
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
192
209
|
|
193
210
|
# Handle message
|
194
|
-
return message_handler.
|
211
|
+
return message_handler.confirm_message_received(
|
212
|
+
request=request, state=state, store=store
|
213
|
+
)
|
195
214
|
|
196
215
|
|
197
216
|
routes = [
|
@@ -204,6 +223,11 @@ routes = [
|
|
204
223
|
Route("/api/v0/fleet/send-node-heartbeat", send_node_heartbeat, methods=["POST"]),
|
205
224
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
206
225
|
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|
226
|
+
Route(
|
227
|
+
"/api/v0/fleet/confirm-message-received",
|
228
|
+
confirm_message_received,
|
229
|
+
methods=["POST"],
|
230
|
+
),
|
207
231
|
]
|
208
232
|
|
209
233
|
app: Starlette = Starlette(
|
@@ -18,6 +18,7 @@
|
|
18
18
|
import threading
|
19
19
|
import time
|
20
20
|
from bisect import bisect_right
|
21
|
+
from collections import defaultdict
|
21
22
|
from dataclasses import dataclass, field
|
22
23
|
from logging import ERROR, WARNING
|
23
24
|
from typing import Optional
|
@@ -79,6 +80,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
79
80
|
self.message_res_store: dict[str, Message] = {}
|
80
81
|
self.message_ins_id_to_message_res_id: dict[str, str] = {}
|
81
82
|
|
83
|
+
# Map flwr_aid to run_ids for O(1) reverse index lookup
|
84
|
+
self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
|
85
|
+
|
82
86
|
self.node_public_keys: set[bytes] = set()
|
83
87
|
|
84
88
|
self.lock = threading.RLock()
|
@@ -398,6 +402,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
398
402
|
fab_hash: Optional[str],
|
399
403
|
override_config: UserConfig,
|
400
404
|
federation_options: ConfigRecord,
|
405
|
+
flwr_aid: Optional[str],
|
401
406
|
) -> int:
|
402
407
|
"""Create a new run for the specified `fab_hash`."""
|
403
408
|
# Sample a random int64 as run_id
|
@@ -421,9 +426,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
421
426
|
sub_status="",
|
422
427
|
details="",
|
423
428
|
),
|
429
|
+
flwr_aid=flwr_aid if flwr_aid else "",
|
424
430
|
),
|
425
431
|
)
|
426
432
|
self.run_ids[run_id] = run_record
|
433
|
+
# Add run_id to the flwr_aid_to_run_ids mapping if flwr_aid is provided
|
434
|
+
if flwr_aid:
|
435
|
+
self.flwr_aid_to_run_ids[flwr_aid].add(run_id)
|
427
436
|
|
428
437
|
# Record federation options. Leave empty if not passed
|
429
438
|
self.federation_options[run_id] = federation_options
|
@@ -451,9 +460,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
451
460
|
with self.lock:
|
452
461
|
return self.node_public_keys.copy()
|
453
462
|
|
454
|
-
def get_run_ids(self) -> set[int]:
|
455
|
-
"""Retrieve all run IDs.
|
463
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
464
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
465
|
+
|
466
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
467
|
+
"""
|
456
468
|
with self.lock:
|
469
|
+
if flwr_aid is not None:
|
470
|
+
# Return run IDs for the specified flwr_aid
|
471
|
+
return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
|
457
472
|
return set(self.run_ids.keys())
|
458
473
|
|
459
474
|
def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
|
@@ -463,7 +478,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
463
478
|
if they have not sent a heartbeat before `active_until`.
|
464
479
|
"""
|
465
480
|
current = now()
|
466
|
-
for record in
|
481
|
+
for record in (self.run_ids.get(run_id) for run_id in run_ids):
|
482
|
+
if record is None:
|
483
|
+
continue
|
467
484
|
with record.lock:
|
468
485
|
if record.run.status.status in (Status.STARTING, Status.RUNNING):
|
469
486
|
if record.active_until < current.timestamp():
|
@@ -164,12 +164,16 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
164
164
|
fab_hash: Optional[str],
|
165
165
|
override_config: UserConfig,
|
166
166
|
federation_options: ConfigRecord,
|
167
|
+
flwr_aid: Optional[str],
|
167
168
|
) -> int:
|
168
169
|
"""Create a new run for the specified `fab_hash`."""
|
169
170
|
|
170
171
|
@abc.abstractmethod
|
171
|
-
def get_run_ids(self) -> set[int]:
|
172
|
-
"""Retrieve all run IDs.
|
172
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
173
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
174
|
+
|
175
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
176
|
+
"""
|
173
177
|
|
174
178
|
@abc.abstractmethod
|
175
179
|
def get_run(self, run_id: int) -> Optional[Run]:
|
@@ -102,7 +102,8 @@ CREATE TABLE IF NOT EXISTS run(
|
|
102
102
|
finished_at TEXT,
|
103
103
|
sub_status TEXT,
|
104
104
|
details TEXT,
|
105
|
-
federation_options BLOB
|
105
|
+
federation_options BLOB,
|
106
|
+
flwr_aid TEXT
|
106
107
|
);
|
107
108
|
"""
|
108
109
|
|
@@ -719,6 +720,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
719
720
|
fab_hash: Optional[str],
|
720
721
|
override_config: UserConfig,
|
721
722
|
federation_options: ConfigRecord,
|
723
|
+
flwr_aid: Optional[str],
|
722
724
|
) -> int:
|
723
725
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
724
726
|
# Sample a random int64 as run_id
|
@@ -735,8 +737,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
735
737
|
"INSERT INTO run "
|
736
738
|
"(run_id, active_until, heartbeat_interval, fab_id, fab_version, "
|
737
739
|
"fab_hash, override_config, federation_options, pending_at, "
|
738
|
-
"starting_at, running_at, finished_at, sub_status, details) "
|
739
|
-
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
740
|
+
"starting_at, running_at, finished_at, sub_status, details, flwr_aid) "
|
741
|
+
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
740
742
|
)
|
741
743
|
override_config_json = json.dumps(override_config)
|
742
744
|
data = [
|
@@ -754,6 +756,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
754
756
|
"",
|
755
757
|
"",
|
756
758
|
"",
|
759
|
+
flwr_aid or "",
|
757
760
|
]
|
758
761
|
self.query(query, tuple(data))
|
759
762
|
return uint64_run_id
|
@@ -782,10 +785,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
782
785
|
result: set[bytes] = {row["public_key"] for row in rows}
|
783
786
|
return result
|
784
787
|
|
785
|
-
def get_run_ids(self) -> set[int]:
|
786
|
-
"""Retrieve all run IDs.
|
787
|
-
|
788
|
-
|
788
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
789
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
790
|
+
|
791
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
792
|
+
"""
|
793
|
+
if flwr_aid:
|
794
|
+
rows = self.query(
|
795
|
+
"SELECT run_id FROM run WHERE flwr_aid = ?;",
|
796
|
+
(flwr_aid,),
|
797
|
+
)
|
798
|
+
else:
|
799
|
+
rows = self.query("SELECT run_id FROM run;", ())
|
789
800
|
return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
|
790
801
|
|
791
802
|
def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
|
@@ -836,6 +847,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
836
847
|
sub_status=row["sub_status"],
|
837
848
|
details=row["details"],
|
838
849
|
),
|
850
|
+
flwr_aid=row["flwr_aid"],
|
839
851
|
)
|
840
852
|
log(ERROR, "`run_id` does not exist.")
|
841
853
|
return None
|
@@ -26,6 +26,8 @@ from flwr.common.constant import SUPERLINK_NODE_ID, Status
|
|
26
26
|
from flwr.common.inflatable import (
|
27
27
|
UnexpectedObjectContentError,
|
28
28
|
get_descendant_object_ids,
|
29
|
+
get_object_tree,
|
30
|
+
no_object_id_recompute,
|
29
31
|
)
|
30
32
|
from flwr.common.logger import log
|
31
33
|
from flwr.common.serde import (
|
@@ -50,6 +52,8 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
50
52
|
PushLogsResponse,
|
51
53
|
)
|
52
54
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
55
|
+
ConfirmMessageReceivedRequest,
|
56
|
+
ConfirmMessageReceivedResponse,
|
53
57
|
ObjectIDs,
|
54
58
|
PullObjectRequest,
|
55
59
|
PullObjectResponse,
|
@@ -107,14 +111,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
107
111
|
"""Get available nodes."""
|
108
112
|
log(DEBUG, "ServerAppIoServicer.GetNodes")
|
109
113
|
|
110
|
-
# Init state
|
111
|
-
state
|
114
|
+
# Init state and store
|
115
|
+
state = self.state_factory.state()
|
116
|
+
store = self.objectstore_factory.store()
|
112
117
|
|
113
118
|
# Abort if the run is not running
|
114
119
|
abort_if(
|
115
120
|
request.run_id,
|
116
121
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
117
122
|
state,
|
123
|
+
store,
|
118
124
|
context,
|
119
125
|
)
|
120
126
|
|
@@ -128,14 +134,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
128
134
|
"""Push a set of Messages."""
|
129
135
|
log(DEBUG, "ServerAppIoServicer.PushMessages")
|
130
136
|
|
131
|
-
# Init state
|
132
|
-
state
|
137
|
+
# Init state and store
|
138
|
+
state = self.state_factory.state()
|
139
|
+
store = self.objectstore_factory.store()
|
133
140
|
|
134
141
|
# Abort if the run is not running
|
135
142
|
abort_if(
|
136
143
|
request.run_id,
|
137
144
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
138
145
|
state,
|
146
|
+
store,
|
139
147
|
context,
|
140
148
|
)
|
141
149
|
|
@@ -146,8 +154,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
146
154
|
detail="`messages_list` must not be empty",
|
147
155
|
)
|
148
156
|
message_ids: list[Optional[str]] = []
|
149
|
-
|
150
|
-
message_proto = request.messages_list.pop(0)
|
157
|
+
for message_proto in request.messages_list:
|
151
158
|
message = message_from_proto(message_proto=message_proto)
|
152
159
|
validation_errors = validate_message(message, is_reply_message=False)
|
153
160
|
_raise_if(
|
@@ -164,9 +171,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
164
171
|
message_id: Optional[str] = state.store_message_ins(message=message)
|
165
172
|
message_ids.append(message_id)
|
166
173
|
|
167
|
-
# Init store
|
168
|
-
store = self.objectstore_factory.store()
|
169
|
-
|
170
174
|
# Store Message object to descendants mapping and preregister objects
|
171
175
|
objects_to_push = store_mapping_and_register_objects(store, request=request)
|
172
176
|
|
@@ -183,10 +187,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
183
187
|
"""Pull a set of Messages."""
|
184
188
|
log(DEBUG, "ServerAppIoServicer.PullMessages")
|
185
189
|
|
186
|
-
# Init state
|
187
|
-
state
|
188
|
-
|
189
|
-
# Init store
|
190
|
+
# Init state and store
|
191
|
+
state = self.state_factory.state()
|
190
192
|
store = self.objectstore_factory.store()
|
191
193
|
|
192
194
|
# Abort if the run is not running
|
@@ -194,6 +196,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
194
196
|
request.run_id,
|
195
197
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
196
198
|
state,
|
199
|
+
store,
|
197
200
|
context,
|
198
201
|
)
|
199
202
|
|
@@ -205,14 +208,15 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
205
208
|
# Register messages generated by LinkState in the Store for consistency
|
206
209
|
for msg_res in messages_res:
|
207
210
|
if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
|
208
|
-
|
209
|
-
|
211
|
+
with no_object_id_recompute():
|
212
|
+
descendants = list(get_descendant_object_ids(msg_res))
|
213
|
+
message_obj_id = msg_res.metadata.message_id
|
210
214
|
# Store mapping
|
211
215
|
store.set_message_descendant_ids(
|
212
216
|
msg_object_id=message_obj_id, descendant_ids=descendants
|
213
217
|
)
|
214
218
|
# Preregister
|
215
|
-
store.preregister(
|
219
|
+
store.preregister(request.run_id, get_object_tree(msg_res))
|
216
220
|
|
217
221
|
# Delete the instruction Messages and their replies if found
|
218
222
|
message_ins_ids_to_delete = {
|
@@ -328,14 +332,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
328
332
|
"""Push ServerApp process outputs."""
|
329
333
|
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
|
330
334
|
|
331
|
-
# Init state
|
335
|
+
# Init state and store
|
332
336
|
state = self.state_factory.state()
|
337
|
+
store = self.objectstore_factory.store()
|
333
338
|
|
334
339
|
# Abort if the run is not running
|
335
340
|
abort_if(
|
336
341
|
request.run_id,
|
337
342
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
338
343
|
state,
|
344
|
+
store,
|
339
345
|
context,
|
340
346
|
)
|
341
347
|
|
@@ -348,16 +354,23 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
348
354
|
"""Update the status of a run."""
|
349
355
|
log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
|
350
356
|
|
351
|
-
# Init state
|
357
|
+
# Init state and store
|
352
358
|
state = self.state_factory.state()
|
359
|
+
store = self.objectstore_factory.store()
|
353
360
|
|
354
361
|
# Abort if the run is finished
|
355
|
-
abort_if(request.run_id, [Status.FINISHED], state, context)
|
362
|
+
abort_if(request.run_id, [Status.FINISHED], state, store, context)
|
356
363
|
|
357
364
|
# Update the run status
|
358
365
|
state.update_run_status(
|
359
366
|
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
360
367
|
)
|
368
|
+
|
369
|
+
# If the run is finished, delete the run from ObjectStore
|
370
|
+
if request.run_status.status == Status.FINISHED:
|
371
|
+
# Delete all objects related to the run
|
372
|
+
store.delete_objects_in_run(request.run_id)
|
373
|
+
|
361
374
|
return UpdateRunStatusResponse()
|
362
375
|
|
363
376
|
def PushLogs(
|
@@ -412,14 +425,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
412
425
|
"""Push an object to the ObjectStore."""
|
413
426
|
log(DEBUG, "ServerAppIoServicer.PushObject")
|
414
427
|
|
415
|
-
# Init state
|
416
|
-
state
|
428
|
+
# Init state and store
|
429
|
+
state = self.state_factory.state()
|
430
|
+
store = self.objectstore_factory.store()
|
417
431
|
|
418
432
|
# Abort if the run is not running
|
419
433
|
abort_if(
|
420
434
|
request.run_id,
|
421
435
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
422
436
|
state,
|
437
|
+
store,
|
423
438
|
context,
|
424
439
|
)
|
425
440
|
|
@@ -427,9 +442,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
427
442
|
# Cancel insertion in ObjectStore
|
428
443
|
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
|
429
444
|
|
430
|
-
# Init store
|
431
|
-
store = self.objectstore_factory.store()
|
432
|
-
|
433
445
|
# Insert in store
|
434
446
|
stored = False
|
435
447
|
try:
|
@@ -449,14 +461,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
449
461
|
"""Pull an object from the ObjectStore."""
|
450
462
|
log(DEBUG, "ServerAppIoServicer.PullObject")
|
451
463
|
|
452
|
-
# Init state
|
453
|
-
state
|
464
|
+
# Init state and store
|
465
|
+
state = self.state_factory.state()
|
466
|
+
store = self.objectstore_factory.store()
|
454
467
|
|
455
468
|
# Abort if the run is not running
|
456
469
|
abort_if(
|
457
470
|
request.run_id,
|
458
471
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
459
472
|
state,
|
473
|
+
store,
|
460
474
|
context,
|
461
475
|
)
|
462
476
|
|
@@ -464,9 +478,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
464
478
|
# Cancel insertion in ObjectStore
|
465
479
|
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
|
466
480
|
|
467
|
-
# Init store
|
468
|
-
store = self.objectstore_factory.store()
|
469
|
-
|
470
481
|
# Fetch from store
|
471
482
|
content = store.get(request.object_id)
|
472
483
|
if content is not None:
|
@@ -478,6 +489,31 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
478
489
|
)
|
479
490
|
return PullObjectResponse(object_found=False, object_available=False)
|
480
491
|
|
492
|
+
def ConfirmMessageReceived(
|
493
|
+
self, request: ConfirmMessageReceivedRequest, context: grpc.ServicerContext
|
494
|
+
) -> ConfirmMessageReceivedResponse:
|
495
|
+
"""Confirm message received."""
|
496
|
+
log(DEBUG, "ServerAppIoServicer.ConfirmMessageReceived")
|
497
|
+
|
498
|
+
# Init state and store
|
499
|
+
state = self.state_factory.state()
|
500
|
+
store = self.objectstore_factory.store()
|
501
|
+
|
502
|
+
# Abort if the run is not running
|
503
|
+
abort_if(
|
504
|
+
request.run_id,
|
505
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
506
|
+
state,
|
507
|
+
store,
|
508
|
+
context,
|
509
|
+
)
|
510
|
+
|
511
|
+
# Delete the message object
|
512
|
+
store.delete(request.message_object_id)
|
513
|
+
store.delete_message_descendant_ids(request.message_object_id)
|
514
|
+
|
515
|
+
return ConfirmMessageReceivedResponse()
|
516
|
+
|
481
517
|
|
482
518
|
def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
|
483
519
|
"""Raise a `ValueError` with a detailed message if a validation error occurs."""
|
@@ -121,6 +121,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
121
121
|
request.run_id,
|
122
122
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
123
123
|
state,
|
124
|
+
None,
|
124
125
|
context,
|
125
126
|
)
|
126
127
|
|
@@ -135,7 +136,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
135
136
|
state = self.state_factory.state()
|
136
137
|
|
137
138
|
# Abort if the run is finished
|
138
|
-
abort_if(request.run_id, [Status.FINISHED], state, context)
|
139
|
+
abort_if(request.run_id, [Status.FINISHED], state, None, context)
|
139
140
|
|
140
141
|
# Update the run status
|
141
142
|
state.update_run_status(
|
flwr/server/superlink/utils.py
CHANGED
@@ -15,11 +15,12 @@
|
|
15
15
|
"""SuperLink utilities."""
|
16
16
|
|
17
17
|
|
18
|
-
from typing import Union
|
18
|
+
from typing import Optional, Union
|
19
19
|
|
20
20
|
import grpc
|
21
21
|
|
22
22
|
from flwr.common.constant import Status, SubStatus
|
23
|
+
from flwr.common.inflatable import iterate_object_tree
|
23
24
|
from flwr.common.typing import RunStatus
|
24
25
|
from flwr.proto.fleet_pb2 import PushMessagesRequest # pylint: disable=E0611
|
25
26
|
from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
|
@@ -39,6 +40,7 @@ def check_abort(
|
|
39
40
|
run_id: int,
|
40
41
|
abort_status_list: list[str],
|
41
42
|
state: LinkState,
|
43
|
+
store: Optional[ObjectStore] = None,
|
42
44
|
) -> Union[str, None]:
|
43
45
|
"""Check if the status of the provided `run_id` is in `abort_status_list`."""
|
44
46
|
run_status: RunStatus = state.get_run_status({run_id})[run_id]
|
@@ -49,6 +51,10 @@ def check_abort(
|
|
49
51
|
msg += " Stopped by user."
|
50
52
|
return msg
|
51
53
|
|
54
|
+
# Clear the objects of the run from the store if the run is finished
|
55
|
+
if store and run_status.status == Status.FINISHED:
|
56
|
+
store.delete_objects_in_run(run_id)
|
57
|
+
|
52
58
|
return None
|
53
59
|
|
54
60
|
|
@@ -62,10 +68,11 @@ def abort_if(
|
|
62
68
|
run_id: int,
|
63
69
|
abort_status_list: list[str],
|
64
70
|
state: LinkState,
|
71
|
+
store: Optional[ObjectStore],
|
65
72
|
context: grpc.ServicerContext,
|
66
73
|
) -> None:
|
67
74
|
"""Abort context if status of the provided `run_id` is in `abort_status_list`."""
|
68
|
-
msg = check_abort(run_id, abort_status_list, state)
|
75
|
+
msg = check_abort(run_id, abort_status_list, state, store)
|
69
76
|
abort_grpc_context(msg, context)
|
70
77
|
|
71
78
|
|
@@ -73,21 +80,27 @@ def store_mapping_and_register_objects(
|
|
73
80
|
store: ObjectStore, request: Union[PushInsMessagesRequest, PushMessagesRequest]
|
74
81
|
) -> dict[str, ObjectIDs]:
|
75
82
|
"""Store Message object to descendants mapping and preregister objects."""
|
83
|
+
if not request.messages_list:
|
84
|
+
return {}
|
85
|
+
|
76
86
|
objects_to_push: dict[str, ObjectIDs] = {}
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
87
|
+
|
88
|
+
# Get run_id from the first message in the list
|
89
|
+
# All messages of a request should in the same run
|
90
|
+
run_id = request.messages_list[0].metadata.run_id
|
91
|
+
|
92
|
+
for object_tree in request.message_object_trees:
|
93
|
+
all_object_ids = [obj.object_id for obj in iterate_object_tree(object_tree)]
|
94
|
+
msg_object_id, descendant_ids = all_object_ids[-1], all_object_ids[:-1]
|
82
95
|
# Store mapping
|
83
96
|
store.set_message_descendant_ids(
|
84
|
-
msg_object_id=
|
97
|
+
msg_object_id=msg_object_id, descendant_ids=descendant_ids
|
85
98
|
)
|
86
99
|
|
87
100
|
# Preregister
|
88
|
-
object_ids_just_registered = store.preregister(
|
101
|
+
object_ids_just_registered = store.preregister(run_id, object_tree)
|
89
102
|
# Keep track of objects that need to be pushed
|
90
|
-
objects_to_push[
|
103
|
+
objects_to_push[msg_object_id] = ObjectIDs(
|
91
104
|
object_ids=object_ids_just_registered
|
92
105
|
)
|
93
106
|
|