flwr-nightly 1.17.0.dev20250320__py3-none-any.whl → 1.17.0.dev20250322__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. flwr/cli/run/run.py +5 -9
  2. flwr/client/client_app.py +10 -12
  3. flwr/client/grpc_client/connection.py +3 -3
  4. flwr/client/message_handler/message_handler.py +3 -3
  5. flwr/client/mod/__init__.py +2 -2
  6. flwr/client/mod/comms_mods.py +16 -22
  7. flwr/client/mod/secure_aggregation/secaggplus_mod.py +26 -26
  8. flwr/common/__init__.py +10 -4
  9. flwr/common/config.py +4 -4
  10. flwr/common/constant.py +1 -1
  11. flwr/common/record/__init__.py +6 -3
  12. flwr/common/record/{parametersrecord.py → arrayrecord.py} +74 -31
  13. flwr/common/record/{configsrecord.py → configrecord.py} +73 -27
  14. flwr/common/record/conversion_utils.py +1 -1
  15. flwr/common/record/{metricsrecord.py → metricrecord.py} +77 -31
  16. flwr/common/record/recorddict.py +95 -56
  17. flwr/common/recorddict_compat.py +54 -62
  18. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  19. flwr/common/serde.py +42 -43
  20. flwr/common/typing.py +8 -8
  21. flwr/proto/exec_pb2.py +30 -30
  22. flwr/proto/exec_pb2.pyi +2 -2
  23. flwr/proto/recorddict_pb2.py +29 -29
  24. flwr/proto/recorddict_pb2.pyi +33 -33
  25. flwr/proto/run_pb2.py +2 -2
  26. flwr/proto/run_pb2.pyi +2 -2
  27. flwr/server/compat/grid_client_proxy.py +1 -1
  28. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  29. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  30. flwr/server/superlink/linkstate/in_memory_linkstate.py +4 -4
  31. flwr/server/superlink/linkstate/linkstate.py +4 -4
  32. flwr/server/superlink/linkstate/sqlite_linkstate.py +7 -7
  33. flwr/server/superlink/linkstate/utils.py +9 -9
  34. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  35. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  36. flwr/server/workflow/default_workflows.py +27 -34
  37. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +32 -34
  38. flwr/simulation/app.py +2 -2
  39. flwr/simulation/ray_transport/ray_actor.py +4 -2
  40. flwr/simulation/run_simulation.py +2 -2
  41. flwr/superexec/deployment.py +3 -3
  42. flwr/superexec/exec_servicer.py +2 -2
  43. flwr/superexec/executor.py +3 -3
  44. flwr/superexec/simulation.py +2 -2
  45. {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/METADATA +1 -1
  46. {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/RECORD +49 -49
  47. {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/LICENSE +0 -0
  48. {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/WHEEL +0 -0
  49. {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/entry_points.txt +0 -0
@@ -26,7 +26,7 @@ from flwr.common.constant import PARTITION_ID_KEY
26
26
  from flwr.common.context import Context
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.message import Message
29
- from flwr.common.typing import ConfigsRecordValues
29
+ from flwr.common.typing import ConfigRecordValues
30
30
  from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
31
31
  from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
32
32
 
@@ -104,7 +104,7 @@ class RayBackend(Backend):
104
104
  if not ray.is_initialized():
105
105
  ray_init_args: dict[
106
106
  str,
107
- ConfigsRecordValues,
107
+ ConfigRecordValues,
108
108
  ] = {}
109
109
 
110
110
  if backend_config.get(self.init_args_key):
@@ -32,7 +32,7 @@ from flwr.common.constant import (
32
32
  SUPERLINK_NODE_ID,
33
33
  Status,
34
34
  )
35
- from flwr.common.record import ConfigsRecord
35
+ from flwr.common.record import ConfigRecord
36
36
  from flwr.common.typing import Run, RunStatus, UserConfig
37
37
  from flwr.server.superlink.linkstate.linkstate import LinkState
38
38
  from flwr.server.utils import validate_message
@@ -69,7 +69,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
69
69
  # Map run_id to RunRecord
70
70
  self.run_ids: dict[int, RunRecord] = {}
71
71
  self.contexts: dict[int, Context] = {}
72
- self.federation_options: dict[int, ConfigsRecord] = {}
72
+ self.federation_options: dict[int, ConfigRecord] = {}
73
73
  self.message_ins_store: dict[UUID, Message] = {}
74
74
  self.message_res_store: dict[UUID, Message] = {}
75
75
  self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
@@ -399,7 +399,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
399
399
  fab_version: Optional[str],
400
400
  fab_hash: Optional[str],
401
401
  override_config: UserConfig,
402
- federation_options: ConfigsRecord,
402
+ federation_options: ConfigRecord,
403
403
  ) -> int:
404
404
  """Create a new run for the specified `fab_hash`."""
405
405
  # Sample a random int64 as run_id
@@ -528,7 +528,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
528
528
 
529
529
  return pending_run_id
530
530
 
531
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
531
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
532
532
  """Retrieve the federation options for the specified `run_id`."""
533
533
  with self.lock:
534
534
  if run_id not in self.run_ids:
@@ -20,7 +20,7 @@ from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.common import Context, Message
23
- from flwr.common.record import ConfigsRecord
23
+ from flwr.common.record import ConfigRecord
24
24
  from flwr.common.typing import Run, RunStatus, UserConfig
25
25
 
26
26
 
@@ -164,7 +164,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
164
164
  fab_version: Optional[str],
165
165
  fab_hash: Optional[str],
166
166
  override_config: UserConfig,
167
- federation_options: ConfigsRecord,
167
+ federation_options: ConfigRecord,
168
168
  ) -> int:
169
169
  """Create a new run for the specified `fab_hash`."""
170
170
 
@@ -236,7 +236,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
236
236
  """
237
237
 
238
238
  @abc.abstractmethod
239
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
239
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
240
240
  """Retrieve the federation options for the specified `run_id`.
241
241
 
242
242
  Parameters
@@ -246,7 +246,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
246
246
 
247
247
  Returns
248
248
  -------
249
- Optional[ConfigsRecord]
249
+ Optional[ConfigRecord]
250
250
  The federation options for the run if it exists; None otherwise.
251
251
  """
252
252
 
@@ -36,7 +36,7 @@ from flwr.common.constant import (
36
36
  Status,
37
37
  )
38
38
  from flwr.common.message import make_message
39
- from flwr.common.record import ConfigsRecord
39
+ from flwr.common.record import ConfigRecord
40
40
  from flwr.common.serde import (
41
41
  error_from_proto,
42
42
  error_to_proto,
@@ -55,8 +55,8 @@ from flwr.server.utils.validator import validate_message
55
55
  from .linkstate import LinkState
56
56
  from .utils import (
57
57
  check_node_availability_for_in_message,
58
- configsrecord_from_bytes,
59
- configsrecord_to_bytes,
58
+ configrecord_from_bytes,
59
+ configrecord_to_bytes,
60
60
  context_from_bytes,
61
61
  context_to_bytes,
62
62
  convert_sint64_to_uint64,
@@ -727,7 +727,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
727
727
  fab_version: Optional[str],
728
728
  fab_hash: Optional[str],
729
729
  override_config: UserConfig,
730
- federation_options: ConfigsRecord,
730
+ federation_options: ConfigRecord,
731
731
  ) -> int:
732
732
  """Create a new run for the specified `fab_id` and `fab_version`."""
733
733
  # Sample a random int64 as run_id
@@ -753,7 +753,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
753
753
  fab_version,
754
754
  fab_hash,
755
755
  override_config_json,
756
- configsrecord_to_bytes(federation_options),
756
+ configrecord_to_bytes(federation_options),
757
757
  ]
758
758
  data += [
759
759
  now().isoformat(),
@@ -911,7 +911,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
911
911
 
912
912
  return pending_run_id
913
913
 
914
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
914
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
915
915
  """Retrieve the federation options for the specified `run_id`."""
916
916
  # Convert the uint64 value to sint64 for SQLite
917
917
  sint64_run_id = convert_uint64_to_sint64(run_id)
@@ -924,7 +924,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
924
924
  return None
925
925
 
926
926
  row = rows[0]
927
- return configsrecord_from_bytes(row["federation_options"])
927
+ return configrecord_from_bytes(row["federation_options"])
928
928
 
929
929
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
930
930
  """Acknowledge a ping received from a node, serving as a heartbeat.
@@ -19,7 +19,7 @@ from os import urandom
19
19
  from typing import Optional
20
20
  from uuid import UUID, uuid4
21
21
 
22
- from flwr.common import ConfigsRecord, Context, Error, Message, Metadata, now, serde
22
+ from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
23
23
  from flwr.common.constant import (
24
24
  SUPERLINK_NODE_ID,
25
25
  ErrorCode,
@@ -32,7 +32,7 @@ from flwr.common.typing import RunStatus
32
32
 
33
33
  # pylint: disable=E0611
34
34
  from flwr.proto.message_pb2 import Context as ProtoContext
35
- from flwr.proto.recorddict_pb2 import ConfigsRecord as ProtoConfigsRecord
35
+ from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
36
36
 
37
37
  # pylint: enable=E0611
38
38
  VALID_RUN_STATUS_TRANSITIONS = {
@@ -172,15 +172,15 @@ def context_from_bytes(context_bytes: bytes) -> Context:
172
172
  return serde.context_from_proto(ProtoContext.FromString(context_bytes))
173
173
 
174
174
 
175
- def configsrecord_to_bytes(configs_record: ConfigsRecord) -> bytes:
176
- """Serialize a `ConfigsRecord` to bytes."""
177
- return serde.configs_record_to_proto(configs_record).SerializeToString()
175
+ def configrecord_to_bytes(config_record: ConfigRecord) -> bytes:
176
+ """Serialize a `ConfigRecord` to bytes."""
177
+ return serde.config_record_to_proto(config_record).SerializeToString()
178
178
 
179
179
 
180
- def configsrecord_from_bytes(configsrecord_bytes: bytes) -> ConfigsRecord:
181
- """Deserialize `ConfigsRecord` from bytes."""
182
- return serde.configs_record_from_proto(
183
- ProtoConfigsRecord.FromString(configsrecord_bytes)
180
+ def configrecord_from_bytes(configrecord_bytes: bytes) -> ConfigRecord:
181
+ """Deserialize `ConfigRecord` from bytes."""
182
+ return serde.config_record_from_proto(
183
+ ProtoConfigRecord.FromString(configrecord_bytes)
184
184
  )
185
185
 
186
186
 
@@ -22,7 +22,7 @@ from uuid import UUID
22
22
 
23
23
  import grpc
24
24
 
25
- from flwr.common import ConfigsRecord, Message
25
+ from flwr.common import ConfigRecord, Message
26
26
  from flwr.common.constant import SUPERLINK_NODE_ID, Status
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.serde import (
@@ -127,7 +127,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
127
127
  request.fab_version,
128
128
  fab_hash,
129
129
  user_config_from_proto(request.override_config),
130
- ConfigsRecord(),
130
+ ConfigRecord(),
131
131
  )
132
132
  return CreateRunResponse(run_id=run_id)
133
133
 
@@ -24,7 +24,7 @@ from grpc import ServicerContext
24
24
  from flwr.common.constant import Status
25
25
  from flwr.common.logger import log
26
26
  from flwr.common.serde import (
27
- configs_record_to_proto,
27
+ config_record_to_proto,
28
28
  context_from_proto,
29
29
  context_to_proto,
30
30
  fab_to_proto,
@@ -182,5 +182,5 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
182
182
  )
183
183
  return GetFederationOptionsResponse()
184
184
  return GetFederationOptionsResponse(
185
- federation_options=configs_record_to_proto(federation_options)
185
+ federation_options=config_record_to_proto(federation_options)
186
186
  )
@@ -22,13 +22,14 @@ from typing import Optional, Union, cast
22
22
 
23
23
  import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
25
+ ArrayRecord,
25
26
  Code,
26
- ConfigsRecord,
27
+ ConfigRecord,
27
28
  Context,
28
29
  EvaluateRes,
29
30
  FitRes,
30
31
  GetParametersIns,
31
- ParametersRecord,
32
+ Message,
32
33
  log,
33
34
  )
34
35
  from flwr.common.constant import MessageType, MessageTypeLegacy
@@ -77,9 +78,9 @@ class DefaultWorkflow:
77
78
 
78
79
  # Run federated learning for num_rounds
79
80
  start_time = timeit.default_timer()
80
- cfg = ConfigsRecord()
81
+ cfg = ConfigRecord()
81
82
  cfg[Key.START_TIME] = start_time
82
- context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
83
+ context.state.config_records[MAIN_CONFIGS_RECORD] = cfg
83
84
 
84
85
  for current_round in range(1, context.config.num_rounds + 1):
85
86
  log(INFO, "")
@@ -129,9 +130,7 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
129
130
  )
130
131
  if parameters is not None:
131
132
  log(INFO, "Using initial global parameters provided by strategy")
132
- paramsrecord = compat.parameters_to_parametersrecord(
133
- parameters, keep_input=True
134
- )
133
+ arr_record = compat.parameters_to_arrayrecord(parameters, keep_input=True)
135
134
  else:
136
135
  # Get initial parameters from one of the clients
137
136
  log(INFO, "Requesting initial parameters from one random client")
@@ -140,10 +139,10 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
140
139
  content = compat.getparametersins_to_recorddict(GetParametersIns({}))
141
140
  messages = grid.send_and_receive(
142
141
  [
143
- grid.create_message(
142
+ Message(
144
143
  content=content,
145
- message_type=MessageTypeLegacy.GET_PARAMETERS,
146
144
  dst_node_id=random_client.node_id,
145
+ message_type=MessageTypeLegacy.GET_PARAMETERS,
147
146
  group_id="0",
148
147
  )
149
148
  ]
@@ -158,20 +157,20 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
158
157
  == Code.OK
159
158
  ):
160
159
  log(INFO, "Received initial parameters from one random client")
161
- paramsrecord = next(iter(msg.content.parameters_records.values()))
160
+ arr_record = next(iter(msg.content.array_records.values()))
162
161
  else:
163
162
  log(
164
163
  WARN,
165
164
  "Failed to receive initial parameters from the client."
166
165
  " Empty initial parameters will be used.",
167
166
  )
168
- paramsrecord = ParametersRecord()
167
+ arr_record = ArrayRecord()
169
168
 
170
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
169
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
171
170
 
172
171
  # Evaluate initial parameters
173
172
  log(INFO, "Starting evaluation of initial global parameters")
174
- parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
173
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
175
174
  res = context.strategy.evaluate(0, parameters=parameters)
176
175
  if res is not None:
177
176
  log(
@@ -192,13 +191,13 @@ def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
192
191
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
193
192
 
194
193
  # Retrieve current_round and start_time from the context
195
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
194
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
196
195
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
197
196
  start_time = cast(float, cfg[Key.START_TIME])
198
197
 
199
198
  # Centralized evaluation
200
- parameters = compat.parametersrecord_to_parameters(
201
- record=context.state.parameters_records[MAIN_PARAMS_RECORD],
199
+ parameters = compat.arrayrecord_to_parameters(
200
+ record=context.state.array_records[MAIN_PARAMS_RECORD],
202
201
  keep_input=True,
203
202
  )
204
203
  res_cen = context.strategy.evaluate(current_round, parameters=parameters)
@@ -224,12 +223,10 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
224
223
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
225
224
 
226
225
  # Get current_round and parameters
227
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
226
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
228
227
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
229
- parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
230
- parameters = compat.parametersrecord_to_parameters(
231
- parametersrecord, keep_input=True
232
- )
228
+ arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
229
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
233
230
 
234
231
  # Get clients and their respective instructions from strategy
235
232
  client_instructions = context.strategy.configure_fit(
@@ -253,10 +250,10 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
253
250
 
254
251
  # Build out messages
255
252
  out_messages = [
256
- grid.create_message(
253
+ Message(
257
254
  content=compat.fitins_to_recorddict(fitins, True),
258
- message_type=MessageType.TRAIN,
259
255
  dst_node_id=proxy.node_id,
256
+ message_type=MessageType.TRAIN,
260
257
  group_id=str(current_round),
261
258
  )
262
259
  for proxy, fitins in client_instructions
@@ -295,10 +292,8 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
295
292
 
296
293
  # Update the parameters and write history
297
294
  if parameters_aggregated:
298
- paramsrecord = compat.parameters_to_parametersrecord(
299
- parameters_aggregated, True
300
- )
301
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
295
+ arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
296
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
302
297
  context.history.add_metrics_distributed_fit(
303
298
  server_round=current_round, metrics=metrics_aggregated
304
299
  )
@@ -311,12 +306,10 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
311
306
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
312
307
 
313
308
  # Get current_round and parameters
314
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
309
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
315
310
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
316
- parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
317
- parameters = compat.parametersrecord_to_parameters(
318
- parametersrecord, keep_input=True
319
- )
311
+ arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
312
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
320
313
 
321
314
  # Get clients and their respective instructions from strategy
322
315
  client_instructions = context.strategy.configure_evaluate(
@@ -339,10 +332,10 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
339
332
 
340
333
  # Build out messages
341
334
  out_messages = [
342
- grid.create_message(
335
+ Message(
343
336
  content=compat.evaluateins_to_recorddict(evalins, True),
344
- message_type=MessageType.EVALUATE,
345
337
  dst_node_id=proxy.node_id,
338
+ message_type=MessageType.EVALUATE,
346
339
  group_id=str(current_round),
347
340
  )
348
341
  for proxy, evalins in client_instructions
@@ -22,7 +22,7 @@ from typing import Optional, Union, cast
22
22
 
23
23
  import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
25
- ConfigsRecord,
25
+ ConfigRecord,
26
26
  Context,
27
27
  FitRes,
28
28
  Message,
@@ -283,10 +283,10 @@ class SecAggPlusWorkflow:
283
283
  ) -> bool:
284
284
  """Execute the 'setup' stage."""
285
285
  # Obtain fit instructions
286
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
286
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
287
287
  current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
288
- parameters = compat.parametersrecord_to_parameters(
289
- context.state.parameters_records[MAIN_PARAMS_RECORD],
288
+ parameters = compat.arrayrecord_to_parameters(
289
+ context.state.array_records[MAIN_PARAMS_RECORD],
290
290
  keep_input=True,
291
291
  )
292
292
  proxy_fitins_lst = context.strategy.configure_fit(
@@ -366,14 +366,14 @@ class SecAggPlusWorkflow:
366
366
  state.sampled_node_ids = state.active_node_ids
367
367
 
368
368
  # Send setup configuration to clients
369
- cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
370
- content = RecordDict({RECORD_KEY_CONFIGS: cfgs_record})
369
+ cfg_record = ConfigRecord(sa_params_dict) # type: ignore
370
+ content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
371
371
 
372
372
  def make(nid: int) -> Message:
373
- return grid.create_message(
373
+ return Message(
374
374
  content=content,
375
- message_type=MessageType.TRAIN,
376
375
  dst_node_id=nid,
376
+ message_type=MessageType.TRAIN,
377
377
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
378
378
  )
379
379
 
@@ -398,7 +398,7 @@ class SecAggPlusWorkflow:
398
398
  if msg.has_error():
399
399
  state.failures.append(Exception(msg.error))
400
400
  continue
401
- key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
401
+ key_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
402
402
  node_id = msg.metadata.src_node_id
403
403
  pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
404
404
  state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
@@ -409,19 +409,19 @@ class SecAggPlusWorkflow:
409
409
  self, grid: Grid, context: LegacyContext, state: WorkflowState
410
410
  ) -> bool:
411
411
  """Execute the 'share keys' stage."""
412
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
412
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
413
413
 
414
414
  def make(nid: int) -> Message:
415
415
  neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
416
- cfgs_record = ConfigsRecord(
416
+ cfg_record = ConfigRecord(
417
417
  {str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
418
418
  )
419
- cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
420
- content = RecordDict({RECORD_KEY_CONFIGS: cfgs_record})
421
- return grid.create_message(
419
+ cfg_record[Key.STAGE] = Stage.SHARE_KEYS
420
+ content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
421
+ return Message(
422
422
  content=content,
423
- message_type=MessageType.TRAIN,
424
423
  dst_node_id=nid,
424
+ message_type=MessageType.TRAIN,
425
425
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
426
426
  )
427
427
 
@@ -458,7 +458,7 @@ class SecAggPlusWorkflow:
458
458
  state.failures.append(Exception(msg.error))
459
459
  continue
460
460
  node_id = msg.metadata.src_node_id
461
- res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
461
+ res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
462
462
  dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST])
463
463
  ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST])
464
464
  srcs += [node_id] * len(dst_lst)
@@ -479,22 +479,22 @@ class SecAggPlusWorkflow:
479
479
  self, grid: Grid, context: LegacyContext, state: WorkflowState
480
480
  ) -> bool:
481
481
  """Execute the 'collect masked vectors' stage."""
482
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
482
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
483
483
 
484
484
  # Send secret key shares to clients (plus FitIns) and collect masked vectors
485
485
  def make(nid: int) -> Message:
486
- cfgs_dict = {
486
+ cfg_dict = {
487
487
  Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
488
488
  Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
489
489
  Key.SOURCE_LIST: state.forward_srcs[nid],
490
490
  }
491
- cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
491
+ cfg_record = ConfigRecord(cfg_dict) # type: ignore
492
492
  content = state.nid_to_fitins[nid]
493
- content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
494
- return grid.create_message(
493
+ content.config_records[RECORD_KEY_CONFIGS] = cfg_record
494
+ return Message(
495
495
  content=content,
496
- message_type=MessageType.TRAIN,
497
496
  dst_node_id=nid,
497
+ message_type=MessageType.TRAIN,
498
498
  group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
499
499
  )
500
500
 
@@ -524,7 +524,7 @@ class SecAggPlusWorkflow:
524
524
  if msg.has_error():
525
525
  state.failures.append(Exception(msg.error))
526
526
  continue
527
- res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
527
+ res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
528
528
  bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS])
529
529
  client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
530
530
  if masked_vector is None:
@@ -550,7 +550,7 @@ class SecAggPlusWorkflow:
550
550
  self, grid: Grid, context: LegacyContext, state: WorkflowState
551
551
  ) -> bool:
552
552
  """Execute the 'unmask' stage."""
553
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
553
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
554
554
  current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
555
555
 
556
556
  # Construct active node IDs and dead node IDs
@@ -560,17 +560,17 @@ class SecAggPlusWorkflow:
560
560
  # Send secure IDs of active and dead clients and collect key shares from clients
561
561
  def make(nid: int) -> Message:
562
562
  neighbours = state.nid_to_neighbours[nid]
563
- cfgs_dict = {
563
+ cfg_dict = {
564
564
  Key.STAGE: Stage.UNMASK,
565
565
  Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
566
566
  Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
567
567
  }
568
- cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
569
- content = RecordDict({RECORD_KEY_CONFIGS: cfgs_record})
570
- return grid.create_message(
568
+ cfg_record = ConfigRecord(cfg_dict) # type: ignore
569
+ content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
570
+ return Message(
571
571
  content=content,
572
- message_type=MessageType.TRAIN,
573
572
  dst_node_id=nid,
573
+ message_type=MessageType.TRAIN,
574
574
  group_id=str(current_round),
575
575
  )
576
576
 
@@ -599,7 +599,7 @@ class SecAggPlusWorkflow:
599
599
  if msg.has_error():
600
600
  state.failures.append(Exception(msg.error))
601
601
  continue
602
- res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
602
+ res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
603
603
  nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
604
604
  shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
605
605
  for owner_nid, share in zip(nids, shares):
@@ -676,10 +676,8 @@ class SecAggPlusWorkflow:
676
676
 
677
677
  # Update the parameters and write history
678
678
  if parameters_aggregated:
679
- paramsrecord = compat.parameters_to_parametersrecord(
680
- parameters_aggregated, True
681
- )
682
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
679
+ arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
680
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
683
681
  context.history.add_metrics_distributed_fit(
684
682
  server_round=current_round, metrics=metrics_aggregated
685
683
  )
flwr/simulation/app.py CHANGED
@@ -47,7 +47,7 @@ from flwr.common.logger import (
47
47
  stop_log_uploader,
48
48
  )
49
49
  from flwr.common.serde import (
50
- configs_record_from_proto,
50
+ config_record_from_proto,
51
51
  context_from_proto,
52
52
  context_to_proto,
53
53
  fab_from_proto,
@@ -184,7 +184,7 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
184
184
  fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
185
185
  GetFederationOptionsRequest(run_id=run.run_id)
186
186
  )
187
- federation_options = configs_record_from_proto(
187
+ federation_options = config_record_from_proto(
188
188
  fed_opt_res.federation_options
189
189
  )
190
190
 
@@ -105,8 +105,10 @@ def pool_size_from_resources(client_resources: dict[str, Union[int, float]]) ->
105
105
  if not node_resources:
106
106
  continue
107
107
 
108
- num_cpus = node_resources["CPU"]
109
- num_gpus = node_resources.get("GPU", 0) # There might not be GPU
108
+ # Fallback to zero when resource quantity is not configured on the ray node
109
+ # e.g.: node without GPU; head node set up not to run tasks (zero resources)
110
+ num_cpus = node_resources.get("CPU", 0)
111
+ num_gpus = node_resources.get("GPU", 0)
110
112
  num_actors = int(num_cpus / client_resources["num_cpus"])
111
113
 
112
114
  # If a GPU is present and client resources do require one
@@ -180,7 +180,7 @@ def run_simulation(
180
180
  for values parsed to initialisation of backend, `client_resources`
181
181
  to define the resources for clients, and `actor` to define the actor
182
182
  parameters. Values supported in <value> are those included by
183
- `flwr.common.typing.ConfigsRecordValues`.
183
+ `flwr.common.typing.ConfigRecordValues`.
184
184
 
185
185
  enable_tf_gpu_growth : bool (default: False)
186
186
  A boolean to indicate whether to enable GPU growth on the main thread. This is
@@ -546,7 +546,7 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
546
546
  default="{}",
547
547
  help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
548
548
  "configure a backend. Values supported in <value> are those included by "
549
- "`flwr.common.typing.ConfigsRecordValues`. ",
549
+ "`flwr.common.typing.ConfigRecordValues`. ",
550
550
  )
551
551
  parser.add_argument(
552
552
  "--enable-tf-gpu-growth",
@@ -23,7 +23,7 @@ from typing import Optional
23
23
  from typing_extensions import override
24
24
 
25
25
  from flwr.cli.config_utils import get_fab_metadata
26
- from flwr.common import ConfigsRecord, Context, RecordDict
26
+ from flwr.common import ConfigRecord, Context, RecordDict
27
27
  from flwr.common.constant import (
28
28
  SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
29
29
  Status,
@@ -141,7 +141,7 @@ class DeploymentEngine(Executor):
141
141
  fab_id, fab_version = get_fab_metadata(fab.content)
142
142
 
143
143
  run_id = self.linkstate.create_run(
144
- fab_id, fab_version, fab_hash, override_config, ConfigsRecord()
144
+ fab_id, fab_version, fab_hash, override_config, ConfigRecord()
145
145
  )
146
146
  return run_id
147
147
 
@@ -160,7 +160,7 @@ class DeploymentEngine(Executor):
160
160
  self,
161
161
  fab_file: bytes,
162
162
  override_config: UserConfig,
163
- federation_options: ConfigsRecord,
163
+ federation_options: ConfigRecord,
164
164
  ) -> Optional[int]:
165
165
  """Start run using the Flower Deployment Engine."""
166
166
  run_id = None
@@ -28,7 +28,7 @@ from flwr.common.auth_plugin import ExecAuthPlugin
28
28
  from flwr.common.constant import LOG_STREAM_INTERVAL, Status, SubStatus
29
29
  from flwr.common.logger import log
30
30
  from flwr.common.serde import (
31
- configs_record_from_proto,
31
+ config_record_from_proto,
32
32
  run_to_proto,
33
33
  user_config_from_proto,
34
34
  )
@@ -79,7 +79,7 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
79
79
  run_id = self.executor.start_run(
80
80
  request.fab.content,
81
81
  user_config_from_proto(request.override_config),
82
- configs_record_from_proto(request.federation_options),
82
+ config_record_from_proto(request.federation_options),
83
83
  )
84
84
 
85
85
  if run_id is None: