flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  13. flwr/cli/run/run.py +5 -9
  14. flwr/client/app.py +6 -4
  15. flwr/client/client_app.py +260 -86
  16. flwr/client/clientapp/app.py +6 -2
  17. flwr/client/grpc_client/connection.py +24 -21
  18. flwr/client/message_handler/message_handler.py +28 -28
  19. flwr/client/mod/__init__.py +2 -2
  20. flwr/client/mod/centraldp_mods.py +7 -7
  21. flwr/client/mod/comms_mods.py +16 -22
  22. flwr/client/mod/localdp_mod.py +4 -4
  23. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  24. flwr/client/rest_client/connection.py +4 -6
  25. flwr/client/run_info_store.py +2 -2
  26. flwr/client/supernode/__init__.py +0 -2
  27. flwr/client/supernode/app.py +1 -11
  28. flwr/common/__init__.py +12 -4
  29. flwr/common/address.py +35 -0
  30. flwr/common/args.py +8 -2
  31. flwr/common/auth_plugin/auth_plugin.py +2 -1
  32. flwr/common/config.py +4 -4
  33. flwr/common/constant.py +16 -0
  34. flwr/common/context.py +4 -4
  35. flwr/common/event_log_plugin/__init__.py +22 -0
  36. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  37. flwr/common/grpc.py +1 -1
  38. flwr/common/logger.py +2 -2
  39. flwr/common/message.py +338 -102
  40. flwr/common/object_ref.py +0 -10
  41. flwr/common/record/__init__.py +8 -4
  42. flwr/common/record/arrayrecord.py +626 -0
  43. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  44. flwr/common/record/conversion_utils.py +9 -18
  45. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  46. flwr/common/record/recorddict.py +288 -0
  47. flwr/common/recorddict_compat.py +410 -0
  48. flwr/common/secure_aggregation/quantization.py +5 -1
  49. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  50. flwr/common/serde.py +67 -190
  51. flwr/common/telemetry.py +0 -10
  52. flwr/common/typing.py +44 -8
  53. flwr/proto/exec_pb2.py +3 -3
  54. flwr/proto/exec_pb2.pyi +3 -3
  55. flwr/proto/message_pb2.py +12 -12
  56. flwr/proto/message_pb2.pyi +9 -9
  57. flwr/proto/recorddict_pb2.py +70 -0
  58. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  59. flwr/proto/run_pb2.py +31 -31
  60. flwr/proto/run_pb2.pyi +3 -3
  61. flwr/server/__init__.py +3 -1
  62. flwr/server/app.py +74 -3
  63. flwr/server/compat/__init__.py +2 -2
  64. flwr/server/compat/app.py +15 -12
  65. flwr/server/compat/app_utils.py +26 -18
  66. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
  67. flwr/server/fleet_event_log_interceptor.py +94 -0
  68. flwr/server/{driver → grid}/__init__.py +8 -7
  69. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  70. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
  71. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
  72. flwr/server/run_serverapp.py +6 -17
  73. flwr/server/server_app.py +126 -33
  74. flwr/server/serverapp/app.py +10 -10
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  77. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  78. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  79. flwr/server/superlink/fleet/vce/vce_api.py +33 -38
  80. flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
  81. flwr/server/superlink/linkstate/linkstate.py +51 -64
  82. flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
  83. flwr/server/superlink/linkstate/utils.py +171 -133
  84. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  85. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  86. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
  87. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  88. flwr/server/typing.py +3 -3
  89. flwr/server/utils/__init__.py +2 -2
  90. flwr/server/utils/validator.py +53 -68
  91. flwr/server/workflow/default_workflows.py +52 -58
  92. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  93. flwr/simulation/app.py +2 -2
  94. flwr/simulation/ray_transport/ray_actor.py +4 -2
  95. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  96. flwr/simulation/run_simulation.py +15 -15
  97. flwr/superexec/app.py +0 -14
  98. flwr/superexec/deployment.py +4 -4
  99. flwr/superexec/exec_event_log_interceptor.py +135 -0
  100. flwr/superexec/exec_grpc.py +10 -4
  101. flwr/superexec/exec_servicer.py +6 -6
  102. flwr/superexec/exec_user_auth_interceptor.py +22 -4
  103. flwr/superexec/executor.py +3 -3
  104. flwr/superexec/simulation.py +3 -3
  105. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
  106. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
  107. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
  108. flwr/client/message_handler/task_handler.py +0 -37
  109. flwr/common/record/parametersrecord.py +0 -204
  110. flwr/common/record/recordset.py +0 -202
  111. flwr/common/recordset_compat.py +0 -418
  112. flwr/proto/recordset_pb2.py +0 -70
  113. flwr/proto/task_pb2.py +0 -33
  114. flwr/proto/task_pb2.pyi +0 -100
  115. flwr/proto/task_pb2_grpc.py +0 -4
  116. flwr/proto/task_pb2_grpc.pyi +0 -4
  117. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  118. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  119. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  120. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
@@ -20,15 +20,15 @@ from dataclasses import dataclass, field
20
20
  from logging import DEBUG, ERROR, INFO, WARN
21
21
  from typing import Optional, Union, cast
22
22
 
23
- import flwr.common.recordset_compat as compat
23
+ import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
25
- ConfigsRecord,
25
+ ConfigRecord,
26
26
  Context,
27
27
  FitRes,
28
28
  Message,
29
29
  MessageType,
30
30
  NDArrays,
31
- RecordSet,
31
+ RecordDict,
32
32
  bytes_to_ndarray,
33
33
  log,
34
34
  ndarrays_to_parameters,
@@ -55,7 +55,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
55
55
  from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
56
56
  from flwr.server.client_proxy import ClientProxy
57
57
  from flwr.server.compat.legacy_context import LegacyContext
58
- from flwr.server.driver import Driver
58
+ from flwr.server.grid import Grid
59
59
 
60
60
  from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
61
61
  from ..constant import Key as WorkflowKey
@@ -66,7 +66,7 @@ class WorkflowState: # pylint: disable=R0902
66
66
  """The state of the SecAgg+ protocol."""
67
67
 
68
68
  nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict)
69
- nid_to_fitins: dict[int, RecordSet] = field(default_factory=dict)
69
+ nid_to_fitins: dict[int, RecordDict] = field(default_factory=dict)
70
70
  sampled_node_ids: set[int] = field(default_factory=set)
71
71
  active_node_ids: set[int] = field(default_factory=set)
72
72
  num_shares: int = 0
@@ -186,7 +186,7 @@ class SecAggPlusWorkflow:
186
186
 
187
187
  self._check_init_params()
188
188
 
189
- def __call__(self, driver: Driver, context: Context) -> None:
189
+ def __call__(self, grid: Grid, context: Context) -> None:
190
190
  """Run the SecAgg+ protocol."""
191
191
  if not isinstance(context, LegacyContext):
192
192
  raise TypeError(
@@ -202,7 +202,7 @@ class SecAggPlusWorkflow:
202
202
  )
203
203
  log(INFO, "Secure aggregation commencing.")
204
204
  for step in steps:
205
- if not step(driver, context, state):
205
+ if not step(grid, context, state):
206
206
  log(INFO, "Secure aggregation halted.")
207
207
  return
208
208
  log(INFO, "Secure aggregation completed.")
@@ -279,14 +279,14 @@ class SecAggPlusWorkflow:
279
279
  return True
280
280
 
281
281
  def setup_stage( # pylint: disable=R0912, R0914, R0915
282
- self, driver: Driver, context: LegacyContext, state: WorkflowState
282
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
283
283
  ) -> bool:
284
284
  """Execute the 'setup' stage."""
285
285
  # Obtain fit instructions
286
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
286
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
287
287
  current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
288
- parameters = compat.parametersrecord_to_parameters(
289
- context.state.parameters_records[MAIN_PARAMS_RECORD],
288
+ parameters = compat.arrayrecord_to_parameters(
289
+ context.state.array_records[MAIN_PARAMS_RECORD],
290
290
  keep_input=True,
291
291
  )
292
292
  proxy_fitins_lst = context.strategy.configure_fit(
@@ -303,7 +303,7 @@ class SecAggPlusWorkflow:
303
303
  )
304
304
 
305
305
  state.nid_to_fitins = {
306
- proxy.node_id: compat.fitins_to_recordset(fitins, True)
306
+ proxy.node_id: compat.fitins_to_recorddict(fitins, True)
307
307
  for proxy, fitins in proxy_fitins_lst
308
308
  }
309
309
  state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
@@ -366,14 +366,14 @@ class SecAggPlusWorkflow:
366
366
  state.sampled_node_ids = state.active_node_ids
367
367
 
368
368
  # Send setup configuration to clients
369
- cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
370
- content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
369
+ cfg_record = ConfigRecord(sa_params_dict) # type: ignore
370
+ content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
371
371
 
372
372
  def make(nid: int) -> Message:
373
- return driver.create_message(
373
+ return Message(
374
374
  content=content,
375
- message_type=MessageType.TRAIN,
376
375
  dst_node_id=nid,
376
+ message_type=MessageType.TRAIN,
377
377
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
378
378
  )
379
379
 
@@ -382,7 +382,7 @@ class SecAggPlusWorkflow:
382
382
  "[Stage 0] Sending configurations to %s clients.",
383
383
  len(state.active_node_ids),
384
384
  )
385
- msgs = driver.send_and_receive(
385
+ msgs = grid.send_and_receive(
386
386
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
387
387
  )
388
388
  state.active_node_ids = {
@@ -398,7 +398,7 @@ class SecAggPlusWorkflow:
398
398
  if msg.has_error():
399
399
  state.failures.append(Exception(msg.error))
400
400
  continue
401
- key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
401
+ key_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
402
402
  node_id = msg.metadata.src_node_id
403
403
  pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
404
404
  state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
@@ -406,22 +406,22 @@ class SecAggPlusWorkflow:
406
406
  return self._check_threshold(state)
407
407
 
408
408
  def share_keys_stage( # pylint: disable=R0914
409
- self, driver: Driver, context: LegacyContext, state: WorkflowState
409
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
410
410
  ) -> bool:
411
411
  """Execute the 'share keys' stage."""
412
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
412
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
413
413
 
414
414
  def make(nid: int) -> Message:
415
415
  neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
416
- cfgs_record = ConfigsRecord(
416
+ cfg_record = ConfigRecord(
417
417
  {str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
418
418
  )
419
- cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
420
- content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
421
- return driver.create_message(
419
+ cfg_record[Key.STAGE] = Stage.SHARE_KEYS
420
+ content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
421
+ return Message(
422
422
  content=content,
423
- message_type=MessageType.TRAIN,
424
423
  dst_node_id=nid,
424
+ message_type=MessageType.TRAIN,
425
425
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
426
426
  )
427
427
 
@@ -431,7 +431,7 @@ class SecAggPlusWorkflow:
431
431
  "[Stage 1] Forwarding public keys to %s clients.",
432
432
  len(state.active_node_ids),
433
433
  )
434
- msgs = driver.send_and_receive(
434
+ msgs = grid.send_and_receive(
435
435
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
436
436
  )
437
437
  state.active_node_ids = {
@@ -458,7 +458,7 @@ class SecAggPlusWorkflow:
458
458
  state.failures.append(Exception(msg.error))
459
459
  continue
460
460
  node_id = msg.metadata.src_node_id
461
- res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
461
+ res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
462
462
  dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST])
463
463
  ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST])
464
464
  srcs += [node_id] * len(dst_lst)
@@ -476,25 +476,25 @@ class SecAggPlusWorkflow:
476
476
  return self._check_threshold(state)
477
477
 
478
478
  def collect_masked_vectors_stage(
479
- self, driver: Driver, context: LegacyContext, state: WorkflowState
479
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
480
480
  ) -> bool:
481
481
  """Execute the 'collect masked vectors' stage."""
482
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
482
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
483
483
 
484
484
  # Send secret key shares to clients (plus FitIns) and collect masked vectors
485
485
  def make(nid: int) -> Message:
486
- cfgs_dict = {
486
+ cfg_dict = {
487
487
  Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
488
488
  Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
489
489
  Key.SOURCE_LIST: state.forward_srcs[nid],
490
490
  }
491
- cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
491
+ cfg_record = ConfigRecord(cfg_dict) # type: ignore
492
492
  content = state.nid_to_fitins[nid]
493
- content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
494
- return driver.create_message(
493
+ content.config_records[RECORD_KEY_CONFIGS] = cfg_record
494
+ return Message(
495
495
  content=content,
496
- message_type=MessageType.TRAIN,
497
496
  dst_node_id=nid,
497
+ message_type=MessageType.TRAIN,
498
498
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
499
499
  )
500
500
 
@@ -503,7 +503,7 @@ class SecAggPlusWorkflow:
503
503
  "[Stage 2] Forwarding encrypted key shares to %s clients.",
504
504
  len(state.active_node_ids),
505
505
  )
506
- msgs = driver.send_and_receive(
506
+ msgs = grid.send_and_receive(
507
507
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
508
508
  )
509
509
  state.active_node_ids = {
@@ -524,7 +524,7 @@ class SecAggPlusWorkflow:
524
524
  if msg.has_error():
525
525
  state.failures.append(Exception(msg.error))
526
526
  continue
527
- res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
527
+ res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
528
528
  bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS])
529
529
  client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
530
530
  if masked_vector is None:
@@ -540,17 +540,17 @@ class SecAggPlusWorkflow:
540
540
  if msg.has_error():
541
541
  state.failures.append(Exception(msg.error))
542
542
  continue
543
- fitres = compat.recordset_to_fitres(msg.content, True)
543
+ fitres = compat.recorddict_to_fitres(msg.content, True)
544
544
  proxy = state.nid_to_proxies[msg.metadata.src_node_id]
545
545
  state.legacy_results.append((proxy, fitres))
546
546
 
547
547
  return self._check_threshold(state)
548
548
 
549
549
  def unmask_stage( # pylint: disable=R0912, R0914, R0915
550
- self, driver: Driver, context: LegacyContext, state: WorkflowState
550
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
551
551
  ) -> bool:
552
552
  """Execute the 'unmask' stage."""
553
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
553
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
554
554
  current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
555
555
 
556
556
  # Construct active node IDs and dead node IDs
@@ -560,17 +560,17 @@ class SecAggPlusWorkflow:
560
560
  # Send secure IDs of active and dead clients and collect key shares from clients
561
561
  def make(nid: int) -> Message:
562
562
  neighbours = state.nid_to_neighbours[nid]
563
- cfgs_dict = {
563
+ cfg_dict = {
564
564
  Key.STAGE: Stage.UNMASK,
565
565
  Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
566
566
  Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
567
567
  }
568
- cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
569
- content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
570
- return driver.create_message(
568
+ cfg_record = ConfigRecord(cfg_dict) # type: ignore
569
+ content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
570
+ return Message(
571
571
  content=content,
572
- message_type=MessageType.TRAIN,
573
572
  dst_node_id=nid,
573
+ message_type=MessageType.TRAIN,
574
574
  group_id=str(current_round),
575
575
  )
576
576
 
@@ -579,7 +579,7 @@ class SecAggPlusWorkflow:
579
579
  "[Stage 3] Requesting key shares from %s clients to remove masks.",
580
580
  len(state.active_node_ids),
581
581
  )
582
- msgs = driver.send_and_receive(
582
+ msgs = grid.send_and_receive(
583
583
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
584
584
  )
585
585
  state.active_node_ids = {
@@ -599,7 +599,7 @@ class SecAggPlusWorkflow:
599
599
  if msg.has_error():
600
600
  state.failures.append(Exception(msg.error))
601
601
  continue
602
- res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
602
+ res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
603
603
  nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
604
604
  shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
605
605
  for owner_nid, share in zip(nids, shares):
@@ -676,10 +676,8 @@ class SecAggPlusWorkflow:
676
676
 
677
677
  # Update the parameters and write history
678
678
  if parameters_aggregated:
679
- paramsrecord = compat.parameters_to_parametersrecord(
680
- parameters_aggregated, True
681
- )
682
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
679
+ arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
680
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
683
681
  context.history.add_metrics_distributed_fit(
684
682
  server_round=current_round, metrics=metrics_aggregated
685
683
  )
flwr/simulation/app.py CHANGED
@@ -47,7 +47,7 @@ from flwr.common.logger import (
47
47
  stop_log_uploader,
48
48
  )
49
49
  from flwr.common.serde import (
50
- configs_record_from_proto,
50
+ config_record_from_proto,
51
51
  context_from_proto,
52
52
  context_to_proto,
53
53
  fab_from_proto,
@@ -184,7 +184,7 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
184
184
  fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
185
185
  GetFederationOptionsRequest(run_id=run.run_id)
186
186
  )
187
- federation_options = configs_record_from_proto(
187
+ federation_options = config_record_from_proto(
188
188
  fed_opt_res.federation_options
189
189
  )
190
190
 
@@ -105,8 +105,10 @@ def pool_size_from_resources(client_resources: dict[str, Union[int, float]]) ->
105
105
  if not node_resources:
106
106
  continue
107
107
 
108
- num_cpus = node_resources["CPU"]
109
- num_gpus = node_resources.get("GPU", 0) # There might not be GPU
108
+ # Fallback to zero when resource quantity is not configured on the ray node
109
+ # e.g.: node without GPU; head node set up not to run tasks (zero resources)
110
+ num_cpus = node_resources.get("CPU", 0)
111
+ num_gpus = node_resources.get("GPU", 0)
110
112
  num_actors = int(num_cpus / client_resources["num_cpus"])
111
113
 
112
114
  # If a GPU is present and client resources do require one
@@ -23,7 +23,7 @@ from flwr import common
23
23
  from flwr.client import ClientFnExt
24
24
  from flwr.client.client_app import ClientApp
25
25
  from flwr.client.run_info_store import DeprecatedRunInfoStore
26
- from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
26
+ from flwr.common import DEFAULT_TTL, Message, Metadata, RecordDict, now
27
27
  from flwr.common.constant import (
28
28
  NUM_PARTITIONS_KEY,
29
29
  PARTITION_ID_KEY,
@@ -31,15 +31,16 @@ from flwr.common.constant import (
31
31
  MessageTypeLegacy,
32
32
  )
33
33
  from flwr.common.logger import log
34
- from flwr.common.recordset_compat import (
35
- evaluateins_to_recordset,
36
- fitins_to_recordset,
37
- getparametersins_to_recordset,
38
- getpropertiesins_to_recordset,
39
- recordset_to_evaluateres,
40
- recordset_to_fitres,
41
- recordset_to_getparametersres,
42
- recordset_to_getpropertiesres,
34
+ from flwr.common.message import make_message
35
+ from flwr.common.recorddict_compat import (
36
+ evaluateins_to_recorddict,
37
+ fitins_to_recorddict,
38
+ getparametersins_to_recorddict,
39
+ getpropertiesins_to_recorddict,
40
+ recorddict_to_evaluateres,
41
+ recorddict_to_fitres,
42
+ recorddict_to_getparametersres,
43
+ recorddict_to_getpropertiesres,
43
44
  )
44
45
  from flwr.server.client_proxy import ClientProxy
45
46
  from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool
@@ -109,23 +110,24 @@ class RayActorClientProxy(ClientProxy):
109
110
 
110
111
  return out_mssg
111
112
 
112
- def _wrap_recordset_in_message(
113
+ def _wrap_recorddict_in_message(
113
114
  self,
114
- recordset: RecordSet,
115
+ recorddict: RecordDict,
115
116
  message_type: str,
116
117
  timeout: Optional[float],
117
118
  group_id: Optional[int],
118
119
  ) -> Message:
119
- """Wrap a RecordSet inside a Message."""
120
- return Message(
121
- content=recordset,
120
+ """Wrap a RecordDict inside a Message."""
121
+ return make_message(
122
+ content=recorddict,
122
123
  metadata=Metadata(
123
124
  run_id=0,
124
125
  message_id="",
125
126
  group_id=str(group_id) if group_id is not None else "",
126
127
  src_node_id=0,
127
128
  dst_node_id=self.node_id,
128
- reply_to_message="",
129
+ reply_to_message_id="",
130
+ created_at=now().timestamp(),
129
131
  ttl=timeout if timeout else DEFAULT_TTL,
130
132
  message_type=message_type,
131
133
  ),
@@ -138,9 +140,9 @@ class RayActorClientProxy(ClientProxy):
138
140
  group_id: Optional[int],
139
141
  ) -> common.GetPropertiesRes:
140
142
  """Return client's properties."""
141
- recordset = getpropertiesins_to_recordset(ins)
142
- message = self._wrap_recordset_in_message(
143
- recordset,
143
+ recorddict = getpropertiesins_to_recorddict(ins)
144
+ message = self._wrap_recorddict_in_message(
145
+ recorddict,
144
146
  message_type=MessageTypeLegacy.GET_PROPERTIES,
145
147
  timeout=timeout,
146
148
  group_id=group_id,
@@ -148,7 +150,7 @@ class RayActorClientProxy(ClientProxy):
148
150
 
149
151
  message_out = self._submit_job(message, timeout)
150
152
 
151
- return recordset_to_getpropertiesres(message_out.content)
153
+ return recorddict_to_getpropertiesres(message_out.content)
152
154
 
153
155
  def get_parameters(
154
156
  self,
@@ -157,9 +159,9 @@ class RayActorClientProxy(ClientProxy):
157
159
  group_id: Optional[int],
158
160
  ) -> common.GetParametersRes:
159
161
  """Return the current local model parameters."""
160
- recordset = getparametersins_to_recordset(ins)
161
- message = self._wrap_recordset_in_message(
162
- recordset,
162
+ recorddict = getparametersins_to_recorddict(ins)
163
+ message = self._wrap_recorddict_in_message(
164
+ recorddict,
163
165
  message_type=MessageTypeLegacy.GET_PARAMETERS,
164
166
  timeout=timeout,
165
167
  group_id=group_id,
@@ -167,17 +169,17 @@ class RayActorClientProxy(ClientProxy):
167
169
 
168
170
  message_out = self._submit_job(message, timeout)
169
171
 
170
- return recordset_to_getparametersres(message_out.content, keep_input=False)
172
+ return recorddict_to_getparametersres(message_out.content, keep_input=False)
171
173
 
172
174
  def fit(
173
175
  self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
174
176
  ) -> common.FitRes:
175
177
  """Train model parameters on the locally held dataset."""
176
- recordset = fitins_to_recordset(
178
+ recorddict = fitins_to_recorddict(
177
179
  ins, keep_input=True
178
180
  ) # This must stay TRUE since ins are in-memory
179
- message = self._wrap_recordset_in_message(
180
- recordset,
181
+ message = self._wrap_recorddict_in_message(
182
+ recorddict,
181
183
  message_type=MessageType.TRAIN,
182
184
  timeout=timeout,
183
185
  group_id=group_id,
@@ -185,17 +187,17 @@ class RayActorClientProxy(ClientProxy):
185
187
 
186
188
  message_out = self._submit_job(message, timeout)
187
189
 
188
- return recordset_to_fitres(message_out.content, keep_input=False)
190
+ return recorddict_to_fitres(message_out.content, keep_input=False)
189
191
 
190
192
  def evaluate(
191
193
  self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
192
194
  ) -> common.EvaluateRes:
193
195
  """Evaluate model parameters on the locally held dataset."""
194
- recordset = evaluateins_to_recordset(
196
+ recorddict = evaluateins_to_recorddict(
195
197
  ins, keep_input=True
196
198
  ) # This must stay TRUE since ins are in-memory
197
- message = self._wrap_recordset_in_message(
198
- recordset,
199
+ message = self._wrap_recorddict_in_message(
200
+ recorddict,
199
201
  message_type=MessageType.EVALUATE,
200
202
  timeout=timeout,
201
203
  group_id=group_id,
@@ -203,7 +205,7 @@ class RayActorClientProxy(ClientProxy):
203
205
 
204
206
  message_out = self._submit_job(message, timeout)
205
207
 
206
- return recordset_to_evaluateres(message_out.content)
208
+ return recorddict_to_evaluateres(message_out.content)
207
209
 
208
210
  def reconnect(
209
211
  self,
@@ -30,7 +30,7 @@ from typing import Any, Optional
30
30
  from flwr.cli.config_utils import load_and_validate
31
31
  from flwr.cli.utils import get_sha256_hash
32
32
  from flwr.client import ClientApp
33
- from flwr.common import Context, EventType, RecordSet, event, log, now
33
+ from flwr.common import Context, EventType, RecordDict, event, log, now
34
34
  from flwr.common.config import get_fused_config_from_dir, parse_config_args
35
35
  from flwr.common.constant import RUN_ID_NUM_BYTES, Status
36
36
  from flwr.common.logger import (
@@ -39,7 +39,7 @@ from flwr.common.logger import (
39
39
  warn_deprecated_feature_with_example,
40
40
  )
41
41
  from flwr.common.typing import Run, RunStatus, UserConfig
42
- from flwr.server.driver import Driver, InMemoryDriver
42
+ from flwr.server.grid import Grid, InMemoryGrid
43
43
  from flwr.server.run_serverapp import run as _run
44
44
  from flwr.server.server_app import ServerApp
45
45
  from flwr.server.superlink.fleet import vce
@@ -168,7 +168,7 @@ def run_simulation(
168
168
  messages sent by the `ServerApp`.
169
169
 
170
170
  num_supernodes : int
171
- Number of nodes that run a ClientApp. They can be sampled by a Driver in the
171
+ Number of nodes that run a ClientApp. They can be sampled by a Grid in the
172
172
  ServerApp and receive a Message describing what the ClientApp should perform.
173
173
 
174
174
  backend_name : str (default: ray)
@@ -180,7 +180,7 @@ def run_simulation(
180
180
  for values parsed to initialisation of backend, `client_resources`
181
181
  to define the resources for clients, and `actor` to define the actor
182
182
  parameters. Values supported in <value> are those included by
183
- `flwr.common.typing.ConfigsRecordValues`.
183
+ `flwr.common.typing.ConfigRecordValues`.
184
184
 
185
185
  enable_tf_gpu_growth : bool (default: False)
186
186
  A boolean to indicate whether to enable GPU growth on the main thread. This is
@@ -225,7 +225,7 @@ def run_serverapp_th(
225
225
  server_app_attr: Optional[str],
226
226
  server_app: Optional[ServerApp],
227
227
  server_app_run_config: UserConfig,
228
- driver: Driver,
228
+ grid: Grid,
229
229
  app_dir: str,
230
230
  f_stop: threading.Event,
231
231
  has_exception: threading.Event,
@@ -239,7 +239,7 @@ def run_serverapp_th(
239
239
  tf_gpu_growth: bool,
240
240
  stop_event: threading.Event,
241
241
  exception_event: threading.Event,
242
- _driver: Driver,
242
+ _grid: Grid,
243
243
  _server_app_dir: str,
244
244
  _server_app_run_config: UserConfig,
245
245
  _server_app_attr: Optional[str],
@@ -260,13 +260,13 @@ def run_serverapp_th(
260
260
  run_id=run_id,
261
261
  node_id=0,
262
262
  node_config={},
263
- state=RecordSet(),
263
+ state=RecordDict(),
264
264
  run_config=_server_app_run_config,
265
265
  )
266
266
 
267
267
  # Run ServerApp
268
268
  updated_context = _run(
269
- driver=_driver,
269
+ grid=_grid,
270
270
  context=context,
271
271
  server_app_dir=_server_app_dir,
272
272
  server_app_attr=_server_app_attr,
@@ -291,7 +291,7 @@ def run_serverapp_th(
291
291
  enable_tf_gpu_growth,
292
292
  f_stop,
293
293
  has_exception,
294
- driver,
294
+ grid,
295
295
  app_dir,
296
296
  server_app_run_config,
297
297
  server_app_attr,
@@ -333,7 +333,7 @@ def _main_loop(
333
333
  run_id=run.run_id,
334
334
  node_id=0,
335
335
  node_config=UserConfig(),
336
- state=RecordSet(),
336
+ state=RecordDict(),
337
337
  run_config=UserConfig(),
338
338
  )
339
339
  try:
@@ -347,9 +347,9 @@ def _main_loop(
347
347
  if server_app_run_config is None:
348
348
  server_app_run_config = {}
349
349
 
350
- # Initialize Driver
351
- driver = InMemoryDriver(state_factory=state_factory)
352
- driver.set_run(run_id=run.run_id)
350
+ # Initialize Grid
351
+ grid = InMemoryGrid(state_factory=state_factory)
352
+ grid.set_run(run_id=run.run_id)
353
353
  output_context_queue: Queue[Context] = Queue()
354
354
 
355
355
  # Get and run ServerApp thread
@@ -357,7 +357,7 @@ def _main_loop(
357
357
  server_app_attr=server_app_attr,
358
358
  server_app=server_app,
359
359
  server_app_run_config=server_app_run_config,
360
- driver=driver,
360
+ grid=grid,
361
361
  app_dir=app_dir,
362
362
  f_stop=f_stop,
363
363
  has_exception=server_app_thread_has_exception,
@@ -546,7 +546,7 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
546
546
  default="{}",
547
547
  help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
548
548
  "configure a backend. Values supported in <value> are those included by "
549
- "`flwr.common.typing.ConfigsRecordValues`. ",
549
+ "`flwr.common.typing.ConfigRecordValues`. ",
550
550
  )
551
551
  parser.add_argument(
552
552
  "--enable-tf-gpu-growth",
flwr/superexec/app.py CHANGED
@@ -16,26 +16,12 @@
16
16
 
17
17
 
18
18
  import argparse
19
- import sys
20
- from logging import INFO
21
19
 
22
- from flwr.common import log
23
20
  from flwr.common.object_ref import load_app, validate
24
21
 
25
22
  from .executor import Executor
26
23
 
27
24
 
28
- def run_superexec() -> None:
29
- """Run Flower SuperExec."""
30
- log(INFO, "Starting Flower SuperExec")
31
-
32
- sys.exit(
33
- "Manually launching the SuperExec is deprecated. Since `flwr 1.13.0` "
34
- "the executor service runs in the SuperLink. Launching it manually is not "
35
- "recommended."
36
- )
37
-
38
-
39
25
  def load_executor(
40
26
  args: argparse.Namespace,
41
27
  ) -> Executor:
@@ -23,7 +23,7 @@ from typing import Optional
23
23
  from typing_extensions import override
24
24
 
25
25
  from flwr.cli.config_utils import get_fab_metadata
26
- from flwr.common import ConfigsRecord, Context, RecordSet
26
+ from flwr.common import ConfigRecord, Context, RecordDict
27
27
  from flwr.common.constant import (
28
28
  SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
29
29
  Status,
@@ -141,7 +141,7 @@ class DeploymentEngine(Executor):
141
141
  fab_id, fab_version = get_fab_metadata(fab.content)
142
142
 
143
143
  run_id = self.linkstate.create_run(
144
- fab_id, fab_version, fab_hash, override_config, ConfigsRecord()
144
+ fab_id, fab_version, fab_hash, override_config, ConfigRecord()
145
145
  )
146
146
  return run_id
147
147
 
@@ -149,7 +149,7 @@ class DeploymentEngine(Executor):
149
149
  """Register a Context for a Run."""
150
150
  # Create an empty context for the Run
151
151
  context = Context(
152
- run_id=run_id, node_id=0, node_config={}, state=RecordSet(), run_config={}
152
+ run_id=run_id, node_id=0, node_config={}, state=RecordDict(), run_config={}
153
153
  )
154
154
 
155
155
  # Register the context at the LinkState
@@ -160,7 +160,7 @@ class DeploymentEngine(Executor):
160
160
  self,
161
161
  fab_file: bytes,
162
162
  override_config: UserConfig,
163
- federation_options: ConfigsRecord,
163
+ federation_options: ConfigRecord,
164
164
  ) -> Optional[int]:
165
165
  """Start run using the Flower Deployment Engine."""
166
166
  run_id = None