flwr-nightly 1.17.0.dev20250318__py3-none-any.whl → 1.17.0.dev20250320__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. flwr/client/app.py +6 -4
  2. flwr/client/clientapp/app.py +2 -2
  3. flwr/client/grpc_client/connection.py +23 -20
  4. flwr/client/message_handler/message_handler.py +27 -27
  5. flwr/client/mod/centraldp_mods.py +7 -7
  6. flwr/client/mod/localdp_mod.py +4 -4
  7. flwr/client/mod/secure_aggregation/secaggplus_mod.py +5 -5
  8. flwr/client/run_info_store.py +2 -2
  9. flwr/common/__init__.py +2 -0
  10. flwr/common/constant.py +2 -0
  11. flwr/common/context.py +4 -4
  12. flwr/common/logger.py +2 -2
  13. flwr/common/message.py +269 -101
  14. flwr/common/record/__init__.py +2 -1
  15. flwr/common/record/configsrecord.py +2 -2
  16. flwr/common/record/metricsrecord.py +1 -1
  17. flwr/common/record/parametersrecord.py +1 -1
  18. flwr/common/record/{recordset.py → recorddict.py} +57 -17
  19. flwr/common/{recordset_compat.py → recorddict_compat.py} +105 -105
  20. flwr/common/serde.py +33 -37
  21. flwr/proto/exec_pb2.py +32 -32
  22. flwr/proto/exec_pb2.pyi +3 -3
  23. flwr/proto/message_pb2.py +12 -12
  24. flwr/proto/message_pb2.pyi +9 -9
  25. flwr/proto/recorddict_pb2.py +70 -0
  26. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +2 -2
  27. flwr/proto/run_pb2.py +32 -32
  28. flwr/proto/run_pb2.pyi +3 -3
  29. flwr/server/__init__.py +2 -0
  30. flwr/server/compat/__init__.py +2 -2
  31. flwr/server/compat/app.py +11 -11
  32. flwr/server/compat/app_utils.py +16 -16
  33. flwr/server/compat/grid_client_proxy.py +38 -38
  34. flwr/server/grid/__init__.py +7 -6
  35. flwr/server/grid/grid.py +46 -17
  36. flwr/server/grid/grpc_grid.py +26 -33
  37. flwr/server/grid/inmemory_grid.py +19 -25
  38. flwr/server/run_serverapp.py +4 -4
  39. flwr/server/server_app.py +37 -11
  40. flwr/server/serverapp/app.py +10 -10
  41. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  42. flwr/server/superlink/linkstate/in_memory_linkstate.py +29 -4
  43. flwr/server/superlink/linkstate/sqlite_linkstate.py +54 -20
  44. flwr/server/superlink/linkstate/utils.py +77 -17
  45. flwr/server/superlink/serverappio/serverappio_servicer.py +1 -1
  46. flwr/server/typing.py +3 -3
  47. flwr/server/utils/validator.py +4 -4
  48. flwr/server/workflow/default_workflows.py +24 -26
  49. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +23 -23
  50. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  51. flwr/simulation/run_simulation.py +13 -13
  52. flwr/superexec/deployment.py +2 -2
  53. flwr/superexec/simulation.py +2 -2
  54. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/METADATA +1 -1
  55. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/RECORD +60 -60
  56. flwr/proto/recordset_pb2.py +0 -70
  57. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  58. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  59. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/LICENSE +0 -0
  60. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/WHEEL +0 -0
  61. {flwr_nightly-1.17.0.dev20250318.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/entry_points.txt +0 -0
@@ -27,19 +27,14 @@ from flwr.common.constant import (
27
27
  Status,
28
28
  SubStatus,
29
29
  )
30
+ from flwr.common.message import make_message
30
31
  from flwr.common.typing import RunStatus
31
32
 
32
33
  # pylint: disable=E0611
33
34
  from flwr.proto.message_pb2 import Context as ProtoContext
34
- from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
35
+ from flwr.proto.recorddict_pb2 import ConfigsRecord as ProtoConfigsRecord
35
36
 
36
37
  # pylint: enable=E0611
37
-
38
- NODE_UNAVAILABLE_ERROR_REASON = (
39
- "Error: Node Unavailable - The destination node is currently unavailable. "
40
- "It exceeds the time limit specified in its last ping."
41
- )
42
-
43
38
  VALID_RUN_STATUS_TRANSITIONS = {
44
39
  (Status.PENDING, Status.STARTING),
45
40
  (Status.STARTING, Status.RUNNING),
@@ -60,6 +55,10 @@ MESSAGE_UNAVAILABLE_ERROR_REASON = (
60
55
  REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
61
56
  "Error: Reply Message Unavailable - The reply message has expired."
62
57
  )
58
+ NODE_UNAVAILABLE_ERROR_REASON = (
59
+ "Error: Node Unavailable - The destination node is currently unavailable. "
60
+ "It exceeds twice the time limit specified in its last ping."
61
+ )
63
62
 
64
63
 
65
64
  def generate_rand_int_from_bytes(
@@ -237,7 +236,9 @@ def has_valid_sub_status(status: RunStatus) -> bool:
237
236
  return status.sub_status == ""
238
237
 
239
238
 
240
- def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Message:
239
+ def create_message_error_unavailable_res_message(
240
+ ins_metadata: Metadata, error_type: str
241
+ ) -> Message:
241
242
  """Generate an error Message that the SuperLink returns carrying the specified
242
243
  error."""
243
244
  current_time = now().timestamp()
@@ -247,22 +248,31 @@ def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Mess
247
248
  message_id=str(uuid4()),
248
249
  src_node_id=SUPERLINK_NODE_ID,
249
250
  dst_node_id=SUPERLINK_NODE_ID,
250
- reply_to_message=ins_metadata.message_id,
251
+ reply_to_message_id=ins_metadata.message_id,
251
252
  group_id=ins_metadata.group_id,
252
253
  message_type=ins_metadata.message_type,
254
+ created_at=current_time,
253
255
  ttl=ttl,
254
256
  )
255
257
 
256
- return Message(
258
+ return make_message(
257
259
  metadata=metadata,
258
260
  error=Error(
259
- code=ErrorCode.REPLY_MESSAGE_UNAVAILABLE,
260
- reason=REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON,
261
+ code=(
262
+ ErrorCode.REPLY_MESSAGE_UNAVAILABLE
263
+ if error_type == "msg_unavail"
264
+ else ErrorCode.NODE_UNAVAILABLE
265
+ ),
266
+ reason=(
267
+ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON
268
+ if error_type == "msg_unavail"
269
+ else NODE_UNAVAILABLE_ERROR_REASON
270
+ ),
261
271
  ),
262
272
  )
263
273
 
264
274
 
265
- def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Message:
275
+ def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> Message:
266
276
  """Error to indicate that the enquired Message had expired before reply arrived or
267
277
  that it isn't found."""
268
278
  metadata = Metadata(
@@ -270,13 +280,14 @@ def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Mess
270
280
  message_id=str(uuid4()),
271
281
  src_node_id=SUPERLINK_NODE_ID,
272
282
  dst_node_id=SUPERLINK_NODE_ID,
273
- reply_to_message=str(reply_to_message),
283
+ reply_to_message_id=str(reply_to_message_id),
274
284
  group_id="", # Unknown
275
285
  message_type=MessageType.SYSTEM,
286
+ created_at=now().timestamp(),
276
287
  ttl=0,
277
288
  )
278
289
 
279
- return Message(
290
+ return make_message(
280
291
  metadata=metadata,
281
292
  error=Error(
282
293
  code=ErrorCode.MESSAGE_UNAVAILABLE,
@@ -364,14 +375,63 @@ def verify_found_message_replies(
364
375
  ret_dict: dict[UUID, Message] = {}
365
376
  current = current_time if current_time else now().timestamp()
366
377
  for message_res in found_message_res_list:
367
- message_ins_id = UUID(message_res.metadata.reply_to_message)
378
+ message_ins_id = UUID(message_res.metadata.reply_to_message_id)
368
379
  if update_set:
369
380
  inquired_message_ids.remove(message_ins_id)
370
381
  # Check if the reply Message has expired
371
382
  if message_ttl_has_expired(message_res.metadata, current):
372
383
  # No need to insert the error Message
373
384
  message_res = create_message_error_unavailable_res_message(
374
- found_message_ins_dict[message_ins_id].metadata
385
+ found_message_ins_dict[message_ins_id].metadata, "msg_unavail"
375
386
  )
376
387
  ret_dict[message_ins_id] = message_res
377
388
  return ret_dict
389
+
390
+
391
+ def check_node_availability_for_in_message(
392
+ inquired_in_message_ids: set[UUID],
393
+ found_in_message_dict: dict[UUID, Message],
394
+ node_id_to_online_until: dict[int, float],
395
+ current_time: Optional[float] = None,
396
+ update_set: bool = True,
397
+ ) -> dict[UUID, Message]:
398
+ """Check node availability for given Message and generate error reply Message if
399
+ unavailable. A Message error indicating node unavailability will be generated for
400
+ each given Message whose destination node is offline or non-existent.
401
+
402
+ Parameters
403
+ ----------
404
+ inquired_in_message_ids : set[UUID]
405
+ Set of Message IDs for which to check destination node availability.
406
+ found_in_message_dict : dict[UUID, Message]
407
+ Dictionary containing all found Message indexed by their IDs.
408
+ node_id_to_online_until : dict[int, float]
409
+ Dictionary mapping node IDs to their online-until timestamps.
410
+ current_time : Optional[float] (default: None)
411
+ The current time to check for expiration. If set to `None`, the current time
412
+ will automatically be set to the current timestamp using `now().timestamp()`.
413
+ update_set : bool (default: True)
414
+ If True, the `inquired_in_message_ids` will be updated to remove invalid ones,
415
+ by default True.
416
+
417
+ Returns
418
+ -------
419
+ dict[UUID, Message]
420
+ A dictionary of error Message indexed by the corresponding Message ID.
421
+ """
422
+ ret_dict = {}
423
+ current = current_time if current_time else now().timestamp()
424
+ for in_message_id in list(inquired_in_message_ids):
425
+ in_message = found_in_message_dict[in_message_id]
426
+ node_id = in_message.metadata.dst_node_id
427
+ online_until = node_id_to_online_until.get(node_id)
428
+ # Generate a reply message containing an error reply
429
+ # if the node is offline or doesn't exist.
430
+ if online_until is None or online_until < current:
431
+ if update_set:
432
+ inquired_in_message_ids.remove(in_message_id)
433
+ reply_message = create_message_error_unavailable_res_message(
434
+ in_message.metadata, "node_unavail"
435
+ )
436
+ ret_dict[in_message_id] = reply_message
437
+ return ret_dict
@@ -206,7 +206,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
206
206
 
207
207
  # Delete the instruction Messages and their replies if found
208
208
  message_ins_ids_to_delete = {
209
- UUID(msg_res.metadata.reply_to_message) for msg_res in messages_res
209
+ UUID(msg_res.metadata.reply_to_message_id) for msg_res in messages_res
210
210
  }
211
211
 
212
212
  state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
flwr/server/typing.py CHANGED
@@ -19,9 +19,9 @@ from typing import Callable
19
19
 
20
20
  from flwr.common import Context
21
21
 
22
- from .grid import Driver
22
+ from .grid import Grid
23
23
  from .serverapp_components import ServerAppComponents
24
24
 
25
- ServerAppCallable = Callable[[Driver, Context], None]
26
- Workflow = Callable[[Driver, Context], None]
25
+ ServerAppCallable = Callable[[Grid, Context], None]
26
+ Workflow = Callable[[Grid, Context], None]
27
27
  ServerFn = Callable[[Context], ServerAppComponents]
@@ -68,8 +68,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
68
68
 
69
69
  # Link respose to original message
70
70
  if not is_reply_message:
71
- if metadata.reply_to_message != "":
72
- validation_errors.append("`metadata.reply_to_message` MUST not be set.")
71
+ if metadata.reply_to_message_id != "":
72
+ validation_errors.append("`metadata.reply_to_message_id` MUST not be set.")
73
73
  if metadata.src_node_id != SUPERLINK_NODE_ID:
74
74
  validation_errors.append(
75
75
  f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
@@ -79,8 +79,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
79
79
  f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
80
80
  )
81
81
  else:
82
- if metadata.reply_to_message == "":
83
- validation_errors.append("`metadata.reply_to_message` MUST be set.")
82
+ if metadata.reply_to_message_id == "":
83
+ validation_errors.append("`metadata.reply_to_message_id` MUST be set.")
84
84
  if metadata.src_node_id == SUPERLINK_NODE_ID:
85
85
  validation_errors.append(
86
86
  f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
@@ -20,7 +20,7 @@ import timeit
20
20
  from logging import INFO, WARN
21
21
  from typing import Optional, Union, cast
22
22
 
23
- import flwr.common.recordset_compat as compat
23
+ import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
25
25
  Code,
26
26
  ConfigsRecord,
@@ -36,7 +36,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
36
36
  from ..client_proxy import ClientProxy
37
37
  from ..compat.app_utils import start_update_client_manager_thread
38
38
  from ..compat.legacy_context import LegacyContext
39
- from ..grid import Driver
39
+ from ..grid import Grid
40
40
  from ..typing import Workflow
41
41
  from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
42
42
 
@@ -56,7 +56,7 @@ class DefaultWorkflow:
56
56
  self.fit_workflow: Workflow = fit_workflow
57
57
  self.evaluate_workflow: Workflow = evaluate_workflow
58
58
 
59
- def __call__(self, driver: Driver, context: Context) -> None:
59
+ def __call__(self, grid: Grid, context: Context) -> None:
60
60
  """Execute the workflow."""
61
61
  if not isinstance(context, LegacyContext):
62
62
  raise TypeError(
@@ -65,7 +65,7 @@ class DefaultWorkflow:
65
65
 
66
66
  # Start the thread updating nodes
67
67
  thread, f_stop, c_done = start_update_client_manager_thread(
68
- driver, context.client_manager
68
+ grid, context.client_manager
69
69
  )
70
70
 
71
71
  # Wait until the node registration done
@@ -73,7 +73,7 @@ class DefaultWorkflow:
73
73
 
74
74
  # Initialize parameters
75
75
  log(INFO, "[INIT]")
76
- default_init_params_workflow(driver, context)
76
+ default_init_params_workflow(grid, context)
77
77
 
78
78
  # Run federated learning for num_rounds
79
79
  start_time = timeit.default_timer()
@@ -87,13 +87,13 @@ class DefaultWorkflow:
87
87
  cfg[Key.CURRENT_ROUND] = current_round
88
88
 
89
89
  # Fit round
90
- self.fit_workflow(driver, context)
90
+ self.fit_workflow(grid, context)
91
91
 
92
92
  # Centralized evaluation
93
- default_centralized_evaluation_workflow(driver, context)
93
+ default_centralized_evaluation_workflow(grid, context)
94
94
 
95
95
  # Evaluate round
96
- self.evaluate_workflow(driver, context)
96
+ self.evaluate_workflow(grid, context)
97
97
 
98
98
  # Bookkeeping and log results
99
99
  end_time = timeit.default_timer()
@@ -119,7 +119,7 @@ class DefaultWorkflow:
119
119
  thread.join()
120
120
 
121
121
 
122
- def default_init_params_workflow(driver: Driver, context: Context) -> None:
122
+ def default_init_params_workflow(grid: Grid, context: Context) -> None:
123
123
  """Execute the default workflow for parameters initialization."""
124
124
  if not isinstance(context, LegacyContext):
125
125
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -137,10 +137,10 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
137
137
  log(INFO, "Requesting initial parameters from one random client")
138
138
  random_client = context.client_manager.sample(1)[0]
139
139
  # Send GetParametersIns and get the response
140
- content = compat.getparametersins_to_recordset(GetParametersIns({}))
141
- messages = driver.send_and_receive(
140
+ content = compat.getparametersins_to_recorddict(GetParametersIns({}))
141
+ messages = grid.send_and_receive(
142
142
  [
143
- driver.create_message(
143
+ grid.create_message(
144
144
  content=content,
145
145
  message_type=MessageTypeLegacy.GET_PARAMETERS,
146
146
  dst_node_id=random_client.node_id,
@@ -152,7 +152,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
152
152
 
153
153
  if (
154
154
  msg.has_content()
155
- and compat._extract_status_from_recordset( # pylint: disable=W0212
155
+ and compat._extract_status_from_recorddict( # pylint: disable=W0212
156
156
  "getparametersres", msg.content
157
157
  ).code
158
158
  == Code.OK
@@ -186,7 +186,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
186
186
  log(INFO, "Evaluation returned no results (`None`)")
187
187
 
188
188
 
189
- def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None:
189
+ def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
190
190
  """Execute the default workflow for centralized evaluation."""
191
191
  if not isinstance(context, LegacyContext):
192
192
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -218,9 +218,7 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
218
218
  )
219
219
 
220
220
 
221
- def default_fit_workflow( # pylint: disable=R0914
222
- driver: Driver, context: Context
223
- ) -> None:
221
+ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disable=R0914
224
222
  """Execute the default workflow for a single fit round."""
225
223
  if not isinstance(context, LegacyContext):
226
224
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -255,8 +253,8 @@ def default_fit_workflow( # pylint: disable=R0914
255
253
 
256
254
  # Build out messages
257
255
  out_messages = [
258
- driver.create_message(
259
- content=compat.fitins_to_recordset(fitins, True),
256
+ grid.create_message(
257
+ content=compat.fitins_to_recorddict(fitins, True),
260
258
  message_type=MessageType.TRAIN,
261
259
  dst_node_id=proxy.node_id,
262
260
  group_id=str(current_round),
@@ -266,7 +264,7 @@ def default_fit_workflow( # pylint: disable=R0914
266
264
 
267
265
  # Send instructions to clients and
268
266
  # collect `fit` results from all clients participating in this round
269
- messages = list(driver.send_and_receive(out_messages))
267
+ messages = list(grid.send_and_receive(out_messages))
270
268
  del out_messages
271
269
  num_failures = len([msg for msg in messages if msg.has_error()])
272
270
 
@@ -284,7 +282,7 @@ def default_fit_workflow( # pylint: disable=R0914
284
282
  for msg in messages:
285
283
  if msg.has_content():
286
284
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
287
- fitres = compat.recordset_to_fitres(msg.content, False)
285
+ fitres = compat.recorddict_to_fitres(msg.content, False)
288
286
  if fitres.status.code == Code.OK:
289
287
  results.append((proxy, fitres))
290
288
  else:
@@ -307,7 +305,7 @@ def default_fit_workflow( # pylint: disable=R0914
307
305
 
308
306
 
309
307
  # pylint: disable-next=R0914
310
- def default_evaluate_workflow(driver: Driver, context: Context) -> None:
308
+ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
311
309
  """Execute the default workflow for a single evaluate round."""
312
310
  if not isinstance(context, LegacyContext):
313
311
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -341,8 +339,8 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
341
339
 
342
340
  # Build out messages
343
341
  out_messages = [
344
- driver.create_message(
345
- content=compat.evaluateins_to_recordset(evalins, True),
342
+ grid.create_message(
343
+ content=compat.evaluateins_to_recorddict(evalins, True),
346
344
  message_type=MessageType.EVALUATE,
347
345
  dst_node_id=proxy.node_id,
348
346
  group_id=str(current_round),
@@ -352,7 +350,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
352
350
 
353
351
  # Send instructions to clients and
354
352
  # collect `evaluate` results from all clients participating in this round
355
- messages = list(driver.send_and_receive(out_messages))
353
+ messages = list(grid.send_and_receive(out_messages))
356
354
  del out_messages
357
355
  num_failures = len([msg for msg in messages if msg.has_error()])
358
356
 
@@ -370,7 +368,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
370
368
  for msg in messages:
371
369
  if msg.has_content():
372
370
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
373
- evalres = compat.recordset_to_evaluateres(msg.content)
371
+ evalres = compat.recorddict_to_evaluateres(msg.content)
374
372
  if evalres.status.code == Code.OK:
375
373
  results.append((proxy, evalres))
376
374
  else:
@@ -20,7 +20,7 @@ from dataclasses import dataclass, field
20
20
  from logging import DEBUG, ERROR, INFO, WARN
21
21
  from typing import Optional, Union, cast
22
22
 
23
- import flwr.common.recordset_compat as compat
23
+ import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
25
25
  ConfigsRecord,
26
26
  Context,
@@ -28,7 +28,7 @@ from flwr.common import (
28
28
  Message,
29
29
  MessageType,
30
30
  NDArrays,
31
- RecordSet,
31
+ RecordDict,
32
32
  bytes_to_ndarray,
33
33
  log,
34
34
  ndarrays_to_parameters,
@@ -55,7 +55,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
55
55
  from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
56
56
  from flwr.server.client_proxy import ClientProxy
57
57
  from flwr.server.compat.legacy_context import LegacyContext
58
- from flwr.server.grid import Driver
58
+ from flwr.server.grid import Grid
59
59
 
60
60
  from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
61
61
  from ..constant import Key as WorkflowKey
@@ -66,7 +66,7 @@ class WorkflowState: # pylint: disable=R0902
66
66
  """The state of the SecAgg+ protocol."""
67
67
 
68
68
  nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict)
69
- nid_to_fitins: dict[int, RecordSet] = field(default_factory=dict)
69
+ nid_to_fitins: dict[int, RecordDict] = field(default_factory=dict)
70
70
  sampled_node_ids: set[int] = field(default_factory=set)
71
71
  active_node_ids: set[int] = field(default_factory=set)
72
72
  num_shares: int = 0
@@ -186,7 +186,7 @@ class SecAggPlusWorkflow:
186
186
 
187
187
  self._check_init_params()
188
188
 
189
- def __call__(self, driver: Driver, context: Context) -> None:
189
+ def __call__(self, grid: Grid, context: Context) -> None:
190
190
  """Run the SecAgg+ protocol."""
191
191
  if not isinstance(context, LegacyContext):
192
192
  raise TypeError(
@@ -202,7 +202,7 @@ class SecAggPlusWorkflow:
202
202
  )
203
203
  log(INFO, "Secure aggregation commencing.")
204
204
  for step in steps:
205
- if not step(driver, context, state):
205
+ if not step(grid, context, state):
206
206
  log(INFO, "Secure aggregation halted.")
207
207
  return
208
208
  log(INFO, "Secure aggregation completed.")
@@ -279,7 +279,7 @@ class SecAggPlusWorkflow:
279
279
  return True
280
280
 
281
281
  def setup_stage( # pylint: disable=R0912, R0914, R0915
282
- self, driver: Driver, context: LegacyContext, state: WorkflowState
282
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
283
283
  ) -> bool:
284
284
  """Execute the 'setup' stage."""
285
285
  # Obtain fit instructions
@@ -303,7 +303,7 @@ class SecAggPlusWorkflow:
303
303
  )
304
304
 
305
305
  state.nid_to_fitins = {
306
- proxy.node_id: compat.fitins_to_recordset(fitins, True)
306
+ proxy.node_id: compat.fitins_to_recorddict(fitins, True)
307
307
  for proxy, fitins in proxy_fitins_lst
308
308
  }
309
309
  state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
@@ -367,10 +367,10 @@ class SecAggPlusWorkflow:
367
367
 
368
368
  # Send setup configuration to clients
369
369
  cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
370
- content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
370
+ content = RecordDict({RECORD_KEY_CONFIGS: cfgs_record})
371
371
 
372
372
  def make(nid: int) -> Message:
373
- return driver.create_message(
373
+ return grid.create_message(
374
374
  content=content,
375
375
  message_type=MessageType.TRAIN,
376
376
  dst_node_id=nid,
@@ -382,7 +382,7 @@ class SecAggPlusWorkflow:
382
382
  "[Stage 0] Sending configurations to %s clients.",
383
383
  len(state.active_node_ids),
384
384
  )
385
- msgs = driver.send_and_receive(
385
+ msgs = grid.send_and_receive(
386
386
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
387
387
  )
388
388
  state.active_node_ids = {
@@ -406,7 +406,7 @@ class SecAggPlusWorkflow:
406
406
  return self._check_threshold(state)
407
407
 
408
408
  def share_keys_stage( # pylint: disable=R0914
409
- self, driver: Driver, context: LegacyContext, state: WorkflowState
409
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
410
410
  ) -> bool:
411
411
  """Execute the 'share keys' stage."""
412
412
  cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
@@ -417,8 +417,8 @@ class SecAggPlusWorkflow:
417
417
  {str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
418
418
  )
419
419
  cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
420
- content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
421
- return driver.create_message(
420
+ content = RecordDict({RECORD_KEY_CONFIGS: cfgs_record})
421
+ return grid.create_message(
422
422
  content=content,
423
423
  message_type=MessageType.TRAIN,
424
424
  dst_node_id=nid,
@@ -431,7 +431,7 @@ class SecAggPlusWorkflow:
431
431
  "[Stage 1] Forwarding public keys to %s clients.",
432
432
  len(state.active_node_ids),
433
433
  )
434
- msgs = driver.send_and_receive(
434
+ msgs = grid.send_and_receive(
435
435
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
436
436
  )
437
437
  state.active_node_ids = {
@@ -476,7 +476,7 @@ class SecAggPlusWorkflow:
476
476
  return self._check_threshold(state)
477
477
 
478
478
  def collect_masked_vectors_stage(
479
- self, driver: Driver, context: LegacyContext, state: WorkflowState
479
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
480
480
  ) -> bool:
481
481
  """Execute the 'collect masked vectors' stage."""
482
482
  cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
@@ -491,7 +491,7 @@ class SecAggPlusWorkflow:
491
491
  cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
492
492
  content = state.nid_to_fitins[nid]
493
493
  content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
494
- return driver.create_message(
494
+ return grid.create_message(
495
495
  content=content,
496
496
  message_type=MessageType.TRAIN,
497
497
  dst_node_id=nid,
@@ -503,7 +503,7 @@ class SecAggPlusWorkflow:
503
503
  "[Stage 2] Forwarding encrypted key shares to %s clients.",
504
504
  len(state.active_node_ids),
505
505
  )
506
- msgs = driver.send_and_receive(
506
+ msgs = grid.send_and_receive(
507
507
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
508
508
  )
509
509
  state.active_node_ids = {
@@ -540,14 +540,14 @@ class SecAggPlusWorkflow:
540
540
  if msg.has_error():
541
541
  state.failures.append(Exception(msg.error))
542
542
  continue
543
- fitres = compat.recordset_to_fitres(msg.content, True)
543
+ fitres = compat.recorddict_to_fitres(msg.content, True)
544
544
  proxy = state.nid_to_proxies[msg.metadata.src_node_id]
545
545
  state.legacy_results.append((proxy, fitres))
546
546
 
547
547
  return self._check_threshold(state)
548
548
 
549
549
  def unmask_stage( # pylint: disable=R0912, R0914, R0915
550
- self, driver: Driver, context: LegacyContext, state: WorkflowState
550
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
551
551
  ) -> bool:
552
552
  """Execute the 'unmask' stage."""
553
553
  cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
@@ -566,8 +566,8 @@ class SecAggPlusWorkflow:
566
566
  Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
567
567
  }
568
568
  cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
569
- content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
570
- return driver.create_message(
569
+ content = RecordDict({RECORD_KEY_CONFIGS: cfgs_record})
570
+ return grid.create_message(
571
571
  content=content,
572
572
  message_type=MessageType.TRAIN,
573
573
  dst_node_id=nid,
@@ -579,7 +579,7 @@ class SecAggPlusWorkflow:
579
579
  "[Stage 3] Requesting key shares from %s clients to remove masks.",
580
580
  len(state.active_node_ids),
581
581
  )
582
- msgs = driver.send_and_receive(
582
+ msgs = grid.send_and_receive(
583
583
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
584
584
  )
585
585
  state.active_node_ids = {