flwr-nightly 1.19.0.dev20250610__py3-none-any.whl → 1.19.0.dev20250612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. flwr/client/grpc_rere_client/connection.py +48 -29
  2. flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
  3. flwr/client/rest_client/connection.py +138 -27
  4. flwr/common/auth_plugin/auth_plugin.py +6 -4
  5. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  6. flwr/common/inflatable.py +70 -1
  7. flwr/common/inflatable_grpc_utils.py +1 -1
  8. flwr/common/inflatable_rest_utils.py +99 -0
  9. flwr/common/serde.py +2 -0
  10. flwr/common/typing.py +5 -3
  11. flwr/proto/fleet_pb2.py +12 -16
  12. flwr/proto/fleet_pb2.pyi +4 -19
  13. flwr/proto/fleet_pb2_grpc.py +34 -0
  14. flwr/proto/fleet_pb2_grpc.pyi +13 -0
  15. flwr/proto/message_pb2.py +15 -9
  16. flwr/proto/message_pb2.pyi +41 -0
  17. flwr/proto/run_pb2.py +24 -24
  18. flwr/proto/run_pb2.pyi +4 -1
  19. flwr/proto/serverappio_pb2.py +22 -26
  20. flwr/proto/serverappio_pb2.pyi +4 -19
  21. flwr/proto/serverappio_pb2_grpc.py +34 -0
  22. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  23. flwr/server/fleet_event_log_interceptor.py +2 -2
  24. flwr/server/grid/grpc_grid.py +20 -9
  25. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
  26. flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
  27. flwr/server/superlink/fleet/rest_rere/rest_api.py +56 -2
  28. flwr/server/superlink/linkstate/in_memory_linkstate.py +17 -2
  29. flwr/server/superlink/linkstate/linkstate.py +6 -2
  30. flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
  31. flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
  32. flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
  33. flwr/server/superlink/utils.py +23 -10
  34. flwr/supercore/object_store/in_memory_object_store.py +160 -33
  35. flwr/supercore/object_store/object_store.py +54 -7
  36. flwr/superexec/deployment.py +6 -2
  37. flwr/superexec/exec_event_log_interceptor.py +4 -4
  38. flwr/superexec/exec_servicer.py +4 -1
  39. flwr/superexec/exec_user_auth_interceptor.py +11 -11
  40. flwr/superexec/executor.py +4 -0
  41. flwr/superexec/simulation.py +7 -1
  42. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/METADATA +1 -1
  43. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/RECORD +45 -44
  44. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/WHEEL +0 -0
  45. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/entry_points.txt +0 -0
@@ -28,7 +28,11 @@ from flwr.common.constant import (
28
28
  SUPERLINK_NODE_ID,
29
29
  )
30
30
  from flwr.common.grpc import create_channel, on_channel_state_change
31
- from flwr.common.inflatable import get_all_nested_objects
31
+ from flwr.common.inflatable import (
32
+ get_all_nested_objects,
33
+ get_object_tree,
34
+ no_object_id_recompute,
35
+ )
32
36
  from flwr.common.inflatable_grpc_utils import (
33
37
  make_pull_object_fn_grpc,
34
38
  make_push_object_fn_grpc,
@@ -43,7 +47,9 @@ from flwr.common.message import remove_content_from_message
43
47
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
44
48
  from flwr.common.serde import message_to_proto, run_from_proto
45
49
  from flwr.common.typing import Run
46
- from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
50
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
51
+ ConfirmMessageReceivedRequest,
52
+ )
47
53
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
48
54
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
49
55
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
@@ -213,18 +219,15 @@ class GrpcGrid(Grid):
213
219
  """Push one message and its associated objects."""
214
220
  # Compute mapping of message descendants
215
221
  all_objects = get_all_nested_objects(message)
216
- all_object_ids = list(all_objects.keys())
217
- msg_id = all_object_ids[-1] # Last object is the message itself
218
- descendant_ids = all_object_ids[:-1] # All but the last object are descendants
222
+ msg_id = message.object_id
223
+ object_tree = get_object_tree(message)
219
224
 
220
225
  # Call GrpcServerAppIoStub method
221
226
  res: PushInsMessagesResponse = self._stub.PushMessages(
222
227
  PushInsMessagesRequest(
223
228
  messages_list=[message_to_proto(remove_content_from_message(message))],
224
229
  run_id=run_id,
225
- msg_to_descendant_mapping={
226
- msg_id: ObjectIDs(object_ids=descendant_ids)
227
- },
230
+ message_object_trees=[object_tree],
228
231
  )
229
232
  )
230
233
 
@@ -262,7 +265,8 @@ class GrpcGrid(Grid):
262
265
  # Check message
263
266
  self._check_message(msg)
264
267
  # Try pushing message and its objects
265
- message_ids.append(self._try_push_message(run_id, msg))
268
+ with no_object_id_recompute():
269
+ message_ids.append(self._try_push_message(run_id, msg))
266
270
 
267
271
  except grpc.RpcError as e:
268
272
  if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
@@ -308,6 +312,13 @@ class GrpcGrid(Grid):
308
312
  run_id=run_id,
309
313
  ),
310
314
  )
315
+
316
+ # Confirm that the message has been received
317
+ self._stub.ConfirmMessageReceived(
318
+ ConfirmMessageReceivedRequest(
319
+ node=self.node, run_id=run_id, message_object_id=msg_id
320
+ )
321
+ )
311
322
  message = cast(
312
323
  Message, inflate_object_from_contents(msg_id, all_object_contents)
313
324
  )
@@ -40,6 +40,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
40
40
  SendNodeHeartbeatResponse,
41
41
  )
42
42
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
43
+ ConfirmMessageReceivedRequest,
44
+ ConfirmMessageReceivedResponse,
43
45
  PullObjectRequest,
44
46
  PullObjectResponse,
45
47
  PushObjectRequest,
@@ -151,6 +153,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
151
153
  res = message_handler.get_run(
152
154
  request=request,
153
155
  state=self.state_factory.state(),
156
+ store=self.objectstore_factory.store(),
154
157
  )
155
158
  except InvalidRunStatusException as e:
156
159
  abort_grpc_context(e.message, context)
@@ -167,6 +170,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
167
170
  request=request,
168
171
  ffs=self.ffs_factory.ffs(),
169
172
  state=self.state_factory.state(),
173
+ store=self.objectstore_factory.store(),
170
174
  )
171
175
  except InvalidRunStatusException as e:
172
176
  abort_grpc_context(e.message, context)
@@ -219,3 +223,24 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
219
223
  abort_grpc_context(e.message, context)
220
224
 
221
225
  return res
226
+
227
+ def ConfirmMessageReceived(
228
+ self, request: ConfirmMessageReceivedRequest, context: grpc.ServicerContext
229
+ ) -> ConfirmMessageReceivedResponse:
230
+ """Confirm message received."""
231
+ log(
232
+ DEBUG,
233
+ "[Fleet.ConfirmMessageReceived] Message with ID '%s' has been received",
234
+ request.message_object_id,
235
+ )
236
+
237
+ try:
238
+ res = message_handler.confirm_message_received(
239
+ request=request,
240
+ state=self.state_factory.state(),
241
+ store=self.objectstore_factory.store(),
242
+ )
243
+ except InvalidRunStatusException as e:
244
+ abort_grpc_context(e.message, context)
245
+
246
+ return res
@@ -44,6 +44,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
44
44
  SendNodeHeartbeatResponse,
45
45
  )
46
46
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
47
+ ConfirmMessageReceivedRequest,
48
+ ConfirmMessageReceivedResponse,
47
49
  ObjectIDs,
48
50
  PullObjectRequest,
49
51
  PullObjectResponse,
@@ -146,6 +148,7 @@ def push_messages(
146
148
  msg.metadata.run_id,
147
149
  [Status.PENDING, Status.STARTING, Status.FINISHED],
148
150
  state,
151
+ store,
149
152
  )
150
153
  if abort_msg:
151
154
  raise InvalidRunStatusException(abort_msg)
@@ -165,7 +168,9 @@ def push_messages(
165
168
  return response
166
169
 
167
170
 
168
- def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
171
+ def get_run(
172
+ request: GetRunRequest, state: LinkState, store: ObjectStore
173
+ ) -> GetRunResponse:
169
174
  """Get run information."""
170
175
  run = state.get_run(request.run_id)
171
176
 
@@ -177,6 +182,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
177
182
  request.run_id,
178
183
  [Status.PENDING, Status.STARTING, Status.FINISHED],
179
184
  state,
185
+ store,
180
186
  )
181
187
  if abort_msg:
182
188
  raise InvalidRunStatusException(abort_msg)
@@ -193,7 +199,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
193
199
 
194
200
 
195
201
  def get_fab(
196
- request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
202
+ request: GetFabRequest, ffs: Ffs, state: LinkState, store: ObjectStore
197
203
  ) -> GetFabResponse:
198
204
  """Get FAB."""
199
205
  # Abort if the run is not running
@@ -201,6 +207,7 @@ def get_fab(
201
207
  request.run_id,
202
208
  [Status.PENDING, Status.STARTING, Status.FINISHED],
203
209
  state,
210
+ store,
204
211
  )
205
212
  if abort_msg:
206
213
  raise InvalidRunStatusException(abort_msg)
@@ -220,6 +227,7 @@ def push_object(
220
227
  request.run_id,
221
228
  [Status.PENDING, Status.STARTING, Status.FINISHED],
222
229
  state,
230
+ store,
223
231
  )
224
232
  if abort_msg:
225
233
  raise InvalidRunStatusException(abort_msg)
@@ -245,6 +253,7 @@ def pull_object(
245
253
  request.run_id,
246
254
  [Status.PENDING, Status.STARTING, Status.FINISHED],
247
255
  state,
256
+ store,
248
257
  )
249
258
  if abort_msg:
250
259
  raise InvalidRunStatusException(abort_msg)
@@ -259,3 +268,25 @@ def pull_object(
259
268
  object_content=content,
260
269
  )
261
270
  return PullObjectResponse(object_found=False, object_available=False)
271
+
272
+
273
+ def confirm_message_received(
274
+ request: ConfirmMessageReceivedRequest,
275
+ state: LinkState,
276
+ store: ObjectStore,
277
+ ) -> ConfirmMessageReceivedResponse:
278
+ """Confirm message received handler."""
279
+ abort_msg = check_abort(
280
+ request.run_id,
281
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
282
+ state,
283
+ store,
284
+ )
285
+ if abort_msg:
286
+ raise InvalidRunStatusException(abort_msg)
287
+
288
+ # Delete the message object
289
+ store.delete(request.message_object_id)
290
+ store.delete_message_descendant_ids(request.message_object_id)
291
+
292
+ return ConfirmMessageReceivedResponse()
@@ -38,6 +38,14 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
38
38
  SendNodeHeartbeatRequest,
39
39
  SendNodeHeartbeatResponse,
40
40
  )
41
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
42
+ ConfirmMessageReceivedRequest,
43
+ ConfirmMessageReceivedResponse,
44
+ PullObjectRequest,
45
+ PullObjectResponse,
46
+ PushObjectRequest,
47
+ PushObjectResponse,
48
+ )
41
49
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
42
50
  from flwr.server.superlink.ffs.ffs import Ffs
43
51
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
@@ -131,6 +139,28 @@ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
131
139
  return message_handler.push_messages(request=request, state=state, store=store)
132
140
 
133
141
 
142
+ @rest_request_response(PullObjectRequest)
143
+ async def pull_object(request: PullObjectRequest) -> PullObjectResponse:
144
+ """Pull PullObject."""
145
+ # Get state from app
146
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
147
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
148
+
149
+ # Handle message
150
+ return message_handler.pull_object(request=request, state=state, store=store)
151
+
152
+
153
+ @rest_request_response(PushObjectRequest)
154
+ async def push_object(request: PushObjectRequest) -> PushObjectResponse:
155
+ """Pull PushObject."""
156
+ # Get state from app
157
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
158
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
159
+
160
+ # Handle message
161
+ return message_handler.push_object(request=request, state=state, store=store)
162
+
163
+
134
164
  @rest_request_response(SendNodeHeartbeatRequest)
135
165
  async def send_node_heartbeat(
136
166
  request: SendNodeHeartbeatRequest,
@@ -148,9 +178,10 @@ async def get_run(request: GetRunRequest) -> GetRunResponse:
148
178
  """GetRun."""
149
179
  # Get state from app
150
180
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
181
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
151
182
 
152
183
  # Handle message
153
- return message_handler.get_run(request=request, state=state)
184
+ return message_handler.get_run(request=request, state=state, store=store)
154
185
 
155
186
 
156
187
  @rest_request_response(GetFabRequest)
@@ -161,9 +192,25 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
161
192
 
162
193
  # Get state from app
163
194
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
195
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
196
+
197
+ # Handle message
198
+ return message_handler.get_fab(request=request, ffs=ffs, state=state, store=store)
199
+
200
+
201
+ @rest_request_response(ConfirmMessageReceivedRequest)
202
+ async def confirm_message_received(
203
+ request: ConfirmMessageReceivedRequest,
204
+ ) -> ConfirmMessageReceivedResponse:
205
+ """Confirm message received."""
206
+ # Get state from app
207
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
208
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
164
209
 
165
210
  # Handle message
166
- return message_handler.get_fab(request=request, ffs=ffs, state=state)
211
+ return message_handler.confirm_message_received(
212
+ request=request, state=state, store=store
213
+ )
167
214
 
168
215
 
169
216
  routes = [
@@ -171,9 +218,16 @@ routes = [
171
218
  Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
172
219
  Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
173
220
  Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
221
+ Route("/api/v0/fleet/pull-object", pull_object, methods=["POST"]),
222
+ Route("/api/v0/fleet/push-object", push_object, methods=["POST"]),
174
223
  Route("/api/v0/fleet/send-node-heartbeat", send_node_heartbeat, methods=["POST"]),
175
224
  Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
176
225
  Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
226
+ Route(
227
+ "/api/v0/fleet/confirm-message-received",
228
+ confirm_message_received,
229
+ methods=["POST"],
230
+ ),
177
231
  ]
178
232
 
179
233
  app: Starlette = Starlette(
@@ -18,6 +18,7 @@
18
18
  import threading
19
19
  import time
20
20
  from bisect import bisect_right
21
+ from collections import defaultdict
21
22
  from dataclasses import dataclass, field
22
23
  from logging import ERROR, WARNING
23
24
  from typing import Optional
@@ -79,6 +80,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
79
80
  self.message_res_store: dict[str, Message] = {}
80
81
  self.message_ins_id_to_message_res_id: dict[str, str] = {}
81
82
 
83
+ # Map flwr_aid to run_ids for O(1) reverse index lookup
84
+ self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
85
+
82
86
  self.node_public_keys: set[bytes] = set()
83
87
 
84
88
  self.lock = threading.RLock()
@@ -398,6 +402,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
398
402
  fab_hash: Optional[str],
399
403
  override_config: UserConfig,
400
404
  federation_options: ConfigRecord,
405
+ flwr_aid: Optional[str],
401
406
  ) -> int:
402
407
  """Create a new run for the specified `fab_hash`."""
403
408
  # Sample a random int64 as run_id
@@ -421,9 +426,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
421
426
  sub_status="",
422
427
  details="",
423
428
  ),
429
+ flwr_aid=flwr_aid if flwr_aid else "",
424
430
  ),
425
431
  )
426
432
  self.run_ids[run_id] = run_record
433
+ # Add run_id to the flwr_aid_to_run_ids mapping if flwr_aid is provided
434
+ if flwr_aid:
435
+ self.flwr_aid_to_run_ids[flwr_aid].add(run_id)
427
436
 
428
437
  # Record federation options. Leave empty if not passed
429
438
  self.federation_options[run_id] = federation_options
@@ -451,9 +460,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
451
460
  with self.lock:
452
461
  return self.node_public_keys.copy()
453
462
 
454
- def get_run_ids(self) -> set[int]:
455
- """Retrieve all run IDs."""
463
+ def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
464
+ """Retrieve all run IDs if `flwr_aid` is not specified.
465
+
466
+ Otherwise, retrieve all run IDs for the specified `flwr_aid`.
467
+ """
456
468
  with self.lock:
469
+ if flwr_aid is not None:
470
+ # Return run IDs for the specified flwr_aid
471
+ return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
457
472
  return set(self.run_ids.keys())
458
473
 
459
474
  def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
@@ -164,12 +164,16 @@ class LinkState(abc.ABC): # pylint: disable=R0904
164
164
  fab_hash: Optional[str],
165
165
  override_config: UserConfig,
166
166
  federation_options: ConfigRecord,
167
+ flwr_aid: Optional[str],
167
168
  ) -> int:
168
169
  """Create a new run for the specified `fab_hash`."""
169
170
 
170
171
  @abc.abstractmethod
171
- def get_run_ids(self) -> set[int]:
172
- """Retrieve all run IDs."""
172
+ def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
173
+ """Retrieve all run IDs if `flwr_aid` is not specified.
174
+
175
+ Otherwise, retrieve all run IDs for the specified `flwr_aid`.
176
+ """
173
177
 
174
178
  @abc.abstractmethod
175
179
  def get_run(self, run_id: int) -> Optional[Run]:
@@ -102,7 +102,8 @@ CREATE TABLE IF NOT EXISTS run(
102
102
  finished_at TEXT,
103
103
  sub_status TEXT,
104
104
  details TEXT,
105
- federation_options BLOB
105
+ federation_options BLOB,
106
+ flwr_aid TEXT
106
107
  );
107
108
  """
108
109
 
@@ -719,6 +720,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
719
720
  fab_hash: Optional[str],
720
721
  override_config: UserConfig,
721
722
  federation_options: ConfigRecord,
723
+ flwr_aid: Optional[str],
722
724
  ) -> int:
723
725
  """Create a new run for the specified `fab_id` and `fab_version`."""
724
726
  # Sample a random int64 as run_id
@@ -735,8 +737,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
735
737
  "INSERT INTO run "
736
738
  "(run_id, active_until, heartbeat_interval, fab_id, fab_version, "
737
739
  "fab_hash, override_config, federation_options, pending_at, "
738
- "starting_at, running_at, finished_at, sub_status, details) "
739
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
740
+ "starting_at, running_at, finished_at, sub_status, details, flwr_aid) "
741
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
740
742
  )
741
743
  override_config_json = json.dumps(override_config)
742
744
  data = [
@@ -754,6 +756,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
754
756
  "",
755
757
  "",
756
758
  "",
759
+ flwr_aid or "",
757
760
  ]
758
761
  self.query(query, tuple(data))
759
762
  return uint64_run_id
@@ -782,10 +785,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
782
785
  result: set[bytes] = {row["public_key"] for row in rows}
783
786
  return result
784
787
 
785
- def get_run_ids(self) -> set[int]:
786
- """Retrieve all run IDs."""
787
- query = "SELECT run_id FROM run;"
788
- rows = self.query(query)
788
+ def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
789
+ """Retrieve all run IDs if `flwr_aid` is not specified.
790
+
791
+ Otherwise, retrieve all run IDs for the specified `flwr_aid`.
792
+ """
793
+ if flwr_aid:
794
+ rows = self.query(
795
+ "SELECT run_id FROM run WHERE flwr_aid = ?;",
796
+ (flwr_aid,),
797
+ )
798
+ else:
799
+ rows = self.query("SELECT run_id FROM run;", ())
789
800
  return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
790
801
 
791
802
  def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
@@ -836,6 +847,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
836
847
  sub_status=row["sub_status"],
837
848
  details=row["details"],
838
849
  ),
850
+ flwr_aid=row["flwr_aid"],
839
851
  )
840
852
  log(ERROR, "`run_id` does not exist.")
841
853
  return None