flwr-nightly 1.19.0.dev20250610__py3-none-any.whl → 1.19.0.dev20250612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/grpc_rere_client/connection.py +48 -29
- flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
- flwr/client/rest_client/connection.py +138 -27
- flwr/common/auth_plugin/auth_plugin.py +6 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/inflatable.py +70 -1
- flwr/common/inflatable_grpc_utils.py +1 -1
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/serde.py +2 -0
- flwr/common/typing.py +5 -3
- flwr/proto/fleet_pb2.py +12 -16
- flwr/proto/fleet_pb2.pyi +4 -19
- flwr/proto/fleet_pb2_grpc.py +34 -0
- flwr/proto/fleet_pb2_grpc.pyi +13 -0
- flwr/proto/message_pb2.py +15 -9
- flwr/proto/message_pb2.pyi +41 -0
- flwr/proto/run_pb2.py +24 -24
- flwr/proto/run_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +22 -26
- flwr/proto/serverappio_pb2.pyi +4 -19
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +20 -9
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +56 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +17 -2
- flwr/server/superlink/linkstate/linkstate.py +6 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
- flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
- flwr/server/superlink/utils.py +23 -10
- flwr/supercore/object_store/in_memory_object_store.py +160 -33
- flwr/supercore/object_store/object_store.py +54 -7
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/exec_user_auth_interceptor.py +11 -11
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/RECORD +45 -44
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/entry_points.txt +0 -0
flwr/server/grid/grpc_grid.py
CHANGED
@@ -28,7 +28,11 @@ from flwr.common.constant import (
|
|
28
28
|
SUPERLINK_NODE_ID,
|
29
29
|
)
|
30
30
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
31
|
-
from flwr.common.inflatable import
|
31
|
+
from flwr.common.inflatable import (
|
32
|
+
get_all_nested_objects,
|
33
|
+
get_object_tree,
|
34
|
+
no_object_id_recompute,
|
35
|
+
)
|
32
36
|
from flwr.common.inflatable_grpc_utils import (
|
33
37
|
make_pull_object_fn_grpc,
|
34
38
|
make_push_object_fn_grpc,
|
@@ -43,7 +47,9 @@ from flwr.common.message import remove_content_from_message
|
|
43
47
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
44
48
|
from flwr.common.serde import message_to_proto, run_from_proto
|
45
49
|
from flwr.common.typing import Run
|
46
|
-
from flwr.proto.message_pb2 import
|
50
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
51
|
+
ConfirmMessageReceivedRequest,
|
52
|
+
)
|
47
53
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
48
54
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
49
55
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
@@ -213,18 +219,15 @@ class GrpcGrid(Grid):
|
|
213
219
|
"""Push one message and its associated objects."""
|
214
220
|
# Compute mapping of message descendants
|
215
221
|
all_objects = get_all_nested_objects(message)
|
216
|
-
|
217
|
-
|
218
|
-
descendant_ids = all_object_ids[:-1] # All but the last object are descendants
|
222
|
+
msg_id = message.object_id
|
223
|
+
object_tree = get_object_tree(message)
|
219
224
|
|
220
225
|
# Call GrpcServerAppIoStub method
|
221
226
|
res: PushInsMessagesResponse = self._stub.PushMessages(
|
222
227
|
PushInsMessagesRequest(
|
223
228
|
messages_list=[message_to_proto(remove_content_from_message(message))],
|
224
229
|
run_id=run_id,
|
225
|
-
|
226
|
-
msg_id: ObjectIDs(object_ids=descendant_ids)
|
227
|
-
},
|
230
|
+
message_object_trees=[object_tree],
|
228
231
|
)
|
229
232
|
)
|
230
233
|
|
@@ -262,7 +265,8 @@ class GrpcGrid(Grid):
|
|
262
265
|
# Check message
|
263
266
|
self._check_message(msg)
|
264
267
|
# Try pushing message and its objects
|
265
|
-
|
268
|
+
with no_object_id_recompute():
|
269
|
+
message_ids.append(self._try_push_message(run_id, msg))
|
266
270
|
|
267
271
|
except grpc.RpcError as e:
|
268
272
|
if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
|
@@ -308,6 +312,13 @@ class GrpcGrid(Grid):
|
|
308
312
|
run_id=run_id,
|
309
313
|
),
|
310
314
|
)
|
315
|
+
|
316
|
+
# Confirm that the message has been received
|
317
|
+
self._stub.ConfirmMessageReceived(
|
318
|
+
ConfirmMessageReceivedRequest(
|
319
|
+
node=self.node, run_id=run_id, message_object_id=msg_id
|
320
|
+
)
|
321
|
+
)
|
311
322
|
message = cast(
|
312
323
|
Message, inflate_object_from_contents(msg_id, all_object_contents)
|
313
324
|
)
|
@@ -40,6 +40,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
40
40
|
SendNodeHeartbeatResponse,
|
41
41
|
)
|
42
42
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
43
|
+
ConfirmMessageReceivedRequest,
|
44
|
+
ConfirmMessageReceivedResponse,
|
43
45
|
PullObjectRequest,
|
44
46
|
PullObjectResponse,
|
45
47
|
PushObjectRequest,
|
@@ -151,6 +153,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
151
153
|
res = message_handler.get_run(
|
152
154
|
request=request,
|
153
155
|
state=self.state_factory.state(),
|
156
|
+
store=self.objectstore_factory.store(),
|
154
157
|
)
|
155
158
|
except InvalidRunStatusException as e:
|
156
159
|
abort_grpc_context(e.message, context)
|
@@ -167,6 +170,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
167
170
|
request=request,
|
168
171
|
ffs=self.ffs_factory.ffs(),
|
169
172
|
state=self.state_factory.state(),
|
173
|
+
store=self.objectstore_factory.store(),
|
170
174
|
)
|
171
175
|
except InvalidRunStatusException as e:
|
172
176
|
abort_grpc_context(e.message, context)
|
@@ -219,3 +223,24 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
219
223
|
abort_grpc_context(e.message, context)
|
220
224
|
|
221
225
|
return res
|
226
|
+
|
227
|
+
def ConfirmMessageReceived(
|
228
|
+
self, request: ConfirmMessageReceivedRequest, context: grpc.ServicerContext
|
229
|
+
) -> ConfirmMessageReceivedResponse:
|
230
|
+
"""Confirm message received."""
|
231
|
+
log(
|
232
|
+
DEBUG,
|
233
|
+
"[Fleet.ConfirmMessageReceived] Message with ID '%s' has been received",
|
234
|
+
request.message_object_id,
|
235
|
+
)
|
236
|
+
|
237
|
+
try:
|
238
|
+
res = message_handler.confirm_message_received(
|
239
|
+
request=request,
|
240
|
+
state=self.state_factory.state(),
|
241
|
+
store=self.objectstore_factory.store(),
|
242
|
+
)
|
243
|
+
except InvalidRunStatusException as e:
|
244
|
+
abort_grpc_context(e.message, context)
|
245
|
+
|
246
|
+
return res
|
@@ -44,6 +44,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
44
44
|
SendNodeHeartbeatResponse,
|
45
45
|
)
|
46
46
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
47
|
+
ConfirmMessageReceivedRequest,
|
48
|
+
ConfirmMessageReceivedResponse,
|
47
49
|
ObjectIDs,
|
48
50
|
PullObjectRequest,
|
49
51
|
PullObjectResponse,
|
@@ -146,6 +148,7 @@ def push_messages(
|
|
146
148
|
msg.metadata.run_id,
|
147
149
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
148
150
|
state,
|
151
|
+
store,
|
149
152
|
)
|
150
153
|
if abort_msg:
|
151
154
|
raise InvalidRunStatusException(abort_msg)
|
@@ -165,7 +168,9 @@ def push_messages(
|
|
165
168
|
return response
|
166
169
|
|
167
170
|
|
168
|
-
def get_run(
|
171
|
+
def get_run(
|
172
|
+
request: GetRunRequest, state: LinkState, store: ObjectStore
|
173
|
+
) -> GetRunResponse:
|
169
174
|
"""Get run information."""
|
170
175
|
run = state.get_run(request.run_id)
|
171
176
|
|
@@ -177,6 +182,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
177
182
|
request.run_id,
|
178
183
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
179
184
|
state,
|
185
|
+
store,
|
180
186
|
)
|
181
187
|
if abort_msg:
|
182
188
|
raise InvalidRunStatusException(abort_msg)
|
@@ -193,7 +199,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
|
193
199
|
|
194
200
|
|
195
201
|
def get_fab(
|
196
|
-
request: GetFabRequest, ffs: Ffs, state: LinkState
|
202
|
+
request: GetFabRequest, ffs: Ffs, state: LinkState, store: ObjectStore
|
197
203
|
) -> GetFabResponse:
|
198
204
|
"""Get FAB."""
|
199
205
|
# Abort if the run is not running
|
@@ -201,6 +207,7 @@ def get_fab(
|
|
201
207
|
request.run_id,
|
202
208
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
203
209
|
state,
|
210
|
+
store,
|
204
211
|
)
|
205
212
|
if abort_msg:
|
206
213
|
raise InvalidRunStatusException(abort_msg)
|
@@ -220,6 +227,7 @@ def push_object(
|
|
220
227
|
request.run_id,
|
221
228
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
222
229
|
state,
|
230
|
+
store,
|
223
231
|
)
|
224
232
|
if abort_msg:
|
225
233
|
raise InvalidRunStatusException(abort_msg)
|
@@ -245,6 +253,7 @@ def pull_object(
|
|
245
253
|
request.run_id,
|
246
254
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
247
255
|
state,
|
256
|
+
store,
|
248
257
|
)
|
249
258
|
if abort_msg:
|
250
259
|
raise InvalidRunStatusException(abort_msg)
|
@@ -259,3 +268,25 @@ def pull_object(
|
|
259
268
|
object_content=content,
|
260
269
|
)
|
261
270
|
return PullObjectResponse(object_found=False, object_available=False)
|
271
|
+
|
272
|
+
|
273
|
+
def confirm_message_received(
|
274
|
+
request: ConfirmMessageReceivedRequest,
|
275
|
+
state: LinkState,
|
276
|
+
store: ObjectStore,
|
277
|
+
) -> ConfirmMessageReceivedResponse:
|
278
|
+
"""Confirm message received handler."""
|
279
|
+
abort_msg = check_abort(
|
280
|
+
request.run_id,
|
281
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
282
|
+
state,
|
283
|
+
store,
|
284
|
+
)
|
285
|
+
if abort_msg:
|
286
|
+
raise InvalidRunStatusException(abort_msg)
|
287
|
+
|
288
|
+
# Delete the message object
|
289
|
+
store.delete(request.message_object_id)
|
290
|
+
store.delete_message_descendant_ids(request.message_object_id)
|
291
|
+
|
292
|
+
return ConfirmMessageReceivedResponse()
|
@@ -38,6 +38,14 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
38
38
|
SendNodeHeartbeatRequest,
|
39
39
|
SendNodeHeartbeatResponse,
|
40
40
|
)
|
41
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
42
|
+
ConfirmMessageReceivedRequest,
|
43
|
+
ConfirmMessageReceivedResponse,
|
44
|
+
PullObjectRequest,
|
45
|
+
PullObjectResponse,
|
46
|
+
PushObjectRequest,
|
47
|
+
PushObjectResponse,
|
48
|
+
)
|
41
49
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
42
50
|
from flwr.server.superlink.ffs.ffs import Ffs
|
43
51
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
@@ -131,6 +139,28 @@ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
|
131
139
|
return message_handler.push_messages(request=request, state=state, store=store)
|
132
140
|
|
133
141
|
|
142
|
+
@rest_request_response(PullObjectRequest)
|
143
|
+
async def pull_object(request: PullObjectRequest) -> PullObjectResponse:
|
144
|
+
"""Pull PullObject."""
|
145
|
+
# Get state from app
|
146
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
147
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
148
|
+
|
149
|
+
# Handle message
|
150
|
+
return message_handler.pull_object(request=request, state=state, store=store)
|
151
|
+
|
152
|
+
|
153
|
+
@rest_request_response(PushObjectRequest)
|
154
|
+
async def push_object(request: PushObjectRequest) -> PushObjectResponse:
|
155
|
+
"""Pull PushObject."""
|
156
|
+
# Get state from app
|
157
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
158
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
159
|
+
|
160
|
+
# Handle message
|
161
|
+
return message_handler.push_object(request=request, state=state, store=store)
|
162
|
+
|
163
|
+
|
134
164
|
@rest_request_response(SendNodeHeartbeatRequest)
|
135
165
|
async def send_node_heartbeat(
|
136
166
|
request: SendNodeHeartbeatRequest,
|
@@ -148,9 +178,10 @@ async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
148
178
|
"""GetRun."""
|
149
179
|
# Get state from app
|
150
180
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
181
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
151
182
|
|
152
183
|
# Handle message
|
153
|
-
return message_handler.get_run(request=request, state=state)
|
184
|
+
return message_handler.get_run(request=request, state=state, store=store)
|
154
185
|
|
155
186
|
|
156
187
|
@rest_request_response(GetFabRequest)
|
@@ -161,9 +192,25 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
161
192
|
|
162
193
|
# Get state from app
|
163
194
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
195
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
196
|
+
|
197
|
+
# Handle message
|
198
|
+
return message_handler.get_fab(request=request, ffs=ffs, state=state, store=store)
|
199
|
+
|
200
|
+
|
201
|
+
@rest_request_response(ConfirmMessageReceivedRequest)
|
202
|
+
async def confirm_message_received(
|
203
|
+
request: ConfirmMessageReceivedRequest,
|
204
|
+
) -> ConfirmMessageReceivedResponse:
|
205
|
+
"""Confirm message received."""
|
206
|
+
# Get state from app
|
207
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
208
|
+
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
|
164
209
|
|
165
210
|
# Handle message
|
166
|
-
return message_handler.
|
211
|
+
return message_handler.confirm_message_received(
|
212
|
+
request=request, state=state, store=store
|
213
|
+
)
|
167
214
|
|
168
215
|
|
169
216
|
routes = [
|
@@ -171,9 +218,16 @@ routes = [
|
|
171
218
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
172
219
|
Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
|
173
220
|
Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
|
221
|
+
Route("/api/v0/fleet/pull-object", pull_object, methods=["POST"]),
|
222
|
+
Route("/api/v0/fleet/push-object", push_object, methods=["POST"]),
|
174
223
|
Route("/api/v0/fleet/send-node-heartbeat", send_node_heartbeat, methods=["POST"]),
|
175
224
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
176
225
|
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|
226
|
+
Route(
|
227
|
+
"/api/v0/fleet/confirm-message-received",
|
228
|
+
confirm_message_received,
|
229
|
+
methods=["POST"],
|
230
|
+
),
|
177
231
|
]
|
178
232
|
|
179
233
|
app: Starlette = Starlette(
|
@@ -18,6 +18,7 @@
|
|
18
18
|
import threading
|
19
19
|
import time
|
20
20
|
from bisect import bisect_right
|
21
|
+
from collections import defaultdict
|
21
22
|
from dataclasses import dataclass, field
|
22
23
|
from logging import ERROR, WARNING
|
23
24
|
from typing import Optional
|
@@ -79,6 +80,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
79
80
|
self.message_res_store: dict[str, Message] = {}
|
80
81
|
self.message_ins_id_to_message_res_id: dict[str, str] = {}
|
81
82
|
|
83
|
+
# Map flwr_aid to run_ids for O(1) reverse index lookup
|
84
|
+
self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
|
85
|
+
|
82
86
|
self.node_public_keys: set[bytes] = set()
|
83
87
|
|
84
88
|
self.lock = threading.RLock()
|
@@ -398,6 +402,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
398
402
|
fab_hash: Optional[str],
|
399
403
|
override_config: UserConfig,
|
400
404
|
federation_options: ConfigRecord,
|
405
|
+
flwr_aid: Optional[str],
|
401
406
|
) -> int:
|
402
407
|
"""Create a new run for the specified `fab_hash`."""
|
403
408
|
# Sample a random int64 as run_id
|
@@ -421,9 +426,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
421
426
|
sub_status="",
|
422
427
|
details="",
|
423
428
|
),
|
429
|
+
flwr_aid=flwr_aid if flwr_aid else "",
|
424
430
|
),
|
425
431
|
)
|
426
432
|
self.run_ids[run_id] = run_record
|
433
|
+
# Add run_id to the flwr_aid_to_run_ids mapping if flwr_aid is provided
|
434
|
+
if flwr_aid:
|
435
|
+
self.flwr_aid_to_run_ids[flwr_aid].add(run_id)
|
427
436
|
|
428
437
|
# Record federation options. Leave empty if not passed
|
429
438
|
self.federation_options[run_id] = federation_options
|
@@ -451,9 +460,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
451
460
|
with self.lock:
|
452
461
|
return self.node_public_keys.copy()
|
453
462
|
|
454
|
-
def get_run_ids(self) -> set[int]:
|
455
|
-
"""Retrieve all run IDs.
|
463
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
464
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
465
|
+
|
466
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
467
|
+
"""
|
456
468
|
with self.lock:
|
469
|
+
if flwr_aid is not None:
|
470
|
+
# Return run IDs for the specified flwr_aid
|
471
|
+
return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
|
457
472
|
return set(self.run_ids.keys())
|
458
473
|
|
459
474
|
def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
|
@@ -164,12 +164,16 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
164
164
|
fab_hash: Optional[str],
|
165
165
|
override_config: UserConfig,
|
166
166
|
federation_options: ConfigRecord,
|
167
|
+
flwr_aid: Optional[str],
|
167
168
|
) -> int:
|
168
169
|
"""Create a new run for the specified `fab_hash`."""
|
169
170
|
|
170
171
|
@abc.abstractmethod
|
171
|
-
def get_run_ids(self) -> set[int]:
|
172
|
-
"""Retrieve all run IDs.
|
172
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
173
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
174
|
+
|
175
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
176
|
+
"""
|
173
177
|
|
174
178
|
@abc.abstractmethod
|
175
179
|
def get_run(self, run_id: int) -> Optional[Run]:
|
@@ -102,7 +102,8 @@ CREATE TABLE IF NOT EXISTS run(
|
|
102
102
|
finished_at TEXT,
|
103
103
|
sub_status TEXT,
|
104
104
|
details TEXT,
|
105
|
-
federation_options BLOB
|
105
|
+
federation_options BLOB,
|
106
|
+
flwr_aid TEXT
|
106
107
|
);
|
107
108
|
"""
|
108
109
|
|
@@ -719,6 +720,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
719
720
|
fab_hash: Optional[str],
|
720
721
|
override_config: UserConfig,
|
721
722
|
federation_options: ConfigRecord,
|
723
|
+
flwr_aid: Optional[str],
|
722
724
|
) -> int:
|
723
725
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
724
726
|
# Sample a random int64 as run_id
|
@@ -735,8 +737,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
735
737
|
"INSERT INTO run "
|
736
738
|
"(run_id, active_until, heartbeat_interval, fab_id, fab_version, "
|
737
739
|
"fab_hash, override_config, federation_options, pending_at, "
|
738
|
-
"starting_at, running_at, finished_at, sub_status, details) "
|
739
|
-
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
740
|
+
"starting_at, running_at, finished_at, sub_status, details, flwr_aid) "
|
741
|
+
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
740
742
|
)
|
741
743
|
override_config_json = json.dumps(override_config)
|
742
744
|
data = [
|
@@ -754,6 +756,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
754
756
|
"",
|
755
757
|
"",
|
756
758
|
"",
|
759
|
+
flwr_aid or "",
|
757
760
|
]
|
758
761
|
self.query(query, tuple(data))
|
759
762
|
return uint64_run_id
|
@@ -782,10 +785,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
782
785
|
result: set[bytes] = {row["public_key"] for row in rows}
|
783
786
|
return result
|
784
787
|
|
785
|
-
def get_run_ids(self) -> set[int]:
|
786
|
-
"""Retrieve all run IDs.
|
787
|
-
|
788
|
-
|
788
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
789
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
790
|
+
|
791
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
792
|
+
"""
|
793
|
+
if flwr_aid:
|
794
|
+
rows = self.query(
|
795
|
+
"SELECT run_id FROM run WHERE flwr_aid = ?;",
|
796
|
+
(flwr_aid,),
|
797
|
+
)
|
798
|
+
else:
|
799
|
+
rows = self.query("SELECT run_id FROM run;", ())
|
789
800
|
return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
|
790
801
|
|
791
802
|
def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
|
@@ -836,6 +847,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
836
847
|
sub_status=row["sub_status"],
|
837
848
|
details=row["details"],
|
838
849
|
),
|
850
|
+
flwr_aid=row["flwr_aid"],
|
839
851
|
)
|
840
852
|
log(ERROR, "`run_id` does not exist.")
|
841
853
|
return None
|