flwr-nightly 1.17.0.dev20250320__py3-none-any.whl → 1.17.0.dev20250322__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. flwr/cli/run/run.py +5 -9
  2. flwr/client/client_app.py +10 -12
  3. flwr/client/grpc_client/connection.py +3 -3
  4. flwr/client/message_handler/message_handler.py +3 -3
  5. flwr/client/mod/__init__.py +2 -2
  6. flwr/client/mod/comms_mods.py +16 -22
  7. flwr/client/mod/secure_aggregation/secaggplus_mod.py +26 -26
  8. flwr/common/__init__.py +10 -4
  9. flwr/common/config.py +4 -4
  10. flwr/common/constant.py +1 -1
  11. flwr/common/record/__init__.py +6 -3
  12. flwr/common/record/{parametersrecord.py → arrayrecord.py} +74 -31
  13. flwr/common/record/{configsrecord.py → configrecord.py} +73 -27
  14. flwr/common/record/conversion_utils.py +1 -1
  15. flwr/common/record/{metricsrecord.py → metricrecord.py} +77 -31
  16. flwr/common/record/recorddict.py +95 -56
  17. flwr/common/recorddict_compat.py +54 -62
  18. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  19. flwr/common/serde.py +42 -43
  20. flwr/common/typing.py +8 -8
  21. flwr/proto/exec_pb2.py +30 -30
  22. flwr/proto/exec_pb2.pyi +2 -2
  23. flwr/proto/recorddict_pb2.py +29 -29
  24. flwr/proto/recorddict_pb2.pyi +33 -33
  25. flwr/proto/run_pb2.py +2 -2
  26. flwr/proto/run_pb2.pyi +2 -2
  27. flwr/server/compat/grid_client_proxy.py +1 -1
  28. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  29. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  30. flwr/server/superlink/linkstate/in_memory_linkstate.py +4 -4
  31. flwr/server/superlink/linkstate/linkstate.py +4 -4
  32. flwr/server/superlink/linkstate/sqlite_linkstate.py +7 -7
  33. flwr/server/superlink/linkstate/utils.py +9 -9
  34. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  35. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  36. flwr/server/workflow/default_workflows.py +27 -34
  37. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +32 -34
  38. flwr/simulation/app.py +2 -2
  39. flwr/simulation/ray_transport/ray_actor.py +4 -2
  40. flwr/simulation/run_simulation.py +2 -2
  41. flwr/superexec/deployment.py +3 -3
  42. flwr/superexec/exec_servicer.py +2 -2
  43. flwr/superexec/executor.py +3 -3
  44. flwr/superexec/simulation.py +2 -2
  45. {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/METADATA +1 -1
  46. {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/RECORD +49 -49
  47. {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/LICENSE +0 -0
  48. {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/WHEEL +0 -0
  49. {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/entry_points.txt +0 -0
flwr/cli/run/run.py CHANGED
@@ -35,15 +35,11 @@ from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
35
35
  from flwr.common.config import (
36
36
  flatten_dict,
37
37
  parse_config_args,
38
- user_config_to_configsrecord,
38
+ user_config_to_configrecord,
39
39
  )
40
40
  from flwr.common.constant import CliOutputFormat
41
41
  from flwr.common.logger import print_json_error, redirect_output, restore_output
42
- from flwr.common.serde import (
43
- configs_record_to_proto,
44
- fab_to_proto,
45
- user_config_to_proto,
46
- )
42
+ from flwr.common.serde import config_record_to_proto, fab_to_proto, user_config_to_proto
47
43
  from flwr.common.typing import Fab
48
44
  from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
49
45
  from flwr.proto.exec_pb2_grpc import ExecStub
@@ -171,14 +167,14 @@ def _run_with_exec_api(
171
167
 
172
168
  fab = Fab(fab_hash, content)
173
169
 
174
- # Construct a `ConfigsRecord` out of a flattened `UserConfig`
170
+ # Construct a `ConfigRecord` out of a flattened `UserConfig`
175
171
  fed_conf = flatten_dict(federation_config.get("options", {}))
176
- c_record = user_config_to_configsrecord(fed_conf)
172
+ c_record = user_config_to_configrecord(fed_conf)
177
173
 
178
174
  req = StartRunRequest(
179
175
  fab=fab_to_proto(fab),
180
176
  override_config=user_config_to_proto(parse_config_args(config_overrides)),
181
- federation_options=configs_record_to_proto(c_record),
177
+ federation_options=config_record_to_proto(c_record),
182
178
  )
183
179
  with unauthenticated_exc_handler():
184
180
  res = stub.StartRun(req)
flwr/client/client_app.py CHANGED
@@ -189,7 +189,7 @@ class ClientApp:
189
189
  >>> def train(message: Message, context: Context) -> Message:
190
190
  >>> print("Executing default train function")
191
191
  >>> # Create and return an echo reply message
192
- >>> return message.create_reply(content=message.content)
192
+ >>> return Message(message.content, reply_to=message)
193
193
 
194
194
  Registering a train function with a custom action name:
195
195
 
@@ -200,7 +200,7 @@ class ClientApp:
200
200
  >>> @app.train("custom_action")
201
201
  >>> def custom_action(message: Message, context: Context) -> Message:
202
202
  >>> print("Executing train function for custom action")
203
- >>> return message.create_reply(content=message.content)
203
+ >>> return Message(message.content, reply_to=message)
204
204
 
205
205
  Registering a train function with a function-specific Flower Mod:
206
206
 
@@ -213,7 +213,7 @@ class ClientApp:
213
213
  >>> def train(message: Message, context: Context) -> Message:
214
214
  >>> print("Executing train function with message size mod")
215
215
  >>> # Create and return an echo reply message
216
- >>> return message.create_reply(content=message.content)
216
+ >>> return Message(message.content, reply_to=message)
217
217
  """
218
218
  return _get_decorator(self, MessageType.TRAIN, action, mods)
219
219
 
@@ -244,7 +244,7 @@ class ClientApp:
244
244
  >>> def evaluate(message: Message, context: Context) -> Message:
245
245
  >>> print("Executing default evaluate function")
246
246
  >>> # Create and return an echo reply message
247
- >>> return message.create_reply(content=message.content)
247
+ >>> return Message(message.content, reply_to=message)
248
248
 
249
249
  Registering an evaluate function with a custom action name:
250
250
 
@@ -255,7 +255,7 @@ class ClientApp:
255
255
  >>> @app.evaluate("custom_action")
256
256
  >>> def custom_action(message: Message, context: Context) -> Message:
257
257
  >>> print("Executing evaluate function for custom action")
258
- >>> return message.create_reply(content=message.content)
258
+ >>> return Message(message.content, reply_to=message)
259
259
 
260
260
  Registering an evaluate function with a function-specific Flower Mod:
261
261
 
@@ -268,7 +268,7 @@ class ClientApp:
268
268
  >>> def evaluate(message: Message, context: Context) -> Message:
269
269
  >>> print("Executing evaluate function with message size mod")
270
270
  >>> # Create and return an echo reply message
271
- >>> return message.create_reply(content=message.content)
271
+ >>> return Message(message.content, reply_to=message)
272
272
  """
273
273
  return _get_decorator(self, MessageType.EVALUATE, action, mods)
274
274
 
@@ -299,7 +299,7 @@ class ClientApp:
299
299
  >>> def query(message: Message, context: Context) -> Message:
300
300
  >>> print("Executing default query function")
301
301
  >>> # Create and return an echo reply message
302
- >>> return message.create_reply(content=message.content)
302
+ >>> return Message(message.content, reply_to=message)
303
303
 
304
304
  Registering a query function with a custom action name:
305
305
 
@@ -310,7 +310,7 @@ class ClientApp:
310
310
  >>> @app.query("custom_action")
311
311
  >>> def custom_action(message: Message, context: Context) -> Message:
312
312
  >>> print("Executing query function for custom action")
313
- >>> return message.create_reply(content=message.content)
313
+ >>> return Message(message.content, reply_to=message)
314
314
 
315
315
  Registering a query function with a function-specific Flower Mod:
316
316
 
@@ -323,7 +323,7 @@ class ClientApp:
323
323
  >>> def query(message: Message, context: Context) -> Message:
324
324
  >>> print("Executing query function with message size mod")
325
325
  >>> # Create and return an echo reply message
326
- >>> return message.create_reply(content=message.content)
326
+ >>> return Message(message.content, reply_to=message)
327
327
  """
328
328
  return _get_decorator(self, MessageType.QUERY, action, mods)
329
329
 
@@ -454,8 +454,6 @@ def _registration_error(fn_name: str) -> ValueError:
454
454
  >>> def {fn_name}(message: Message, context: Context) -> Message:
455
455
  >>> print("ClientApp {fn_name} running")
456
456
  >>> # Create and return an echo reply message
457
- >>> return message.create_reply(
458
- >>> content=message.content()
459
- >>> )
457
+ >>> return Message(message.content, reply_to=message)
460
458
  """,
461
459
  )
@@ -28,7 +28,7 @@ from cryptography.hazmat.primitives.asymmetric import ec
28
28
  from flwr.common import (
29
29
  DEFAULT_TTL,
30
30
  GRPC_MAX_MESSAGE_LENGTH,
31
- ConfigsRecord,
31
+ ConfigRecord,
32
32
  Message,
33
33
  Metadata,
34
34
  RecordDict,
@@ -166,7 +166,7 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
166
166
  message_type = MessageType.EVALUATE
167
167
  elif field == "reconnect_ins":
168
168
  recorddict = RecordDict()
169
- recorddict.configs_records["config"] = ConfigsRecord(
169
+ recorddict.config_records["config"] = ConfigRecord(
170
170
  {"seconds": proto.reconnect_ins.seconds}
171
171
  )
172
172
  message_type = "reconnect"
@@ -216,7 +216,7 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
216
216
  msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
217
217
  elif message_type == "reconnect":
218
218
  reason = cast(
219
- Reason.ValueType, recorddict.configs_records["config"]["reason"]
219
+ Reason.ValueType, recorddict.config_records["config"]["reason"]
220
220
  )
221
221
  msg_proto = ClientMessage(
222
222
  disconnect_res=ClientMessage.DisconnectRes(reason=reason)
@@ -26,7 +26,7 @@ from flwr.client.client import (
26
26
  )
27
27
  from flwr.client.numpy_client import NumPyClient
28
28
  from flwr.client.typing import ClientFnExt
29
- from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordDict, log
29
+ from flwr.common import ConfigRecord, Context, Message, Metadata, RecordDict, log
30
30
  from flwr.common.constant import MessageType, MessageTypeLegacy
31
31
  from flwr.common.recorddict_compat import (
32
32
  evaluateres_to_recorddict,
@@ -72,7 +72,7 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
72
72
  if message.metadata.message_type == "reconnect":
73
73
  # Retrieve ReconnectIns from RecordDict
74
74
  recorddict = message.content
75
- seconds = cast(int, recorddict.configs_records["config"]["seconds"])
75
+ seconds = cast(int, recorddict.config_records["config"]["seconds"])
76
76
  # Construct ReconnectIns and call _reconnect
77
77
  disconnect_msg, sleep_duration = _reconnect(
78
78
  ServerMessage.ReconnectIns(seconds=seconds)
@@ -80,7 +80,7 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
80
80
  # Store DisconnectRes in RecordDict
81
81
  reason = cast(int, disconnect_msg.disconnect_res.reason)
82
82
  recorddict = RecordDict()
83
- recorddict.configs_records["config"] = ConfigsRecord({"reason": reason})
83
+ recorddict.config_records["config"] = ConfigRecord({"reason": reason})
84
84
  out_message = Message(recorddict, reply_to=message)
85
85
  # Return Message and sleep duration
86
86
  return out_message, sleep_duration
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
19
- from .comms_mods import message_size_mod, parameters_size_mod
19
+ from .comms_mods import arrays_size_mod, message_size_mod
20
20
  from .localdp_mod import LocalDpMod
21
21
  from .secure_aggregation import secagg_mod, secaggplus_mod
22
22
  from .utils import make_ffn
@@ -24,10 +24,10 @@ from .utils import make_ffn
24
24
  __all__ = [
25
25
  "LocalDpMod",
26
26
  "adaptiveclipping_mod",
27
+ "arrays_size_mod",
27
28
  "fixedclipping_mod",
28
29
  "make_ffn",
29
30
  "message_size_mod",
30
- "parameters_size_mod",
31
31
  "secagg_mod",
32
32
  "secaggplus_mod",
33
33
  ]
@@ -34,47 +34,41 @@ def message_size_mod(
34
34
  """
35
35
  message_size_in_bytes = 0
36
36
 
37
- for p_record in msg.content.parameters_records.values():
38
- message_size_in_bytes += p_record.count_bytes()
39
-
40
- for c_record in msg.content.configs_records.values():
41
- message_size_in_bytes += c_record.count_bytes()
42
-
43
- for m_record in msg.content.metrics_records.values():
44
- message_size_in_bytes += m_record.count_bytes()
37
+ for record in msg.content.values():
38
+ message_size_in_bytes += record.count_bytes()
45
39
 
46
40
  log(INFO, "Message size: %i bytes", message_size_in_bytes)
47
41
 
48
42
  return call_next(msg, ctxt)
49
43
 
50
44
 
51
- def parameters_size_mod(
45
+ def arrays_size_mod(
52
46
  msg: Message, ctxt: Context, call_next: ClientAppCallable
53
47
  ) -> Message:
54
- """Parameters size mod.
48
+ """Arrays size mod.
55
49
 
56
- This mod logs the number of parameters transmitted in the message as well as their
57
- size in bytes.
50
+ This mod logs the number of array elements transmitted in ``ArrayRecord``s of
51
+ the message as well as their sizes in bytes.
58
52
  """
59
53
  model_size_stats = {}
60
- parameters_size_in_bytes = 0
61
- for record_name, p_record in msg.content.parameters_records.items():
62
- p_record_bytes = p_record.count_bytes()
63
- parameters_size_in_bytes += p_record_bytes
64
- parameter_count = 0
65
- for array in p_record.values():
66
- parameter_count += (
54
+ arrays_size_in_bytes = 0
55
+ for record_name, arr_record in msg.content.array_records.items():
56
+ arr_record_bytes = arr_record.count_bytes()
57
+ arrays_size_in_bytes += arr_record_bytes
58
+ element_count = 0
59
+ for array in arr_record.values():
60
+ element_count += (
67
61
  int(np.prod(array.shape)) if array.shape else array.numpy().size
68
62
  )
69
63
 
70
64
  model_size_stats[f"{record_name}"] = {
71
- "parameters": parameter_count,
72
- "bytes": p_record_bytes,
65
+ "elements": element_count,
66
+ "bytes": arr_record_bytes,
73
67
  }
74
68
 
75
69
  if model_size_stats:
76
70
  log(INFO, model_size_stats)
77
71
 
78
- log(INFO, "Total parameters transmitted: %i bytes", parameters_size_in_bytes)
72
+ log(INFO, "Total array elements transmitted: %i bytes", arrays_size_in_bytes)
79
73
 
80
74
  return call_next(msg, ctxt)
@@ -22,7 +22,7 @@ from typing import Any, cast
22
22
 
23
23
  from flwr.client.typing import ClientAppCallable
24
24
  from flwr.common import (
25
- ConfigsRecord,
25
+ ConfigRecord,
26
26
  Context,
27
27
  Message,
28
28
  Parameters,
@@ -63,7 +63,7 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
63
63
  share_keys_plaintext_concat,
64
64
  share_keys_plaintext_separate,
65
65
  )
66
- from flwr.common.typing import ConfigsRecordValues
66
+ from flwr.common.typing import ConfigRecordValues
67
67
 
68
68
 
69
69
  @dataclass
@@ -97,7 +97,7 @@ class SecAggPlusState:
97
97
  ss2_dict: dict[int, bytes] = field(default_factory=dict)
98
98
  public_keys_dict: dict[int, tuple[bytes, bytes]] = field(default_factory=dict)
99
99
 
100
- def __init__(self, **kwargs: ConfigsRecordValues) -> None:
100
+ def __init__(self, **kwargs: ConfigRecordValues) -> None:
101
101
  for k, v in kwargs.items():
102
102
  if k.endswith(":V"):
103
103
  continue
@@ -115,7 +115,7 @@ class SecAggPlusState:
115
115
  new_v = dict(zip(keys, values))
116
116
  self.__setattr__(k, new_v)
117
117
 
118
- def to_dict(self) -> dict[str, ConfigsRecordValues]:
118
+ def to_dict(self) -> dict[str, ConfigRecordValues]:
119
119
  """Convert the state to a dictionary."""
120
120
  ret = vars(self)
121
121
  for k in list(ret.keys()):
@@ -144,13 +144,13 @@ def secaggplus_mod(
144
144
  return call_next(msg, ctxt)
145
145
 
146
146
  # Retrieve local state
147
- if RECORD_KEY_STATE not in ctxt.state.configs_records:
148
- ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord({})
149
- state_dict = ctxt.state.configs_records[RECORD_KEY_STATE]
147
+ if RECORD_KEY_STATE not in ctxt.state.config_records:
148
+ ctxt.state.config_records[RECORD_KEY_STATE] = ConfigRecord({})
149
+ state_dict = ctxt.state.config_records[RECORD_KEY_STATE]
150
150
  state = SecAggPlusState(**state_dict)
151
151
 
152
152
  # Retrieve incoming configs
153
- configs = msg.content.configs_records[RECORD_KEY_CONFIGS]
153
+ configs = msg.content.config_records[RECORD_KEY_CONFIGS]
154
154
 
155
155
  # Check the validity of the next stage
156
156
  check_stage(state.current_stage, configs)
@@ -175,27 +175,27 @@ def secaggplus_mod(
175
175
  res = _collect_masked_vectors(
176
176
  state, configs, fitres.num_examples, fitres.parameters
177
177
  )
178
- for p_record in out_content.parameters_records.values():
179
- p_record.clear()
178
+ for arr_record in out_content.array_records.values():
179
+ arr_record.clear()
180
180
  elif state.current_stage == Stage.UNMASK:
181
181
  res = _unmask(state, configs)
182
182
  else:
183
183
  raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}")
184
184
 
185
185
  # Save state
186
- ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict())
186
+ ctxt.state.config_records[RECORD_KEY_STATE] = ConfigRecord(state.to_dict())
187
187
 
188
188
  # Return message
189
- out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
189
+ out_content.config_records[RECORD_KEY_CONFIGS] = ConfigRecord(res, False)
190
190
  return Message(out_content, reply_to=msg)
191
191
 
192
192
 
193
- def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
193
+ def check_stage(current_stage: str, configs: ConfigRecord) -> None:
194
194
  """Check the validity of the next stage."""
195
195
  # Check the existence of Config.STAGE
196
196
  if Key.STAGE not in configs:
197
197
  raise KeyError(
198
- f"The required key '{Key.STAGE}' is missing from the ConfigsRecord."
198
+ f"The required key '{Key.STAGE}' is missing from the ConfigRecord."
199
199
  )
200
200
 
201
201
  # Check the value type of the Config.STAGE
@@ -223,7 +223,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
223
223
 
224
224
 
225
225
  # pylint: disable-next=too-many-branches
226
- def check_configs(stage: str, configs: ConfigsRecord) -> None:
226
+ def check_configs(stage: str, configs: ConfigRecord) -> None:
227
227
  """Check the validity of the configs."""
228
228
  # Check configs for the setup stage
229
229
  if stage == Stage.SETUP:
@@ -239,7 +239,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
239
239
  if key not in configs:
240
240
  raise KeyError(
241
241
  f"Stage {Stage.SETUP}: the required key '{key}' is "
242
- "missing from the ConfigsRecord."
242
+ "missing from the ConfigRecord."
243
243
  )
244
244
  # Bool is a subclass of int in Python,
245
245
  # so `isinstance(v, int)` will return True even if v is a boolean.
@@ -272,7 +272,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
272
272
  raise KeyError(
273
273
  f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
274
274
  f"the required key '{key}' is "
275
- "missing from the ConfigsRecord."
275
+ "missing from the ConfigRecord."
276
276
  )
277
277
  if not isinstance(configs[key], list) or any(
278
278
  elm
@@ -295,7 +295,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
295
295
  raise KeyError(
296
296
  f"Stage {Stage.UNMASK}: "
297
297
  f"the required key '{key}' is "
298
- "missing from the ConfigsRecord."
298
+ "missing from the ConfigRecord."
299
299
  )
300
300
  if not isinstance(configs[key], list) or any(
301
301
  elm
@@ -313,8 +313,8 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
313
313
 
314
314
 
315
315
  def _setup(
316
- state: SecAggPlusState, configs: ConfigsRecord
317
- ) -> dict[str, ConfigsRecordValues]:
316
+ state: SecAggPlusState, configs: ConfigRecord
317
+ ) -> dict[str, ConfigRecordValues]:
318
318
  # Assigning parameter values to object fields
319
319
  sec_agg_param_dict = configs
320
320
  state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
@@ -349,8 +349,8 @@ def _setup(
349
349
 
350
350
  # pylint: disable-next=too-many-locals
351
351
  def _share_keys(
352
- state: SecAggPlusState, configs: ConfigsRecord
353
- ) -> dict[str, ConfigsRecordValues]:
352
+ state: SecAggPlusState, configs: ConfigRecord
353
+ ) -> dict[str, ConfigRecordValues]:
354
354
  named_bytes_tuples = cast(dict[str, tuple[bytes, bytes]], configs)
355
355
  key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
356
356
  log(DEBUG, "Node %d: starting stage 1...", state.nid)
@@ -412,10 +412,10 @@ def _share_keys(
412
412
  # pylint: disable-next=too-many-locals
413
413
  def _collect_masked_vectors(
414
414
  state: SecAggPlusState,
415
- configs: ConfigsRecord,
415
+ configs: ConfigRecord,
416
416
  num_examples: int,
417
417
  updated_parameters: Parameters,
418
- ) -> dict[str, ConfigsRecordValues]:
418
+ ) -> dict[str, ConfigRecordValues]:
419
419
  log(DEBUG, "Node %d: starting stage 2...", state.nid)
420
420
  available_clients: list[int] = []
421
421
  ciphertexts = cast(list[bytes], configs[Key.CIPHERTEXT_LIST])
@@ -498,8 +498,8 @@ def _collect_masked_vectors(
498
498
 
499
499
 
500
500
  def _unmask(
501
- state: SecAggPlusState, configs: ConfigsRecord
502
- ) -> dict[str, ConfigsRecordValues]:
501
+ state: SecAggPlusState, configs: ConfigRecord
502
+ ) -> dict[str, ConfigRecordValues]:
503
503
  log(DEBUG, "Node %d: starting stage 3...", state.nid)
504
504
 
505
505
  active_nids = cast(list[int], configs[Key.ACTIVE_NODE_ID_LIST])
flwr/common/__init__.py CHANGED
@@ -31,7 +31,10 @@ from .parameter import ndarray_to_bytes as ndarray_to_bytes
31
31
  from .parameter import ndarrays_to_parameters as ndarrays_to_parameters
32
32
  from .parameter import parameters_to_ndarrays as parameters_to_ndarrays
33
33
  from .record import Array as Array
34
+ from .record import ArrayRecord as ArrayRecord
35
+ from .record import ConfigRecord as ConfigRecord
34
36
  from .record import ConfigsRecord as ConfigsRecord
37
+ from .record import MetricRecord as MetricRecord
35
38
  from .record import MetricsRecord as MetricsRecord
36
39
  from .record import ParametersRecord as ParametersRecord
37
40
  from .record import RecordDict as RecordDict
@@ -42,7 +45,7 @@ from .telemetry import event as event
42
45
  from .typing import ClientMessage as ClientMessage
43
46
  from .typing import Code as Code
44
47
  from .typing import Config as Config
45
- from .typing import ConfigsRecordValues as ConfigsRecordValues
48
+ from .typing import ConfigRecordValues as ConfigRecordValues
46
49
  from .typing import DisconnectRes as DisconnectRes
47
50
  from .typing import EvaluateIns as EvaluateIns
48
51
  from .typing import EvaluateRes as EvaluateRes
@@ -52,9 +55,9 @@ from .typing import GetParametersIns as GetParametersIns
52
55
  from .typing import GetParametersRes as GetParametersRes
53
56
  from .typing import GetPropertiesIns as GetPropertiesIns
54
57
  from .typing import GetPropertiesRes as GetPropertiesRes
58
+ from .typing import MetricRecordValues as MetricRecordValues
55
59
  from .typing import Metrics as Metrics
56
60
  from .typing import MetricsAggregationFn as MetricsAggregationFn
57
- from .typing import MetricsRecordValues as MetricsRecordValues
58
61
  from .typing import NDArray as NDArray
59
62
  from .typing import NDArrays as NDArrays
60
63
  from .typing import Parameters as Parameters
@@ -66,11 +69,13 @@ from .typing import Status as Status
66
69
 
67
70
  __all__ = [
68
71
  "Array",
72
+ "ArrayRecord",
69
73
  "ClientMessage",
70
74
  "Code",
71
75
  "Config",
76
+ "ConfigRecord",
77
+ "ConfigRecordValues",
72
78
  "ConfigsRecord",
73
- "ConfigsRecordValues",
74
79
  "Context",
75
80
  "DEFAULT_TTL",
76
81
  "DisconnectRes",
@@ -89,10 +94,11 @@ __all__ = [
89
94
  "MessageType",
90
95
  "MessageTypeLegacy",
91
96
  "Metadata",
97
+ "MetricRecord",
98
+ "MetricRecordValues",
92
99
  "Metrics",
93
100
  "MetricsAggregationFn",
94
101
  "MetricsRecord",
95
- "MetricsRecordValues",
96
102
  "NDArray",
97
103
  "NDArrays",
98
104
  "Parameters",
flwr/common/config.py CHANGED
@@ -34,7 +34,7 @@ from flwr.common.constant import (
34
34
  )
35
35
  from flwr.common.typing import Run, UserConfig, UserConfigValue
36
36
 
37
- from . import ConfigsRecord, object_ref
37
+ from . import ConfigRecord, object_ref
38
38
 
39
39
  T_dict = TypeVar("T_dict", bound=dict[str, Any]) # pylint: disable=invalid-name
40
40
 
@@ -260,9 +260,9 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
260
260
  )
261
261
 
262
262
 
263
- def user_config_to_configsrecord(config: UserConfig) -> ConfigsRecord:
264
- """Construct a `ConfigsRecord` out of a `UserConfig`."""
265
- c_record = ConfigsRecord()
263
+ def user_config_to_configrecord(config: UserConfig) -> ConfigRecord:
264
+ """Construct a `ConfigRecord` out of a `UserConfig`."""
265
+ c_record = ConfigRecord()
266
266
  for k, v in config.items():
267
267
  c_record[k] = v
268
268
 
flwr/common/constant.py CHANGED
@@ -121,7 +121,7 @@ TIMESTAMP_HEADER = "flwr-timestamp"
121
121
  TIMESTAMP_TOLERANCE = 10 # General tolerance for timestamp verification
122
122
  SYSTEM_TIME_TOLERANCE = 5 # Allowance for system time drift
123
123
 
124
- # Constants for ParametersRecord
124
+ # Constants for ArrayRecord
125
125
  GC_THRESHOLD = 200_000_000 # 200 MB
126
126
 
127
127
 
@@ -15,15 +15,18 @@
15
15
  """Record APIs."""
16
16
 
17
17
 
18
- from .configsrecord import ConfigsRecord
18
+ from .arrayrecord import Array, ArrayRecord, ParametersRecord
19
+ from .configrecord import ConfigRecord, ConfigsRecord
19
20
  from .conversion_utils import array_from_numpy
20
- from .metricsrecord import MetricsRecord
21
- from .parametersrecord import Array, ParametersRecord
21
+ from .metricrecord import MetricRecord, MetricsRecord
22
22
  from .recorddict import RecordDict, RecordSet
23
23
 
24
24
  __all__ = [
25
25
  "Array",
26
+ "ArrayRecord",
27
+ "ConfigRecord",
26
28
  "ConfigsRecord",
29
+ "MetricRecord",
27
30
  "MetricsRecord",
28
31
  "ParametersRecord",
29
32
  "RecordDict",