flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  2. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  3. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  11. flwr/cli/run/run.py +5 -9
  12. flwr/client/app.py +6 -4
  13. flwr/client/client_app.py +162 -99
  14. flwr/client/clientapp/app.py +2 -2
  15. flwr/client/grpc_client/connection.py +24 -21
  16. flwr/client/message_handler/message_handler.py +27 -27
  17. flwr/client/mod/__init__.py +2 -2
  18. flwr/client/mod/centraldp_mods.py +7 -7
  19. flwr/client/mod/comms_mods.py +16 -22
  20. flwr/client/mod/localdp_mod.py +4 -4
  21. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  22. flwr/client/run_info_store.py +2 -2
  23. flwr/common/__init__.py +12 -4
  24. flwr/common/config.py +4 -4
  25. flwr/common/constant.py +6 -6
  26. flwr/common/context.py +4 -4
  27. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  28. flwr/common/logger.py +2 -2
  29. flwr/common/message.py +327 -102
  30. flwr/common/record/__init__.py +8 -4
  31. flwr/common/record/arrayrecord.py +626 -0
  32. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  33. flwr/common/record/conversion_utils.py +1 -1
  34. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  35. flwr/common/record/recorddict.py +288 -0
  36. flwr/common/recorddict_compat.py +410 -0
  37. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  38. flwr/common/serde.py +66 -71
  39. flwr/common/typing.py +8 -8
  40. flwr/proto/exec_pb2.py +3 -3
  41. flwr/proto/exec_pb2.pyi +3 -3
  42. flwr/proto/message_pb2.py +12 -12
  43. flwr/proto/message_pb2.pyi +9 -9
  44. flwr/proto/recorddict_pb2.py +70 -0
  45. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  46. flwr/proto/run_pb2.py +31 -31
  47. flwr/proto/run_pb2.pyi +3 -3
  48. flwr/server/__init__.py +3 -1
  49. flwr/server/app.py +56 -1
  50. flwr/server/compat/__init__.py +2 -2
  51. flwr/server/compat/app.py +11 -11
  52. flwr/server/compat/app_utils.py +16 -16
  53. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
  54. flwr/server/fleet_event_log_interceptor.py +94 -0
  55. flwr/server/{driver → grid}/__init__.py +8 -7
  56. flwr/server/{driver/driver.py → grid/grid.py} +47 -18
  57. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
  58. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
  59. flwr/server/run_serverapp.py +4 -4
  60. flwr/server/server_app.py +38 -18
  61. flwr/server/serverapp/app.py +10 -10
  62. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  63. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  64. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  65. flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
  66. flwr/server/superlink/linkstate/linkstate.py +4 -4
  67. flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
  68. flwr/server/superlink/linkstate/utils.py +93 -27
  69. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  70. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  71. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
  72. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  73. flwr/server/typing.py +3 -3
  74. flwr/server/utils/validator.py +4 -4
  75. flwr/server/workflow/default_workflows.py +48 -57
  76. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  77. flwr/simulation/app.py +2 -2
  78. flwr/simulation/ray_transport/ray_actor.py +4 -2
  79. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  80. flwr/simulation/run_simulation.py +15 -15
  81. flwr/superexec/deployment.py +4 -4
  82. flwr/superexec/exec_event_log_interceptor.py +135 -0
  83. flwr/superexec/exec_grpc.py +10 -4
  84. flwr/superexec/exec_servicer.py +2 -2
  85. flwr/superexec/exec_user_auth_interceptor.py +18 -2
  86. flwr/superexec/executor.py +3 -3
  87. flwr/superexec/simulation.py +3 -3
  88. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
  89. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
  90. flwr/common/record/parametersrecord.py +0 -339
  91. flwr/common/record/recordset.py +0 -209
  92. flwr/common/recordset_compat.py +0 -418
  93. flwr/proto/recordset_pb2.py +0 -70
  94. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  95. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  96. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  97. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
  98. {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
@@ -20,15 +20,16 @@ import timeit
20
20
  from logging import INFO, WARN
21
21
  from typing import Optional, Union, cast
22
22
 
23
- import flwr.common.recordset_compat as compat
23
+ import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
25
+ ArrayRecord,
25
26
  Code,
26
- ConfigsRecord,
27
+ ConfigRecord,
27
28
  Context,
28
29
  EvaluateRes,
29
30
  FitRes,
30
31
  GetParametersIns,
31
- ParametersRecord,
32
+ Message,
32
33
  log,
33
34
  )
34
35
  from flwr.common.constant import MessageType, MessageTypeLegacy
@@ -36,7 +37,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
36
37
  from ..client_proxy import ClientProxy
37
38
  from ..compat.app_utils import start_update_client_manager_thread
38
39
  from ..compat.legacy_context import LegacyContext
39
- from ..driver import Driver
40
+ from ..grid import Grid
40
41
  from ..typing import Workflow
41
42
  from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
42
43
 
@@ -56,7 +57,7 @@ class DefaultWorkflow:
56
57
  self.fit_workflow: Workflow = fit_workflow
57
58
  self.evaluate_workflow: Workflow = evaluate_workflow
58
59
 
59
- def __call__(self, driver: Driver, context: Context) -> None:
60
+ def __call__(self, grid: Grid, context: Context) -> None:
60
61
  """Execute the workflow."""
61
62
  if not isinstance(context, LegacyContext):
62
63
  raise TypeError(
@@ -65,7 +66,7 @@ class DefaultWorkflow:
65
66
 
66
67
  # Start the thread updating nodes
67
68
  thread, f_stop, c_done = start_update_client_manager_thread(
68
- driver, context.client_manager
69
+ grid, context.client_manager
69
70
  )
70
71
 
71
72
  # Wait until the node registration done
@@ -73,13 +74,13 @@ class DefaultWorkflow:
73
74
 
74
75
  # Initialize parameters
75
76
  log(INFO, "[INIT]")
76
- default_init_params_workflow(driver, context)
77
+ default_init_params_workflow(grid, context)
77
78
 
78
79
  # Run federated learning for num_rounds
79
80
  start_time = timeit.default_timer()
80
- cfg = ConfigsRecord()
81
+ cfg = ConfigRecord()
81
82
  cfg[Key.START_TIME] = start_time
82
- context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
83
+ context.state.config_records[MAIN_CONFIGS_RECORD] = cfg
83
84
 
84
85
  for current_round in range(1, context.config.num_rounds + 1):
85
86
  log(INFO, "")
@@ -87,13 +88,13 @@ class DefaultWorkflow:
87
88
  cfg[Key.CURRENT_ROUND] = current_round
88
89
 
89
90
  # Fit round
90
- self.fit_workflow(driver, context)
91
+ self.fit_workflow(grid, context)
91
92
 
92
93
  # Centralized evaluation
93
- default_centralized_evaluation_workflow(driver, context)
94
+ default_centralized_evaluation_workflow(grid, context)
94
95
 
95
96
  # Evaluate round
96
- self.evaluate_workflow(driver, context)
97
+ self.evaluate_workflow(grid, context)
97
98
 
98
99
  # Bookkeeping and log results
99
100
  end_time = timeit.default_timer()
@@ -119,7 +120,7 @@ class DefaultWorkflow:
119
120
  thread.join()
120
121
 
121
122
 
122
- def default_init_params_workflow(driver: Driver, context: Context) -> None:
123
+ def default_init_params_workflow(grid: Grid, context: Context) -> None:
123
124
  """Execute the default workflow for parameters initialization."""
124
125
  if not isinstance(context, LegacyContext):
125
126
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -129,21 +130,19 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
129
130
  )
130
131
  if parameters is not None:
131
132
  log(INFO, "Using initial global parameters provided by strategy")
132
- paramsrecord = compat.parameters_to_parametersrecord(
133
- parameters, keep_input=True
134
- )
133
+ arr_record = compat.parameters_to_arrayrecord(parameters, keep_input=True)
135
134
  else:
136
135
  # Get initial parameters from one of the clients
137
136
  log(INFO, "Requesting initial parameters from one random client")
138
137
  random_client = context.client_manager.sample(1)[0]
139
138
  # Send GetParametersIns and get the response
140
- content = compat.getparametersins_to_recordset(GetParametersIns({}))
141
- messages = driver.send_and_receive(
139
+ content = compat.getparametersins_to_recorddict(GetParametersIns({}))
140
+ messages = grid.send_and_receive(
142
141
  [
143
- driver.create_message(
142
+ Message(
144
143
  content=content,
145
- message_type=MessageTypeLegacy.GET_PARAMETERS,
146
144
  dst_node_id=random_client.node_id,
145
+ message_type=MessageTypeLegacy.GET_PARAMETERS,
147
146
  group_id="0",
148
147
  )
149
148
  ]
@@ -152,26 +151,26 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
152
151
 
153
152
  if (
154
153
  msg.has_content()
155
- and compat._extract_status_from_recordset( # pylint: disable=W0212
154
+ and compat._extract_status_from_recorddict( # pylint: disable=W0212
156
155
  "getparametersres", msg.content
157
156
  ).code
158
157
  == Code.OK
159
158
  ):
160
159
  log(INFO, "Received initial parameters from one random client")
161
- paramsrecord = next(iter(msg.content.parameters_records.values()))
160
+ arr_record = next(iter(msg.content.array_records.values()))
162
161
  else:
163
162
  log(
164
163
  WARN,
165
164
  "Failed to receive initial parameters from the client."
166
165
  " Empty initial parameters will be used.",
167
166
  )
168
- paramsrecord = ParametersRecord()
167
+ arr_record = ArrayRecord()
169
168
 
170
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
169
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
171
170
 
172
171
  # Evaluate initial parameters
173
172
  log(INFO, "Starting evaluation of initial global parameters")
174
- parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
173
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
175
174
  res = context.strategy.evaluate(0, parameters=parameters)
176
175
  if res is not None:
177
176
  log(
@@ -186,19 +185,19 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
186
185
  log(INFO, "Evaluation returned no results (`None`)")
187
186
 
188
187
 
189
- def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None:
188
+ def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
190
189
  """Execute the default workflow for centralized evaluation."""
191
190
  if not isinstance(context, LegacyContext):
192
191
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
193
192
 
194
193
  # Retrieve current_round and start_time from the context
195
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
194
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
196
195
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
197
196
  start_time = cast(float, cfg[Key.START_TIME])
198
197
 
199
198
  # Centralized evaluation
200
- parameters = compat.parametersrecord_to_parameters(
201
- record=context.state.parameters_records[MAIN_PARAMS_RECORD],
199
+ parameters = compat.arrayrecord_to_parameters(
200
+ record=context.state.array_records[MAIN_PARAMS_RECORD],
202
201
  keep_input=True,
203
202
  )
204
203
  res_cen = context.strategy.evaluate(current_round, parameters=parameters)
@@ -218,20 +217,16 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
218
217
  )
219
218
 
220
219
 
221
- def default_fit_workflow( # pylint: disable=R0914
222
- driver: Driver, context: Context
223
- ) -> None:
220
+ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disable=R0914
224
221
  """Execute the default workflow for a single fit round."""
225
222
  if not isinstance(context, LegacyContext):
226
223
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
227
224
 
228
225
  # Get current_round and parameters
229
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
226
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
230
227
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
231
- parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
232
- parameters = compat.parametersrecord_to_parameters(
233
- parametersrecord, keep_input=True
234
- )
228
+ arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
229
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
235
230
 
236
231
  # Get clients and their respective instructions from strategy
237
232
  client_instructions = context.strategy.configure_fit(
@@ -255,10 +250,10 @@ def default_fit_workflow( # pylint: disable=R0914
255
250
 
256
251
  # Build out messages
257
252
  out_messages = [
258
- driver.create_message(
259
- content=compat.fitins_to_recordset(fitins, True),
260
- message_type=MessageType.TRAIN,
253
+ Message(
254
+ content=compat.fitins_to_recorddict(fitins, True),
261
255
  dst_node_id=proxy.node_id,
256
+ message_type=MessageType.TRAIN,
262
257
  group_id=str(current_round),
263
258
  )
264
259
  for proxy, fitins in client_instructions
@@ -266,7 +261,7 @@ def default_fit_workflow( # pylint: disable=R0914
266
261
 
267
262
  # Send instructions to clients and
268
263
  # collect `fit` results from all clients participating in this round
269
- messages = list(driver.send_and_receive(out_messages))
264
+ messages = list(grid.send_and_receive(out_messages))
270
265
  del out_messages
271
266
  num_failures = len([msg for msg in messages if msg.has_error()])
272
267
 
@@ -284,7 +279,7 @@ def default_fit_workflow( # pylint: disable=R0914
284
279
  for msg in messages:
285
280
  if msg.has_content():
286
281
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
287
- fitres = compat.recordset_to_fitres(msg.content, False)
282
+ fitres = compat.recorddict_to_fitres(msg.content, False)
288
283
  if fitres.status.code == Code.OK:
289
284
  results.append((proxy, fitres))
290
285
  else:
@@ -297,28 +292,24 @@ def default_fit_workflow( # pylint: disable=R0914
297
292
 
298
293
  # Update the parameters and write history
299
294
  if parameters_aggregated:
300
- paramsrecord = compat.parameters_to_parametersrecord(
301
- parameters_aggregated, True
302
- )
303
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
295
+ arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
296
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
304
297
  context.history.add_metrics_distributed_fit(
305
298
  server_round=current_round, metrics=metrics_aggregated
306
299
  )
307
300
 
308
301
 
309
302
  # pylint: disable-next=R0914
310
- def default_evaluate_workflow(driver: Driver, context: Context) -> None:
303
+ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
311
304
  """Execute the default workflow for a single evaluate round."""
312
305
  if not isinstance(context, LegacyContext):
313
306
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
314
307
 
315
308
  # Get current_round and parameters
316
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
309
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
317
310
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
318
- parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
319
- parameters = compat.parametersrecord_to_parameters(
320
- parametersrecord, keep_input=True
321
- )
311
+ arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
312
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
322
313
 
323
314
  # Get clients and their respective instructions from strategy
324
315
  client_instructions = context.strategy.configure_evaluate(
@@ -341,10 +332,10 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
341
332
 
342
333
  # Build out messages
343
334
  out_messages = [
344
- driver.create_message(
345
- content=compat.evaluateins_to_recordset(evalins, True),
346
- message_type=MessageType.EVALUATE,
335
+ Message(
336
+ content=compat.evaluateins_to_recorddict(evalins, True),
347
337
  dst_node_id=proxy.node_id,
338
+ message_type=MessageType.EVALUATE,
348
339
  group_id=str(current_round),
349
340
  )
350
341
  for proxy, evalins in client_instructions
@@ -352,7 +343,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
352
343
 
353
344
  # Send instructions to clients and
354
345
  # collect `evaluate` results from all clients participating in this round
355
- messages = list(driver.send_and_receive(out_messages))
346
+ messages = list(grid.send_and_receive(out_messages))
356
347
  del out_messages
357
348
  num_failures = len([msg for msg in messages if msg.has_error()])
358
349
 
@@ -370,7 +361,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
370
361
  for msg in messages:
371
362
  if msg.has_content():
372
363
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
373
- evalres = compat.recordset_to_evaluateres(msg.content)
364
+ evalres = compat.recorddict_to_evaluateres(msg.content)
374
365
  if evalres.status.code == Code.OK:
375
366
  results.append((proxy, evalres))
376
367
  else:
@@ -20,15 +20,15 @@ from dataclasses import dataclass, field
20
20
  from logging import DEBUG, ERROR, INFO, WARN
21
21
  from typing import Optional, Union, cast
22
22
 
23
- import flwr.common.recordset_compat as compat
23
+ import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
25
- ConfigsRecord,
25
+ ConfigRecord,
26
26
  Context,
27
27
  FitRes,
28
28
  Message,
29
29
  MessageType,
30
30
  NDArrays,
31
- RecordSet,
31
+ RecordDict,
32
32
  bytes_to_ndarray,
33
33
  log,
34
34
  ndarrays_to_parameters,
@@ -55,7 +55,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
55
55
  from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
56
56
  from flwr.server.client_proxy import ClientProxy
57
57
  from flwr.server.compat.legacy_context import LegacyContext
58
- from flwr.server.driver import Driver
58
+ from flwr.server.grid import Grid
59
59
 
60
60
  from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
61
61
  from ..constant import Key as WorkflowKey
@@ -66,7 +66,7 @@ class WorkflowState: # pylint: disable=R0902
66
66
  """The state of the SecAgg+ protocol."""
67
67
 
68
68
  nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict)
69
- nid_to_fitins: dict[int, RecordSet] = field(default_factory=dict)
69
+ nid_to_fitins: dict[int, RecordDict] = field(default_factory=dict)
70
70
  sampled_node_ids: set[int] = field(default_factory=set)
71
71
  active_node_ids: set[int] = field(default_factory=set)
72
72
  num_shares: int = 0
@@ -186,7 +186,7 @@ class SecAggPlusWorkflow:
186
186
 
187
187
  self._check_init_params()
188
188
 
189
- def __call__(self, driver: Driver, context: Context) -> None:
189
+ def __call__(self, grid: Grid, context: Context) -> None:
190
190
  """Run the SecAgg+ protocol."""
191
191
  if not isinstance(context, LegacyContext):
192
192
  raise TypeError(
@@ -202,7 +202,7 @@ class SecAggPlusWorkflow:
202
202
  )
203
203
  log(INFO, "Secure aggregation commencing.")
204
204
  for step in steps:
205
- if not step(driver, context, state):
205
+ if not step(grid, context, state):
206
206
  log(INFO, "Secure aggregation halted.")
207
207
  return
208
208
  log(INFO, "Secure aggregation completed.")
@@ -279,14 +279,14 @@ class SecAggPlusWorkflow:
279
279
  return True
280
280
 
281
281
  def setup_stage( # pylint: disable=R0912, R0914, R0915
282
- self, driver: Driver, context: LegacyContext, state: WorkflowState
282
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
283
283
  ) -> bool:
284
284
  """Execute the 'setup' stage."""
285
285
  # Obtain fit instructions
286
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
286
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
287
287
  current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
288
- parameters = compat.parametersrecord_to_parameters(
289
- context.state.parameters_records[MAIN_PARAMS_RECORD],
288
+ parameters = compat.arrayrecord_to_parameters(
289
+ context.state.array_records[MAIN_PARAMS_RECORD],
290
290
  keep_input=True,
291
291
  )
292
292
  proxy_fitins_lst = context.strategy.configure_fit(
@@ -303,7 +303,7 @@ class SecAggPlusWorkflow:
303
303
  )
304
304
 
305
305
  state.nid_to_fitins = {
306
- proxy.node_id: compat.fitins_to_recordset(fitins, True)
306
+ proxy.node_id: compat.fitins_to_recorddict(fitins, True)
307
307
  for proxy, fitins in proxy_fitins_lst
308
308
  }
309
309
  state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
@@ -366,14 +366,14 @@ class SecAggPlusWorkflow:
366
366
  state.sampled_node_ids = state.active_node_ids
367
367
 
368
368
  # Send setup configuration to clients
369
- cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
370
- content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
369
+ cfg_record = ConfigRecord(sa_params_dict) # type: ignore
370
+ content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
371
371
 
372
372
  def make(nid: int) -> Message:
373
- return driver.create_message(
373
+ return Message(
374
374
  content=content,
375
- message_type=MessageType.TRAIN,
376
375
  dst_node_id=nid,
376
+ message_type=MessageType.TRAIN,
377
377
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
378
378
  )
379
379
 
@@ -382,7 +382,7 @@ class SecAggPlusWorkflow:
382
382
  "[Stage 0] Sending configurations to %s clients.",
383
383
  len(state.active_node_ids),
384
384
  )
385
- msgs = driver.send_and_receive(
385
+ msgs = grid.send_and_receive(
386
386
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
387
387
  )
388
388
  state.active_node_ids = {
@@ -398,7 +398,7 @@ class SecAggPlusWorkflow:
398
398
  if msg.has_error():
399
399
  state.failures.append(Exception(msg.error))
400
400
  continue
401
- key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
401
+ key_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
402
402
  node_id = msg.metadata.src_node_id
403
403
  pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
404
404
  state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
@@ -406,22 +406,22 @@ class SecAggPlusWorkflow:
406
406
  return self._check_threshold(state)
407
407
 
408
408
  def share_keys_stage( # pylint: disable=R0914
409
- self, driver: Driver, context: LegacyContext, state: WorkflowState
409
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
410
410
  ) -> bool:
411
411
  """Execute the 'share keys' stage."""
412
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
412
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
413
413
 
414
414
  def make(nid: int) -> Message:
415
415
  neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
416
- cfgs_record = ConfigsRecord(
416
+ cfg_record = ConfigRecord(
417
417
  {str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
418
418
  )
419
- cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
420
- content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
421
- return driver.create_message(
419
+ cfg_record[Key.STAGE] = Stage.SHARE_KEYS
420
+ content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
421
+ return Message(
422
422
  content=content,
423
- message_type=MessageType.TRAIN,
424
423
  dst_node_id=nid,
424
+ message_type=MessageType.TRAIN,
425
425
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
426
426
  )
427
427
 
@@ -431,7 +431,7 @@ class SecAggPlusWorkflow:
431
431
  "[Stage 1] Forwarding public keys to %s clients.",
432
432
  len(state.active_node_ids),
433
433
  )
434
- msgs = driver.send_and_receive(
434
+ msgs = grid.send_and_receive(
435
435
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
436
436
  )
437
437
  state.active_node_ids = {
@@ -458,7 +458,7 @@ class SecAggPlusWorkflow:
458
458
  state.failures.append(Exception(msg.error))
459
459
  continue
460
460
  node_id = msg.metadata.src_node_id
461
- res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
461
+ res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
462
462
  dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST])
463
463
  ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST])
464
464
  srcs += [node_id] * len(dst_lst)
@@ -476,25 +476,25 @@ class SecAggPlusWorkflow:
476
476
  return self._check_threshold(state)
477
477
 
478
478
  def collect_masked_vectors_stage(
479
- self, driver: Driver, context: LegacyContext, state: WorkflowState
479
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
480
480
  ) -> bool:
481
481
  """Execute the 'collect masked vectors' stage."""
482
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
482
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
483
483
 
484
484
  # Send secret key shares to clients (plus FitIns) and collect masked vectors
485
485
  def make(nid: int) -> Message:
486
- cfgs_dict = {
486
+ cfg_dict = {
487
487
  Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
488
488
  Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
489
489
  Key.SOURCE_LIST: state.forward_srcs[nid],
490
490
  }
491
- cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
491
+ cfg_record = ConfigRecord(cfg_dict) # type: ignore
492
492
  content = state.nid_to_fitins[nid]
493
- content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
494
- return driver.create_message(
493
+ content.config_records[RECORD_KEY_CONFIGS] = cfg_record
494
+ return Message(
495
495
  content=content,
496
- message_type=MessageType.TRAIN,
497
496
  dst_node_id=nid,
497
+ message_type=MessageType.TRAIN,
498
498
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
499
499
  )
500
500
 
@@ -503,7 +503,7 @@ class SecAggPlusWorkflow:
503
503
  "[Stage 2] Forwarding encrypted key shares to %s clients.",
504
504
  len(state.active_node_ids),
505
505
  )
506
- msgs = driver.send_and_receive(
506
+ msgs = grid.send_and_receive(
507
507
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
508
508
  )
509
509
  state.active_node_ids = {
@@ -524,7 +524,7 @@ class SecAggPlusWorkflow:
524
524
  if msg.has_error():
525
525
  state.failures.append(Exception(msg.error))
526
526
  continue
527
- res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
527
+ res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
528
528
  bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS])
529
529
  client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
530
530
  if masked_vector is None:
@@ -540,17 +540,17 @@ class SecAggPlusWorkflow:
540
540
  if msg.has_error():
541
541
  state.failures.append(Exception(msg.error))
542
542
  continue
543
- fitres = compat.recordset_to_fitres(msg.content, True)
543
+ fitres = compat.recorddict_to_fitres(msg.content, True)
544
544
  proxy = state.nid_to_proxies[msg.metadata.src_node_id]
545
545
  state.legacy_results.append((proxy, fitres))
546
546
 
547
547
  return self._check_threshold(state)
548
548
 
549
549
  def unmask_stage( # pylint: disable=R0912, R0914, R0915
550
- self, driver: Driver, context: LegacyContext, state: WorkflowState
550
+ self, grid: Grid, context: LegacyContext, state: WorkflowState
551
551
  ) -> bool:
552
552
  """Execute the 'unmask' stage."""
553
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
553
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
554
554
  current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
555
555
 
556
556
  # Construct active node IDs and dead node IDs
@@ -560,17 +560,17 @@ class SecAggPlusWorkflow:
560
560
  # Send secure IDs of active and dead clients and collect key shares from clients
561
561
  def make(nid: int) -> Message:
562
562
  neighbours = state.nid_to_neighbours[nid]
563
- cfgs_dict = {
563
+ cfg_dict = {
564
564
  Key.STAGE: Stage.UNMASK,
565
565
  Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
566
566
  Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
567
567
  }
568
- cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
569
- content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
570
- return driver.create_message(
568
+ cfg_record = ConfigRecord(cfg_dict) # type: ignore
569
+ content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
570
+ return Message(
571
571
  content=content,
572
- message_type=MessageType.TRAIN,
573
572
  dst_node_id=nid,
573
+ message_type=MessageType.TRAIN,
574
574
  group_id=str(current_round),
575
575
  )
576
576
 
@@ -579,7 +579,7 @@ class SecAggPlusWorkflow:
579
579
  "[Stage 3] Requesting key shares from %s clients to remove masks.",
580
580
  len(state.active_node_ids),
581
581
  )
582
- msgs = driver.send_and_receive(
582
+ msgs = grid.send_and_receive(
583
583
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
584
584
  )
585
585
  state.active_node_ids = {
@@ -599,7 +599,7 @@ class SecAggPlusWorkflow:
599
599
  if msg.has_error():
600
600
  state.failures.append(Exception(msg.error))
601
601
  continue
602
- res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
602
+ res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
603
603
  nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
604
604
  shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
605
605
  for owner_nid, share in zip(nids, shares):
@@ -676,10 +676,8 @@ class SecAggPlusWorkflow:
676
676
 
677
677
  # Update the parameters and write history
678
678
  if parameters_aggregated:
679
- paramsrecord = compat.parameters_to_parametersrecord(
680
- parameters_aggregated, True
681
- )
682
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
679
+ arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
680
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
683
681
  context.history.add_metrics_distributed_fit(
684
682
  server_round=current_round, metrics=metrics_aggregated
685
683
  )
flwr/simulation/app.py CHANGED
@@ -47,7 +47,7 @@ from flwr.common.logger import (
47
47
  stop_log_uploader,
48
48
  )
49
49
  from flwr.common.serde import (
50
- configs_record_from_proto,
50
+ config_record_from_proto,
51
51
  context_from_proto,
52
52
  context_to_proto,
53
53
  fab_from_proto,
@@ -184,7 +184,7 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
184
184
  fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
185
185
  GetFederationOptionsRequest(run_id=run.run_id)
186
186
  )
187
- federation_options = configs_record_from_proto(
187
+ federation_options = config_record_from_proto(
188
188
  fed_opt_res.federation_options
189
189
  )
190
190
 
@@ -105,8 +105,10 @@ def pool_size_from_resources(client_resources: dict[str, Union[int, float]]) ->
105
105
  if not node_resources:
106
106
  continue
107
107
 
108
- num_cpus = node_resources["CPU"]
109
- num_gpus = node_resources.get("GPU", 0) # There might not be GPU
108
+ # Fallback to zero when resource quantity is not configured on the ray node
109
+ # e.g.: node without GPU; head node set up not to run tasks (zero resources)
110
+ num_cpus = node_resources.get("CPU", 0)
111
+ num_gpus = node_resources.get("GPU", 0)
110
112
  num_actors = int(num_cpus / client_resources["num_cpus"])
111
113
 
112
114
  # If a GPU is present and client resources do require one