flwr-nightly 1.13.0.dev20241105__py3-none-any.whl → 1.13.0.dev20241107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (42) hide show
  1. flwr/cli/run/run.py +16 -5
  2. flwr/client/app.py +10 -6
  3. flwr/client/nodestate/__init__.py +25 -0
  4. flwr/client/nodestate/in_memory_nodestate.py +38 -0
  5. flwr/client/nodestate/nodestate.py +30 -0
  6. flwr/client/nodestate/nodestate_factory.py +37 -0
  7. flwr/client/run_info_store.py +1 -0
  8. flwr/common/config.py +10 -0
  9. flwr/common/constant.py +1 -1
  10. flwr/common/context.py +9 -4
  11. flwr/common/object_ref.py +40 -33
  12. flwr/common/serde.py +2 -0
  13. flwr/proto/exec_pb2.py +14 -17
  14. flwr/proto/exec_pb2.pyi +6 -20
  15. flwr/proto/message_pb2.py +8 -8
  16. flwr/proto/message_pb2.pyi +4 -1
  17. flwr/server/app.py +140 -107
  18. flwr/server/driver/driver.py +1 -1
  19. flwr/server/driver/grpc_driver.py +2 -6
  20. flwr/server/driver/inmemory_driver.py +1 -3
  21. flwr/server/run_serverapp.py +5 -2
  22. flwr/server/serverapp/app.py +1 -1
  23. flwr/server/superlink/driver/serverappio_servicer.py +2 -0
  24. flwr/server/superlink/linkstate/in_memory_linkstate.py +15 -16
  25. flwr/server/superlink/linkstate/linkstate.py +18 -1
  26. flwr/server/superlink/linkstate/sqlite_linkstate.py +41 -21
  27. flwr/server/superlink/linkstate/utils.py +14 -30
  28. flwr/server/superlink/simulation/__init__.py +15 -0
  29. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  30. flwr/server/superlink/simulation/simulationio_servicer.py +132 -0
  31. flwr/simulation/__init__.py +2 -0
  32. flwr/simulation/run_simulation.py +4 -1
  33. flwr/simulation/simulationio_connection.py +86 -0
  34. flwr/superexec/deployment.py +8 -4
  35. flwr/superexec/exec_servicer.py +2 -2
  36. flwr/superexec/executor.py +4 -3
  37. flwr/superexec/simulation.py +8 -8
  38. {flwr_nightly-1.13.0.dev20241105.dist-info → flwr_nightly-1.13.0.dev20241107.dist-info}/METADATA +1 -1
  39. {flwr_nightly-1.13.0.dev20241105.dist-info → flwr_nightly-1.13.0.dev20241107.dist-info}/RECORD +42 -34
  40. {flwr_nightly-1.13.0.dev20241105.dist-info → flwr_nightly-1.13.0.dev20241107.dist-info}/LICENSE +0 -0
  41. {flwr_nightly-1.13.0.dev20241105.dist-info → flwr_nightly-1.13.0.dev20241107.dist-info}/WHEEL +0 -0
  42. {flwr_nightly-1.13.0.dev20241105.dist-info → flwr_nightly-1.13.0.dev20241107.dist-info}/entry_points.txt +0 -0
flwr/server/app.py CHANGED
@@ -47,6 +47,7 @@ from flwr.common.constant import (
47
47
  ISOLATION_MODE_SUBPROCESS,
48
48
  MISSING_EXTRA_REST,
49
49
  SERVERAPPIO_API_DEFAULT_ADDRESS,
50
+ SIMULATIONIO_API_DEFAULT_ADDRESS,
50
51
  TRANSPORT_TYPE_GRPC_ADAPTER,
51
52
  TRANSPORT_TYPE_GRPC_RERE,
52
53
  TRANSPORT_TYPE_REST,
@@ -63,6 +64,7 @@ from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
63
64
  from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
64
65
  from flwr.superexec.app import load_executor
65
66
  from flwr.superexec.exec_grpc import run_exec_api_grpc
67
+ from flwr.superexec.simulation import SimulationEngine
66
68
 
67
69
  from .client_manager import ClientManager
68
70
  from .history import History
@@ -79,6 +81,7 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
79
81
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
80
82
  from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
81
83
  from .superlink.linkstate import LinkStateFactory
84
+ from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
82
85
 
83
86
  DATABASE = ":flwr-in-memory-state:"
84
87
  BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
@@ -215,6 +218,7 @@ def run_superlink() -> None:
215
218
  # Parse IP addresses
216
219
  serverappio_address, _, _ = _format_address(args.serverappio_api_address)
217
220
  exec_address, _, _ = _format_address(args.exec_api_address)
221
+ simulationio_address, _, _ = _format_address(args.simulationio_api_address)
218
222
 
219
223
  # Obtain certificates
220
224
  certificates = _try_obtain_certificates(args)
@@ -225,128 +229,148 @@ def run_superlink() -> None:
225
229
  # Initialize FfsFactory
226
230
  ffs_factory = FfsFactory(args.storage_dir)
227
231
 
228
- # Start ServerAppIo API
229
- serverappio_server: grpc.Server = run_serverappio_api_grpc(
230
- address=serverappio_address,
232
+ # Start Exec API
233
+ executor = load_executor(args)
234
+ exec_server: grpc.Server = run_exec_api_grpc(
235
+ address=exec_address,
231
236
  state_factory=state_factory,
232
237
  ffs_factory=ffs_factory,
238
+ executor=executor,
233
239
  certificates=certificates,
240
+ config=parse_config_args(
241
+ [args.executor_config] if args.executor_config else args.executor_config
242
+ ),
234
243
  )
235
- grpc_servers = [serverappio_server]
244
+ grpc_servers = [exec_server]
236
245
 
237
- # Start Fleet API
238
- bckg_threads = []
239
- if not args.fleet_api_address:
240
- if args.fleet_api_type in [
241
- TRANSPORT_TYPE_GRPC_RERE,
242
- TRANSPORT_TYPE_GRPC_ADAPTER,
243
- ]:
244
- args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
245
- elif args.fleet_api_type == TRANSPORT_TYPE_REST:
246
- args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS
247
-
248
- fleet_address, host, port = _format_address(args.fleet_api_address)
249
-
250
- num_workers = args.fleet_api_num_workers
251
- if num_workers != 1:
252
- log(
253
- WARN,
254
- "The Fleet API currently supports only 1 worker. "
255
- "You have specified %d workers. "
256
- "Support for multiple workers will be added in future releases. "
257
- "Proceeding with a single worker.",
258
- args.fleet_api_num_workers,
259
- )
260
- num_workers = 1
246
+ # Determine Exec plugin
247
+ # If simulation is used, don't start ServerAppIo and Fleet APIs
248
+ sim_exec = isinstance(executor, SimulationEngine)
261
249
 
262
- if args.fleet_api_type == TRANSPORT_TYPE_REST:
263
- if (
264
- importlib.util.find_spec("requests")
265
- and importlib.util.find_spec("starlette")
266
- and importlib.util.find_spec("uvicorn")
267
- ) is None:
268
- sys.exit(MISSING_EXTRA_REST)
269
-
270
- _, ssl_certfile, ssl_keyfile = (
271
- certificates if certificates is not None else (None, None, None)
272
- )
273
-
274
- fleet_thread = threading.Thread(
275
- target=_run_fleet_api_rest,
276
- args=(
277
- host,
278
- port,
279
- ssl_keyfile,
280
- ssl_certfile,
281
- state_factory,
282
- ffs_factory,
283
- num_workers,
284
- ),
285
- )
286
- fleet_thread.start()
287
- bckg_threads.append(fleet_thread)
288
- elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
289
- maybe_keys = _try_setup_node_authentication(args, certificates)
290
- interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
291
- if maybe_keys is not None:
292
- (
293
- node_public_keys,
294
- server_private_key,
295
- server_public_key,
296
- ) = maybe_keys
297
- state = state_factory.state()
298
- state.store_node_public_keys(node_public_keys)
299
- state.store_server_private_public_key(
300
- private_key_to_bytes(server_private_key),
301
- public_key_to_bytes(server_public_key),
302
- )
303
- log(
304
- INFO,
305
- "Node authentication enabled with %d known public keys",
306
- len(node_public_keys),
307
- )
308
- interceptors = [AuthenticateServerInterceptor(state)]
250
+ bckg_threads = []
309
251
 
310
- fleet_server = _run_fleet_api_grpc_rere(
311
- address=fleet_address,
252
+ if sim_exec:
253
+ simulationio_server: grpc.Server = run_simulationio_api_grpc(
254
+ address=simulationio_address,
312
255
  state_factory=state_factory,
313
256
  ffs_factory=ffs_factory,
314
257
  certificates=certificates,
315
- interceptors=interceptors,
316
258
  )
317
- grpc_servers.append(fleet_server)
318
- elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
319
- fleet_server = _run_fleet_api_grpc_adapter(
320
- address=fleet_address,
259
+ grpc_servers.append(simulationio_server)
260
+
261
+ else:
262
+ # Start ServerAppIo API
263
+ serverappio_server: grpc.Server = run_serverappio_api_grpc(
264
+ address=serverappio_address,
321
265
  state_factory=state_factory,
322
266
  ffs_factory=ffs_factory,
323
267
  certificates=certificates,
324
268
  )
325
- grpc_servers.append(fleet_server)
326
- else:
327
- raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
328
-
329
- # Start Exec API
330
- exec_server: grpc.Server = run_exec_api_grpc(
331
- address=exec_address,
332
- state_factory=state_factory,
333
- ffs_factory=ffs_factory,
334
- executor=load_executor(args),
335
- certificates=certificates,
336
- config=parse_config_args(
337
- [args.executor_config] if args.executor_config else args.executor_config
338
- ),
339
- )
340
- grpc_servers.append(exec_server)
269
+ grpc_servers.append(serverappio_server)
270
+
271
+ # Start Fleet API
272
+ if not args.fleet_api_address:
273
+ if args.fleet_api_type in [
274
+ TRANSPORT_TYPE_GRPC_RERE,
275
+ TRANSPORT_TYPE_GRPC_ADAPTER,
276
+ ]:
277
+ args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
278
+ elif args.fleet_api_type == TRANSPORT_TYPE_REST:
279
+ args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS
280
+
281
+ fleet_address, host, port = _format_address(args.fleet_api_address)
282
+
283
+ num_workers = args.fleet_api_num_workers
284
+ if num_workers != 1:
285
+ log(
286
+ WARN,
287
+ "The Fleet API currently supports only 1 worker. "
288
+ "You have specified %d workers. "
289
+ "Support for multiple workers will be added in future releases. "
290
+ "Proceeding with a single worker.",
291
+ args.fleet_api_num_workers,
292
+ )
293
+ num_workers = 1
294
+
295
+ if args.fleet_api_type == TRANSPORT_TYPE_REST:
296
+ if (
297
+ importlib.util.find_spec("requests")
298
+ and importlib.util.find_spec("starlette")
299
+ and importlib.util.find_spec("uvicorn")
300
+ ) is None:
301
+ sys.exit(MISSING_EXTRA_REST)
302
+
303
+ _, ssl_certfile, ssl_keyfile = (
304
+ certificates if certificates is not None else (None, None, None)
305
+ )
341
306
 
342
- if args.isolation == ISOLATION_MODE_SUBPROCESS:
343
- # Scheduler thread
344
- scheduler_th = threading.Thread(
345
- target=_flwr_serverapp_scheduler,
346
- args=(state_factory, args.serverappio_api_address, args.ssl_ca_certfile),
347
- )
348
- scheduler_th.start()
349
- bckg_threads.append(scheduler_th)
307
+ fleet_thread = threading.Thread(
308
+ target=_run_fleet_api_rest,
309
+ args=(
310
+ host,
311
+ port,
312
+ ssl_keyfile,
313
+ ssl_certfile,
314
+ state_factory,
315
+ ffs_factory,
316
+ num_workers,
317
+ ),
318
+ )
319
+ fleet_thread.start()
320
+ bckg_threads.append(fleet_thread)
321
+ elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
322
+ maybe_keys = _try_setup_node_authentication(args, certificates)
323
+ interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
324
+ if maybe_keys is not None:
325
+ (
326
+ node_public_keys,
327
+ server_private_key,
328
+ server_public_key,
329
+ ) = maybe_keys
330
+ state = state_factory.state()
331
+ state.store_node_public_keys(node_public_keys)
332
+ state.store_server_private_public_key(
333
+ private_key_to_bytes(server_private_key),
334
+ public_key_to_bytes(server_public_key),
335
+ )
336
+ log(
337
+ INFO,
338
+ "Node authentication enabled with %d known public keys",
339
+ len(node_public_keys),
340
+ )
341
+ interceptors = [AuthenticateServerInterceptor(state)]
342
+
343
+ fleet_server = _run_fleet_api_grpc_rere(
344
+ address=fleet_address,
345
+ state_factory=state_factory,
346
+ ffs_factory=ffs_factory,
347
+ certificates=certificates,
348
+ interceptors=interceptors,
349
+ )
350
+ grpc_servers.append(fleet_server)
351
+ elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
352
+ fleet_server = _run_fleet_api_grpc_adapter(
353
+ address=fleet_address,
354
+ state_factory=state_factory,
355
+ ffs_factory=ffs_factory,
356
+ certificates=certificates,
357
+ )
358
+ grpc_servers.append(fleet_server)
359
+ else:
360
+ raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
361
+
362
+ if args.isolation == ISOLATION_MODE_SUBPROCESS:
363
+ # Scheduler thread
364
+ scheduler_th = threading.Thread(
365
+ target=_flwr_serverapp_scheduler,
366
+ args=(
367
+ state_factory,
368
+ args.serverappio_api_address,
369
+ args.ssl_ca_certfile,
370
+ ),
371
+ )
372
+ scheduler_th.start()
373
+ bckg_threads.append(scheduler_th)
350
374
 
351
375
  # Graceful shutdown
352
376
  register_exit_handlers(
@@ -361,7 +385,7 @@ def run_superlink() -> None:
361
385
  for thread in bckg_threads:
362
386
  if not thread.is_alive():
363
387
  sys.exit(1)
364
- serverappio_server.wait_for_termination(timeout=1)
388
+ exec_server.wait_for_termination(timeout=1)
365
389
 
366
390
 
367
391
  def _flwr_serverapp_scheduler(
@@ -657,6 +681,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser:
657
681
  _add_args_serverappio_api(parser=parser)
658
682
  _add_args_fleet_api(parser=parser)
659
683
  _add_args_exec_api(parser=parser)
684
+ _add_args_simulationio_api(parser=parser)
660
685
 
661
686
  return parser
662
687
 
@@ -790,3 +815,11 @@ def _add_args_exec_api(parser: argparse.ArgumentParser) -> None:
790
815
  "For example:\n\n`--executor-config 'verbose=true "
791
816
  'root-certificates="certificates/superlink-ca.crt"\'`',
792
817
  )
818
+
819
+
820
+ def _add_args_simulationio_api(parser: argparse.ArgumentParser) -> None:
821
+ parser.add_argument(
822
+ "--simulationio-api-address",
823
+ help="SimulationIo API (gRPC) server address (IPv4, IPv6, or a domain name).",
824
+ default=SIMULATIONIO_API_DEFAULT_ADDRESS,
825
+ )
@@ -27,7 +27,7 @@ class Driver(ABC):
27
27
  """Abstract base Driver class for the ServerAppIo API."""
28
28
 
29
29
  @abstractmethod
30
- def init_run(self, run_id: int) -> None:
30
+ def set_run(self, run_id: int) -> None:
31
31
  """Request a run to the SuperLink with a given `run_id`.
32
32
 
33
33
  If a Run with the specified `run_id` exists, a local Run
@@ -112,12 +112,8 @@ class GrpcDriver(Driver):
112
112
  channel.close()
113
113
  log(DEBUG, "[Driver] Disconnected")
114
114
 
115
- def init_run(self, run_id: int) -> None:
116
- """Initialize the run."""
117
- # Check if is initialized
118
- if self._run is not None:
119
- return
120
-
115
+ def set_run(self, run_id: int) -> None:
116
+ """Set the run."""
121
117
  # Get the run info
122
118
  req = GetRunRequest(run_id=run_id)
123
119
  res: GetRunResponse = self._stub.GetRun(req)
@@ -62,10 +62,8 @@ class InMemoryDriver(Driver):
62
62
  ):
63
63
  raise ValueError(f"Invalid message: {message}")
64
64
 
65
- def init_run(self, run_id: int) -> None:
65
+ def set_run(self, run_id: int) -> None:
66
66
  """Initialize the run."""
67
- if self._run is not None:
68
- return
69
67
  run = self.state.get_run(run_id)
70
68
  if run is None:
71
69
  raise RuntimeError(f"Cannot find the run with ID: {run_id}")
@@ -174,7 +174,7 @@ def run_server_app() -> None:
174
174
  root_certificates=root_certificates,
175
175
  )
176
176
  flwr_dir = get_flwr_dir(args.flwr_dir)
177
- driver.init_run(args.run_id)
177
+ driver.set_run(args.run_id)
178
178
  run_ = driver.run
179
179
  if not run_.fab_hash:
180
180
  raise ValueError("FAB hash not provided.")
@@ -188,6 +188,7 @@ def run_server_app() -> None:
188
188
 
189
189
  app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
190
190
  config = get_project_config(app_path)
191
+ run_id = run_.run_id
191
192
  else:
192
193
  # User provided `app_dir`, but not `--run-id`
193
194
  # Create run if run_id is not provided
@@ -203,7 +204,8 @@ def run_server_app() -> None:
203
204
  req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
204
205
  res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
205
206
  # Fetch full `Run` using `run_id`
206
- driver.init_run(res.run_id) # pylint: disable=W0212
207
+ driver.set_run(res.run_id) # pylint: disable=W0212
208
+ run_id = res.run_id
207
209
 
208
210
  # Obtain server app reference and the run config
209
211
  server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
@@ -221,6 +223,7 @@ def run_server_app() -> None:
221
223
 
222
224
  # Initialize Context
223
225
  context = Context(
226
+ run_id=run_id,
224
227
  node_id=0,
225
228
  node_config={},
226
229
  state=RecordSet(),
@@ -189,7 +189,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
189
189
  run = run_from_proto(res.run)
190
190
  fab = fab_from_proto(res.fab)
191
191
 
192
- driver.init_run(run.run_id)
192
+ driver.set_run(run.run_id)
193
193
 
194
194
  # Start log uploader for this run
195
195
  log_uploader = start_log_uploader(
@@ -23,6 +23,7 @@ from uuid import UUID
23
23
 
24
24
  import grpc
25
25
 
26
+ from flwr.common import ConfigsRecord
26
27
  from flwr.common.constant import Status
27
28
  from flwr.common.logger import log
28
29
  from flwr.common.serde import (
@@ -112,6 +113,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
112
113
  request.fab_version,
113
114
  fab_hash,
114
115
  user_config_from_proto(request.override_config),
116
+ ConfigsRecord(),
115
117
  )
116
118
  return CreateRunResponse(run_id=run_id)
117
119
 
@@ -30,6 +30,7 @@ from flwr.common.constant import (
30
30
  RUN_ID_NUM_BYTES,
31
31
  Status,
32
32
  )
33
+ from flwr.common.record import ConfigsRecord
33
34
  from flwr.common.typing import Run, RunStatus, UserConfig
34
35
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
35
36
  from flwr.server.superlink.linkstate.linkstate import LinkState
@@ -39,7 +40,6 @@ from .utils import (
39
40
  generate_rand_int_from_bytes,
40
41
  has_valid_sub_status,
41
42
  is_valid_transition,
42
- make_node_unavailable_taskres,
43
43
  )
44
44
 
45
45
 
@@ -69,6 +69,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
69
69
  # Map run_id to RunRecord
70
70
  self.run_ids: dict[int, RunRecord] = {}
71
71
  self.contexts: dict[int, Context] = {}
72
+ self.federation_options: dict[int, ConfigsRecord] = {}
72
73
  self.task_ins_store: dict[UUID, TaskIns] = {}
73
74
  self.task_res_store: dict[UUID, TaskRes] = {}
74
75
 
@@ -255,21 +256,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
255
256
  task_res_list.append(task_res)
256
257
  replied_task_ids.add(reply_to)
257
258
 
258
- # Check if the node is offline
259
- for task_id in task_ids - replied_task_ids:
260
- task_ins = self.task_ins_store.get(task_id)
261
- if task_ins is None:
262
- continue
263
- node_id = task_ins.task.consumer.node_id
264
- online_until, _ = self.node_ids[node_id]
265
- # Generate a TaskRes containing an error reply if the node is offline.
266
- if online_until < time.time():
267
- err_taskres = make_node_unavailable_taskres(
268
- ref_taskins=task_ins,
269
- )
270
- self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
271
- task_res_list.append(err_taskres)
272
-
273
259
  # Mark all of them as delivered
274
260
  delivered_at = now().isoformat()
275
261
  for task_res in task_res_list:
@@ -378,12 +364,14 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
378
364
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
379
365
  return self.public_key_to_node_id.get(node_public_key)
380
366
 
367
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
381
368
  def create_run(
382
369
  self,
383
370
  fab_id: Optional[str],
384
371
  fab_version: Optional[str],
385
372
  fab_hash: Optional[str],
386
373
  override_config: UserConfig,
374
+ federation_options: ConfigsRecord,
387
375
  ) -> int:
388
376
  """Create a new run for the specified `fab_hash`."""
389
377
  # Sample a random int64 as run_id
@@ -407,6 +395,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
407
395
  pending_at=now().isoformat(),
408
396
  )
409
397
  self.run_ids[run_id] = run_record
398
+
399
+ # Record federation options. Leave empty if not passed
400
+ self.federation_options[run_id] = federation_options
410
401
  return run_id
411
402
  log(ERROR, "Unexpected run creation failure.")
412
403
  return 0
@@ -514,6 +505,14 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
514
505
 
515
506
  return pending_run_id
516
507
 
508
+ def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
509
+ """Retrieve the federation options for the specified `run_id`."""
510
+ with self.lock:
511
+ if run_id not in self.run_ids:
512
+ log(ERROR, "`run_id` is invalid")
513
+ return None
514
+ return self.federation_options[run_id]
515
+
517
516
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
518
517
  """Acknowledge a ping received from a node, serving as a heartbeat."""
519
518
  with self.lock:
@@ -20,6 +20,7 @@ from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.common import Context
23
+ from flwr.common.record import ConfigsRecord
23
24
  from flwr.common.typing import Run, RunStatus, UserConfig
24
25
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
25
26
 
@@ -152,12 +153,13 @@ class LinkState(abc.ABC): # pylint: disable=R0904
152
153
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
153
154
 
154
155
  @abc.abstractmethod
155
- def create_run(
156
+ def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
156
157
  self,
157
158
  fab_id: Optional[str],
158
159
  fab_version: Optional[str],
159
160
  fab_hash: Optional[str],
160
161
  override_config: UserConfig,
162
+ federation_options: ConfigsRecord,
161
163
  ) -> int:
162
164
  """Create a new run for the specified `fab_hash`."""
163
165
 
@@ -227,6 +229,21 @@ class LinkState(abc.ABC): # pylint: disable=R0904
227
229
  there is no Run pending.
228
230
  """
229
231
 
232
+ @abc.abstractmethod
233
+ def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
234
+ """Retrieve the federation options for the specified `run_id`.
235
+
236
+ Parameters
237
+ ----------
238
+ run_id : int
239
+ The identifier of the run.
240
+
241
+ Returns
242
+ -------
243
+ Optional[ConfigsRecord]
244
+ The federation options for the run if it exists; None otherwise.
245
+ """
246
+
230
247
  @abc.abstractmethod
231
248
  def store_server_private_public_key(
232
249
  self, private_key: bytes, public_key: bytes
@@ -33,6 +33,7 @@ from flwr.common.constant import (
33
33
  RUN_ID_NUM_BYTES,
34
34
  Status,
35
35
  )
36
+ from flwr.common.record import ConfigsRecord
36
37
  from flwr.common.typing import Run, RunStatus, UserConfig
37
38
 
38
39
  # pylint: disable=E0611
@@ -45,6 +46,8 @@ from flwr.server.utils.validator import validate_task_ins_or_res
45
46
 
46
47
  from .linkstate import LinkState
47
48
  from .utils import (
49
+ configsrecord_from_bytes,
50
+ configsrecord_to_bytes,
48
51
  context_from_bytes,
49
52
  context_to_bytes,
50
53
  convert_sint64_to_uint64,
@@ -54,7 +57,6 @@ from .utils import (
54
57
  generate_rand_int_from_bytes,
55
58
  has_valid_sub_status,
56
59
  is_valid_transition,
57
- make_node_unavailable_taskres,
58
60
  )
59
61
 
60
62
  SQL_CREATE_TABLE_NODE = """
@@ -95,7 +97,8 @@ CREATE TABLE IF NOT EXISTS run(
95
97
  running_at TEXT,
96
98
  finished_at TEXT,
97
99
  sub_status TEXT,
98
- details TEXT
100
+ details TEXT,
101
+ federation_options BLOB
99
102
  );
100
103
  """
101
104
 
@@ -636,20 +639,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
636
639
  data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
637
640
  task_ins_rows = self.query(query, data)
638
641
 
639
- # Make TaskRes containing node unavailabe error
640
- for row in task_ins_rows:
641
- for row in rows:
642
- # Convert values from sint64 to uint64
643
- convert_sint64_values_in_dict_to_uint64(
644
- row, ["run_id", "producer_node_id", "consumer_node_id"]
645
- )
646
-
647
- task_ins = dict_to_task_ins(row)
648
- err_taskres = make_node_unavailable_taskres(
649
- ref_taskins=task_ins,
650
- )
651
- result.append(err_taskres)
652
-
653
642
  return result
654
643
 
655
644
  def num_task_ins(self) -> int:
@@ -810,12 +799,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
810
799
  return uint64_node_id
811
800
  return None
812
801
 
802
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
813
803
  def create_run(
814
804
  self,
815
805
  fab_id: Optional[str],
816
806
  fab_version: Optional[str],
817
807
  fab_hash: Optional[str],
818
808
  override_config: UserConfig,
809
+ federation_options: ConfigsRecord,
819
810
  ) -> int:
820
811
  """Create a new run for the specified `fab_id` and `fab_version`."""
821
812
  # Sample a random int64 as run_id
@@ -830,15 +821,29 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
830
821
  if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
831
822
  query = (
832
823
  "INSERT INTO run "
833
- "(run_id, fab_id, fab_version, fab_hash, override_config, pending_at, "
834
- "starting_at, running_at, finished_at, sub_status, details)"
835
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
824
+ "(run_id, fab_id, fab_version, fab_hash, override_config, "
825
+ "federation_options, pending_at, starting_at, running_at, finished_at, "
826
+ "sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
836
827
  )
837
828
  if fab_hash:
838
829
  fab_id, fab_version = "", ""
839
830
  override_config_json = json.dumps(override_config)
840
- data = [sint64_run_id, fab_id, fab_version, fab_hash, override_config_json]
841
- data += [now().isoformat(), "", "", "", "", ""]
831
+ data = [
832
+ sint64_run_id,
833
+ fab_id,
834
+ fab_version,
835
+ fab_hash,
836
+ override_config_json,
837
+ configsrecord_to_bytes(federation_options),
838
+ ]
839
+ data += [
840
+ now().isoformat(),
841
+ "",
842
+ "",
843
+ "",
844
+ "",
845
+ "",
846
+ ]
842
847
  self.query(query, tuple(data))
843
848
  return uint64_run_id
844
849
  log(ERROR, "Unexpected run creation failure.")
@@ -1003,6 +1008,21 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1003
1008
 
1004
1009
  return pending_run_id
1005
1010
 
1011
+ def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
1012
+ """Retrieve the federation options for the specified `run_id`."""
1013
+ # Convert the uint64 value to sint64 for SQLite
1014
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1015
+ query = "SELECT federation_options FROM run WHERE run_id = ?;"
1016
+ rows = self.query(query, (sint64_run_id,))
1017
+
1018
+ # Check if the run_id exists
1019
+ if not rows:
1020
+ log(ERROR, "`run_id` is invalid")
1021
+ return None
1022
+
1023
+ row = rows[0]
1024
+ return configsrecord_from_bytes(row["federation_options"])
1025
+
1006
1026
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
1007
1027
  """Acknowledge a ping received from a node, serving as a heartbeat."""
1008
1028
  sint64_node_id = convert_uint64_to_sint64(node_id)